Open In App

Decision Tree in R Programming

Last Updated : 10 Feb, 2025
Comments
Improve
Suggest changes
Like Article
Like
Report

In this article, we’ll explore how to implement decision trees in R, covering key concepts, step-by-step examples, and tuning strategies.

A decision tree is a flowchart-like model where each internal node represents a decision based on a feature, each branch represents an outcome of that decision, and each leaf node represents a final prediction. The algorithm recursively splits the data into subsets based on feature values to maximize homogeneity in the resulting groups.

Key concepts include:

  1. Root Node: The topmost decision node.
  2. Splitting Criteria: Metrics like Gini impurity (classification) or variance reduction (regression) determine how to split data.
  3. Pruning: Techniques to trim branches that overfit the training data.
  4. Leaf Nodes: Terminal nodes providing final predictions.

Building Decision Trees in R

Let’s build a decision tree to classify iris flowers into species (setosa, versicolor, or virginica).

1. Load Data and Split into Train/Test Sets

Before building a decision tree, we need to load the dataset and prepare it for training.

R
library(rpart)
library(rpart.plot)
library(caret)

data(iris)
set.seed(123)
train_index <- createDataPartition(iris$Species, p = 0.8, list = FALSE)
train_data <- iris[train_index, ]
test_data <- iris[-train_index, ]


2. Train the Model

Once the data is ready, we can build the Decision Tree model.

R
tree_model <- rpart(Species ~ ., 
                    data = train_data, 
                    method = "class",  # For classification
                    control = rpart.control(minsplit = 10, cp = 0.01))

The model is trained using the Gini Index or Entropy to determine the best splits.

3. Visualize the Tree

R
rpart.plot(tree_model, box.palette = "auto", nn = TRUE)

Output

Capture

Decision Tree

This generates a tree diagram showing decision rules (e.g., petal length ≤ 2.5 for setosa). The tree uses Petal Length and Petal Width to classify flowers into three species: Setosa, Versicolor, and Virginica. The structure visually explains how the classification happens based on different feature thresholds.

4. Make predictions and Evaluate

To check model performance, we predict species for the test dataset.

R
predictions <- predict(tree_model, test_data, type = "class")
confusionMatrix(predictions, test_data$Species)

Output:

Capture

The confusion matrix shows accuracy, precision, and recall. The accuracy is high for all three species.

Decision Trees are easy to interpret and visualize. We can fine-tune it further by adjusting hyperparameters like pruning.

5. Prune the Tree

Avoid overfitting by pruning using the complexity parameter (cp):

R
printcp(tree_model)  # Identify optimal cp (xerror is minimized)
optimal_cp <- tree_model$cptable[which.min(tree_model$cptable[, "xerror"]), "CP"]
pruned_tree <- prune(tree_model, cp = optimal_cp)
rpart.plot(pruned_tree)

Output:

CaptureCapture

Tuning Hyperparameters

Use caret to optimize parameters like minsplit (minimum samples to split) and cp:

R
control <- trainControl(method = "cv", number = 10)  # 10-fold cross-validation
tuned_tree <- train(Species ~ ., 
                    data = iris, 
                    method = "rpart",
                    trControl = control,
                    tuneGrid = expand.grid(cp = seq(0.01, 0.1, 0.01)))
print(tuned_tree)

Output:

Capture

Decision trees in R are a versatile tool for predictive modeling. The rpart and caret packages simplify implementation, while pruning and cross-validation ensure robustness. For complex datasets, consider ensemble methods like random forests or gradient boosting (e.g., randomForest package) built on decision tree principles.



Next Article
Article Tags :
Practice Tags :

Similar Reads