0% found this document useful (0 votes)
19 views14 pages

Tree-Hybrid MLPs

This paper introduces a new framework called Tree-hybrid simple MLP (T-MLP) that combines the strengths of Gradient Boosted Decision Trees (GBDTs) and Deep Neural Networks (DNNs) for efficient tabular data prediction. T-MLP is designed to be hyperparameter tuning-free, lightweight, and adaptable to various datasets, achieving competitive performance across 88 benchmarks while significantly reducing training time. The framework utilizes a GBDT feature gate for feature selection and employs a simplified MLP architecture, making it a practical solution for tabular prediction tasks.

Uploaded by

zaaahra.amini
Copyright
© © All Rights Reserved
We take content rights seriously. If you suspect this is your content, claim it here.
Available Formats
Download as PDF, TXT or read online on Scribd
0% found this document useful (0 votes)
19 views14 pages

Tree-Hybrid MLPs

This paper introduces a new framework called Tree-hybrid simple MLP (T-MLP) that combines the strengths of Gradient Boosted Decision Trees (GBDTs) and Deep Neural Networks (DNNs) for efficient tabular data prediction. T-MLP is designed to be hyperparameter tuning-free, lightweight, and adaptable to various datasets, achieving competitive performance across 88 benchmarks while significantly reducing training time. The framework utilizes a GBDT feature gate for feature selection and employs a simplified MLP architecture, making it a practical solution for tabular prediction tasks.

Uploaded by

zaaahra.amini
Copyright
© © All Rights Reserved
We take content rights seriously. If you suspect this is your content, claim it here.
Available Formats
Download as PDF, TXT or read online on Scribd
You are on page 1/ 14

Team up GBDTs and DNNs: Advancing Efficient and Effective

Tabular Prediction with Tree-hybrid MLPs


Jiahuan Yan Jintai Chen∗ Qianxing Wang
Zhejiang University University of Illinois at Zhejiang University
Hangzhou, China Urbana-Champaign Hangzhou, China
[email protected] Urbana, IL, USA [email protected]
[email protected]

Danny Z. Chen Jian Wu


University of Notre Dame Zhejiang University
arXiv:2407.09790v1 [cs.LG] 13 Jul 2024

Notre Dame, IN, USA Hangzhou, China


[email protected] [email protected]
ABSTRACT CCS CONCEPTS
Tabular datasets play a crucial role in various applications. Thus, • Computing methodologies → Machine learning; Supervised
developing efficient, effective, and widely compatible prediction learning; Neural networks.
algorithms for tabular data is important. Currently, two prominent
model types, Gradient Boosted Decision Trees (GBDTs) and Deep KEYWORDS
Neural Networks (DNNs), have demonstrated performance advan- classification and regression, tabular data, green AI, AutoML
tages on distinct tabular prediction tasks. However, selecting an
effective model for a specific tabular dataset is challenging, often de- ACM Reference Format:
manding time-consuming hyperparameter tuning. To address this Jiahuan Yan, Jintai Chen∗ , Qianxing Wang, Danny Z. Chen, and Jian Wu.
model selection dilemma, this paper proposes a new framework that 2024. Team up GBDTs and DNNs: Advancing Efficient and Effective Tabular
Prediction with Tree-hybrid MLPs. In Proceedings of the 30th ACM SIGKDD
amalgamates the advantages of both GBDTs and DNNs, resulting in
Conference on Knowledge Discovery and Data Mining (KDD ’24), August
a DNN algorithm that is as efficient as GBDTs and is competitively
25–29, 2024, Barcelona, Spain. ACM, New York, NY, USA, 14 pages. https:
effective regardless of dataset preferences for GBDTs or DNNs. Our //doi.org/10.1145/3637528.3671964
idea is rooted in an observation that deep learning (DL) offers a
larger parameter space that can represent a well-performing GBDT
model, yet the current back-propagation optimizer struggles to 1 INTRODUCTION
efficiently discover such optimal functionality. On the other hand, Tabular data are a ubiquitous and dominating data structure in var-
during GBDT development, hard tree pruning, entropy-driven fea- ious machine learning applications (e.g., click-through rate (CTR)
ture gate, and model ensemble have proved to be more adaptable prediction [17] and financial risk detection [3]). Current prevalent
to tabular data. By combining these key components, we present a tabular prediction (i.e., classification and regression) models can be
Tree-hybrid simple MLP (T-MLP). In our framework, a tensorized, generally categorized into two main types: (1) Gradient Boosted
rapidly trained GBDT feature gate, a DNN architecture pruning Decision Trees (GBDTs) [16, 18, 31, 43], a kind of classical non-deep-
approach, as well as a vanilla back-propagation optimizer collabora- learning approach that has been extensively verified as test-of-time
tively train a randomly initialized MLP model. Comprehensive ex- solutions [7, 23, 55]; (2) Deep Neural Networks (DNNs), on which
periments show that T-MLP is competitive with extensively tuned continuous endeavors apply deep learning (DL) techniques from
DNNs and GBDTs in their dominating tabular benchmarks (88 computer vision (CV) and natural language processing (NLP) to
datasets) respectively, all achieved with compact model storage develop tabular learning methods such as meticulous architecture
and significantly reduced training duration. The codes and full engineering [2, 12, 22, 42, 62] and pre-training [48, 58, 69]. With
experiment results are available at https://github.com/jyansir/tmlp. recent developments of bespoke tabular DNNs, increasing stud-
ies reported their better comparability [13, 62] and even superior-
∗ The corresponding author. ity [14, 48] to GBDTs, especially in complex data scenarios [45, 58],
while classical thinking believes that GBDTs still completely sur-
pass DNNs in typical tabular tasks [7, 23], both evaluated with
Permission to make digital or hard copies of all or part of this work for personal or
classroom use is granted without fee provided that copies are not made or distributed different benchmarks and baselines, implying respective tabular
for profit or commercial advantage and that copies bear this notice and the full citation data proficiency of these two model types.
on the first page. Copyrights for components of this work owned by others than the For DNNs, their inherent high-dimensional feature spaces and
author(s) must be honored. Abstracting with credit is permitted. To copy otherwise, or
republish, to post on servers or to redistribute to lists, requires prior specific permission smooth back-propagation optimization gain tremendous success on
and/or a fee. Request permissions from [email protected]. unstructured data [10, 44] and capability of mining subtle feature
KDD ’24, August 25–29, 2024, Barcelona, Spain interactions [46, 49, 57, 62]. Besides, leveraging DNN’s transfer-
© 2024 Copyright held by the owner/author(s). Publication rights licensed to ACM.
ACM ISBN 979-8-4007-0490-1/24/08 ability, recent popular tabular Transformers can be further im-
https://doi.org/10.1145/3637528.3671964 proved by costly pre-training [48, 58, 69], like their counterparts in
KDD ’24, August 25–29, 2024, Barcelona, Spain Jiahuan Yan, Jintai Chen, Qianxing Wang, Danny Z. Chen, & Jian Wu

NLP [10, 32, 68]. However, compared to the simple multi-layer per- Table 1: Comparison of model cost-effectiveness on small and
ceptron (MLP) and GBDTs, Transformer architectures are more com- large datasets across popular tabular DNNs. 𝐹 and 𝑁 denote
plicated and are prone to be over-parameterized, data-hungry, and the amounts of features and samples, 𝑃 is the parameter
increase processing latency, especially those recent language-model- number, and 𝑇 denotes the overhead of total training time
based architectures [8, 67]. Thus, they typically under-perform on against the proposed T-MLP. We reuse performances and
tabular datasets that are potentially small-sized [23]. parameter sizes of the best model configurations in the FT-
Regarding GBDTs, they thrive on greedy feature selection, tree Transformer benchmark. 𝑇 is evaluated on an NVIDIA A100
pruning, and efficient ensemble, yielding remarkable performances PCIe 40GB (see Sec. 4.1). Based on the fixed architecture and
and efficiency on the majority of tabular prediction applications [47, training configurations, T-MLP achieves stable model size
55]. Yet, they are usually hyperparameter-sensitive [43, 63] and not and cheap training duration cost regardless of the data scale.
well-suited in extreme tabular scenarios, such as large-scale tables Dataset: Adult (𝐹 =14, 𝑁 =49K) Year (𝐹 =90, 𝑁 =515K)
with intricate feature interactions [45]. Also, their inference latency
increases markedly as the data scale grows [7]. 𝑃(M) 𝑇 ACC ↑ 𝑃(M) 𝑇 RMSE ↓
Besides, both the GBDT and DNN frameworks achieve respective MLP 0.77 7.7× 0.852 1.16 15.9× 8.853
state-of-the-art results with expensive training costs, since heavy NODE 20.83 120.4× 0.858 7.55 206.0× 8.784
AutoInt 0.01 25.0× 0.859 0.08 101.9× 8.882
hyperparameter search is required to achieve considerable perfor- DCNv2 1.18 8.0× 0.853 11.32 29.9× 8.890
mance. But, this is carbon-unfriendly and is not compatible in FT-T 3.82 19.6× 0.859 1.25 116.3× 8.855
computation-limited or real-time applications, while not enough T-MLP 0.73 1.0× 0.864 0.75 1.0× 8.768
proactive efforts on economical tabular prediction have been made.
To address the model selection dilemma, we comprehensively
combine the advantages of both GBDTs and DNNs, and propose a single T-MLP is competitive with advanced or pre-trained
new Tree-hybrid simple MLP (T-MLP), which is high-performing, DNNs, and T-MLP ensemble can even consistently outper-
efficient, and lightweight. Specifically, a single T-MLP is equipped form them and is competitive with extensively tuned state-
with a GBDT feature gate to perform sample-specific feature selec- of-the-art GBDTs, all achieved with a compact model size
tion in a greedy fashion, and GBDT-inspired pruned MLP architec- and significantly reduced training duration.
tures to process the selected salient features. The whole framework • We develop an open-source Python package with APIs of
is optimized using back-propagation with these GBDTs’ properties, benchmark loading, uniform baseline invocation (DNNs, GB-
and all the components make the system compact, overfit-resistant, DTs, T-MLP), DNN pruning, and other advanced functions
and generalizable. Furthermore, model ensemble can be efficiently as a developmental tool for the tabular learning community.
achieved by training multiple sparse MLPs (we uniformly use 3
MLPs here) in parallel with a shared gate and predicting in a bag- 2 RELATED WORK
ging manner. Overall, T-MLP has the following appealing features.
2.1 Model Frameworks for Tabular Prediction
(1) Generalized data adaptability: Different from existing tabular
prediction methods that suffer from the model selection dilemma, In the past two decades, classical non-deep-learning methods [25, 35,
T-MLP is flexible enough to handle all datasets regardless of the 65, 66] have been prevalent for tabular prediction applications, espe-
framework preference (see Sec. 4.2 and Sec. 4.4). (2) Hyperparam- cially GBDTs [16, 18, 31, 43] due to their efficiency and robustness in
eter tuning free: T-MLP is able to produce competitive results typical tabular tasks [23]. Because of the universal success of DNNs
with all the configurations pre-fixed, which is significantly on unstructured data and the development of computation devices,
time-saving, user-friendly, environmentally friendly and widely there is an increasing effort in applying DNNs to such tasks. The
practical in broader applications. (3) Lightweight storage: In T- early tabular DNNs aimed to be comparable with GBDTs by emulat-
MLP, the DNN part is purely composed of simple and highly sparse ing the ensemble tree frameworks (e.g., NODE [42], Net-DNF [30],
MLP architectures, yet is still able to be state-of-the-art competitive and TabNet [2]), but they neglected the advantages of DNNs for au-
even with one-block MLP. Table 1 presents the economical cost- tomatic feature fusion and interaction. Hence, more recent attempts
performance trade-off of T-MLP compared to common DNNs; such leveraged DNNs’ superiority, as they transferred successful neural
cost-effectiveness becomes more profound as the data scale grows. architectures (e.g., AutoInt [49], FT-Transformer [22]), proposed
In summary, our main contributions are as follows: bespoke designs (e.g., T2G-Former [62]), or adopted pre-training
(e.g., SAINT [48], TransTab [58]), reporting competitive or even
• We propose a new GBDT-DNN hybrid framework, T-MLP, surpassing results compared to conventionally dominating GB-
which is a one-stop and economical solution for effective DTs in specific data scenarios [14, 48]. Contemporary surveys [7]
tabular data prediction regardless of framework preferences demonstrated that GBDTs and DNNs are two prevailing types of
of specific datasets, offering a novel optimization paradigm frameworks in current tabular learning research.
for tabular model architectures.
• Multi-facet analysis on feature selection strategy, parameter 2.2 Lightweight DNNs
sparsity, and decision boundary pattern is given for in-depth Lightweight DNNs are an evergreen research topic in CV and NLP,
understanding of the T-MLP efficiency and superiority. which aim to maintain effective performance while promoting DNN
• Comprehensive experiments on 88 datasets from 4 bench- compactness and efficiency. A recent trend is to substitute domi-
marks, covering DNN- and GBDT-favored ones, show that a nating backbones with pure simple MLPs, such as MLP-Mixer [53],
Team up GBDTs and DNNs: Advancing Efficient and Effective Tabular Prediction with Tree-hybrid MLPs KDD ’24, August 25–29, 2024, Barcelona, Spain

