Article preview
Artificial Intelligence
Machine Learning
15 мая
13 minutes

What Do Trees Think About?

In this article, we will start exploring the classification problem. And the first algorithm will be the most intuitively understandable one — decision trees.

Vyacheslav Gorash avatar

Vyacheslav Gorash

3D Graphics and Machine Learning Developer with 6 years of experience

This article is part of a series on the fundamentals of machine learning.

In the previous article, we started exploring the simplest machine learning methods using linear regression as an example. Today we will look at another model, but this time for a classification task.

Classification vs Regression

Let's start by recalling the first article and formulate the problem for ourselves. The input data is the same as before: a number or a vector of numbers, which we will denote as XX. But the output data — y^\widehat{y} — will now change. Previously, it was simply a number. For linear regression, there were no restrictions on either the values or the range. Now, y^\widehat{y} will be able to take not just any value, but only one of the predefined ones. Such values will be called classes. We also have known correct answers for our data, which we will denote as yy (without the hat).

Let's return to the example from the previous article about predicting weight from height. If before we were predicting the exact weight in kilograms, now we would rather be in the role of judges before boxing matches. That is, we don't care about the exact weight, but rather about the category (light, middle, heavy, and so on).

Another option is binary classification, where there are only two classes. This usually means we are answering some question with "yes" or "no". For example, using the same height and weight data, we could try to answer the question "Is this person taller/heavier than the city average?". This is exactly the binary classification we will focus on.

How to Classify?

So, we have an input vector XX and only two possible answers. Let's denote the answer "yes" as one, and the answer "no" as zero.

