Techno Blender
Digitally Yours.

Neural Networks as Decision Trees | by Nakul Upadhya | Apr, 2023

0 81


Photo by Jens Lelie on Unsplash

Get the power of a Neural Network with the interpretable structure of a Decision Tree

The recent boom in AI has clearly shown the power of deep neural networks in various tasks, especially in the field of classification problems where the data is high-dimensional and has complex, non-linear relationships with the target variables. However, explaining the decisions of any neural classifier is an incredibly hard problem. While many post-hoc methods such as DeepLift [2] and Layer-Wise Relevance Propagation [3] can help with explaining individual decisions, explaining the global decision mechanisms (or what the model generally looks for) is much more difficult.

Because of this, many practitioners in high-stakes fields instead opt for more interpretable models like basic Decision Trees since the decision hierarchy can be clearly visualized and understood by stakeholders. However, basic trees by themselves often do not provide enough accuracy for the task at hand and often ensemble methods like Bagging or Boosting are used to improve the model’s performance. This however again sacrifices some interpretability as in order for one to understand a single decision, a practitioner would need to look through hundreds of trees. However, these are still preferred over deep networks as at least feature importance (on both a local and global scale) can be easily retrieved and displayed.

So the problem at hand is that we want the discriminative power of a neural network, but with the interpretability of a decision tree. So why don’t we just structure our network as a tree? Well, that is the main approach taken by Fross and Hinton (2017) in their paper “Distilling a Neural Network into a Soft Decision Tree” [1]. In this article, I will break down the key mechanisms behind a Neural Decision Tree and explain some of the benefits of their approach as well as some factors one may need to consider when implementing this methodology in practice. While we will mainly be discussing classification trees, the approaches detailed can also be applied to regression trees with a relatively small number of tweaks.

Soft Vs. Hard Decision Trees

Before diving into how to construct a neural network into a soft decision tree, let’s first define what a soft decision tree is.

When people think of decision trees (such as the ones implemented in sklearn), they are thinking about hard decision trees where every decision is deterministic.

Example of a Hard Decision Tree (Image by Author)

If a condition is met, we go towards the left branch, otherwise, we go right. Each leaf node has a class and a prediction is made by simply going through the tree and picking the class we end up in. The large we allow the tree to grow, the more paths we can take to achieve the final decision.

Soft decision trees have many similarities, but work slightly differently

Example of a soft decision tree (Image by Author)