(a) GBDTs (b) DNNs (c) T-MLP (Tree-hybrid MLPs)


feature MLP smooth
selection

101 𝑥! 𝑥" 𝑥#
optimization

# ID Features
pre- Transformer Ensemble
pruning Prediction
pre-trained large Shared
model
efficient models capacity
GBDTs Simplified
... ensemble ... Feature Gate Sparse MLPs

×3
LP
- Incompatible with non-typical tables - Complicated architecture engineering - One framework with both properties for all datasets

M
- Heavy hyperparameter tuning (HPT) - Heavy HPT or costly pre-training

T-
- Fixed and simple architecture, stable model size
- Inferior to GBDTs on typical tables - No HPT & pre-training

Figure 1: Our proposed T-MLP vs. existing tabular prediction approaches: GBDTs and DNNs. (a) GBDTs are classical non-
deep-learning models for tabular prediction. (b) DNNs are emerging promising methods especially for large-scale, complex,
cross-table scenarios. (c) T-MLP is a hybrid framework that integrates the strengths of both GBDTs and DNNs, accomplished via
GBDT feature gate tensorization, MLP framework pruning, simple block ensemble, and end-to-end back-propagation. It yields
competitive results on both DNN- and GBDT-favored datasets, with a rapid development process and compact model size.

gMLP [36], MAXIM [54], and other vision MLPs [11, 24, 51], achiev- inference on a sample 𝑥 ∈ R𝐹 provides 𝑇 times a single decision tree
ing comparable or even superior results to their CNN or Trans- prediction 𝑦ˆ (𝑘 ) = CART (𝑘 ) (𝑥), 𝑘 ∈ {1, 2, . . . ,𝑇 }. For each decision
former counterparts with reduced capacity or FLOPs. This pure- tree prediction, there exists a sample-specific decision path from
MLP trend is also arising in NLP [19] and other real-world appli- its root to one of the leaf nodes, forming a used feature list that
cations [15]. Another lightweight scheme is model compression, includes features involved in this prediction action. We denote this
where pruning is a predominant approach used to trim down large accessed feature list of the 𝑘-th decision tree as a binary vector
language models [56, 68] from various granularity [38, 50, 61]. In 𝛼 (𝑘 ) ∈ {0, 1}𝐹 , in which 0 indicates that the corresponding feature
the tabular prediction field, there are a few pure-MLP studies, but of this sample is not used by the 𝑘-th decision, and 1 indicates that
all focusing on regularization [29] or numerical embedding [21] it is accessed. Consequently, we can represent the GBDT feature
rather than the DNN architecture itself. Besides, model compression frequency of the sample with the sum of the 𝑘 decision trees’ binary
of tabular DNNs has not yet been explored. We introduce related vectors, as:
techniques to make our T-MLP more compact and effective.
∑︁
𝛼= 𝛼 (𝑘 ) ,
𝑘
3 TREE-HYBRID SIMPLE MLP where 𝛼 represents the exploitation level of each feature in the
We first review some preliminaries of typical GBDTs’ inference GBDT, suggesting the feature preference of the GBDT model on
process and feature encoding techniques in current Transformer- this sample.
based tabular DNNs. Next, we elaborate on the detailed designs of
several key components of T-MLP, including the GBDT feature gate Feature Tokenizer. Inspired by the classical language models
for sample-specific feature selection, the pure-MLP basic block, and (e.g., BERT [32]), recent dominating Transformer-based tabular
GBDT-inspired fine-grained pruning for sparse MLPs. Finally, we models [22, 48, 62] adopted distributed feature representation [39]
provide a discussion of the T-MLP workflow. by embedding tabular values into vector spaces and treating the
values as “unordered” word vectors. Such Transformer models
use feature tokenizer [22] to process tabular features as follows:
3.1 Preliminaries Each tabular scalar value is mapped to a vector 𝑒 ∈ R𝑑 with
Problem Statement. Given a tabular dataset with input features a feature-specific linear projection, where 𝑑 is the feature hid-
𝑋 ∈ R𝑁 ×𝐹 and targets 𝑦 ∈ R𝑁 , the tabular prediction task is to den dimension. For numerical (continuous) values, the projection
find an optimal solution 𝑓 : ∈ R𝑁 ×𝐹 → R𝑁 that minimizes the weights are multiplied with the value magnitudes. Given 𝐹 1 nu-
empirical difference between the predictions 𝑦ˆ and the targets 𝑦. merical features and 𝐹 2 categorical features, the feature tokenizer
Here in current practice, the common choice of 𝑓 is either tradi- outputs feature embedding 𝐸 ∈ R (1+𝐹1 +𝐹2 ) ×𝑑 by stacking pro-
tional GBDTs (e.g., XGBoost [16], CatBoost [43], LightGBM [31]) or jected features (and an extra [CLS] tokeni embedding), i.e., 𝐸 =
tabular DNNs (e.g., TabNet [2], FT-Transformer [22], SAINT [48],
h
(1) (𝐹 )
1 (1) (𝐹 )
stack 𝑒 CLS, 𝑒 num, . . . , 𝑒 num , 𝑒 cat , . . . , 𝑒 cat2 .
T2G-Former [62]). A typical difference metric is accuracy or AUC
score for classification tasks, and is the root of mean squared error
(RMSE) for regression.
3.2 GBDT Feature Gate
Early attempts of tabular DNNs tried to emulate behavioral patterns
Definition 3.1: GBDT Feature Frequency. Given a GBDT model of GBDTs by ensembling neural networks to build differential tree
with 𝑇 decision trees (e.g., CART [35]), the GBDT feature frequency models, such as representative models NODE [42] and TabNet [2].
of a sample denotes the number of times each feature is accessed However, even realizing decision-tree-like hard feature selection or
by this GBDT on the sample. Specifically, the process of GBDT resorting to complicated Transformer architectures, they were still
KDD ’24, August 25–29, 2024, Barcelona, Spain Jiahuan Yan, Jintai Chen, Qianxing Wang, Danny Z. Chen, & Jian Wu

