8.3.1 - Fitting Classification Trees

The tree library is isued to construct classification and regression trees. We use classification trees to analyse the Carseat data. We recode the continuous variable Sales

Carseats %>%
    as_tibble() %>%
    mutate(High = as.factor( ifelse(Sales <= 8, 'No', 'Yes') )) -> carseats

We now use tree() to fit a classification tree.

carseat_tree <- tree(High ~ . -Sales, data = carseats)

summary(carseat_tree)
## 
## Classification tree:
## tree(formula = High ~ . - Sales, data = carseats)
## Variables actually used in tree construction:
## [1] "ShelveLoc"   "Price"       "Income"      "CompPrice"   "Population" 
## [6] "Advertising" "Age"         "US"         
## Number of terminal nodes:  27 
## Residual mean deviance:  0.4575 = 170.7 / 373 
## Misclassification error rate: 0.09 = 36 / 400

We see the training error rate is 9%. For classification trees, the error rate reported by summary() is given by:

\[ -2 \sum_m \sum_k n_{mk} log\hat{p}_mk \]

Where \(n_{mk}\) is the number of observations in the \(m\)th terminal node that belong to the \(k\)th class. The residual mean deviance is the deviance divided by \(n - \mid T_0 \mid\), where \(\mid T_0 \mid\) is the number of terminal nodes.

We can plot the tree:

plot(carseat_tree)
text(carseat_tree, pretty = 0, cex = .7)

The most important factor appears to be shelving location since the first branch differentiates Good from Bad and Medium.

The object’s print function outputs the branches:

carseat_tree
## node), split, n, deviance, yval, (yprob)
##       * denotes terminal node
## 
##   1) root 400 541.500 No ( 0.59000 0.41000 )  
##     2) ShelveLoc: Bad,Medium 315 390.600 No ( 0.68889 0.31111 )  
##       4) Price < 92.5 46  56.530 Yes ( 0.30435 0.69565 )  
##         8) Income < 57 10  12.220 No ( 0.70000 0.30000 )  
##          16) CompPrice < 110.5 5   0.000 No ( 1.00000 0.00000 ) *
##          17) CompPrice > 110.5 5   6.730 Yes ( 0.40000 0.60000 ) *
##         9) Income > 57 36  35.470 Yes ( 0.19444 0.80556 )  
##          18) Population < 207.5 16  21.170 Yes ( 0.37500 0.62500 ) *
##          19) Population > 207.5 20   7.941 Yes ( 0.05000 0.95000 ) *
##       5) Price > 92.5 269 299.800 No ( 0.75465 0.24535 )  
##        10) Advertising < 13.5 224 213.200 No ( 0.81696 0.18304 )  
##          20) CompPrice < 124.5 96  44.890 No ( 0.93750 0.06250 )  
##            40) Price < 106.5 38  33.150 No ( 0.84211 0.15789 )  
##              80) Population < 177 12  16.300 No ( 0.58333 0.41667 )  
##               160) Income < 60.5 6   0.000 No ( 1.00000 0.00000 ) *
##               161) Income > 60.5 6   5.407 Yes ( 0.16667 0.83333 ) *
##              81) Population > 177 26   8.477 No ( 0.96154 0.03846 ) *
##            41) Price > 106.5 58   0.000 No ( 1.00000 0.00000 ) *
##          21) CompPrice > 124.5 128 150.200 No ( 0.72656 0.27344 )  
##            42) Price < 122.5 51  70.680 Yes ( 0.49020 0.50980 )  
##              84) ShelveLoc: Bad 11   6.702 No ( 0.90909 0.09091 ) *
##              85) ShelveLoc: Medium 40  52.930 Yes ( 0.37500 0.62500 )  
##               170) Price < 109.5 16   7.481 Yes ( 0.06250 0.93750 ) *
##               171) Price > 109.5 24  32.600 No ( 0.58333 0.41667 )  
##                 342) Age < 49.5 13  16.050 Yes ( 0.30769 0.69231 ) *
##                 343) Age > 49.5 11   6.702 No ( 0.90909 0.09091 ) *
##            43) Price > 122.5 77  55.540 No ( 0.88312 0.11688 )  
##              86) CompPrice < 147.5 58  17.400 No ( 0.96552 0.03448 ) *
##              87) CompPrice > 147.5 19  25.010 No ( 0.63158 0.36842 )  
##               174) Price < 147 12  16.300 Yes ( 0.41667 0.58333 )  
##                 348) CompPrice < 152.5 7   5.742 Yes ( 0.14286 0.85714 ) *
##                 349) CompPrice > 152.5 5   5.004 No ( 0.80000 0.20000 ) *
##               175) Price > 147 7   0.000 No ( 1.00000 0.00000 ) *
##        11) Advertising > 13.5 45  61.830 Yes ( 0.44444 0.55556 )  
##          22) Age < 54.5 25  25.020 Yes ( 0.20000 0.80000 )  
##            44) CompPrice < 130.5 14  18.250 Yes ( 0.35714 0.64286 )  
##              88) Income < 100 9  12.370 No ( 0.55556 0.44444 ) *
##              89) Income > 100 5   0.000 Yes ( 0.00000 1.00000 ) *
##            45) CompPrice > 130.5 11   0.000 Yes ( 0.00000 1.00000 ) *
##          23) Age > 54.5 20  22.490 No ( 0.75000 0.25000 )  
##            46) CompPrice < 122.5 10   0.000 No ( 1.00000 0.00000 ) *
##            47) CompPrice > 122.5 10  13.860 No ( 0.50000 0.50000 )  
##              94) Price < 125 5   0.000 Yes ( 0.00000 1.00000 ) *
##              95) Price > 125 5   0.000 No ( 1.00000 0.00000 ) *
##     3) ShelveLoc: Good 85  90.330 Yes ( 0.22353 0.77647 )  
##       6) Price < 135 68  49.260 Yes ( 0.11765 0.88235 )  
##        12) US: No 17  22.070 Yes ( 0.35294 0.64706 )  
##          24) Price < 109 8   0.000 Yes ( 0.00000 1.00000 ) *
##          25) Price > 109 9  11.460 No ( 0.66667 0.33333 ) *
##        13) US: Yes 51  16.880 Yes ( 0.03922 0.96078 ) *
##       7) Price > 135 17  22.070 No ( 0.64706 0.35294 )  
##        14) Income < 46 6   0.000 No ( 1.00000 0.00000 ) *
##        15) Income > 46 11  15.160 Yes ( 0.45455 0.54545 ) *

