The tree
library is isued to construct classification and regression trees. We use classification trees to analyse the Carseat
data. We recode the continuous variable Sales
Carseats %>%
as_tibble() %>%
mutate(High = as.factor( ifelse(Sales <= 8, 'No', 'Yes') )) -> carseats
We now use tree()
to fit a classification tree.
carseat_tree <- tree(High ~ . -Sales, data = carseats)
summary(carseat_tree)
##
## Classification tree:
## tree(formula = High ~ . - Sales, data = carseats)
## Variables actually used in tree construction:
## [1] "ShelveLoc" "Price" "Income" "CompPrice" "Population"
## [6] "Advertising" "Age" "US"
## Number of terminal nodes: 27
## Residual mean deviance: 0.4575 = 170.7 / 373
## Misclassification error rate: 0.09 = 36 / 400
We see the training error rate is 9%. For classification trees, the error rate reported by summary()
is given by:
\[ -2 \sum_m \sum_k n_{mk} log\hat{p}_mk \]
Where \(n_{mk}\) is the number of observations in the \(m\)th terminal node that belong to the \(k\)th class. The residual mean deviance is the deviance divided by \(n - \mid T_0 \mid\), where \(\mid T_0 \mid\) is the number of terminal nodes.
We can plot the tree:
plot(carseat_tree)
text(carseat_tree, pretty = 0, cex = .7)
The most important factor appears to be shelving location since the first branch differentiates Good
from Bad
and Medium
.
The object’s print function outputs the branches:
carseat_tree
## node), split, n, deviance, yval, (yprob)
## * denotes terminal node
##
## 1) root 400 541.500 No ( 0.59000 0.41000 )
## 2) ShelveLoc: Bad,Medium 315 390.600 No ( 0.68889 0.31111 )
## 4) Price < 92.5 46 56.530 Yes ( 0.30435 0.69565 )
## 8) Income < 57 10 12.220 No ( 0.70000 0.30000 )
## 16) CompPrice < 110.5 5 0.000 No ( 1.00000 0.00000 ) *
## 17) CompPrice > 110.5 5 6.730 Yes ( 0.40000 0.60000 ) *
## 9) Income > 57 36 35.470 Yes ( 0.19444 0.80556 )
## 18) Population < 207.5 16 21.170 Yes ( 0.37500 0.62500 ) *
## 19) Population > 207.5 20 7.941 Yes ( 0.05000 0.95000 ) *
## 5) Price > 92.5 269 299.800 No ( 0.75465 0.24535 )
## 10) Advertising < 13.5 224 213.200 No ( 0.81696 0.18304 )
## 20) CompPrice < 124.5 96 44.890 No ( 0.93750 0.06250 )
## 40) Price < 106.5 38 33.150 No ( 0.84211 0.15789 )
## 80) Population < 177 12 16.300 No ( 0.58333 0.41667 )
## 160) Income < 60.5 6 0.000 No ( 1.00000 0.00000 ) *
## 161) Income > 60.5 6 5.407 Yes ( 0.16667 0.83333 ) *
## 81) Population > 177 26 8.477 No ( 0.96154 0.03846 ) *
## 41) Price > 106.5 58 0.000 No ( 1.00000 0.00000 ) *
## 21) CompPrice > 124.5 128 150.200 No ( 0.72656 0.27344 )
## 42) Price < 122.5 51 70.680 Yes ( 0.49020 0.50980 )
## 84) ShelveLoc: Bad 11 6.702 No ( 0.90909 0.09091 ) *
## 85) ShelveLoc: Medium 40 52.930 Yes ( 0.37500 0.62500 )
## 170) Price < 109.5 16 7.481 Yes ( 0.06250 0.93750 ) *
## 171) Price > 109.5 24 32.600 No ( 0.58333 0.41667 )
## 342) Age < 49.5 13 16.050 Yes ( 0.30769 0.69231 ) *
## 343) Age > 49.5 11 6.702 No ( 0.90909 0.09091 ) *
## 43) Price > 122.5 77 55.540 No ( 0.88312 0.11688 )
## 86) CompPrice < 147.5 58 17.400 No ( 0.96552 0.03448 ) *
## 87) CompPrice > 147.5 19 25.010 No ( 0.63158 0.36842 )
## 174) Price < 147 12 16.300 Yes ( 0.41667 0.58333 )
## 348) CompPrice < 152.5 7 5.742 Yes ( 0.14286 0.85714 ) *
## 349) CompPrice > 152.5 5 5.004 No ( 0.80000 0.20000 ) *
## 175) Price > 147 7 0.000 No ( 1.00000 0.00000 ) *
## 11) Advertising > 13.5 45 61.830 Yes ( 0.44444 0.55556 )
## 22) Age < 54.5 25 25.020 Yes ( 0.20000 0.80000 )
## 44) CompPrice < 130.5 14 18.250 Yes ( 0.35714 0.64286 )
## 88) Income < 100 9 12.370 No ( 0.55556 0.44444 ) *
## 89) Income > 100 5 0.000 Yes ( 0.00000 1.00000 ) *
## 45) CompPrice > 130.5 11 0.000 Yes ( 0.00000 1.00000 ) *
## 23) Age > 54.5 20 22.490 No ( 0.75000 0.25000 )
## 46) CompPrice < 122.5 10 0.000 No ( 1.00000 0.00000 ) *
## 47) CompPrice > 122.5 10 13.860 No ( 0.50000 0.50000 )
## 94) Price < 125 5 0.000 Yes ( 0.00000 1.00000 ) *
## 95) Price > 125 5 0.000 No ( 1.00000 0.00000 ) *
## 3) ShelveLoc: Good 85 90.330 Yes ( 0.22353 0.77647 )
## 6) Price < 135 68 49.260 Yes ( 0.11765 0.88235 )
## 12) US: No 17 22.070 Yes ( 0.35294 0.64706 )
## 24) Price < 109 8 0.000 Yes ( 0.00000 1.00000 ) *
## 25) Price > 109 9 11.460 No ( 0.66667 0.33333 ) *
## 13) US: Yes 51 16.880 Yes ( 0.03922 0.96078 ) *
## 7) Price > 135 17 22.070 No ( 0.64706 0.35294 )
## 14) Income < 46 6 0.000 No ( 1.00000 0.00000 ) *
## 15) Income > 46 11 15.160 Yes ( 0.45455 0.54545 ) *
Each split shows the split criterion, the number of observations in the branch, the deviance, the overall prediction for the branch, and the fraction of observations in the branch that take on the values. Branches to terminal nodes are indicated with an asterisk.
Let’s split the data into a training and test set to gauge the predictive power of the tree. We also create a function to format our output table.
print_table <- function(x) {
x %>% kable(align = 'l') %>% kable_styling()
}
set.seed(1)
carseat_smpl <- carseats %>% resample_partition(c(train = .5, test = .5))
carseat_tree <- tree(High ~ .-Sales, carseat_smpl$train)
carseat_smpl$test %>%
as_tibble() %>%
mutate(High_prime = predict(carseat_tree, newdata = ., type = 'class')) %>%
summarise('Error Rate' = mean(High != High_prime) * 100) %>%
print_table()
Error Rate |
---|
29.35323 |
We now test whether pruning the tree enhances its predictive capabilities. The function cv.tree()
performs cross-validation. The argument FUN = prune.misclass
in order to let the classification rate guide the cross-validation and pruning process rather than the default, which is deviance.
carseat_cv <- cv.tree(carseat_tree, FUN = prune.misclass)
carseat_cv
## $size
## [1] 19 15 14 12 9 6 4 2 1
##
## $dev
## [1] 50 50 49 54 51 53 52 56 81
##
## $k
## [1] -Inf 0.000000 1.000000 1.500000 2.000000 2.333333 2.500000
## [8] 8.500000 24.000000
##
## $method
## [1] "misclass"
##
## attr(,"class")
## [1] "prune" "tree.sequence"
The size
attribute shows the number of terminal nodes of each tree considered, dev
is the error rate (in this case cross-validation error), and k
is the cost-complexity parameter \(\alpha\).
The tree with 14 terminal nodes had the lowest CV error. We plot size
against dev
tibble(
size = carseat_cv$size,
cv_error = carseat_cv$dev
) %>%
ggplot(aes(size, cv_error)) +
geom_point() +
geom_line() +
labs(x = 'Terminal Nodes', y = 'Cross-Validation Error Rate')
We apply the prune.misclass()
function in order to prune to the 14 node tree.
carseat_prune <- prune.misclass(carseat_tree, best = 9)
plot(carseat_prune)
text(carseat_prune, cex = .7)
Let’s see how this performs on the test data.
carseat_smpl$test %>%
as_tibble() %>%
mutate(High_prime = predict(carseat_prune, newdata = ., type = 'class')) %>%
summarise('Error Rate' = mean(High != High_prime) * 100) %>%
print_table()
Error Rate |
---|
27.8607 |
The error rate has decreased from 29.35% to 27.86%.
We fit a regression tree on the Boston
data set. We first fit the tree to the training data.
set.seed(20)
Boston %>%
as_tibble() %>%
resample_partition(c(train = .5, test = .5)) -> boston_smpl
boston_tree <- tree(medv ~ ., data = boston_smpl$train)
summary(boston_tree)
##
## Regression tree:
## tree(formula = medv ~ ., data = boston_smpl$train)
## Variables actually used in tree construction:
## [1] "rm" "lstat" "rad" "ptratio"
## Number of terminal nodes: 8
## Residual mean deviance: 15.12 = 3689 / 244
## Distribution of residuals:
## Min. 1st Qu. Median Mean 3rd Qu. Max.
## -23.04000 -1.69400 0.05849 0.00000 2.23200 20.78000
plot(boston_tree)
text(boston_tree, cex = .7, pretty = 0)
Note that only 5 of the variables have been used in constructing the tree.
We use cv.tree()
to see if pruning the tree will improve performance.
boston_cv <- cv.tree(boston_tree)
tibble(size = as.integer(boston_cv$size), deviance = boston_cv$dev) %>%
ggplot(aes(size, deviance)) +
geom_point() +
geom_line()
We pick 5 as the point to cut where the knee of the graph appears to be.
boston_prune <- prune.tree(boston_tree, best = 5)
plot(boston_prune)
text(boston_prune, cex = .7, pretty = 0)
boston_smpl$test %>%
as_tibble() %>%
mutate(medv_prime = predict(boston_prune, newdata = .)) %>%
summarise('MSE' = mean((medv - medv_prime)^2)) %>%
print_table()
MSE |
---|
28.95809 |
The MSE is around 28, therefore the root mean squred error is \(\sqrt(27) \approx 5.3\), so this model leads to predictions which are within $5,300 of the true median of the house price medv
.
The randomForest
library is used to perform random forests and bagging. We recal that bagging is a random forest with \(m = p\), so the randomForest()
function can be used for both scenarios.
boston_bag <- randomForest(medv ~ ., data = boston_smpl$train, mtry = 13, importance = T)
boston_bag
##
## Call:
## randomForest(formula = medv ~ ., data = boston_smpl$train, mtry = 13, importance = T)
## Type of random forest: regression
## Number of trees: 500
## No. of variables tried at each split: 13
##
## Mean of squared residuals: 15.13463
## % Var explained: 81.2
The mtry = 13
indicates that all 13 predictors should be considered for each split of the tree - ie. bagging should be done.
Let’s take a look at how it performs:
boston_smpl$test %>%
as_tibble() %>%
mutate(medv_prime = predict(boston_bag, newdata = .)) %>%
summarise(MSE = mean((medv - medv_prime)^2)) %>%
print_table()
MSE |
---|
15.46192 |
The test MSE is over that of the pruned tree.
The number of trees grown can be changed with the ntree
argument.
boston_bag <- randomForest(medv ~ ., data = boston_smpl$train, mtry = 13, ntree = 25)
boston_smpl$test %>%
as_tibble() %>%
mutate(medv_prime = predict(boston_bag, newdata = .)) %>%
summarise(MSE = mean((medv - medv_prime)^2)) %>%
print_table()
MSE |
---|
16.14903 |
Growing a random forest is exactly the same, except a smaller value of mtry
is used. By default randomForest()
uses * \(p/3\) variables when building a random forest of regression trees. * \(\sqrt{p}\) variables when building a random forest of classification trees.
boston_frst <- randomForest(medv ~ ., data = boston_smpl$train, mtry = 6, importance = T)
boston_smpl$test %>%
as_tibble() %>%
mutate(medv_prime = predict(boston_frst, newdata = .)) %>%
summarise(MSE = mean((medv - medv_prime)^2)) %>%
print_table()
MSE |
---|
15.92335 |
The random forest has slightly increased the test MSE as opposed to bagging.
The importance()
function shows us the importance of each variable:
importance(boston_frst)
## %IncMSE IncNodePurity
## crim 10.4666047 543.58661
## zn 2.8049562 62.72601
## indus 8.4061104 864.37024
## chas 0.4863885 40.69838
## nox 12.9571009 623.95515
## rm 33.8397112 7298.94203
## age 14.2347805 670.81392
## dis 8.7153657 1086.90617
## rad 4.5419058 122.74522
## tax 8.7075252 337.11199
## ptratio 15.6394880 1338.60228
## black 8.2978958 360.42838
## lstat 26.9012662 6468.64836
The two measures of importance are: * %IncMSE - the mean decrease of accuracy in predictions on the out of bag samples when a given variable is excluded from the model. * IncNodePurity - Measure of the total decrease in node impurity that results from splits over that variable, averaged over all trees.
The varImpPlot()
function can be used to plot these importance measures:
varImpPlot(boston_frst)
The results show that across all trees considered in the random forest, wealth of the community (
lstat
) and house size (rm
) are the two most important variables.
We use the gbm
package and the gbm()
function to fit boosted regression trees to the Boston
data set. We use the distribution = 'gaussian'
argument as this is a regression problem. If it were a binary classification problem we would use distribution = 'bernoulli'
. The argument n.trees = 5000
indicates we want 5000 trees, and interaction.depth = 4
limits the depth of each tree.
boston_boost <- gbm(
medv ~ .,
data = boston_smpl$train,
distribution = 'gaussian',
n.trees = 5000,
interaction.depth = 4
)
The summary()
function produces a relative influence plot and statistics
summary(boston_boost)
## var rel.inf
## rm rm 37.51782445
## lstat lstat 28.08569531
## dis dis 8.12962345
## nox nox 5.56629537
## crim crim 5.42904772
## ptratio ptratio 4.04181395
## black black 3.96800915
## age age 3.55343182
## tax tax 1.27903671
## chas chas 1.20998688
## rad rad 0.58315384
## indus indus 0.55746738
## zn zn 0.07861397
Again we see that lstat
and rm
are the most important variables.
We can produce partial dependence plots for these two variables. This illustrates the marginal effect of these two variables on the response after integrating out the other variables.
par(mfrow = c(1,2))
plot(boston_boost, i ="rm")
plot(boston_boost, i ="lstat")
We use the boosted model to predict medv
.
boston_smpl$test %>%
as_tibble() %>%
mutate(medv_prime = predict(boston_boost, n.trees = 5000, newdata = .)) %>%
summarise('MSE' = mean( (medv - medv_prime)^2 )) %>%
print_table()
MSE |
---|
13.83959 |
The test MSE is better than the bagging and the random forest. We can boost with a slightly different shrinkage parameter \(\lambda\). Let’s try \(\lambda = .02\).
boston_boost <- gbm(
medv ~ .,
data = boston_smpl$train,
distribution = 'gaussian',
n.trees = 5000,
interaction.depth = 4,
shrinkage = 0.2,
verbose = F
)
boston_smpl$test %>%
as_tibble() %>%
mutate(medv_prime = predict(boston_boost, n.trees = 5000, newdata = .)) %>%
summarise('MSE' = mean( (medv - medv_prime)^2 )) %>%
print_table()
MSE |
---|
15.79923 |
In this instance it raises our test MSE - however we could use cross-validation to find the best shrinkage factor.