0% found this document useful (0 votes)
167 views

TabTransformer - Tabular Data Modeling Using Contextual Embeddings

1. The TabTransformer is a novel deep learning architecture that uses self-attention based Transformers to learn contextual embeddings of categorical features in tabular data. 2. Through experiments on 15 public datasets, TabTransformer outperforms state-of-the-art deep learning methods for tabular data by at least 1.0% average AUC, and matches the performance of tree-based ensemble models. 3. TabTransformer learns highly robust contextual embeddings against missing and noisy data features, and provides better interpretability than existing methods.
Copyright
© © All Rights Reserved
We take content rights seriously. If you suspect this is your content, claim it here.
Available Formats
Download as PDF, TXT or read online on Scribd
0% found this document useful (0 votes)
167 views

TabTransformer - Tabular Data Modeling Using Contextual Embeddings

1. The TabTransformer is a novel deep learning architecture that uses self-attention based Transformers to learn contextual embeddings of categorical features in tabular data. 2. Through experiments on 15 public datasets, TabTransformer outperforms state-of-the-art deep learning methods for tabular data by at least 1.0% average AUC, and matches the performance of tree-based ensemble models. 3. TabTransformer learns highly robust contextual embeddings against missing and noisy data features, and provides better interpretability than existing methods.
Copyright
© © All Rights Reserved
We take content rights seriously. If you suspect this is your content, claim it here.
Available Formats
Download as PDF, TXT or read online on Scribd
You are on page 1/ 17

TabTransformer: Tabular Data Modeling

Using Contextual Embeddings


Xin Huang,1 Ashish Khetan, 1 Milan Cvitkovic 2 Zohar Karnin 1
1
Amazon AWS
2
PostEra
[email protected], [email protected], [email protected], [email protected]
arXiv:2012.06678v1 [cs.LG] 11 Dec 2020

Abstract semi-supervised learning methods. This is due to the fact


