08 Tree Classification
08 Tree Classification
1 / 22
Agenda
2 / 22
Regression versus Classification trees
3 / 22
Interpretation of classification trees
4 / 22
How to grow a classification tree?
5 / 22
Classification error rate
6 / 22
The Gini index
K
X
G= p̂mk (1 − p̂mk ) (2)
k=1
7 / 22
The Gini index
The Gini index takes on a small value if all of the p̂mk are close to 0
or 1.
Because of this, the Gini index is called a measure of node purity.
For example, a small value of p̂mk indicates indicates that a node
contains predominantly observations from a single class.
8 / 22
The cross entropy
K
X
D=− p̂mk log(p̂mk ) (3)
k=1
Since
0 ≤ p̂mk ≤ 1
this implies that
0 ≤ −p̂mk log(p̂mk ).
Exercise: The cross entropy will take on a value near 0 if the p̂mk ’s
are all near 0 or 1.
9 / 22
Gini index, cross entropy, and classification error rate
Like the Gini index, the cross-entropy will take on a small value if
node m is pure.
In fact, it turns out that the Gini index and the cross-entropy are
quite similar numerically.
When building a classification tree, either the Gini index or the
cross- entropy are typically used to evaluate the quality of a
particular split, since these two approaches are more sensitive to
node purity than is the classification error rate.
Any of these three approaches might be used when pruning the tree,
but the classification error rate is preferable if prediction accuracy of
the final pruned tree is the goal.
10 / 22
Deviance
XX
−2 nmk log(pmk
ˆ ) (4)
m k
11 / 22
Application to Carseats Dataset
12 / 22
Task 1
13 / 22
Solution Task 1 (a,b,c)
library(ISLR)
library(tree)
attach(Carseats)
# creating a binary variable
High <- ifelse(Sales <= 8, "No", "Yes")
# merge High with the rest of the Carseats
Carseats <- data.frame(Carseats, High)
# fit a regression tree
tree.carseats <- tree(High ~. - Sales, Carseats)
14 / 22
Task 2
How does the tree fit? Plot your tree and explain your explaination.
15 / 22
Solution Task 2
16 / 22
Solution Task 2 (continued)
summary(tree.carseats)
##
## Classification tree:
## tree(formula = High ~ . - Sales, data = Carseats)
## Variables actually used in tree construction:
## [1] "ShelveLoc" "Price" "Income" "CompPrice
## [6] "Advertising" "Age" "US"
## Number of terminal nodes: 27
## Residual mean deviance: 0.4575 = 170.7 / 373
## Misclassification error rate: 0.09 = 36 / 400
17 / 22
Solution to Task 2 (Continued) ShelveLoc: Bad,Medium
Yes No
No
Yes
Yes No
Yes No
19 / 22
Solution to Task 3
20 / 22
Task 4
Given Task 3, now prune the tree to see if you get improved results.
Note: use the argument FUN=prune.misclass in order to indicate
that we want the classification error rate to guide the
cross-validation and pruning process, rather than the default for the
cv.tree() function, which is the deviance.
21 / 22
Solution to Task 4
22 / 22