Techno Blender
Digitally Yours.

Improving TabTransformer Part 1: Linear Numerical Embeddings | by Anton Rubert | Oct, 2022

0 67


Deep learning for tabular data with FT-Transformer

Photo by Nick Hillier on Unsplash

In the previous post about TabTransformer I’ve described how the model works and how it can be applied to your data. This post will build on it, so if you haven’t read it yet, I highly recommend starting there and returning to this post afterwards.

TabTransformer was shown to outperform traditional multi-layer perceptrons (MLPs) and came close to the performance of Gradient Boosted Trees (GBTs) on some datasets. However, there is one noticeable drawback with the architecture — it doesn’t take numerical features into account when constructing contextual embeddings. This post deep dives into the paper by Gorishniy et al. (2021) which has addressed this issue by introducing FT-Transformer (Feature Tokenizer + Transformer).

Both models use Transformers (Vaswani et al., 2017) as their model backbone, but there are 2 main differences:

  • Use of numerical embeddings
  • Use of CLS token for output

Numerical Embeddings

Traditional TabTransformer takes categorical embeddings and passes them through the Transformer blocks to transform them into contextual ones. Then, numerical features are concatenated with these contextual embeddings and are passed through the MLP to get a prediction.

TabTransformer diagram. Image by author.

Most of the magic happens inside the Transformer blocks, so it’s a shame that numerical features are left out and are only used in the final layers of the model. Gorishniy et al. (2021) propose to address this issue by embedding numerical features as well.

The embeddings that the FT-Transformer uses are linear, meaning that each feature gets transformed into dense vector after passing through a simple fully connected layer. It should be noted that these dense layers don’t share weights, so there’s a separate embedding layer per numeric feature.

Linear Numerical Embeddings. Image by author.

You might find yourself asking — why would you do that if these features are already numeric? The main reason is that numerical embeddings can be passed through the Transformer blocks together with the categorical ones. This adds more context to learn from and hence improves the representation quality.

Transformer with Numerical Embeddings. Image by author.

Interestingly, it was demonstrated (e.g. here) that the addition of these numerical embeddings can improve the performance of various deep learning models (not only TabTransformer), so they can be applied even to simple MLPs.

MLP with Numerical Embeddings. Image by author.

CLS Token

The usage of CLS token is adapted from NLP domain but it translates quite nicely to the tabular tasks. The basic idea is that after we’ve embedded our features, we append to them another “embedding” which represents a CLS token. This way, categorical, numerical and CLS embeddings get contextualised by passing through the Transformer blocks. Afterwards, contextualised CLS token embedding serves as an input into a simple MLP classifier which produces the desired output.

FT-Transformer

By augmenting TabTransformer with numerical embeddings and CLS token, we get the final proposed architecture.

FT-Transformer. Image by author.
Reported results for FT-Transformer. Source: Gorishniy et al. (2021)

From the results we can see that FT-Transformer outperforms gradient boosting models on a variety of dataset. In addition, it outperforms ResNet which is a strong deep learning baseline for tabular data. Interestingly, hyperparameter tuning doesn’t change the FT-Transformer results that much which might indicate that it’s not that sensitive to the hyperparameters.

This section is going to show you how to use FT-Transformer by validating the results for Adult Income Dataset. I’m going to use a package called tabtransformertf which can be installed using pip install tabtransformertf . It allows us to use the tabular transformer models without extensive pre-processing. Below you can see the main steps and results of the analysis but make sure to look into the supplementary notebook for more details.

Data pre-processing

Data can be download from here or using a number of APIs. Data pre-processing steps are not that relevant for this post, so you can find a full working example on GitHub. FT-Transformer specific pre-processing is similar to TabTransformer since we need to create the categorical preprocessing layers and transform the data into TF Datasets.

FT-Transformer Initialisation

Initialisation of the model is relatively straightforward and each of the parameters is commented on. Three FT-Transformer specific parameters are — numerical_embeddings , numerical_embedding_type and explainable

  • numerical_embeddings — similar to category_lookup , these are preprocessing layers. It is None for FT-Transformer because we don’t pre-process numerical features.
  • numerical_embedding_typelinear for linear embeddings. More types will be covered in the next post.
  • explainable — if set toTrue the model will output feature importances for each row. They’re inferred from attention weights.

Model Training

Training procedure is similar to any Keras model. The only thing to watch out for is if you’ve specified explainable as True , then you need two losses and metrics instead of one.

Training takes roughly 70 epochs, below you can see the progress of loss and metric values. You can reduce the number of early stopping rounds for less rounds or simplify the model further (e.g. less attention heads) to speed up the training.

Training/Validation loss and metric. Plots by author.

Evaluation

Test dataset is evaluated using ROC AUC and PR AUC since it’s an imbalanced binary classification problem. To validated the reported results, I’m also including the accuracy metric assuming a threshold of 0.5.