that the basic decision tree learner does not produce reliable
We propose TabTransformer, a novel deep tabular data mod-
probability estimation to its predictions (Tanha, Someren,
eling architecture for supervised and semi-supervised learn-
ing. The TabTransformer is built upon self-attention based and Afsarmanesh 2017). (c) The state-of-the-art deep
Transformers. The Transformer layers transform the embed- learning methods (Devlin et al. 2019) to handle missing and
dings of categorical features into robust contextual embed- noisy data features do not apply to them. Also, robustness of
dings to achieve higher prediction accuracy. Through exten- tree-based models has not been studied much in literature.
sive experiments on fifteen publicly available datasets, we A classical and popular model that is trained using gra-
show that the TabTransformer outperforms the state-of-the- dient descent and hence allows end-to-end learning of im-
art deep learning methods for tabular data by at least 1.0% on age/text encoders is multi-layer perceptron (MLP). The
mean AUC, and matches the performance of tree-based en-
MLPs usually learn parametric embeddings to encode cat-
semble models. Furthermore, we demonstrate that the contex-
tual embeddings learned from TabTransformer are highly ro- egorical data features. But due to their shallow architecture
bust against both missing and noisy data features, and provide and context-free embeddings, they have the following lim-
better interpretability. Lastly, for the semi-supervised setting itations: (a) neither the model nor the learned embeddings
we develop an unsupervised pre-training procedure to learn are interpretable; (b) it is not robust against missing and
data-driven contextual embeddings, resulting in an average noisy data (Section 3.2); (c) for semi-supervised learning,
2.1% AUC lift over the state-of-the-art methods. they do not achieve competitive performance (Section 3.4).
Most importantly, MLPs do not match the performance of
tree-based models such as GBDT on most of the datasets
1 Introduction (Arik and Pfister 2019). To bridge this performance gap be-
Tabular data is the most common data type in many real- tween MLP and GBDT, researchers have proposed various
world applications such as recommender systems (Cheng deep learning models (Song et al. 2019; Cheng et al. 2016;
et al. 2016), online advertising (Song et al. 2019), and port- Arik and Pfister 2019; Guo et al. 2018). Although these deep
folio optimization (Ban, El Karoui, and Lim 2018). Many learning models achieve comparable prediction accuracy,
machine learning competitions such as Kaggle and KDD they do not address all the limitations of GBDT and MLP.
Cup are primarily designed to solve problems in tabular do- Furthermore, their comparisons are done in a limited setting
main. of a handful of datasets. In particular, in Section 3.3 we show
The state-of-the-art for modeling tabular data is tree- that when compared to standard GBDT on a large collection
based ensemble methods such as the gradient boosted of datasets, GBDT perform significantly better than these
decision trees (GBDT) (Chen and Guestrin 2016; recent models.
Prokhorenkova et al. 2018). This is in contrast to modeling In this paper, we propose TabTransformer to address
image and text data where all the existing competitive the limitations of MLPs and existing deep learning mod-
models are based on deep learning (Sandler et al. 2018; els, while bridging the performance gap between MLP and
Devlin et al. 2019). The tree-based ensemble models can GBDT. We establish performance gain of TabTransformer
achieve competitive prediction accuracy, are fast to train through extensive experiments on fifteen publicly available
and easy to interpret. These benefits make them highly datasets.
favourable among machine learning practitioners. However,
the tree-based models have several limitations in compari- The TabTransformer is built upon Transformers (Vaswani
son to deep learning models. (a) They are not suitable for et al. 2017) to learn efficient contextual embeddings of cate-
continual training from streaming data, and do not allow gorical features. Different from tabular domain, the applica-
efficient end-to-end learning of image/text encoders in tion of embeddings has been studied extensively in NLP. The
presence of multi-modality along with tabular data. (b) In use of embeddings to encode words in a dense low dimen-
their basic form they are not suitable for state-of-the-art sional space is prevalent in natural language processing. Be-
ginning from Word2Vec (Rong 2014) with the context-free
Preprint. word embeddings to BERT (Devlin et al. 2019) which pro-
vides the contextual word-token embeddings, embeddings 2. We investigate the resulting contextual embeddings and
have been widely studied and applied in practice in NLP. highlight their interpretability, contrasted to parametric
In comparison to context-free embeddings, the contextual context-free embeddings achieved by existing art.
embedding based models (Mikolov et al. 2011; Huang, Xu, 3. We demonstrate the robustness of TabTransformer against
and Yu 2015; Devlin et al. 2019) have achieved tremen- noisy and missing data.
dous success. In particular, self-attention based Transform-
ers (Vaswani et al. 2017) have become a standard component 4. We provide and extensively study a two-phase pre-
of NLP models to achieve state-of-the-art performance. The training then fine-tune procedure for tabular data, beating
effectiveness and interpretability of contextual embeddings the state-of-the-art performance of semi-supervised learn-
generated by Transformers have been also well studied (Co- ing methods.
enen et al. 2019; Brunner et al. 2019).
Motivated by the successful applications of Transform- 2 The TabTransformer
ers in NLP, we adapt them in tabular domain. In par- The TabTransformer architecture comprises a column em-
ticular, TabTransformer applies a sequence of multi-head bedding layer, a stack of N Transformer layers, and a multi-
attention-based Transformer layers on parametric embed- layer perceptron. Each Transformer layer (Vaswani et al.
dings to transform them into contextual embeddings, bridg- 2017) consists of a multi-head self-attention layer followed
ing the performance gap between baseline MLP and GBDT by a position-wise feed-forward layer. The architecture of
models. We investigate the effectiveness and interpretabil- TabTransformer is shown below in Figure 1.
ity of the resulting contextual embeddings generated by the
Transformers. We find that highly correlated features (in-
cluding feature pairs in the same column and cross column)
result in embedding vectors that are close together in Eu-
clidean distance, whereas no such pattern exists in context-
free embeddings learned in a baseline MLP model. We also
study the robustness of the TabTransformer against random
missing and noisy data. The contextual embeddings make
them highly robust in comparison to MLPs.
Furthermore, many existing deep learning models for tab-
ular data are designed for supervised learning scenario but
few are for semi-supervised leanring (SSL). Unfortunately,
the state-of-art SSL models developed in computer vision
(Voulodimos et al. 2018; Kendall and Gal 2017) and NLP
(Vaswani et al. 2017; Devlin et al. 2019) cannot be easily ex-
tended to tabular domain. Motivated by such challenges, we
exploit pre-training methodologies from the language mod-
els and propose a semi-supervised learning approach for pre-
training Transformers of our TabTransformer model using
unlabeled data.
One of the key benefits of our proposed method for
semi-supervised learning is the two independent training
phases: a costly pre-training phase on unlabeled data and
a lightweight fine-tuning phase on labeled data. This dif-
fers from many state-of-the-art semi-supervised methods
(Chapelle, Scholkopf, and Zien 2009; Oliver et al. 2018;
Stretcu et al. 2019) that require a single training job includ- Figure 1: The architecture of TabTransformer.
ing both the labeled and unlabeled data. The separated train-
ing procedure benefits the scenario where the model needs Let (x, y) denote a feature-target pair, where x ≡
to be pretrained once but fine-tuned multiple times for mul- {xcat , xcont }. The xcat denotes all the categorical features
tiple target variables. This scenario is in fact quite common and xcont ∈ Rc denotes all of the c continuous features. Let
in the industrial setting as companies tend to have one large xcat ≡ {x1 , x2 , · · · , xm } with each xi being a categorical
dataset (e.g. describing customers/products) and are inter- feature, for i ∈ {1, · · · , m}.
ested in applying multiple analyses on this data. To summa- We embed each of the xi categorical features into a para-
rize, we provide the following contributions: metric embedding of dimension d using Column embedding,
1. We propose TabTransformer, an architecture that provides which is explained below in detail. Let eφi (xi ) ∈ Rd for
and exploits contextual embeddings of categorical fea- i ∈ {1, · · · , m} be the embedding of the xi feature, and
tures. We provide extensive empirical evidence showing Eφ (xcat ) = {eφ1 (x1 ), · · · , eφm (xm )} be the set of embed-
TabTransformer is superior to both a baseline MLP and dings for all the categorical features.
recent deep networks for tabular data while matching the Next, these parametric embeddings Eφ (xcat ) are in-
performance of tree-based ensemble models (GBDT). putted to the first Transformer layer. The output of the
first Transformer layer is inputted to the second layer dimension of cφi , `, is a hyper-parameter. The unique iden-
Transformer, and so forth. Each parametric embedding tifier cφi ∈ R` distinguishes the classes in column i from
is transformed into contextual embedding when outputted those in the other columns.
from the top layer Transformer, through successive ag- The use of unique identifier is new and is particularly de-
gregation of context from other embeddings. We de- signed for tabular data. Rather in language modeling, em-
note the sequence of Transformer layers as a function beddings are element-wisely added with the positional en-
fθ . The function fθ operates on parametric embeddings coding of the word in the sentence. Since, in tabular data,
{eφ1 (x1 ), · · · , eφm (xm )} and returns the corresponding there is no ordering of the features, we do not use positional
contextual embeddings {h1 , · · · , hm } where hi ∈ Rd for encodings. An ablation study on different embedding strate-
i ∈ {1, · · · , m}. gies is given in Appendix A. The strategies include both dif-
The contextual embeddings {h1 , · · · , hm } are concate- ferent choices for `, d and element-wise adding the unique
nated along with the continuous features xcont to form a vec- identifier and feature-value specific embeddings rather than
tor of dimension (d × m + c). This vector is inputted to an concatenating them.
MLP, denoted by gψ , to predict the target y. Let H be the
cross-entropy for classification tasks and mean square error Pre-training the Embeddings. The contextual embed-
for regression tasks. We minimize the following loss func- dings explained above are learned in end-to-end supervised
tion L(x, y) to learn all the TabTransformer parameters in training using labeled examples. For a scenario, when there
an end-to-end learning by the first-order gradient methods. are a few labeled examples and a large number of unlabeled
The TabTransformer parameters include φ for column em- examples, we introduce a pre-training procedure to train the
bedding, θ for Transformer layers, and ψ for the top MLP Transformer layers using unlabeled data. This is followed by
layer. fine-tuning of the pre-trained Transformer layers along with
the top MLP layer using the labeled data. For fine-tuning,
L(x, y) ≡ H(gψ (fθ (Eφ (xcat )), xcont ), y) . (1) we use the supervised loss defined in Equation (1).
We explore two different types of pre-training procedures,
Below, we explain the Transformer layers and column em- the masked language modeling (MLM) (Devlin et al. 2019)
bedding. and the replaced token detection (RTD) (Clark et al. 2020).
Transformer. A Transformer (Vaswani et al. 2017) con- Given an input xcat = {x1 , x2 , ..., xm }, MLM randomly
sists of a multi-head self-attention layer followed by a selects k% features from index 1 to m and masks them as
position-wise feed-forward layer, with element-wise addi- missing. The Transformer layers along with the column em-
tion and layer-normalization being done after each layer. beddings are trained by minimizing cross-entropy loss of a
A self-attention layer comprises three parametric matrices - multi-class classifier that tries to predict the original features
Key, Query and Value. Each input embedding is projected on of the masked features, from the contextual embedding out-
to these matrices, to generate their key, query and value vec- putted from the top-layer Transformer.
tors. Formally, let K ∈ Rm×k , Q ∈ Rm×k and V ∈ Rm×v Instead of masking features, RTD replaces the original
be the matrices comprising key, query and value vectors of feature by a random value of that feature. Here, the loss is
all the embeddings, respectively, and m be the number of minimized for a binary classifier that tries to predict whether
embeddings inputted to the Transformer, k and v be the di- or not the feature has been replaced. The RTD procedure
mensions of the key and value vectors, respectively. Every as proposed in (Clark et al. 2020) uses auxiliary genera-
input embedding attends to all other embeddings through a tor for sampling a subset of features that a feature should
Attention head, which is computed as follows: be replaced with. The reason they used an auxiliary en-
coder network as the generator is that there are tens of thou-
Attention(K, Q, V ) = A · V, (2) sands of tokens in language data and a uniformly random
√ token is too easy to detect. In contrast, (a) the number of
where A = softmax((QK T )/ k). For each embedding, classes within each categorical feature is typically limited;
the attention matrix A ∈ Rm×m calculates how much it (b) a different binary classifier is defined for each column
attends to other embeddings, thus transforming the embed- rather than a shared one, as each column has its own embed-
ding into contextual one. The output of the attention head ding lookup table. We name the two pre-training methods as
of dimension v is projected back to the embedding of di- TabTransformer-MLM and TabTransformer-RTD. In our ex-
mension d through a fully connected layer, which in turn is periments, the replacement value k is set to 30. An ablation
passed through two position-wise feed-forward layers. The study on k is given in Appendix A.
first layer expands the embedding to four times its size and
the second layer projects it back to its original size. 3 Experiments
Column embedding. For each categorical feature (col- Data. We evaluate TabTransformer and baseline models
umn) i, we have an embedding lookup table eφi (.), for on 15 publicly available binary classification datasets from
i ∈ {1, 2, ..., m}. For ith feature with di classes, the em- the UCI repository (Dua and Graff 2017), the AutoML
bedding table eφi (.) has (di + 1) embeddings where the Challenge (Guyon et al. 2019), and Kaggle (Kaggle, Inc.
additional embedding corresponds to a missing value. The 2017) for both supervised and semi-supervised learning.
embedding for the encoded value xi = j ∈ [0, 1, 2, .., di ] is Each dataset is divided into five cross-validation splits. The
eφi (j) = [cφi , wφij ], where cφi ∈ R` , wφij ∈ Rd−` . The training/validation/testing proportion of the data for each
split are 65/15/20%. The number of categorical features Table 1: Comparison between TabTransfomers and the base-
across dataset ranges from 2 to 136. In the semi-supervised line MLP. The evaluation metric is AUC in percentage.
experiments, for each dataset and split, the first p observa-
tions in the training data are marked as the labeled data and Dataset Baseline MLP TabTransformer Gain (%)
the remaining training data as the unlabeled set. The value albert 74.0 75.7 1.7
of p is chosen as 50, 200, and 500, corresponding to 3 dif- 1995 income 90.5 90.6 0.1
ferent scenarios. In the supervised experiments, each train- dota2games 63.1 63.3 0.2
ing dataset is fully labeled. Summary statistics of the all the hcdr main 74.3 75.1 0.8
datasets are provided in Table 8, 9 in Appendix C. adult 72.5 73.7 1.2
bank marketing 92.9 93.4 0.5
Setup. For the TabTransformer, the hidden (embedding) blastchar 83.9 83.5 -0.4
insurance co 69.7 74.4 4.7
dimension, the number of layers and the number of atten- jasmine 85.1 85.3 0.2
tion heads are fixed to 32, 6, and 8 respectively. The MLP online shoppers 91.9 92.7 0.8
layer sizes are set to {4 × l, 2 × l}, where l is the size of its philippine 82.1 83.4 1.3
input. For hyperparameter optimization (HPO), each model qsar bio 91.0 91.8 0.8
is given 20 HPO rounds for each cross-validation split. For seismicbumps 73.5 75.1 1.6
evaluation metrics, we use the Area under the curve (AUC) shrutime 84.6 85.6 1.0
(Bradley 1997). Note, the pre-training is only applied in spambase 98.4 98.5 0.1
semi-supervised scenario. We do not find much benefit in
using it when the entire data is labeled. Its benefit is evident
when there is a large number of unlabeled examples and a with each other and form clusters in the embedding space.
few labeled examples. Since in this scenario the pre-training Each cluster is annotated by a set of labels. For example, we
provides a representation of the data that could not have been find that all of the client-based features (color markers) such
learned based only on the labeled examples. as job, education level and martial status stay close in the
The experiment section is organized as follows. In Section center and non-client based features (gray markers) such as
3.1, we first demonstrate the effectiveness of the attention- month (last contact month of the year), day (last contact day
based Transformer by comparing our model with the one of the week) lie outside the central area; in the bottom clus-
without the Transformers (equivalently an MLP model). In ter the embedding of owning a housing loan stays close with
Section 3.2, we illustrate the robustness of TabTransformer that of being default; over the left cluster, embeddings of be-
against noisy and missing data. Finally, extensive evaluation ing a student, martial status as single, not having a housing
on various methods are conducted in Section 3.3 for super- loan, and education level as tertiary get together; and in the
vised learning, and in Section 3.4 for semi-supervised learn- right cluster, education levels are closely associated with the
ing. occupation types (Torpey and Watson 2014). In Figure 2, the
center and right plots are t-SNE plots of embeddings before
3.1 The Effectiveness of the Transformer Layers being passed through the Transformer and the context-free
First, a comparison between TabTransformers and the base- embeddings from MLP, respectively. For the embeddings
line MLP is conducted in a supervised learning scenario. We before being passed into the Transformer, it starts to dis-
remove the Transformer layers fθ from the architecture, fix tinguish the non-client based features (gray markers) from
the rest of the components, and compare it with the origi- the client-based features (color markers). For the embed-
nal TabTransformer. The model without the attention-based dings from MLP, we do not observe such pattern and many
Transformer layers is equivalently an MLP. The dimension categorical features which are not semantically similar are
of embeddings d for categorical features is set as 32 for grouped together, as indicated by the annotation in the plot.
both models. The comparison results over 15 datasets are In addition to prove the effectiveness of Transformer lay-
presented in Table 1. The TabTransformer with the Trans- ers, on the test data we take all of the contextual embeddings
former layers outperforms the baseline MLP on 14 out of 15 from each Transformer layer of a trained TabTransformer,
datasets with an average 1.0% gain in AUC. use the embeddings from each layer along with the contin-
Next, we take contextual embeddings from different lay- uous variables as features, and separately fit a linear model
ers of the Transformer and compute a t-SNE plot (Maaten with target y. Since all of the experimental datasets are for
and Hinton 2008) to visualize their similarity in function binary classification, the linear model is logistic regression.
space. More precisely, for each dataset we take its test data, The motivation for this evaluation is defining the success of
pass their categorical features into a trained TabTransformer, a simple linear model as a measure of quality for the learned
and extract all contextual embeddings (across all columns) embeddings.
from a certain layer of the Transformer. The t-SNE algo- For each dataset and each layer, an average of CV-score
rithm is then used to reduce each embedding to a 2D point in AUC on the test data is computed. The evaluation is con-
in the t-SNE plot. Figure 2 (left) shows the 2D visualization ducted on the entire test data with number of data points over
of embeddings from the last layer of the Transformer for 9000. Figure 3 presents results for dataset BankMarketing,
dataset bank marketing. Each marker in the plot represents Adult, and QSAR Bio. For each line, each prediction score is
an average of 2D points over the test data points for a certain normalized by the “best score” from an end-to-end trained
class. We can see that semantically similar classes are close TabTransformer for the corresponding dataset. We also ex-
Figure 2: t-SNE plots of learned embeddings for categorical features on dataset BankMarketing. Left: TabTransformer-the
embeddings generated from the last layer of the attention-based Transformer. Center: TabTransformer-the embeddings before
being passed into the attention-based Transformer. Right: The embeddings learned from MLP.

plore the average and maximum pooling strategy (Howard 3.2 The Robustness of TabTransformer
and Ruder 2018) rather than concatenation of embeddings We further demonstrate the robustness of TabTransformer
as the features for the linear model. The upward pattern on the noisy data and data with missing values, against the
clearly shows that embeddings becomes more effective as baseline MLP. We consider these two scenarios only on cat-
the Transformer layer progresses. In contrast, the embed- egorical features to specifically prove the robustness of con-
dings from MLP (the single black markers) perform worse textual embeddings from the Transformer layers.
with a linear model. Furthermore, the last value in each line
close to 1.0 indicates that a linear model with the last layer of Noisy Data. On the test examples, we firstly contaminate
embeddings as features can achieve reliable accuracy, which the data by replacing a certain number of values by randomly
confirms our assumption. generated ones from the corresponding columns (features).
Next, the noisy data are passed into a trained TabTrans-
former to compute a prediction AUC score. Results on a set
of 3 different dataets are presented in Figure 4. As the noisy
rate increases, TabTransformer performs better in prediction
accuracy and thus is more robust than MLP. In particular
notice the Blastchar dataset where the performance is near
identical with no noise, yet as the noise increases, TabTrans-
former becomes significantly more performant compared to
the baseline. We conjecture that the robustness comes from
the contextual property of the embeddings. Despite a feature
being noisy, it draws information from the correct features
allowing for a certain amount of correction.
Data with Missing Values. Similarly, on the test data we
artificially select a number of values to be missing and send
the data with missing values to a trained TabTransformer
to compute the prediction score. There are two options to
handle the embeddings of missing values: (1) Use the aver-
age learned embeddings over all classes in the corresponding
column; (2) the embedding for the class of missing value, the
additional embedding for each column mentioned in Section
2. Since the benchmark datasets do not contain enough miss-
ing values to effectively train the embedding in option (2),
we use the average embedding in (1) for imputation. Re-
Figure 3: Predictions of liner models using features as the sults on the same 3 datasets are presented in Figure 5. We
embeddings extracted from different Transformer layers in can see the same patterns of the noisy data case, i.e. that the
TabTransformer. Layer 0 corresponds to the embeddings TabTransformer shows better stability than MLP in handling
before being passed into the Transformer layers. For each missing values.
dataset, each prediction score is normalized by the “best
score” from an end-to-end trained TabTransformer. 3.3 Supervised Learning
Here we compare the performance of TabTransformer
against following four categories of methods: (a) Logistic
Figure 4: Performance of TabTransformer and MLP with Figure 5: Performance of TabTransformer and MLP un-
noisy data. For each dataset, each prediction score is nor- der missing data scenario. For each dataset, each prediction
malized by the score of TabTransformer at 0 noise. score is normalized by the score of TabTransformer trained
without missing values.
Table 2: Model performance in supervised learning. The
evaluation metric is mean ± standard deviation of AUC
score over the 15 datasets for each model. Larger the num- unlabeled samples. Specifically, we compare our pretrained
ber, better the result. The top 2 numbers are bold. and then fine-tuned TabTransformer-RTD/MLM against fol-
lowing semi-supervised models: (a) Entropy Regulariza-
Model Name Mean AUC (%) tion (ER) (Grandvalet and Bengio 2006) combined with
TabTransformer 82.8 ± 0.4 MLP and TabTransformer (b) Pseudo Labeling (PL) (Lee
MLP 81.8 ± 0.4 2013) combined with MLP, TabTransformer, and GBDT
GBDT 82.9 ± 0.4 (Jain 2017) (c) MLP (DAE): an unsupervised pre-training
Sparse MLP 81.4 ± 0.4 method designed for deep models on tabular data: the swap
Logistic Regression 80.4 ± 0.4 noise Denoising AutoEncoder (Jahrer 2018).
TabNet 77.1 ± 0.5 The pre-training models TabTransformer-MLM,
VIB 80.5 ± 0.4 TabTransformer-RTD and MLP (DAE) are firstly pre-
trained on the entire unlabeled training data and then
fine-tuned on labeled data. The semi-supervised learning
methods, Pseudo Labeling and Entropy Regularization, are
regression and GBDT (b) MLP and a sparse MLP following trained on the mix of labeled and unlabeled training data.
(Morcos et al. 2019) (c) TabNet model of Arik and Pfister To better present results, we split the set of 15 datasets into
(2019) (d) and the Variational Information Bottleneck model two subsets. The first set includes 6 datasets with more than
(VIB) of Alemi et al. (2017). 30K data points and the second set includes remaining 9
Results are summarized in Table 2. TabTransformer, datasets.
MLP, and GBDT are the top 3 performers. The TabTrans-
former outperforms the baseline MLP with an average 1.0% The results are presented in Table 3 and Table 4. When
gain and perform comparable with the GBDT. Furthermore, the number of unlabeled data is large, Table 3 shows that
the TabTransformer is significantly better than TabNet and our TabTransformer-RTD and TabTransformer-MLM sig-
VIB, the recent deep networks for tabular data. For experi- nificantly outperform all the other competitors. Particularly,
ment and model details, see Appendix B. The models’ per- TabTransformer-RTD/MLM improves over all the other
formances on each individual dataset are presented in Table competitors by at least 1.2%, 2.0% and 2.1% on mean AUC
16 and 17 in Appendix C. for the scenario of 50, 200, and 500 labeled data points
respectively. The Transformer-based semi-supervised learn-
ing methods TabTransformer (ER) and TabTransformer (PL)
3.4 Semi-supervised Learning and the tree-based semi-supervised learning method GBDT
Lastly, we evaluate the TabTransformer under the semi- (PL) perform worse than the average of all the models. When
supervised learning scenario where few labeled training ex- the number of unlabeled data becomes smaller, as shown in
amples are available together with a significant number of Table 4, TabTransformer-RTD still outperforms most of its
Table 3: Semi-supervised learning results for 8 datasets each deep models designed specifically for tabular data, there are
with more than 30K data points, for different number of la- deep versions of factorization machines (Guo et al. 2018;
beled data points. Evaluation metrics are mean AUC in per- Xiao et al. 2017), Transformers-based methods (Song et al.
centage. Larger the number, better the result. 2019; Li et al. 2020; Sun et al. 2019), and deep versions of
decision-tree-based algorithms (Ke et al. 2019; Yang, Mo-
# Labeled data 50 200 500 rillo, and Hospedales 2018). In particular, (Song et al. 2019)
TabTransformer-RTD 66.6 ± 0.6 70.9 ± 0.6 73.1 ± 0.6 applies one layer of multi-head attention on embeddings to
TabTransformer-MLM 66.8 ± 0.6 71.0 ± 0.6 72.9 ± 0.6 learn higher order features. The higher order features are
MLP (ER) 65.6 ± 0.6 69.0 ± 0.6 71.0 ± 0.6 concatenated and inputted to a fully connected layer to make
MLP (PL) 65.4 ± 0.6 68.8 ± 0.6 71.0 ± 0.6 the final prediction. (Li et al. 2020) use self-attention lay-
TabTransformer (ER) 62.7 ± 0.6 67.1 ± 0.6 69.3 ± 0.6
TabTransformer (PL) 63.6 ± 0.6 67.3 ± 0.7 69.3 ± 0.6 ers and track the attention scores to obtain feature impor-
MLP (DAE) 65.2 ± 0.5 68.5 ± 0.6 71.0 ± 0.6 tance scores. (Sun et al. 2019) combine the Factorization
GBDT (PL) 56.5 ± 0.5 63.1 ± 0.6 66.5 ± 0.7 Machine model with transformer mechanism. All 3 papers
are focused on recommendation systems making it hard to
Table 4: Semi-supervised learning results for 12 datasets have a clear comparison with this paper. Other models have
each with less than 30K data points, for different number been designed around the purported properties of tabular
of labeled data points. Evaluation metrics are mean AUC in data such as low-order and sparse feature interactions. These
percentage. Larger the number, better the result. include Deep & Cross Networks (Wang et al. 2017), Wide
& Deep Networks (Cheng et al. 2016), TabNets (Arik and
# Labeled data 50 200 500 Pfister 2019), and AdaNet (Cortes et al. 2016).
Semi-supervised learning. (Izmailov et al. 2019) give a
TabTransformer-RTD 78.6 ± 0.6 81.6 ± 0.5 83.4 ± 0.5 semi-supervised method based on density estimation and
TabTransformer-MLM 78.5 ± 0.6 81.0 ± 0.6 82.4 ± 0.5
MLP (ER) 79.4 ± 0.6 81.1 ± 0.6 82.3 ± 0.6 evaluate their approach on tabular data. Pseudo labeling
MLP (PL) 79.1 ± 0.6 81.1 ± 0.6 82.0 ± 0.6 (Lee 2013) is a simple, efficient and popular baseline
TabTransformer (ER) 77.9 ± 0.6 81.2 ± 0.6 82.1 ± 0.6 method. The Pseudo labeling uses the current network to
TabTransformer (PL) 77.8 ± 0.6 81.0 ± 0.6 82.1 ± 0.6 infer pseudo-labels of unlabeled examples, by choosing the
MLP (DAE) 78.5 ± 0.7 80.7 ± 0.6 82.2 ± 0.6 most confident class. These pseudo-labels are treated like
GBDT (PL) 73.4 ± 0.7 78.8 ± 0.6 81.3 ± 0.6 human-provided labels in the cross entropy loss. Label prop-
agation (Zhu and Ghahramani 2002), (Iscen et al. 2019) is
a similar approach where a node’s labels propagate to all
competitors but with a marginal improvement. nodes according to their proximity, and are used by the train-
Furthermore, we observe that when the number of unla- ing model as if they were the true labels. Another stan-
beled data is small as shown in Table 4, TabTransformer- dard method in semi-supervised learning is entropy reg-
RTD performs better than TabTransformer-MLM, thanks to ularization (Grandvalet and Bengio 2005; Sajjadi, Javan-
its easier pre-training task (a binary classification) than that mardi, and Tasdizen 2016). It adds average per-sample en-
of MLM (a multi-class classification). This is consistent with tropy for the unlabeled examples to the original loss func-
the finding of the ELECTRA paper (Clark et al. 2020). In tion for the labeled examples. Another classical approach of
Table 4, with only 50 labeled data points, MLP (ER) and semi-supervised learning is co-training (Nigam and Ghani
MLP (PL) beat our TabTransformer-RTD/MLM. This can 2000). However, the recent approaches - entropy regulariza-
be attributed to the fact that there is room for improvement tion and pseudo labeling - are typically better and more pop-
in our fine-tuning procedure. In particular, our approach al- ular. A succinct review of semi-supervised learning methods
lows to obtain informative embeddings but does not allow in general can be found in (Oliver et al. 2019; Chappelle,
the weights of the classifier itself to be trained with unla- Schölkopf, and Zien 2010).
belled data. Since this issue does not occur for ER and PL,
they obtain an advantage in extremely small labelled set. We 5 Conclusion
point out however that this only means that the methods are
We proposed TabTransformer, a novel deep tabular data
complementary and mention that a possible follow up could
modeling architecture for supervised and semi-supervised
combine the best of all approaches.
learning. We provide extensive empirical evidence showing
Both evaluation results, Table 3 and Table 4, show that
TabTransformer significantly outperforms MLP and recent
our TabTransformer-RTD and Transformers-MLM models
deep networks for tabular data while matching the perfor-
are promising in extracting useful information from unla-
mance of tree-based ensemble models (GBDT). We pro-
beled data to help supervised training, and are particularly
vide and extensively study a two-phase pre-training then
useful when the size of unlabeled data is large. For model
fine-tune procedure for tabular data, beating the state-of-the-
performance on each individual dataset see Table 10, 11, 12,
art performance of semi-supervised learning methods. Tab-
13, 14, 15 in Appendix C.
Transformer shows promising results for robustness against
noisy and missing data, and interpretability of the contex-
4 Related Work tual embeddings. For future work, it would be interesting to
Supervised learning. Standard MLPs have been applied to investigate them in detail.
tabular data for many years (De Brébisson et al. 2015). For
References Grandvalet, Y.; and Bengio, Y. 2005. Semi-supervised learn-
Alemi, A. A.; Fischer, I.; and Dillon, J. V. 2018. Uncertainty ing by entropy minimization. In Advances in neural informa-
in the Variational Information Bottleneck. arXiv:1807.00906 tion processing systems, 529–536.
[cs, stat] URL http://arxiv.org/abs/1807.00906. ArXiv: Grandvalet, Y.; and Bengio, Y. 2006. Entropy regularization.
1807.00906. Semi-supervised learning 151–168.
Alemi, A. A.; Fischer, I.; Dillon, J. V.; and Murphy, K. 2017. Guo, H.; Tang, R.; Ye, Y.; Li, Z.; He, X.; and Dong, Z. 2018.
Deep Variational Information Bottleneck. International Con- DeepFM: An End-to-End Wide & Deep Learning Framework
ference on Learning Representations abs/1612.00410. URL for CTR Prediction. arXiv:1804.04950 [cs, stat] URL http:
https://arxiv.org/abs/1612.00410. //arxiv.org/abs/1804.04950. ArXiv: 1804.04950.
Arik, S. O.; and Pfister, T. 2019. TabNet: Attentive Inter-
pretable Tabular Learning. arXiv preprint arXiv:1908.07442 Guyon, I.; Sun-Hosoya, L.; Boullé, M.; Escalante, H. J.;
URL https://arxiv.org/abs/1908.07442. Escalera, S.; Liu, Z.; Jajetic, D.; Ray, B.; Saeed, M.; Se-
bag, M.; Statnikov, A.; Tu, W.; and Viegas, E. 2019. Anal-
Ban, G.-Y.; El Karoui, N.; and Lim, A. E. 2018. Machine ysis of the AutoML Challenge series 2015-2018. In Au-
learning and portfolio optimization. Management Science toML, Springer series on Challenges in Machine Learn-
64(3): 1136–1154. ing. URL https://www.automl.org/wp-content/uploads/2018/
Bradley, A. P. 1997. The use of the area under the ROC curve 09/chapter10-challenge.pdf.
in the evaluation of machine learning algorithms. Pattern Howard, J.; and Ruder, S. 2018. Universal language
recognition 30(7): 1145–1159. model fine-tuning for text classification. arXiv preprint
Brunner, G.; Liu, Y.; Pascual, D.; Richter, O.; and Watten- arXiv:1801.06146 .
hofer, R. 2019. On the validity of self-attention as explana-
Huang, Z.; Xu, W.; and Yu, K. 2015. Bidirectional
tion in transformer models. arXiv preprint arXiv:1908.04211
LSTM-CRF models for sequence tagging. arXiv preprint
.
arXiv:1508.01991 .
Chapelle, O.; Scholkopf, B.; and Zien, A. 2009. Semi-
supervised learning). IEEE Transactions on Neural Networks Iscen, A.; Tolias, G.; Avrithis, Y.; and Chum, O. 2019. Label
20(3): 542–542. propagation for deep semi-supervised learning. In Proceed-
ings of the IEEE Conference on Computer Vision and Pattern
Chappelle, O.; Schölkopf, B.; and Zien, A. 2010. Semi- Recognition, 5070–5079.
supervised learning. Adaptive Computation and Machine
Learning. Izmailov, P.; Kirichenko, P.; Finzi, M.; and Wilson, A. G.
2019. Semi-Supervised Learning with Normalizing Flows.
Chen, T.; and Guestrin, C. 2016. Xgboost: A scalable tree arXiv:1912.13025 [cs, stat] URL http://arxiv.org/abs/1912.
boosting system. In Proceedings of the 22nd acm sigkdd in- 13025. ArXiv: 1912.13025.
ternational conference on knowledge discovery and data min-
ing, 785–794. Jahrer, M. 2018. Porto Seguro’s Safe Driver Prediction. URL
Cheng, H.-T.; Koc, L.; Harmsen, J.; Shaked, T.; Chandra, T.; https://kaggle.com/c/porto-seguro-safe-driver-prediction.
Aradhye, H.; Anderson, G.; Corrado, G.; Chai, W.; Ispir, M.; Jain, S. 2017. Introduction to Pseudo-Labelling
et al. 2016. Wide & deep learning for recommender systems. : A Semi-Supervised learning technique. https:
In Proceedings of the 1st workshop on deep learning for rec- //www.analyticsvidhya.com/blog/2017/09/pseudo-labelling-
ommender systems, 7–10. semi-supervised-learning-technique/.
Clark, K.; Luong, M.-T.; Le, Q. V.; and Manning, C. D. Kaggle, Inc. 2017. The State of ML and Data Science 2017.
2020. ELECTRA: Pre-training Text Encoders as Discrimi- URL https://www.kaggle.com/surveys/2017.
nators Rather Than Generators. In International Conference
on Learning Representations. URL https://openreview.net/ Ke, G.; Meng, Q.; Finley, T.; Wang, T.; Chen, W.;
forum?id=r1xMH1BtvB. Ma, W.; Ye, Q.; and Liu, T.-Y. 2017. LightGBM: A
highly efficient gradient boosting decision tree. In Ad-
Coenen, A.; Reif, E.; Yuan, A.; Kim, B.; Pearce, A.; Viégas, vances in Neural Information Processing Systems, 3146–
F.; and Wattenberg, M. 2019. Visualizing and measuring the 3154. URL https://papers.nips.cc/paper/6907-lightgbm-a-
geometry of bert. arXiv preprint arXiv:1906.02715 . highly-efficient-gradient-boosting-decision-tree.pdf.
Cortes, C.; Gonzalvo, X.; Kuznetsov, V.; Mohri, M.; and
Ke, G.; Zhang, J.; Xu, Z.; Bian, J.; and Liu, T.-Y. 2019.
Yang, S. 2016. AdaNet: Adaptive Structural Learning of Ar-
TabNN: A Universal Neural Network Solution for Tabular
tificial Neural Networks.
Data. URL https://openreview.net/forum?id=r1eJssCqY7.
De Brébisson, A.; Simon, E.; Auvolat, A.; Vincent, P.; and
Bengio, Y. 2015. Artificial Neural Networks Applied to Taxi Kendall, A.; and Gal, Y. 2017. What uncertainties do we need
Destination Prediction. In Proceedings of the 2015th Inter- in bayesian deep learning for computer vision? In Advances
national Conference on ECML PKDD Discovery Challenge in neural information processing systems, 5574–5584.
- Volume 1526, ECMLPKDDDC’15, 40–51. Aachen, DEU: Klambauer, G.; Unterthiner, T.; Mayr, A.; and Hochreiter, S.
CEUR-WS.org. 2017. Self-normalizing neural networks. In Advances in neu-
Devlin, J.; Chang, M.-W.; Lee, K.; and Toutanova, K. 2019. ral information processing systems, 971–980.
BERT: Pre-training of Deep Bidirectional Transformers for Lee, D.-H. 2013. Pseudo-label: The simple and efficient
Language Understanding. In NAACL-HLT. semi-supervised learning method for deep neural networks.
Dua, D.; and Graff, C. 2017. UCI Machine Learning Reposi- In Workshop on challenges in representation learning, ICML,
tory. URL http://archive.ics.uci.edu/ml. volume 3, 2.
Li, Z.; Cheng, W.; Chen, Y.; Chen, H.; and Wang, W. 2020. Song, W.; Shi, C.; Xiao, Z.; Duan, Z.; Xu, Y.; Zhang, M.;
Interpretable Click-Through Rate Prediction through Hier- and Tang, J. 2019. AutoInt: Automatic Feature Interaction
archical Attention. In Proceedings of the 13th Interna- Learning via Self-Attentive Neural Networks. Proceedings
tional Conference on Web Search and Data Mining, 313– of the 28th ACM International Conference on Information
321. Houston TX USA: ACM. ISBN 978-1-4503-6822-3. and Knowledge Management - CIKM ’19 1161–1170. doi:
doi:10.1145/3336191.3371785. URL http://dl.acm.org/doi/ 10.1145/3357384.3357925. URL http://arxiv.org/abs/1810.
10.1145/3336191.3371785. 11921. ArXiv: 1810.11921.
Loshchilov, I.; and Hutter, F. 2017. Decoupled Weight De- Stretcu, O.; Viswanathan, K.; Movshovitz-Attias, D.; Pla-
cay Regularization. In International Conference on Learning tanios, E.; Ravi, S.; and Tomkins, A. 2019. Graph Agree-
Representations. URL https://arxiv.org/abs/1711.05101. ment Models for Semi-Supervised Learning. In Advances in
Maaten, L. v. d.; and Hinton, G. 2008. Visualizing data using Neural Information Processing Systems 32, 8713–8723. Cur-
t-SNE. Journal of machine learning research 9(Nov): 2579– ran Associates, Inc. URL http://papers.nips.cc/paper/9076-
2605. graph-agreement-models-for-semi-supervised-learning.pdf.
Sun, Q.; Cheng, Z.; Fu, Y.; Wang, W.; Jiang, Y.-G.; and Xue,
Mikolov, T.; Kombrink, S.; Burget, L.; Černockỳ, J.; and
X. 2019. DeepEnFM: Deep neural networks with Encoder
Khudanpur, S. 2011. Extensions of recurrent neural net-
enhanced Factorization Machine URL https://openreview.net/
work language model. In 2011 IEEE international confer-
forum?id=SJlyta4YPS.
ence on acoustics, speech and signal processing (ICASSP),
5528–5531. IEEE. Tanha, J.; Someren, M.; and Afsarmanesh, H. 2017. Semi-
supervised self-training for decision tree classifiers. Interna-
Morcos, A. S.; Yu, H.; Paganini, M.; and Tian, Y. 2019.
tional Journal of Machine Learning and Cybernetics 8: 355–
One ticket to win them all: generalizing lottery ticket initial-
370.
izations across datasets and optimizers. arXiv:1906.02773
[cs, stat] URL http://arxiv.org/abs/1906.02773. ArXiv: Torpey, E.; and Watson, A. 2014. Education level and
1906.02773. jobs: Opportunities by state. URL https://www.bls.gov/
careeroutlook/2014/article/education-level-and-jobs.htm.
Nigam, K.; and Ghani, R. 2000. Analyzing the effectiveness
and applicability of co-training. In Proceedings of the ninth Vaswani, A.; Shazeer, N.; Parmar, N.; Uszkoreit, J.; Jones, L.;
international conference on Information and knowledge man- Gomez, A. N.; Kaiser, Ł.; and Polosukhin, I. 2017. Attention
agement, 86–93. is all you need. In Advances in neural information processing
systems, 5998–6008.
Oliver, A.; Odena, A.; Raffel, C.; Cubuk, E. D.; and Goodfel-
low, I. J. 2019. Realistic Evaluation of Deep Semi-Supervised Voulodimos, A.; Doulamis, N.; Doulamis, A.; and Protopa-
Learning Algorithms. arXiv:1804.09170 [cs, stat] URL http: padakis, E. 2018. Deep learning for computer vision: A brief
//arxiv.org/abs/1804.09170. ArXiv: 1804.09170. review. Computational intelligence and neuroscience 2018.
Oliver, A.; Odena, A.; Raffel, C. A.; Cubuk, E. D.; and Good- Wang, R.; Fu, B.; Fu, G.; and Wang, M. 2017. Deep & Cross
fellow, I. 2018. Realistic evaluation of deep semi-supervised Network for Ad Click Predictions. In ADKDD@KDD.
learning algorithms. In Advances in Neural Information Pro- Xiao, J.; Ye, H.; He, X.; Zhang, H.; Wu, F.; and Chua,
cessing Systems, 3235–3246. T.-S. 2017. Attentional Factorization Machines: Learn-
Paszke, A.; Gross, S.; Massa, F.; Lerer, A.; Bradbury, J.; ing the Weight of Feature Interactions via Attention Net-
Chanan, G.; Killeen, T.; Lin, Z.; Gimelshein, N.; Antiga, works. In Proceedings of the Twenty-Sixth International
L.; Desmaison, A.; Kopf, A.; Yang, E.; DeVito, Z.; Raison, Joint Conference on Artificial Intelligence, 3119–3125. Mel-
M.; Tejani, A.; Chilamkurthy, S.; Steiner, B.; Fang, L.; Bai, bourne, Australia: International Joint Conferences on Arti-
J.; and Chintala, S. 2019. PyTorch: An Imperative Style, ficial Intelligence Organization. ISBN 978-0-9992411-0-
High-Performance Deep Learning Library. In Wallach, H.; 3. doi:10.24963/ijcai.2017/435. URL https://www.ijcai.org/
Larochelle, H.; Beygelzimer, A.; d’Alché Buc, F.; Fox, E.; proceedings/2017/435.
and Garnett, R., eds., Advances in Neural Information Pro- Yang, Y.; Morillo, I. G.; and Hospedales, T. M. 2018. Deep
cessing Systems 32, 8024–8035. Curran Associates, Inc. URL neural decision trees. arXiv preprint arXiv:1806.06988 .
http://papers.neurips.cc/paper/9015-pytorch-an-imperative-
style-high-performance-deep-learning-library.pdf. Zhu, X.; and Ghahramani, Z. 2002. Learning from labeled
and unlabeled data with label propagation .
Prokhorenkova, L.; Gusev, G.; Vorobev, A.; Dorogush, A. V.;
and Gulin, A. 2018. CatBoost: unbiased boosting with cate-
gorical features. In Advances in neural information process-
ing systems, 6638–6648.
Rong, X. 2014. word2vec parameter learning explained.
arXiv preprint arXiv:1411.2738 .
Sajjadi, M.; Javanmardi, M.; and Tasdizen, T. 2016. Regular-
ization with stochastic transformations and perturbations for
deep semi-supervised learning. In Advances in neural infor-
mation processing systems, 1163–1171.
Sandler, M.; Howard, A.; Zhu, M.; Zhmoginov, A.; and Chen,
L.-C. 2018. Mobilenetv2: Inverted residuals and linear bot-
tlenecks. In Proceedings of the IEEE conference on computer
vision and pattern recognition, 4510–4520.
A Appendix: Ablation Studies combine this study with another ablation on shared RTD
We perform a number of ablation studies on various archi- binary classifier (predictor) vs. different classifiers for dif-
tectural choices and pre-training approaches for our Tab- ferent columns. Results in Table 7 show that our choice
Transformer. The first ablation study is on the choice of of dynamic replacement and un-shared RTD classifiers per-
column embedding. The second and third ablation studies form better than static replacement and shared RTD clas-
focus on the pre-training approach. Specifically, they are sifiers. Figure 7 shows the pre-training curves of training
on the replacement value k and dynamic versus static re- and validation accuracy for the three choices – dynamic re-
placement strategy. For the pre-training approach, we use placement, static replacement, and static replacement with a
TabTransformer-RTD as our model. That is, the loss in the shared RTD classifier.
pre-training is RTD loss. For TabTransformer, the hidden
(embedding) dimension, the number of layers and the num- B Appendix: Experiment and Model Details
ber of attention heads in the Transformer are set to 32, 6, and In this section, we discuss the experiments and model de-
8 respectively. The MLP layer sizes are set to {4 × l, 2 × l}, tails. First, we go through the experiments details and hyper
where l is the size of its input. To better present the result, we parameters search space for HPO in Section B.1. Next, we
introduce an additional evaluation metric, the relative AUC. discuss the feature engineering in Section B.2.
More precisely, for each dataset and cross-validation split,
the relative AUC for a model is the relative change of its
AUC against the mean AUC over all competing models.
B.1 Experiments Details and Hyper Parameters
Setup. All experiments were run on an Ubuntu Linux
Column Embedding. The first study is on the choice machine with 8 CPUs and 60GB memory, with all mod-
of column embedding – shared parameters cφi across the els using a single NVIDIA V100 Tensor Core GPU. For
embeddings of multiple classes in column i for i ∈ the competing models mentioned in the experiment, we re-
{1, 2, ..., m}. In particular, we study the optimal dimension implemented all of them for consistency of pre-processing.
of cφi , `. An alternative choice is to element-wisely add the In cases where there exist published results for a model,
unique identifier cφi and feature-value specific embeddings our tested results are close to the published records. The
wφij rather than concatenating them. In that case, both the GBDT model is implemented using the LightGBM library
dimension of cφi and wφij are equal to the dimension of (Ke et al. 2017). All the other models are implemented using
embedding d. The goal of having column embedding is to the PyTorch library (Paszke et al. 2019). To reproduce our
enable the model to distinguish the classes in one column experiment results, the models’ implementations and the ex-
from those in the other columns. A baseline approach is to act values for all hyper-parameters can be found in another
not have any shared embedding. Results are presented in Ta- supplemental material, Code and Data Appendix.
ble 5 where “Col Embed-Concat-1/X” indicates that the di- For each dataset, all of the cross-validation splits, labeled,
mension ` is set as d/X. The relative AUC score is calcu- and unlabeled training data are obtained with a fixed random
lated over all the models that appear in the rows and columns seed such that every model tested receives exactly the same
in the table, which explains why negative scores appear in training and testing conditions.
some of the entries. Results show that not having the shared
As all the datasets are for binary classification, the
column embedding performs worst and our concatenation
cross entropy loss was used for both supervised and semi-
column embedding gives an average better performance.
supervised training (for pre-training, the problem is bi-
The replacement value k. The second ablation study is nary classification in RTD and multi-class classification
on the replacement value k in pre-training approach. We run in MLM). For all deep models, the AdamW optimizer
experiments for three different choices of k – {15, 30, 50} (Loshchilov and Hutter 2017) was used to update the
on three different datasets, namely – Adult, BankMarketing, model parameters, and a constant learning rate was applied
and 1995 income. The TabTransformer is firstly pre-trained throughout each training job. All models used early stopping
with a value of k on unlabeled data and then fine-tuned on based on the performance on the validation set and the early
labeled data. The number of labeled data is set as 256. The stopping patience (the number of epochs) is set as 15.
final fine-tuning accuracy is not much sensitive to the value
of k, as shown in Table 6. The pre-training curves of train- Hyper-parameters Search Space. The hyper-parameters
ing and validation accuracy for the three different replace- tuned for the GBDT model were the number of leaves in the
ment value k is shown in Figure 6. Note, that a constant trees with a search space {x ∈ Z|5 ≤ x ≤ 50}, the mini-
prediction model would achieve 85% accuracy for the 15% mum number of datapoints required to split a leaf in the trees
replacement value. with a search space {x ∈ Z|1 ≤ x ≤ 100}, the boosting
learning rate with a search space {x = 5 · 10u , u ∈ U| − 3 ≤
Dynamic versus Static Replacement. The third ablation x ≤ −1}, and the number of trees used for boosting with a
study is on dynamic vs static replacement in the pre-training search space {x ∈ Z|10 ≤ x ≤ 1000}.
approach. In dynamic replacement, we randomly replace For all of the deep models, the common hyper-parameters
feature values during pre-training over the epochs. That is include the weight decay factor with a search space {x =
the replacement is different in each epoch. Whereas in static 10u , u ∈ U| − 6 ≤ u ≤ −1}, the learning rate with a search
replacement, the random replacement is chosen once, and space {x = 10u , u ∈ U|−6 ≤ u ≤ −3}, the dropout proba-
then the same replacement is used in all the epochs. We bility with a search space {0, 0.1, 0.2, ...0.5}, and whether to
Table 5: Performance of TabTransformer with no column embedding, concatenation column embedding, and addition column
embedding. The evaluation metric is mean ± standard deviation of relative AUCs (in percentage) over all 15 datasets. Larger
value means better performance. The best model is bold for each row.

# of Transformers Layers No Col Embed Col Embed-Concat-1/4 Col Embed-Concat-1/8 Col Embed-Add
1 -0.59 ± 0.33 -2.01 ± 1.33 -0.27 ± 0.21 -1.11 ± 0.77
2 -0.59 ± 0.22 -0.37 ± 0.20 -0.14 ± 0.19 0.34 ± 0.27
3 -0.37 ± 0.19 0.04 ± 0.18 -0.02 ± 0.21 0.21 ± 0.23
6 0.54 ± 0.22 0.53 ± 0.24 0.70 ± 0.17 0.25 ± 0.23
12 0.66 ± 0.21 1.05 ± 0.31 0.73 ± 0.58 0.42 ± 0.39

Table 6: Fine-tuning performance of TabTransformer-RTD for different pre-training replacement value k. The number of labeled
data points is 256. The evaluation metrics are mean ± standard deviation of (1) AUC score over 5 cross-validation splits for
each dataset (in percentage); (2) relative AUCs over the 3 datasets (in percentage). Larger value means better performance. The
best model is bold for each column.

Replacement value k% Adult BankMarketing 1995 income relative AUC (%)


15 58.1 ± 3.52 85.9 ± 1.62 86.8 ± 1.35 0.02 ± 0.10
30 58.1 ± 3.15 86.1 ± 1.58 86.7 ± 1.41 0.08 ± 0.10
50 57.9 ± 3.21 85.7 ± 1.93 86.7 ± 1.38 -0.10 ± 0.11

Table 7: Fine-tuning performance of TabTransformer-RTD for dynamic replacement, static replacement, and static replacement
with a shared classifier. The number of labeled data points is 256. The evaluation metrics are mean ± standard deviation of (1)
AUC score over 5 cross-validation splits for each dataset (in percentage) ; (2) relative AUCs over the 3 datasets (in percentage).
Larger value means better performance. The best model is bold for each column.

Adult BankMarketing 1995 income relative AUC (%)


Dynamic Replacement (Un-shared RTD classifiers) 58.1 ± 3.52 85.9 ± 1.62 86.8 ± 1.35 0.81 ± 0.19
Static Replacement (Un-shared RTD classifiers) 57.9 ± 2.93 83.9 ± 1.18 85.9 ± 1.60 -0.33 ± 0.15
Static Replacement (Shared RTD Classifiers) 57.5 ± 2.74 84.2 ± 1.46 86.0 ± 1.69 -0.49 ± 0.11
Figure 6: The pre-training curves of training and validation accuracy for the three different replacement value k for dataset
Adult, BankMarketing, and 1995 income.

one-hot encode categorical variables or train learnable em- {32, 64, 128, 256}, and the number of layers {1, 2, 3, 6, 12}.
beddings. The search spaces of the first and second hidden layer in
For MLPs, they all used SELU activations (Klambauer MLP are exactly the same as those in MLP model setting.
et al. 2017) followed by batch normalization in each layer, The dimension of cφi , ` was chosen as d/8 based on the
and set the number of hidden layers as 2. The model-specific ablation study in Appendix A.
hyper-parameters tuned were the first hidden layer with a For Sparse MLP (Prune), its implementation was the same
search space {x = m ∗ l, m ∈ Z|1 ≤ m ≤ 8} where l is the as the MLP except that at every k epochs during train-
input size, and the second hidden layer with a search space ing the fraction p of weights with the smallest magnitude
{x = m ∗ l, m ∈ Z|1 ≤ m ≤ 3}. were permanently set to zero. The model-specific hyper-
For TabTransformer, the hidden (embedding) dimension, parameters tuned were the fraction p with a search space
the number of layers and the number of attention heads {x = 5 · 10u , u ∈ U| − 2 ≤ u ≤ −1}. The number of layers
in the Transformer were fixed to 32, 6, and 8 respectively and layer sizes are exactly the same as the setting in MLP.
during the experiments. The MLP layer sizes were fixed The parameter k is set as 10.
to {4 × l, 2 × l}, where l was the size of its input. How- For TabNet model, we implemented exactly as described
ever, these parameters were optimally selected based on 50 in Arik and Pfister (2019), though we also added the option
rounds of HPO run on 5 datasets. The search spaces were the to use a softmax attention instead of a sparsemax attention,
number of attention heads {2, 4, 8}, the hidden dimension and did not include the sparsification term in the loss func-
Figure 7: The pre-training curves of training and validation accuracy for dynamic mask, static mask, and static mask with a
shared predictor (classifier) for dataset Adult, BankMarketing, and 1995 income.