Each split shows the split criterion, the number of observations in the branch, the deviance, the overall prediction for the branch, and the fraction of observations in the branch that take on the values. Branches to terminal nodes are indicated with an asterisk.

Let’s split the data into a training and test set to gauge the predictive power of the tree. We also create a function to format our output table.

print_table <- function(x) {
    x %>% kable(align = 'l') %>% kable_styling()
}

set.seed(1)
carseat_smpl <- carseats %>% resample_partition(c(train = .5, test = .5))

carseat_tree <- tree(High ~ .-Sales, carseat_smpl$train)

carseat_smpl$test %>%
    as_tibble() %>%
    mutate(High_prime = predict(carseat_tree, newdata = ., type = 'class')) %>%
    summarise('Error Rate' = mean(High != High_prime) * 100) %>%
    print_table()
Error Rate
29.35323

We now test whether pruning the tree enhances its predictive capabilities. The function cv.tree() performs cross-validation. The argument FUN = prune.misclass in order to let the classification rate guide the cross-validation and pruning process rather than the default, which is deviance.

carseat_cv <- cv.tree(carseat_tree, FUN = prune.misclass)
carseat_cv
## $size
## [1] 19 15 14 12  9  6  4  2  1
## 
## $dev
## [1] 50 50 49 54 51 53 52 56 81
## 
## $k
## [1]      -Inf  0.000000  1.000000  1.500000  2.000000  2.333333  2.500000
## [8]  8.500000 24.000000
## 
## $method
## [1] "misclass"
## 
## attr(,"class")
## [1] "prune"         "tree.sequence"

The size attribute shows the number of terminal nodes of each tree considered, dev is the error rate (in this case cross-validation error), and k is the cost-complexity parameter \(\alpha\).