The resulting accuracy score is 0.8576 which is just slightly below the reported score of 0.86. This difference might be due to random variations during training or due to different hyperparameters. Still the results are close enough to the reported ones, so it’s a good sign that the research is reproducible.

Explainability

One of the biggest advantages of FT-Transformer is the in-built explainability. Since all the features are passed through a Transformer, we can get their attention maps and infer feature importances. These importances are calculated using the following formula

Feature importances formula. Source: Gorishniy et al. (2021)

where p_ihl is the h-th head’s attention map for the [CLS] token from the forward pass of the l-th layer on the i-th sample. The formula basically sums up all the attention scores for [CLS] token across different attention-heads (heads parameter) and Transformer layers (depth parameter) and then divides them by heads x depth. Local importances (p_i) can be averaged across all rows to get the global importances (p).

Now, let’s see what are the importances for the Adult income dataset.

From code above you can see that the model already outputs most of the information we need. Processing and plotting it gives the following results.

Feature importances. Plot by author.

Top-5 features indeed make sense, since people with larger income tend to be older, married and more educated. We can sense check the local importances as well by looking at the importances for largest prediction and smallest one.

Top 3 contributions. Created by author.

Again, the importances make intuitive sense. A person with the largest probability of earning more than 50K has large capital gains, 15 years of education, and is old. The person with lowest chances is just 18 years old, finished 10 years of education and work 15 hours a week.

In this post you saw what the FT-Transformer is, how it differs from the TabTransformer, and how it can be trained using tabtransformertf package.

Overall, the FT-Transformer is a promising addition to the deep tabular learning domain. By embedding not only categorical but also numerical features, the model was able to significantly improve its performance when compared to TabTransformer, and further reduced the gap between deep models and gradient boosted models like XGBoost. In addition, the model is explainable which is beneficial for many domains.