tion. The model-specific hyper-parameters tuned were the B.2 Feature Engineering
number of layers with a search space {x ∈ Z|3 ≤ x ≤ 10} , For categorical variables, the processing options include
the hidden dimension {x ∈ Z|8 ≤ x ≤ 128}, and the sparse whether to one-hot encode versus learn a parametric em-
coefficient with a search space {x = 10u , u ∈ U| − 6 ≤ u ≤ bedding, what embedding dimension to use, and how to ap-
−2}. ply dropout regularization (whether to drop vector elements
For VIB model, we implemented it as described in Alemi, or whole embeddings). In our experiments we found that
Fischer, and Dillon (2018). We used a diagonal covariance, learned embeddings nearly always improved performance
with 10 samples from the variational distribution during as long as the cardinality of the categorical variable is sig-
training and 20 during testing. The model-specific hyper- nificantly less than the number of data points, otherwise the
parameters tuned were the number of hidden layers and feature is merely a means for the model to overfit.
layer sizes, with exactly the same search spaces as MLP, and For scalar variables, the processing options include how
the number of mixture components in the mixture of gaus- to re-scale the variable (via quantiles, normalization, or log
sians used in the marginal distribution with a search space scaling) or whether to quantize the feature and treat it like
{x ∈ Z|3 ≤ x ≤ 10}. a categorical variable. While we have not explored this idea
For MLP (DAE), its pre-training used swap noise as fully, the best strategy is likely to use all the different types
described in Jahrer (2018). The model-specific hyper- of encoding in parallel, turning each scalar feature into three
parameters were exactly the same as MLP. re-scaled features and one categorical feature. Unlike learn-
For Pseudo Labeling (Lee 2013), since this method was ing embeddings for high-cardinality categorical features,
combined with deep models such as MLP, TabTransformer adding potentially-redundant encodings for scalar variables
and GBDT, the model-specific hyper-parameters were ex- should not lead to overfitting, but can make the difference
actly the same as the corresponding deep models mentioned between a feature being useful or not.
above. The unsupervised coefficient α is chosen as αf = For text variables, we simply encodes the number of
3, T1 = 30, T2 = 70. words and character in the text.
For Entropy Regularization (Grandvalet and Bengio
2006), it is the same as Pseudo Labeling. The additional C Appendix: Benchmark Dataset
model-specific hyper-parameter was the positive Lagrange Information and Experiment Results
multiplier λ with a search space {0.1, 0.2, ..., 0.9}.
Table 8: Benchmark datasets. All datasets are binary classification tasks. Positive Class% is the fraction of data points that
belongs to the positive class.