The tree with 14 terminal nodes had the lowest CV error. We plot size against dev

tibble(
    size = carseat_cv$size,
    cv_error = carseat_cv$dev
) %>%
ggplot(aes(size, cv_error)) +
    geom_point() +
    geom_line() +
    labs(x = 'Terminal Nodes', y = 'Cross-Validation Error Rate')

We apply the prune.misclass() function in order to prune to the 14 node tree.

carseat_prune <- prune.misclass(carseat_tree, best = 9)
plot(carseat_prune)
text(carseat_prune, cex = .7)

Let’s see how this performs on the test data.

carseat_smpl$test %>%
    as_tibble() %>%
    mutate(High_prime = predict(carseat_prune, newdata = ., type = 'class')) %>%
    summarise('Error Rate' = mean(High != High_prime) * 100) %>%
    print_table()
Error Rate
27.8607

The error rate has decreased from 29.35% to 27.86%.

8.3.2 - Fitting Regression Trees.

We fit a regression tree on the Boston data set. We first fit the tree to the training data.

set.seed(20)
Boston %>%
    as_tibble() %>%
    resample_partition(c(train = .5, test = .5)) -> boston_smpl

boston_tree <- tree(medv ~ ., data = boston_smpl$train)
summary(boston_tree)
## 
## Regression tree:
## tree(formula = medv ~ ., data = boston_smpl$train)
## Variables actually used in tree construction:
## [1] "rm"      "lstat"   "rad"     "ptratio"
## Number of terminal nodes:  8 
## Residual mean deviance:  15.12 = 3689 / 244 
## Distribution of residuals:
##      Min.   1st Qu.    Median      Mean   3rd Qu.      Max. 
## -23.04000  -1.69400   0.05849   0.00000   2.23200  20.78000
plot(boston_tree)
text(boston_tree, cex = .7, pretty = 0)

Note that only 5 of the variables have been used in constructing the tree.

We use cv.tree() to see if pruning the tree will improve performance.

boston_cv <- cv.tree(boston_tree)

tibble(size = as.integer(boston_cv$size), deviance = boston_cv$dev) %>%
    ggplot(aes(size, deviance)) +
    geom_point() +
    geom_line()

We pick 5 as the point to cut where the knee of the graph appears to be.

boston_prune <- prune.tree(boston_tree, best = 5)

plot(boston_prune)
text(boston_prune, cex = .7, pretty = 0)

boston_smpl$test %>%
    as_tibble() %>%
    mutate(medv_prime = predict(boston_prune, newdata = .)) %>%
    summarise('MSE' = mean((medv - medv_prime)^2)) %>%
    print_table()
MSE
28.95809

The MSE is around 28, therefore the root mean squred error is \(\sqrt(27) \approx 5.3\), so this model leads to predictions which are within $5,300 of the true median of the house price medv.

Bagging and Random Forests

The randomForest library is used to perform random forests and bagging. We recal that bagging is a random forest with \(m = p\), so the randomForest() function can be used for both scenarios.

boston_bag <- randomForest(medv ~ ., data = boston_smpl$train, mtry = 13, importance = T)
boston_bag
## 
## Call:
##  randomForest(formula = medv ~ ., data = boston_smpl$train, mtry = 13,      importance = T) 
##                Type of random forest: regression
##                      Number of trees: 500
## No. of variables tried at each split: 13
## 
##           Mean of squared residuals: 15.13463
##                     % Var explained: 81.2

The mtry = 13 indicates that all 13 predictors should be considered for each split of the tree - ie. bagging should be done.

Let’s take a look at how it performs:

boston_smpl$test %>%
    as_tibble() %>%
    mutate(medv_prime = predict(boston_bag,  newdata = .)) %>%
    summarise(MSE = mean((medv - medv_prime)^2)) %>%
    print_table()
MSE
15.46192

The test MSE is over that of the pruned tree.

The number of trees grown can be changed with the ntree argument.

boston_bag <- randomForest(medv ~ ., data = boston_smpl$train, mtry = 13, ntree = 25)
boston_smpl$test %>%
    as_tibble() %>%
    mutate(medv_prime = predict(boston_bag,  newdata = .)) %>%
    summarise(MSE = mean((medv - medv_prime)^2)) %>%
    print_table()