My next post is going to cover different numerical embedding types (not just linear) which improves the performance even further. Stay tuned!

  • Adult Income Dataset (Creative Commons Attribution 4.0 International license (CC BY 4.0)) — Dua, D. and Graff, C. (2019). UCI Machine Learning Repository [http://archive.ics.uci.edu/ml]. Irvine, CA: University of California, School of Information and Computer Science.
  • Yury Gorishniy, et al., 2021, Revisiting Deep Learning Models for Tabular Data
  • Vaswani et al., 2017, https://arxiv.org/abs/2106.11959


Deep learning for tabular data with FT-Transformer

Photo by Nick Hillier on Unsplash

In the previous post about TabTransformer I’ve described how the model works and how it can be applied to your data. This post will build on it, so if you haven’t read it yet, I highly recommend starting there and returning to this post afterwards.

TabTransformer was shown to outperform traditional multi-layer perceptrons (MLPs) and came close to the performance of Gradient Boosted Trees (GBTs) on some datasets. However, there is one noticeable drawback with the architecture — it doesn’t take numerical features into account when constructing contextual embeddings. This post deep dives into the paper by Gorishniy et al. (2021) which has addressed this issue by introducing FT-Transformer (Feature Tokenizer + Transformer).

Both models use Transformers (Vaswani et al., 2017) as their model backbone, but there are 2 main differences:

  • Use of numerical embeddings
  • Use of CLS token for output

Numerical Embeddings

Traditional TabTransformer takes categorical embeddings and passes them through the Transformer blocks to transform them into contextual ones. Then, numerical features are concatenated with these contextual embeddings and are passed through the MLP to get a prediction.

TabTransformer diagram. Image by author.

Most of the magic happens inside the Transformer blocks, so it’s a shame that numerical features are left out and are only used in the final layers of the model. Gorishniy et al. (2021) propose to address this issue by embedding numerical features as well.

The embeddings that the FT-Transformer uses are linear, meaning that each feature gets transformed into dense vector after passing through a simple fully connected layer. It should be noted that these dense layers don’t share weights, so there’s a separate embedding layer per numeric feature.

Linear Numerical Embeddings. Image by author.

You might find yourself asking — why would you do that if these features are already numeric? The main reason is that numerical embeddings can be passed through the Transformer blocks together with the categorical ones. This adds more context to learn from and hence improves the representation quality.

Transformer with Numerical Embeddings. Image by author.

Interestingly, it was demonstrated (e.g. here) that the addition of these numerical embeddings can improve the performance of various deep learning models (not only TabTransformer), so they can be applied even to simple MLPs.

MLP with Numerical Embeddings. Image by author.

CLS Token

The usage of CLS token is adapted from NLP domain but it translates quite nicely to the tabular tasks. The basic idea is that after we’ve embedded our features, we append to them another “embedding” which represents a CLS token. This way, categorical, numerical and CLS embeddings get contextualised by passing through the Transformer blocks. Afterwards, contextualised CLS token embedding serves as an input into a simple MLP classifier which produces the desired output.

FT-Transformer

By augmenting TabTransformer with numerical embeddings and CLS token, we get the final proposed architecture.

FT-Transformer. Image by author.
Reported results for FT-Transformer. Source: Gorishniy et al. (2021)

From the results we can see that FT-Transformer outperforms gradient boosting models on a variety of dataset. In addition, it outperforms ResNet which is a strong deep learning baseline for tabular data. Interestingly, hyperparameter tuning doesn’t change the FT-Transformer results that much which might indicate that it’s not that sensitive to the hyperparameters.

This section is going to show you how to use FT-Transformer by validating the results for Adult Income Dataset. I’m going to use a package called tabtransformertf which can be installed using pip install tabtransformertf . It allows us to use the tabular transformer models without extensive pre-processing. Below you can see the main steps and results of the analysis but make sure to look into the supplementary notebook for more details.

Data pre-processing

Data can be download from here or using a number of APIs. Data pre-processing steps are not that relevant for this post, so you can find a full working example on GitHub. FT-Transformer specific pre-processing is similar to TabTransformer since we need to create the categorical preprocessing layers and transform the data into TF Datasets.

FT-Transformer Initialisation

Initialisation of the model is relatively straightforward and each of the parameters is commented on. Three FT-Transformer specific parameters are — numerical_embeddings , numerical_embedding_type and explainable

  • numerical_embeddings — similar to category_lookup , these are preprocessing layers. It is None for FT-Transformer because we don’t pre-process numerical features.
  • numerical_embedding_typelinear for linear embeddings. More types will be covered in the next post.
  • explainable — if set toTrue the model will output feature importances for each row. They’re inferred from attention weights.

Model Training

Training procedure is similar to any Keras model. The only thing to watch out for is if you’ve specified explainable as True , then you need two losses and metrics instead of one.

Training takes roughly 70 epochs, below you can see the progress of loss and metric values. You can reduce the number of early stopping rounds for less rounds or simplify the model further (e.g. less attention heads) to speed up the training.

Training/Validation loss and metric. Plots by author.

Evaluation

Test dataset is evaluated using ROC AUC and PR AUC since it’s an imbalanced binary classification problem. To validated the reported results, I’m also including the accuracy metric assuming a threshold of 0.5.

The resulting accuracy score is 0.8576 which is just slightly below the reported score of 0.86. This difference might be due to random variations during training or due to different hyperparameters. Still the results are close enough to the reported ones, so it’s a good sign that the research is reproducible.

Explainability

One of the biggest advantages of FT-Transformer is the in-built explainability. Since all the features are passed through a Transformer, we can get their attention maps and infer feature importances. These importances are calculated using the following formula

Feature importances formula. Source: Gorishniy et al. (2021)

where p_ihl is the h-th head’s attention map for the [CLS] token from the forward pass of the l-th layer on the i-th sample. The formula basically sums up all the attention scores for [CLS] token across different attention-heads (heads parameter) and Transformer layers (depth parameter) and then divides them by heads x depth. Local importances (p_i) can be averaged across all rows to get the global importances (p).

Now, let’s see what are the importances for the Adult income dataset.

From code above you can see that the model already outputs most of the information we need. Processing and plotting it gives the following results.

Feature importances. Plot by author.

Top-5 features indeed make sense, since people with larger income tend to be older, married and more educated. We can sense check the local importances as well by looking at the importances for largest prediction and smallest one.

Top 3 contributions. Created by author.

Again, the importances make intuitive sense. A person with the largest probability of earning more than 50K has large capital gains, 15 years of education, and is old. The person with lowest chances is just 18 years old, finished 10 years of education and work 15 hours a week.

In this post you saw what the FT-Transformer is, how it differs from the TabTransformer, and how it can be trained using tabtransformertf package.

Overall, the FT-Transformer is a promising addition to the deep tabular learning domain. By embedding not only categorical but also numerical features, the model was able to significantly improve its performance when compared to TabTransformer, and further reduced the gap between deep models and gradient boosted models like XGBoost. In addition, the model is explainable which is beneficial for many domains.

My next post is going to cover different numerical embedding types (not just linear) which improves the performance even further. Stay tuned!

  • Adult Income Dataset (Creative Commons Attribution 4.0 International license (CC BY 4.0)) — Dua, D. and Graff, C. (2019). UCI Machine Learning Repository [http://archive.ics.uci.edu/ml]. Irvine, CA: University of California, School of Information and Computer Science.
  • Yury Gorishniy, et al., 2021, Revisiting Deep Learning Models for Tabular Data
  • Vaswani et al., 2017, https://arxiv.org/abs/2106.11959

FOLLOW US ON GOOGLE NEWS

Read original article here

Denial of responsibility! Techno Blender is an automatic aggregator of the all world’s media. In each content, the hyperlink to the primary source is specified. All trademarks belong to their rightful owners, all materials to their authors. If you are the owner of the content and do not want us to publish your materials, please contact us by email – [email protected]. The content will be deleted within 24 hours.

Leave a comment