Dataset Name N Datapoints N Features Positive Class%


1995 income 32561 14 24.1
adult 34190 25 85.4
albert 425240 79 50.0
bank marketing 45211 16 11.7
blastchar 7043 20 26.5
dota2games 92650 117 52.7
fabert 8237 801 11.3
hcdr main 307511 120 8.1
htru2 17898 8 9.2
insurance co 5822 85 6.0
jannis 83733 55 2.0
jasmine 2984 145 50.0
online shoppers 12330 17 15.5
philippine 5832 309 50.0
qsar bio 1055 41 33.7
seismicbumps 2583 18 6.6
shrutime 10000 11 20.4
spambase 4601 57 39.4
sylvine 5124 20 50.0
volkert 58310 181 12.7

Table 9: Benchmark Dataset Links.

Dataset Name URL


1995 income https://www.kaggle.com/lodetomasi1995/income-classification
adult http://automl.chalearn.org/data
albert http://automl.chalearn.org/data
bank marketing https://archive.ics.uci.edu/ml/datasets/bank+marketing
blastchar https://www.kaggle.com/blastchar/telco-customer-churn
dota2games https://archive.ics.uci.edu/ml/datasets/Dota2+Games+Results
fabert http://automl.chalearn.org/data
hcdr main https://www.kaggle.com/c/home-credit-default-risk
htru2 https://archive.ics.uci.edu/ml/datasets/HTRU2
insurance co https://archive.ics.uci.edu/ml/datasets/Insurance+Company+Benchmark+%28COIL+2000%29
jannis http://automl.chalearn.org/data
jasmine http://automl.chalearn.org/data
online shoppers https://archive.ics.uci.edu/ml/datasets/Online+Shoppers+Purchasing+Intention+Dataset
philippine http://automl.chalearn.org/data
qsar bio https://archive.ics.uci.edu/ml/datasets/QSAR+biodegradation
seismicbumps https://archive.ics.uci.edu/ml/datasets/seismic-bumps
shrutime https://www.kaggle.com/shrutimechlearn/churn-modelling
spambase https://archive.ics.uci.edu/ml/datasets/Spambase
sylvine http://automl.chalearn.org/data
volkert http://automl.chalearn.org/data
Table 10: AUC score for semi-supervised learning models on all datasets with 50 fine-tune data points. Values are the mean
over 5 cross-validation splits, plus or minus the standard deviation. Larger values means better result.