While in hard decision trees, each branching is deterministic, soft decision trees define the probability of going into a certain branch if the condition is met. So while a hard decision tree outputs a single value, soft decision trees instead output a probability distribution for all possible classes where the probability of a class is the product of the probabilities that we travel through to reach the leaves. For example, the probability of approval in the tree above is equal to P(b1|X)(1-P(b2|X) + (1-P(b2|X))(1-P(b3|X)). Classification decisions are then just the class with the highest probability.

This structure has many benefits. For one, having non-deterministic decisions provides users with an idea of the uncertainty present in a given classification. Additionally, technically hard trees are just special variants of soft trees where all branching probabilities are all equal to 1.

One downside of these trees is the slight reduction in interpretability. From a stakeholder standpoint, its easier to understand “we approved a loan because the individual made $100k a year and had less than $400k in debt” compared to:

If the income is over $100k we have a .7 probability to go right and if the debt is under 400k we have a .8 probability to approve which results in in a .56 probability plus whatever happens in the left branch

That doesn’t mean that these are not interpretable (as one can still see exactly what the model is looking at) but just require a bit more help from the model developer.

Oblique Decision Trees

The second concept that is required before getting into Neural Decision Trees is the concept of “Oblique” Decision trees.

Traditional decision trees are considered “Orthogonal” trees in that their decisions are made orthogonal to a given axis. Simply put, only one variable is used in any given decision. Oblique trees on the other hand use multiple variables in their decision-making, usually in a linear combination.

Example of an oblique decision boundary (Figure from Zhang et. al 2017 [4])

Some examples of the values in decision nodes could be “Income — Debt > 0”. This can result in stronger decision boundaries. One downside is that without proper regularization, these boundaries can get increasingly complex.

Putting it together

Now that we understand Soft and Oblique decision trees, we can put these together to understand the neural formulation.

The first component is the decision nodes. For each node what we need is some probability based on the input value. To achieve this, we can use the bread and butter of neural networks: Weights and activations. In each decision node, we will first take the linear combination of the input variables and then apply a sigmoid function on the sum, resulting in the branching probability.

To prevent extremely soft decisions (and make the decision tree more like a hard decision tree), a tempered sigmoid (or a multiplication of the linear combination before applying the sigmoid) can be used instead.

Each leaf node contains an N-dimensional tensor where N is the number of classes. This tensor represents the probability distribution of samples being in a class.

Neural Network as a Decision Tree (Replication of a figure from Frosst & Hinton 2017 [1])

Like in a soft decision tree, the output of this neural tree is a probability distribution of the classes. The output distribution is equal to the sum of the distributions multiplied by the path probability to reach that distribution.

Training the Tree

One benefit of neural trees is that they can be trained through continuous optimization algorithms like Gradient Descent instead of the greedy algorithms normal decision trees need to be constructed with. All we need to do is define the loss function:

The loss function for the neural tree (Image from Frosst & Hinton 2017 [1])

The loss function for this tree is similar to cross-entropy loss. In this equation, P^l(x) is the probability of reaching the leaf node l given the data point x, T_k is the probability of the target being class k (1 or 0), and Q_k^l is the element of the tensor (probability distribution) in the leaf node l that corresponds to class k.

One important note about this structure is that the tree shape is fixed. Unlike normal decision trees which use a greedy algorithm to split one node at a time and grow the tree, with this soft decision tree we first set the size of the tree and then use gradient descent to update all parameters simultaneously. One benefit of this approach is that it is much easier to constrain the size of the tree without losing much discriminative power.

One potential pitfall we may run into during the training process is that the model can heavily favor a single branch and not leverage the power of a tree. To avoid getting stuck on poor solutions, it is recommended to introduce a penalty to the loss function that encourages the tree to leverage both left and right sub-trees.

The penalty is the cross-entropy between the desired averaged distribution (50/50 for both left and right trees) and the actual average distribution defined as alpha.

Definition of alpha for a given node i (Image from Frosst & Hinton 2017 [1])

In this equation, P^i(x) is the path probability from the root node to node i for a given data point x. We then sum the penalty across all internal nodes.

Definition of the penalty (Image from Frosst & Hinton 2017 [1])

In this equation, lambda is a hyperparameter that determines the strength of the penalty. However, this may cause some problems as we descend the tree, the data has less of a chance to fall into a 50/50 split, so therefore it is encouraged to instead use an adaptive lambda that changes depending on the depth of the tree. This would modify the penalty to:

Modified Penalty function (Image by author, taken from text in Frosst & Hinton 2017 [1])

As we go deeper into the tree, it is recommended to decay the lambda proportional to 2^-d.

Visualizing Results

While the reformulation of a neural network into a tree is interesting, the main reason one would pursue this avenue is to provide more model interpretability.

Let’s first see the interpretation on a classic problem—digit classification in MNIST:

Example on MNIST (Image from Frosst & Hinton 2017 [1])

In the figure above, the images at the inner nodes are the learned filters, and the images at the leaves are visualizations of the learned probability distribution over classes. For each leaf and node, the most likely classification is annotated in blue.

Looking at this tree, we can see some interesting features. For example, if we look at the right-most internal node the potential classifications are between 3 and 8. In fact, we can actually see the outline of the 3 in the decision node visualization. The white areas seem to indicate that the model looks for lines that close out the inner loops of a 3 to make it an 8. We can also see that the model looks for the shape of a zero in the 3rd node from the left in the bottom set of inner nodes.

Another interesting example is predicting wins in Connect4

Visualization of the first 2 Layers of a neural decision tree predicting who would win a Connect4 game (Image from Frosst & Hinton 2017)

The learned filters in this example show that the game can be split into two distinct types of games: ones where the players have primarily focused on the edge of the board versus the ones where players have placed pieces in the center of the board.

Constructing Neural Networks as a Soft-Decision Tree allows us to leverage the power of neural networks while still preserving some interpretability. As shown by the results on the MNIST dataset, the learned filters can provide both local and global explainability, a trait that is welcome and helpful for higher stake tasks as well. Additionally, the training method (optimizing and updating the entire tree at a time) allows us to get more discriminative power while keeping the size of the tree fixed, something we are not able to achieve with normal decision trees.

Despite this, Neural Trees are not perfect. The soft nature of the tree does mean that the data scientist using these needs to “preprocess” the tree before presenting it to non-technical stakeholders, while normal decision trees can be shown as is (as they are pretty self-explanatory). Additionally, while the oblique nature of the tree helps with accuracy, having too many variables in a given node can again make it harder to explain. This means that regularization is not only recommended but to a certain extent it is necessary. Also, no matter how interpretable a model is, there is still the need for interpretable features that stakeholders can understand.

All of these drawbacks however do not detract from the potential of these models to grow the interpretability-performance frontier. I highly encourage everyone to try these models in their next data science task. I also recommend that everyone read the original paper.

  1. Implementation of the Soft Decision Tree in PyTorch
  2. For more information on XAI and time-series forecasting, follow

References

[1] N. Frosst, G. Hinton. Distilling a Neural Network Into a Soft Decision Tree (2017). 2017 Artificial Intelligence in Action Conference

[2] A. Shrikumar, P. Greenside, A. Jundjae. Learning Important Features Through Propagating Activation Differences (2017). International Conference on Machine Learning PMLR 2017.

[3] S.Bach, A. Binder, G. Montavon, F. Klauschen, K-R. Muller, W. Samek. On Pixel-Wise Explanations for Non-Linear Classifier Decisions by Layer-Wise Relevance Propagation (2015). PloS one, 10(7), e0130140

[4] L. Zhang, J. Varadarajan, P. N. Suganthan, N. Ahuja, P. Moulin. Robust Visual Tracking Using Oblique Random Forests (2017). Conference on Computer Vision and Pattern Recognition 2017.


Photo by Jens Lelie on Unsplash

Get the power of a Neural Network with the interpretable structure of a Decision Tree

The recent boom in AI has clearly shown the power of deep neural networks in various tasks, especially in the field of classification problems where the data is high-dimensional and has complex, non-linear relationships with the target variables. However, explaining the decisions of any neural classifier is an incredibly hard problem. While many post-hoc methods such as DeepLift [2] and Layer-Wise Relevance Propagation [3] can help with explaining individual decisions, explaining the global decision mechanisms (or what the model generally looks for) is much more difficult.

Because of this, many practitioners in high-stakes fields instead opt for more interpretable models like basic Decision Trees since the decision hierarchy can be clearly visualized and understood by stakeholders. However, basic trees by themselves often do not provide enough accuracy for the task at hand and often ensemble methods like Bagging or Boosting are used to improve the model’s performance. This however again sacrifices some interpretability as in order for one to understand a single decision, a practitioner would need to look through hundreds of trees. However, these are still preferred over deep networks as at least feature importance (on both a local and global scale) can be easily retrieved and displayed.

So the problem at hand is that we want the discriminative power of a neural network, but with the interpretability of a decision tree. So why don’t we just structure our network as a tree? Well, that is the main approach taken by Fross and Hinton (2017) in their paper “Distilling a Neural Network into a Soft Decision Tree” [1]. In this article, I will break down the key mechanisms behind a Neural Decision Tree and explain some of the benefits of their approach as well as some factors one may need to consider when implementing this methodology in practice. While we will mainly be discussing classification trees, the approaches detailed can also be applied to regression trees with a relatively small number of tweaks.

Soft Vs. Hard Decision Trees

Before diving into how to construct a neural network into a soft decision tree, let’s first define what a soft decision tree is.

When people think of decision trees (such as the ones implemented in sklearn), they are thinking about hard decision trees where every decision is deterministic.

Example of a Hard Decision Tree (Image by Author)

If a condition is met, we go towards the left branch, otherwise, we go right. Each leaf node has a class and a prediction is made by simply going through the tree and picking the class we end up in. The large we allow the tree to grow, the more paths we can take to achieve the final decision.

Soft decision trees have many similarities, but work slightly differently

Example of a soft decision tree (Image by Author)

While in hard decision trees, each branching is deterministic, soft decision trees define the probability of going into a certain branch if the condition is met. So while a hard decision tree outputs a single value, soft decision trees instead output a probability distribution for all possible classes where the probability of a class is the product of the probabilities that we travel through to reach the leaves. For example, the probability of approval in the tree above is equal to P(b1|X)(1-P(b2|X) + (1-P(b2|X))(1-P(b3|X)). Classification decisions are then just the class with the highest probability.

This structure has many benefits. For one, having non-deterministic decisions provides users with an idea of the uncertainty present in a given classification. Additionally, technically hard trees are just special variants of soft trees where all branching probabilities are all equal to 1.

One downside of these trees is the slight reduction in interpretability. From a stakeholder standpoint, its easier to understand “we approved a loan because the individual made $100k a year and had less than $400k in debt” compared to:

If the income is over $100k we have a .7 probability to go right and if the debt is under 400k we have a .8 probability to approve which results in in a .56 probability plus whatever happens in the left branch

That doesn’t mean that these are not interpretable (as one can still see exactly what the model is looking at) but just require a bit more help from the model developer.

Oblique Decision Trees

The second concept that is required before getting into Neural Decision Trees is the concept of “Oblique” Decision trees.

Traditional decision trees are considered “Orthogonal” trees in that their decisions are made orthogonal to a given axis. Simply put, only one variable is used in any given decision. Oblique trees on the other hand use multiple variables in their decision-making, usually in a linear combination.

Example of an oblique decision boundary (Figure from Zhang et. al 2017 [4])

Some examples of the values in decision nodes could be “Income — Debt > 0”. This can result in stronger decision boundaries. One downside is that without proper regularization, these boundaries can get increasingly complex.

Putting it together

Now that we understand Soft and Oblique decision trees, we can put these together to understand the neural formulation.

The first component is the decision nodes. For each node what we need is some probability based on the input value. To achieve this, we can use the bread and butter of neural networks: Weights and activations. In each decision node, we will first take the linear combination of the input variables and then apply a sigmoid function on the sum, resulting in the branching probability.

To prevent extremely soft decisions (and make the decision tree more like a hard decision tree), a tempered sigmoid (or a multiplication of the linear combination before applying the sigmoid) can be used instead.

Each leaf node contains an N-dimensional tensor where N is the number of classes. This tensor represents the probability distribution of samples being in a class.

Neural Network as a Decision Tree (Replication of a figure from Frosst & Hinton 2017 [1])

Like in a soft decision tree, the output of this neural tree is a probability distribution of the classes. The output distribution is equal to the sum of the distributions multiplied by the path probability to reach that distribution.

Training the Tree

One benefit of neural trees is that they can be trained through continuous optimization algorithms like Gradient Descent instead of the greedy algorithms normal decision trees need to be constructed with. All we need to do is define the loss function:

The loss function for the neural tree (Image from Frosst & Hinton 2017 [1])

The loss function for this tree is similar to cross-entropy loss. In this equation, P^l(x) is the probability of reaching the leaf node l given the data point x, T_k is the probability of the target being class k (1 or 0), and Q_k^l is the element of the tensor (probability distribution) in the leaf node l that corresponds to class k.

One important note about this structure is that the tree shape is fixed. Unlike normal decision trees which use a greedy algorithm to split one node at a time and grow the tree, with this soft decision tree we first set the size of the tree and then use gradient descent to update all parameters simultaneously. One benefit of this approach is that it is much easier to constrain the size of the tree without losing much discriminative power.

One potential pitfall we may run into during the training process is that the model can heavily favor a single branch and not leverage the power of a tree. To avoid getting stuck on poor solutions, it is recommended to introduce a penalty to the loss function that encourages the tree to leverage both left and right sub-trees.

The penalty is the cross-entropy between the desired averaged distribution (50/50 for both left and right trees) and the actual average distribution defined as alpha.

Definition of alpha for a given node i (Image from Frosst & Hinton 2017 [1])

In this equation, P^i(x) is the path probability from the root node to node i for a given data point x. We then sum the penalty across all internal nodes.

Definition of the penalty (Image from Frosst & Hinton 2017 [1])

In this equation, lambda is a hyperparameter that determines the strength of the penalty. However, this may cause some problems as we descend the tree, the data has less of a chance to fall into a 50/50 split, so therefore it is encouraged to instead use an adaptive lambda that changes depending on the depth of the tree. This would modify the penalty to:

Modified Penalty function (Image by author, taken from text in Frosst & Hinton 2017 [1])

As we go deeper into the tree, it is recommended to decay the lambda proportional to 2^-d.

Visualizing Results

While the reformulation of a neural network into a tree is interesting, the main reason one would pursue this avenue is to provide more model interpretability.

Let’s first see the interpretation on a classic problem—digit classification in MNIST:

Example on MNIST (Image from Frosst & Hinton 2017 [1])

In the figure above, the images at the inner nodes are the learned filters, and the images at the leaves are visualizations of the learned probability distribution over classes. For each leaf and node, the most likely classification is annotated in blue.

Looking at this tree, we can see some interesting features. For example, if we look at the right-most internal node the potential classifications are between 3 and 8. In fact, we can actually see the outline of the 3 in the decision node visualization. The white areas seem to indicate that the model looks for lines that close out the inner loops of a 3 to make it an 8. We can also see that the model looks for the shape of a zero in the 3rd node from the left in the bottom set of inner nodes.

Another interesting example is predicting wins in Connect4

Visualization of the first 2 Layers of a neural decision tree predicting who would win a Connect4 game (Image from Frosst & Hinton 2017)

The learned filters in this example show that the game can be split into two distinct types of games: ones where the players have primarily focused on the edge of the board versus the ones where players have placed pieces in the center of the board.

Constructing Neural Networks as a Soft-Decision Tree allows us to leverage the power of neural networks while still preserving some interpretability. As shown by the results on the MNIST dataset, the learned filters can provide both local and global explainability, a trait that is welcome and helpful for higher stake tasks as well. Additionally, the training method (optimizing and updating the entire tree at a time) allows us to get more discriminative power while keeping the size of the tree fixed, something we are not able to achieve with normal decision trees.

Despite this, Neural Trees are not perfect. The soft nature of the tree does mean that the data scientist using these needs to “preprocess” the tree before presenting it to non-technical stakeholders, while normal decision trees can be shown as is (as they are pretty self-explanatory). Additionally, while the oblique nature of the tree helps with accuracy, having too many variables in a given node can again make it harder to explain. This means that regularization is not only recommended but to a certain extent it is necessary. Also, no matter how interpretable a model is, there is still the need for interpretable features that stakeholders can understand.

All of these drawbacks however do not detract from the potential of these models to grow the interpretability-performance frontier. I highly encourage everyone to try these models in their next data science task. I also recommend that everyone read the original paper.

  1. Implementation of the Soft Decision Tree in PyTorch
  2. For more information on XAI and time-series forecasting, follow

References

[1] N. Frosst, G. Hinton. Distilling a Neural Network Into a Soft Decision Tree (2017). 2017 Artificial Intelligence in Action Conference

[2] A. Shrikumar, P. Greenside, A. Jundjae. Learning Important Features Through Propagating Activation Differences (2017). International Conference on Machine Learning PMLR 2017.

[3] S.Bach, A. Binder, G. Montavon, F. Klauschen, K-R. Muller, W. Samek. On Pixel-Wise Explanations for Non-Linear Classifier Decisions by Layer-Wise Relevance Propagation (2015). PloS one, 10(7), e0130140

[4] L. Zhang, J. Varadarajan, P. N. Suganthan, N. Ahuja, P. Moulin. Robust Visual Tracking Using Oblique Random Forests (2017). Conference on Computer Vision and Pattern Recognition 2017.

FOLLOW US ON GOOGLE NEWS

Read original article here

Denial of responsibility! Techno Blender is an automatic aggregator of the all world’s media. In each content, the hyperlink to the primary source is specified. All trademarks belong to their rightful owners, all materials to their authors. If you are the owner of the content and do not want us to publish your materials, please contact us by email – [email protected]. The content will be deleted within 24 hours.
Leave a comment