rapidly submerged in subsequent DNN studies that mainly focused frequency 𝛼ˆ for each sample during the first-epoch computation,
on promotion from deep learning perspectives [13, 22, 62]. We seek and reuse the cache in the subsequent model training or inference.
to rethink this line of work and observe that they achieve hard
feature selection with learnable continuous feature masks through 3.3 Pure MLP Basic Block
DNNs’ smooth back-propagation, which may be incompatible with To explore the capability of pure-MLP architecture and keep our
the discrete nature of GBDTs, and hence restrict their potential. tabular model compact, we take inspiration from vision MLPs.
To resolve this issue, we propose GBDT Feature Gate (GFG), a We observe that a key factor of their success is the attention-like
GBDT-based feature selector tensorized with GBDT weights to interaction realized by linear projection and soft gating on fea-
faithfully replicate its feature selection behavior. Specifically, given tures [11, 24, 53]. Thus, we employ the spatial gating unit (SGU)
a GFG initialized by a 𝑇 -tree GBDT, the feature selection process proposed in [36], and formulate a simplified pure-MLP block, as:
on an 𝐹 -feature sample 𝑥 (𝐸ˆ = GFG(𝑥) ∈ R𝐹 ×𝑑 ) is formulated as:
𝐸ˆ (𝑙+1) = SGU(GELU(LayerNorm( 𝐸ˆ (𝑙 ) )𝑊1 ))𝑊2 + 𝐸ˆ (𝑙 ) , (5)
𝐸 = FeatureTokenizer(𝑥) ∈ R𝐹 ×𝑑 , (1) SGU(𝑋 ) = 𝑊3 LayerNorm(𝑋 :,:𝑑 ′ ) ⊙ 𝑋 :,𝑑 ′ : . (6)
𝛼 = GBDTFeatureFrequency(𝑥) ∈ R𝐹 , (2) The block is similar to a single feed-forward neural network (FFN)
𝛼ˆ = 𝛼/𝑇 ∈ R𝐹 , 𝛼¯ = BinarySampler(𝛼) ˆ ∈ {0, 1}𝐹 , (3) in the Transformer with an extra SGU (Eq. (6)) for feature-level in-
( teraction. The main parameters are located in two transformations,
𝛼¯ ⊙ 𝐸 :,𝑖 if training ′ ′
i.e., 𝑊1 ∈ R𝑑 ×2𝑑 and 𝑊2 ∈ R𝑑 ×𝑑 in Eq. (5), where 𝑑 ′ corresponds
𝐸ˆ:,𝑖 = , 𝑖 ∈ {1, 2, . . . , 𝑑 } . (4) ′
𝛼ˆ ⊙ 𝐸 :,𝑖 if inference to the FFN intermediate dimension size. In Eq. (6), 𝑋 ∈ R𝐹 ×2𝑑
denotes the input features of SGU, and 𝑊3 ∈ R 𝐹 ×𝐹 is a feature-level
The extra [CLS] embedding is omitted in this subsection for
notation brevity; in implementation, it is directly concatenated to transformation to emulate attention operation. Since 𝑑 ≈ 𝑑 ′ ≫ 𝐹
ˆ In Eq. (3), 𝛼ˆ is the normalized GBDT feature
the head of the gated 𝐸. in most cases, the model size is determined by 𝑊1 and 𝑊2 , and is
frequency that represents the access probabilities of each feature in comparable to the FFN size. All the bias vectors are omitted for
the 𝑇 -tree GBDT, and 𝛼¯ is a binary feature mask sampled with the notation brevity.
ˆ To incorporate the GBDT’s feature preference into
probabilities 𝛼. Analogous to vision data, we treat tabular features and feature
the DNN framework, in Eq. (4), we use sparse feature masks from embeddings as image pixels and channels. But completely differ-
real GBDT feature access probabilities to perform hard feature selec- ent from vision MLPs, T-MLP is a hybrid framework tailored for
tion during training, and use the soft probabilities during inference economical tabular prediction that performs competitively against
for deterministic prediction. GFG assists in filtering out unneces- tabular Transformers and GBDTs with significantly reduced run-
sary features according to the GBDT’s feature preference, ensuring time costs. On most uncomplicated datasets, using only one basic
an oracle selection behavior compared to previous differential tree block in T-MLP is enough. In comparison, previous vision MLP
models in learning feature masks with neural networks. studies emphasized architecture engineering and often demanded
Since the original GBDT library (we uniformly use XGBoost dozens of blocks in order to be comparable to vision Transformers.
in this work) has no APIs for efficiently fetching sample-specific
GBDT feature frequency in Eq. (2) and the used backend is incom-
3.4 Sparsity with User-controllable Pruning
patible with common DL libraries (e.g., PyTorch), to integrate the Inspired by the pre-pruning of GBDTs that controls model complex-
GFG module into the parallel DNN framework, we tensorize the ity and promotes generalization with user-defined hyperparameters
behavior of Eq. (2). Technically, we are inspired by the principle of (e.g., maximum tree depth, minimum samples per leaf), we design
the Microsoft Hummingbird compiling tools1 and extract routing a similar mechanism for T-MLP by leveraging the predominant
matrices, a series of parameter matrices that contain information model compression approach, i.e., DNN pruning [26, 38, 50], which
of each decision tree’s node adjacency and threshold values, from is widely used in NLP research to trim down over-parameterized
the XGBoost model. Based on the extracted routing matrices, fea- language models while maintaining the original reliability [61].
ture access frequency can be simply acquired through alternating Specifically, we introduce two fine-grained variables 𝑧 h ∈ {0, 1}𝑑

tensor multiplication and comparison on input features 𝑥, and the and 𝑧 in ∈ {0, 1}𝑑 to mask parameters from hidden dimension and
submodule of Eq. (2) is initialized with these parameter matrices. intermediate dimension, respectively. As the previous FFN prun-
In the actual implementation, we just rapidly train an XGBoost ing in language models [59], the T-MLP pruning operation can be
with uniform default hyperparameters provided in [22] (regardless attained by simply applying the mask variables to the weight matri-
of its performance) to initialize and freeze the submodule of Eq. (2) ces, i.e., substituting 𝑊1 and 𝑊2 with diag(𝑧 h )𝑊1 and diag(𝑧 in )𝑊2
during the T-MLP initialization step. Other trainable parameters in Eq. (5). We use the classical 𝑙 0 regularization reparametrized with
are randomly initialized. Since there are a large number of deci- hard concrete distributions [37], and adopt a Lagrangian multiplier
sion trees to vote the feature preference in a GBDT model, slight objective to achieve the controllable sparsity as in [61].
hyperparameter modification will not change the overall feature Although early attempts of tabular DNNs have considered sparse
preference trend, and a lightly-trained default XGBoost is always structures, for example, TabNet [2] and NODE [42] built learn-
usable enough to guide greedy feature selection. To further speed able sparse feature masks, and more recently TabCaps [12] and
up the processes in Eqs. (2)-(3), we cache the normalized feature T2G-Former [62] designed sparse feature interaction, there are two
essential differences: (1) existing tabular DNNs only considered
1 https://github.com/microsoft/hummingbird sparsity on the feature dimension, while T-MLP introduces sparsity
Team up GBDTs and DNNs: Advancing Efficient and Effective Tabular Prediction with Tree-hybrid MLPs KDD ’24, August 25–29, 2024, Barcelona, Spain

on the input features (Sec. 3.2) and the hidden dimension (this sub- patterns of common DNNs, GBDTs, and T-MLP by visualizing their
section), which was ignored in previous tabular DNN prediction decision boundaries to further examine the superiority of T-MLP.
studies and widely recognized as an over-parameterized facet in
NLP practice [59, 61]; (2) learnable sparsity in existing tabular DNNs 4.1 Experimental Setup
is completely coupled and determined by prediction loss functions, Datasets. We use four recent high-quality tabular benchmarks
while our introduced DNN pruning techniques determine the spar- (FT-Transformer2 (FT-T, 11 datasets) [22], T2G-Former3 (T2G, 12
sity based on the user-defined sparsity rate (objective-independent), datasets) [62], SAINT4 (26 datasets) [48], and Tabular Benchmark5
with the same controllable nature of GBDTs pre-pruning. (TabBen, 39 datasets) [23]), considering their elaborated results on
In the main experiments (Sec. 4.2), we uniformly fix the target extensive baselines and datasets. The FT-T and T2G benchmarks
sparsity at 0.33 for T-MLP, i.e., only around 33% of DNN parameters are representative of large-scale tabular datasets, whose sizes vary
are retained after training. We further explore the relationship from 10K to 1,000K and include various DNN baselines. The SAINT
between model sparsity and performance in Sec. 4.3, and obtain benchmark is gathered from the OpenML repository6 , and is dom-
performance boost with suitable parameter pruning, even on T-MLP inated by the pre-trained DNN SAINT, containing balanced task
with one basic block, This implies pervasive over-parameterization types and diverse GBDTs. TabBen is based on “typical tabular data”
in previous tabular DNN designs. settings that constrain dataset properties, e.g., the data scale (a
maximum data volume of 10K) and the feature number (not high-
3.5 Overall Workflow and Efficient Ensemble dimension) [23], and the datasets are categorized into several types
The overall T-MLP workflow is as follows: During the training stage, with combinations of task types and feature characteristics. No-
the input tabular features are embedded with the feature tokenizer tably, on TabBen, GBDTs achieve overwhelming victory, surpassing
and discretely selected by the sampled feature mask 𝛼¯ in Eq. (3); commonly-used DNNs. Each benchmark represents a specific frame-
then, they are processed by a single pruned basic block in Eq. (5), work preference. Since several benchmarks have adjusted dataset
and the pruning parameter masks 𝑧 h and 𝑧 in are sampled with arrangements in their current repositories (e.g., some datasets were
reparameterization on the 𝑙 0 regularization; the final prediction is removed and some were added), to faithfully follow and reuse the
made with the [CLS] token feature using a normal prediction head results, we only retain the datasets reported in the published origi-
as in other tabular Transformers, as: nal papers. We provide detailed benchmark statistical information
in Table 2 and discuss benchmark characteristics in Appendix A.
(𝑙 )
𝑦ˆ = FC(ReLU(LayerNorm( 𝐸ˆ [CLS],: ))),

where FC denotes a fully connected layer. We use the cross entropy


loss for classification and the mean squared error loss for regression
as in previous tabular DNNs. The whole framework is optimized
with back-propagation. After training, the parameter masks are
directly applied to 𝑊1 and 𝑊2 by accordingly dropping the pruned
hidden and intermediate dimensions. In the inference stage, the
T2G SAINT TabBen
input features are softly selected by the normalized GBDT feature
frequency 𝛼ˆ in Eq. (3), and processed by the simplified basic block.
Since the T-MLP architecture is compact and computation-friendly
with low runtime cost, we further provide an efficient ensemble ver- Figure 2: The winning rates of GBDTs and DNNs on three
sion by simultaneously training three branches with the shared benchmarks, which represent the proportion of each frame-
GBDT feature gate from the same initialization point with three work achieving the best performance in the benchmarks. It
fixed learning rates. This produces three different sparse MLPs, exhibits varying framework preferences among the datasets
inspired by the model soups ensemble method [60]. The final en- used in different tabular prediction works.
semble prediction is the average result of the three branches as in
a bagging ensemble model. Since the ensemble learning process
can be implemented by simultaneous training and inference with Implementation Details. We implement our T-MLP model and
multi-processing programming (e.g., RandomForest [9]), the train- Python package using PyTorch on Python 3.10. Since the reported
ing duration is not tripled but determined by the slowest converging baseline training durations on the original benchmarks are esti-
branch. mated under different runtime environments and using different
evaluation codes, and do not consider hyperparameter tuning (HPT)
budgets, for uniform comparison of training costs, we encapsulate
4 EXPERIMENTS the experimental baselines with the same sklearn-style APIs as
In this section, we first compare our T-MLP with advanced DNNs T-MLP in our built package, and conduct all the experiments on
and classical GBDTs on their dominating benchmarks (including 88
2 https://github.com/yandex-research/rtdl-revisiting-models/tree/main
datasets for different task types) and analyze from the perspective
3 https://github.com/jyansir/t2g-former/tree/master
of cost-effectiveness. Next, we conduct ablation and comparison 4 https://github.com/somepago/saint/tree/main
experiments with multi-facet analysis to evaluate the key designs 5 https://github.com/LeoGrin/tabular-benchmark/tree/main

that make T-MLP effective. Besides, we compare the optimized 6 https://www.openml.org


KDD ’24, August 25–29, 2024, Barcelona, Spain Jiahuan Yan, Jintai Chen, Qianxing Wang, Danny Z. Chen, & Jian Wu

Table 2: Dataset statistics on four experimental benchmarks. “# bin., # mul., and # reg.” are the amounts of binary classification,
multi-class classification, and regression datasets. “# small, # middle, # large, and # ex. large” represent the amounts of small
(𝑁 ≤ 3K), middle (3K < 𝑁 ≤ 10K), large (10K < 𝑁 ≤ 100K), and extremely large (𝑁 > 100K) datasets, where 𝑁 denotes the
training data size. “# wide and # ex. wide” are the amounts of wide (32 < 𝐹 ≤ 64) and extremely wide (𝐹 > 64) datasets, where 𝐹
is the feature amount. “bin. metric, mul. metric, and reg. metric” represent the evaluation metrics used for each task type in
the benchmarks. “R-Squared” score is the coefficient of determination.

# bin. # mul. # reg. # small # middle # large # ex. large # wide # ex. wide bin. metric mul. metric reg. metric
FT-T [22] 3 4 4 0 0 6 5 2 5 ACC ACC RMSE
T2G [62] 3 5 4 0 3 7 2 2 2 ACC ACC RMSE
SAINT [48] 9 7 10 10 3 12 1 6 9 AUC ACC RMSE
TabBen [23] 15 0 24 2 37 0 0 5 2 ACC N/A R-Squared

Table 3: Cost-effectiveness comparison on the FT-T benchmark. Classification datasets and regression datasets are evaluated
using the accuracy and RMSE metrics, respectively. “Rank” denotes the average values (standard deviations) of all the methods
across the datasets. “𝑇 ” represents the average overhead of the used training time against T-MLP, and “𝑇 ∗ ” compares only the
duration before achieving the best validation scores. All the training durations are estimated with the original hyperparameter
search settings. “𝑃” denotes the average parameter number of the best model configuration provided by the FT-T repository.
TabNet is not compared considering its different backend (Tensorflow) in the evaluation. The top performances are marked in
bold, and the second best ones are underlined (similar marks are used in the subsequent tables).

CA ↓ AD ↑ HE ↑ JA ↑ HI ↑ AL ↑ EP ↑ YE ↓ CO ↑ YA ↓ MI ↓ Rank 𝑇 𝑇∗ 𝑃(M)
TabNet 0.510 0.850 0.378 0.723 0.719 0.954 0.8896 8.909 0.957 0.823 0.751 9.0 (1.5) N/A N/A N/A
SNN 0.493 0.854 0.373 0.719 0.722 0.954 0.8975 8.895 0.961 0.761 0.751 7.8 (1.1) ×42.76 ×24.87 1.12
AutoInt 0.474 0.859 0.372 0.721 0.725 0.945 0.8949 8.882 0.934 0.768 0.750 7.4 (2.1) ×121.68 ×112.31 1.14
GrowNet 0.487 0.857 N/A N/A 0.722 N/A 0.8970 8.827 N/A 0.765 0.751 N/A N/A N/A N/A
MLP 0.499 0.852 0.383 0.719 0.723 0.954 0.8977 8.853 0.962 0.757 0.747 6.5 (1.7) ×27.41 ×28.46 0.55
DCNv2 0.484 0.853 0.385 0.716 0.723 0.955 0.8977 8.890 0.965 0.757 0.749 6.4 (1.8) ×31.15 ×40.65 4.17
NODE 0.464 0.858 0.359 0.727 0.726 0.918 0.8958 8.784 0.958 0.753 0.745 5.4 (3.2) ×386.54 ×353.38 16.59
ResNet 0.486 0.854 0.396 0.728 0.727 0.963 0.8969 8.846 0.964 0.757 0.748 4.5 (2.2) ×56.20 ×58.46 6.16
FT-T 0.459 0.859 0.391 0.732 0.720 0.960 0.8982 8.855 0.970 0.756 0.746 3.3 (2.4) ×117.35 ×97.49 2.12
T-MLP 0.447 0.864 0.386 0.728 0.729 0.956 0.8977 8.768 0.968 0.756 0.747 3.1 (0.9) ×1.00 ×1.00 0.79
T-MLP(3) 0.438 0.867 0.386 0.732 0.730 0.960 0.8978 8.732 0.969 0.755 0.745 1.7 (0.8) ×1.05 ×1.08 2.37

Table 4: Cost-effectiveness comparison on the T2G benchmark with similar notations as in Table 3. The baseline performances
and configurations are also reused from the T2G repository. According to the T2G paper, for the extremely large dataset Year,
FT-T and T2G use 50-iteration hyperparameter tuning (HPT), DANet-28 follows its default hyperparameters, and the other
baseline results are acquired with 100-iteration HPT.

GE ↑ CH ↑ EY ↑ CA ↓ HO ↓ AD ↑ OT ↑ HE ↑ JA ↑ HI ↑ FB ↓ YE ↓ Rank 𝑇 𝑇∗ 𝑃(M)
XGBoost 0.684 0.859 0.725 0.436 3.169 0.873 0.825 0.375 0.719 0.724 5.359 8.850 4.3 (3.1) ×32.78 ×42.88 N/A
MLP 0.586 0.858 0.611 0.499 3.173 0.854 0.810 0.384 0.720 0.720 5.943 8.849 8.3 (1.9) ×13.73 ×11.45 0.64
SNN 0.647 0.857 0.616 0.498 3.207 0.854 0.812 0.372 0.719 0.722 5.892 8.901 8.3 (1.5) ×22.74 ×12.54 0.82
TabNet 0.600 0.850 0.621 0.513 3.252 0.848 0.791 0.379 0.723 0.720 6.559 8.916 10.2 (2.4) N/A N/A N/A
DANet-28 0.616 0.851 0.605 0.524 3.236 0.850 0.810 0.355 0.707 0.715 6.167 8.914 10.6 (2.0) N/A N/A N/A
NODE 0.539 0.859 0.655 0.463 3.216 0.858 0.804 0.353 0.728 0.725 5.698 8.777 7.0 (3.0) ×329.79 ×288.21 16.95
AutoInt 0.583 0.855 0.611 0.472 3.147 0.857 0.801 0.373 0.721 0.725 5.852 8.862 8.1 (2.0) ×68.30 ×55.52 0.06
DCNv2 0.557 0.857 0.614 0.489 3.172 0.855 0.802 0.386 0.716 0.722 5.847 8.882 8.4 (2.0) ×24.40 ×21.63 2.30
FT-T 0.613 0.861 0.708 0.460 3.124 0.857 0.813 0.391 0.732 0.731 6.079 8.852 4.7 (2.6) ×64.68 ×50.90 2.22
T2G 0.656 0.863 0.782 0.455 3.138 0.860 0.819 0.391 0.737 0.734 5.701 8.851 3.1 (1.7) ×88.93 ×87.04 1.19
T-MLP 0.706 0.862 0.717 0.449 3.125 0.864 0.814 0.386 0.728 0.729 5.667 8.768 3.3 (0.9) ×1.00 ×1.00 0.72
T-MLP(3) 0.714 0.866 0.747 0.438 3.063 0.867 0.823 0.386 0.732 0.730 5.629 8.732 1.9 (0.8) ×1.09 ×1.11 2.16
Team up GBDTs and DNNs: Advancing Efficient and Effective Tabular Prediction with Tree-hybrid MLPs KDD ’24, August 25–29, 2024, Barcelona, Spain

NVIDIA A100 PCIe 40GB. All the hyperparameter spaces and iter- Table 5: The average values (standard deviations) of all the
ation numbers of the baselines follow the settings in the original method ranks on the SAINT benchmark of three task types.
papers to emulate the tuning process of each baseline. For T-MLP, |𝐷 | is the dataset number in each group. Notably, all the base-
we use fixed hyperparameters as the model is trained only once. line results are based on HPT, and SAINT variants need fur-
The XGBoost used for T-MLP’s GBDT Feature Gate is in default ther training budgets on pre-training and data augmentation.
configuration as in [22]. In experiments, each single T-MLP uses one More detailed results are given in the Appendix.
basic block for most datasets if without special specification. We
uniformly use a learning rate of 1e-4 for a single T-MLP and learn- Binclass Multiclass Regression
Task Type: (|𝐷 |=9) (|𝐷 |=7) (|𝐷 |=10)
ing rates of 1e-4, 5e-4, and 1e-3 for the three branches in the T-MLP
ensemble (group “T-MLP(3)”). We reuse the same data splits as in RF 7.8 (3.3) 7.3 (2.2) 9.1 (4.2)
the original benchmarks. The baseline performances are inherited ExtraTrees 7.8 (3.8) 9.6 (1.9) 8.6 (3.5)
from the reported benchmark results, and the baseline capacities KNeighborsDist 13.7 (0.7) 11.6 (3.5) 12.9 (1.8)
KNeighborsUnif 14.4 (0.5) 12.4 (3.4) 14.0 (1.0)
are calculated based on the best model configurations provided in LightGBM 5.7 (3.3) 3.9 (2.8) 6.5 (3.2)
the corresponding paper repositories. Detailed information of the XGBoost 4.2 (2.8) 6.7 (3.5) 7.3 (2.9)
runtime environment and hyperparameters is given in Appendix C. CatBoost 3.9 (2.8) 7.2 (2.4) 5.6 (2.7)
MLP 10.7 (1.8) 10.1 (3.9) N/A
Compared Methods. On the four benchmarks, we compare our NeuralNetFastAI N/A N/A 11.9 (2.2)
T-MLP (the single-model and 3-model-ensemble versions) with: (1) TabNet 13.2 (2.0) 13.5 (1.1) 10.2 (4.5)
known non-pre-trained DNNs: MLP, ResNet, SNN [33], GrowNet [4], TabTransformer 10.8 (1.4) 10.0 (3.6) 10.0 (2.9)
TabNet [2], NODE [42], AutoInt [49], DCNv2 [57], TabTransformer [28], SAINT-s 7.8 (2.4) 7.9 (6.1) 4.7 (3.8)
SAINT-i 7.2 (2.6) 7.1 (2.7) 5.9 (3.5)
DANets [13], FT-Transformer (FT-T) [22], and T2G-Former (T2G) [62]; SAINT 4.2 (2.7) 5.2 (2.2) 4.2 (2.3)
(2) pre-trained DNN: SAINT [48]; (3) GBDT models: XGBoost [16],
CatBoost [43], LightGBM [31], GradientBoostingTree (GBT), Hist- T-MLP 4.6 (2.8) 4.6 (3.0) 4.6 (3.3)
GradientBoostingTree (HistGBT), and other traditional non-deep T-MLP(3) 3.9 (1.9) 2.9 (2.5) 5.0 (2.9)
machine learning methods like RandomForest (RF) [9]. For other
unmentioned baselines, please refer to Appendix B. In the experi-
ment tables below, “T-MLP” denotes a single T-MLP and “T-MLP(3)” ensemble often approximates that of the single T-MLP. From the
denotes the ensemble version with three branches. perspective of model storage, as expected, the size of the single
T-MLP is comparable to the average level of naive MLPs across the
4.2 Main Results and Analysis datasets and its size variation is stable (see Table 1), since the block
number, hidden dimension size, and sparsity rate are all fixed. In
In Table 3 to Table 6, the baseline results are based on heavy HPT,
Sec. 4.3, we will further analyze the impact of model sparsity and
and are obtained from respectively reported benchmarks. All the
theoretical complexity of the model parameters.
T-MLP results are based on default hyperparameters.
Comparison with Pre-trained DNNs. Table 5 reports the means
Comparison with Advanced DNNs. Tables 3 and 4 report de- and standard deviations of model ranks on the SAINT benchmark [48].
tailed performances and runtime costs on the FT-T and T2G bench- Surprisingly, we find that the simple pure MLP-based T-MLP outper-
marks for comparison of our T-MLP versions and bespoke tabular forms Transformer-based SAINT variants (SAINT-s and SAINT-i)
DNNs [22, 62]. The baseline results in these tables are based on 50 and is comparable with SAINT on all the three task types. It is
(for complicated models on large datasets, e.g., FT-Transformer on worth noting that SAINT and its variants adopt complicated inter-
the Year dataset) or 100 (the other cases) iterations of HPT except sample attention and self-supervised pre-training along with HPT
special models (default NODE for the datasets with large class num- on parameters of the training process. Moreover, T-MLP ensemble
bers and default DANets for all datasets). An overall trend that one even achieves stable results that are competitive to tuned GBDTs
may observe is that the single T-MLP is able to achieve competitive (i.e., XGBoost, CatBoost, LightGBM) and surpasses the pre-trained
results as the state-of-the-art DNNs on each benchmark, and a sim- SAINT on classification tasks. Since the detailed HPT conditions
ple ensemble of three T-MLPs (i.e., “T-MLP(3)”) exhibits even better (i.e., iteration times, HPT methods, parameter sampling distribu-
performances with significantly reduced training costs. Specifically, tions) are not reported, we do not estimate specific training costs.
benefiting from fixed hyperparameters and simple structures, the
single T-MLP achieves obvious speedup and reduces training du- Comparison with Extensively Tuned GBDTs. Table 6 compares
rations by orders of magnitude compared to the powerful DNNs, T-MLP on the typically GPDTs-dominating benchmark TabBen [23],
and is also more training-friendly than XGBoost, a representative on which GBDT frameworks completely outperform various DNNs
GBDT that highly relies on heavy HPT. Besides, we observe only across all types of datasets. Results of each baseline on TabBen are
about 10% training duration increase in T-MLP ensemble since we obtained with around 400 iterations of heavy HPT, almost repre-
adopt multiprocessing programming to simultaneously train the senting the ultimate performances with unlimited computation re-
three T-MLPs (see Sec. 3.5) and thus the training time depends sources and budgets. As expected, when extensively tuned XGBoost
on the slowest converging sub-model. In the implementation de- is available, the single T-MLP is eclipsed, but it is still competitive to
tails (Sec. 4.1), the single T-MLP uses the smallest learning rate in the other ensemble tree models (i.e., RF, GBT, HistGBT) and superior
the three sub-models, and hence the convergence time of T-MLP to the compared DNNs. Further, we find that T-MLP ensemble is
KDD ’24, August 25–29, 2024, Barcelona, Spain Jiahuan Yan, Jintai Chen, Qianxing Wang, Danny Z. Chen, & Jian Wu

Table 6: The average values (standard deviations) of all the Table 7: Main ablation and comparison on classical tables in
method ranks on TabBen (four dataset types). “Num.” and various task types and data scales. The top 4 rows: ablations
“Cat.” denote numerical datasets (all features are numeri- on key designs in the T-MLP framework. The bottom 2 rows:
cal) and categorical datasets (some features are categorical), results of T-MLP with neural network feature gate (NN FG).
respectively. “Classif.” and “Reg.” denote classification and
regression tasks. “Num. Reg.” group includes only results of Dataset: CA (21K) ↓ AD (49K) ↑ HI (98K) ↑ YE (515K) ↓
regression on numerical datasets (similar notations are for
the others). |𝐷 | is the dataset number in each group. Baseline T-MLP 0.4471 0.864 0.729 8.768
w/o sparsity 0.4503 0.857 0.726 8.887
test results are obtained based on the best validation results
w/o GBDT FG 0.4539 0.859 0.728 8.799
during ∼400 iterations of HPT (according to the TabBen paper w/o both 0.4602 0.856 0.724 8.896
and repository). Detailed results are given in the Appendix.
T-MLP (NN FG) 0.4559 0.852 0.718 8.925
w/o sparsity 0.4557 0.840 0.713 8.936
Num. Classif. Num. Reg. Cat. Classif. Cat. Reg.
Dataset Type: (|𝐷 |=9) (|𝐷 |=14) (|𝐷 |=6) (|𝐷 |=10)
MLP 8.4 (0.8) N/A N/A N/A
ResNet 6.9 (1.9) 6.5 (1.9) 7.8 (1.0) 7.7 (0.5) compared to the others (14, 28, and 90 features in AD, HI, and YE,
FT-T 5.7 (1.9) 5.5 (2.3) 5.5 (2.2) 6.7 (1.1) respectively) and the average feature importance may be relatively
SAINT 6.9 (1.4) 5.5 (2.2) 8.0 (1.1) N/A large; thus, the CA results are more likely to be affected by feature
GBT 4.7 (2.0) 4.3 (1.7) 5.2 (2.3) 4.3 (1.1) selection. For the datasets with larger feature amounts, selecting
HistGBT N/A N/A 5.2 (2.3) 4.3 (1.3) effective features is likely to be more difficult.
RF 4.6 (2.1) 4.8 (2.2) 4.0 (3.2) 5.8 (1.9)
XGBoost 2.6 (1.4) 2.4 (1.5) 2.8 (1.5) 2.1 (1.0) Greedy Feature Selection. We notice a recent attempt on sample-
T-MLP 3.2 (1.6) 4.3 (1.9) 3.5 (2.3) 3.6 (1.4) specific sparsity for biomedical tables using a gating network; it
T-MLP(3) 2.1 (1.4) 2.7 (1.5) 3.0 (1.3) 1.8 (0.7) was originally designed for low-sample-size tabular settings and
helped prediction interpretability in the biomedical domain [64].
We use its code and build a T-MLP version by substituting GBDT
able to be comparable to the ultimate XGBoost in all the four dataset FG with the neural network feature gate (NN FG) for comparison.
types with similar rank stability, serving as a candidate for a tuned The bottom two rows of Table 7 report the results. As expected, on
XGBoost alternative. More significantly, in the experiments, each the smallest dataset CA, NN FG can boost performance by learning
T-MLP (or T-MLP ensemble) employs a tensorized XGBoost trained to select informative features, but such a feature gating strategy
in default configuration (see implementation details in Sec. 4.1), consistently hurts the performance as data scales increase. This may
and all the other hyperparameters are fixed; thus T-MLP and its be due to (1) large datasets demand more complicated structures
ensemble have potential capability of a higher performance ceiling to learn the meticulous feature selection, (2) the discrete nature of
by HPT or selecting other GBDTs as the feature gate. the selection behavior is incompatible with smooth optimization
In summary, we empirically show the strong potential of our patterns of neural networks, and (3) DNNs’ confirmation bias [52]
hybrid framework to achieve flexible and generalized data adapt- may mislead the learning process, i.e., NN FG will be ill-informed
ability with various tabular preferences (tabular data preferring once the subsequent neural network captures wrong patterns. In
advanced DNNs, pre-training, or GBDTs). Based on the impressive contrast, GBDT FG always selects features greedily as real GBDTs,
economical performance-cost trade-off and friendly training pro- which is conservative and generally reasonable. Besides, the com-
cess, T-MLP can serve as a promising tabular model framework plicated sub-tree structures are more complete for the selection
for real-world applications, especially under limited computation action.
budgets.
Sparsity Promotes Tabular DNNs. Fig. 3 plots performance vari-
ations on two classification/regression tasks with respect to T-MLP
4.3 What Makes T-MLP Cost-effective?
sparsity. Different from the pruning techniques in NLP that aim to
Table 7 reports ablation and comparison experimental results of T- trim down model sizes while maintaining the ability of the original
MLP on several classification and regression datasets (i.e., California models, we find that suitable model sparsity often promotes tabu-
Housing (CA) [40], Adult (AD) [34], Higgs (HI) [5], and Year (YE) [6]) lar prediction, but both excessive and insufficient sparsity cannot
in various data scales (given in parentheses). achieve the best results. The results empirically indicate that, com-
Main Ablations. The top four rows in Table 7 report the impact of pared to DNN pruning in large pre-trained models for unstructured
two key designs in a single T-MLP. An overall observation is that data, in the tabular data domain, the pruning has the capability
both the structure sparsity and GBDT feature gate (FG) contribute to to promote non-large tabular DNNs as GBDTs’ beneficial sparse
performance enhancement of T-MLP. From the perspective of data structures achieved by tree pre-pruning, and the hidden dimension
processing, GBDT FG brings local sparsity through sample-specific in tabular DNNs is commonly over-parameterized.
feature selection, and the sparse MLP structure offers global sparsity
shared by all samples. Interestingly, we find that the impact of GBDT 4.4 Superiority Interpretability of T-MLP
FG is more profound on the CA dataset. A possible explanation In Fig. 4, we visualize decision boundaries of FT-Transformer, XG-
is that the feature amount of CA (8 features) is relatively small Boost, and the single T-MLP to inspect data patterns captured by
Team up GBDTs and DNNs: Advancing Efficient and Effective Tabular Prediction with Tree-hybrid MLPs KDD ’24, August 25–29, 2024, Barcelona, Spain

possesses both the advantages to be overfit-resistant, which helps


provide its superiority on both GBDT- and DNN-favored datasets.

5 CONCLUSIONS
In this paper, we proposed T-MLP, a novel hybrid framework attain-
ing the advantages of both GBDTs and DNNs to address the model
selection dilemma in tabular prediction tasks. We combined a ten-
sorized GBDT feature gate, DNN pruning techniques, and a vanilla
back-propagation optimizer to develop a simple yet efficient and
widely effective MLP model. Experiments on diverse benchmarks
showed that, with significantly reduced runtime costs, T-MLP has
Figure 3: Performance variation plots on the Adult and Year the generalized adaptability to achieve considerably competitive
datasets with respect to variations of T-MLP sparsity. All the results regardless of dataset-specific framework preferences. We
best results are achieved with suitable sparsity. expect that our T-MLP will serve as a practical method for econom-
ical tabular prediction as well as in broad applications, and help
advance research on hybrid tabular models.
FT-T (77%) XGB (76%) T-MLP (80%)
Credit-g

ACKNOWLEDGMENTS
This research was partially supported by National Natural Sci-
ence Foundation of China under grants No. 62176231, Zhejiang
Key R&D Program of China under grant No. 2023C03053 and No.
2024SSYS0026.
FT-T (82%) XGB (84%) T-MLP (85%)
Bioresponse

REFERENCES
[1] Naomi S Altman. 1992. An introduction to kernel and nearest-neighbor nonpara-
metric regression. The American Statistician 46, 3 (1992), 175–185.
[2] Sercan Ö Arik and Tomas Pfister. 2021. TabNet: Attentive interpretable tabular
learning. In AAAI. 6679–6687.
[3] Saqib Aziz, Michael Dowling, Helmi Hammami, and Anke Piepenbrink. 2022.
Machine learning in finance: A topic modeling approach. European Financial
Management (2022).
[4] Sarkhan Badirli, Xuanqing Liu, Zhengming Xing, Avradeep Bhowmik, Khoa
Figure 4: Decision boundary visualization of FT-Transformer Doan, and Sathiya S Keerthi. 2020. Gradient boosting neural networks: GrowNet.
(FT-T), XGBoost, and a single-block T-MLP on the Biore- arXiv preprint arXiv:2002.07971 (2020).
sponse and Credit-g datasets, using two most important fea- [5] Pierre Baldi, Peter Sadowski, et al. 2014. Searching for exotic particles in high-
energy physics with deep learning. Nature Communications 5, 1 (2014), 4308.
tures. Different colors represent distinct categories, while the [6] Thierry Bertin-Mahieux, Daniel PW Ellis, Brian Whitman, and Paul Lamere. 2011.
varying shades of colors indicate the predicted probabilities. The million song dataset. In ISMIR.
[7] Vadim Borisov, Tobias Leemann, Kathrin Seßler, Johannes Haug, Martin Pawel-
czyk, and Gjergji Kasneci. 2022. Deep neural networks and tabular data: A survey.
IEEE Transactions on Neural Networks and Learning Systems (2022).
these three methods. The two most important features are selected [8] Vadim Borisov, Kathrin Sessler, Tobias Leemann, Martin Pawelczyk, and Gjergji
by mutual information (estimated with the Scikit Learn package). Kasneci. 2022. Language Models are Realistic Tabular Data Generators. In ICLR.
[9] Leo Breiman. 2001. Random forests. Machine learning 45 (2001), 5–32.
Different from common DNNs and GBDTs, T-MLP exhibits a novel [10] Tom Brown, Benjamin Mann, Nick Ryder, Melanie Subbiah, Jared D Kaplan,
intermediate pattern that combines characteristics from both DNNs Prafulla Dhariwal, Arvind Neelakantan, Pranav Shyam, Girish Sastry, Amanda
and GBDTs. Compared to DNNs, T-MLP yields grid-like boundaries Askell, et al. 2020. Language models are few-shot learners. In NeurIPS, Vol. 33.
1877–1901.
whose edges are often orthogonal to feature surfaces as GBDTs, [11] Guiping Cao, Shengda Luo, Wenjian Huang, Xiangyuan Lan, Dongmei Jiang,
and the complexity is essentially simplified with pruned sparse Yaowei Wang, and Jianguo Zhang. 2023. Strip-MLP: Efficient Token Interaction
for Vision MLP. In ICCV. 1494–1504.
architectures. Besides, T-MLP is able to capture tree-model-like sub- [12] Jintai Chen, KuanLun Liao, Yanwen Fang, Danny Chen, and Jian Wu. 2022.
patterns (see T-MLP on Credit-g), while DNNs manage only main TabCaps: A Capsule Neural Network for Tabular Data Classification with BoW
patterns. Hence, DNNs are overfit-sensitive due to their relatively Routing. In ICLR.
[13] Jintai Chen, Kuanlun Liao, Yao Wan, Danny Z Chen, and Jian Wu. 2022. DANets:
irregular boundaries and neglecting of fine-grained sub-patterns. Deep abstract networks for tabular data classification and regression. In AAAI.
Compared to GBDTs with jagged boundaries and excessively split [14] Jintai Chen, Jiahuan Yan, Danny Ziyi Chen, and Jian Wu. 2023. ExcelFormer:
sub-patterns, T-MLP holds very smooth vertices at the intersection A Neural Network Surpassing GBDTs on Tabular Data. arXiv preprint
arXiv:2301.02819 (2023).
of boundaries (see T-MLP on Credit-g). Notably, T-MLP can decide [15] Si-An Chen, Chun-Liang Li, Nate Yoder, Sercan O Arik, and Tomas Pfister. 2023.
conditional split points like GBDT feature splitting (orthogonal TSMixer: An All-MLP architecture for time series forecasting. arXiv preprint
arXiv:2303.06053 (2023).
edges at feature surfaces) through a smooth process (see T-MLP [16] Tianqi Chen and Carlos Guestrin. 2016. XGBoost: A scalable tree boosting system.
boundary edges on Bioresponse, from top to bottom, in which In Proceedings of the 22nd ACM SIGKDD International Conference on Knowledge
the split point on the horizontal feature is conditionally changed Discovery and Data Mining. 785–794.
[17] Paul Covington, Jay Adams, and Emre Sargin. 2016. Deep neural networks
with respect to the vertical feature in a smooth manner, while XG- for YouTube recommendations. In Proceedings of the 10th ACM Conference on
Boost is hard to attain such dynamical split points). Overall, T-MLP Recommender Systems.
KDD ’24, August 25–29, 2024, Barcelona, Spain Jiahuan Yan, Jintai Chen, Qianxing Wang, Danny Z. Chen, & Jian Wu

[18] Jerome H Friedman. 2001. Greedy function approximation: A gradient boosting [50] Mingjie Sun, Zhuang Liu, Anna Bair, and J Zico Kolter. 2023. A Simple
machine. Annals of Statistics (2001). and Effective Pruning Approach for Large Language Models. arXiv preprint
[19] Francesco Fusco, Damian Pascual, Peter Staar, and Diego Antognini. 2023. pNLP- arXiv:2306.11695 (2023).
Mixer: An Efficient all-MLP Architecture for Language. In Proceedings of the 61st [51] Chuanxin Tang, Yucheng Zhao, et al. 2022. Sparse MLP for image recognition: Is
Annual Meeting of the Association for Computational Linguistics. Association for self-attention really necessary?. In AAAI. 2344–2351.
Computational Linguistics, 53–60. [52] Antti Tarvainen and Harri Valpola. 2017. Mean teachers are better role models:
[20] Pierre Geurts, Damien Ernst, and Louis Wehenkel. 2006. Extremely randomized Weight-averaged consistency targets improve semi-supervised deep learning
trees. Machine learning 63 (2006), 3–42. results. In NeurIPS, Vol. 30.
[21] Yury Gorishniy, Ivan Rubachev, and Artem Babenko. 2022. On embeddings for [53] Ilya O Tolstikhin, Neil Houlsby, Alexander Kolesnikov, Lucas Beyer, Xiaohua
numerical features in tabular deep learning. In NeurIPS. 24991–25004. Zhai, Thomas Unterthiner, Jessica Yung, Andreas Steiner, Daniel Keysers, Jakob
[22] Yury Gorishniy, Ivan Rubachev, Valentin Khrulkov, and Artem Babenko. 2021. Uszkoreit, et al. 2021. MLP-Mixer: An all-MLP architecture for vision. In NeurIPS.
Revisiting deep learning models for tabular data. In NeurIPS. 18932–18943. 24261–24272.
[23] Léo Grinsztajn, Edouard Oyallon, and Gaël Varoquaux. 2022. Why do tree-based [54] Zhengzhong Tu, Hossein Talebi, Han Zhang, Feng Yang, Peyman Milanfar, Alan
models still outperform deep learning on typical tabular data?. In NeurIPS. Bovik, and Yinxiao Li. 2022. MAXIM: Multi-Axis MLP for image processing. In
[24] Jianyuan Guo, Yehui Tang, et al. 2022. Hire-MLP: Vision MLP via hierarchical CVPR. 5769–5780.
rearrangement. In CVPR. 826–836. [55] Shahadat Uddin, Arif Khan, Md Ekramul Hossain, and Mohammad Ali Moni.
[25] Xinran He, Junfeng Pan, et al. 2014. Practical lessons from predicting clicks on 2019. Comparing different supervised machine learning algorithms for disease
ads at Facebook. In Proceedings of the International Workshop on Data Mining for prediction. BMC Medical Informatics and Decision Making (2019), 1–16.
Online Advertising. [56] Ashish Vaswani, Noam Shazeer, Niki Parmar, Jakob Uszkoreit, Llion Jones,
[26] Lu Hou, Zhiqi Huang, Lifeng Shang, Xin Jiang, Xiao Chen, and Qun Liu. 2020. Aidan N Gomez, Łukasz Kaiser, and Illia Polosukhin. 2017. Attention is all
DynaBERT: Dynamic BERT with adaptive width and depth. In NeurIPS, Vol. 33. you need. In NeurIPS.
9782–9793. [57] Ruoxi Wang, Rakesh Shivanna, Derek Cheng, Sagar Jain, Dong Lin, Lichan Hong,
[27] Jeremy Howard and Sylvain Gugger. 2020. Fastai: A layered API for deep learning. and Ed Chi. 2021. DCN V2: Improved deep & cross network and practical lessons
Information 11, 2 (2020), 108. for web-scale learning to rank systems. In WWW. 1785–1797.
[28] Xin Huang, Ashish Khetan, Milan Cvitkovic, and Zohar Karnin. 2020. TabTrans- [58] Zifeng Wang and Jimeng Sun. 2022. TransTab: Learning transferable tabular
former: Tabular data modeling using contextual embeddings. arXiv preprint Transformers across tables. In NeurIPS, Vol. 35. 2902–2915.
arXiv:2012.06678 (2020). [59] Ziheng Wang, Jeremy Wohlwend, and Tao Lei. 2020. Structured Pruning of Large
[29] Arlind Kadra, Marius Lindauer, Frank Hutter, and Josif Grabocka. 2021. Well- Language Models. In EMNLP. 6151–6162.
tuned simple nets excel on tabular datasets. In NeurIPS. 23928–23941. [60] Mitchell Wortsman, Gabriel Ilharco, Samir Ya Gadre, Rebecca Roelofs, Raphael
[30] Liran Katzir, Gal Elidan, and Ran El-Yaniv. 2020. Net-DNF: Effective deep modeling Gontijo-Lopes, Ari S Morcos, Hongseok Namkoong, Ali Farhadi, Yair Carmon,
of tabular data. In ICLR. Simon Kornblith, et al. 2022. Model soups: Averaging weights of multiple fine-
[31] Guolin Ke, Qi Meng, Thomas Finley, Taifeng Wang, Wei Chen, Weidong Ma, tuned models improves accuracy without increasing inference time. In ICML.
Qiwei Ye, and Tie-Yan Liu. 2017. LightGBM: A highly efficient gradient boosting 23965–23998.
decision tree. In NeurIPS. [61] Mengzhou Xia, Zexuan Zhong, and Danqi Chen. 2022. Structured Pruning Learns
[32] Jacob Devlin Ming-Wei Chang Kenton and Lee Kristina Toutanova. 2019. BERT: Compact and Accurate Models. In ACL.
Pre-training of Deep Bidirectional Transformers for Language Understanding. In [62] Jiahuan Yan, Jintai Chen, Yixuan Wu, Danny Z Chen, and Jian Wu. 2023. T2G-
NAACL-HLT. 4171–4186. Former: Organizing tabular features into relation graphs promotes heterogeneous
[33] Günter Klambauer, Thomas Unterthiner, Andreas Mayr, and Sepp Hochreiter. feature interaction. In AAAI.
2017. Self-normalizing neural networks. In NeurIPS, Vol. 30. [63] Jiahuan Yan, Bo Zheng, Hongxia Xu, Yiheng Zhu, Danny Chen, Jimeng Sun,
[34] Ron Kohavi et al. 1996. Scaling up the accuracy of Naive-Bayes classifiers: A Jian Wu, and Jintai Chen. 2024. Making Pre-trained Language Models Great on
decision-tree hybrid. In KDD, Vol. 96. 202–207. Tabular Prediction. In ICLR.
[35] Bin Li, J Friedman, R Olshen, and C Stone. 1984. Classification and regression [64] Junchen Yang, Ofir Lindenbaum, and Yuval Kluger. 2022. Locally sparse neural
trees (CART). Biometrics (1984). networks for tabular biomedical data. In ICML. PMLR, 25123–25153.
[36] Hanxiao Liu, Zihang Dai, David So, and Quoc V Le. 2021. Pay attention to MLPs. [65] Jun Zhang and Vasant Honavar. 2003. Learning from attribute value taxonomies
In NeurIPS. 9204–9215. and partially specified instances. In ICML.
[37] Christos Louizos, Max Welling, and Diederik P Kingma. 2018. Learning Sparse [66] Jun Zhang, D-K Kang, et al. 2006. Learning accurate and concise Naïve Bayes
Neural Networks through L_0 Regularization. In ICLR. classifiers from attribute value taxonomies and data. Knowledge and Information
[38] Xinyin Ma, Gongfan Fang, and Xinchao Wang. 2023. LLM-Pruner: On the Struc- Systems (2006).
tural Pruning of Large Language Models. In NeurIPS. [67] Tianping Zhang, Shaowen Wang, Shuicheng Yan, Jian Li, and Qian Liu. 2023.
[39] Tomas Mikolov, Kai Chen, et al. 2013. Efficient estimation of word representations Generative Table Pre-training Empowers Models for Tabular Prediction. arXiv
in vector space. arXiv preprint arXiv:1301.3781 (2013). preprint arXiv:2305.09696 (2023).
[40] R Kelley Pace and Ronald Barry. 1997. Sparse spatial autoregressions. Statistics [68] Wayne Xin Zhao, Kun Zhou, Junyi Li, Tianyi Tang, Xiaolei Wang, Yupeng Hou,
& Probability Letters 33, 3 (1997), 291–297. Yingqian Min, Beichen Zhang, Junjie Zhang, Zican Dong, et al. 2023. A survey
[41] F. Pedregosa, G. Varoquaux, et al. 2011. Scikit-learn: Machine Learning in Python. of large language models. arXiv preprint arXiv:2303.18223 (2023).
Journal of Machine Learning Research 12 (2011), 2825–2830. [69] Bingzhao Zhu, Xingjian Shi, Nick Erickson, Mu Li, George Karypis, and Mahsa
[42] Sergei Popov, Stanislav Morozov, and Artem Babenko. 2019. Neural Oblivious Shoaran. 2023. XTab: Cross-table Pretraining for Tabular Transformers. In ICML.
Decision Ensembles for Deep Learning on Tabular Data. In ICLR.
[43] Liudmila Prokhorenkova, Gleb Gusev, Aleksandr Vorobev, Anna Veronika Doro-
gush, and Andrey Gulin. 2018. CatBoost: Unbiased boosting with categorical A BENCHMARK CHARACTERISTICS
features. In NeurIPS.
[44] Alec Radford, Jong Wook Kim, et al. 2021. Learning transferable visual models
We provide detailed dataset statistical information of each bench-
from natural language supervision. In ICML. 8748–8763. mark in Table 2. These benchmarks exhibit broad data diversity in
[45] Camilo Ruiz, Hongyu Ren, Kexin Huang, and Jure Leskovec. 2023. Enabling data scales and task types. From the FT-T benchmark to TabBen, the
tabular deep learning when 𝑑 ≫ 𝑛 with an auxiliary knowledge graph. arXiv
preprint arXiv:2306.04766 (2023). overall data volume is gradually reduced. We additionally visualize
[46] Sungyong Seo, Jing Huang, Hao Yang, and Yan Liu. 2017. Interpretable convo- the respective winning rates of GBDT and DNN frameworks in
lutional neural networks with dual local and global attention for review rating Fig. 2, indicating varying framework preferences among the dataset
prediction. In Proceedings of the 11th ACM Conference on Recommender Systems.
297–305. collections used in different tabular prediction tasks. FT-T does not
[47] Ravid Shwartz-Ziv and Amitai Armon. 2022. Tabular data: Deep learning is not include GBDT baselines in its main benchmark, but has the most
all you need. Information Fusion 81 (2022), 84–90.
[48] Gowthami Somepalli, Avi Schwarzschild, Micah Goldblum, C Bayan Bruss, and
extremely large datasets. Overall, the FT-T benchmark is the ex-
Tom Goldstein. 2022. SAINT: Improved Neural Networks for Tabular Data tremely large-scale data collection (in both data volume and feature
via Row Attention and Contrastive Pre-Training. In NeurIPS 2022 First Table width), the T2G benchmark is a large one, the SAINT benchmark
Representation Workshop.
[49] Weiping Song, Chence Shi, Zhiping Xiao, Zhijian Duan, Yewen Xu, Ming Zhang, contains diverse data scales, and TabBen focuses on middle-size
and Jian Tang. 2019. AutoInt: Automatic feature interaction learning via self- typical tables.
attentive neural networks. In CIKM. 1161–1170.
Team up GBDTs and DNNs: Advancing Efficient and Effective Tabular Prediction with Tree-hybrid MLPs KDD ’24, August 25–29, 2024, Barcelona, Spain

B BASELINE INFORMATION • ExtraTrees [20]: A classical tree bagging implementation.


We list all the compared baselines in this section. • k-NN [1]: Traditional supervised machine learning algo-
rithms; two KNeighbors models are used (KNeighborsDist,
• MLP: Vanilla multi-layer perception with no feature interac-
KNeighborsUnif).
tion.
• NeuralNetFastAI [27]: FastAI neural network models that
• ResNet: A popular DNN backbone in vision applications.
operate on tabular data.
• SNN [33]: An MLP-like architecture with SELU activation.
• sklearn-GBDT [41]: Two traditional GBDT implementations
• GrowNet [4]: MLPs built in a gradient boosted manner.
(GradientBoostingTree and HistGradientBoostingTrees) pro-
• NODE [42]: Generalized oblivious decision tree ensembles.
vided in the Scikit Learn package.
• TabNet [2]: A Transformer-based recurrent architecture em-
ulating tree-based learning process.
• AutoInt [49]: Attention-based feature embeddings.
C RUNTIME ENVIRONMENT AND
• DCNv2 [57]: An MLP-based architecture with the feature- HYPERPARAMETERS
crossing module. C.1 Runtime Environment
• TabTransformer [28]: A Transformer model concatenating All the experiments are conducted with PyTorch version 1.11.0,
numerical features and encoded categorical features. CUDA version 11.3, and Scikit Learn version 1.1.0, with each trial
• DANets [13]: An MLP-based architecture with neural-guided using an NVIDIA A100 PCIe 40GB and an Intel Xeon Processor
feature selection and abstraction in each block. 40C.
• FT-Transformer [22]: A popular tabular Transformer encod-
ing both numerical and categorical features. C.2 Hyperparameters of T-MLP
• T2G-Former [62]: A tabular Transformer with automatic
In the main experiments, we uniformly set the hidden size 𝑑 to
relation graph estimation for selective feature interaction.
1024, the intermediate size 𝑑 ′ to 676 (2/3 of the hidden size), the
• SAINT [48]: A Transformer-like architecture performing
sparsity rate to 0.33, and the residual dropout rate to 0.1, with three
row-level and column-level attention, and contrastively pre-
basic blocks for multi-class classification or extremely large binary
training to minimize the differences between data points and
classification datasets, and one block for the others. The learning
their augmented views. rate of the single T-MLP is 1e-4, and the learning rates of the three
• XGBoost [16]: A predominant GBDT implementation. branches in T-MLP ensemble are 1e-4, 5e-4, and 1e-3, respectively.
• CatBoost [43]: A GBDT approach with oblivious decision
trees. C.3 Hyperparameters of Baselines
• LightGBM [31]: An efficient GBDT implementation.
For all the baselines on the FT-T and T2G benchmarks, we follow the
• RandomForest [9]: A popular bagging ensemble algorithm
given hyperparameter spaces and iteration times from the original
of decision trees.
benchmark papers to estimate the training costs.
KDD ’24, August 25–29, 2024, Barcelona, Spain Jiahuan Yan, Jintai Chen, Qianxing Wang, Danny Z. Chen, & Jian Wu

Table 8: AUC scores (the higher the better) of the baselines on the binary classification datasets in the SAINT benchmark.

OpenML ID: 31 44 1017 1111 1487 1494 1590 4134 42178


RF 0.778 0.986 0.798 0.774 0.910 0.928 0.908 0.868 0.840
ExtraTrees 0.764 0.986 0.811 0.748 0.921 0.935 0.903 0.856 0.831
KNeighborsDist 0.501 0.873 0.722 0.517 0.741 0.868 0.684 0.808 0.755
KNeighborsUnif 0.489 0.847 0.712 0.516 0.734 0.865 0.669 0.790 0.764
LightGBM 0.752 0.989 0.829 0.815 0.919 0.923 0.930 0.860 0.854
XGBoost 0.778 0.989 0.821 0.818 0.919 0.926 0.931 0.864 0.856
CatBoost 0.788 0.988 0.838 0.818 0.917 0.937 0.930 0.862 0.841
MLP 0.705 0.980 0.745 0.709 0.913 0.932 0.910 0.818 0.841
TabNet 0.736 0.979 0.422 0.718 0.625 0.677 0.917 0.701 0.830
TabTransformer 0.771 0.982 0.729 0.763 0.884 0.913 0.907 0.809 0.841
SAINT-s 0.774 0.982 0.781 0.804 0.906 0.933 0.922 0.819 0.858
SAINT-i 0.774 0.981 0.759 0.816 0.920 0.934 0.919 0.845 0.854
SAINT 0.790 0.991 0.843 0.808 0.919 0.937 0.921 0.853 0.857
T-MLP 0.805 0.983 0.818 0.814 0.924 0.933 0.924 0.853 0.862
T-MLP(3) 0.802 0.983 0.821 0.816 0.924 0.935 0.925 0.855 0.861

Table 9: Accuracy scores (the higher the better) of the baselines on the multi-class classification datasets in the SAINT benchmark.

OpenML ID: 188 1596 4541 40685 41166 41169 42734


RF 0.653 0.953 0.607 0.999 0.671 0.358 0.743
ExtraTrees 0.653 0.946 0.595 0.999 0.648 0.341 0.736
KNeighborsDist 0.442 0.965 0.491 0.997 0.620 0.205 0.685
KNeighborsUnif 0.422 0.963 0.489 0.997 0.605 0.189 0.693
LightGBM 0.667 0.969 0.611 0.999 0.721 0.356 0.754
XGBoost 0.612 0.928 0.611 0.999 0.707 0.356 0.752
CatBoost 0.667 0.871 0.604 0.999 0.692 0.376 0.747
MLP 0.388 0.915 0.597 0.997 0.707 0.378 0.733
TabNet 0.259 0.744 0.517 0.997 0.599 0.243 0.630
TabTransformer 0.660 0.715 0.601 0.999 0.531 0.352 0.744
SAINT-s 0.680 0.735 0.607 0.999 0.582 0.194 0.755
SAINT-i 0.646 0.937 0.598 0.999 0.713 0.380 0.747
SAINT 0.680 0.946 0.606 0.999 0.701 0.377 0.752
T-MLP 0.660 0.968 0.598 1.000 0.718 0.382 0.747
T-MLP(3) 0.674 0.970 0.601 1.000 0.728 0.384 0.750
Team up GBDTs and DNNs: Advancing Efficient and Effective Tabular Prediction with Tree-hybrid MLPs KDD ’24, August 25–29, 2024, Barcelona, Spain

Table 10: RMSE scores (the lower the better) of the baselines on the regression datasets in the SAINT benchmark.

OpenML ID: 422 541 42563 42571 42705 42724 42726 42727 42728 42729
RF 0.027 17.814 37085.577 1999.442 16.729 12375.312 2.476 0.149 13.700 1.767
ExtraTrees 0.027 19.269 35049.267 1961.928 15.349 12505.090 2.522 0.147 13.578 1.849
KNeighborsDist 0.029 25.054 46331.144 2617.202 14.496 13046.090 2.501 0.167 13.692 2.100
KNeighborsUnif 0.029 24.698 47201.343 2629.277 18.397 12857.449 2.592 0.169 13.703 2.109
LightGBM 0.027 19.871 32870.697 1898.032 13.018 11639.594 2.451 0.144 13.468 1.958
XGBoost 0.028 13.791 36375.583 1903.027 12.311 11931.233 2.452 0.145 13.480 1.849
CatBoost 0.027 14.060 35187.381 1886.593 11.890 11614.567 2.405 0.142 13.441 1.883
NeuralNetFastAl 0.028 22.756 42751.432 1991.774 15.892 11618.684 2.500 0.162 13.781 3.351
TabNet 0.028 22.731 200802.769 1943.091 11.084 11613.275 2.175 0.183 16.665 2.310
TabTransformer 0.028 21.600 37057.686 1980.696 15.693 11618.356 2.494 0.162 12.982 3.259
SAINT–s 0.027 9.613 193430.703 1937.189 10.034 11580.835 2.145 0.158 12.603 1.833
SAINT-i 0.028 12.564 33992.508 1997.111 11.513 11612.084 2.104 0.153 12.534 1.867
SAINT 0.027 11.661 33112.387 1953.391 10.282 11577.678 2.113 0.145 12.578 1.882
T-MLP 0.027 11.643 21773.233 1946.203 9.027 11828.872 2.041 0.161 13.271 1.843
T-MLP(3) 0.027 13.790 22185.024 1939.557 8.972 11762.376 2.049 0.161 13.016 1.852

Table 11: Accuracy scores (the higher the better) of the baselines for the binary classification tasks on the TabBen numerical
datasets.

eye MiniBooNE Higgs bank-market covertype MagicTele. electricity credit jannis


Resnet 0.574 0.937 0.694 0.794 0.803 0.858 0.809 0.761 0.746
FT-T 0.586 0.937 0.706 0.804 0.813 0.851 0.820 0.765 0.766
SAINT 0.589 0.935 0.707 0.791 0.803 0.851 0.818 0.760 0.773
GBT 0.637 0.932 0.711 0.803 0.819 0.851 0.862 0.772 0.770
XGBoost 0.655 0.936 0.714 0.804 0.819 0.859 0.868 0.774 0.778
RF 0.650 0.927 0.708 0.798 0.827 0.853 0.861 0.772 0.773
MLP 0.569 0.935 0.689 0.792 0.789 0.847 0.810 0.760 0.746
T-MLP 0.610 0.946 0.731 0.802 0.909 0.859 0.842 0.772 0.800
T-MLP(3) 0.613 0.947 0.733 0.803 0.915 0.861 0.848 0.775 0.799

Table 12: R-Squared scores (the higher the better) of the baselines for the regression tasks on the TabBen numerical datasets.

elevators Bike houses nyc-taxi pol sulfur Ailerons wine supercon. house sales Brazilian Miami cpu act diamonds
Resnet 0.910 0.669 0.821 0.468 0.948 0.819 0.835 0.363 0.895 0.868 0.998 0.914 0.982 0.942
FT-T 0.914 0.671 0.832 0.476 0.995 0.838 0.844 0.359 0.885 0.875 0.998 0.919 0.978 0.944
SAINT 0.923 0.684 0.820 0.496 0.996 0.788 0.784 0.374 0.894 0.879 0.994 0.921 0.984 0.944
GBT 0.863 0.690 0.840 0.554 0.979 0.806 0.843 0.458 0.905 0.884 0.995 0.924 0.985 0.945
XGBoost 0.908 0.695 0.852 0.553 0.990 0.865 0.844 0.498 0.911 0.887 0.998 0.936 0.986 0.946
RF 0.841 0.687 0.829 0.563 0.989 0.859 0.839 0.504 0.909 0.871 0.993 0.924 0.983 0.945
T-MLP 0.875 0.694 0.834 0.560 0.995 0.853 0.840 0.410 0.894 0.886 0.993 0.939 0.982 0.949
T-MLP(3) 0.908 0.698 0.838 0.566 0.996 0.860 0.843 0.416 0.899 0.888 0.995 0.939 0.983 0.950
KDD ’24, August 25–29, 2024, Barcelona, Spain Jiahuan Yan, Jintai Chen, Qianxing Wang, Danny Z. Chen, & Jian Wu

Table 13: Accuracy scores (the higher the better) of the baselines for the binary classification tasks on the TabBen categorical
datasets.

eye road-safety electricity covertype rl compass


FT-T 0.598 0.767 0.842 0.867 0.703 0.753
Resnet 0.579 0.761 0.826 0.853 0.706 0.745
SAINT 0.585 0.764 0.834 0.850 0.682 0.719
GBT 0.639 0.762 0.880 0.856 0.776 0.741
XGBoost 0.668 0.767 0.887 0.864 0.770 0.769
HistGBT 0.636 0.765 0.882 0.845 0.761 0.751
RF 0.657 0.759 0.878 0.859 0.798 0.793
T-MLP 0.605 0.786 0.880 0.882 0.757 0.785
T-MLP(3) 0.609 0.786 0.881 0.880 0.762 0.790

Table 14: R-Squared scores (the higher the better) of the baselines for the regression tasks on the TabBen categorical datasets.

Bike particulate Brazilian diamonds black nyc-taxi analcatdata OnlineNews Mercedes house sales
FT-T 0.937 0.673 0.883 0.990 0.379 0.511 0.977 0.143 0.548 0.891
Resnet 0.936 0.658 0.878 0.989 0.360 0.451 0.978 0.130 0.545 0.881
GBT 0.942 0.683 0.995 0.990 0.615 0.573 0.981 0.153 0.578 0.891
XGBoost 0.946 0.691 0.998 0.991 0.619 0.578 0.983 0.162 0.578 0.896
HistGBT 0.942 0.690 0.993 0.991 0.616 0.539 0.982 0.156 0.576 0.890
RF 0.938 0.674 0.993 0.988 0.609 0.585 0.981 0.149 0.575 0.875
T-MLP 0.938 0.692 0.996 0.990 0.620 0.571 0.990 0.154 0.576 0.893
T-MLP(3) 0.942 0.698 0.996 0.993 0.622 0.580 0.990 0.158 0.578 0.894

You might also like