Dataset N Datapoints N Features Positive Class% Best Model TabTransformer-RTD TabTransformer-MLM MLP (ER)
albert 425240 79 50.0 TabTransformer-MLM 0.644 ± 0.015 0.647 ± 0.019 0.612 ± 0.017
hcdr main 307511 120 8.1 MLP (DAE) 0.592 ± 0.047 0.596 ± 0.047 0.602 ± 0.033
dota2games 92650 117 52.7 TabTransformer-MLM 0.526 ± 0.009 0.538 ± 0.011 0.519 ± 0.007
jannis 83733 55 2.0 TabTransformer-RTD 0.684 ± 0.055 0.665 ± 0.056 0.621 ± 0.022
volkert 58310 181 1.0 TabTransformer-RTD 0.693 ± 0.046 0.689 ± 0.042 0.657 ± 0.028
bank marketing 45211 16 11.7 MLP (PL) 0.771 ± 0.046 0.735 ± 0.040 0.792 ± 0.039
adult 34190 25 85.4 MLP (DAE) 0.580 ± 0.012 0.613 ± 0.014 0.609 ± 0.005
1995 income 32561 14 24.1 TabTransformer-MLM 0.840 ± 0.029 0.862 ± 0.018 0.839 ± 0.034
htru2 17898 8 9.2 MLP (DAE) 0.956 ± 0.007 0.958 ± 0.009 0.969 ± 0.012
online shoppers 12330 17 15.5 MLP (DAE) 0.790 ± 0.013 0.780 ± 0.024 0.855 ± 0.019
shrutime 10000 11 20.4 TabTransformer-RTD 0.752 ± 0.019 0.741 ± 0.019 0.725 ± 0.032
fabert 8237 801 11.3 MLP (PL) 0.535 ± 0.027 0.525 ± 0.019 0.572 ± 0.019
blastchar 7043 20 26.5 TabTransformer-MLM 0.806 ± 0.018 0.822 ± 0.009 0.803 ± 0.021
philippine 5832 309 50.0 TabTransformer-RTD 0.739 ± 0.027 0.729 ± 0.035 0.722 ± 0.031
insurance co 5822 85 6.0 MLP (PL) 0.601 ± 0.056 0.573 ± 0.077 0.575 ± 0.063
sylvine 5124 20 50.0 MLP (PL) 0.872 ± 0.031 0.898 ± 0.030 0.930 ± 0.015
spambase 4601 57 39.4 MLP (ER) 0.949 ± 0.005 0.945 ± 0.011 0.957 ± 0.008
jasmine 2984 145 50.0 TabTransformer-MLM 0.821 ± 0.019 0.837 ± 0.019 0.830 ± 0.022
seismicbumps 2583 18 6.6 TabTransformer (ER) 0.740 ± 0.088 0.738 ± 0.068 0.712 ± 0.074
qsar bio 1055 41 33.7 MLP (DAE) 0.875 ± 0.028 0.869 ± 0.036 0.880 ± 0.022

