How Decision Tree Works
A decision tree is so widely used and its idea is very simple. However, I always tend to forget how it works because it's too simple! Decision tree uses something called Information Gain as its objective. It tries to maximize the Information Gain.
Entropy
Before talking about information gain, we must know entropy.
Its graph would look as below:
Notice that the entropy has its maximum at
Example
Grade | Bumpiness | Speed limit | Speed |
---|---|---|---|
Steep | Bumpy | Yes | Slow |
Steep | Smooth | Yes | Slow |
Flat | Bumpy | No | Fast |
Steep | Smooth | No | Fast |
Suppose we want to predict speed. Either slow or fast.
The base entropy is 1 because
Information Gain
Finally, we can talk about the information gain. Information gain refers to a difference between parent and children entropy.
Example
We have 3 features (Grade, Bumpiness, Speed Limit) and for each feature, we compute its entropy and compute the difference.
In case of Grade,
Grade | Bumpiness | Speed limit | Speed |
---|---|---|---|
Steep | Bumpy | Yes | Slow |
Steep | Smooth | Yes | Slow |
Flat | Bumpy | No | Fast |
Steep | Smooth | No | Fast |
we have two cases when the grade is steep or flat
When grade is steep
There are 2 slow and 1 fast examples. Therefore,
When grade is flat
It's simple because it's "pure" (there exists only one class). Its entropy is 0.
Information Gain
Therefore,
the weighted average of children entropy() is
the information gain
If we repeat for other features (Bumpiness, Speed Limit)
Therefore, the next split will occur around Speed Limit because it has the greatest information gain.
Min Split
The decision tree will repeat the above process, but a question is when will stop splitting. There are two popular hyperparameters for decision trees. One is depth and the other one is min split.
Depth refers to the depth of a tree. It will stop if current depth is at the max depth.
Min split refers to the number of minimum samples to split. For example, if a min split
is 3 and there are 2 samples left.
There will be no split afterward.
Code Samples
In sklearn
,
from sklearn import tree
clf = tree.DecisionTreeClassifier()
clf = clf.fit(X, Y)