Let's take the simplest case where XX is a single number. Linear regression won't work here because it produces an answer from a continuous set (spoiler: there is a way to make regression solve our current task, but that's a topic for the next article). So let's come up with something else for now.

Let's visualize our data. It's one-dimensional, so it's easy to display on a number line. Let's mark with blue dots the data for which the correct answer is 0, and with red dots those for which the answer is 1.

Classification data visualization 1

In this case, everything is simple. We can choose some value on the line, and if our input number is less than this value, the answer is 0, and if it's greater — 1. How do we choose such a number? Let's try all possible ways to split the data. Usually, the split is made at the midpoint between two points. This gives us 8 options:

Classification data visualization 2

Now let's check how many correct answers each option yields. Obviously, in this particular case, we have an option that gives all correct answers:

Classification data visualization 3

But this doesn't always happen. Let's change the initial conditions:

Classification data with different initial conditions

In this case, we cannot perfectly separate our data. But we can minimize the error. If we split the data at the same point as before, we get 7 correct answers out of 8. And this is the best we can achieve with a single split.

Building Trees

But there are cases where it's impossible to achieve reasonable quality with just one split. For example:

Data where a single split is ineffective

Here, with one split, we get at most 5 correct answers out of 8, which is not very good. So we need to come up with something else.

The solution seems intuitive: split not once, but multiple times. For example, we can first split like this:

Data with multiple splits

After this split, we consider that all objects below this value have class 1. The remaining part is split once more (the points filled in white were already classified in the previous step, so they can be ignored):

Data with multiple splits

As a result, we get an algorithm that we need to execute to classify any element:

Data splitting algorithm visualization

This algorithm is called a decision tree. A tree has a root — this is where the algorithm starts. The points where it ends, that is, the places where we get a class (in the figure above, these are the colored numbers — class labels) are called the leaves of the tree. And all intermediate points (rectangles in the figure) are called nodes. It is the node that describes the splitting of values into two parts.

Now we need to understand how to build such a tree. So, at the very beginning, there are no nodes or leaves. We need to make the first split of the data, that is, create a node. Again, we select points — candidates for splitting. These will be the midpoints of the segments between neighboring data points (just like in the simplest case example):

Splitting candidates visualization

Now we need to determine which candidate is the most suitable. And simply counting the number of correct answers won't work here. The thing is, if we try to maximize the number of correct answers at the first step (in programming, this is called a greedy algorithm), we might make the task harder at subsequent steps. That is, we need to optimize the entire solution, not just the first step.

So we'll have to take a slightly more complex approach. We will still iterate through all candidates one by one, but the metric will be different. First, we count how many elements end up to the right and to the left of our split:

PointSleft\left\| \mathbf{S}_{\mathbf{left}} \right\|Sright\left\| \mathbf{S}_{\mathbf{right}} \right\|
A17
B26
C35
D44
E53
F62
G71

Let's set these values aside for now, they will be useful a bit later. Now we need to calculate how well the data is separated on the left and on the right. In other words, how effectively we've divided the objects into two classes. For this, we'll use the Gini criterion. It's calculated using the formula

G=1(p02+p12)G = 1 - \left( p_{0}^{2} + p_{1}^{2} \right)

where p0p_{0} and p1p_{1} are the proportions of classes zero and one in our split.

Suppose we have three zeros and one one on the left. Then

Gleft=1((34)2+ (14)2)=10.625=0.375G_{left} = 1 - \left( \left( \frac{3}{4} \right)^{2} + \ \left( \frac{1}{4} \right)^{2} \right) = 1 - 0.625 = 0.375

The lower the criterion, the better. For example, if there are only ones or only zeros on the left, the criterion value will be zero. This means that on this side, we have perfectly separated one of the classes.

Let's calculate the criteria for all our candidates:

PointGleft\mathbf{G}_{\mathbf{left}}Gright\mathbf{G}_{\mathbf{right}}
A00.41
B00.44
C00.48
D0.3750.375
E0.480
F0.440
G0.410

Now we just need to calculate the final metric for each candidate. It's calculated as a weighted sum. Remember that we counted the number of elements on the right and left. The weights will be exactly these counts divided by the total number of elements S|S|. We have 8 elements in total, so S=8|S| = 8. The final formula:

M= SleftSGleft+SrightSGrightM = \ \frac{\left| S_{left} \right|}{|S|}G_{left} + \frac{\left| S_{right} \right|}{|S|}G_{right}

Calculating for all our candidates:

PointM
A0.36
B0.33
C0.3
D0.375
E0.3
F0.33
G0.36

The best metric is for points C and E. We can split at either one, let's split at C.

Best candidate for the first split

Now let's see what we got. On the left, there's only class 1. This means we have a leaf here and we don't split further. On the right, there are different classes. This means it's not a leaf, but another node. We repeat the splitting for this node:

Best candidate for the second split

After calculations, we'll see that we need to split at E. On the left, we'll have class 0. On the right, class 1. That is, all objects are assigned to classes. We add two leaves. The tree is built.

The final algorithm (initially the tree is empty):

  1. Add a node at the root
  2. Perform a split at the node
  3. If all objects on the left are of the same class, create a leaf, otherwise create a node
  4. If all objects on the right are of the same class, create a leaf, otherwise create a node
  5. Repeat steps 2-4 as long as there is at least one node in the tree

Once the tree is built, for each new object we simply traverse it from the root until we reach a leaf. The class in the leaf will be the class for our object.

More Dimensions

We've learned how to build trees for one-dimensional input data, that is, for single numbers. For example, using the tree described in the previous chapter, we could try to determine whether we're looking at an apple or a pear based on a single parameter. For instance, by diameter. However, such classification would be very inaccurate. Diameter alone says almost nothing about whether it's a pear or an apple.

Data that cannot be effectively separated in one dimension

In all the figures below, we'll mark apples in red and pears in blue. From the figure, you can see that it's impossible to split the data into groups of more than one element.

The situation changes dramatically if we add a second dimension. Now each fruit will be described not by one parameter, but by two: width (or diameter) and height. Adding a second parameter will allow us to distinguish apples from pears much more effectively.

Moving data to 2 dimensions

Notice that nothing has changed along the X axis. But now the data is very well separated along the Y axis.

If a human were doing the classification, they would look at the elongation (the difference between height and width). To do something similar with machine learning, we can again use decision trees. We just need to figure out how to build a tree for our case.

In fact, the algorithm will be almost the same. At each step, we split the data into two parts. The main difference is that before, we didn't need to choose which parameter to split by, since there was only one. Now there are two parameters, and first we need to decide which one to use for splitting. After that, everything reduces to choosing the split point, which we already learned to do in the previous chapter.

What will such splits look like on a graph? Before, we had only one axis, and we simply placed a point, to the left of which all objects go to one branch of the tree, and to the right — to another. Now we have two axes. And the first step is choosing which axis to split by. If we choose the X axis, it will be a vertical line; if Y — a horizontal line.

Splitting candidate in 2 dimensions

Orange and green colors show splits along different features (axes).

At the second stage, we choose exactly where such a line will pass. After that, objects on different sides of the line (above/below or left/right) go to different branches of the tree.

It remains to figure out how exactly to choose the axis. In fact, everything is very simple. We now iterate not only through splitting candidates for one feature, but for all of them. And we choose the best among all options. And when at the second step we need to split along the chosen axis, we already know where to do it, since we calculated that at the first step. If we visualize the tree's operation on the graph, we get this result:

Splitting result in 2 dimensions

The final algorithm will look like this:

  1. Add a node at the root
  2. Find the best split for each feature
  3. Perform the split at the node using the best feature
  4. If all objects on the left are of the same class, create a leaf, otherwise create a node
  5. If all objects on the right are of the same class, create a leaf, otherwise create a node
  6. Repeat steps 2-5 as long as there is at least one node in the tree

If you look at this algorithm, it becomes clear that nothing prevents us from working with higher-dimensional data. For example, for three-dimensional data, we would separate it not with a line, but with a plane. This can still be visualized. But for four or more dimensions, there is no clear visualization, since the data would be separated by a hyperplane.

Overfitting and What to Do About It

So far, we've been building the tree until all objects are correctly classified. At first glance, this seems like the only right approach — after all, we don't want to make errors. But this is only at first glance. Perfectionism in machine learning does more harm than good.

Let's return to the apple and pear example and imagine that a defective elongated apple made it into our training set. The decision tree, if left as is, will build branches to correctly classify this object as an apple, even though it's surrounded by pears. As a result, when we use this tree for classification, there's a risk that a normal pear will be classified as an apple.

Overfitting in two dimensions

What can we do about this? We need to make the tree ignore single unusual examples. There are several ways to do this:

  1. Limit the tree depth. For example, we can allow no more than three splits. After three splits, the outermost nodes become leaves.
  2. Limit the minimum number of data points in a leaf. For example, if a node has four or fewer objects, it automatically becomes a leaf and is not split further.
  3. Limit the total number of leaves. If the limit is reached, no new splits occur, and all outermost nodes become leaves.

Each of these methods can help combat overfitting. But there is no universal recipe. Usually, a combination of methods is applied, and parameters are tuned experimentally.

For such experiments (and not only for decision trees, but for many other machine learning models as well), having a test set is very useful. What is it? Before training begins, we set aside 10-20 percent of the data as the test set. And we don't use it during training. This is a very important point — the test data must never be seen by the model during training.

When training is complete, we run the test data through the model and see how well it performs. If the results on the training data are very good but poor on the test data, this is most likely overfitting, and measures need to be taken. Usually, the goal is for the percentage of correct answers on the training and test sets to be roughly the same.

Questions and Answers

Ready to discuss your project?

Describe your task, we will make a research and respond to you as soon as possible.

We will be happy to advise you in any of the available ways.

By leaving a request you agree to the data processing policy