Table 11: (Continued) AUC score for semi-supervised learning models on all datasets with 50 fine-tune data points. Values are
the mean over 5 cross-validation splits, plus or minus the standard deviation. Larger values means better result.

Dataset MLP (PL) TabTransformer (ER) TabTransformer (PL) MLP (DAE) GBDT (PL)
albert 0.607 ± 0.013 0.580 ± 0.017 0.587 ± 0.012 0.612 ± 0.014 0.547 ± 0.032
hcdr main 0.599 ± 0.038 0.581 ± 0.023 0.570 ± 0.031 0.620 ± 0.028 0.531 ± 0.024
dota2games 0.520 ± 0.006 0.516 ± 0.009 0.519 ± 0.008 0.516 ± 0.004 0.505 ± 0.008
jannis 0.623 ± 0.035 0.582 ± 0.035 0.604 ± 0.013 0.626 ± 0.023 0.519 ± 0.047
volkert 0.653 ± 0.035 0.635 ± 0.024 0.639 ± 0.040 0.629 ± 0.019 0.525 ± 0.018
bank marketing 0.805 ± 0.036 0.744 ± 0.063 0.767 ± 0.058 0.786 ± 0.055 0.688 ± 0.057
adult 0.605 ± 0.021 0.568 ± 0.012 0.582 ± 0.024 0.616 ± 0.010 0.519 ± 0.024
1995 income 0.819 ± 0.042 0.813 ± 0.045 0.822 ± 0.048 0.811 ± 0.042 0.685 ± 0.084
htru2 0.970 ± 0.012 0.955 ± 0.007 0.951 ± 0.009 0.973 ± 0.003 0.919 ± 0.021
online shoppers 0.848 ± 0.021 0.816 ± 0.036 0.818 ± 0.028 0.858 ± 0.019 0.818 ± 0.032
shrutime 0.715 ± 0.044 0.748 ± 0.035 0.739 ± 0.034 0.683 ± 0.055 0.651 ± 0.093
fabert 0.577 ± 0.027 0.504 ± 0.020 0.516 ± 0.020 0.552 ± 0.013 0.534 ± 0.016
blastchar 0.799 ± 0.025 0.799 ± 0.013 0.792 ± 0.025 0.817 ± 0.016 0.729 ± 0.053
philippine 0.725 ± 0.022 0.689 ± 0.046 0.703 ± 0.050 0.717 ± 0.022 0.628 ± 0.085
insurance co 0.601 ± 0.057 0.575 ± 0.066 0.592 ± 0.080 0.522 ± 0.052 0.560 ± 0.081
sylvine 0.939 ± 0.013 0.891 ± 0.022 0.904 ± 0.027 0.925 ± 0.010 0.914 ± 0.021
spambase 0.951 ± 0.010 0.947 ± 0.006 0.948 ± 0.006 0.949 ± 0.012 0.899 ± 0.039
jasmine 0.819 ± 0.021 0.825 ± 0.024 0.819 ± 0.018 0.812 ± 0.029 0.755 ± 0.016
seismicbumps 0.678 ± 0.106 0.745 ± 0.080 0.713 ± 0.090 0.724 ± 0.049 0.601 ± 0.071
qsar bio 0.875 ± 0.015 0.851 ± 0.041 0.835 ± 0.053 0.888 ± 0.022 0.804 ± 0.057

