TabTransformer - Tabular Data Modeling Using Contextual Embeddings
TabTransformer - Tabular Data Modeling Using Contextual Embeddings
plore the average and maximum pooling strategy (Howard 3.2 The Robustness of TabTransformer
and Ruder 2018) rather than concatenation of embeddings We further demonstrate the robustness of TabTransformer
as the features for the linear model. The upward pattern on the noisy data and data with missing values, against the
clearly shows that embeddings becomes more effective as baseline MLP. We consider these two scenarios only on cat-
the Transformer layer progresses. In contrast, the embed- egorical features to specifically prove the robustness of con-
dings from MLP (the single black markers) perform worse textual embeddings from the Transformer layers.
with a linear model. Furthermore, the last value in each line
close to 1.0 indicates that a linear model with the last layer of Noisy Data. On the test examples, we firstly contaminate
embeddings as features can achieve reliable accuracy, which the data by replacing a certain number of values by randomly
confirms our assumption. generated ones from the corresponding columns (features).
Next, the noisy data are passed into a trained TabTrans-
former to compute a prediction AUC score. Results on a set
of 3 different dataets are presented in Figure 4. As the noisy
rate increases, TabTransformer performs better in prediction
accuracy and thus is more robust than MLP. In particular
notice the Blastchar dataset where the performance is near
identical with no noise, yet as the noise increases, TabTrans-
former becomes significantly more performant compared to
the baseline. We conjecture that the robustness comes from
the contextual property of the embeddings. Despite a feature
being noisy, it draws information from the correct features
allowing for a certain amount of correction.
Data with Missing Values. Similarly, on the test data we
artificially select a number of values to be missing and send
the data with missing values to a trained TabTransformer
to compute the prediction score. There are two options to
handle the embeddings of missing values: (1) Use the aver-
age learned embeddings over all classes in the corresponding
column; (2) the embedding for the class of missing value, the
additional embedding for each column mentioned in Section
2. Since the benchmark datasets do not contain enough miss-
ing values to effectively train the embedding in option (2),
we use the average embedding in (1) for imputation. Re-
Figure 3: Predictions of liner models using features as the sults on the same 3 datasets are presented in Figure 5. We
embeddings extracted from different Transformer layers in can see the same patterns of the noisy data case, i.e. that the
TabTransformer. Layer 0 corresponds to the embeddings TabTransformer shows better stability than MLP in handling
before being passed into the Transformer layers. For each missing values.
dataset, each prediction score is normalized by the “best
score” from an end-to-end trained TabTransformer. 3.3 Supervised Learning
Here we compare the performance of TabTransformer
against following four categories of methods: (a) Logistic
Figure 4: Performance of TabTransformer and MLP with Figure 5: Performance of TabTransformer and MLP un-
noisy data. For each dataset, each prediction score is nor- der missing data scenario. For each dataset, each prediction
malized by the score of TabTransformer at 0 noise. score is normalized by the score of TabTransformer trained
without missing values.
Table 2: Model performance in supervised learning. The
evaluation metric is mean ± standard deviation of AUC
score over the 15 datasets for each model. Larger the num- unlabeled samples. Specifically, we compare our pretrained
ber, better the result. The top 2 numbers are bold. and then fine-tuned TabTransformer-RTD/MLM against fol-
lowing semi-supervised models: (a) Entropy Regulariza-
Model Name Mean AUC (%) tion (ER) (Grandvalet and Bengio 2006) combined with
TabTransformer 82.8 ± 0.4 MLP and TabTransformer (b) Pseudo Labeling (PL) (Lee
MLP 81.8 ± 0.4 2013) combined with MLP, TabTransformer, and GBDT
GBDT 82.9 ± 0.4 (Jain 2017) (c) MLP (DAE): an unsupervised pre-training
Sparse MLP 81.4 ± 0.4 method designed for deep models on tabular data: the swap
Logistic Regression 80.4 ± 0.4 noise Denoising AutoEncoder (Jahrer 2018).
TabNet 77.1 ± 0.5 The pre-training models TabTransformer-MLM,
VIB 80.5 ± 0.4 TabTransformer-RTD and MLP (DAE) are firstly pre-
trained on the entire unlabeled training data and then
fine-tuned on labeled data. The semi-supervised learning
methods, Pseudo Labeling and Entropy Regularization, are
regression and GBDT (b) MLP and a sparse MLP following trained on the mix of labeled and unlabeled training data.
(Morcos et al. 2019) (c) TabNet model of Arik and Pfister To better present results, we split the set of 15 datasets into
(2019) (d) and the Variational Information Bottleneck model two subsets. The first set includes 6 datasets with more than
(VIB) of Alemi et al. (2017). 30K data points and the second set includes remaining 9
Results are summarized in Table 2. TabTransformer, datasets.
MLP, and GBDT are the top 3 performers. The TabTrans-
former outperforms the baseline MLP with an average 1.0% The results are presented in Table 3 and Table 4. When
gain and perform comparable with the GBDT. Furthermore, the number of unlabeled data is large, Table 3 shows that
the TabTransformer is significantly better than TabNet and our TabTransformer-RTD and TabTransformer-MLM sig-
VIB, the recent deep networks for tabular data. For experi- nificantly outperform all the other competitors. Particularly,
ment and model details, see Appendix B. The models’ per- TabTransformer-RTD/MLM improves over all the other
formances on each individual dataset are presented in Table competitors by at least 1.2%, 2.0% and 2.1% on mean AUC
16 and 17 in Appendix C. for the scenario of 50, 200, and 500 labeled data points
respectively. The Transformer-based semi-supervised learn-
ing methods TabTransformer (ER) and TabTransformer (PL)
3.4 Semi-supervised Learning and the tree-based semi-supervised learning method GBDT
Lastly, we evaluate the TabTransformer under the semi- (PL) perform worse than the average of all the models. When
supervised learning scenario where few labeled training ex- the number of unlabeled data becomes smaller, as shown in
amples are available together with a significant number of Table 4, TabTransformer-RTD still outperforms most of its
Table 3: Semi-supervised learning results for 8 datasets each deep models designed specifically for tabular data, there are
with more than 30K data points, for different number of la- deep versions of factorization machines (Guo et al. 2018;
beled data points. Evaluation metrics are mean AUC in per- Xiao et al. 2017), Transformers-based methods (Song et al.
centage. Larger the number, better the result. 2019; Li et al. 2020; Sun et al. 2019), and deep versions of
decision-tree-based algorithms (Ke et al. 2019; Yang, Mo-
# Labeled data 50 200 500 rillo, and Hospedales 2018). In particular, (Song et al. 2019)
TabTransformer-RTD 66.6 ± 0.6 70.9 ± 0.6 73.1 ± 0.6 applies one layer of multi-head attention on embeddings to
TabTransformer-MLM 66.8 ± 0.6 71.0 ± 0.6 72.9 ± 0.6 learn higher order features. The higher order features are
MLP (ER) 65.6 ± 0.6 69.0 ± 0.6 71.0 ± 0.6 concatenated and inputted to a fully connected layer to make
MLP (PL) 65.4 ± 0.6 68.8 ± 0.6 71.0 ± 0.6 the final prediction. (Li et al. 2020) use self-attention lay-
TabTransformer (ER) 62.7 ± 0.6 67.1 ± 0.6 69.3 ± 0.6
TabTransformer (PL) 63.6 ± 0.6 67.3 ± 0.7 69.3 ± 0.6 ers and track the attention scores to obtain feature impor-
MLP (DAE) 65.2 ± 0.5 68.5 ± 0.6 71.0 ± 0.6 tance scores. (Sun et al. 2019) combine the Factorization
GBDT (PL) 56.5 ± 0.5 63.1 ± 0.6 66.5 ± 0.7 Machine model with transformer mechanism. All 3 papers
are focused on recommendation systems making it hard to
Table 4: Semi-supervised learning results for 12 datasets have a clear comparison with this paper. Other models have
each with less than 30K data points, for different number been designed around the purported properties of tabular
of labeled data points. Evaluation metrics are mean AUC in data such as low-order and sparse feature interactions. These
percentage. Larger the number, better the result. include Deep & Cross Networks (Wang et al. 2017), Wide
& Deep Networks (Cheng et al. 2016), TabNets (Arik and
# Labeled data 50 200 500 Pfister 2019), and AdaNet (Cortes et al. 2016).
Semi-supervised learning. (Izmailov et al. 2019) give a
TabTransformer-RTD 78.6 ± 0.6 81.6 ± 0.5 83.4 ± 0.5 semi-supervised method based on density estimation and
TabTransformer-MLM 78.5 ± 0.6 81.0 ± 0.6 82.4 ± 0.5
MLP (ER) 79.4 ± 0.6 81.1 ± 0.6 82.3 ± 0.6 evaluate their approach on tabular data. Pseudo labeling
MLP (PL) 79.1 ± 0.6 81.1 ± 0.6 82.0 ± 0.6 (Lee 2013) is a simple, efficient and popular baseline
TabTransformer (ER) 77.9 ± 0.6 81.2 ± 0.6 82.1 ± 0.6 method. The Pseudo labeling uses the current network to
TabTransformer (PL) 77.8 ± 0.6 81.0 ± 0.6 82.1 ± 0.6 infer pseudo-labels of unlabeled examples, by choosing the
MLP (DAE) 78.5 ± 0.7 80.7 ± 0.6 82.2 ± 0.6 most confident class. These pseudo-labels are treated like
GBDT (PL) 73.4 ± 0.7 78.8 ± 0.6 81.3 ± 0.6 human-provided labels in the cross entropy loss. Label prop-
agation (Zhu and Ghahramani 2002), (Iscen et al. 2019) is
a similar approach where a node’s labels propagate to all
competitors but with a marginal improvement. nodes according to their proximity, and are used by the train-
Furthermore, we observe that when the number of unla- ing model as if they were the true labels. Another stan-
beled data is small as shown in Table 4, TabTransformer- dard method in semi-supervised learning is entropy reg-
RTD performs better than TabTransformer-MLM, thanks to ularization (Grandvalet and Bengio 2005; Sajjadi, Javan-
its easier pre-training task (a binary classification) than that mardi, and Tasdizen 2016). It adds average per-sample en-
of MLM (a multi-class classification). This is consistent with tropy for the unlabeled examples to the original loss func-
the finding of the ELECTRA paper (Clark et al. 2020). In tion for the labeled examples. Another classical approach of
Table 4, with only 50 labeled data points, MLP (ER) and semi-supervised learning is co-training (Nigam and Ghani
MLP (PL) beat our TabTransformer-RTD/MLM. This can 2000). However, the recent approaches - entropy regulariza-
be attributed to the fact that there is room for improvement tion and pseudo labeling - are typically better and more pop-
in our fine-tuning procedure. In particular, our approach al- ular. A succinct review of semi-supervised learning methods
lows to obtain informative embeddings but does not allow in general can be found in (Oliver et al. 2019; Chappelle,
the weights of the classifier itself to be trained with unla- Schölkopf, and Zien 2010).
belled data. Since this issue does not occur for ER and PL,
they obtain an advantage in extremely small labelled set. We 5 Conclusion
point out however that this only means that the methods are
We proposed TabTransformer, a novel deep tabular data
complementary and mention that a possible follow up could
modeling architecture for supervised and semi-supervised
combine the best of all approaches.
learning. We provide extensive empirical evidence showing
Both evaluation results, Table 3 and Table 4, show that
TabTransformer significantly outperforms MLP and recent
our TabTransformer-RTD and Transformers-MLM models
deep networks for tabular data while matching the perfor-
are promising in extracting useful information from unla-
mance of tree-based ensemble models (GBDT). We pro-
beled data to help supervised training, and are particularly
vide and extensively study a two-phase pre-training then
useful when the size of unlabeled data is large. For model
fine-tune procedure for tabular data, beating the state-of-the-
performance on each individual dataset see Table 10, 11, 12,
art performance of semi-supervised learning methods. Tab-
13, 14, 15 in Appendix C.
Transformer shows promising results for robustness against
noisy and missing data, and interpretability of the contex-
4 Related Work tual embeddings. For future work, it would be interesting to
Supervised learning. Standard MLPs have been applied to investigate them in detail.
tabular data for many years (De Brébisson et al. 2015). For
References Grandvalet, Y.; and Bengio, Y. 2005. Semi-supervised learn-
Alemi, A. A.; Fischer, I.; and Dillon, J. V. 2018. Uncertainty ing by entropy minimization. In Advances in neural informa-
in the Variational Information Bottleneck. arXiv:1807.00906 tion processing systems, 529–536.
[cs, stat] URL http://arxiv.org/abs/1807.00906. ArXiv: Grandvalet, Y.; and Bengio, Y. 2006. Entropy regularization.
1807.00906. Semi-supervised learning 151–168.
Alemi, A. A.; Fischer, I.; Dillon, J. V.; and Murphy, K. 2017. Guo, H.; Tang, R.; Ye, Y.; Li, Z.; He, X.; and Dong, Z. 2018.
Deep Variational Information Bottleneck. International Con- DeepFM: An End-to-End Wide & Deep Learning Framework
ference on Learning Representations abs/1612.00410. URL for CTR Prediction. arXiv:1804.04950 [cs, stat] URL http:
https://arxiv.org/abs/1612.00410. //arxiv.org/abs/1804.04950. ArXiv: 1804.04950.
Arik, S. O.; and Pfister, T. 2019. TabNet: Attentive Inter-
pretable Tabular Learning. arXiv preprint arXiv:1908.07442 Guyon, I.; Sun-Hosoya, L.; Boullé, M.; Escalante, H. J.;
URL https://arxiv.org/abs/1908.07442. Escalera, S.; Liu, Z.; Jajetic, D.; Ray, B.; Saeed, M.; Se-
bag, M.; Statnikov, A.; Tu, W.; and Viegas, E. 2019. Anal-
Ban, G.-Y.; El Karoui, N.; and Lim, A. E. 2018. Machine ysis of the AutoML Challenge series 2015-2018. In Au-
learning and portfolio optimization. Management Science toML, Springer series on Challenges in Machine Learn-
64(3): 1136–1154. ing. URL https://www.automl.org/wp-content/uploads/2018/
Bradley, A. P. 1997. The use of the area under the ROC curve 09/chapter10-challenge.pdf.
in the evaluation of machine learning algorithms. Pattern Howard, J.; and Ruder, S. 2018. Universal language
recognition 30(7): 1145–1159. model fine-tuning for text classification. arXiv preprint
Brunner, G.; Liu, Y.; Pascual, D.; Richter, O.; and Watten- arXiv:1801.06146 .
hofer, R. 2019. On the validity of self-attention as explana-
Huang, Z.; Xu, W.; and Yu, K. 2015. Bidirectional
tion in transformer models. arXiv preprint arXiv:1908.04211
LSTM-CRF models for sequence tagging. arXiv preprint
.
arXiv:1508.01991 .
Chapelle, O.; Scholkopf, B.; and Zien, A. 2009. Semi-
supervised learning). IEEE Transactions on Neural Networks Iscen, A.; Tolias, G.; Avrithis, Y.; and Chum, O. 2019. Label
20(3): 542–542. propagation for deep semi-supervised learning. In Proceed-
ings of the IEEE Conference on Computer Vision and Pattern
Chappelle, O.; Schölkopf, B.; and Zien, A. 2010. Semi- Recognition, 5070–5079.
supervised learning. Adaptive Computation and Machine
Learning. Izmailov, P.; Kirichenko, P.; Finzi, M.; and Wilson, A. G.
2019. Semi-Supervised Learning with Normalizing Flows.
Chen, T.; and Guestrin, C. 2016. Xgboost: A scalable tree arXiv:1912.13025 [cs, stat] URL http://arxiv.org/abs/1912.
boosting system. In Proceedings of the 22nd acm sigkdd in- 13025. ArXiv: 1912.13025.
ternational conference on knowledge discovery and data min-
ing, 785–794. Jahrer, M. 2018. Porto Seguro’s Safe Driver Prediction. URL
Cheng, H.-T.; Koc, L.; Harmsen, J.; Shaked, T.; Chandra, T.; https://kaggle.com/c/porto-seguro-safe-driver-prediction.
Aradhye, H.; Anderson, G.; Corrado, G.; Chai, W.; Ispir, M.; Jain, S. 2017. Introduction to Pseudo-Labelling
et al. 2016. Wide & deep learning for recommender systems. : A Semi-Supervised learning technique. https:
In Proceedings of the 1st workshop on deep learning for rec- //www.analyticsvidhya.com/blog/2017/09/pseudo-labelling-
ommender systems, 7–10. semi-supervised-learning-technique/.
Clark, K.; Luong, M.-T.; Le, Q. V.; and Manning, C. D. Kaggle, Inc. 2017. The State of ML and Data Science 2017.
2020. ELECTRA: Pre-training Text Encoders as Discrimi- URL https://www.kaggle.com/surveys/2017.
nators Rather Than Generators. In International Conference
on Learning Representations. URL https://openreview.net/ Ke, G.; Meng, Q.; Finley, T.; Wang, T.; Chen, W.;
forum?id=r1xMH1BtvB. Ma, W.; Ye, Q.; and Liu, T.-Y. 2017. LightGBM: A
highly efficient gradient boosting decision tree. In Ad-
Coenen, A.; Reif, E.; Yuan, A.; Kim, B.; Pearce, A.; Viégas, vances in Neural Information Processing Systems, 3146–
F.; and Wattenberg, M. 2019. Visualizing and measuring the 3154. URL https://papers.nips.cc/paper/6907-lightgbm-a-
geometry of bert. arXiv preprint arXiv:1906.02715 . highly-efficient-gradient-boosting-decision-tree.pdf.
Cortes, C.; Gonzalvo, X.; Kuznetsov, V.; Mohri, M.; and
Ke, G.; Zhang, J.; Xu, Z.; Bian, J.; and Liu, T.-Y. 2019.
Yang, S. 2016. AdaNet: Adaptive Structural Learning of Ar-
TabNN: A Universal Neural Network Solution for Tabular
tificial Neural Networks.
Data. URL https://openreview.net/forum?id=r1eJssCqY7.
De Brébisson, A.; Simon, E.; Auvolat, A.; Vincent, P.; and
Bengio, Y. 2015. Artificial Neural Networks Applied to Taxi Kendall, A.; and Gal, Y. 2017. What uncertainties do we need
Destination Prediction. In Proceedings of the 2015th Inter- in bayesian deep learning for computer vision? In Advances
national Conference on ECML PKDD Discovery Challenge in neural information processing systems, 5574–5584.
- Volume 1526, ECMLPKDDDC’15, 40–51. Aachen, DEU: Klambauer, G.; Unterthiner, T.; Mayr, A.; and Hochreiter, S.
CEUR-WS.org. 2017. Self-normalizing neural networks. In Advances in neu-
Devlin, J.; Chang, M.-W.; Lee, K.; and Toutanova, K. 2019. ral information processing systems, 971–980.
BERT: Pre-training of Deep Bidirectional Transformers for Lee, D.-H. 2013. Pseudo-label: The simple and efficient
Language Understanding. In NAACL-HLT. semi-supervised learning method for deep neural networks.
Dua, D.; and Graff, C. 2017. UCI Machine Learning Reposi- In Workshop on challenges in representation learning, ICML,
tory. URL http://archive.ics.uci.edu/ml. volume 3, 2.
Li, Z.; Cheng, W.; Chen, Y.; Chen, H.; and Wang, W. 2020. Song, W.; Shi, C.; Xiao, Z.; Duan, Z.; Xu, Y.; Zhang, M.;
Interpretable Click-Through Rate Prediction through Hier- and Tang, J. 2019. AutoInt: Automatic Feature Interaction
archical Attention. In Proceedings of the 13th Interna- Learning via Self-Attentive Neural Networks. Proceedings
tional Conference on Web Search and Data Mining, 313– of the 28th ACM International Conference on Information
321. Houston TX USA: ACM. ISBN 978-1-4503-6822-3. and Knowledge Management - CIKM ’19 1161–1170. doi:
doi:10.1145/3336191.3371785. URL http://dl.acm.org/doi/ 10.1145/3357384.3357925. URL http://arxiv.org/abs/1810.
10.1145/3336191.3371785. 11921. ArXiv: 1810.11921.
Loshchilov, I.; and Hutter, F. 2017. Decoupled Weight De- Stretcu, O.; Viswanathan, K.; Movshovitz-Attias, D.; Pla-
cay Regularization. In International Conference on Learning tanios, E.; Ravi, S.; and Tomkins, A. 2019. Graph Agree-
Representations. URL https://arxiv.org/abs/1711.05101. ment Models for Semi-Supervised Learning. In Advances in
Maaten, L. v. d.; and Hinton, G. 2008. Visualizing data using Neural Information Processing Systems 32, 8713–8723. Cur-
t-SNE. Journal of machine learning research 9(Nov): 2579– ran Associates, Inc. URL http://papers.nips.cc/paper/9076-
2605. graph-agreement-models-for-semi-supervised-learning.pdf.
Sun, Q.; Cheng, Z.; Fu, Y.; Wang, W.; Jiang, Y.-G.; and Xue,
Mikolov, T.; Kombrink, S.; Burget, L.; Černockỳ, J.; and
X. 2019. DeepEnFM: Deep neural networks with Encoder
Khudanpur, S. 2011. Extensions of recurrent neural net-
enhanced Factorization Machine URL https://openreview.net/
work language model. In 2011 IEEE international confer-
forum?id=SJlyta4YPS.
ence on acoustics, speech and signal processing (ICASSP),
5528–5531. IEEE. Tanha, J.; Someren, M.; and Afsarmanesh, H. 2017. Semi-
supervised self-training for decision tree classifiers. Interna-
Morcos, A. S.; Yu, H.; Paganini, M.; and Tian, Y. 2019.
tional Journal of Machine Learning and Cybernetics 8: 355–
One ticket to win them all: generalizing lottery ticket initial-
370.
izations across datasets and optimizers. arXiv:1906.02773
[cs, stat] URL http://arxiv.org/abs/1906.02773. ArXiv: Torpey, E.; and Watson, A. 2014. Education level and
1906.02773. jobs: Opportunities by state. URL https://www.bls.gov/
careeroutlook/2014/article/education-level-and-jobs.htm.
Nigam, K.; and Ghani, R. 2000. Analyzing the effectiveness
and applicability of co-training. In Proceedings of the ninth Vaswani, A.; Shazeer, N.; Parmar, N.; Uszkoreit, J.; Jones, L.;
international conference on Information and knowledge man- Gomez, A. N.; Kaiser, Ł.; and Polosukhin, I. 2017. Attention
agement, 86–93. is all you need. In Advances in neural information processing
systems, 5998–6008.
Oliver, A.; Odena, A.; Raffel, C.; Cubuk, E. D.; and Goodfel-
low, I. J. 2019. Realistic Evaluation of Deep Semi-Supervised Voulodimos, A.; Doulamis, N.; Doulamis, A.; and Protopa-
Learning Algorithms. arXiv:1804.09170 [cs, stat] URL http: padakis, E. 2018. Deep learning for computer vision: A brief
//arxiv.org/abs/1804.09170. ArXiv: 1804.09170. review. Computational intelligence and neuroscience 2018.
Oliver, A.; Odena, A.; Raffel, C. A.; Cubuk, E. D.; and Good- Wang, R.; Fu, B.; Fu, G.; and Wang, M. 2017. Deep & Cross
fellow, I. 2018. Realistic evaluation of deep semi-supervised Network for Ad Click Predictions. In ADKDD@KDD.
learning algorithms. In Advances in Neural Information Pro- Xiao, J.; Ye, H.; He, X.; Zhang, H.; Wu, F.; and Chua,
cessing Systems, 3235–3246. T.-S. 2017. Attentional Factorization Machines: Learn-
Paszke, A.; Gross, S.; Massa, F.; Lerer, A.; Bradbury, J.; ing the Weight of Feature Interactions via Attention Net-
Chanan, G.; Killeen, T.; Lin, Z.; Gimelshein, N.; Antiga, works. In Proceedings of the Twenty-Sixth International
L.; Desmaison, A.; Kopf, A.; Yang, E.; DeVito, Z.; Raison, Joint Conference on Artificial Intelligence, 3119–3125. Mel-
M.; Tejani, A.; Chilamkurthy, S.; Steiner, B.; Fang, L.; Bai, bourne, Australia: International Joint Conferences on Arti-
J.; and Chintala, S. 2019. PyTorch: An Imperative Style, ficial Intelligence Organization. ISBN 978-0-9992411-0-
High-Performance Deep Learning Library. In Wallach, H.; 3. doi:10.24963/ijcai.2017/435. URL https://www.ijcai.org/
Larochelle, H.; Beygelzimer, A.; d’Alché Buc, F.; Fox, E.; proceedings/2017/435.
and Garnett, R., eds., Advances in Neural Information Pro- Yang, Y.; Morillo, I. G.; and Hospedales, T. M. 2018. Deep
cessing Systems 32, 8024–8035. Curran Associates, Inc. URL neural decision trees. arXiv preprint arXiv:1806.06988 .
http://papers.neurips.cc/paper/9015-pytorch-an-imperative-
style-high-performance-deep-learning-library.pdf. Zhu, X.; and Ghahramani, Z. 2002. Learning from labeled
and unlabeled data with label propagation .
Prokhorenkova, L.; Gusev, G.; Vorobev, A.; Dorogush, A. V.;
and Gulin, A. 2018. CatBoost: unbiased boosting with cate-
gorical features. In Advances in neural information process-
ing systems, 6638–6648.
Rong, X. 2014. word2vec parameter learning explained.
arXiv preprint arXiv:1411.2738 .
Sajjadi, M.; Javanmardi, M.; and Tasdizen, T. 2016. Regular-
ization with stochastic transformations and perturbations for
deep semi-supervised learning. In Advances in neural infor-
mation processing systems, 1163–1171.
Sandler, M.; Howard, A.; Zhu, M.; Zhmoginov, A.; and Chen,
L.-C. 2018. Mobilenetv2: Inverted residuals and linear bot-
tlenecks. In Proceedings of the IEEE conference on computer
vision and pattern recognition, 4510–4520.
A Appendix: Ablation Studies combine this study with another ablation on shared RTD
We perform a number of ablation studies on various archi- binary classifier (predictor) vs. different classifiers for dif-
tectural choices and pre-training approaches for our Tab- ferent columns. Results in Table 7 show that our choice
Transformer. The first ablation study is on the choice of of dynamic replacement and un-shared RTD classifiers per-
column embedding. The second and third ablation studies form better than static replacement and shared RTD clas-
focus on the pre-training approach. Specifically, they are sifiers. Figure 7 shows the pre-training curves of training
on the replacement value k and dynamic versus static re- and validation accuracy for the three choices – dynamic re-
placement strategy. For the pre-training approach, we use placement, static replacement, and static replacement with a
TabTransformer-RTD as our model. That is, the loss in the shared RTD classifier.
pre-training is RTD loss. For TabTransformer, the hidden
(embedding) dimension, the number of layers and the num- B Appendix: Experiment and Model Details
ber of attention heads in the Transformer are set to 32, 6, and In this section, we discuss the experiments and model de-
8 respectively. The MLP layer sizes are set to {4 × l, 2 × l}, tails. First, we go through the experiments details and hyper
where l is the size of its input. To better present the result, we parameters search space for HPO in Section B.1. Next, we
introduce an additional evaluation metric, the relative AUC. discuss the feature engineering in Section B.2.
More precisely, for each dataset and cross-validation split,
the relative AUC for a model is the relative change of its
AUC against the mean AUC over all competing models.
B.1 Experiments Details and Hyper Parameters
Setup. All experiments were run on an Ubuntu Linux
Column Embedding. The first study is on the choice machine with 8 CPUs and 60GB memory, with all mod-
of column embedding – shared parameters cφi across the els using a single NVIDIA V100 Tensor Core GPU. For
embeddings of multiple classes in column i for i ∈ the competing models mentioned in the experiment, we re-
{1, 2, ..., m}. In particular, we study the optimal dimension implemented all of them for consistency of pre-processing.
of cφi , `. An alternative choice is to element-wisely add the In cases where there exist published results for a model,
unique identifier cφi and feature-value specific embeddings our tested results are close to the published records. The
wφij rather than concatenating them. In that case, both the GBDT model is implemented using the LightGBM library
dimension of cφi and wφij are equal to the dimension of (Ke et al. 2017). All the other models are implemented using
embedding d. The goal of having column embedding is to the PyTorch library (Paszke et al. 2019). To reproduce our
enable the model to distinguish the classes in one column experiment results, the models’ implementations and the ex-
from those in the other columns. A baseline approach is to act values for all hyper-parameters can be found in another
not have any shared embedding. Results are presented in Ta- supplemental material, Code and Data Appendix.
ble 5 where “Col Embed-Concat-1/X” indicates that the di- For each dataset, all of the cross-validation splits, labeled,
mension ` is set as d/X. The relative AUC score is calcu- and unlabeled training data are obtained with a fixed random
lated over all the models that appear in the rows and columns seed such that every model tested receives exactly the same
in the table, which explains why negative scores appear in training and testing conditions.
some of the entries. Results show that not having the shared
As all the datasets are for binary classification, the
column embedding performs worst and our concatenation
cross entropy loss was used for both supervised and semi-
column embedding gives an average better performance.
supervised training (for pre-training, the problem is bi-
The replacement value k. The second ablation study is nary classification in RTD and multi-class classification
on the replacement value k in pre-training approach. We run in MLM). For all deep models, the AdamW optimizer
experiments for three different choices of k – {15, 30, 50} (Loshchilov and Hutter 2017) was used to update the
on three different datasets, namely – Adult, BankMarketing, model parameters, and a constant learning rate was applied
and 1995 income. The TabTransformer is firstly pre-trained throughout each training job. All models used early stopping
with a value of k on unlabeled data and then fine-tuned on based on the performance on the validation set and the early
labeled data. The number of labeled data is set as 256. The stopping patience (the number of epochs) is set as 15.
final fine-tuning accuracy is not much sensitive to the value
of k, as shown in Table 6. The pre-training curves of train- Hyper-parameters Search Space. The hyper-parameters
ing and validation accuracy for the three different replace- tuned for the GBDT model were the number of leaves in the
ment value k is shown in Figure 6. Note, that a constant trees with a search space {x ∈ Z|5 ≤ x ≤ 50}, the mini-
prediction model would achieve 85% accuracy for the 15% mum number of datapoints required to split a leaf in the trees
replacement value. with a search space {x ∈ Z|1 ≤ x ≤ 100}, the boosting
learning rate with a search space {x = 5 · 10u , u ∈ U| − 3 ≤
Dynamic versus Static Replacement. The third ablation x ≤ −1}, and the number of trees used for boosting with a
study is on dynamic vs static replacement in the pre-training search space {x ∈ Z|10 ≤ x ≤ 1000}.
approach. In dynamic replacement, we randomly replace For all of the deep models, the common hyper-parameters
feature values during pre-training over the epochs. That is include the weight decay factor with a search space {x =
the replacement is different in each epoch. Whereas in static 10u , u ∈ U| − 6 ≤ u ≤ −1}, the learning rate with a search
replacement, the random replacement is chosen once, and space {x = 10u , u ∈ U|−6 ≤ u ≤ −3}, the dropout proba-
then the same replacement is used in all the epochs. We bility with a search space {0, 0.1, 0.2, ...0.5}, and whether to
Table 5: Performance of TabTransformer with no column embedding, concatenation column embedding, and addition column
embedding. The evaluation metric is mean ± standard deviation of relative AUCs (in percentage) over all 15 datasets. Larger
value means better performance. The best model is bold for each row.
# of Transformers Layers No Col Embed Col Embed-Concat-1/4 Col Embed-Concat-1/8 Col Embed-Add
1 -0.59 ± 0.33 -2.01 ± 1.33 -0.27 ± 0.21 -1.11 ± 0.77
2 -0.59 ± 0.22 -0.37 ± 0.20 -0.14 ± 0.19 0.34 ± 0.27
3 -0.37 ± 0.19 0.04 ± 0.18 -0.02 ± 0.21 0.21 ± 0.23
6 0.54 ± 0.22 0.53 ± 0.24 0.70 ± 0.17 0.25 ± 0.23
12 0.66 ± 0.21 1.05 ± 0.31 0.73 ± 0.58 0.42 ± 0.39
Table 6: Fine-tuning performance of TabTransformer-RTD for different pre-training replacement value k. The number of labeled
data points is 256. The evaluation metrics are mean ± standard deviation of (1) AUC score over 5 cross-validation splits for
each dataset (in percentage); (2) relative AUCs over the 3 datasets (in percentage). Larger value means better performance. The
best model is bold for each column.
Table 7: Fine-tuning performance of TabTransformer-RTD for dynamic replacement, static replacement, and static replacement
with a shared classifier. The number of labeled data points is 256. The evaluation metrics are mean ± standard deviation of (1)
AUC score over 5 cross-validation splits for each dataset (in percentage) ; (2) relative AUCs over the 3 datasets (in percentage).
Larger value means better performance. The best model is bold for each column.
one-hot encode categorical variables or train learnable em- {32, 64, 128, 256}, and the number of layers {1, 2, 3, 6, 12}.
beddings. The search spaces of the first and second hidden layer in
For MLPs, they all used SELU activations (Klambauer MLP are exactly the same as those in MLP model setting.
et al. 2017) followed by batch normalization in each layer, The dimension of cφi , ` was chosen as d/8 based on the
and set the number of hidden layers as 2. The model-specific ablation study in Appendix A.
hyper-parameters tuned were the first hidden layer with a For Sparse MLP (Prune), its implementation was the same
search space {x = m ∗ l, m ∈ Z|1 ≤ m ≤ 8} where l is the as the MLP except that at every k epochs during train-
input size, and the second hidden layer with a search space ing the fraction p of weights with the smallest magnitude
{x = m ∗ l, m ∈ Z|1 ≤ m ≤ 3}. were permanently set to zero. The model-specific hyper-
For TabTransformer, the hidden (embedding) dimension, parameters tuned were the fraction p with a search space
the number of layers and the number of attention heads {x = 5 · 10u , u ∈ U| − 2 ≤ u ≤ −1}. The number of layers
in the Transformer were fixed to 32, 6, and 8 respectively and layer sizes are exactly the same as the setting in MLP.
during the experiments. The MLP layer sizes were fixed The parameter k is set as 10.
to {4 × l, 2 × l}, where l was the size of its input. How- For TabNet model, we implemented exactly as described
ever, these parameters were optimally selected based on 50 in Arik and Pfister (2019), though we also added the option
rounds of HPO run on 5 datasets. The search spaces were the to use a softmax attention instead of a sparsemax attention,
number of attention heads {2, 4, 8}, the hidden dimension and did not include the sparsification term in the loss func-
Figure 7: The pre-training curves of training and validation accuracy for dynamic mask, static mask, and static mask with a
shared predictor (classifier) for dataset Adult, BankMarketing, and 1995 income.
tion. The model-specific hyper-parameters tuned were the B.2 Feature Engineering
number of layers with a search space {x ∈ Z|3 ≤ x ≤ 10} , For categorical variables, the processing options include
the hidden dimension {x ∈ Z|8 ≤ x ≤ 128}, and the sparse whether to one-hot encode versus learn a parametric em-
coefficient with a search space {x = 10u , u ∈ U| − 6 ≤ u ≤ bedding, what embedding dimension to use, and how to ap-
−2}. ply dropout regularization (whether to drop vector elements
For VIB model, we implemented it as described in Alemi, or whole embeddings). In our experiments we found that
Fischer, and Dillon (2018). We used a diagonal covariance, learned embeddings nearly always improved performance
with 10 samples from the variational distribution during as long as the cardinality of the categorical variable is sig-
training and 20 during testing. The model-specific hyper- nificantly less than the number of data points, otherwise the
parameters tuned were the number of hidden layers and feature is merely a means for the model to overfit.
layer sizes, with exactly the same search spaces as MLP, and For scalar variables, the processing options include how
the number of mixture components in the mixture of gaus- to re-scale the variable (via quantiles, normalization, or log
sians used in the marginal distribution with a search space scaling) or whether to quantize the feature and treat it like
{x ∈ Z|3 ≤ x ≤ 10}. a categorical variable. While we have not explored this idea
For MLP (DAE), its pre-training used swap noise as fully, the best strategy is likely to use all the different types
described in Jahrer (2018). The model-specific hyper- of encoding in parallel, turning each scalar feature into three
parameters were exactly the same as MLP. re-scaled features and one categorical feature. Unlike learn-
For Pseudo Labeling (Lee 2013), since this method was ing embeddings for high-cardinality categorical features,
combined with deep models such as MLP, TabTransformer adding potentially-redundant encodings for scalar variables
and GBDT, the model-specific hyper-parameters were ex- should not lead to overfitting, but can make the difference
actly the same as the corresponding deep models mentioned between a feature being useful or not.
above. The unsupervised coefficient α is chosen as αf = For text variables, we simply encodes the number of
3, T1 = 30, T2 = 70. words and character in the text.
For Entropy Regularization (Grandvalet and Bengio
2006), it is the same as Pseudo Labeling. The additional C Appendix: Benchmark Dataset
model-specific hyper-parameter was the positive Lagrange Information and Experiment Results
multiplier λ with a search space {0.1, 0.2, ..., 0.9}.
Table 8: Benchmark datasets. All datasets are binary classification tasks. Positive Class% is the fraction of data points that
belongs to the positive class.
Dataset N Datapoints N Features Positive Class% Best Model TabTransformer-RTD TabTransformer-MLM MLP (ER)
albert 425240 79 50.0 TabTransformer-MLM 0.644 ± 0.015 0.647 ± 0.019 0.612 ± 0.017
hcdr main 307511 120 8.1 MLP (DAE) 0.592 ± 0.047 0.596 ± 0.047 0.602 ± 0.033
dota2games 92650 117 52.7 TabTransformer-MLM 0.526 ± 0.009 0.538 ± 0.011 0.519 ± 0.007
jannis 83733 55 2.0 TabTransformer-RTD 0.684 ± 0.055 0.665 ± 0.056 0.621 ± 0.022
volkert 58310 181 1.0 TabTransformer-RTD 0.693 ± 0.046 0.689 ± 0.042 0.657 ± 0.028
bank marketing 45211 16 11.7 MLP (PL) 0.771 ± 0.046 0.735 ± 0.040 0.792 ± 0.039
adult 34190 25 85.4 MLP (DAE) 0.580 ± 0.012 0.613 ± 0.014 0.609 ± 0.005
1995 income 32561 14 24.1 TabTransformer-MLM 0.840 ± 0.029 0.862 ± 0.018 0.839 ± 0.034
htru2 17898 8 9.2 MLP (DAE) 0.956 ± 0.007 0.958 ± 0.009 0.969 ± 0.012
online shoppers 12330 17 15.5 MLP (DAE) 0.790 ± 0.013 0.780 ± 0.024 0.855 ± 0.019
shrutime 10000 11 20.4 TabTransformer-RTD 0.752 ± 0.019 0.741 ± 0.019 0.725 ± 0.032
fabert 8237 801 11.3 MLP (PL) 0.535 ± 0.027 0.525 ± 0.019 0.572 ± 0.019
blastchar 7043 20 26.5 TabTransformer-MLM 0.806 ± 0.018 0.822 ± 0.009 0.803 ± 0.021
philippine 5832 309 50.0 TabTransformer-RTD 0.739 ± 0.027 0.729 ± 0.035 0.722 ± 0.031
insurance co 5822 85 6.0 MLP (PL) 0.601 ± 0.056 0.573 ± 0.077 0.575 ± 0.063
sylvine 5124 20 50.0 MLP (PL) 0.872 ± 0.031 0.898 ± 0.030 0.930 ± 0.015
spambase 4601 57 39.4 MLP (ER) 0.949 ± 0.005 0.945 ± 0.011 0.957 ± 0.008
jasmine 2984 145 50.0 TabTransformer-MLM 0.821 ± 0.019 0.837 ± 0.019 0.830 ± 0.022
seismicbumps 2583 18 6.6 TabTransformer (ER) 0.740 ± 0.088 0.738 ± 0.068 0.712 ± 0.074
qsar bio 1055 41 33.7 MLP (DAE) 0.875 ± 0.028 0.869 ± 0.036 0.880 ± 0.022
Table 11: (Continued) AUC score for semi-supervised learning models on all datasets with 50 fine-tune data points. Values are
the mean over 5 cross-validation splits, plus or minus the standard deviation. Larger values means better result.
Dataset MLP (PL) TabTransformer (ER) TabTransformer (PL) MLP (DAE) GBDT (PL)
albert 0.607 ± 0.013 0.580 ± 0.017 0.587 ± 0.012 0.612 ± 0.014 0.547 ± 0.032
hcdr main 0.599 ± 0.038 0.581 ± 0.023 0.570 ± 0.031 0.620 ± 0.028 0.531 ± 0.024
dota2games 0.520 ± 0.006 0.516 ± 0.009 0.519 ± 0.008 0.516 ± 0.004 0.505 ± 0.008
jannis 0.623 ± 0.035 0.582 ± 0.035 0.604 ± 0.013 0.626 ± 0.023 0.519 ± 0.047
volkert 0.653 ± 0.035 0.635 ± 0.024 0.639 ± 0.040 0.629 ± 0.019 0.525 ± 0.018
bank marketing 0.805 ± 0.036 0.744 ± 0.063 0.767 ± 0.058 0.786 ± 0.055 0.688 ± 0.057
adult 0.605 ± 0.021 0.568 ± 0.012 0.582 ± 0.024 0.616 ± 0.010 0.519 ± 0.024
1995 income 0.819 ± 0.042 0.813 ± 0.045 0.822 ± 0.048 0.811 ± 0.042 0.685 ± 0.084
htru2 0.970 ± 0.012 0.955 ± 0.007 0.951 ± 0.009 0.973 ± 0.003 0.919 ± 0.021
online shoppers 0.848 ± 0.021 0.816 ± 0.036 0.818 ± 0.028 0.858 ± 0.019 0.818 ± 0.032
shrutime 0.715 ± 0.044 0.748 ± 0.035 0.739 ± 0.034 0.683 ± 0.055 0.651 ± 0.093
fabert 0.577 ± 0.027 0.504 ± 0.020 0.516 ± 0.020 0.552 ± 0.013 0.534 ± 0.016
blastchar 0.799 ± 0.025 0.799 ± 0.013 0.792 ± 0.025 0.817 ± 0.016 0.729 ± 0.053
philippine 0.725 ± 0.022 0.689 ± 0.046 0.703 ± 0.050 0.717 ± 0.022 0.628 ± 0.085
insurance co 0.601 ± 0.057 0.575 ± 0.066 0.592 ± 0.080 0.522 ± 0.052 0.560 ± 0.081
sylvine 0.939 ± 0.013 0.891 ± 0.022 0.904 ± 0.027 0.925 ± 0.010 0.914 ± 0.021
spambase 0.951 ± 0.010 0.947 ± 0.006 0.948 ± 0.006 0.949 ± 0.012 0.899 ± 0.039
jasmine 0.819 ± 0.021 0.825 ± 0.024 0.819 ± 0.018 0.812 ± 0.029 0.755 ± 0.016
seismicbumps 0.678 ± 0.106 0.745 ± 0.080 0.713 ± 0.090 0.724 ± 0.049 0.601 ± 0.071
qsar bio 0.875 ± 0.015 0.851 ± 0.041 0.835 ± 0.053 0.888 ± 0.022 0.804 ± 0.057
Table 12: AUC score for semi-supervised learning models on all datasets with 200 fine-tune data points. Values are the mean
over 5 cross-validation splits, plus or minus the standard deviation. Larger values means better result.
Dataset N Datapoints N Features Positive Class% Best Model TabTransformer-RTD TabTransformer-MLM MLP (ER)
albert 425240 79 50.0 TabTransformer-MLM 0.699 ± 0.011 0.701 ± 0.014 0.642 ± 0.020
hcdr main 307511 120 8.1 TabTransformer-MLM 0.655 ± 0.040 0.668 ± 0.028 0.639 ± 0.027
dota2games 92650 117 52.7 TabTransformer-MLM 0.536 ± 0.012 0.549 ± 0.008 0.527 ± 0.012
jannis 83733 55 2.0 TabTransformer-RTD 0.713 ± 0.037 0.692 ± 0.024 0.665 ± 0.024
volkert 58310 181 12.7 TabTransformer-RTD 0.753 ± 0.022 0.742 ± 0.023 0.696 ± 0.033
bank marketing 45211 16 11.7 MLP (PL) 0.854 ± 0.020 0.838 ± 0.010 0.860 ± 0.008
adult 34190 25 85.4 MLP (ER) 0.596 ± 0.023 0.614 ± 0.012 0.623 ± 0.017
1995 income 32561 14 24.1 TabTransformer-MLM 0.866 ± 0.014 0.875 ± 0.011 0.868 ± 0.007
htru2 17898 8 9.2 MLP (DAE) 0.961 ± 0.008 0.963 ± 0.009 0.974 ± 0.007
online shoppers 12330 17 15.5 MLP (ER) 0.834 ± 0.015 0.838 ± 0.024 0.876 ± 0.019
shrutime 10000 11 20.4 TabTransformer-RTD 0.805 ± 0.017 0.783 ± 0.024 0.773 ± 0.013
fabert 8237 801 11.3 MLP (ER) 0.556 ± 0.023 0.561 ± 0.028 0.600 ± 0.046
blastchar 7043 20 26.5 TabTransformer-MLM 0.831 ± 0.010 0.841 ± 0.014 0.829 ± 0.010
philippine 5832 309 50.0 TabTransformer-RTD 0.757 ± 0.017 0.754 ± 0.016 0.732 ± 0.024
insurance co 5822 85 6.0 TabTransformer (ER) 0.667 ± 0.062 0.640 ± 0.043 0.601 ± 0.059
sylvine 5124 20 50.0 MLP (PL) 0.939 ± 0.008 0.948 ± 0.006 0.957 ± 0.008
spambase 4601 57 39.4 MLP (ER) 0.957 ± 0.006 0.955 ± 0.010 0.968 ± 0.009
jasmine 2984 145 50.0 TabTransformer-RTD 0.843 ± 0.016 0.843 ± 0.028 0.831 ± 0.019
seismicbumps 2583 18 6.6 TabTransformer-RTD 0.738 ± 0.063 0.708 ± 0.083 0.694 ± 0.088
qsar bio 1055 41 33.7 TabTransformer-RTD 0.896 ± 0.018 0.889 ± 0.030 0.895 ± 0.026
Table 13: (Continued) AUC score for semi-supervised learning models on all datasets with 200 fine-tune data points. Values
are the mean over 5 cross-validation splits, plus or minus the standard deviation. Larger values means better result.
Dataset MLP (PL) TabTransformer (ER) TabTransformer (PL) MLP (DAE) GBDT (PL)
albert 0.638 ± 0.024 0.630 ± 0.025 0.630 ± 0.021 0.646 ± 0.023 0.628 ± 0.015
hcdr main 0.631 ± 0.019 0.611 ± 0.030 0.605 ± 0.021 0.636 ± 0.027 0.579 ± 0.039
dota2games 0.527 ± 0.014 0.528 ± 0.017 0.525 ± 0.011 0.528 ± 0.012 0.506 ± 0.008
jannis 0.667 ± 0.036 0.619 ± 0.024 0.637 ± 0.026 0.659 ± 0.020 0.525 ± 0.030
volkert 0.693 ± 0.028 0.694 ± 0.002 0.689 ± 0.015 0.672 ± 0.015 0.612 ± 0.042
bank marketing 0.866 ± 0.008 0.853 ± 0.016 0.858 ± 0.009 0.863 ± 0.009 0.802 ± 0.012
adult 0.616 ± 0.014 0.582 ± 0.026 0.584 ± 0.017 0.611 ± 0.027 0.572 ± 0.040
1995 income 0.869 ± 0.009 0.848 ± 0.024 0.852 ± 0.015 0.865 ± 0.011 0.822 ± 0.020
htru2 0.974 ± 0.007 0.955 ± 0.007 0.954 ± 0.007 0.974 ± 0.010 0.946 ± 0.022
online shoppers 0.873 ± 0.030 0.857 ± 0.014 0.853 ± 0.017 0.873 ± 0.021 0.846 ± 0.019
shrutime 0.774 ± 0.018 0.803 ± 0.022 0.803 ± 0.024 0.763 ± 0.018 0.750 ± 0.050
fabert 0.595 ± 0.048 0.530 ± 0.027 0.522 ± 0.024 0.580 ± 0.020 0.573 ± 0.026
blastchar 0.829 ± 0.011 0.823 ± 0.011 0.823 ± 0.011 0.832 ± 0.013 0.783 ± 0.017
philippine 0.733 ± 0.018 0.736 ± 0.018 0.739 ± 0.024 0.720 ± 0.020 0.729 ± 0.024
insurance co 0.616 ± 0.045 0.715 ± 0.038 0.680 ± 0.034 0.612 ± 0.024 0.630 ± 0.087
sylvine 0.961 ± 0.004 0.951 ± 0.009 0.950 ± 0.010 0.955 ± 0.009 0.957 ± 0.005
spambase 0.965 ± 0.008 0.962 ± 0.006 0.960 ± 0.008 0.964 ± 0.009 0.957 ± 0.013
jasmine 0.839 ± 0.013 0.824 ± 0.024 0.841 ± 0.016 0.842 ± 0.014 0.826 ± 0.013
seismicbumps 0.684 ± 0.071 0.723 ± 0.080 0.727 ± 0.081 0.673 ± 0.070 0.603 ± 0.023
qsar bio 0.892 ± 0.033 0.871 ± 0.036 0.876 ± 0.032 0.891 ± 0.018 0.855 ± 0.035
Table 14: AUC score for semi-supervised learning models on all datasets with 500 fine-tune data points. Values are the mean
over 5 cross-validation splits, plus or minus the standard deviation. Larger values means better result.
Dataset N Datapoints N Features Positive Class% Best Model TabTransformer-RTD TabTransformer-MLM MLP (ER)
albert 425240 79 50.0 TabTransformer-RTD 0.711 ± 0.004 0.707 ± 0.006 0.666 ± 0.008
hcdr main 307511 120 8.1 TabTransformer-MLM 0.690 ± 0.038 0.698 ± 0.033 0.653 ± 0.019
dota2games 92650 117 52.7 TabTransformer-MLM 0.548 ± 0.008 0.557 ± 0.003 0.543 ± 0.008
jannis 83733 55 2.0 TabTransformer-RTD 0.747 ± 0.015 0.720 ± 0.018 0.707 ± 0.036
volkert 58310 181 12.7 TabTransformer-RTD 0.771 ± 0.016 0.760 ± 0.015 0.723 ± 0.016
bank marketing 45211 16 11.7 TabTransformer-RTD 0.879 ± 0.012 0.866 ± 0.016 0.869 ± 0.012
adult 34190 25 85.4 MLP (PL) 0.625 ± 0.011 0.647 ± 0.008 0.644 ± 0.015
1995 income 32561 14 24.1 MLP (DAE) 0.874 ± 0.008 0.880 ± 0.007 0.878 ± 0.002
htru2 17898 8 9.2 MLP (DAE) 0.964 ± 0.009 0.966 ± 0.009 0.973 ± 0.010
online shoppers 12330 17 15.5 MLP (ER) 0.859 ± 0.009 0.861 ± 0.014 0.888 ± 0.012
shrutime 10000 11 20.4 TabTransformer-RTD 0.831 ± 0.017 0.815 ± 0.004 0.793 ± 0.017
fabert 8237 801 11.3 MLP (ER) 0.618 ± 0.014 0.609 ± 0.019 0.621 ± 0.032
blastchar 7043 20 26.5 TabTransformer-RTD 0.840 ± 0.013 0.839 ± 0.015 0.829 ± 0.013
philippine 5832 309 50.0 TabTransformer-MLM 0.769 ± 0.028 0.772 ± 0.017 0.734 ± 0.024
insurance co 5822 85 6.0 TabTransformer (ER) 0.688 ± 0.039 0.642 ± 0.029 0.659 ± 0.023
sylvine 5124 20 50.0 MLP (PL) 0.955 ± 0.007 0.959 ± 0.006 0.967 ± 0.003
spambase 4601 57 39.4 MLP (ER) 0.966 ± 0.007 0.968 ± 0.008 0.975 ± 0.004
jasmine 2984 145 50.0 TabTransformer-RTD 0.847 ± 0.016 0.844 ± 0.011 0.837 ± 0.019
seismicbumps 2583 18 6.6 TabTransformer-RTD 0.758 ± 0.081 0.729 ± 0.069 0.682 ± 0.123
qsar bio 1055 41 33.7 MLP (DAE) 0.909 ± 0.024 0.889 ± 0.038 0.918 ± 0.023
Table 15: (Continued) AUC score for semi-supervised learning models on all datasets with 500 fine-tune data points. Values
are the mean over 5 cross-validation splits, plus or minus the standard deviation. Larger values means better result.
Dataset MLP (PL) TabTransformer (ER) TabTransformer (PL) MLP (DAE) GBDT (PL)
albert 0.662 ± 0.007 0.664 ± 0.011 0.643 ± 0.029 0.666 ± 0.006 0.653 ± 0.011
hcdr main 0.645 ± 0.022 0.623 ± 0.036 0.636 ± 0.031 0.657 ± 0.033 0.607 ± 0.035
dota2games 0.544 ± 0.010 0.538 ± 0.009 0.541 ± 0.010 0.542 ± 0.012 0.505 ± 0.005
jannis 0.698 ± 0.033 0.662 ± 0.007 0.660 ± 0.024 0.693 ± 0.024 0.521 ± 0.045
volkert 0.722 ± 0.012 0.712 ± 0.016 0.705 ± 0.021 0.712 ± 0.016 0.705 ± 0.016
bank marketing 0.876 ± 0.017 0.863 ± 0.008 0.868 ± 0.016 0.874 ± 0.012 0.838 ± 0.019
adult 0.651 ± 0.012 0.618 ± 0.023 0.618 ± 0.021 0.654 ± 0.016 0.647 ± 0.030
1995 income 0.880 ± 0.003 0.868 ± 0.008 0.869 ± 0.007 0.882 ± 0.001 0.839 ± 0.013
htru2 0.974 ± 0.007 0.960 ± 0.010 0.960 ± 0.008 0.976 ± 0.006 0.949 ± 0.007
online shoppers 0.885 ± 0.021 0.861 ± 0.011 0.860 ± 0.013 0.885 ± 0.019 0.865 ± 0.011
shrutime 0.800 ± 0.015 0.825 ± 0.013 0.822 ± 0.016 0.804 ± 0.015 0.788 ± 0.019
fabert 0.596 ± 0.046 0.573 ± 0.048 0.578 ± 0.033 0.617 ± 0.042 0.585 ± 0.025
blastchar 0.833 ± 0.013 0.834 ± 0.013 0.832 ± 0.011 0.833 ± 0.012 0.795 ± 0.021
philippine 0.740 ± 0.023 0.746 ± 0.020 0.735 ± 0.015 0.739 ± 0.017 0.749 ± 0.026
insurance co 0.646 ± 0.048 0.710 ± 0.040 0.666 ± 0.060 0.612 ± 0.013 0.672 ± 0.037
sylvine 0.968 ± 0.003 0.958 ± 0.005 0.958 ± 0.003 0.967 ± 0.003 0.967 ± 0.006
spambase 0.973 ± 0.005 0.968 ± 0.007 0.967 ± 0.006 0.972 ± 0.006 0.972 ± 0.005
jasmine 0.833 ± 0.009 0.833 ± 0.021 0.838 ± 0.018 0.842 ± 0.011 0.838 ± 0.022
seismicbumps 0.677 ± 0.103 0.687 ± 0.100 0.735 ± 0.081 0.696 ± 0.112 0.666 ± 0.063
qsar bio 0.914 ± 0.032 0.894 ± 0.036 0.895 ± 0.035 0.925 ± 0.034 0.908 ± 0.024
Table 16: AUC score for supervised learning models on all datasets. Values are the mean over 5 cross-validation splits, plus or
minus the standard deviation. Larger values means better result.
Dataset N Datapoints N Features Positive Class% Best Model Logistic Regression GBDT
ds name
albert 425240 79 50.0 GBDT 0.726 ± 0.001 0.763 ± 0.001
hcdr main 307511 120 8.1 GBDT 0.747 ± 0.004 0.756 ± 0.004
dota2games 92650 117 52.7 Logistic Regression 0.634 ± 0.003 0.621 ± 0.004
bank marketing 45211 16 11.7 TabTransformer 0.911 ± 0.005 0.933 ± 0.003
adult 34190 25 85.4 GBDT 0.721 ± 0.010 0.756 ± 0.011
1995 income 32561 14 24.1 TabTransformer 0.899 ± 0.002 0.906 ± 0.002
online shoppers 12330 17 15.5 GBDT 0.908 ± 0.015 0.930 ± 0.008
shrutime 10000 11 20.4 GBDT 0.828 ± 0.013 0.859 ± 0.009
blastchar 7043 20 26.5 GBDT 0.844 ± 0.010 0.847 ± 0.016
philippine 5832 309 50.0 TabTransformer 0.725 ± 0.022 0.812 ± 0.013
insurance co 5822 85 6.0 TabTransformer 0.736 ± 0.023 0.732 ± 0.022
spambase 4601 57 39.4 GBDT 0.947 ± 0.008 0.987 ± 0.005
jasmine 2984 145 50.0 GBDT 0.846 ± 0.017 0.862 ± 0.008
seismicbumps 2583 18 6.6 GBDT 0.749 ± 0.068 0.756 ± 0.084
qsar bio 1055 41 33.7 TabTransformer 0.847 ± 0.037 0.913 ± 0.031
Table 17: (Continued) AUC score for supervised learning models on all datasets. Values are the mean over 5 cross-validation
splits, plus or minus the standard deviation. Larger values means better result.