Deep Learning For Causal Inference
Deep Learning For Causal Inference
Bernard J. Koch∗1 , Tim Sainburg2 , Pablo Geraldo Bastı̀as1 , Song Jiang3 , Yizhou Sun3
and Jacob Foster1
1
UCLA Department of Sociology
2
Harvard Medical School
3
UCLA Department of Computer Science
April 2023
Abstract
This primer systematizes the emerging literature on causal inference using deep neural net-
works under the potential outcomes framework. It provides an intuitive introduction on building
and optimizing custom deep learning models and shows how to adapt them to estimate/predict
heterogeneous treatment effects. It also discusses ongoing work to extend causal inference to set-
tings where confounding is non-linear, time-varying, or encoded in text, networks, and images.
To maximize accessibility, we also introduce prerequisite concepts from causal inference and deep
learning. The primer differs from other treatments of deep learning and causal inference in its
sharp focus on observational causal estimation, its extended exposition of key algorithms, and its
detailed tutorials for implementing, training, and selecting among deep estimators in Tensorflow
2 and PyTorch.
1
Contents
1 Introduction 4
6 Beyond Traditional Data: Text, Networks, Images, and Treatment over Time 38
6.1 Causal Inference from Text . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . 39
6.2 Causal Inference from Networks . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . 41
6.3 Causal Inference from Images . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . 42
6.4 Causal Inference from Time-varying Data . . . . . . . . . . . . . . . . . . . . . . . . . 42
2
A Balancing Using Integral Probability Metrics 57
A.1 Wasserstein Distance . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . 57
A.2 Extending Representation Balancing with IPMs . . . . . . . . . . . . . . . . . . . . . . 58
A.2.1 Extending Representation Balancing with Matching . . . . . . . . . . . . . . . 61
Boxes
1 Box 1: Example Scenarios for Causal Inference with Non-Traditional Data . . . . . . . 6
2 Box 2: Basic Introduction to Supervised Learning . . . . . . . . . . . . . . . . . . . . 6
3 Box 3: Reading Machine Learning Papers: Computational Graphs and Loss Functions 8
4 Box 4: Basic Introduction to Causal Inference . . . . . . . . . . . . . . . . . . . . . . . 16
5 Box 5: Applied Causal Inference Example: The Infant Health and Development Study 17
6 Box 6: Notation for Causal Inference and Estimation . . . . . . . . . . . . . . . . . . . 20
7 Box 7: TARNet in Code . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . 29
8 Box 8: Graph Neural Networks and Transformers . . . . . . . . . . . . . . . . . . . . . 40
9 Box 9: Generative Adversarial Networks (GAN) . . . . . . . . . . . . . . . . . . . . . 64
3
1 Introduction
This primer aims to introduce social science readers to an exciting literature exploring how deep neural
networks can be used to estimate causal effects. In recent years, both causal inference frameworks and
deep learning have seen rapid adoption across science, industry, and medicine. Causal inference has a
long tradition in the social sciences, and social scientists are increasingly exploring the use of machine
learning for causal inference (Athey and Imbens, 2016; Wager and Athey, 2018; Chernozhukov et al.,
2018). Nevertheless, deep learning remains conspicuously underutilized by social scientists compared
The deep learning revolution has been spurred by the flexibility and expressiveness of these models.
Neural networks are nearly non-parametric and can theoretically approximate any continuous function
(Cybenko, 1989), making them well suited for both classification and regression tasks. Furthermore,
they can be configured with different architectures and objectives to learn from a variety of quantitative
data as well as text, images, video, networks, and speech. These advantages allow them to learn
vector “representations” of complex data with emergent properties. Simple examples of representation
learning include the Word2Vec algorithm that discovers semantic relationship between words in texts,
or face classification models that learn vectors describing facial features (Mikolov et al., 2013). More
recently, generative models like DALL-E, Stable Diffusion, and ChatGPT have shown how coherent
text passages and life-like images can be reconstructed from learned representations.
Here we explore the potential for leveraging these advantages to estimate causal effects. Causal
inference frameworks are non-parametric, but the linear models traditionally used to estimate causal
effects require strong parametric assumptions. In contrast, the nearly non-parametric nature of neural
networks allows us to estimate smooth response surfaces that capture heterogeneous treatment effects
for individual units with low bias.1 The ability of these models to learn from complex data means
we can extend causal inference to new settings where confounding is complicated, time-varying, or
even encoded in texts, graphs, or images (see Box 1 for hypothetical examples). Lastly given the right
objectives, neural networks promise to learn deconfounded representations of data, presenting a new
1 Neural networks can have hundreds to billions of parameters making them effectively non-parametric. The risks of
4
strategy for treatment modeling.
This primer synthesizes existing literature on deep causal estimators, but it is not a review; its
goals are fundamentally pedagogical and prospective rather than retrospective. In Section 2, we
introduce social scientists to the fundamental concepts of deep learning, and the basic workflow for
building and training their own deep neural networks within a supervised learning framework. For
readers unfamiliar with causal inference, Section 3 introduces the assumptions of causal identification
and three fundamental estimation strategies within the selection on observables design: matching,
outcome modeling, and inverse propensity score weighting. Machine learning models often perform
poorly in both theory and practice when only one of these strategies is employed, so we also introduce
Section 4 is the main body of the article. Here we introduce four related deep learning models for
the estimation of heterogeneous treatment effects: the S-learner, T-learner, TARNet and Dragonnet
(Shalit et al., 2017; Shi et al., 2019). Although this literature is rapidly evolving, these four models
are sufficient to illustrate how traditional estimation strategies can be used in creative ways that
leverage the key strengths of neural networks (i.e., deconfounding through representation learning,
semi-parametric inference). Section 5 deals with the practical considerations of building confidence
intervals and interpreting neural networks. These guidelines are concretized in the companion online
tutorials, which show readers how to implement and interpret the models described in Section 4 in
In Section 6, we focus on the future of deep causal information: estimators that can disentange
counfounding relationships embedded within texts, images, graphs, or time-varying data. In the
interest of clarity, we give hypothetical examples of the types of questions social scientists might
answer with these models, and briefly describe ongoing research on each of these modalities. For fuller
treatments of some of these models, see the appendix. We conclude with a discussion of how neural
networks fit into the broader literature on machine learning for causal inference (Section 7).
The primer makes multiple contributions. First, it is one of the first pieces in the sociological
literature to introduce the fundamentals of deep learning not only at a conceptual level (e.g, back-
propagation, representation learning), but at a practical one (e.g., validation, hyperparameter tuning).
5
Our recommendations for training and interpreting neural networks are supported by heavily anno-
tated tutorials that teach readers without prior familiarity with deep learning how to build their own
custom models in Tensorflow 2 and PyTorch. Second, we use this foundation and select examples
to build intuition on how the core strengths of deep learning can be leveraged for causal inference.
Finally, we highlight future directions for this literature and argue why the future of causal estimation
Text. As a motivating example, Veitch et al. (2020) consider the effect of the author’s
reported gender (T ) on the number of upvotes a Reddit post receives (Y ). However gender
may also “affect the text of the post, e.g., through tone, style, or topic choices, which also
affects its score [(X)].” Controlling for a representation of the text would allow the analyst to
more accurately estimate the direct effect of gender.
Images. Todorov et al. (2005) showed that split second-judgments of a politician’s compe-
tence (T ) from pictures (X) of their face is predictive of their electability (Y ). When attempting
to replicate this study using machine learning classifiers rather than human classifiers, Joo et al.
(2015) suggest that the age of the face (Z) is a not-so-obvious confounder: while older individ-
uals are more likely to appear competent, they are also more likely to be incumbents. Even
if age is unknown, using neural networks to control for confounders implicitly encoded in the
image (like age) could reduce bias.
Networks. Nagpal et al. (2020) explore the question of which types of prescription opioids
(e.g., natural, semi-synthetic, synthetic) (T ) are most likely to cause long term addiction (Y ).
Because of predisposition to different injuries, type of employment (X) could be a common cause
of both treatment and outcome. Suppose job type is unobserved, but we know that patients
are likely to associate with coworkers through homophily. To capture some of the effects of
this latent unobserved confounder, analysts might choose to control for a representation of the
patient’s position in their social network when estimating the causal effect.
Deep learning algorithms have most commonly been adapted for causal inference using
supervised machine learning, the most popular learning framework within the field.a The goal
of supervised learning is teach a model a non-linear function that transforms covariates/features
6
X into predicted outcomes Ŷ in unseen data. The model learns this function from labeled
examples of Xtr and Ytr in a training dataset.
As in traditional statistical analyses, the function is learned by optimizing the model’s
parameters such that they minimize the error between its predictions Yˆtr and the true values
Ytr using a loss function (e.g., a likelihood). In a traditional social science analysis focused on
inference, we would stop here and interpret these parameters. In supervised machine learning
where the focus is on generalization to unseen data, the model is ultimately used to predict
outcomes Yte in a test dataset of previously unseen covariates/features Xte . This framework
can be generically applied to cases where Y is categorical (called classification problems), and
where Y is continuous (called regression problems).
Statistical learning theory articulates the central challenge of supervised learning as a bal-
ance between overfitting and underfitting the training dataset. This is also called the
bias-variance” tradeoff. In a regression context, bias error is the difference between the
expected value of Y and the expected value of the mapping function learned by the model.b
High bias typically results from an algorithm that has not sufficiently learned the relationships
in the training dataset (i.e., underfit the data). In contrast, an algorithm that has learned the
training dataset so closely that it is fitting noise in the sample (i.e., overfitting) is likely to
generalize poorly, producing out-of-sample predictions with high variance. Underfitting can be
easily diagnosed and addressed by increasing the complexity of the model. In the case of deep
learning, model complexity can be increased by adding additional layers or parameters/neurons.
Diagnosing and addressing overfitting is a more challenging problem. In supervised learning,
overfitting is diagnosed after training (but before testing) by assessing predictive performance
in a reserved portion of the training set called the validation set. If the model fits the training
dataset well but performs poorly in the validation set, it is likely to generalize poorly to the
test set as well. To prevent overfitting, regularization techniques can be used to simplify the
complexity of the model. Training and regularization of neural networks is discussed in detail
in Section 2.2. For a full treatment of supervised learning and statistical learning theory, see
Hastie et al. (2009).
a The other two prominent paradigms are unsupervised learning and reinforcement learning.
b Note that bias in statistical learning theory is not equivalent to bias of a statistical estimator.
Artificial neural networks (ANN) are statistical models inspired by the human brain (Brand et al.,
2020; Goodfellow et al., 2016). In an ANN, each “neuron” in the network takes the weighted sum
of its inputs (the outputs of other neurons) and transforms them using a differentiable, non-linear
function (e.g. sigmoid, rectified linear unit) that outputs a value between 0 and 1 if the transformed
value is above some threshold. Neurons are arrayed in layers where an input layer takes the raw data,
and neurons in subsequent layers take the weighted sum of outputs in previous layers as input. An
“output” layer contains a neuron for each of the predicted outcomes with transformation functions
appropriate to those outcomes. For example, a regression network that predicts one outcome will
7
have a single output neuron without a transformation function so that it produces a real number. A
regression network without any hidden layers corresponds exactly to a generalized linear model (Fig.
1A). When additional “hidden” layers are added between the input and output layers, the architecture
is called a feed-forward network or multi-layer perceptron (Fig. 1B). A neural network with
multiple hidden layers is called a “deep” network, hence the name “deep learning” (LeCun et al.,
2015). A neural network with a single, large enough hidden layer can theoretically approximate any
Box 3: Reading Machine Learning Papers: Computational Graphs and Loss Func-
tions
Within the machine learning literature, novel algorithms are often presented in terms of
their computational graph and loss function. A computational graph (not to be confused with
a causal graph) uses arrows to depicts the flow of data from the inputs of a neural network,
through parameters, to the outputs. Layers of neurons or specialized sub-architectures are
often generically abstracted as shapes. In our diagrams, we use purple to represent observables,
orange for representation layers of the network, white for produced outputs, and red and blue
for outcome modeling layers. Operations that are computed after prediction (i.e., for which an
error gradient is not calculated) are shown with dashed lines (e.g., plug-in estimation of causal
estimands).
Along with the architecture, the loss function of a neural network is the primary means for
the analyst to dictate what types of representations a neural network learns and what types of
outputs it produces. In multi-task learning settings, we denote joint loss functions for an entire
network as a weighted sum of the losses for substituent tasks and modules. These specific losses
are weighted by hyperparameters. For example, we might weight the joint loss for a network
that predicts outcomes and propensity scores as:
where h(X, T ) is the predicted potential outcome, π(X, T ) is the predicted propensity score, λ
is a hyperparameter and MSE and BCE stand for mean squared error and binary cross entropy
(i.e., log loss), common losses for regression and binary classification respectively (Box 6).
Neural networks are trained to predict their outcomes by optimizing a loss function (also called
an objective or cost function). During training, the backpropagation algorithm uses calculus’s
chain rule to assign portions of the total error in the loss function to each neuron in the network.
An optimizer, such as the stochastic gradient descent algorithm or the currently popular ADAM
algorithm (Kingma and Ba, 2015), then moves each parameter in the opposite direction of this error
8
Figure 1: A: Generalized linear model represented as a computational graph. Observable
covariates X1 , X2 , X3 and treatment status T depicted in purple. Each of the lines between the purple
inputs and the orange box represents a parameter (i.e., a β in a generalized linear model equation).
The orange box is an “output neuron” that sums it’s weighted inputs, performs a transformation g
(the link function in GLM; in this case the identity function), and predicts the conditional outcome
Ŷ (T ). Instead of theoretically interpreting these parameters from an inferential statistics perspective,
machine learning approaches typically use the predicted observed and unobserved potential outcomes
for plug-in estimation of causal estimands (e.g., the conditional average treatment effect CAT ˆ E).
After training, setting T to 1 − T for each observation can predict the unobserved potential outcome
Ŷ (1 − T ). Because this operation occurs after prediction and does not feed a gradient back to the
network to optimize the parameters, it is depicted here with a dotted line. Plug-in calculation of
CAT ˆ E is similarly shown with a dotted line.
B: Feed-forward neural network (S-learner). In a feed-forward neural network, additional fully
connected (parameterized) layers of neurons are added between the inputs and output neuron. The size
of the input covariates and hidden layers are generically abstracted as boxes. The final hidden layer
before the output neuron is denoted Φ because the hidden layers collectively encode a representation
function (see section 2.3). In causal inference settings, this architecture is sometimes called a S(ingle)-
learner because one feed-forward network learns to9 predict both potential outcomes.
gradient. Neural networks first rose to popularity in the 1980s but fell out of favor compared to other
machine learning model families (e.g., support vector machines) due to their expense of training. By
the late 2000s, improvements to backpropagation, advances in computing power (i.e., graphic cards),
and access to larger datasets collectively enabled a deep learning revolution where ANNs began to
significantly outperform other model families. Today, deep learning is the hegemonic machine learning
This section focuses on the practice of training neural networks within a supervised learning framework.
While the principles behind supervised machine learning are universal, the workflow for neural networks
differs substantially from other ML approaches (e.g., random forests, support vector machines) in
practice. Figure 2 presents this workflow in four different parts: Set Up, Training, Model Evaluation,
and Interpretation. We delve into each of these topics in more detail below. Box 2 contains a basic
The first step in training a neural network, as in other types of supervised machine learning, is to split
your dataset into training, validation, and testing datasets (Fig. 2A). If the network is being used for
statistical inference, as here, the testing dataset is optional, and inference may be conducted on just
While the computational graph and loss function define a deep learning architecture (Box 3), actual
implementations can vary significantly due to the choice of hyperparameters. In supervised machine
learning, hyperparameters are parameters that are not learned automatically when training the
model, but must be specified by the analyst. In deep learning, architectural hyperparameters include
the number of layers to use for each section of the computational graph, the number of neurons to use
in each layer, and the activation functions to be used by neurons. While some basic rules of thumb
apply (e.g., use fewer layers than neurons), these choices remain poorly understood theoretically2 ;
2 For some interesting work on understanding neural networks theoretically from a statistical physics perspective see
10
Figure 2: Supervised Deep Learning Workflow. 1) Set Up: The first step in training a deep
learning model is splitting the data into a training set, validation set, and optionally a test set. Initial
hyperparameters are then selected from a set of choices specified by the user. 2) Training: In
each iteration of the training process (called an epoch), the training set is randomly divided into
small minibatches For each minibatch, the network makes predictions for all units, and calculates
the error gradients to be assigned to each neuron in the network based on those predictions. An
optimizer then move the network’s parameters in the opposite direction of the error gradient. After all
minibatches have been trained (one epoch), error is calculated on the entire validation set. This whole
process is repeated up until the validation error stops decreasing (to avoid overfitting). 3) Model
Evaluation: A criterion (typically the validation error) is used to evaluate the performance of this
hyperparameterization. New hyperparameters are then selected using a hyperparameter optimization
algorithm (eg. Grid search, Bayesian hyperparameter optimization, genetic algorithms) and steps 1
and 2 are repeated. Once the hyperparameter optimization algorithm has completed its search, the
“best” model is selected for inference. 4) Inference and interpretation: With a model selected,
the analyst is now ready to apply it to their test data (or in the case of statistical inference, potentially
the full dataset). Predictions of the outcomes and/or propensity score can then be used to compute
the CATE and calculate confidence intervals. Feature importance algorithms like SHAP or Integrated
Gradients can also be used to interpret the CATE estimates.
11
Decisions are generally made by comparing empirical performance on the validation set, a practice
Neural networks are trained by repeatedly making predictions from the training set, calculating error
gradients for each parameter, and backpropagating small fractions of those error gradients. (Fig. 2 B).
A full pass through examples in the training set is called a training loop or epoch. At the beginning of
each epoch, the training set is divided into mini-batches of 2 to 1000 units, randomly sampled without
replacement. This practice not only aids in memory management, it also improves optimization. Using
small random samples reduces the risk of large “exploding” error gradients, particularly early in the
training, that could cause the model to overshoot optimal solutions and instead get stuck in local
minima.
only a sample of a sample (the training dataset), the optimizer only adjusts weight parameters by a
fraction of the error gradient (the learning rate) to avoid overfitting. The learning rate is also a
The non-convex nature of most loss functions4 means that optimization often requires hundreds
to potentially millions of epochs of training. Moreover, neural networks are highly susceptible to
overfitting because it is easy to overparameterize them with excessive neurons/layers. To ward against
overfitting, error metrics on the complete validation set are computed at the end of every epoch.
In a regularization practice called “early stopping,” analysts usually stop training once validation
metrics stop improving. Other common regularization techniques include weight decay (i.e., ℓ2
norm, ridge, or Tikhonov) penalties on the parameters, dropout of neurons during training, and batch
normalization.
Roberts et al. (2022).
3 In the specific context of causal inference, we recommend not having mini-batches that are too small such that the
model can learn from both treated and control units with sufficient overlap.
4 In convex functions (e.g. the OLS loss), there is a single minimum, so optimizing the function means that you will
always converge at the same parameter weights. This is not the case for non-convex functions which may have many
local minima.
12
Dropout is a regularization technique in deep learning where certain nodes are randomly silenced
from training during a given epoch (Srivastava et al., 2014). The general idea of dropout is to force
two neurons in the same layer to learn different aspects of the covariate/feature space and reduce
(Ioffe and Szegedy, 2015). By standardizing (i.e. z-scoring) the inputs to a layer on a per-batch basis
and then rescaling them using trainable parameters, batch normalization smooths the optimization of
the loss function. The addition and extent of each of these regularization techniques can be treated
as hyperparemeters.
(Tutorial 2 )
After the model has been trained, the analyst compares models assembled with different hyper-
parameterizations or initial parameter values (Fig. 2C). Hyperparameterizations can be chosen using
random search, an exhaustive grid search of all possible combinations, or strategic search algorithms
like Bayesian hyperparameter optimization or evolutionary optimization (Snoek et al., 2012). Valida-
tion loss metrics on the final epoch are commonly used for these comparisons.
Model selection for causal estimators is complicated by the fundamental problem of causal inference:
we are not actually interested in the observed “factual” outcomes and propensity scores, but the CATE
and ATE. In the case of algorithms like Dragonnet 4.3 where the validation loss explicitly targets a
causal quantity, we use that as the model selection criterion. In cases where the algorithm is only
trained for outcome modeling or propensity modeling, other solutions are needed. In the Appendix,
we describe Johansson et al. (2020)’s proposal to use matching on a nearest neighbor approximation of
the Precision in Estimated Heterogeneous Effects (PEHE), a measure of CATE bias, as an alternative
The development of more sophisticated methods for model selection of causal estimators through
data simulation is an active area of research within this literature.5 For example, Parikh et al.
5 We note that crossfitting (Zivich and Breskin, 2021), another approach that has emerged for model selection of other
types of machine learning causal estimators may work for the models discussed here, but is likely data-inefficient.
13
(2022) use deep generative models to approximate the data generating distribution under weak, non-
parametric assumptions. Alaa and Van Der Schaar (2019) independently model each outcome and the
One comparative advantage of deep learning over other machine learning approaches has been the
ability of ANNs to encode and automatically compress informative features from complex data into
tasks easier (Goodfellow et al., 2016; Bengio, 2013). While other machine learning approaches may
also encode representations, they often require extensive pre-processing to create useful features for
the algorithm (i.e., feature engineering). Through the lens of representation learning, a geometric
interpretation of the role of each layer in a supervised neural network is to transform its inputs (either
raw data or output of previous layers) into a typically lower (but possibly higher) dimensional vector
space. As a means to share statistical power, encoded representations can also be jointly learned for
The simplest example of a representation might be the final layer in a feed-forward network,
where the early layers of the network can be understood as non-linearly encoding the inputs into an
array of latent linear features for the output neuron (Goodfellow et al., 2016) (Fig. 1B). A famous
example of representation learning is the use of neural networks for face detection. Examining the
representations produced by each layer of these networks shows that each subsequent layer seems to
capture increasingly abstract features of a face (first edges, then noses and eyes, and finally whole
faces) (LeCun et al., 2015). A more familiar example of representation learning to social scientists
might be word vector models like Word2Vec (Mikolov et al., 2013). Word2Vec is a neural network with
one hidden layer and one output layer where words that are semantically similar are closer together
The novel contribution of deep learning to causal estimation is the proposal that a neural network
can learn a function Φ that produces representations of the covariates decorrelated from the treatment.
14
Figure 3: Balancing through representation learning. The promise of deep learning for causal
inference is that a neural network encoding function Φ can transform the treated and control covariate
distributions into a representation space such that they are indistinguishable. Used with permission
from Johansson and Shen (2018).
Fundamentally, the idea is that Φ can transform the treated and control covariate distributions into a
representation space such that they are indistinguishable (Fig. 3). To ensure that these representations
are also still predictive of the outcome (multi-task learning), multiple loss functions are generally
applied simultaneously to balance these objectives. This approach is applied in a majority of the
The papers described in this primer are primarily framed within the Potential Outcomes causal frame-
work (Neyman-Rubin causal model) (Rubin, 1974; Imbens and Rubin, 2015). This framework is con-
cerned with identifying the “potential outcomes” of each unit i in the sample, had it received treatment
(Y (1)) or not received treatment (Y (0)). However, because each unit can only receive one treatment
15
regime in reality (being treated or remaining untreated), it is not possible to observe both potential
outcomes for each individual (often termed “the fundamental problem of causal inference” (Holland,
1986)). While we cannot thus identify individual treatment effects τi = Yi (1) − Yi (0) for each unit,
causal inference frameworks allow us to probabilistically estimate average treatment effects (AT E)
and average treatment effects conditional on select covariates (CAT E) across samples of treated and
control units. Within this literature, the motivation of many papers is to present algorithms that
can both infer CATEs from observational data, but also predict them for out-of-sample units where
treatment status is unknown. For readers unfamiliar with causal inference, a short introduction is
16
the treatment and outcome. Often times, the confounder is a cause of the treatment and
outcome. As an example of selection bias, estimating the causal effect of attending college
(treatment) on adult income (outcome) requires controlling for the fact that parental income
may be a common cause of both college attendance and adult income.
where Y (1) and Y (0) are the potential outcomes had the unit i received or not received the
Box 5: Applied Causal Inference Example: The Infant Health and Development
Study
To make this problem setting more concrete for readers unfamiliar with causal inference,
consider simulations based on the 1985-1988 Infant Health and Development Study that are
widely used as benchmarks within this literature. In this experiment, premature children
were randomly assigned to intensive, high-quality childcare (T ), and their cognitive test scores
were measured later (Y ). The authors also measured numerous other covariates X including
“pregnancy complications, child’s birth weight and gestation age, birth order, child’s gender,
household composition, day care arrangements, source of health care, quality of the home
environment, parents’ race and ethnicity, and maternal age, education, IQ, and employment”
(Ramey et al., 1992). The AT E would be the effect of intensive child care on cognitive scores
across all children, while various CAT Es might be formulated to better understand how the
effects of child care vary for female children, children born to teenage mothers, or children with
unemployed parents.
Hill (2011) turns this experimental data into an observational benchmark by re-simulating
the outcome such that the covariates X induce confounding bias between the treatment and
outcome. While the simulations don’t preserve the names of the covariates, we can imagine
some confounding relationships that might be present in an observational study. For example,
suppose that affluent (X1 ) parents are more likely able to afford high-quality child care (T ), but
there is actually a weak association between childcare and premature babies’ cognitive ability
(Y ). We also know affluent parents are more likely to engage in breastfeeding (X2 ), which is
positively associated with higher cognitive ability (Heck et al., 2006; Kramer et al., 2008). If
17
we do not account for the correlation between income and childcare (X1 → T ), or income and
cognitive ability (X1 → X2 → Y ), we may have bias in our AT E/CAT E estimates, or worse,
erroneously interpret the correlation between childcare and cognitive ability as causal. This
example is depicted in a causal graph below.
The hypothetical confounding bias presented here can be adjusted for either through treat-
ment modeling (e.g., inverse propensity score weighting, non-parametric, deep representation
learning) to block the path X1 → T , outcome modeling (e.g., generalized linear models, deep
regression) to block the path X1 → X2 → Y , or both (see Section 3.2). For coded examples
using many of these approaches in Tensorflow and Pytorch on the IHDP benchmark, please see
the tutorials.
Within the machine learning literature on causal inference treated here, the primary strategy for
presence of confounding relationships between covariates associated with both the treatment and the
outcome.
The key assumption allowing the identification of causal effects in the presence of confounders is:
1. Conditional Ignorability/Exchangability The potential outcomes Y (0), Y (1) and the treat-
18
ment T are conditionally independent given X,
Y (0), Y (1) ⊥⊥ T |X
Conditional Ignorability specifies that there are no unmeasured confounders that affect both treat-
ment and outcome outside of those in the observed covariates/features X. Additionally X may contain
predictors of the outcome, but should not contain instrumental variables or colliders within the con-
ditioning set.6
fies that when a unit receives treatment, their observed outcome is exactly the corresponding potential
outcome (and the same goes for the outcomes under the control condition). Moreover, the response
of any unit does not vary with the treatment assignment to other units (i.e., no network or spillover
effects), and the form/level of treatment is homogeneous and consistent across units (no multiple
T = t → Y = Y (T )
3. Overlap. For all x ∈ X (i.e., any observed covariate value), all treatments t ∈ {0, 1} have a
non-zero probability of being observed in the data, within the “strata” defined by such covariates,
4. An additional assumption often invoked at the interface of identification and estimation using
Invertability
Φ−1 (Φ(X)) = X
6 A variable is a collider if it is caused by two other variables. Controlling for colliding variables, or descendants of
19
In words, there must exist an inverse function of the representation function Φ encoded by a neural
network that can reproduce X from representation space. This is required for the Conditional Ignor-
ability assumption to hold when using representation learning. From a practical perspective, it also
means that the representation we created is rich enough to capture the causal relationships we are
interested in.
For reference, we describe the full notation used within the review in Box 6.
We use uppercase to denote general quantities (e.g., random variables) and lowercase to
denote specific quantities for individual units (e.g., observed variable values).
Causal identification
• Observed covariates/features: X
• Potential outcomes: Y (0) and Y (1)
• Treatment: T
• Unobservable Individual Treatment Effect: τi = Yi (1) − Yi (0)
• Average Treatment Effect: AT E = E[Yi (1) − Yi (0)] = E[τi ]
• Conditional Average Treatment Effect: CAT E(x) = E[Yi (1) − Yi (0)|Xi = x] = E[τi |Xi = x]
Deep learning estimation
• Predicted potential outcomes: Ŷ (0) and Ŷ (1)
• Outcome modeling functions: Ŷ (T ) = h(X, T )
• Propensity score function: π(X, T ) = P (T |X) (where π(X, 0) = 1 − π(X, 1))
• Representation functions: Φ(X) (producing representations ϕ)
• Loss functions: L(true, predicted)
• Loss abbreviations: M SE (mean squared error), BCE (binary cross-entropy), CCE (categorical
cross-entropy)
• Loss hyperparameters: λ, α, β
• Estimated CATE*: CAT ˆ Ei = τ̂i = Ŷi (1) − Ŷi (0) = h(X, 1) − h(X, 0)
• Estimated ATE: ATˆ E = N1
PN
i=1 τˆi
Beyond the AT E and CAT E there is an additional metric commonly used in the machine learn-
ing literature, first introduced by Hill (2011) called the Precision in Estimated Heterogeneous
Effects (PEHE). PEHE is the average error across the predicted CAT Es.
• Precision in Estimated Heterogeneous Effects: P EHE =
PN
1
N i=1 (τi − τˆi )2
Beyond being a metric for simulations with known counterfactuals, the P EHE has theoretical
significance in the formulation of generalization bounds within this literature (Shalit et al.,
2017; Johansson et al., 2018, 2020; Zhang et al., 2020).
*Note that we use τ̂ to refer to the estimated CATE because truly individual treatment
effects cannot be described only by the observed covariates X.
20
3.2 Estimation of Causal Effects
Once a strategy for identifying causal effects from available data has been developed (arguably the
harder and more important part of causal inference), statistical methods can be used to estimate causal
effects by controlling for confounding bias. There are two fundamental approaches to estimation:
treatment modeling to control for correlations between the covariates X and the treatment T , and
outcome modeling to control for correlations between the treatment X and the outcome Y (Fig.
4). Below we briefly review three traditional techniques for removing confounding bias to motivate
our systematization of deep learning models. First, we discuss outcome modeling through regression.
Next, we consider treatment modeling through non-parametric matching. Finally, we discuss treatment
modeling through inverse propensity score weighting (IPW) and introduce the concept of double
robustness.
Assuming the treatment effect is constant across covariates/features or the probability of treatment
is constant across all covariates/features (both improbable assumptions), the simplest consistent ap-
proach to estimating the AT E is to regress the outcome on the treatment indicator and covariates
using a linear model.7 The ATE is then the coefficient of the treatment indicator. Without loss of
Ŷi (T ) = h(Xi , T )
the application of machine learning to causal inference, is to use h(X, T ) to impute Ŷ (1) and Ŷ (0),
CAT
\ Ei = τˆi = Yiˆ(1) − Yiˆ(0) = h(Xi , 1) − h(Xi , 0)
7 Another outcome modeling approach that could be used to estimate the outcome, not discussed here, is g-
21
Figure 4: Two fundamental approaches to deconfounding. Blunted arrows indicate blocked
causal paths. Treatment modeling approaches like inverse propensity weighting, balancing, and repre-
sentation learning adjust for the association between the covariates X and the treatment T . Outcome
modeling approaches like generalized linear models or machine learning regressors adjust for the asso-
ciation between X and the outcome Y .
22
and the ATE as:
N
1 X
AT
[ E= τˆi
N i=1
A common treatment-modeling strategy is balancing the treated and control covariate distributions
through matching. Matching requires the analyst to select a distance measure that captures the
difference in observed covariate distributons between a treated and untreated unit (Austin, 2011).
Units with treatment status T can then be matched with one or more counterparts with treatment
status 1−T using a variety of algorithms (Stuart, 2010). In a one-to-one matching scenario where each
treated unit has an otherwise identical untreated counterpart, the covariate distribution of treated and
Another common approach is inverse propensity score weighting (IPW). In IPW, units are
weighted on their inverse propensity to receive treatment. Without loss of generality, we call the
propensity function π. The propensity score is calculated as the probability of receiving treatment
conditional on covariates:
π(X, T ) = P (T |X)
N
1 X Ti Yi (1 − Ti )Yi
AT E =
[ + (1)
N i=1 π̂(Xi , 1) π̂(Xi , 0)
Note that only one of the two terms is active for any given unit. Furthermore, this presentation
looks different than how the IPW is generally presented because we use π as a function with different
23
outputs depending on the value of T rather than a scalar (Box 6).8
IPW weighting is attractive because if the propensity score π is specified correctly, it is an unbiased
estimator of the ATE. Moreover, the IPW is consistent if π is estimated consistently (Rosenbaum and
Because different models make different assumptions, it is not uncommon to combine outcome modeling
with propensity modeling or matching estimators to create doubly-robust estimators. For example,
one of the most widely used doubly-robust estimators is the Augmented-IPW (AIPW) estimator.
N
1 X T 1−T
ATˆ E = [( − ) × [Y − h(Φ(X), T )] + [h(Φ(X), 1) − h(Φ(X), 0)] (2)
N i=1 π(Φ(X), 1) π(Φ(X), 0) | {z } | {z }
| {z } Residual Confounding Outcome Modeling
Treatment Modeling
| {z }
Adjustment
The first term is the difference in prediction from two outcome models, one for treated and one for
control units, while the last terms is a “corrected” IPW estimator replacing the raw outcome by the
residuals from the regression models. As expected, this estimator is unbiased if the IPW and regression
estimators are consistently estimated. However, the model is attractive because it will be consistent
if either the propensity score π(X, T ) is correctly specified or the regression model h is consistently
specified (Glynn and Quinn, 2010). The model also provide efficiency gains with respect to the use of
Doubly robust estimation is especially important for causal estimation using machine learning.
When using simple outcome plug-in estimators, bias is directly dependent on estimation error, which
may be different for each potential outcome depending on the modeling strategy (Kennedy, 2020).
Machine learning estimation of the propensity score can also rely heavily on non-confounding predic-
tors, giving rise to extreme weights (Schnitzer et al., 2016). More generally, there are no asymptotic
8 To de-emphasize the contribution of units with extreme weights due to sparse data, sometimes a “stabilized” IPW
24
linearity guarantees for machine learning estimators which may converge at a slow rate, leading to
misleading confidence intervals (Naimi and Balzer, 2018; Zivich and Breskin, 2021). For these reasons,
plug-in machine learning estimation often has poor empirical performance when not using double
robust estimators (Benkeser et al., 2017; Bodnar et al., 2022; Kennedy, 2020; Zivich and Breskin,
2021).
The growth of machine learning for causal inference literature has thus been largely driven by
by using machine learning only to estimate the nuissance parameters (i.e., potential outcomes and
propensity score) of influence functions for causal parameters like the ATE and CATE (Chernozhukov
et al., 2018; Kennedy, 2016; Van der Laan and Rose, 2011). In these approaches, the estimation of
causal parameters is only-second order dependent on machine learning error, there is double-robustness
against inconsistent estimation, and guarantees of fast convergence and asymptotically-valid confidence
intervals even if the machine learning models converge slowly (Benkeser et al., 2017; Kennedy, 2020;
Naimi and Balzer, 2018; Zivich and Breskin, 2021). We use the final algorithm introduced below,
The architectures proposed in the deep learning literature for causal estimation build upon the core idea
discussed above. First, we introduce “S-Learners” and “T-Learners” to show how neural networks can
be used to estimate non-linearities in potential outcomes. Second, given the right objectives, a neural
network can learn representations of the treated and control distributions that are deconfounded (Fig.
3). This approach, which can be related theoretically to non-parametric matching, is illustrated by the
foundational TARNet algorithm in section 4.3 (Shalit et al., 2017). Finally, the machine learning for
causal inference literature has been largely driven by the introduction of semi-parametric frameworks
that allow predictive machine learning models to be plugged-in to doubly robust estimation equations
(Van der Laan and Rose, 2011; Chernozhukov et al., 2018, 2021). In section 4.3, we introduce the
25
concept of influence functions and the targeted maximum likelihood estimator to explain the Dragonnet
algorithm. For clarity the algorithms presented here all share a familial resemblence to the TARNet
algorithm. However, we note that there are many other approaches to using deep learning for causal
Because at most one potential outcome is unobserved, it is not possible to apply supervised models
to directly learn treatment effects. Across econometrics, biostatistics, and machine learning, a common
approach to this challenge has been to instead use machine learning to model each potential outcome
separately and use plug-in estimators for treatment effects (Chernozhukov et al., 2018; Van der Laan
and Rose, 2011; Wager and Athey, 2018). As with linear models, a single neural model can be trained
to learn both potential outcomes (“S[ingle]-learner”) (Fig. 1B), or two independent models can be
trained to learn each potential outcome (a “T-learner”) (Johansson et al., 2020) (Fig. 5A). In both
cases, the neural network estimators would be feed-forward networks tasked with minimizing the MSE
in the prediction of observed outcomes. In a slight abuse of notation, the joint loss function for a
L(Y, h(X, T )) = M SE(Ti (Yi , h1 (Xi , 1)) + (1 − Ti )(Yi , h0 (Xi , 0)) (3)
After training, inputting the same unit into both networks of a T-learner will produce predictions
for both potential outcomes: Ŷ (T ) and Ŷ (1 − T ). We can plug-in these predictions to estimate the
where the first term is a switch to make sure the treated potential outcome comes first. The average
26
Figure 5: A. T-learner. In a T-learner, separate feed-forward networks are used to model each
outcome. We denote the function encoded by these outcome modelers h. B. TARNet. TARNet
extends the T-learner with shared representation layers. The motivation behind TARNet (and further
elaborations of this model) is that the multi-task objective of accurately predicting both the treated
and control potential outcomes forces the representation layers to learn a balancing function Φ such
that the Φ(X|T = 0) and Φ(X|T = 1) are overlapping distributions in representation space. For a code
implementation, see Box 7. C. Dragonnet Dragonnet also adds a propensity score head to TARNet
and a free “nudge” parameter ϵ. In an adaptation of Targeted Maximum Likelihood Estimation, π̂
and ϵ are used to re-weight the outcomes to provide lower biased estimates of the AT E.
27
treatment effect as,
N
1 X
AT
[ E= τˆi
N i=1
Nearly all of the models described below combine this plug-in outcome modeling approach with
TARNet (Tutorial 1 )
Balancing is a treatment adjustment strategy that aims to deconfound the treatment from outcome
by forcing the treated and control covariate distributions closer together (Johansson et al., 2016). The
novel contribution of deep learning to the selection on observables literature is the proposal that a
neural network can transform the covariates into a representation space Φ such that the treated and
To encourage a neural network to learn balanced representations, the seminal paper in this liter-
ature, Shalit et al. (2017), proposes a simple two-headed neural network called Treatment Agnostic
Regression Network (TARNet) that extends the outcome modeling T-learner with shared representa-
tion layers (Fig. 5B). Each head models a separate potential outcome: one head learns the function
Ŷ (1) = h1 (Φ(X), 1), and the other head learns the function Ŷ (0) = h0 (Φ(X), 0). During training,
only one head will receive error gradients at a time (the one predicting the observed outcome). How-
ever, both heads backpropagate their gradients to shared representation layers that learn Φ(X). The
idea is that these representation layers must learn to balance the data because they are tasked with
predicting both outcomes. The authors of this algorithm have subsequently extended TARNet with
additional losses in an algorithm called CFRNET that explicitly encourage balancing by minimizing
a statistical distance between the two covariate distributions in representation space (see Appendix
The complete objective for the network is to fit the parameters of h and Φ for all n units in the
28
training sample such that,
N
1 X
arg min (Yi − (Ti (Yi , h1 (Φ(Xi ), 1)) + (1 − Ti )(Yi , h0 (Φ(Xi ), 0))2 + λ R(h) (4)
h,Φ N i=1 | {z } | {z } | {z }
Yˆi (1) Yˆi (0) L2
or more compactly,
where R(h) is a model complexity term (e.g., for L2 regularization) and λ is a hyperparameter chosen
through model selection. For coded versions of TARNet in Tensorflow and Pytorch, see Box 7.
#In TF fxnl API, stack layers by feeding output of prev layer to next
#Make 2 representation layers
#units is the output dim of layer
#elu is \"exponentiated linear unit" activation fxn
phi = Dense(units=200, activation=’elu’)(x)
phi = Dense(units=200, activation=’elu’)(phi)
# Output predictions
y0_pred = Dense(units=1, activation=None)(y0_hidden)
y1_pred = Dense(units=1, activation=None)(y1_hidden)
#Bundle outputs
concat_pred = Concatenate(1)([y0_pred, y1_pred])
#instantiate model
model = Model(inputs=x, outputs=concat_pred)
return model
29
Pytorch
class TARNet(nn.Module):
def __init__(self,input_dim):
super(TARNet,self).__init__()
self.phi = nn.Sequential(
#both input and output dims are specified in torch
nn.Linear(input_dim, 200),
nn.ELU(), #activations are discrete from layers
nn.Linear(200,200),
nn.ELU())
self.y0_hidden = nn.Sequential(
nn.Linear(200, 100),
nn.ELU(),
nn.Linear(100,100),
nn.ELU())
self.y1_hidden = nn.Sequential(
nn.Linear(200, 100),
nn.ELU(),
nn.Linear(100,100),
nn.ELU())
self.y0_pred =nn.Linear(100,1)
self.y1_pred = nn.Linear(100,1)
y1_rep=rep=self.y1_hidden(rep)
y1_hat=rep=self.y1_pred(y1_rep)
Rather than applying losses directly to the representation function, IPW methods estimate propensity
scores from representations using the function π(Φ(X), T ) = P (T |Φ(X)). As in traditional IPW
estimators, these methods exploit the sufficiency of correctly-specified propensity scores to reweight
the plugged-in outcome predictions and provide unbiased estimates of the ATE (Rosenbaum and
Rubin, 1983). Because these models combine outcome modeling with IPW, they retain the attractive
statistical properties of doubly robust estimators discussed in section 3.2.2 (Atan et al., 2018). In this
30
section we focus on Shi et al. (2019)’s Dragonnet model, which adapts semi-parametric estimation
theory for batch-wise neural network training in a procedure they call “Targeted Regularization”
(TarReg) (Kennedy, 2016). Given the increasing importance of semi-parametric theory and “double
machine learning” across the causal estimation literature, we include a brief introduction to semi-
parametric theory and targeted maximum likelihood estimation (TMLE) before diving into the details
of the Dragonnet algorithm Van der Laan and Rose (2011); Chernozhukov et al. (2018).
Dragonnet (Tutorial 3 )
A trivial extension to TARNet is to add a third head to predict the propensity score. This third
head could use multiple neural network layers or just a single neuron, as proposed in Dragonnet (Fig.
5C) (Shi et al., 2019). Dragonnet uses this additional head to develop a training procedure called
With three heads, the basic loss function for this network looks like:
with α being a hyperparameter to balance the two objectives. The mean squared error and binary cross-
entropy are standard objective functions in machine learning for regression and binary classification,
respectively. Note that the first term is simply an expansion of the first term in equation 4.2
Below, we explore how the authors add a second loss on top of this one to allow for semi-parametric
estimation.
In recent years, semi-parametric theory has emerged as a dominant theoretical framework for applying
machine learning algorithms, including neural networks, to causal estimation (Chernozhukov et al.,
2018, 2021, 2022; Farrell et al., 2021; Kennedy, 2016; Nie and Wager, 2021; Van der Laan and Rose,
2011; Wager and Athey, 2018). The great appeal of these frameworks is that they allow for ma-
31
chine learning algorithms to be plugged-in for non-linear estimates of outcomes and propensity score,
while still providing attractive statistical guarantees (e.g., consistency, efficiency, asymptotically-valid
confidence intervals).
At a very intuitive level, semi-parametric causal estimation is focused on estimating a target pa-
rameter of a distribution P (the AT E) of treatment effects T (P ) (Fisher and Kennedy, 2021). While
we do not know the true distribution of treatment effects because we lack counterfactuals, we do know
some parameters of this distribution (e.g., the treatment assignment mechanism). We can encode
these constraints in the form of a likelihood that parametrically defines a set of possible approximate
distributions P from our existing data P . Within this set there is a sample-inferred distribution P̃ ∈ P,
Regardless of P̃ chosen, P̃ ̸= P → T (P̃ ) ̸= T (P ). We do not know how to pick P̃ with finite data
to get the best estimate T (P̃ ). We can maximize a likelihood function to pick P̃ , but there may be
“nuisance” parameters in the likelihood that are not the target and we do not care about estimating
accurately. Maximum likelihood optimization may provide lower-biased estimates of these nuissance
To sharpen the likelihood’s focus on T (P ), we define a “nudge” parameter ϵ that moves P̃ closer
to P (thus moving T (P̃ ) closer to T (P )). An influence curve of T (P ) tells us how changes in ϵ will
induce changes in T (P + ϵ(P̃ − P )). We’ll use this influence curve to fit ϵ to get a better approximation
of T (P ) within the likelihood framework. In particular, there is a specific efficient influence curve
(EIC) that provides us with the lowest variance estimates of T (P ). In causal estimation, solving
the EIC for the ATE yields estimates that are asymptotically unbiased, efficient, and have confidence
N
1 X Ti 1 − Ti
EICAT E = [( − ) × (Yi − h(Xi , T )) ] + [h(Xi , 1) − h(Xi , 0)]] − AT E (7)
N i=1 π(Xi , 1) π(Xi , 0) | {z } | {z }
| {z } Residual Confounding Outcome Modeling
Treatment Modeling
| {z }
Adjustment
32
Setting EICAT E to it’s mean of 0,
N
1 X Ti 1 − Ti
AT E = [( − ) × (Yi − h(Xi , T )) ] + [h(Xi , 1) − h(Xi , 0)]] (8)
N i=1 π(Xi , 1) π(Xi , 0) | {z } | {z }
| {z } Residual Confounding Outcome Modeling
Treatment Modeling
| {z }
Adjustment
The underbraces illustrate how EICAT E resembles a doubly robust estimator. When the EIC is
minimized (set to 0) as in equation 8, the AT E is equal to the outcome modeling estimate plus a
Targeted Regularization (TarReg) is closely modeled after “Targeted Maxmimum Likelihood Estima-
tion” (TMLE) (Van der Laan and Rose, 2011). TMLE is an iterative procedure where a nuissance
parameter ϵ is used to nudge the outcome models towards sharper estimates of the ATE when mini-
1. Fit h by predicting outcomes (e.g., using TARNet) and minimizing M SE(Y, h(Φ(X), T ))
2. Fit π by predicting treatment (e.g., using logistic regression) and BCE(T, π(Φ(X), T ))
3. Plug-in h and π functions to fit ϵ and estimate h∗ (X, T ) where,
∗ Ti 1 − Ti
h (Xi , Ti ) = h(Φ(Xi ), Φ(Ti )) + − × |{z}
ϵ
| {z } | {z } π(Φ(Xi ), 1) π(Φ(Xi ), 0)
Y ∗ “nudge”
Ŷ
| {z }
Treatment Modeling Adjustment
by minimizing M SE(Y, h∗ (Φ(X), T )). This is equivalent to minimizing the “Adjustment” part
in equation 8.
4. Plug-in h∗ (X, T ) to estimate ATˆ E:
N
1 X ∗
AT
[ E T M LE = h (Xi , 1) − h∗ (Xi , 0)
N i=1 | {z } | {z }
Yi∗ (1) Yi∗ (0)
Targeted Regularization takes TMLE and adapts it for a neural network loss function. The main
difference is that steps 1 and 2 above are done concurrently by Dragonnet, and that the loss functions
9 For a deeper dive on targeted learning, we recommend (Benkeser and Chambaz, ????).
33
for the first three steps are combined into a single loss applied to the whole network at the end of each
batch. It requires adding a single free parameter to the Dragonnet network for ϵ.
At a very intuitive level, Targeted Regularization is appealing because it introduces a loss function
to TARNet that explicitly encourages the network to learn the mean of the treatment effect distri-
bution, and not just the outcome distribution. The Targeted Regularization procedure proceeds as
follows:
In each epoch:
Ti 1 − Ti
h∗ (Φ(Xi ), Ti ) = h(Φ(Xi ), Ti ) + ( − ) × |{z}
ϵ
| {z } | {z } π(Φ(Xi ), 1) π(Φ(Xi ), 0)
Y ∗ “nudge”
Yˆ
| {z }
i
Treatment Modeling Adjustment
At the end of training, we can thus estimate the targeted regularization estimate of the ATE
ˆ T R as in TMLE:
AT E
N
ˆ TR = 1 X ∗
AT E h (Φ(Xi ), 1) − h∗ (Φ(Xi ), 0)
N i=1 | {z } | {z }
Yi∗ (1) Yi∗ (0)
Compared to S-learners, T-learners, and TARNet, the Dragonnet algorithm is particularly attrac-
tive because of the statistical guarantees afforded by its semiparametric framework. It is doubly robust,
unbiased, converges at a rate of √1 , and the sampling distribution is asymptotically normal. Below
n
34
5 Confidence and Interpretation
In this section, we move from theory to practice, and treat best practices for building confidence
intervals and interpreting heterogeneous treatment effects. Both of these topics are active areas of
development, not only within the causal inference literature, but across machine learning research.
(Tutorial 4 )
In this paper, we feature Dragonnet over other approaches because of its attractive statistical
TMLE, an asymptotically valid standard error can be calculated as the sample corrected variance of
s
ˆ TR)
V ar(EICAT E
σAT E
ˆ TR = (9)
N
and,
T 1−T ˆ T R ] (10)
V ar(EICAT E
ˆ T R ) = V ar[( − )(Y − h∗ (X, T )) + (h∗ (X, 1) − h∗ (X, 0)) − AT E
π(X, 1) π(X, 0)
In Tutorial 5, we show how σATˆ E can be used to calculate a Wald confidence interval for Dragonnet.
While not featured in this review, asymptotically valid conference intervals can also be calculated
using RieszNet, a variant of Dragonnet introduced in Chernozhukov et al. (2022) that connects neural
network estimation to the automatically debiased machine learning literature currently popular in
35
5.2 Interpretation
(Tutorial 4 )
A lack of interpretability has been a barrier to the adoption of machine learning methods like
neural networks and random forests in social science settings. However, the literature on post-hoc
interpretability techniques has matured considerably over the past five years, and several techniques
for identifying important features/covariates such as permutation importance, LIME scores, SHAP
scores, Individual Conditional Expectation plots etc... are in widespread usage today (Altmann et al.,
2010; Goldstein et al., 2015; Lundberg and Lee, 2017; Ribeiro et al., 2016). For a broad and accessible
Building on criteria used to evaluate other explainable AI methods, Crabbé et al. (2022) note four
desirable properties of a feature importance technique for the interpretation of deep causal estima-
tors: sensitivity, completeness, linearity, and implementation invariance (Sundararajan et al., 2017). A
method that is ’sensitive’ can distinguish between features that are simply predictive of the outcome,
and those that actually influence CATE heterogeneity. A method that is ’complete’ identifies all fea-
tures that, together, explain all effect heterogeneity compared to a baseline. A ’linear’ method is one
where the feature importance scores additively describe the prediction. Lastly, the approach should be
agnostic to both the model architecture (e.g., TARNet, Dragonnet) and different architectural hyper-
they identify two that manifest all four of these qualities: SHAP scores, and integrated gradients.
SHAP (SHapley Additive exPlanations) scores have emerged as one of the most popular methods
for evaluating machine learning models in recent years (Lundberg and Lee, 2017). SHAP is what is
called a “local” interpretability method: it provides feature importance estimates for each individual
game between covariates to predict a specific outcome. Under the hood, the algorithm exhaustively
compares all possible “coalitions” of covariates and their ability to predict the outcome (win the game).
Predictions from this powerset of coalitions are used to calculate the additive marginal contributions
of each feature in prediction using Shapley values. The disadvantage of SHAP is that, even with
36
computational tricks, calculating scores for every unit can become computationally intractable in high
dimensional datasets. SHAP scores are interpreted in comparison to a causal baseline of the ATE.
Because of the computational expense of SHAP scores, Crabbé et al. (2022) also recommend another
this algorithm draws a straight-line, linear path in feature space between the target input (individual
unit) and a baseline (i.e., a hypothetical unit who is exactly average on all covariates). A feature
importance score can then be constructed by calculating the gradient in prediction error along this
path with respect to the feature of interest. Note that SHAP scores can also be understood theoretically
within the path framework. From this perspective, coalitions are paths in which each feature is turned
on sequentially, and the SHAP score is the expectation across these paths. This interpretation leads
to a gradient-based algorithm for calculating SHAP scores specifically for neural networks, which is
also in the SHAP package. In practice, we recommend that analysts experiment with both integrated
To move from theory to empirics, the online tutorials show how to implement many of the ideas
presented throughout this primer. The tutorials are hosted in notebooks in the Google Colaboratory
environment. When users open a Colab notebook, Google immediately provides a free virtual machine
with standard Python machine learning packages available. This means that readers need not install
anything on their own computers to experiment with these models. The tutorials are written in the
Python programming language and provide examples in both Tensorflow2 and Pytorch, the two most
popular deep learning frameworks. We note that both Tensorflow2 and Pytorch have implementations
in R. However, we strongly recommend that readers interested in getting into deep learning work in
Python, which has a much richer ecosystem of third-party packages for machine learning.
• Tutorial 1 introduces S-learners, and T-learners before TARNet as a way to get familiar
37
• Tutorial 2 focuses on causal inference metrics and hyperparameter optimization. Be-
cause we do not observe counterfactual outcomes, it’s not obvious how to optimize supervised
learning models for causal inference. This tutorial introduces some metrics for evaluating model
performance. In the first part, you learn how to assess performance on these metrics in Tensor-
board. In the second part, we hack Keras Tuner to do hyperparameter optimization for TARNet,
and discuss considerations for training models as estimators rather than predictors.
(2019). We add treatment modeling to our TARNet model, and build an augmented inverse
propensity score estimator. We then briefly describe the algorithm for Targeted Maximum Like-
lihood Estimation to introduce and build a Dragonnet with Shi et al.’s Targeted Regularization.
valid confidence intervals for the average treatment effect. We also interpret the features con-
tributing to different heterogeneous CATEs using Integrated Gradients and SHAP scores. This
tutorial is a good tutorial if you also just want to learn how to interpret SHAP scores, indepen-
weighted CFRNet in Shalit et al. (2017); Johansson et al. (2018, 2020) (Appendix A.A). This
approach relies on integral probability metrics to bound the counterfactual prediction loss and
force the treated and control distributions closer together. The weighted variant adds adaptive
propensity-based weights that provide a consistency guarantee, relax overlap assumptions, and
As exciting as neural networks are for heterogeneous treatment effect estimation from quantitative
data, a great promise of deep causal estimation is inference when treatments, confounders, and media-
38
tors are encoded in high-dimensional data (e.g., text, images, social networks, speech, and video) or are
time-varying. This is a strong advantage of neural networks over other machine learning approaches,
which do not generalize competitively to non-quantitative data. In these scenarios, multi-task ob-
jectives and tailored architectures can be used to learn representations that are simultaneously rich,
capture information about causal quantities, and disentangle their relationships. Moreover, the inher-
ent flexibility of neural networks means that, in many cases, the TARNet-style models presented above
can serve as the foundations to inference on text and graphs with some architectural modifications,
This literature is rapidly evolving, so readers should treat this section of the primer as funda-
mentally prospective. To maintain accessibility, our primary goal here is to introduce readers to
hypothetical scenarios where they might perform causal inference on text, network, or image data.
tion in these settings. The identification assumptions for different data types differ substantially, so we
generally leave those to the interested reader. Finally, we briefly discuss approaches for dealing with
time-varying confounding. We also take this section as an opportunity to introduce the Transformer
or Graph Neural Network, an architecture now used in most contemporary deep learning models to
In recent years, an interdisciplinary community across both social science and computer science has co-
alesced around causal inference from text (see Keith et al. (2020) and Feder et al. (2021) for exhaustive
reviews). Broadly speaking, texts may capture information about any causal quantity (treatments,
outcomes, confounders, mediators) we might be interested in. For example, in an exit-polling exper-
iment, analysts might want to measure toxicity (Y ) in text responses to political prompts. In an
observational study of e-mail response times (Y ), analysts might want to measure the effects of the
tone of the email (T ). In this scenario, the analyst might also want to control for confounders like sub-
ject matter (X). Each of these scenarios presents distinct identification challenges (Feder et al., 2021).
But in all cases, we can use low-dimensional representations of the high dimensional text to extract,
39
quantify, and disentangle relationships between nuanced qualities like tone and subject matter.
The ability of neural networks to automatically extract features makes them particularly suited for
the last scenario when both treatment information and confounding covariates are encoded in text.
In many cases, we may not have explicitly identified, quantified, or labeled all of the confounders in
text (e.g., subject matter and tone of emails), but we would still like to control for them. Pryzant
et al. (2021),Veitch et al. (2020), and Gui and Veitch (2022) address this problem by prepending
Transformer-layers (Box 8) for reading text to the beginning of TARNet or Dragonnet. Veitch et al.
(2020) demonstrate the viability of this approach on a Science of Science question testing the causal
effect of equations on getting papers accepted to computer science conferences. Pryzant et al. (2021);
Gui and Veitch (2022) explore the more complicated scenario not in which the treatment is explicitly
known (e.g., equations in papers, gender of authors), but is instead externally perceived upon reading
(e.g., politeness/rudeness of an email or toxicity of a social media post). In these models, an additional
loss function is also added for learning text representations concurrently with the causal inference losses
discussed above.
Graph neural networks (GNNs) are the current state-of-the-art approach for creating
representations for nodes in graphs. Compared to previous approaches that relied on “shallow”
embeddings based only on a node’s local context (e.g., random walks to nearby nodes), GNNs
are attractive because their node representations are aggregated from the structural position
and covariates of all nodes n degrees away from the target node, where n is the number of
graph neural network layers.
The most intuitive understanding of how graph neural networks work is as a message passing
system (Gilmer et al., 2017). We use one of the first GNN papers, the Graph Convolutional
Network as an example (Kipf and Welling, 2017). In this interpretation, each node has a
message that it passes to it’s neighbors through a graph convolution operation. In the first layer
of a GNN this message would consist of the node’s covariates/features. In consecutive layers of
the network, these messages are actually representations of the node produced by the previous
layer. During graph convolution, each node multiplies incoming messages by it’s own set of
weights and combines these weighted inputs using an aggregation function (e.g., summation).
By the n-th GNN layer, these messages will contain structure and covariate information from
all nodes n degrees away. For interested readers, there is also a spectral interpretation of this
process. Typically GNNs are trained to produce representations of graphs by predicting the
probability that two nodes are linked in the network, and then used for something else. One
variant of the GNN uses an “attention” mechanism to vary the extent that nodes value messages
from different neighbors (the graph attention network or GAT) (Veličković et al., 2018).
40
As of 2023, Transformers are the hegemonic architecture used in natural language process-
ing. After their introduction in 2017, these models improved performance on many high-profile
NLP tasks across the board. Several enterprise-scale transformers have been featured in the
media for their impressive performance in text generation and question answering (e.g. Ope-
nAI’s GPT-3 and ChatGPT, Google’s Bard). Smaller models in broad use are based on the
BERT architecture (Devlin et al., 2019).
The connection between GNNs, and specifically GATs, is the focus on attention mechanisms.
From this perspective, words in sentences are akin to nodes in networks, with their relative
positions to each other being analogous to their structural positions in the graph. Transformers
improved on previous sequential approaches to text analysis (i.e. recurrent neural networks) by
having each word (or representation of a word) receive messages from not just adjacent words,
but all words heterogeneously. Attention mechanisms throughout the architecture allow each
layer of a transformer to attend to words or aggregated representation mechanisms heteroge-
neously. Architectures such as BERT or GPT stack transformer layers to create models with
hundreds of millions of parameters. These models are expensive to train, both computationally
and with respect to data, so they are often pretrained on large datasets and then “fine-tuned”
(lightly re-trained) with smaller datasets for specific tasks or to align with certain goals.
A smaller literature has leveraged relational data for causal inference in two distinct scenarios. In the
first traditional selection on observable settings, we wish to control for information about unobserved
confounding inferable from homophilous ties. For example, age or gender might be unmeasured in our
data, but we might expect people to develop friendship ties with those of the same gender identity or
age cohort.
This scenario suggests estimation strategies similar to those when confounders are encoded in text.
Much like Transformer layers can be prepended to TARNet-style estimators to learn from text, graph
neural networks (an analog of the Transformer) can be preprended to learn from graphs. Guo et al.
(2020) provides a first pass at this problem by adding GNN layers to CFRNet Shalit et al. (2017).
Veitch et al. (2019) instead adapt Dragonnet in a semi-parametric framework to allow for consistent
estimates of the treatment and outcome, assuming the network representation encodes significant
The second, more challenging scenario is estimating the causal effect of social influence on outcomes
from observational data. For example, Cristali and Veitch (2022) introduce the problem of measuring
41
the effects of vaccination (T ) on peer vaccination choice (Y ). This is a hard problem because a) SUTVA
is a fundamental assumption of all causal inference frameworks and b) it is hard to disentangle whether
changes in the outcome result from the treatment via peer effects (e.g, person A pressuring person B to
vaccinate), or from homophily (e.g., person A and person B having similar political leanings). In other
words, contagion and homophily are generically confounded (Shalizi and Thomas, 2011). McFowland
and Shalizi (2021) are the first to tackle this problem by making strong parametric assumptions about
the generation of network ties and the outcome model. Cristali and Veitch (2022) instead propose an
While ideas from causal inference have been leveraged extensively to improve image classification,
to our knowledge there are no papers that explore causal inference where treatments, confounders,
mediators, or predictors are encoded in images.10 That being said, some scenarios proposed for
causal text analysis should apply here as well. For example, consider the conjoint experiment by
Todorov et al. (2005) where both the treatment (e.g. incumbency of a politician) and potential latent
confounders (e.g., party, age, gender, race) are encoded in an image. In this setting, a TARNet-
like model adapted to learn and condition on image representations could improve treatment effect
estimation by controlling for confounders such as the politician’s age. Causal inference on images is
an area ripe for exploration, and we hope to see more work here in the future.
One natural extension of deep causal estimation is to scenarios where treatments are administered
over time and confounding may be time-varying. While “g-methods” developed by Robins et al.
for estimating effects with time-varying treatments and confounding have existed for decades, the
statistical assumptions encoded in these models are quite strong (Robins, 1994; Robins et al., 2000,
2009). Due to their reliance on generalized linear models to define the “structural” component, they
10 Jessonet al. (2021) introduce a simulation where the MNIST digit dataset serves as covariates X as toy example of
high-dimensional confounding, but not a possible application.
42
assume that the outcome is a linear function of all covariates and treatment. Second, for identification,
they make strong assumptions about which previous timesteps confound the current one. Third, they
require different coefficients to be estimated at each time steps. Transformers (Box 8) and recurrent
neural networks, a simpler model for sequential data (Appendix A.C), should be able to capture long-
term dependencies and non-linearities in ways that marginal structural models and g-computation
cannot.
Several papers have begun to explore these possibilities in the context of personalized medicine. Lim
et al. (2018) build a marginal structural model using a recurrent neural network, and Bica et al. (2020a)
extend this framework with an additional loss to more explicitly deal with time varying confounding by
forcing the model to “unlearn” information about the previous time steps. Melnychuk et al. (2022) go
one step further by adapting Bica et al. (2020a)’s approach with a transformer. Inspired by longitudinal
targeted maximum likelihood, Frauen et al. (2022) add a semi-parametric targeting layer to their RNN
to create a g-computation algorithm that is doubly robust and asymptotically efficient. Li et al. (2021)
instead propose an RNN framework for g-computation that allows for dynamic treatment regimes. All
of these papers use simulations of tumor growth dynamics, naturalistic simulations based on vital signs
from intensive care unit visits, or factual datasets exploring treatment response to physical therapy
In this primer we introduce social scientists to the emerging machine learning literature on deep
learning for causal inference. To set the stage, we first provide both an intuitive introduction to
fundamental deep learning concepts like representation and multi-task learning, as well as practical
guidelines for training neural networks. In the main body of the article, we show how ML researchers
have adapted core treatment and outcome modeling strategies to leverage the particular strengths
of neural networks for heterogeneous treatment effect estimation. We follow with a discussion on
inference (e.g., model selection, confidence intervals, interpretation), and closed with a prospective
look at algorithms for inference from text, social networks, images, and time varying data.
43
Deep learning is not the only potential tool for heterogeneous treatment effect inference, and there
are robust literatures exploring the usage of other methods in both the econometrics and biostatistics
communities (Van der Laan and Rose, 2011; Chernozhukov et al., 2018; Wager and Athey, 2018).
While these literatures are certainly more mature, below we discuss reasons why we think the use gap
between neural networks and other machine learning methods will continue to narrow, a change that
First, neural networks are better at modeling non-linear heterogeneity (e.g., in treatment responses)
than other machine learning methods. In extensive simulations, Curth et al. (2021) found that when
the data-generating process for treatment heterogeneity includes exponential relationships, neural
networks outperformed random forests, but tree-based methods are robust when the data-generating
process is built on linear functions. Neural networks were also consistently better at predicting outlier
treatment effects than forests. These differences result from how the two methods model functions.
While neural networks can approximate any continuous function with enough neurons, random forests
must build non-linear or non-orthogonal decision boundaries using piecewise functions and average
predictions. Consistent with these differences, Curth et al. (2021) also find that neural networks
do better when variables are constructed as continuous covariates, and vice versa when they are
dichotomized.
From a statistical perspective, the rise of semi-parametric and double machine learning frameworks
has also narrowed the gap between neural networks and other types of machine learning in terms of the-
oretical guarantees. For example, the TMLE-inspired Dragonnet algorithm featured here is unbiased,
plausibly consistent, and converges to the target estimand at a fast rate of √1 . The closely-related
n
Riezsnet double machine learning model (not featured) boasts similar guarantees (Chernozhukov et al.,
2022). Beyond these algorithms, there is a growing adjacent literature of model-agnostic plug-in learn-
ers (e.g., X-learner, R-learner) that can leverage the strengths of neural networks (Nie and Wager,
Third, folk beliefs about the data-hungriness and uninterpretability of neural networks are over-
stated. Neural networks are data-hungry when over-parameterized or learning from high-dimensional
data like images, but we show in the tutorials that modest-sized, well-regularized neural networks
44
can successfully infer heterogeneous treatment effects in a naturalistic simulation of quantitative data
with less than 800 units. In Section 5, we also highlight the considerable progress in machine learning
interpretability over the past five years, much of which has been on model-agnostic approaches that
In our opinion, the most pressing limitation of current deep learning approaches is the difficulty
of optimizing neural networks. Theoretically, this stems from a) the complexity of the loss functions
which are often non-convex, and b) the ease of over-parameterizing these models to fit these functions.
If neural networks are to be used as statistical estimators, statistical guarantees must be backed by
optimization guarantees and/or more rigorous methods for model selection. Outside of statistical
estimation, this limitation has largely been addressed through empirical testing on test data and
strategic model selection. Within the statistical estimation context, this gap will likely need to be
addressed by simulation-based sensitivity analyses and, in the short term, comparisons to other model
families.
Moreover, there has been a lack of mature tools and empirical applications of these models. A
major goal of this primer, and the tutorials in particular, is to synthesize the theoretical literature,
practical training and interpretation guidelines, and annotated code so that social scientists in one place
can start using these models. Deep learning frameworks like Tensorflow and Pytorch are becoming
more accessible every year, but we note that canned Python packages like Uber’s causalML exist for
interested readers who just want to experiment with a few of these models (Chen et al., 2020).
Despite current limitations, we believe the future of causal estimation runs through deep learning.
As causal inference ventures into new settings, the flexibility of neural networks will become essential
for learning from text, graph, image, video, and speech data. For time-varying settings, we believe
the ability of neural networks to model non-linearities and long-range temporal dependencies will
ultimately lead to solutions with net weaker assumptions than current approaches. Overall, we are
optimistic and excited to see where deep causal estimation heads over the next few years.
11 Critics often point to out-of-bag feature importances as a particular strength of random forests, but this approach
has been shown to be less accurate than model-agnostic permutation importances anyways (Altmann et al., 2010).
45
References
Alaa, Ahmed and Mihaela Van Der Schaar. 2019. “Validating Causal Inference Models via Influence
Functions.” In International Conference on Machine Learning, volume 36, pp. 191–201. Association
Altmann, André, Laura Toloşi, Oliver Sander, and Thomas Lengauer. 2010. “Permutation importance:
Arjovsky, Martin, Soumith Chintala, and Léon Bottou. 2017. “Wasserstein Generative Adversarial
Networks.” In International Conference on Machine Learning, volume 34, pp. 214–223. Association
Atan, Onur, James Jordon, and Mihaela Van Der Schaar. 2018. “Deep-Treat: Learning Optimal
Personalized Treatments from Observational Data Using Neural Networks.” In Association for the
Athey, Susan and Guido Imbens. 2016. “Recursive partitioning for heterogeneous causal effects.”
Austin, Peter C. 2011. “An Introduction to Propensity Score Methods for Reducing the Effects of
Conference on Statistical Language and Speech Processing, pp. 1–37. Association for Computing
Machinery.
Benkeser, D, M Carone, M J Van Der Laan, and P B Gilbert. 2017. “Doubly robust nonparametric
Benkeser, David and Antoine Chambaz. ???? A Ride in Targeted Learning Territory.
Bica, Ioana, Ahmed M Alaa, James Jordon, and Mihaela van der Schaar. 2020a. “Estimating Coun-
46
International Conference on Learning Representations, volume 37. Association for Computing Ma-
chinery.
Bica, Ioana, James Jordon, and Mihaela van der Schaar. 2020b. “Estimating the Effects of Continuous-
Bodnar, Lisa M, Abigail R Cartus, Edward H Kennedy, Sharon I Kirkpatrick, Sara M Parisi, Kather-
ine P Himes, Corette B Parker, William A Grobman, Hyagriv N Simhan, Robert M Silver, Deb-
orah A Wing, Samuel Perry, and Ashley I Naimi. 2022. “Use of a Doubly Robust Machine-
Learning–Based Approach to Evaluate Body Mass Index as a Modifier of the Association Between
Fruit and Vegetable Intake and Preeclampsia.” American Journal of Epidemiology 191:1396–1406.
Brand, Jennie E, Bernard Koch, and Jiahui Xu. 2020. “Machine Learning.” In Sage Research Methods
Foundations. SAGE.
Chen, Huigang, Totte Harinen, Jeong-Yoon Lee, Mike Yung, and Zhenyu Zhao. 2020. “CausalML:
Chernozhukov, Victor, Denis Chetverikov, Mert Demirer, Esther Duflo, Christian Hansen, Whitney
Newey, and James Robins. 2018. “Double/debiased machine learning for treatment and structural
Chernozhukov, Victor, Whitney Newey, Victor M Quintas-Martinez, and Vasilis Syrgkanis. 2022.
“Riesznet and forestriesz: Automatic debiased machine learning with neural nets and random
Chernozhukov, Victor, Whitney K Newey, Victor Quintas-Martinez, and Vasilis Syrgkanis. 2021. “Au-
tomatic debiased machine learning via neural nets for generalized linear regression.” arXiv preprint
arXiv:2104.14737 .
Cho, Kyunghyun, Bart van Merriënboer, Caglar Gulcehre, Dzmitry Bahdanau, Fethi Bougares, Holger
Schwenk, and Yoshua Bengio. 2014. “Learning Phrase Representations using RNN Encoder–Decoder
47
for Statistical Machine Translation.” In Conference on Empirical Methods in Natural Language
Crabbé, Jonathan, Alicia Curth, Ioana Bica, and Mihaela van der Schaar. 2022. “Benchmark-
ing heterogeneous treatment effect models through the lens of interpretability.” arXiv preprint
arXiv:2206.08363 .
Cristali, Irina and Victor Veitch. 2022. “Using Embeddings for Causal Estimation of Peer Influence
Curth, Alicia, David Svensson, Jim Weatherall, and Mihaela van der Schaar. 2021. “Really Doing
Estimation.” In Proceedings of the Neural Information Processing Systems Track on Datasets and
Cuturi, Marco. 2013. “Sinkhorn Distances: Lightspeed Computation of Optimal Transport.” In Neural
Information Processing Systems, volume 27, pp. 2292–2300. Curran Associates, Inc.
Devlin, Jacob, Ming-Wei Chang, Kenton Lee, and Kristina Toutanova. 2019. “BERT: Pre-training
American Chapter of the Association for Computational Linguistics: Human Language Technologies,
Du, Xin, Lei Sun, Wouter Duivesteijn, Alexander Nikolaev, and Mykola Pechenizkiy. 2021. “Adversar-
ial Balancing-based Representation Learning for Causal Effect Inference with Observational Data.”
48
Farrell, Max H, Tengyuan Liang, and Sanjog Misra. 2021. “Deep Neural Networks for Estimation and
Feder, Amir, Katherine A Keith, Emaad Manzoor, Reid Pryzant, Dhanya Sridhar, Zach Wood-
Doughty, Jacob Eisenstein, Justin Grimmer, Roi Reichart, Margaret E Roberts, et al. 2021. “Causal
Fisher, Aaron and Edward H Kennedy. 2021. “Visually Communicating and Teaching Intuition for
Frauen, Dennis, Tobias Hatt, Valentyn Melnychuk, and Stefan Feuerriegel. 2022. “Estimating average
Gilmer, Justin, Samuel S Schoenholz, Patrick F Riley, Oriol Vinyals, and George E Dahl. 2017.
“Neural message passing for quantum chemistry.” In International Conference on Machine Learning,
Glynn, Adam N and Kevin M Quinn. 2010. “An introduction to the Augmented Inverse Propensity
Goldstein, Alex, Adam Kapelner, Justin Bleich, and Emil Pitkin. 2015. “Peeking Inside the Black
Box: Visualizing Statistical Learning With Plots of Individual Conditional Expectation.” Journal
Goldszmidt, Moisés and Judea Pearl. 1996. “Qualitative probabilities for default reasoning, belief
Goodfellow, Ian, Yoshua Bengio, and Aaron Courville. 2016. Deep Learning. MIT Press. http:
//www.deeplearningbook.org.
Goodfellow, Ian, Jean Pouget-Abadie, Mehdi Mirza, Bing Xu, David Warde-Farley, Sherjil Ozair,
Aaron Courville, and Yoshua Bengio. 2014. “Generative Adversarial Nets.” In Neural Information
Processing Systems, volume 27, pp. 2672–2680. Association for Computing Machinery.
49
Gretton, Arthur, Karsten M Borgwardt, Malte J Rasch, Bernhard Schölkopf, and Alexander Smola.
Gui, Lin and Victor Veitch. 2022. “Causal Estimation for Text Data with (Apparent) Overlap Viola-
Gulrajani, Ishaan, Faruk Ahmed, Martin Arjovsky, Vincent Dumoulin, and Aaron Courville. 2017.
Guo, Ruocheng, Jundong Li, and Huan Liu. 2020. “Counterfactual Evaluation of Treatment Assign-
ment Functions with Networked Observational Data.” In SIAM International Conference on Data
Hastie, T., R. Tibshirani, and J.H. Friedman. 2009. The Elements of Statistical Learning: Data
Heck, Katherine E, Paula Braveman, Catherine Cubbin, Gilberto F Chávez, and John L Kiely. 2006.
“Socioeconomic status and breastfeeding initiation among California mothers.” Public health reports
121:51–59.
Hernán, Miguel A. and James M. Robins. 2020. Causal Inference: What If. 2020 . Chapman & Hall.
Hill, Jennifer L. 2011. “Bayesian Nonparametric Modeling for Causal Inference.” Journal of Compu-
Hochreiter, Sepp and Jürgen Schmidhuber. 1997. “Long Short-Term Memory.” Neural Computation
9:1735–1780.
Holland, Paul W. 1986. “Statistics and Causal Inference.” Journal of the American statistical Asso-
ciation 81:945–960.
Huszar, Ferenc. 2015. “Another Favourite Machine Learning Paper: Adversarial Networks vs Kernel
50
Imbens, Guido W and Donald B Rubin. 2015. Causal Inference in Statistics, Social, and Biomedical
Ioffe, Sergey and Christian Szegedy. 2015. “Batch Normalization: Accelerating Deep Network Training
by Reducing Internal Covariate Shift.” In International Conference on Machine Learning, volume 32,
Jesson, Andrew, Sören Mindermann, Yarin Gal, and Uri Shalit. 2021. “Quantifying Igno-
arXiv:2103.04850 .
Johansson, Fredrik and Max Shen. 2018. “Causal Inference & Deep Learning.”
Johansson, Fredrik D, Nathan Kallus, Uri Shalit, and David Sontag. 2018. “Learning Weighted Rep-
Johansson, Fredrik D., Uri Shalit, Nathan Kallus, and David A. Sontag. 2020. “Generalization Bounds
and Representation Learning for Estimation of Potential Outcomes and Causal Effects.” arXiv
abs/2001.07426.
Johansson, Fredrik D, Uri Shalit, and David Sontag. 2016. “Learning Representations for Counter-
factual Inference.” In International Conference on Machine Learning, volume 48. Association for
Computing Machinery.
Joo, Jungseock, Francis F Steen, and Song-Chun Zhu. 2015. “Automated Facial Trait Judgment and
Kallus, Nathan. 2020. “Generalized Optimal Matching Methods for Causal Inference.” Journal of
Keith, Katherine, David Jensen, and Brendan O’Connor. 2020. “Text and Causal Inference: A Review
of Using Text to Remove Confounding from Causal Estimates.” In Annual Meeting of the Association
51
for Computational Linguistics, volume 58, pp. 5332–5344, Online. Association for Computational
Linguistics.
Kennedy, Edward H. 2016. “Semiparametric Theory and Empirical Processes in Causal Inference.” In
Statistical Causal Inferences and their Applications in Public Health Research, pp. 141–167. Springer.
Kennedy, Edward H. 2020. “Towards optimal doubly robust estimation of heterogeneous causal ef-
fects.”
Kingma, Diederik P. and Jimmy Ba. 2015. “Adam: A Method for Stochastic Optimization.” In
Kipf, Thomas N and Max Welling. 2017. “Semi-supervised classification with graph convolutional
Kramer, Michael S, Frances Aboud, Elena Mironova, Irina Vanilovich, Robert W Platt, Lidia Matush,
Sergei Igumnov, Eric Fombonne, Natalia Bogdanovich, Thierry Ducruet, et al. 2008. “Breastfeeding
and child cognitive development: new evidence from a large randomized trial.” Archives of general
psychiatry 65:578–584.
Künzel, Sören R, Jasjeet S Sekhon, Peter J Bickel, and Bin Yu. 2019. “Metalearners for estimating
heterogeneous treatment effects using machine learning.” Proceedings of the national academy of
sciences 116:4156–4165.
LeCun, Yann, Yoshua Bengio, and Geoffrey Hinton. 2015. “Deep learning.” Nature 521:436.
Li, Rui, Stephanie Hu, Mingyu Lu, Yuria Utsumi, Prithwish Chakraborty, Daby M. Sow, Piyush
Madan, Jun Li, Mohamed Ghalwash, Zach Shahn, and Li-wei Lehman. 2021. “G-Net: a Recurrent
Regime.” In Proceedings of Machine Learning for Health, edited by Subhrajit Roy, Stephen Pfohl,
Emma Rocheteau, Girmaw Abebe Tadesse, Luis Oala, Fabian Falck, Yuyin Zhou, Liyue Shen,
Ghada Zamzmi, Purity Mugambi, Ayah Zirikly, Matthew B. A. McDermott, and Emily Alsentzer,
52
Lim, Bryan, Ahmed M Alaa, and Mihaela van der Schaar. 2018. “Forecasting Treatment Responses
Over Time Using Recurrent Marginal Structural Networks.” In Neural Information Processing
Lundberg, Scott M and Su-In Lee. 2017. “A Unified Approach to Interpreting Model Predictions.” In
Advances in Neural Information Processing Systems, edited by I. Guyon, U. Von Luxburg, S. Bengio,
H. Wallach, R. Fergus, S. Vishwanathan, and R. Garnett, volume 30. Curran Associates, Inc.
McFowland, Edward and Cosma Rohilla Shalizi. 2021. “Estimating Causal Peer Influence in Ho-
mophilous Social Networks by Inferring Latent Locations.” Journal of the American Statistical
Association 0:1–12.
Melnychuk, Valentyn, Dennis Frauen, and Stefan Feuerriegel. 2022. “Causal Transformer for Estimat-
Learning, edited by Kamalika Chaudhuri, Stefanie Jegelka, Le Song, Csaba Szepesvari, Gang Niu,
and Sivan Sabato, volume 162 of Proceedings of Machine Learning Research, pp. 15293–15329.
PMLR.
Mikolov, Tomas, Ilya Sutskever, Kai Chen, Greg Corrado, and Jeffrey Dean. 2013. “Distributed Rep-
resentations of Words and Phrases and Their Compositionality.” In Neural Information Processing
Molnar, C. 2022. Interpretable Machine Learning: A Guide for Making Black Box Models Explainable.
Christoph Molnar.
Müller, Alfred. 1997. “Integral Probability Metrics and Their Generating Classes of Functions.”
Nagpal, Chirag, Dennis Wei, Bhanukiran Vinzamuri, Monica Shekhar, Sara E. Berger, Subhro Das, and
Kush R. Varshney. 2020. “Interpretable Subgroup Discovery in Treatment Effect Estimation with
53
Naimi, Ashley I and Laura B Balzer. 2018. “Stacked generalization: an introduction to super learning.”
Nie, Xinkun and Stefan Wager. 2021. “Quasi-oracle Estimation of Heterogeneous Treatment Effects.”
Biometrika 108:299–319.
Parikh, Harsh, Carlos Varjao, Louise Xu, and Eric Tchetgen Tchetgen. 2022. “Validating Causal
edited by Kamalika Chaudhuri, Stefanie Jegelka, Le Song, Csaba Szepesvari, Gang Niu, and Sivan
Sabato, volume 162 of Proceedings of Machine Learning Research, pp. 17346–17358. PMLR.
Pryzant, Reid, Dallas Card, Dan Jurafsky, Victor Veitch, and Dhanya Sridhar. 2021. “Causal Effects of
Linguistic Properties.” In Conference of the North American Chapter of the Association for Compu-
tational Linguistics: Human Language Technologies, pp. 4095–4109. Association for Computational
Linguistics.
Ramey, Craig T, Donna M Bryant, Barbara H Wasik, Joseph J Sparling, Kaye H Fendt, and Lisa M
La Vange. 1992. “Infant Health and Development Program for low birth weight, premature infants:
Ribeiro, Marco Tulio, Sameer Singh, and Carlos Guestrin. 2016. “” Why should i trust you?” Ex-
plaining the predictions of any classifier.” In Proceedings of the 22nd ACM SIGKDD international
Roberts, Daniel A., Sho Yaida, and Boris Hanin. 2022. The Principles of Deep Learning Theory.
Robins, James. 1986. “A New Approach to Causal Inference in Mortality Studies with a Sustained
Modelling 7:1393–1512.
54
Robins, James. 1987. “A Graphical Approach to the Identification and Estimation of Causal Pa-
rameters in Mortality Studies with Sustained Exposure Periods.” Journal of Chronic Diseases
40:139S–161S.
Robins, James M. 1994. “Correcting for non-compliance in randomized trials using structural nested
Robins, James M, Miguel Angel Hernan, and Babette Brumback. 2000. “Marginal Structural Models
2009. “Longitudinal Data Analysis.” Handbooks of Modern Statistical Methods pp. 553–599.
Rosenbaum, Paul R and Donald B Rubin. 1983. “The Central Role of the Propensity Score in Obser-
Rubin, Donald B. 1974. “Estimating Causal Effects of Treatments in Randomized and Non-randomized
Schnitzer, Mireille E, Judith J Lok, and Susan Gruber. 2016. “Variable selection for confounder control,
flexible modeling and collaborative targeted minimum loss-based estimation in causal inference.”
Schwab, Patrick, Lorenz Linhardt, and Walter Karlen. 2018. “Perfect Match: A Simple Method for
Shalit, Uri, Fredrik D Johansson, and David Sontag. 2017. “Estimating Individual Treatment Effect
Shalizi, Cosma Rohilla and Andrew C. Thomas. 2011. “Homophily and Contagion Are Generically
Confounded in Observational Social Network Studies.” Sociological Methods & Research 40:211–239.
55
Shi, Claudia, David Blei, and Victor Veitch. 2019. “Adapting Neural Networks for the Estimation of
Snoek, Jasper, Hugo Larochelle, and Ryan P Adams. 2012. “Practical Bayesian Optimization of
F. Pereira, C.J. Burges, L. Bottou, and K.Q. Weinberger, volume 25. Curran Associates, Inc.
Srivastava, Nitish, Geoffrey E. Hinton, Alex Krizhevsky, Ilya Sutskever, and Ruslan Salakhutdinov.
2014. “Dropout: a simple way to prevent neural networks from overfitting.” Journal of Machine
Stuart, Elizabeth A. 2010. “Matching Methods for Causal Inference: A Review and a Look Forward.”
Sundararajan, Mukund, Ankur Taly, and Qiqi Yan. 2017. “Axiomatic attribution for deep networks.”
Todorov, Alexander, Anesu N Mandisodza, Amir Goren, and Crystal C Hall. 2005. “Inferences of
Van der Laan, Mark J and Sherri Rose. 2011. Targeted Learning: Causal Inference for Observational
Veitch, Victor, Dhanya Sridhar, and David Blei. 2020. “Adapting Text Embeddings for Causal Infer-
ence.” In Conference on Uncertainty in Artificial Intelligence, pp. 919–928. Association for Uncer-
Veitch, Victor, Yixin Wang, and David Blei. 2019. “Using Embeddings to Correct for Unobserved
56
Veličković, Petar, Guillem Cucurull, Arantxa Casanova, Adriana Romero, Pietro Lio, and Yoshua Ben-
gio. 2018. “Graph Attention Networks.” In International Conference for Learning Representations,
volume 6. OpenReview.
Wager, Stefan and Susan Athey. 2018. “Estimation and Inference of Heterogeneous Treatment Effects
Yao, Liuyi, Sheng Li, Yaliang Li, Mengdi Huai, Jing Gao, and Aidong Zhang. 2018. “Representa-
tion Learning for Treatment Effect Estimation from Observational Data.” In Neural Information
Yoon, Jinsung, James Jordon, and Mihaela Van Der Schaar. 2018. “GANITE: Estimation of Indi-
Zhang, Yao, Alexis Bellot, and Mihaela Schaar. 2020. “Learning Overlapping Representations for
Intelligence and Statistics, pp. 1005–1014. Association for Artificial Intelligence and Statistics.
Zivich, Paul N and Alexander Breskin. 2021. “Machine learning for causal inference: on the use of
Following (Stock, 2017; Daza, 2019), suppose we have two discrete distributions (treated and control)
with marginal densities p(x) and q(x) captured as vectors t and c, with dimensions n and m respectively.
To compute the Wasserstein distance, we must define a ”mapping matrix” P that defines the mapping
of “earth” in p(x) to corresponding piles in q(x). Let U(t, c) be the set of positive, n × m mapping
57
matrices where the sum of the rows is t and the sum of the columns is c.
U(t, c) = P ⊂ Rnxm T
>0 |P · 1m = t, P · 1n = c (11)
In words, this matrix maps the probability mass from points in the support of p(x) (i.e, the elements
of t) to points in the support of q(x) (the elements of c) (note that the mapping need not be one-to-
one). We also have a “cost” matrix C that describes the cost of applying P (i.e. the cost of shoveling
dirt according to the map described in P ). The cost matrix can be computed using a norm ℓ (most
commonly ℓ2 ) between the points in t being mapped to c in the mapping matrix P . Finally, the ℓ-norm
X
dWℓ = minP ⊂(t,c) Pij Cij (12)
i,j
In other words, the Wasserstein distance is the smallest Frobenius inner product of a mapping matrix
P that fits the above constraints, and its associated cost matrix C. Although this problem can be
solved via linear programming, the Wasserstein distance is often implemented in a different form that
works with continuous distributions and can be optimized by gradient descent (Arjovsky et al., 2017;
Gulrajani et al., 2017). There is also a variant of the Wasserstein distance that imposes an entropy-
based regularization on the coupling matrix to make it smoother or sparser called the Sinkhorn distance
(Cuturi, 2013).
Deep Dive: CFRNet (Shalit et al. (2017); Johansson et al. (2018, 2020))
Beyond receiving outcome modeling gradients for both potential outcomes, the authors have sub-
sequently extended TARNet with additional losses that explicitly encourage balancing by minimizing
a statistical distance between the two covariate distributions in representation space. These distances
are called integral probability metrics (Müller, 1997).12 Johansson et al. (2016); Shalit et al. (2017);
12 Zhang et al. (2020) criticize the usage of IPMs because they make no restrictions on the moments of the transformed
distributions. Thus while the covariate distributions may have a high percentage of overlap in representation space, this
58
Johansson et al. (2018) propose two possible IPMs, the Wasserstein distance and the maximum mean
The Wasserstein or “Earth Mover’s” distance fits an interpretable “map” (i.e. a matrix) showing
how to efficiently convert from one probability mass distribution to another. The Wasserstein distance
is most easily understood as an optimal transport problem (i.e., a scenario where we want to transport
one distribution to another at minimum cost). The nickname “Earth mover’s distance” comes from
the metaphor of shoveling dirt to terraform one landscape into another. In the idealized case in which
one distribution can be perfectly transformed into another, the Wasserstein map corresponds exactly
The MMD is the normed distance between the means of two distributions, after a kernel function
ϕ has transformed them into a high-dimensional space called a reproducing kernel Hibbert Space
(RKHS) (Gretton et al., 2012). The MMD with an L2 norm in RKHS H can be specified as:
The metric is built on the idea that there is no function that would have differing Expected Values
for P and Q in this high-dimensional space if P and Q are the same distribution (Huszar, 2015). The
MMD is inexpensive to calculate using the “kernel trick” where the inner product between two points
can be calculated in the RKHS without first transforming each point into the RKHS.13
When an IPM loss is applied to the representation layers in TARNet, the authors call the resulting
network “CounterFactual Regression Network” (CFRNet) (Fig. 6A) (Shalit et al., 2017). The loss
N
1 X
min M SE(Yi , h(Φ(Xi ), Ti )) +λ IP M (Φ(X|T = 1), Φ(X|T = 0)) +α R(h) (14)
h,Φ,IP M N i=1 | {z } | {z } | {z }
Outcome Loss Dist. b/w T & C covar. distributions L2
59
Figure 6: A. CFRNet architecture originally introduced in Shalit et al. (2017). CFRNet adds an
additional integral probability metric (IPM) loss to TARNet to explicitly force representations of the
treated and control covariates closer in representation space.
B. Weighted CFRNet adds a propensity score head to CFRNet to predict IPW-weighted outcomes.
During training, the propensity score is used to reweight both the predicted outcomes Ŷ (0) and Ŷ (1),
as well as the represented covariate distributions in calculation of the IPM loss. This allows the authors
to provide consistency guarantees and relax the overlap assumption. Figures adapted from Johansson
et al. (2020).
60
These two papers also make important theoretical contributions by providing bounds on the gener-
alization error for the PEHE (Hill, 2011). In Shalit et al. (2017), they show that the PEHE is bounded
by the sum of the factual loss, counterfactual loss, and the variance of the conditional outcome.
In Johansson et al. (2020), the authors introduce estimated IPW weights π(Φ(X), T ) to CFRNet
that are used within the IPM calculation to provide consistency guarantees (Fig. 6B). Theoretically,
they also use these weights to relax the overlap assumption as long as the weights themselves obey the
positivity assumption. From a practical standpoint, adding weights that are optimized smoothly across
the whole dataset each epoch reduces noise created by calculating the IPM score in small batches.
N
1 X P̂ (Ti )
arg min · M SE(Yi , h(Φ(Xi ), Ti )) +λh R(h) +
h,Φ,IP M,π,λh ,λw N i=1 π(Φ(Xi ), Ti ) | {z } | {z }
| {z } Outcome Loss L2 Outcome
IP W
where R(h) is a model complexity term and λh , λπ and α are hyperparameters. The final term is a
Beyond IPMs, other approaches have directly embraced matching as a balancing strategy. Yao et al.
(2018) train their TARNet on six point mini-batches of propensity score-matched units with additional
reconstruction losses designed to preserve the relative distances between these points when projecting
them into representation space. Schwab et al. (2018) takes an even simpler approach by feeding random
61
Appendix B Model Selection Using the PEHE
In order to select hyperparameters in real data, Johansson et al. (2020) propose to use a matching vari-
ant of P EHE with the nearest Euclidean neighbor of each unit i from the other treatment assignment
group yinn as a counterfactual. If we identify the nearest neighbor j of each unit i in representation
then,
N
1 X 2
P EHEnn = ((1 − 2ti )(yi (ti ) − yinn (1 − ti ) − (h(Φ(x), 1) − h(Φ(x), 0)))
N i=1 | {z } | {z }
CAT Enn ˆ E
CAT
If we take the square root of the P EHEnn then we get an approximation of the unit-level error.
√
The intuition behind P EHEnn is solid. If our representation function Φ is truly learning to
balance the treated and control distributions, CAT Enn should coarsely measure it.
Recurrent neural networks are a specialized architecture created for learning outcomes from sequential
data (e.g. time series, biological sequences, text) (Fig. 7). In a classic RNN, each “unit” u in
the network takes as input its own covariates X (or possibly a representation) and a representation
produced by the previous unit, encoding cumulative information about earlier states in the sequence.
These units are not just simple hidden layers: there is a set of weights within each unit for its raw
inputs, the representation from the previous time step, and its outputs. Different RNN variants
have different operations for integrating past representations with present inputs. Recurrent neural
networks may be directed acyclic graphs or feedback on themselves. Commonly used variants include
Gated Recurrent Unit networks (GRU) and Long-term Short-term memory networks (LSTM) (Cho
62
Figure 7: Recurrent neural network.
ing
Adversarial training approaches include a wide variety of architectures where two networks or loss
functions compete against each other. Adversarial approaches are inspired by Generative Adversarial
Networks (GANs) (Box 9) (Goodfellow et al., 2014). In the machine learning literature on causal
inference, adversarial training has been applied both to trade off outcome modeling and treatment
modeling tasks during representation learning, as well as to trade off estimation and regularization
of IPW weights. GANs have also been used directly as generative models for counterfactual and
63
Box 9: Generative Adversarial Networks (GAN)
where the first quantity is the discriminator’s estimated probability data from X is indeed real,
and the second quantity is the discriminator’s estimate that a generated quantity from the
distribution Z is real.
Because the discriminator is trying to catch the generator, its objective is to maximize the
same function,
In practice, the discriminator and the generator are trained either alternatingly or simultane-
ously, with the discriminator increasing its ability to discern between real and fake outcomes
over time, and the generator increasing its ability to deceive the discriminator. The idea is that
the adaptive loss function created by the discriminator can coax the generator out of local min-
ima to generate superior outcomes. Results by these models have been impressive, and many
of the fake portraits and “deepfake” videos circulating online in recent years are generated by
this architectures. The advantage of GANs is that they can impressively learn very complex
generative distributions with limited modeling assumptions. The disadvantage of GANs is that
they are difficult and unreliable to train, often plateauing in local optima.
64
D.1 GANs as Generative Models of Treatment Effect Distributions (GAN-
ITE)
Deep Dive: GANITE (Yoon et al. (2018)) Although a generative model of the treatment effect
distribution is generally unknown, a natural application of GANs is to try to machine learn this
distribution from data. GANITE uses two GANs: GAN1 , consisting of generator G1 and discriminator
D1 , to model the counterfactual distribution and GAN2 , consisting of generator G2 and discriminator
D2 , to model the CAT E distribution (Yoon et al., 2018) (Fig. 8). The training procedure for GAN1
is as follows:
1. Taking X,T , and generative noise Z as input, generator G1 generates both potential outcomes
{Ỹ (T ), Ỹ (1 − T )}. A factual loss M SE(Y (T ), Ỹ (T )) is applied.
2. Create a new vector C = {Y (T ), Ỹ (1 − T )} by combining the observed potential outcome and
the counterfactual predicted by G1 .
3. Taking X and C as inputs, the discriminator D1 rates each value in C for the probability that
it is the observed outcome using the categorical cross entropy loss:
4. This loss is then fed back to G1 such that the total loss for the generator is now
After generator G1 is trained to completion, the authors use C as a “complete dataset” containing
both a factual outcome and a counterfactual outcome to train GAN2 , which generates treatment
effects:
1. Taking only X and generative noise Z as input, G2 generates a new potential outcome vector
R = {Ŷ (T ), Ŷ (1−T )}. G2 receives an MSE loss to minimize the difference between its predictions
and the “complete dataset” C: M SE(C, R).
2. Discriminator D2 takes X, C, and R as inputs and estimates a probability that C is the “com-
plete” dataset, and that R is the “complete dataset”:
L(D2 ) = CCE({ P (C = C) , P (R = C) }, { |C ==
{z C} , C 1 == Y (T )}) (18)
| {z } | {z } | {z }
Prob C is “CD” Prob R is “CD” 1 if idx 0 is C 1 if idx 1 is C
65
Figure 8: GANITE has two GANs. The first generator G1 generates counterfactuals Ỹ (T ). The
discriminator D1 attempts to discriminate between these predictions and real data (Y (T )). The
second generator proposes pairs of potential outcomes Ŷ (0) and Ŷ (1) (i.e., treatment effects), a vector
we call R. Discriminator D2 attempts to discern between R and a “complete dataset” C created by
pairing each observed/factual outcome Y (T ) with a synthetic outcome Ỹ (1 − T ) proposed by G1 .
Although we do not show gradients in other figures, we make an exception for GANs (red line).
3. This loss is then fed back to the generator G2 such that the total loss for the generator is now
At the end of training, G2 should be able to predict treatment effects with only covariates X and noise
Z as inputs. An evolution of GANITE, SCIGAN, extends this framework to settings with more than
66
D.2 Adversarial Representation Balancing
The use of the IPM loss in CFRNet (Shalit et al., 2017) may also be viewed as an adversarial ap-
proach in that the representation layers are forced to maximize performance on two competing tasks:
predicting outcomes and minimizing an IPM. Rather than using an IPM loss, other authors have
trained propensity score estimators that send positive (rather than negative) gradients back to the
Bica et al. (2020a) extend this approach to settings with treatment over time using a recurrent
neural network. In their medical setting, decorrelating treatment from patient covariates and history
67