Table 12: AUC score for semi-supervised learning models on all datasets with 200 fine-tune data points. Values are the mean
over 5 cross-validation splits, plus or minus the standard deviation. Larger values means better result.

Dataset N Datapoints N Features Positive Class% Best Model TabTransformer-RTD TabTransformer-MLM MLP (ER)
albert 425240 79 50.0 TabTransformer-MLM 0.699 ± 0.011 0.701 ± 0.014 0.642 ± 0.020
hcdr main 307511 120 8.1 TabTransformer-MLM 0.655 ± 0.040 0.668 ± 0.028 0.639 ± 0.027
dota2games 92650 117 52.7 TabTransformer-MLM 0.536 ± 0.012 0.549 ± 0.008 0.527 ± 0.012
jannis 83733 55 2.0 TabTransformer-RTD 0.713 ± 0.037 0.692 ± 0.024 0.665 ± 0.024
volkert 58310 181 12.7 TabTransformer-RTD 0.753 ± 0.022 0.742 ± 0.023 0.696 ± 0.033
bank marketing 45211 16 11.7 MLP (PL) 0.854 ± 0.020 0.838 ± 0.010 0.860 ± 0.008
adult 34190 25 85.4 MLP (ER) 0.596 ± 0.023 0.614 ± 0.012 0.623 ± 0.017
1995 income 32561 14 24.1 TabTransformer-MLM 0.866 ± 0.014 0.875 ± 0.011 0.868 ± 0.007
htru2 17898 8 9.2 MLP (DAE) 0.961 ± 0.008 0.963 ± 0.009 0.974 ± 0.007
online shoppers 12330 17 15.5 MLP (ER) 0.834 ± 0.015 0.838 ± 0.024 0.876 ± 0.019
shrutime 10000 11 20.4 TabTransformer-RTD 0.805 ± 0.017 0.783 ± 0.024 0.773 ± 0.013
fabert 8237 801 11.3 MLP (ER) 0.556 ± 0.023 0.561 ± 0.028 0.600 ± 0.046
blastchar 7043 20 26.5 TabTransformer-MLM 0.831 ± 0.010 0.841 ± 0.014 0.829 ± 0.010
philippine 5832 309 50.0 TabTransformer-RTD 0.757 ± 0.017 0.754 ± 0.016 0.732 ± 0.024
insurance co 5822 85 6.0 TabTransformer (ER) 0.667 ± 0.062 0.640 ± 0.043 0.601 ± 0.059
sylvine 5124 20 50.0 MLP (PL) 0.939 ± 0.008 0.948 ± 0.006 0.957 ± 0.008
spambase 4601 57 39.4 MLP (ER) 0.957 ± 0.006 0.955 ± 0.010 0.968 ± 0.009
jasmine 2984 145 50.0 TabTransformer-RTD 0.843 ± 0.016 0.843 ± 0.028 0.831 ± 0.019
seismicbumps 2583 18 6.6 TabTransformer-RTD 0.738 ± 0.063 0.708 ± 0.083 0.694 ± 0.088
qsar bio 1055 41 33.7 TabTransformer-RTD 0.896 ± 0.018 0.889 ± 0.030 0.895 ± 0.026
Table 13: (Continued) AUC score for semi-supervised learning models on all datasets with 200 fine-tune data points. Values
are the mean over 5 cross-validation splits, plus or minus the standard deviation. Larger values means better result.

Dataset MLP (PL) TabTransformer (ER) TabTransformer (PL) MLP (DAE) GBDT (PL)
albert 0.638 ± 0.024 0.630 ± 0.025 0.630 ± 0.021 0.646 ± 0.023 0.628 ± 0.015
hcdr main 0.631 ± 0.019 0.611 ± 0.030 0.605 ± 0.021 0.636 ± 0.027 0.579 ± 0.039
dota2games 0.527 ± 0.014 0.528 ± 0.017 0.525 ± 0.011 0.528 ± 0.012 0.506 ± 0.008
jannis 0.667 ± 0.036 0.619 ± 0.024 0.637 ± 0.026 0.659 ± 0.020 0.525 ± 0.030
volkert 0.693 ± 0.028 0.694 ± 0.002 0.689 ± 0.015 0.672 ± 0.015 0.612 ± 0.042
bank marketing 0.866 ± 0.008 0.853 ± 0.016 0.858 ± 0.009 0.863 ± 0.009 0.802 ± 0.012
adult 0.616 ± 0.014 0.582 ± 0.026 0.584 ± 0.017 0.611 ± 0.027 0.572 ± 0.040
1995 income 0.869 ± 0.009 0.848 ± 0.024 0.852 ± 0.015 0.865 ± 0.011 0.822 ± 0.020
htru2 0.974 ± 0.007 0.955 ± 0.007 0.954 ± 0.007 0.974 ± 0.010 0.946 ± 0.022
online shoppers 0.873 ± 0.030 0.857 ± 0.014 0.853 ± 0.017 0.873 ± 0.021 0.846 ± 0.019
shrutime 0.774 ± 0.018 0.803 ± 0.022 0.803 ± 0.024 0.763 ± 0.018 0.750 ± 0.050
fabert 0.595 ± 0.048 0.530 ± 0.027 0.522 ± 0.024 0.580 ± 0.020 0.573 ± 0.026
blastchar 0.829 ± 0.011 0.823 ± 0.011 0.823 ± 0.011 0.832 ± 0.013 0.783 ± 0.017
philippine 0.733 ± 0.018 0.736 ± 0.018 0.739 ± 0.024 0.720 ± 0.020 0.729 ± 0.024
insurance co 0.616 ± 0.045 0.715 ± 0.038 0.680 ± 0.034 0.612 ± 0.024 0.630 ± 0.087
sylvine 0.961 ± 0.004 0.951 ± 0.009 0.950 ± 0.010 0.955 ± 0.009 0.957 ± 0.005
spambase 0.965 ± 0.008 0.962 ± 0.006 0.960 ± 0.008 0.964 ± 0.009 0.957 ± 0.013
jasmine 0.839 ± 0.013 0.824 ± 0.024 0.841 ± 0.016 0.842 ± 0.014 0.826 ± 0.013
seismicbumps 0.684 ± 0.071 0.723 ± 0.080 0.727 ± 0.081 0.673 ± 0.070 0.603 ± 0.023
qsar bio 0.892 ± 0.033 0.871 ± 0.036 0.876 ± 0.032 0.891 ± 0.018 0.855 ± 0.035

