Trompt Towards A Better Deep Neural Network For Tabular Data
Trompt Towards A Better Deep Neural Network For Tabular Data
Kuan-Yu Chen 1 Ping-Han Chiang 1 Hsin-Rung Chou 1 Ting-Wei Chen 1 Darby Tien-Hao Chang 1 2
Abstract
Tabular data is arguably one of the most com-
arXiv:2305.18446v2 [cs.LG] 31 May 2023
1
Trompt: Towards a Better Deep Neural Network for Tabular Data
between different algorithms, a standard benchmark for a pre-trained model. Unlike fine-tuning that changes the task
tabular data was proposed by (Grinsztajn et al., 2022). The and usually involves updating model weights, a pre-train
benchmark, denoted as Grinsztajn45 in this work, consists model with prompts can dedicate itself to one task. With
of 45 curated datasets from various domains. prompt learning, a small amount of data or even zero-shot
can achieve good results (Radford et al., 2018; Brown et al.,
In this paper, we propose a novel prompt-inspired architec-
2020). The emergence of prompt learning substantially
ture, Trompt, which abbreviates Tabular Prompt. Prompt
improves the application versatility of pre-trained models
learning has played an important role in the recent devel-
that are too large for common users to fine-tune.
opment of language models. For example, GPT-3 can well
handle a wide range of tasks with an appropriate prompt To prompt a language model, one can insert a task-specific
engineering (Radford et al., 2018; Brown et al., 2020). In prompt before a sentence and hint the model to adjust its
Trompt, prompt is utilized to derive feature importances responses for different tasks (Brown et al., 2020). Prompts
that vary in different samples. Trompt consists of multiple can either be discrete or soft. The former are composed of
Trompt Cells and a shared Trompt Downstream as Figure 2. discrete tokens from the vocabulary of natural languages
Each Trompt Cell is responsible for feature extraction, while (Radford et al., 2018; Brown et al., 2020), while the latter
the Trompt Downstream is for prediction. are learned representations (Li & Liang, 2021; Lester et al.,
2021).
The performance of Trompt is evaluated on the Grinsztajn45
benchmark and compared with three deep learning models
and five tree-based models. Figure 1 illustrates the overall 2.2. Tabular Neural Network
evaluation results on Grinsztajn45. The x-axis is the number Transformer. Self-attention has revolutionized NLP since
of hyperparameter search iterations and y-axis is the nor- 2017 (Vaswani et al., 2017), and soon been adopted by other
malized performance. In Figure 1, Trompt is consistently domains, such as computer vision, reinforcement learning
better than state-of-the-art deep learning models (SAINT and speech recognition (Dosovitskiy et al., 2020; Chen et al.,
and FT-Transformer) and the gap between deep learning 2021; Zhang et al., 2020). The intention of transformer
models and tree-based models is narrowed. blocks is to capture the relationships among features, which
Our key contributions are summarized as follows: can be applied on tabular data as well.
TabTransformer (Huang et al., 2020) is the first transformer-
• The experiments are conducted on a recognized tab- based tabular neural network. However, TabTransformer
ular benchmark, Grinsztajn45. Additionally, we add only fed categorical features to transformer blocks and ig-
two well-performed tree-based models, LightGBM (Ke nored the potential relationships among categorical and nu-
et al., 2017) and CatBoost (Prokhorenkova et al., 2018) merical features. FT-Transformer (Gorishniy et al., 2021)
to baselines. fixed this issue through feeding both categorical and numeri-
cal features to transformer blocks. SAINT (Somepalli et al.,
• Trompt achieves state-of-the-art performance among
2021) further improved FT-Transformer through applying
deep learning models and narrows the performance gap
attentions on not only the feature dimensions but also the
between deep learning models and tree-based models.
sample dimensions.
• Thorough empirical studies and ablation tests were
Inductive Bias Investigation. Deep neural networks per-
conducted to verify the design of Trompt. The results
form well on tasks with clear inductive bias. For example,
further shed light on future research directions of the
Convolutional Neural Network (CNN) works well on im-
architecture design of tabular neural network.
ages. The kernel of CNN is designed to capture local pat-
terns since neighboring pixels usually relate to each other
2. Related Work (LeCun et al., 1995). Recurrent Neural Networks (RNN) is
widely used in language understanding because the causal
In this section, we first discuss the prompt learning of lan-
relationship among words is well encapsulated through re-
guage models. Secondly, we discuss two research branches
current units (Rumelhart et al., 1986). However, unlike other
of tabular neural networks, transformer and inductive bias
popular tasks, the inductive bias of tabular data has not been
investigation. Lastly, we discuss the differences between
well discovered.
Trompt and the related works and highlight the uniqueness
of our work. Given the fact that tree-based model has been the solid state
of the art for tabular data (Borisov et al., 2021; Gorishniy
2.1. Prompt Learning et al., 2021; Shwartz-Ziv & Armon, 2022), Net-DNF (Katzir
et al., 2020) and TabNet (Arik & Pfister, 2021) hypothesized
The purpose of prompt learning is to transform the input and that the inductive bias for tabular data might be the learning
output of downstream tasks to the original task used to build
2
Trompt: Towards a Better Deep Neural Network for Tabular Data
Final Overall
Prediction Loss
Average Sum
Trompt Downstream
Trompt Cell Prediction Loss
1 (shared) 1 1
Trompt Downstream
Trompt Cell Prediction Loss
2 (shared) 2 2
Input Label
Trompt Downstream
Trompt Cell Prediction Loss
L-1 (shared) L-1 L-1
Trompt Downstream
Trompt Cell Prediction Loss
(shared) L L
L
strategy of tree-based model. The strategy is to find the is for prediction. The details of Trompt Cell and Trompt
optimal root-to-leaf decision paths by selecting a portion Downstream are discussed in Section 3.1 and Section 3.2,
of the features and deriving the optimal split from the se- respectively. In Section 3.3, we further discuss the prompt
lected features in non-leaf nodes. To emulate the learning learning of Trompt.
strategy, TabNet utilized sequential attention and sparsity
regularization. On the other hand, Net-DNF theoretically 3.1. Trompt Cell
proved that decision tree is equivalent to some disjunctive
normal form (DNF) and proposed disjunctive neural normal Figure 3 illustrates the architecture of a Trompt Cell, which
form to emulate a DNF formula. can be divided into three parts. The first part derives fea-
ture importances (Mimportance ) based on column embeddings
(Ecolumn ), the previous cell’s output (Oprev ) and prompt em-
2.3. The Uniqueness of Trompt
beddings (Eprompt ). The second part transforms the input
We argue that the column importances of tabular data are into feature embeddings (Efeature ) with two paths for cate-
not invariant for all samples and can be grouped into mul- gorical and numerical columns, respectively. The third part
tiple modalities. Since prompt learning is born to adapt a expands Efeature for the later multiplication.
model to multiple tasks, the concept is used in Trompt to
The details of the first part are illustrated in Section 3.1.1
handle multiple modalities. To this end, Trompt separates
and the details of the second and third parts are illustrated
the learning strategy of tabular data into two parts. The first
in Section 3.1.2. Lastly, the generation of the output of a
part, analogous to pre-trained models, focus on learning
Trompt Cell is illustrated in Section 3.1.3.
the intrinsic column information of a table. The second
part, analogous to prompts, focus on diversifying the feature
3.1.1. D ERIVE F EATURE I MPORTANCES
importances of different samples.
Let Ecolumn ∈ RC×d be column embeddings and Eprompt ∈
As far as our understanding, Trompt is the first prompt-
RP ×d be prompt embeddings. C is the number of columns
inspired tabular neural network. Compared to transformer-
of a table defined by the dataset, while P and d are hy-
based models, Trompt learns separated column importances
perparameters for the number of prompts and the hidden
instead of focusing on the interactions among columns.
dimension, respectively. Both Ecolumn and Eprompt are input
Compared to TabNet and Net-DNF, Trompt handle mul-
independent and trainable. Let Oprev ∈ RB×P ×d be the
tiple modalities by emulating prompt learning instead of the
previous cell’s output and B be the batch size.
branch split of decision tree.
Oprev is fused with the prompt embeddings as Equations (1)
3. Trompt and (2). Since Eprompt is input independent and lack a batch
dimension, Eprompt is expanded to SEprompt through the
In this section, we elaborate on the architecture design stack operation as Equation (1). Later, we concatenate
of Trompt. As Figure 2 shows, Trompt consists of mul- SEprompt and Oprev and then reduce the dimension of the
tiple Trompt Cells and a shared Trompt Downstream. Each concatenated tensor back to RB×P ×d for the final addition
Trompt Cell is responsible for feature extraction and provid- as Equation (2).
ing diverse representations, while the Trompt Downstream
3
Trompt: Towards a Better Deep Neural Network for Tabular Data
Concat
Dense
Layer
Norm
+
Matrix Multiplication
Softmax
Element-wise Multiplication
Layer
Norm
Column-wise Sum
Dense
Layer
Group
ReLU
Norm
Norm
Categorical
Embedding
Features
Concat
input +
Dense
Layer
ReLU
Norm
Numerical
Features
Part 2: Construct feature embeddings Part 3: Expand feature embeddings to accomodate multiple prompts
For the same reason as Eprompt , the Ecolumn is expanded 2021; Gorishniy et al., 2021). The embedding construc-
to SEcolumn as Equation (3). Subsequently, feature impor- tion procedure is illustrated in part two of Figure 3, where
tances are derived through Equation (4), where ⊗ is the Efeature ∈ RB×C×d is the feature embeddings of the batch.
batch matrix multiplication, ⊺ is the batch transpose, and
The shapes of Mimportance and Efeature are RB×P ×C
the softmax is applied to the column axis.
and RB×C×d , respectively. Since Efeature lacks the
prompt dimension, Trompt expands Efeature into Êfeature ∈
SEprompt = stack(Eprompt ) ∈ RB×P ×d (1) RB×P ×C×d to accommodate the P prompts by a dense
layer in part three of Figure 3.
ˆ prompt =dense(concat(SEprompt , Oprev ))
SE
+ SEprompt 3.1.3. G ENERATE O UTPUT
(2)
+ Oprev The output of Trompt Cell is the column-wise sum of the
∈ RB×P ×d element-wise multiplication of Êfeature and Mimportance as
Equation (5), where ⊙ is element-wise multiplication. No-
SEcolumn = stack(Ecolumn ) ∈ RB×C×d (3) tice that, during element-wise multiplication, the shape of
Mimportance is considered RB×P ×C×1 . In addition, since col-
umn is the third axis, the shape is reduced from RB×P ×C×d
ˆ prompt ⊗ SE⊺ B×P ×C to RB×P ×d after column-wise summation.
Mimportance = softmax(SE column ) ∈ R
C
(4) X
O= (Êfeature ⊙ Mimportance ):,:,i,: ∈ RB×P ×d (5)
The output of the first part is Mimportance ∈ RB×P ×C , i=1
which accommodates the feature importances yielded by
P prompts. Notice that the column embeddings are not
3.2. Trompt Downstream
connected to the input and the prompt embeddings are fused
with the previous cell’s output. In Section 3.3, we further A Trompt Downstream makes a prediction based on a
discuss these designs and their connections to the prompt Trompt Cell’s output, which contains representations cor-
learning of NLP. responding to P prompt embeddings. To aggregate these
representations, the weight for each prompt is first derived
3.1.2. C ONSTRUCT AND E XPAND F EATURE through a dense layer and a softmax activation function as
E MBEDDINGS Equation (6). Afterwards, the weighted sum is calculated as
Equation (7).
In Trompt, categorical features are embedded through a
embedding layer and numerical features are embedded The prediction is subsequently made through two dense
through a dense layer as previous works (Somepalli et al., layers as Equation (8), where T is the target dimension.
4
Trompt: Towards a Better Deep Neural Network for Tabular Data
5
Trompt: Towards a Better Deep Neural Network for Tabular Data
Section 4.2. Third, ablation studies regarding the hyper- sion of dense layers is set as d to reduce hyperparameters
parameters and the architecture of Trompt are studied in and save computing resources. On the other hand, the num-
Section 4.3. Lastly, the interpretability of Trompt is investi- ber of prompts and the number of Trompt Cells are set to P
gated using synthetic and real-world datasets in Section 4.4. and L. Please refer to Appendix F for the hyperparameter
search spaces for all baselines and Trompt.
4.1. Setup
The performance and ablation study of Trompt primarily fo- Table 2. Default hyperparameters of Trompt.
cus on the Grinsztajn45 benchmark (Grinsztajn et al., 2022)
1
. This benchmark comprises datasets from various domains Hyperparameter Symbol Value
and follows a unified methodology for evaluating different
models, providing a fair and comprehensive assessment. Feature Embeddings
Prompt/Column Embeddings d 128
Furthermore, we evaluate the performance of Trompt on Hidden Dimension
datasets selected by FT-Transformer and SAINT to compare
Prompts P 128
it with state-of-the-art tabular neural networks.
Layer L 6
For interpretability analysis, we follow the experimental
settings of TabNet (Arik & Pfister, 2021). This involves
using two synthetic datasets (Syn2 and Syn4) and a real- 4.2. Evaluation Results
world dataset (mushroom) to visualize attention masks.
The results of classification tasks are discussed in Sec-
The settings of Grinsztajn45 are presented in Section 4.1.1 tion 4.2.1 and the results of regression tasks are discussed
and the implementation details of Trompt are presented in in Section 4.2.2. The evaluation metrics are accuracy and
Section 4.1.2. Furthermore, the settings of datasets cho- r2-score for classification and regression tasks, respectively.
sen by FT-Transformer and SAINT are provided in Ap- In this section, we report an overall result and leave results
pendix B.2 and Appendix B.3, respectively. of individual datasets in Appendix B.1. In addition, the eval-
uation results on datasets chosen by FT-Transformer and
4.1.1. S ETTINGS OF G RINSZTAJN 45 SAINT are provided in Appendix B.2 and Appendix B.3,
To fairly evaluate the performance, we follow the configu- respectively.
rations of Grinsztajn45, including train test data split, data
preprocessing and evaluation metric. Grinsztajn45 com- 4.2.1. C LASSIFICATION
prises two kinds of tasks, classification tasks and regression On the medium-sized classification tasks, Figure 5 shows
tasks. Please see Appendix A.1 and Appendix A.2 for the that Trompt outperforms DNN models. The curve of Trompt
dataset selection criteria and dataset normalization process is consistently above deep neural networks (SAINT, FT-
of Grinsztajn45. The tasks are further grouped according to Transformer and ResNet) on tasks with and without categor-
(i) the size of datasets (medium-sized and large-sized) and ical features. Additionally, Trompt narrows the gap between
(ii) the inclusion of categorical features (numerical only and deep neural networks and tree-based models, especially on
heterogeneous). the tasks with heterogeneous features. In Figure 5b, Trompt
In addition, we make the following adjustments: (i) models seems to be a member of the leading cluster with four tree-
with incomplete experimental results in (Grinsztajn et al., based models. The GradientBoostingTree starts slow but
2022) are omitted, (ii) two well-performed tree-based mod- catches up the leading cluster in the end of search. The other
els are added for comparison, and (iii) Trompt used a hy- deep neural networks forms the second cluster and have a
perparameter search space smaller than its opponents. The gap to the leading one.
details of the adjustments are described in Appendix A.3 On the large-sized classification tasks, tree-based models
and Appendix A.4. remain the leading positions but the gap to deep neural
networks is obscure. This echoes that deep neural net-
4.1.2. I MPLEMENTATION D ETAILS works requires more samples for training (LeCun et al.,
Trompt is implemented using PyTorch. The default hyperpa- 2015). Figure 6a shows that Trompt outperforms ALL
rameters are shown in Table 2. The size of embeddings and models on the task with numerical features and Figure 6b
the hidden dimension of dense layers are configured d. Note shows that Trompt achieves a comparable performance to
that only the size of column and prompt embeddings must FT-Transformer on tasks with heterogeneous features.
be the same by the architecture design. The hidden dimen- With the small hyperparameter search space, the curve of
1 Trompt is relatively flat. The flat curve also suggests that
https://github.com/LeoGrin/tabular-benchmark
Trompt performs well with its default hyperparameters. Its
6
Trompt: Towards a Better Deep Neural Network for Tabular Data
(a) Numerical features only. (b) Heterogeneous features. (a) Numerical features only. (b) Heterogeneous features.
Figure 5. Benchmark on medium-sized classification datasets. Figure 7. Benchmark on medium-sized regression datasets.
Figure 6. Benchmark on large-sized classification datasets. Figure 8. Benchmark on large-sized regression datasets.
7
Trompt: Towards a Better Deep Neural Network for Tabular Data
w (default) w/o
4.4. Interpretability
(a) Important features. (b) Feature importances of
Besides outstanding performance, tree-based models are Trompt.
well-known for their interpretability. Here we explore
whether Trompt can also provide concise feature impor- Figure 10. Attention mask on Syn4 dataset (synthetic).
tances that highlighted salient features. To investigate this,
we conduct experiments on both synthetic datasets and real- Real-world datasets. The mushroom dataset (Dua & Graff,
world datasets, following the experimental design of TabNet 2017) is used as the real-world dataset for visualization as
(Arik & Pfister, 2021). To derive the feature importances of TabNet (Arik & Pfister, 2021). With only the Odor feature,
Trompt for each sample, Mimportance ∈ RB×P ×C is reduced most machine learning models can achieve > 95% test
to M̂importance ∈ RB×C as Equation (10), where the weight accuracy (Arik & Pfister, 2021). As a result, a high feature
of Mimportance is the Wprompt of Equation (6). importance is expected on Odor.
Notice that all Trompt Cells derive separated feature impor- Table 5 shows the three most important features of Trompt
tances. We demonstrate the averaged results of all cells here and five tree-based models. As shown, all models place
and leave the results of each cell in Appendix E.1. Odor in their top three. The second and third places of
Trompt, gill-size and gill-color, also appear in the top three
of the other models. Actually, cap-color is selected only by
P XGBoost. If it is excluded, the union of the top important
features of all models comes down to four features. The
X
M̂importance = (Wprompt ⊙ Mimportance ):,i,: ∈ RB×C
i=1 one Trompt missed is spore-print-color, which is the fifth
(10) place of Trompt. Overall speaking, the important features
selected by Trompt are consistent with those by tree-based
Synthetic datasets. The Syn2 and Syn4 datasets are used to
models, and can therefore be used in various analyses that
study the feature importances learned by each model (Chen
are familiar in the field of machine learning.
et al., 2018). A model is trained on oversampled training set
(10k to 100k) using default hyperparameters and evaluated To further demonstrate that the experimental results were
on 20 randomly picked testing samples. The configuration not ad-hoc, we repeat the experiments on additional real-
is identical to that in TabNet (Arik & Pfister, 2021). world datasets. Please see Appendix E.2 for the details and
8
Trompt: Towards a Better Deep Neural Network for Tabular Data
experimental results. eling of a tabular neural network might help making good
predictions.
5. Discussion
5.2. The differences between Trompt and Tree-based
In this section, we further explore the “prompt” mechanism Models
of Trompt. Section 5.1 clarifies the underlying hypothe-
sis of how the prompt learning of Trompt fits for tabular As discussed in Section 3.3, the idea of using prompt learn-
data. In addition, as Trompt is partially inspired by the ing to derive feature importances, is inspired by the learning
learning strategy of tree-based models, we further discussed algorithm of tree-based models and the intrinsic properties
the difference between Trompt and tree-based models in of tabular data. As a result, Trompt and tree-based models
Section 5.2. share a common characteristic in that they enable sample-
dependent feature importances. However, there are two
main differences between them. First, to incorporate the
5.1. Further exploration of the ”prompt” mechanism in
intrinsic properties of tabular data, Trompt uses column em-
Trompt
beddings to share the column information across samples,
The ”prompt” mechanism in Trompt is realized as Equa- while the learning strategy of tree-based models learn col-
tion (4). This equation involves a matrix multiplication of umn information in their node-split nature. Second, Trompt
expanded prompt embeddings (SE ˆ prompt ∈ RB×P ×d ) and and tree-based models use different techniques to learn fea-
the transpose of expanded column embeddings (SEcolumn ∈ ture importance. Trompt derives feature importances explic-
RB×C×d ). It results in Mimportance ∈ RP ×C , which repre- itly through prompt learning, while tree-based models vary
sents prompt-to-column feature importances. The matrix the feature importances implicitly in the root-to-leaf path.
multiplication calculates the cosine-based distance between
ˆ prompt and SEcolumn , and favors high similarity between
SE 6. Conclusion
the sample-specific representations and sample-invariant
intrinsic properties. In this study, we introduce Trompt, a novel network archi-
ˆ prompt consists of P embeddings that tecture for tabular data analysis. Trompt utilizes prompt
To make it clearer, SE
learning to determine varying feature importances in indi-
are specific to individual samples, except for the first Trompt
vidual samples. Our evaluation shows that Trompt outper-
Cell where Oprev is a zero tensor since there is no previous
forms state-of-the-art deep neural networks (SAINT and
Trompt Cell, as stated in Equations (1) and (2). On the other
FT-Transformer) and closes the performance gap between
hand, SEcolumn consists of C embeddings that represent
deep neural networks and tree-based models.
intrinsic properties specific to a tabular dataset as stated in
Equation (3). The emergence of prompt learning in deep learning is
promising. While the design of Trompt may not be intuitive
Unlike self-attention, which calculates the distance between
or perfect for language model prompts, it demonstrates the
queries and keys and derives token-to-token similarity mea-
ˆ prompt potential of leveraging prompts in tabular data analysis. This
sures, Trompt calculates the distance between SE
work introduces a new strategy for deep neural networks
and SEcolumn in Equation (4) to derive sample-to-intrinsic-
to challenge tree-based models and future research in this
property similarity measures. The underlying idea of the
direction can explore more prompt-inspired architectures.
calculation is to capture the distance between each sample
and intrinsic property of a tabular dataset and we hypothe-
size that incorporating the intrinsic properties into the mod-
9
Trompt: Towards a Better Deep Neural Network for Tabular Data
Chen, T., He, T., Benesty, M., Khotilovich, V., Tang, Y., Katzir, L., Elidan, G., and El-Yaniv, R. Net-dnf: Effective
Cho, H., Chen, K., et al. Xgboost: extreme gradient deep modeling of tabular data. In International Confer-
boosting. R package version 0.4-2, 1(4):1–4, 2015. ence on Learning Representations, 2020.
Cortez, P., Cerdeira, A., Almeida, F., Matos, T., and Reis, J. Ke, G., Meng, Q., Finley, T., Wang, T., Chen, W., Ma,
Modeling wine preferences by data mining from physic- W., Ye, Q., and Liu, T.-Y. Lightgbm: A highly efficient
ochemical properties. Decision support systems, 47(4): gradient boosting decision tree. Advances in neural infor-
547–553, 2009. mation processing systems, 30, 2017.
Devlin, J., Chang, M.-W., Lee, K., and Toutanova, K. Bert: LeCun, Y., Bengio, Y., et al. Convolutional networks for
Pre-training of deep bidirectional transformers for lan- images, speech, and time series. The handbook of brain
guage understanding. arXiv preprint arXiv:1810.04805, theory and neural networks, 3361(10):1995, 1995.
2018.
LeCun, Y., Bengio, Y., and Hinton, G. Deep learning. nature,
Dosovitskiy, A., Beyer, L., Kolesnikov, A., Weissenborn, 521(7553):436–444, 2015.
D., Zhai, X., Unterthiner, T., Dehghani, M., Minderer, M.,
Heigold, G., Gelly, S., et al. An image is worth 16x16 Lester, B., Al-Rfou, R., and Constant, N. The power of scale
words: Transformers for image recognition at scale. arXiv for parameter-efficient prompt tuning. arXiv preprint
preprint arXiv:2010.11929, 2020. arXiv:2104.08691, 2021.
10
Trompt: Towards a Better Deep Neural Network for Tabular Data
Li, X. L. and Liang, P. Prefix-tuning: Optimizing continuous Vanschoren, J., van Rijn, J. N., Bischl, B., and Torgo,
prompts for generation. arXiv preprint arXiv:2101.00190, L. Openml: networked science in machine learn-
2021. ing. SIGKDD Explorations, 15(2):49–60, 2013. doi:
10.1145/2641190.2641198. URL http://doi.acm.
Prokhorenkova, L., Gusev, G., Vorobev, A., Dorogush, A. V., org/10.1145/2641190.264119.
and Gulin, A. Catboost: unbiased boosting with categori-
cal features. Advances in neural information processing Vaswani, A., Shazeer, N., Parmar, N., Uszkoreit, J., Jones,
systems, 31, 2018. L., Gomez, A. N., Kaiser, Ł., and Polosukhin, I. At-
tention is all you need. Advances in neural information
Radford, A., Narasimhan, K., Salimans, T., Sutskever, I., processing systems, 30, 2017.
et al. Improving language understanding by generative
pre-training. 2018. Zhang, Q., Lu, H., Sak, H., Tripathi, A., McDermott, E.,
Koo, S., and Kumar, S. Transformer transducer: A stream-
Ramachandram, D. and Taylor, G. W. Deep multimodal able speech recognition model with transformer encoders
learning: A survey on recent advances and trends. IEEE and rnn-t loss. In ICASSP 2020-2020 IEEE International
signal processing magazine, 34(6):96–108, 2017. Conference on Acoustics, Speech and Signal Processing
Redmon, J., Divvala, S., Girshick, R., and Farhadi, A. You (ICASSP), pp. 7829–7833. IEEE, 2020.
only look once: Unified, real-time object detection. In
Proceedings of the IEEE conference on computer vision
and pattern recognition, pp. 779–788, 2016.
Sahoo, D., Pham, Q., Lu, J., and Hoi, S. C. Online deep
learning: Learning deep neural networks on the fly. arXiv
preprint arXiv:1711.03705, 2017.
11
Trompt: Towards a Better Deep Neural Network for Tabular Data
A. Settings of Grinsztajn45
In this section, we provide brief summaries with regard to dataset selection criteria in Appendix A.1, dataset normalization
in Appendix A.2, baseline models in Appendix A.3 and hyperparameter search mechanism in Appendix A.4.
12
Trompt: Towards a Better Deep Neural Network for Tabular Data
B.1. Grinsztajn45
In main paper, we have discussed the overall performance of Trompt using the learning curves during hyperparameter
optimization. In this section, we present quantitative evaluation results of both default and optimized hyperparameters. In
addition, we provide the figures of individual datasets for reference.
The quantitative evaluation results of classification and regression tasks are discussed in Appendix B.1.1 and Appendix B.1.2
respectively. For classification datasets, we use accuracy as the evaluation metric. For regression datasets, we use r2-score
as the evaluation metric. As a result, in both categories, the higher the number, the better the result. Besides evaluation
metrics, the ranking of each model is also provided. To derive ranking, we calculate the mean and standard deviation of all
rankings on datasets of a task. Notice that since the names of some datasets are long, we first denote each dataset a notation
in Tables 6 to 8 and use them in following tables.
13
Trompt: Towards a Better Deep Neural Network for Tabular Data
Notation Dataset
A1 KDDCup09 upselling
A2 compass
A3 covertype
A4 electricity
A5 eye movements
A6 rl
A7 road-safety
B1 Higgs
B2 MagicTelescope
B3 MiniBooNE
B4 bank-marketing
B5 california
B6 covertype
B7 credit
B8 electricity
B9 eye movements
B10 house 16H
B11 jannis
B12 kdd ipums la 97-small
B13 phoneme
B14 pol
B15 wine
14
Trompt: Towards a Better Deep Neural Network for Tabular Data
Notation Dataset
15
Trompt: Towards a Better Deep Neural Network for Tabular Data
Notation Dataset
A1 covertype
A2 road-safety
B1 covertype
B2 Higgs
B3 MiniBooNE
B4 jannis
C1 black friday
C2 diamonds
C3 nyc-taxi-green-dec-2016
C4 particulate-matter-ukair-2017
C5 SGEMM GPU kernel performance
D1 diamonds
D2 nyc-taxi-green-dec-2016
D3 year
B.1.1. C LASSIFICATION
The evaluation results for medium-sized classification tasks are presented in Table 9 for heterogeneous features, and in
Tables 10 and 11 for numerical features only.
For large-sized classification tasks, the results can be found in Table 12 for heterogeneous features, and in Table 13 for
numerical features only.
Furthermore, individual figures illustrating the performance of Trompt on medium-sized tasks are provided in Figure 11 for
heterogeneous features, and in Figure 12 for numerical features only. The individual figures for large-sized tasks can be
found in Figure 13 for heterogeneous features, and in Figure 14 for numerical features only.
The evaluation results consistently demonstrate that Trompt outperforms state-of-the-art deep neural networks (FT-
Transformer and SAINT) across all classification tasks (refer to Tables 9 to 13). Moreover, Trompt’s default rankings
consistently yield better performance than the searched rankings, indicating its strength in default configurations without
tuning. Remarkably, in a large-sized task with numerical features only, Trompt even surpasses tree-based models (refer to
Table 13).
16
Trompt: Towards a Better Deep Neural Network for Tabular Data
A1 A2 A3 A4 A5 A6 A7 Ranking
Default
Trompt (ours) 78.91% 78.59% 87.29% 84.50% 64.25% 75.13% 75.80% 3.71 ± 1.78
FT-Transformer 78.56% 73.43% 85.57% 82.71% 58.79% 71.52% 73.90% 6.29 ± 2.00
ResNet 74.24% 73.78% 82.49% 81.99% 57.14% 66.51% 73.45% 8.29 ± 2.92
SAINT 79.00% 70.09% 83.04% 82.42% 58.62% 67.69% 75.89% 6.86 ± 2.55
CatBoost 79.90% 74.22% 83.69% 85.01% 64.62% 75.29% 76.80% 2.93 ± 2.55
LightGBM 78.70% 73.63% 83.23% 86.37% 64.48% 77.04% 76.43% 3.86 ± 1.93
XGBoost 78.39% 74.46% 84.13% 87.86% 64.77% 78.42% 75.94% 3.00 ± 2.92
RandomForest 79.38% 79.28% 84.75% 86.24% 63.62% 73.82% 75.45% 3.71 ± 1.86
GradientBoostingTree 79.90% 72.01% 78.92% 82.94% 61.81% 69.60% 75.00% 6.36 ± 2.51
Searched
Trompt (ours) 79.00% 79.55% 88.29% 85.13% 64.29% 76.02% 76.38% 4.43 ± 2.20
FT-Transformer 78.00% 75.30% 86.64% 84.01% 59.85% 70.38% 76.86% 5.57 ± 2.19
ResNet 76.87% 74.35% 85.17% 82.68% 57.82% 69.59% 75.85% 8.43 ± 2.73
SAINT 77.80% 71.87% 84.95% 83.32% 58.54% 68.20% 76.43% 7.86 ± 2.64
CatBoost 80.50% 76.87% 87.48% 87.73% 66.48% 78.67% 77.16% 2.43 ± 2.71
LightGBM 79.81% 78.15% 86.62% 88.64% 66.14% 77.69% 76.43% 3.14 ± 2.05
XGBoost 79.69% 76.83% 86.25% 88.52% 66.57% 77.18% 76.69% 3.57 ± 1.93
RandomForest 79.38% 79.28% 85.89% 87.76% 65.70% 79.79% 75.88% 4.29 ± 2.27
GradientBoostingTree 80.01% 73.77% 85.55% 87.85% 63.30% 77.58% 76.23% 5.29 ± 2.17
17
Trompt: Towards a Better Deep Neural Network for Tabular Data
Table 10. The performance of medium-sized classification task (numerical features only) (1).
B1 B2 B3 B4 B5 B6 B7 B8
Default
Trompt (ours) 69.26% 86.30% 93.82% 79.36% 89.09% 82.68% 75.84% 82.89%
FT-Transformer 66.94% 84.42% 92.80% 80.09% 87.40% 80.42% 74.32% 81.24%
ResNet 65.39% 85.11% 93.10% 78.68% 86.90% 79.09% 74.99% 80.91%
SAINT 69.29% 85.16% 93.18% 79.18% 87.69% 78.05% 76.49% 81.25%
CatBoost 71.30% 86.14% 93.64% 80.45% 90.21% 80.16% 76.95% 84.48%
LightGBM 70.79% 85.47% 93.16% 80.33% 90.06% 79.50% 77.17% 84.34%
XGBoost 69.25% 85.31% 93.29% 79.81% 90.30% 79.87% 75.91% 86.11%
RandomForest 70.12% 85.56% 92.09% 79.46% 88.80% 81.35% 76.64% 84.79%
GradientBoostingTree 70.49% 84.44% 92.16% 80.27% 88.00% 76.85% 77.52% 82.16%
Searched
Trompt (ours) 69.60% 86.35% 93.74% 79.30% 89.28% 83.73% 76.52% 83.12%
FT-Transformer 70.67% 85.26% 93.59% 80.22% 88.61% 81.22% 76.50% 81.94%
ResNet 69.02% 85.62% 93.69% 79.13% 87.28% 80.21% 76.28% 80.98%
SAINT 70.73% 84.85% 93.54% 79.29% 88.92% 80.27% 76.24% 81.84%
CatBoost 71.46% 85.92% 93.84% 80.39% 90.32% 82.98% 77.59% 86.33%
LightGBM 71.01% 85.70% 93.71% 80.15% 90.13% 81.81% 77.13% 85.94%
XGBoost 71.36% 86.05% 93.66% 80.34% 90.12% 81.75% 77.26% 86.94%
RandomForest 70.76% 85.41% 92.65% 79.82% 89.21% 82.73% 77.25% 86.14%
GradientBoostingTree 71.00% 85.57% 93.22% 80.26% 89.68% 81.72% 77.27% 86.24%
18
Trompt: Towards a Better Deep Neural Network for Tabular Data
Table 11. The performance of medium-sized classification task (numerical features only) (2).
Default
Trompt (ours) 61.60% 88.05% 76.89% 86.61% 88.67% 98.49% 79.07% 4.07 ± 2.61
FT-Transformer 58.62% 87.16% 72.94% 87.16% 85.67% 98.08% 77.21% 6.93 ± 2.06
ResNet 56.06% 86.48% 70.70% 86.94% 85.37% 94.87% 77.06% 8.20 ± 2.05
SAINT 57.18% 88.19% 76.04% 88.32% 85.28% 97.04% 75.90% 6.20 ± 2.13
CatBoost 63.87% 88.59% 77.85% 87.98% 87.44% 98.46% 78.58% 2.47 ± 2.03
LightGBM 64.39% 88.43% 77.27% 87.43% 86.90% 98.38% 79.81% 3.27 ± 1.82
XGBoost 64.75% 88.16% 76.00% 87.31% 87.05% 98.35% 79.78% 4.13 ± 1.97
RandomForest 63.16% 87.92% 76.34% 88.32% 88.01% 98.10% 80.30% 3.93 ± 2.16
GradientBoostingTree 62.33% 87.68% 76.17% 88.32% 84.26% 96.71% 77.09% 5.80 ± 2.52
Searched
Trompt (ours) 62.71% 88.46% 76.99% 87.25% 88.67% 98.38% 78.58% 4.80 ± 2.47
FT-Transformer 58.30% 88.15% 76.43% 89.12% 85.66% 98.45% 76.74% 6.47 ± 2.41
ResNet 57.03% 87.54% 74.63% 88.23% 85.87% 94.86% 77.41% 7.73 ± 2.50
SAINT 58.90% 88.27% 77.22% 89.05% 85.39% 98.12% 76.87% 6.93 ± 2.22
CatBoost 65.07% 88.54% 77.95% 88.02% 88.83% 98.47% 79.89% 1.93 ± 2.36
LightGBM 65.43% 88.62% 77.70% 88.18% 87.60% 98.21% 79.55% 3.53 ± 1.44
XGBoost 65.83% 88.83% 77.83% 88.12% 86.81% 98.09% 79.46% 3.20 ± 2.22
RandomForest 65.04% 87.80% 77.27% 87.95% 88.45% 98.20% 78.96% 5.33 ± 1.88
GradientBoostingTree 63.04% 88.22% 77.17% 88.32% 86.68% 98.06% 78.56% 5.07 ± 1.77
19
Trompt: Towards a Better Deep Neural Network for Tabular Data
A1 A2 Ranking
Default
Trompt (ours) 92.76% 78.36% 1.50 ± 4.36
FT-Transformer 93.17% 76.09% 4.50 ± 3.61
ResNet 89.45% 76.53% 6.00 ± 2.25
SAINT 91.23% 77.31% 4.50 ± 1.73
CatBoost 88.27% 78.21% 4.50 ± 1.73
LightGBM 84.76% 77.97% 6.00 ± 2.84
XGBoost 87.81% 78.22% 4.50 ± 2.65
RandomForest 90.66% 77.67% 4.50 ± 1.00
GradientBoostingTree 79.46% 75.19% 9.00 ± 4.62
Searched
Trompt (ours) 93.95% 78.44% 3.50 ± 3.40
FT-Transformer 93.61% 78.92% 3.50 ± 2.36
ResNet 92.27% 78.40% 8.00 ± 3.61
SAINT 92.54% 77.96% 8.50 ± 4.36
CatBoost 93.70% 80.15% 1.50 ± 4.36
LightGBM 93.25% 79.75% 4.00 ± 1.32
XGBoost 93.07% 79.91% 4.00 ± 2.18
RandomForest 93.30% 78.13% 6.00 ± 2.47
GradientBoostingTree 92.99% 78.59% 6.00 ± 1.76
20
Trompt: Towards a Better Deep Neural Network for Tabular Data
Table 13. The performance of large-sized classification task (numerical features only).
B1 B2 B3 B4 Ranking
Default
Trompt (ours) 72.13% 94.68% 90.04% 79.54% 1.38 ± 3.44
FT-Transformer 69.60% 94.03% 89.83% 75.86% 6.00 ± 2.96
ResNet 69.88% 94.09% 88.01% 73.58% 6.00 ± 2.78
SAINT 71.81% 94.36% 86.94% 78.60% 3.75 ± 1.82
CatBoost 72.61% 94.32% 83.77% 79.54% 2.88 ± 3.01
LightGBM 72.12% 93.71% 80.71% 78.70% 5.00 ± 2.17
XGBoost 71.64% 93.67% 83.61% 78.28% 6.00 ± 1.50
RandomForest 71.58% 93.08% 87.67% 77.97% 6.00 ± 1.80
GradientBoostingTree 71.03% 92.25% 76.98% 77.18% 8.00 ± 3.29
Searched
Trompt (ours) 72.86% 94.36% 91.27% 79.88% 3.25 ± 2.97
FT-Transformer 72.86% 94.42% 90.57% 79.59% 3.25 ± 2.07
ResNet 72.29% 94.46% 89.36% 78.11% 6.75 ± 3.49
SAINT 72.65% 94.45% 89.53% 79.30% 5.50 ± 1.67
CatBoost 72.99% 94.55% 90.19% 79.89% 1.75 ± 3.49
LightGBM 72.55% 94.39% 89.71% 79.32% 6.00 ± 0.89
XGBoost 72.81% 94.40% 89.32% 79.67% 5.25 ± 2.30
RandomForest 71.98% 93.53% 90.59% 78.85% 7.00 ± 3.96
GradientBoostingTree 72.49% 94.07% 89.79% 79.34% 6.25 ± 1.95
21
Trompt: Towards a Better Deep Neural Network for Tabular Data
Figure 11. Benchmark on every medium-sized classification dataset with heterogeneous features.
22
Trompt: Towards a Better Deep Neural Network for Tabular Data
Figure 12. Benchmark on every medium-sized classification dataset with numerical features only.
23
Trompt: Towards a Better Deep Neural Network for Tabular Data
Figure 13. Benchmark on every large-sized classification dataset with heterogeneous features.
Figure 14. Benchmark on every large-sized classification dataset with numerical features only.
B.1.2. R EGRESSION
The evaluation results for medium-sized regression datasets are presented in Tables 14 and 15 for heterogeneous features,
and in Tables 16 to 18 for numerical features only.
For large-sized regression datasets, the results can be found in Table 19 for heterogeneous features, and in Table 20 for
numerical features only.
Furthermore, individual figures illustrating the performance of Trompt on medium-sized regression tasks are provided in
Figure 15 for heterogeneous features, and in Figure 16 for numerical features only. The individual figures for large-sized
tasks can be found in Figure 17 for heterogeneous features, and in Figure 18 for numerical features only.
The evaluation results consistently demonstrate that Trompt outperforms state-of-the-art deep neural networks (SAINT and
FT-Transformer) on medium-sized regression tasks (refer to Tables 14 to 18). However, Trompt’s performance is slightly
inferior to other deep neural networks on large-sized datasets (refer to Tables 19 and 20). Nevertheless, it is worth noting
that the performance of Trompt remains consistently competitive when considering all benchmark results.
24
Trompt: Towards a Better Deep Neural Network for Tabular Data
Table 14. The performance of medium-sized regression task (heterogeneous features) (1).
C1 C2 C3 C4 C5 C6 C7 C8
Default
Trompt (ours) 93.93% 99.63% 54.09% 8.71% 99.96% 94.70% 57.94% 98.88%
FT-Transformer 93.21% 88.00% 54.24% 0.00% 99.96% 93.99% 31.46% 98.84%
ResNet 89.90% 87.47% 51.99% 0.00% 99.72% 91.09% 10.79% 98.46%
SAINT 92.50% 99.20% 54.25% 11.23% 99.96% 95.10% 40.72% 98.47%
CatBoost 94.21% 99.59% 56.33% 15.16% 99.97% 98.01% 61.70% 99.11%
LightGBM 94.02% 99.38% 54.77% 14.41% 99.97% 98.23% 61.68% 99.01%
XGBoost 93.93% 99.76% 49.71% 6.64% 99.97% 97.59% 58.93% 98.96%
RandomForest 93.61% 99.30% 50.78% 13.16% 99.98% 98.00% 55.85% 98.79%
GradientBoostingTree 84.15% 99.62% 57.17% 15.30% 99.97% 98.27% 61.34% 98.42%
Searched
Trompt (ours) 94.50% 99.75% 56.87% 13.05% 99.96% 97.93% 60.17% 98.99%
FT-Transformer 93.58% 88.12% 54.90% 14.05% 99.97% 97.63% 37.93% 98.96%
ResNet 93.65% 87.83% 54.47% 12.95% 99.96% 97.83% 35.56% 98.79%
SAINT 93.89% 99.51% 55.14% 13.90% 99.97% 94.59% 58.72% 98.72%
CatBoost 94.87% 99.60% 57.74% 16.54% 99.97% 98.33% 61.79% 99.18%
LightGBM 94.37% 99.42% 55.58% 14.41% 99.97% 98.23% 61.68% 99.07%
XGBoost 94.62% 99.76% 56.87% 16.21% 99.98% 98.30% 61.88% 99.12%
RandomForest 93.79% 99.34% 57.55% 14.94% 99.98% 98.07% 60.91% 98.79%
GradientBoostingTree 94.07% 99.46% 57.53% 15.27% 99.98% 98.13% 61.54% 98.98%
25
Trompt: Towards a Better Deep Neural Network for Tabular Data
Table 15. The performance of medium-sized regression task (heterogeneous features) (2).
Default
Trompt (ours) 89.02% 9.61% 64.94% 99.95% 0.64% 5.38 ± 2.02
FT-Transformer 87.38% 12.38% 65.43% 99.94% 0.00% 6.88 ± 1.74
ResNet 86.45% 0.00% 65.23% 98.70% 0.00% 8.35 ± 2.13
SAINT 88.01% 17.48% 64.80% 99.98% 0.00% 6.31 ± 1.54
CatBoost 89.75% 54.63% 69.16% 99.99% 4.97% 2.15 ± 2.13
LightGBM 89.05% 54.48% 68.74% 99.99% 4.91% 3.00 ± 1.74
XGBoost 88.34% 56.99% 66.16% 100.00% 0.00% 4.08 ± 2.46
RandomForest 87.44% 56.18% 65.44% 100.00% 5.92% 4.23 ± 2.27
GradientBoostingTree 86.93% 46.90% 67.17% 99.94% 0.00% 4.62 ± 2.92
Searched
Trompt (ours) 89.16% 48.04% 66.33% 99.99% 3.59% 5.77 ± 1.98
FT-Transformer 88.85% 50.44% 67.18% 99.90% 3.18% 7.23 ± 1.70
ResNet 88.10% 42.42% 65.50% 99.76% 2.11% 8.31 ± 2.08
SAINT 89.18% 36.42% 66.93% 99.99% 1.21% 7.00 ± 1.98
CatBoost 89.84% 56.79% 69.33% 100.00% 9.08% 2.00 ± 2.24
LightGBM 89.33% 54.48% 68.74% 100.00% 4.91% 4.31 ± 1.18
XGBoost 89.65% 57.82% 69.08% 100.00% 8.01% 2.15 ± 1.74
RandomForest 87.50% 58.48% 67.44% 100.00% 9.52% 4.31 ± 2.80
GradientBoostingTree 89.05% 57.29% 68.30% 100.00% 5.54% 3.92 ± 1.35
26
Trompt: Towards a Better Deep Neural Network for Tabular Data
Table 16. The performance of medium-sized regression task (numerical features only) (1).
D1 D2 D3 D4 D5 D6 D7
Default
Trompt (ours) 84.80% 68.29% 99.70% 92.75% 81.17% 97.23% 94.15%
FT-Transformer 83.80% 66.92% 99.71% 91.87% 79.20% 96.85% 93.85%
ResNet 82.54% 64.52% 99.57% 91.41% 75.06% 96.75% nan
SAINT 0.00% 67.85% 99.39% 91.46% 82.04% 98.33% 94.24%
CatBoost 85.76% 69.93% 99.60% 93.56% 86.16% 98.56% 94.57%
LightGBM 84.68% 69.28% 99.38% 92.25% 84.33% 98.46% 94.49%
XGBoost 82.58% 67.93% 99.76% 92.03% 84.04% 98.25% 94.09%
RandomForest 83.71% 67.32% 99.29% 91.41% 81.54% 98.23% 93.96%
GradientBoostingTree 83.95% 67.58% 99.62% 89.42% 80.46% 98.34% 94.41%
Searched
Trompt (ours) 85.08% 68.57% 99.62% 92.80% 84.53% 98.61% 94.31%
FT-Transformer 83.90% 67.17% 99.77% 91.87% 83.00% 97.87% 94.34%
ResNet 83.21% 66.71% 99.69% 91.36% 82.03% 98.07% nan
SAINT 78.31% 68.44% 99.41% 92.10% 83.67% 98.39% 94.42%
CatBoost 85.92% 70.31% 99.62% 93.78% 86.90% 98.67% 94.59%
LightGBM 84.68% 69.28% 99.28% 93.33% 84.80% 98.31% 94.49%
XGBoost 84.58% 69.43% 99.76% 93.59% 85.64% 98.61% 94.55%
RandomForest 83.75% 68.69% 99.33% 92.42% 83.02% 98.28% 94.53%
GradientBoostingTree 84.25% 68.94% 99.60% 92.43% 84.48% 98.51% 94.47%
27
Trompt: Towards a Better Deep Neural Network for Tabular Data
Table 17. The performance of medium-sized regression task (numerical features only) (2).
Default
Trompt (ours) 89.69% 62.96% 54.53% 88.04% 83.52% 97.88% 16.99%
FT-Transformer 91.01% 63.03% 48.90% 87.42% 81.10% 97.82% 5.86%
ResNet 88.77% 62.01% 47.62% 84.71% 75.92% 97.80% 22.34%
SAINT 87.30% 64.59% 50.30% 87.34% 81.59% 97.81% 46.65%
CatBoost 91.17% 66.18% 51.01% 88.73% 84.72% 97.82% 52.91%
LightGBM 88.59% 66.49% 51.95% 88.12% 83.51% 97.85% 53.06%
XGBoost 88.48% 64.75% 48.14% 87.43% 83.74% 97.73% 54.87%
RandomForest 83.37% 63.58% 51.12% 86.87% 82.99% 97.67% 54.54%
GradientBoostingTree 80.22% 66.31% 47.33% 86.16% 78.74% 97.94% 45.15%
Searched
Trompt (ours) 90.69% 65.13% 46.50% 88.27% 83.57% 97.92% 45.57%
FT-Transformer 91.37% 64.69% 48.67% 87.56% 83.05% 97.92% 47.43%
ResNet 90.82% 64.19% 48.16% 86.72% 82.08% 97.91% 46.78%
SAINT 92.27% 65.06% 49.40% 87.87% 82.03% 97.94% 49.58%
CatBoost 91.56% 66.39% 41.22% 88.89% 85.53% 97.93% 54.06%
LightGBM 88.59% 66.49% 51.60% 88.45% 85.33% 97.85% 53.06%
XGBoost 90.67% 66.79% 54.63% 88.76% 84.95% 97.87% 55.23%
RandomForest 83.82% 65.47% 49.15% 87.10% 82.77% 97.89% 56.04%
GradientBoostingTree 85.84% 66.32% 52.49% 88.32% 84.07% 97.94% 55.21%
28
Trompt: Towards a Better Deep Neural Network for Tabular Data
Table 18. The performance of medium-sized regression task (numerical features only) (3).
Default
Trompt (ours) 95.13% 80.96% 87.91% 31.68% 18.41% 4.68 ± 2.29
FT-Transformer 94.16% 82.70% 88.01% 26.98% 0.00% 6.21 ± 2.29
ResNet 84.68% 74.54% 87.14% 26.86% 8.13% 8.06 ± 2.08
SAINT 99.04% 80.52% 89.22% 36.25% 25.92% 5.32 ± 1.86
CatBoost 98.63% 86.85% 90.51% 45.00% 27.34% 2.05 ± 2.16
LightGBM 98.70% 81.43% 89.79% 42.86% 25.50% 3.05 ± 1.89
XGBoost 98.50% 83.49% 89.55% 42.37% 16.33% 4.47 ± 2.04
RandomForest 98.67% 84.47% 90.20% 48.28% 20.69% 5.26 ± 2.40
GradientBoostingTree 93.49% 81.04% 85.62% 37.57% 24.21% 5.84 ± 2.60
Searched
Trompt (ours) 99.58% 84.15% 89.49% 40.91% 26.03% 5.11 ± 1.97
FT-Transformer 99.44% 84.26% 88.26% 36.07% 23.96% 6.37 ± 2.48
ResNet 94.99% 81.45% 89.22% 36.11% 21.73% 7.61 ± 2.31
SAINT 99.56% 78.81% 89.37% 37.38% 26.45% 5.79 ± 2.46
CatBoost 99.24% 86.84% 90.94% 50.11% 28.26% 2.26 ± 2.46
LightGBM 98.70% 81.31% 90.48% 42.86% 25.50% 4.84 ± 2.25
XGBoost 98.97% 86.03% 91.02% 50.06% 28.04% 2.79 ± 2.11
RandomForest 98.87% 85.64% 90.89% 50.43% 24.09% 5.58 ± 2.33
GradientBoostingTree 98.91% 81.31% 90.36% 45.55% 26.94% 4.58 ± 1.69
29
Trompt: Towards a Better Deep Neural Network for Tabular Data
C1 C2 C3 C4 C5 Ranking
Default
Trompt (ours) 99.96% 60.97% 99.17% 40.35% 70.48% 5.20 ± 1.50
FT-Transformer 99.94% 35.14% 99.23% 40.61% 67.61% 5.80 ± 2.48
ResNet 98.95% 33.70% 98.16% 39.71% 66.60% 8.00 ± 2.86
SAINT 99.97% 38.91% 99.18% 54.80% 68.74% 4.80 ± 0.75
CatBoost 99.98% 63.32% 99.28% 60.50% 70.68% 1.80 ± 2.25
LightGBM 99.98% 63.24% 99.16% 57.69% 70.37% 3.60 ± 1.67
XGBoost 99.98% 63.45% 99.22% 62.44% 70.60% 1.60 ± 2.73
RandomForest − − − − − −
GradientBoostingTree 99.98% 61.65% 98.57% 48.09% 67.73% 5.20 ± 1.36
Searched
Trompt (ours) 99.98% 62.86% 99.18% 54.79% 70.73% 6.20 ± 2.26
FT-Transformer 99.98% 39.00% 99.26% 57.02% 70.45% 5.80 ± 1.75
ResNet 99.98% 39.38% 99.23% 54.30% 68.71% 6.40 ± 2.88
SAINT 99.98% 39.53% 99.26% 56.58% 69.73% 5.20 ± 1.55
CatBoost 99.98% 63.62% 99.33% 62.64% 71.17% 2.60 ± 2.25
LightGBM 99.98% 63.24% 99.24% 57.69% 70.99% 4.60 ± 1.86
XGBoost 99.98% 63.90% 99.32% 64.79% 71.22% 1.20 ± 2.80
RandomForest − − − − − −
GradientBoostingTree 99.98% 63.06% 99.18% 63.62% 70.58% 4.00 ± 2.07
30
Trompt: Towards a Better Deep Neural Network for Tabular Data
Table 20. The performance of large-sized regression task (numerical features only).
D1 D2 D3 Ranking
Default
Trompt (ours) 94.58% 33.79% 24.98% 5.67 ± 1.41
FT-Transformer 94.52% 11.98% 11.72% 7.33 ± 3.07
ResNet 94.10% 24.69% 11.88% 7.33 ± 2.95
SAINT 94.45% 53.44% 28.87% 4.33 ± 2.06
CatBoost 94.76% 58.47% 30.20% 1.33 ± 3.37
LightGBM 94.75% 56.07% 28.10% 2.67 ± 2.22
XGBoost 94.74% 60.87% 25.12% 3.00 ± 2.22
RandomForest − − − −
GradientBoostingTree 94.59% 46.35% 25.74% 4.33 ± 0.48
Searched
Trompt (ours) 94.61% 52.42% 29.71% 7.33 ± 3.30
FT-Transformer 94.63% 53.82% 30.51% 5.67 ± 1.83
ResNet 94.64% 52.84% 28.01% 7.00 ± 2.63
SAINT 94.65% 54.94% 30.46% 5.00 ± 0.50
CatBoost 94.80% 59.97% 31.30% 2.00 ± 2.63
LightGBM 94.75% 56.07% 28.10% 4.67 ± 1.71
XGBoost 94.80% 62.36% 30.75% 1.33 ± 3.37
RandomForest − − − −
GradientBoostingTree 94.72% 61.72% 30.73% 3.00 ± 1.71
31
Trompt: Towards a Better Deep Neural Network for Tabular Data
Figure 15. Benchmark on every medium-sized regression dataset with heterogeneous features.
32
Trompt: Towards a Better Deep Neural Network for Tabular Data
Figure 16. Benchmark on every medium-sized regression dataset with numerical features only.
33
Trompt: Towards a Better Deep Neural Network for Tabular Data
Figure 17. Benchmark on every large-sized regression dataset with heterogeneous features.
Figure 18. Benchmark on every large-sized regression dataset with numerical features only.
34
Trompt: Towards a Better Deep Neural Network for Tabular Data
It’s important to note that due to limited computing resources, Trompt did not undergo hyperparameter search. Instead, we
obtained the performance of FT-Transformer from its original paper. In terms of the learning strategy, Trompt was trained
for 100 epochs, and the performance was evaluated using the checkpoint at the 100th epoch. This approach was adopted as
we observed that the datasets chosen by FT-Transformer are often large, making overfitting less likely.
As shown in Table 21, Trompt generally achieves comparable or slightly inferior performance when compared to the default
hyperparameter settings of FT-Transformer on the datasets specifically chosen by FT-Transformer. It is important to note
that the reported performance is an average result based on three random seeds.
Dataset Metric Trompt (ours) FT (Default) FT (Tune) #Parameters (Trompt) #Parameters (FT)
35
Trompt: Towards a Better Deep Neural Network for Tabular Data
36
Trompt: Towards a Better Deep Neural Network for Tabular Data
Table 23. Average r2-score of Trompt using different target normalizations on Grinsztajn45 regression tasks.
D.1. Hyperparameters
Ablations on the size of hidden dimension.
The hidden dimension (d) parameter in Trompt plays a crucial role in configuring various parts of the model, such as the size
of dense layers and embeddings. To evaluate the impact of different values of d, we conducted experiments using Trompt
with six different values of d.
The results presented in Table 24 demonstrate that Trompt achieves good performance when an adequate amount of hidden
dimension is used, particularly when d is larger than 32. This suggests that a larger hidden dimension allows Trompt to
capture and represent more complex patterns and relationships in the data, leading to improved performance.
37
Trompt: Towards a Better Deep Neural Network for Tabular Data
1 3 6 (default) 12
D.2. Architecture
Ablations on whether the output of previous Trompt Cell is connected to current Trompt Cell.
The connection between the output of the previous Trompt Cell and the current Trompt Cell is crucial, as it allows for the
fusion of prompt embeddings with input-related representations. This fusion results in sample-wise feature importances,
providing valuable insights into the importance of each feature. Without this connection, the feature importances of each
Trompt Cell would become deterministic and lose their variability. As illustrated in Table 26, connecting the output of the
previous Trompt Cell yields improved performance in both regression and classification tasks.
Table 26. The performance of whether the output of previous Trompt Cell is connected to current Trompt Cell.
Table 27. The performance of whether column embeddings are input independent.
38
Trompt: Towards a Better Deep Neural Network for Tabular Data
Table 28. The top-3 importance score ratio on the red wine quality dataset.
39
Trompt: Towards a Better Deep Neural Network for Tabular Data
Table 29. The top-3 importance score ratio on the white wine quality dataset.
RandomForest alcohol (24.22%) volatile acidity (12.44%) free sulfur dioxide (11.78%)
XGBoost alcohol (31.87%) free sulfur dioxide (11.38%) volatile acidity (10.05%)
LightGBM alcohol (24.02%) volatile acidity (12.47%) free sulfur dioxide (11.45%)
CatBoost alcohol (17.34%) volatile acidity (12.07%) free sulfur dioxide (11.47%)
GradientBoostingTree alcohol (27.84%) volatile acidity (13.59%) free sulfur dioxide (12.87%)
Trompt (ours) fixed acidity (10.91%) volatile acidity (10.47%) pH (10.37%)
Parameter Distribution
40
Trompt: Towards a Better Deep Neural Network for Tabular Data
Parameter Distribution
Parameter Distribution
Parameter Distribution
Parameter Distribution
41
Trompt: Towards a Better Deep Neural Network for Tabular Data
Parameter Distribution
Parameter Distribution
Parameter Distribution
Parameter Distribution
42
Trompt: Towards a Better Deep Neural Network for Tabular Data
Parameter Distribution
Parameter Distribution
43