Decision Trees
The basics of a decision tree
Tree structure
Predictions from a decision tree
Loss function
Growing a decision tree model
Pruning a decision tree model
The rpart()
function
Review of Kaggle notebooks for building Decision Tree models
readability_sub <- read.csv(here('data/readability_sub.csv'), header=TRUE)readability_sub[,c('V220','V166','target')]
V220 V166 target1 -0.13908 0.19028 -2.062822 0.21764 0.07101 0.582593 0.05812 0.03993 -1.653134 0.02526 0.18846 -0.873915 0.22431 0.06201 -1.740496 -0.07795 0.10754 -3.639947 0.43401 0.12202 -0.622848 -0.24365 0.02455 -0.344279 0.15894 0.10422 -1.1229910 0.14496 0.02340 -0.9985711 0.34223 0.22065 -0.8765712 0.25219 0.10865 -0.0330513 0.03533 0.07549 -0.4953014 0.36411 0.18676 0.1245415 0.29989 0.11618 0.0967816 0.19837 0.08273 0.3842217 0.07807 0.10235 -0.5814318 0.07936 0.11619 -0.3432519 0.57001 -0.02385 -0.3905420 0.34523 0.09300 -0.67548
Let's imagine a simple tree model to predict readability scores from Feature 220.
This model splits the sample into two pieces using a split point of 0.2 for the predictor variable (Feature 220).
There are 11 observations with Feature 220 less than 0.2 and nine with Feature 220 larger than 0.2. The top of the tree model is called root node, and the R1 and R2 in this model are called terminal nodes.
This model has two terminal nodes.
There is a number assigned to each terminal node. These numbers are the average values for the target outcome (readability score) for those observations in that specific node. It can be symbolically shown as
¯Yt=1nt∑iϵRtYi,
where nt is the number of observations in a terminal node, and Rt represents the set of observations in the tth node.
There is also a concept of depth of a tree.
The root node is counted as depth 0, and each split increases the depth of the tree by one. In this case, we can say that this tree model has a depth of one.
We can increase the complexity of our tree model by splitting the first node into two more nodes using a split point of 0.
Now, our model has a depth of two and a total of three terminal nodes.
Each terminal node is assigned a score by computing the average outcome for those observations in that node.
The tree model can have nodes from splitting another variable.
For instance, the model below first splits the observations based on Feature 220, then based on Feature 166, yielding a tree model again with three nodes with a depth of two.
This tree model's complexity is the same as the previous one's; the only difference is that we have nodes from two predictors instead of one.
A final example is another tree model with increasing complexity and having a depth of three and four nodes.
It first splits observations based on whether or not Feature 220 is less than 0.2, then splits observations based on whether or not Feature 220 is less than 0, and finally splits observations based on whether or not Feature 166 is less than 0.1.
Suppose you developed a tree model and decided to use this model to make predictions for new observations.
Let's assume our model is the below.
How do we use this model to make predictions for new observations?
Suppose that there is a new reading passage:
The value for Feature 220 is - 0.5
The value for Feature 166 is 0.
What is the predicted readability score for this passage based on this tree model?
You can trace a path starting from the root node (top of the tree) and see where this reading passage will end.
Suppose you have another new reading passage:
The value for Feature 220 is 0.1
The value for Feature 166 is 0.
What is the predicted readability score for this passage based on this tree model?
When we fit a tree model, the algorithm decides the best split that minimizes the sum of squared errors. The sum of squared error from a tree model can be shown as
SSE=T∑t=1∑iϵRt(Yi−^YRt)2
where T is the total number of terminal nodes in the tree model, and ^YRt is the prediction for the observations in the tth node (average target outcome for those observations in the tth node).
Deciding on a root node and then growing a tree model from that root node can become computationally exhaustive depending on the size of the dataset and the number of variables.
The decision tree algorithm
searches all variables designated as predictors in the dataset at all possible split points for these variables,
calculates the SSE for all possible splits,
and then finds the split that would reduce the prediction error the most.
The search continues by growing the tree model sequentially until there is no more split left that would give better predictions.
Let's demonstrate the logic of this search process with the toy dataset (N=20) and two predictor variables: Feature 220 and Feature 166.
Before we start our search process, we should come up with a baseline SSE to decide whether any future split will improve our predictions.
We can imagine that a model with no split and using the mean of all 20 observations to predict the target outcome is the simplest baseline model.
SSE from this baseline model is equal to 17.73
# average outcomemu <- mean(readability_sub$target)# SSE for baseline modelsum((readability_sub$target - mu)^2)
[1] 17.73
The first step in building the tree model is to find the root node.
In this case, we have two candidates for the root node: Feature 220 and Feature 166.
We want to know which predictor should be the root node and what value we should use to split to improve the baseline SSE the most.
The following is what the process would look like:
Pick a split point
Divide 20 observations into two nodes based on the split point
Calculate the average target outcome within each node as predictions for the observations in that node
Calculate SSE within each node using the predictions from Step 3, and sum them all across two nodes.
Repeat Steps 1 - 4 for every possible split point, and find the best split point with the minimum SSE
Repeat Steps 1 - 5 for every possible predictor.
The search process indicates that the best split point for Feature 220 is -0.026.
If we divide the observations into two nodes based on Feature 220 using the split point -0.026, SSE would be equal to 12.20, a significant improvement over the baseline model.
We can (Steps 1-4) repeat the same process for Feature 166.
The search process indicates that the best split point for Feature 166 is 0.19.
If we divide the observations into two nodes based on Feature 166 using the split point 0.19, SSE would be equal to 16.62, also an improvement over the baseline model.
We have two choices for the root node (because we have two predictors in this demonstration):
Feature 220, Best split point = -0.026, SSE = 12.20
Feature 166, Best split point = 0.189, SSE = 16.62
Both options improve our predictions over the baseline (SSE = 17.73), but one is better. Therefore, our final decision to start growing our tree is to pick Feature 220 as our root node and split it at -0.07.
Our tree model starts growing!
Now, we have to decide if we should add another split to either one of these two nodes. There are now four possible split scenarios.
For the first terminal node on the left, we can split the observations based on Feature 220.
For the first terminal node on the left, we can split the observations based on Feature 166.
For the second terminal node on the right, we can split the observations based on Feature 220.
For the second terminal node on the right, we can split the observations based on Feature 166.
For each of these scenarios, we can implement Step 1 - Step 4, identifying the best split point and what SSE that split yields. The code below implements these steps for all four scenarios.
Scenario 1: Split the first terminal node based on Feature 220
Scenario 2: Split the first terminal node based on Feature 166
Scenario 3: Split the second terminal node based on Feature 220
Scenario 4: Split the second terminal node based on Feature 166
Based on our search, the following splits provided the least SSEs:
1st terminal node, Split variable: Feature220, Split Point: -0.19, SSE = 8.01
1st terminal node, Split variable: Feature166, Split Point: 0.066, SSE = 8.01
2nd terminal node, Split variable: Feature220, Split Point: 0.178, SSE = 10.94
2nd terminal node, Split variable: Feature166, Split Point: 0.067, SSE = 9.96
We can decide to continue with either Scenario 1 or 2 because they yield the minimum SSE. Assuming we decide Scenario 1, our tree model now looks like this.
The search process continues until a specific criterion is met to stop the algorithm.
There may be several conditions where we may constrain the growth and force the algorithm to stop searching for additional growth in the tree model.
Some of these are listed below:
Minimizing SSE: the algorithm stops when there is no potential split in any of the existing nodes that would reduce the sum of squared errors.
*The minimum number of observations to split: the algorithm does not attempt to split a node unless there is a certain number of observations in the node.
Maximum depth: The algorithm stops searching when the tree reaches a certain depth.
Overfitting and underfitting also occur as one develops a tree model.
When the depth of a tree becomes unnecessarily large, there is a risk of overfitting (increased variance in model predictions across samples, less generalizable).
When the depth of a tree is small, there is a risk of underfitting (increased bias in model predictions, underperforming model).
To balance the model variance and model bias, we can add a penalty term to the loss function:
SSE=T∑t=1∑iϵRt(Yi−^YRt)2+αT
The penalty term α is known as the cost complexity parameter.
The product term αT increases as the number of terminal nodes increases in the model, so this term penalizes the increasing complexity of the tree model.
By fine-tuning the value of α through cross-validation, we can find a balance between model variance and bias, as we did for regression models.
The process is called pruning because the terminal nodes from the tree are eliminated in a nested way for increasing levels of α.
rpart()
functionLet's first replicate our manual search in previous slides to build a tree model predicting the readability scores from Feature 220 and Feature 166 in the toy dataset.
require(rpart)require(rattle)dt <- rpart(formula = target ~ V220 + V166, data = readability_sub, method = "anova", control = list(minsplit = 1, cp = 0, minbucket = 1, maxdepth = 2) )
The formula
argument works similarly as in the regression models.
The data
argument provides the name of the datasetat the outcome is a continuous variable and this is a regression problem.
The control
argument is a list object with several settings to be used during tree-building.
minsplit=1
forces algorithm not to attempt to split a node unless there is at least one observation.
minbucket=1
forces algorithm to have at least one observation for any node.
maxdepth = 2
forces algorithm to stop when the depth of a tree reaches to 2.
cp=0
indicates that we do not want to apply any penalty term ( λ=0 ) during the model building process.
# ?prp # check for a lot of settings for modifying this plotfancyRpartPlot(dt,type=2,sub='')
You can also ask for more specific information about the model-building process by running the summary()
function.
summary(dt)
Call:rpart(formula = target ~ V220 + V166, data = readability_sub, method = "anova", control = list(minsplit = 1, cp = 0, minbucket = 1, maxdepth = 2)) n= 20 CP nsplit rel error xerror xstd1 0.3122 0 1.0000 1.079 0.49592 0.2363 1 0.6878 1.518 0.57073 0.1259 2 0.4515 1.518 0.57074 0.0000 3 0.3256 1.402 0.4620Variable importanceV220 V166 81 19 Node number 1: 20 observations, complexity param=0.3122 mean=-0.7633, MSE=0.8867 left son=2 (3 obs) right son=3 (17 obs) Primary splits: V220 < -0.02634 to the left, improve=0.31220, (0 missing) V166 < 0.1894 to the right, improve=0.06253, (0 missing)Node number 2: 3 observations, complexity param=0.2363 mean=-2.016, MSE=1.811 left son=4 (2 obs) right son=5 (1 obs) Primary splits: V220 < -0.1914 to the right, improve=0.7711, (0 missing) V166 < 0.06604 to the right, improve=0.7711, (0 missing)Node number 3: 17 observations, complexity param=0.1259 mean=-0.5423, MSE=0.3979 left son=6 (4 obs) right son=7 (13 obs) Primary splits: V166 < 0.06651 to the left, improve=0.3301, (0 missing) V220 < 0.1787 to the left, improve=0.1854, (0 missing)Node number 4: 2 observations mean=-2.851, MSE=0.6218 Node number 5: 1 observations mean=-0.3443, MSE=0 Node number 6: 4 observations mean=-1.196, MSE=0.2983 Node number 7: 13 observations mean=-0.3413, MSE=0.2567
We will now expand the model by adding more predictors to be considered for the tree model.
Suppose that we now have six different predictors to be considered for the tree model.
We will provide the same settings with six predictors except for the complexity parameter.
We will fit three models by setting the complexity parameter at 0, 0.05, and 0.1 to see what happens to the tree model as we increase the complexity parameter.
dt <- rpart(formula = target ~ V78 + V166 + V220 + V375 + V562 + V568, data = readability_sub, method = "anova", control = list(minsplit=1, cp=0, minbucket = 1, maxdepth = 5) )fancyRpartPlot(dt,type=2,sub='')
dt <- rpart(formula = target ~ V78 + V166 + V220 + V375 + V562 + V568, data = readability_sub, method = "anova", control = list(minsplit=1, cp=0.05, minbucket = 1, maxdepth = 5) )fancyRpartPlot(dt,type=2,sub='')
dt <- rpart(formula = target ~ V78 + V166 + V220 + V375 + V562 + V568, data = readability_sub, method = "anova", control = list(minsplit=1, cp=0.1, minbucket = 1, maxdepth = 5) )fancyRpartPlot(dt,type=2,sub='')
Performance Comparison of Different Algorithms
R-square | MAE | RMSE | |
---|---|---|---|
Linear Regression | 0.658 | 0.499 | 0.620 |
Ridge Regression | 0.727 | 0.432 | 0.536 |
Lasso Regression | 0.721 | 0.433 | 0.542 |
Elastic Net | 0.726 | 0.433 | 0.539 |
KNN | 0.611 | 0.519 | 0.648 |
Decision Tree | 0.499 | 0.574 | 0.724 |
Performance Comparison of Different Algorithms
-LL | AUC | ACC | TPR | TNR | FPR | PRE | |
---|---|---|---|---|---|---|---|
Logistic Regression | 0.510 | 0.719 | 0.755 | 0.142 | 0.949 | 0.051 | 0.471 |
Logistic Regression with Ridge Penalty | 0.511 | 0.718 | 0.754 | 0.123 | 0.954 | 0.046 | 0.461 |
Logistic Regression with Lasso Penalty | 0.509 | 0.720 | 0.754 | 0.127 | 0.952 | 0.048 | 0.458 |
Logistic Regression with Elastic Net | 0.509 | 0.720 | 0.753 | 0.127 | 0.952 | 0.048 | 0.456 |
KNN | ? | ? | ? | ? | ? | ? | ? |
Decision Tree | 0.558 | 0.603 | 0.757 | 0.031 | 0.986 | 0.014 | 0.423 |
Decision Trees
The basics of a decision tree
Tree structure
Predictions from a decision tree
Loss function
Growing a decision tree model
Pruning a decision tree model
The rpart()
function
Review of Kaggle notebooks for building Decision Tree models
Keyboard shortcuts
↑, ←, Pg Up, k | Go to previous slide |
↓, →, Pg Dn, Space, j | Go to next slide |
Home | Go to first slide |
End | Go to last slide |
Number + Return | Go to specific slide |
b / m / f | Toggle blackout / mirrored / fullscreen mode |
c | Clone slideshow |
p | Toggle presenter mode |
t | Restart the presentation timer |
?, h | Toggle this help |
Esc | Back to slideshow |