Table 14: AUC score for semi-supervised learning models on all datasets with 500 fine-tune data points. Values are the mean
over 5 cross-validation splits, plus or minus the standard deviation. Larger values means better result.

Dataset N Datapoints N Features Positive Class% Best Model TabTransformer-RTD TabTransformer-MLM MLP (ER)
albert 425240 79 50.0 TabTransformer-RTD 0.711 ± 0.004 0.707 ± 0.006 0.666 ± 0.008
hcdr main 307511 120 8.1 TabTransformer-MLM 0.690 ± 0.038 0.698 ± 0.033 0.653 ± 0.019
dota2games 92650 117 52.7 TabTransformer-MLM 0.548 ± 0.008 0.557 ± 0.003 0.543 ± 0.008
jannis 83733 55 2.0 TabTransformer-RTD 0.747 ± 0.015 0.720 ± 0.018 0.707 ± 0.036
volkert 58310 181 12.7 TabTransformer-RTD 0.771 ± 0.016 0.760 ± 0.015 0.723 ± 0.016
bank marketing 45211 16 11.7 TabTransformer-RTD 0.879 ± 0.012 0.866 ± 0.016 0.869 ± 0.012
adult 34190 25 85.4 MLP (PL) 0.625 ± 0.011 0.647 ± 0.008 0.644 ± 0.015
1995 income 32561 14 24.1 MLP (DAE) 0.874 ± 0.008 0.880 ± 0.007 0.878 ± 0.002
htru2 17898 8 9.2 MLP (DAE) 0.964 ± 0.009 0.966 ± 0.009 0.973 ± 0.010
online shoppers 12330 17 15.5 MLP (ER) 0.859 ± 0.009 0.861 ± 0.014 0.888 ± 0.012
shrutime 10000 11 20.4 TabTransformer-RTD 0.831 ± 0.017 0.815 ± 0.004 0.793 ± 0.017
fabert 8237 801 11.3 MLP (ER) 0.618 ± 0.014 0.609 ± 0.019 0.621 ± 0.032
blastchar 7043 20 26.5 TabTransformer-RTD 0.840 ± 0.013 0.839 ± 0.015 0.829 ± 0.013
philippine 5832 309 50.0 TabTransformer-MLM 0.769 ± 0.028 0.772 ± 0.017 0.734 ± 0.024
insurance co 5822 85 6.0 TabTransformer (ER) 0.688 ± 0.039 0.642 ± 0.029 0.659 ± 0.023
sylvine 5124 20 50.0 MLP (PL) 0.955 ± 0.007 0.959 ± 0.006 0.967 ± 0.003
spambase 4601 57 39.4 MLP (ER) 0.966 ± 0.007 0.968 ± 0.008 0.975 ± 0.004
jasmine 2984 145 50.0 TabTransformer-RTD 0.847 ± 0.016 0.844 ± 0.011 0.837 ± 0.019
seismicbumps 2583 18 6.6 TabTransformer-RTD 0.758 ± 0.081 0.729 ± 0.069 0.682 ± 0.123
qsar bio 1055 41 33.7 MLP (DAE) 0.909 ± 0.024 0.889 ± 0.038 0.918 ± 0.023

Table 15: (Continued) AUC score for semi-supervised learning models on all datasets with 500 fine-tune data points. Values
are the mean over 5 cross-validation splits, plus or minus the standard deviation. Larger values means better result.

Dataset MLP (PL) TabTransformer (ER) TabTransformer (PL) MLP (DAE) GBDT (PL)
albert 0.662 ± 0.007 0.664 ± 0.011 0.643 ± 0.029 0.666 ± 0.006 0.653 ± 0.011
hcdr main 0.645 ± 0.022 0.623 ± 0.036 0.636 ± 0.031 0.657 ± 0.033 0.607 ± 0.035
dota2games 0.544 ± 0.010 0.538 ± 0.009 0.541 ± 0.010 0.542 ± 0.012 0.505 ± 0.005
jannis 0.698 ± 0.033 0.662 ± 0.007 0.660 ± 0.024 0.693 ± 0.024 0.521 ± 0.045
volkert 0.722 ± 0.012 0.712 ± 0.016 0.705 ± 0.021 0.712 ± 0.016 0.705 ± 0.016
bank marketing 0.876 ± 0.017 0.863 ± 0.008 0.868 ± 0.016 0.874 ± 0.012 0.838 ± 0.019
adult 0.651 ± 0.012 0.618 ± 0.023 0.618 ± 0.021 0.654 ± 0.016 0.647 ± 0.030
1995 income 0.880 ± 0.003 0.868 ± 0.008 0.869 ± 0.007 0.882 ± 0.001 0.839 ± 0.013
htru2 0.974 ± 0.007 0.960 ± 0.010 0.960 ± 0.008 0.976 ± 0.006 0.949 ± 0.007
online shoppers 0.885 ± 0.021 0.861 ± 0.011 0.860 ± 0.013 0.885 ± 0.019 0.865 ± 0.011
shrutime 0.800 ± 0.015 0.825 ± 0.013 0.822 ± 0.016 0.804 ± 0.015 0.788 ± 0.019
fabert 0.596 ± 0.046 0.573 ± 0.048 0.578 ± 0.033 0.617 ± 0.042 0.585 ± 0.025
blastchar 0.833 ± 0.013 0.834 ± 0.013 0.832 ± 0.011 0.833 ± 0.012 0.795 ± 0.021
philippine 0.740 ± 0.023 0.746 ± 0.020 0.735 ± 0.015 0.739 ± 0.017 0.749 ± 0.026
insurance co 0.646 ± 0.048 0.710 ± 0.040 0.666 ± 0.060 0.612 ± 0.013 0.672 ± 0.037
sylvine 0.968 ± 0.003 0.958 ± 0.005 0.958 ± 0.003 0.967 ± 0.003 0.967 ± 0.006
spambase 0.973 ± 0.005 0.968 ± 0.007 0.967 ± 0.006 0.972 ± 0.006 0.972 ± 0.005
jasmine 0.833 ± 0.009 0.833 ± 0.021 0.838 ± 0.018 0.842 ± 0.011 0.838 ± 0.022
seismicbumps 0.677 ± 0.103 0.687 ± 0.100 0.735 ± 0.081 0.696 ± 0.112 0.666 ± 0.063
qsar bio 0.914 ± 0.032 0.894 ± 0.036 0.895 ± 0.035 0.925 ± 0.034 0.908 ± 0.024
Table 16: AUC score for supervised learning models on all datasets. Values are the mean over 5 cross-validation splits, plus or
minus the standard deviation. Larger values means better result.

Dataset N Datapoints N Features Positive Class% Best Model Logistic Regression GBDT
ds name
albert 425240 79 50.0 GBDT 0.726 ± 0.001 0.763 ± 0.001
hcdr main 307511 120 8.1 GBDT 0.747 ± 0.004 0.756 ± 0.004
dota2games 92650 117 52.7 Logistic Regression 0.634 ± 0.003 0.621 ± 0.004
bank marketing 45211 16 11.7 TabTransformer 0.911 ± 0.005 0.933 ± 0.003
adult 34190 25 85.4 GBDT 0.721 ± 0.010 0.756 ± 0.011
1995 income 32561 14 24.1 TabTransformer 0.899 ± 0.002 0.906 ± 0.002
online shoppers 12330 17 15.5 GBDT 0.908 ± 0.015 0.930 ± 0.008
shrutime 10000 11 20.4 GBDT 0.828 ± 0.013 0.859 ± 0.009
blastchar 7043 20 26.5 GBDT 0.844 ± 0.010 0.847 ± 0.016
philippine 5832 309 50.0 TabTransformer 0.725 ± 0.022 0.812 ± 0.013
insurance co 5822 85 6.0 TabTransformer 0.736 ± 0.023 0.732 ± 0.022
spambase 4601 57 39.4 GBDT 0.947 ± 0.008 0.987 ± 0.005
jasmine 2984 145 50.0 GBDT 0.846 ± 0.017 0.862 ± 0.008
seismicbumps 2583 18 6.6 GBDT 0.749 ± 0.068 0.756 ± 0.084
qsar bio 1055 41 33.7 TabTransformer 0.847 ± 0.037 0.913 ± 0.031

Table 17: (Continued) AUC score for supervised learning models on all datasets. Values are the mean over 5 cross-validation
splits, plus or minus the standard deviation. Larger values means better result.

MLP Sparse MLP TabTransformer TabNet VIB


ds name
albert 0.740 ± 0.001 0.741 ± 0.001 0.757 ± 0.002 0.705 ± 0.005 0.737 ± 0.001
hcdr main 0.743 ± 0.004 0.753 ± 0.004 0.751 ± 0.004 0.711 ± 0.006 0.745 ± 0.005
dota2games 0.631 ± 0.002 0.633 ± 0.004 0.633 ± 0.002 0.529 ± 0.025 0.628 ± 0.003
bank marketing 0.929 ± 0.003 0.926 ± 0.007 0.934 ± 0.004 0.885 ± 0.017 0.920 ± 0.005
adult 0.725 ± 0.010 0.740 ± 0.007 0.737 ± 0.009 0.663 ± 0.016 0.733 ± 0.009
1995 income 0.905 ± 0.003 0.904 ± 0.004 0.906 ± 0.003 0.875 ± 0.006 0.904 ± 0.003
online shoppers 0.919 ± 0.010 0.922 ± 0.011 0.927 ± 0.010 0.888 ± 0.020 0.907 ± 0.012
shrutime 0.846 ± 0.013 0.828 ± 0.007 0.856 ± 0.005 0.785 ± 0.024 0.833 ± 0.011
blastchar 0.839 ± 0.010 0.842 ± 0.015 0.835 ± 0.014 0.816 ± 0.014 0.842 ± 0.012
philippine 0.821 ± 0.020 0.764 ± 0.018 0.834 ± 0.018 0.721 ± 0.008 0.757 ± 0.018
insurance co 0.697 ± 0.027 0.705 ± 0.054 0.744 ± 0.009 0.630 ± 0.061 0.647 ± 0.028
spambase 0.984 ± 0.004 0.980 ± 0.009 0.985 ± 0.005 0.975 ± 0.008 0.983 ± 0.004
jasmine 0.851 ± 0.015 0.856 ± 0.013 0.853 ± 0.015 0.816 ± 0.017 0.847 ± 0.017
seismicbumps 0.735 ± 0.028 0.699 ± 0.074 0.751 ± 0.096 0.701 ± 0.051 0.681 ± 0.084
qsar bio 0.910 ± 0.037 0.916 ± 0.036 0.918 ± 0.038 0.860 ± 0.038 0.914 ± 0.028

You might also like