MSE
16.14903

Growing a random forest is exactly the same, except a smaller value of mtry is used. By default randomForest() uses * \(p/3\) variables when building a random forest of regression trees. * \(\sqrt{p}\) variables when building a random forest of classification trees.

boston_frst <- randomForest(medv ~ ., data = boston_smpl$train, mtry = 6, importance = T)
boston_smpl$test %>%
    as_tibble() %>%
    mutate(medv_prime = predict(boston_frst,  newdata = .)) %>%
    summarise(MSE = mean((medv - medv_prime)^2)) %>%
    print_table()
MSE
15.92335

The random forest has slightly increased the test MSE as opposed to bagging.

The importance() function shows us the importance of each variable:

importance(boston_frst)
##            %IncMSE IncNodePurity
## crim    10.4666047     543.58661
## zn       2.8049562      62.72601
## indus    8.4061104     864.37024
## chas     0.4863885      40.69838
## nox     12.9571009     623.95515
## rm      33.8397112    7298.94203
## age     14.2347805     670.81392
## dis      8.7153657    1086.90617
## rad      4.5419058     122.74522
## tax      8.7075252     337.11199
## ptratio 15.6394880    1338.60228
## black    8.2978958     360.42838
## lstat   26.9012662    6468.64836

The two measures of importance are: * %IncMSE - the mean decrease of accuracy in predictions on the out of bag samples when a given variable is excluded from the model. * IncNodePurity - Measure of the total decrease in node impurity that results from splits over that variable, averaged over all trees.

The varImpPlot() function can be used to plot these importance measures:

varImpPlot(boston_frst)

The results show that across all trees considered in the random forest, wealth of the community (lstat) and house size (rm) are the two most important variables.

8.3.4 - Boosting

We use the gbm package and the gbm() function to fit boosted regression trees to the Boston data set. We use the distribution = 'gaussian' argument as this is a regression problem. If it were a binary classification problem we would use distribution = 'bernoulli'. The argument n.trees = 5000 indicates we want 5000 trees, and interaction.depth = 4 limits the depth of each tree.

boston_boost <- gbm(
    medv ~ .,
    data = boston_smpl$train,
    distribution = 'gaussian',
    n.trees = 5000,
    interaction.depth = 4
)

The summary() function produces a relative influence plot and statistics

summary(boston_boost)

##             var     rel.inf
## rm           rm 37.51782445
## lstat     lstat 28.08569531
## dis         dis  8.12962345
## nox         nox  5.56629537
## crim       crim  5.42904772
## ptratio ptratio  4.04181395
## black     black  3.96800915
## age         age  3.55343182
## tax         tax  1.27903671
## chas       chas  1.20998688
## rad         rad  0.58315384
## indus     indus  0.55746738
## zn           zn  0.07861397

Again we see that lstat and rm are the most important variables.

We can produce partial dependence plots for these two variables. This illustrates the marginal effect of these two variables on the response after integrating out the other variables.

par(mfrow = c(1,2))
plot(boston_boost, i ="rm")

plot(boston_boost, i ="lstat")

We use the boosted model to predict medv.

boston_smpl$test %>%
    as_tibble() %>%
    mutate(medv_prime = predict(boston_boost, n.trees = 5000, newdata = .)) %>%
    summarise('MSE' = mean( (medv - medv_prime)^2 )) %>%
    print_table()
MSE
13.83959

The test MSE is better than the bagging and the random forest. We can boost with a slightly different shrinkage parameter \(\lambda\). Let’s try \(\lambda = .02\).

boston_boost <- gbm(
    medv ~ .,
    data = boston_smpl$train,
    distribution = 'gaussian',
    n.trees = 5000,
    interaction.depth = 4,
    shrinkage = 0.2,
    verbose = F
)

boston_smpl$test %>%
    as_tibble() %>%
    mutate(medv_prime = predict(boston_boost, n.trees = 5000, newdata = .)) %>%
    summarise('MSE' = mean( (medv - medv_prime)^2 )) %>%
    print_table()
MSE
15.79923

In this instance it raises our test MSE - however we could use cross-validation to find the best shrinkage factor.