Lagemann Et Al - 2023 - Deep Learning of Causal Structures in High Dimensions Under Data Limitations
Lagemann Et Al - 2023 - Deep Learning of Causal Structures in High Dimensions Under Data Limitations
Article https://doi.org/10.1038/s42256-023-00744-z
Received: 13 April 2022 Kai Lagemann 1 , Christian Lagemann2, Bernd Taschler 1,3
&
Sach Mukherjee 1,4
Accepted: 20 September 2023
Causality remains an important open area in artificial intelligence (AI) In biomedicine, causal networks representing the interplay
research1,2, and the task of identifying causal relationships between between entities such as genes or proteins play a central conceptual
variables is key in many scientific domains3. The rich body of work in and practical role. Such networks are increasingly understood to be
learning causal structures includes methods such as PC4, LiNGAM5, context-dependent, and are thought to underpin aspects of disease
IDA6, GIES7, RFCI8, ICP9 and MRCL10. Scaling causal structure learning to heterogeneity and the variation in therapeutic response (for example,
larger problems has been facilitated by reformulation as a continuous refs. 16–19). A key bottleneck in characterizing such heterogeneity lies
optimization problem11, and recent neural approaches, such as SDI12, in the challenging nature of learning causal structures at scale, because
DCDI13, DCD-FG14 and ENCO15, have demonstrated state-of-the-art per- of general methodological issues as well as aspects relevant in the
formance (Supplementary section 1 provides a detailed discussion). biological domain such as high dimensionality, complex underlying
However, learning causal structures from data remains nontrivial and events, the presence of hidden/unmeasured variables, limited data
continues to pose challenges, particularly under the conditions (high and noise levels.
dimensionality, limited data sizes and hidden variables, for example) In this Article, we propose a deep architecture for causal learn-
seen in many real-world problems. ing that is particularly motivated by high-dimensional biomedical
1
Statistics and Machine Learning, German Center for Neurodegenerative Diseases (DZNE), Bonn, Germany. 2Institute of Aerodynamics, RWTH
Aachen University, Aachen, Germany. 3Big Data Institute, Li Ka Shing Centre for Health Information and Discovery, University of Oxford, Oxford, UK.
4
MRC Biostatistics Unit, University of Cambridge, Cambridge, UK. e-mail: [email protected]; [email protected]
problems. The approach we put forward operates within an emerg- edges encode these causal relationships. D2CL seeks to learn G* from
ing causal risk paradigm (Methods and Supplementary section 2) two inputs: (1) empirical data X containing measurements on each of
that allows us to leverage AI tools and scale to very high-dimensional the variables of interest and (2) prior causal knowledge Π concerning
problems involving thousands of variables. The learners proposed a subset of causal relationships. This corresponds to a common para-
allow for the integration of partial knowledge concerning a subset of digm in real-world scientific settings, where some data are measured
causal relationships and then seek to generalize beyond what is initially on variables of interest, but only limited knowledge about causal rela-
known to learn relationships between all variables. This corresponds to tionships is available at the outset (for example, from prior scientific
a common scientific use-case in which some prior knowledge is avail- knowledge or specific experiments).
able at the outset—from previous experiments or scientific background We formalize the task in the following way. For each ordered pair
knowledge—but it is desired to go beyond what is known to learn a of variables with indices (i, j) whose causal status is not known via Π,
model spanning all available variables. our goal is to learn an indicator of whether or not Xi has a causal influ-
A large part of the causal structure learning literature involves ence on Xj. D2CL treats these causal indicators as ‘labels’ in a machine
learning models that allow an explicit description of the relevant learning sense, using the available inputs to learn a suitable mapping.
data-generating model (including both observational and interven- The goal of the mapping is to minimize discrepancy with respect to the
tional distributions) and are in that sense ‘generative’ (see, for example, true, unknown causal status; this learning task can be viewed through
ref. 3 and references therein). Taking a different approach, a line of the lens of causal risk23. In all experiments, the learner never has access
recent work, including refs. 10,20–22, has considered learning indi- to data in which the parent node of an unknown edge was intervened
cators of causal relationships between variables (without necessarily on. This makes learning challenging, as we require generalization to
learning full details of the underlying data-generating models), and interventional regimes/distributions that are entirely unseen.
this can be viewed as being related to notions of causal risk23. Such Learning is carried out using a flexible, neural model Fθ with a set
indicators may encode, for example, whether, for a pair of variables A of trainable parameters θ. The model is trained in a specific fashion
and B, A has a causal influence on B, B on A, or neither. that leverages the input information Π as a supervision/training signal
The approach we propose, called ‘deep discriminative causal to allow the model to learn representations suitable for generalization
learning’ (D2CL), is in the latter vein. We consider a version of the to novel causal relationships (the Methods provides details and a dis-
causal structure learning problem in which the desired output con- cussion of the assumptions). The network Fθ combines a convolutional
sists of binary indicators of causal relationships between observed neural network (CNN) and a graph neural network (GNN) to resolve
variables10,23, that is, a directed graph with nodes identified with the distributional and graph structural regularities (Fig. 2). In image pro-
variables. Available multivariate data X are transformed to provide cessing, CNNs make use of certain properties, such as spatial invari-
inputs to a neural network (NN), whose outputs are estimates of the ance, that exploit the notion of an image as a function on the plane.
causal indicators. D2CL differs from classical causal structure learn- Here we leverage the CNN toolkit to capture distributional information
ing approaches both in terms of the underlying framework (based on in data X, represented as images. We create these visual representations
causal risk rather than generative causal models) and in leveraging for two-tuples of nodes. Specifically, for a variable pair (i, j) we use the
NNs. The assumptions underlying the approach are also different in n × 2 submatrix X(⋅, [ij]), to form a bivariate kernel density estimate
nature from those in classical causal structure learning and concern fij = KDE(X(⋅, [ij])) that is treated as an image input. Note that this is in
higher-level regularities in the data-generating processes (Methods). general asymmetric in the sense that fij ≠ fji. This is important, as we
A number of recent papers, including refs. 12–15, also leverage neural want to learn ordered/directed relationships (symmetry here would
approaches for learning causal structures and share a basis in the imply an inability to distinguish the causal direction). The GNN is aimed
continuous optimization framework introduced in ref. 11 based on at capturing graph structural regularities and to this end learns a state
a directed acyclic graph (DAG) framework. D2CL, in contrast, uses a embedding hj that contains the information of the neighbourhood for
risk-based approach that is not based on DAGs. Eigenmann et al. 23 each node j. The GNN requires a graph as input; we provide an initial
studied causal risk for the assessment of existing learners; instead, we input graph G0̂ via computationally lightweight routines solely based
leverage the notion of causal risk to propose a new learner. In common on the available data, X (Methods).
with D2CL, the recently proposed CSIvA method24 seeks to directly Finally, following training, the model F—with parameters now fixed
map input data to a graph output. The key difference is that, while as a function of inputs X and Π—can be used to assign causal status to
CSIvA uses a meta-learning scheme based on large-scale synthetic any pair via an inference step. In the experiments described in the fol-
data, D2CL is based on supervised learning using data from a specific lowing, the global model output is tested systematically at large scale
system of interest (for example, a biological system; see Supplementary against either the true graph G* (in simulations) or against entirely
section 1 for a more detailed overview and comparison). We show that unseen interventional experiments (for real biological examples).
context-specific training allows D2CL to successfully learn structures Our focus is on causal learning for real-world, high-dimensional
in a range of scenarios, including challenging real-world experimental problems with thousands of nodes and limited data, motivated by
data (as detailed in the following). Furthermore, D2CL is demonstrably large-scale biomedical problems. Within the causal risk paradigm10,23
scalable to large numbers of variables (we show examples ranging up to we use here, acyclicity (of the directed graphs to be learned) is not
p = 50,000 nodes) and applicable in regimes where very large sample assumed, nor is availability of any standard factorization of the joint
data or strong simulation engines are not available. probability distribution. It is not required that data samples in X are
drawn from a single distribution; instead, data can be drawn from,
Framework overview for example, a mix of observational and interventional distribu-
We propose an end-to-end neural approach to learn causal networks tions, and the causal characteristics of these regimes (for example,
from a combination of empirical data X and prior causal knowledge Π. which node(s) or latents were intervened on) need not be known in
The general D2CL workflow and its application to biomolecular prob- advance. Nor is it required that we have interventional data or prior
lems are summarized in Fig. 1. Here we provide a very brief, high-level information concerning all nodes. On the contrary, in all experi-
summary of the main ideas. A detailed presentation of the methodology ments, the learner never has access to data in which the parent node
and associated discussion (including of causal semantics and assump- of an unknown edge was intervened on nor prior information con-
tions) are provided in the Methods and Supplementary section 2. cerning the unknown edge. This is a common real-world set-up, in
Suppose X1, …, Xp is a set of variables whose mutual causal relation- particular for emerging experimental designs in biology (examples
ships are of interest. Let G* denote an (unknown) graph whose directed are described in the following). We emphasize that the NNs used are
Xi Xj
Data matrix X
0.2 0.8
Input pair
Causal 0.1 1.9 CNN
Perturbation −0.9 0.5 {Xi, Xj} 0/1
knowledge KDE image
experiments 0.7 0.2
Data Data
0.1 1.2 0.1 1.2
0.4 0.9 0.4 0.9
0.3 1.0 0.3 1.0
0.2 1.1 0.2 1.1
Training
Training
Inference
Training Model θ
Model/D2CL Model/D2CL
ng
0.4
an
ni
ce
ar
Le
IDA: AUC = 0.78
D2CL: AUC = 0.85
0.2 SCL: AUC = 0.85
Learned
graph ENCO
(P = 1,000) DCD-FG
0
0 0.2 0.4 0.6 0.8 1.0
0.8 0.8
0.75 0.75
Tanh AUC
0.7 0.7
AUC
0.65 0.65
0.6 0.6
0.55 0.55
0.5 0.5
0.45 0.45
0.4 0.4
SNR 10.0 6.0 4.0 2.0 1.0 0.75 0.5 0.25 0.1 SNR 10.0 6.0 4.0 2.0 1.0 0.75 0.5 0.25 0.1
Pearson 0.72 0.71 0.68 0.63 0.59 0.58 0.55 0.53 0.51 Pearson 0.51 0.48 0.46 0.42 0.40 0.40 0.40 0.44 0.48
IDA 0.77 0.78 0.77 0.75 0.74 0.72 0.69 0.66 0.63 IDA 0.90 0.89 0.88 0.85 0.81 0.79 0.77 0.75 0.74
D2CL 0.85 0.86 0.85 0.84 0.80 0.79 0.77 0.73 0.72 D2CL 0.94 0.94 0.94 0.94 0.89 0.89 0.86 0.83 0.81
SCL 0.85 0.84 0.81 0.76 0.72 0.68 0.65 0.58 0.56 SCL 0.90 0.89 0.88 0.85 0.79 0.78 0.74 0.63 0.59
Fig. 3 | Results for large-scale simulated data. a, Overview of the experimental DCD-FG are represented by single markers for five different runs. c, Results for
workflow. Data were simulated from known, gold-standard causal graphs, and an illustrative nonlinear case (the tangent hyperbolic), at varying noise levels,
the output of the learners was compared with the true, underlying graph to for direct causal relationships. The causal area under the ROC curve (AUC; with
quantify the ability to recover the causal structure. Finite-sample empirical respect to the causal ground truth graph, see Methods and Supplementary
data were generated using a directed causal graph of specified dimension p, section 3 for details) is shown as a function of SNR for an experiment with
specifically via linear and nonlinear structural equation models with noise. p = 1,500 variables and a sample size of n = 1,024. Results for other linear and
b, ROC curves for an illustrative nonlinear case (the tangent hyperbolic), with nonlinear functions are provided in Supplementary section 4. D2CL (blue) is
an SNR of 10.0, for direct causal relations in a graph with p = 1,500 nodes. D2CL compared with Pearson correlations (orange; this is a non-causal baseline), IDA
(black) is compared against Pearson correlation coefficients (orange), IDA (cyan) and SCL (green). d, Results for indirect causal relationships, with other
(cyan), SCL (green), ENCO (blue) and DCD-FG (brown). The ROC curve and the settings as in c. Here, causal AUC is shown with respect to a graph encoding
area under the ROC curve (AUC) are given for algorithms providing a continuous causal, but potentially indirect, relationships. (Results shown are averages over
output (Pearson, IDA, SCL and D2CL). The binary graph estimates of ENCO and five datasets at each specified SNR).
restricted this comparison to a subset of the simulations. Illustrative and the types used. The test results (AUC and AUPRC values) are sum-
results are provided in Fig. 3b. We find that neither approach is effec- marized in Supplementary Tables 8 and 9 and support the notion that
tive in this case, possibly due to the limited data and the presence D2CL is robust to different types of noise.
of latent variables. The graph G* in the above examples encodes direct causal rela-
In addition, we tested the effectiveness of D2CL for additive and tionships as there is an edge from one node to another if the former
multiplicative Gaussian noise with varying SNRs under settings with appears in the equation for the latter. However, in many real-world
hard deterministic and stochastic interventions. We refer the interested examples, interest focuses also on indirect effects, which may be
reader to Supplementary section 3 for a definition of an intervention mediated by other nodes. For example, if node A has a direct effect
0.8
True positive rate
0.500
0.500
0.500
0.850
0.595
0.820
0.766
0.796
0.794
0.816
0.784
0.786
0.841
0.841
0.518
0.837
0.749
0.821
0.518
0.517
0.762
0.831
0.518
0.517
0.817
0.811
0.517
0.4 1.0 1.0 1.0
0.8 0.8 0.8
0.2
0.6 0.6 0.6
0
0 0.2 0.4 0.6 0.8 1.0 0 0.2 0.4 0.6 0.8 1.0 0 0.2 0.4 0.6 0.8 1.0
D2CL - CNN tower D2CL - GNN tower (Lasso) D2CL - GNN tower (Pearson) D2CL (Pearson) D2CL (Lasso) IDA LVIDA Kendall SCL
d n = 100
e n = 300 f n = 706
1.0
0.8
True positive rate
0.500
0.500
0.850
0.500
0.520
0.829
0.852
0.522
0.853
0.759
0.762
0.828
0.831
0.518
0.814
0.753
0.817
0.517
0.521
0.517
0.811
0.8
True positive rate
0.824
0.850
0.820
0.834
0.760
0.837
0.831
0.681
0.817
0.817
False positive rate False positive rate False positive rate False positive rate False positive rate
p = 1,000 p = 5,535
Fig. 4 | Results for the yeast gene deletion experiments. Causal learning arrangements. Here, only D2CL variants are shown, as the other methods could
methods, including D2CL, were applied to gene expression measurements from not be run due to the computational burden in this higher-dimensional case.
yeast cells. Performance was quantified using causal ROC curves (and AUCs) Comparison with the corresponding p = 1,000 case demonstrates the scalability
computed with respect to a causal ground truth obtained from entirely unseen of D2CL, with performance broadly maintained in the higher-dimensional setting.
interventional experiments (see main text for details). a–c, The number of The D2CL variants shown include a CNN tower alone (g), GNN tower alone (h,i)
interventions m whose effects are available to the learner was varied (with the and the entire D2CL architecture (j,k); methods compared against include IDA,
problem dimension fixed to p = 1,000 and the sample size to n = 706): m = 100 LVIDA, Kendall correlations (as a non-causal baseline) and SCL (see main text
(a), 500 (b) and 753 (c). d–f, The sample size n of the data matrix X was varied and Supplementary sections 1 and 3 for details and references). For D2CL and
(with the problem dimension fixed to p = 1,000 and the number of available its variants, two different initial graph estimates were used based respectively
interventions fixed to m = 753): n = 100 (d), 300 (e) and 706 (f). g–k, Analogous on Pearson correlation coefficients (‘Pearson’) and on a lightweight regression
results for a higher-dimensional setting covering all available genes (roughly (‘Lasso’); details are provided in the main text.
the full yeast genome) with p = 5,535 (with n = 706 and m = 753) for the indicated
on B, and B on C, intervention on A may change C, even though A Out-of-system, out-of-distribution evaluation. D2CL is trainable using
does not itself appear in the equation for C. To test the ability to (limited) data from a specific system (for example, a specific biological
learn indirect edges, we proceeded as above, but with the inputs Π system, such as cells of a particular kind, or a disease state). However,
being indirect edges and the output tested against the true indirect it is interesting to see whether it is possible to generalize to different
graph. Results are presented in Fig. 3d. D2CL outperforms existing systems. To this end, we trained D2CL on a dataset from a certain system
methods across a range of SNRs and also in other linear/nonlinear and cross-evaluated the trained model on data from another system (a
problem configurations (Supplementary Tables 4 and 5). IDA per- different simulation regime). The results are provided in Supplemen-
forms well in the case of a linear SEM, but not for functions based tary Tables 10 and 11. Some generalization appears possible, suggesting
on nonlinear multilayer perceptrons. D2CL appears to be the most that D2CL can find signals that are causally informative in a cross-system
noise-robust of the methods tested. These results show that D2CL can sense, although performance is always worse relative to in-system train-
learn indirect causal edges over many variables under conditions of ing (this is expected in our framework, and we emphasize that we do
noise and nonlinearity. not claim any general ability to achieve out-of-system generalization).
0.82 0.75
0.81
0.70
0.80
0.65
0.850
0.839
0.820
0.839
0.825
0.825
0.837
0.822
0.799
0.799
0.738
0.738
0.819
0.831
0.817
0.817
0.811
0.811
0.79
0.78 0.60
D2CL - CNN D2CL - GNN D2CL - GNN D2CL D2CL Zeroed Gaussian noise Gaussian noise Gaussian noise
tower tower (Pearson) tower (Lasso) (Pearson) (Lasso) embedding σ = 1.0 σ = 2.0 σ = 5.0
Fig. 5 | Sensitivity to incorrect causal inputs and additional results on causal tower at two different layer depths. Edges A → B are shown as filled circles and
direction. a, Robustness to incorrect causal inputs. The sensitivity of D2CL to reverse edges B → A as x-shaped markers. An edge and its corresponding reverse
errors in prior/input causal knowledge Π was studied by artificially introducing are shown in the same colour. For improved readability, only ten random pairs
errors into Π, with 10% of inputs corrupted (experiments used the yeast gene are highlighted in colours and bigger markers. We see that the embedding is not
deletion data; see main text for details). Results quantified via causal AUC invariant with respect to causal direction and is able to effectively identify the
(as in the main results, computed with respect to an experimentally defined correct direction (as shown also in an additional experiment, see main text). The
ground truth), shown for several D2CL variants. b, An ablation-like study in different D2CL variants include a CNN tower alone, a GNN tower for two different
which failures of either the CNN (orange) or the GNN (blue) tower within D2CL initial graph estimates, and the complete network for the same two initial graph
were artificially introduced. The relevant embedding was either set to zero or estimates. Initial graph estimates for the GNN and combined models were either
to zero-mean Gaussian noise (with scale as shown). The unaffected case is given based on Pearson correlation coefficients (‘Pearson’) or a lightweight regression
as a dashed black line. c, Causal direction analysis (see main text for details). (‘Lasso’; see main text for details).
Low-dimensional representations of latent feature maps of the converged CNN
Nevertheless, these results broadly support the notion of large-scale gene expression in yeast cells under each of a large number of interven-
meta-learning for causal structures24. tions (gene deletions; Supplementary section 3 provides further details).
In biological experiments, causal effects may be indirect, and
Large-scale evaluation. Finally, to test the scalability of D2CL to we sought to learn a directed graph with nodes corresponding to p
high-dimensional problems, we considered a problem with p = 50,000 observed genes and edges representing (possibly indirect) causal
variables (that is, p = 50,000 nodes in the ground-truth graph; note that influences. Such edges are scientifically interesting and amenable
none of the compared methods can practically scale to this setting). to experimental verification, as noted in refs. 22,27. Cycles can arise
We considered learning of direct causal relationships; the results are in systems biology (see, for example, ref. 28) and we do not enforce
shown in Supplementary Table 6 and support the notion that D2CL can acyclicity (see ref. 29 and references therein for a discussion of cyclic
scale to problems spanning many thousands of variables. causality). A fuller discussion of the causal interpretation of labora-
tory experiments is beyond the scope of this Article, but relevant work
Large-scale biological data includes refs. 29–31, and we direct the interested reader to these refer-
To study performance in the context of real biological data, we leveraged ences for further discussion.
a large set of gene deletion experiments in yeast25, which have previously Because causal background knowledge is an input for our
been used for causal learning9,10,26. The experiments involve measuring approach, it is relevant to consider performance as a function of the
a b
D2CL: human data ROC curves & AUC Baselines: AUC
1.0
2 0.65
D CL K562
D2CL RPE 0.60
0.8
True positive rate
0.55
0.6 0.50
AUC
AUC
0.8 0.45
0.4 0.650 0.733 CAM RFCI
0.7 0.40 IDA
GES
0.2 0.6 0.35 GIES D2CL (p = 8,552)
0.5 0.30 LiNGAM
K562 RPE
0
0 0.2 0.4 0.6 0.8 1.0 10
1
10
2
10
3
Fig. 6 | Results for high-dimensional human data. Single-cell CRISPR-based RPE cells and a cancer cell line (K562) in problems spanning more than 8,000
experiments (due to ref. 32) were used to illustrate the use of the proposed variables (other methods could not be practically run in this case due to the
approaches in a high-dimensional human cell setting. Performance was computational burden). b, Performance of existing causal learning approaches
quantified using causal ROC curves (and AUC) computed with respect to a (on K562 data) as a function of problem dimension. The dashed line indicates
causal ground truth obtained from entirely unseen interventional experiments D2CL performance on the full problem (p = 8,552 variables).
(see main text for details). a, Results from D2CL applied to data obtained from
amount of such input. To this end, we fixed the problem size to p = 1,000 undirected links would have an AUC score of 0.5 (because the output
and varied the number of interventions m whose effects were available k → l entails also l → k, one of which is a false positive). Supplementary
to the learner (Supplementary section 3 provides details). As each Table 4 shows that D2CL is indeed capable of accurately identifying
experiment involves only a subset of the entire yeast genome, latent causal direction. In addition, Fig. 5c shows a low-dimensional rep-
variables are present by design. The input prior knowledge Π is derived resentation of the feature maps of the converged CNN tower. These
from the causal status, but, as in all experiments, is strictly disjoint with feature maps differ by causal direction (k → l versus l → k) throughout
respect to any test edges. the forward pass, supporting the foregoing arguments.
Results are presented in Fig. 4a–c, including the area under the
receiver operating characteristic (ROC) curve (AUC; computed with High-dimensional CRISPR-based perturbations
respect to an experimentally determined gold standard; Supplemen- Finally, we used recent, single-cell clustered regularly interspaced short
tary section 3). Overall, the proposed methods perform well, achieving palindromic repeats (CRISPR)-based interventional experiments32 to
good results in this high-dimensional, limited-data regime. Next, to illustrate the use of the proposed approaches in very high-dimensional
shed light on data efficiency, we varied the sample size n of the data data from human cells. The experimental protocol (see ref. 32 for full
matrix X (Fig. 4d–f). details) includes a large number of interventions in a leukaemia cell
Finally, we tested the performance in a higher-dimensional exam- line (K562) and in retinal pigment epithelial (RPE) cells. The K562 and
ple spanning all p = 5,535 available genes (Fig. 4g–k) and found that RPE experiments include gene-expression levels for a total of, respec-
D2CL remains effective at the genome scale. Interestingly, although tively, p = 8,552 and p = 8,833 genes (Supplementary section 3 provides
the CNN tower performs particularly well, the GNN tower degrades details). This is a challenging setting due to the known complexity
more. This may be because larger p leads to a larger number of variable of regulatory events in human cells and high levels of variability and
pairs (which is helpful for the CNN), but also to a (rapid) increase in the noise in single-cell protocols. The results are presented in Fig. 6 and
number of nodes and edges in the GNN subgraphs and hence a harder demonstrate good performance for RPE, and slightly worse perfor-
GNN learning task in practice. mance, but still nontrivial consistency with the experimental gold
D2CL leverages prior causal knowledge. However, in practice, standard, for K562. Additional plots in Fig. 6 and Supplementary Fig. 3
the available causal inputs Π may be incorrect, for example, due to show the performance and runtime for a set of baseline algorithms.
flawed initial experiments or errors in the known science. To study These results demonstrate two key points. First, the runtime for many
sensitivity to flawed causal inputs, we introduced errors into Π. This available algorithms grows so rapidly with increasing number of vari-
was done by perturbing 10% of the inputs (that is, labelling causal pairs ables as to render them unsuitable for problems at this scale. Second,
as non-causal and vice versa) at the outset. The results are shown in for existing methods that are at all able to scale to larger problems,
Fig. 5a and demonstrate a level of robustness to such perturbation. performance is considerably less effective than D2CL in this setting.
We also see a benefit of the dual network variants; this is investigated
further in Fig. 5b. For the latter, in general, the embedding of either Conclusions
tower is modified immediately before the fusion layer. We considered Emerging experimental protocols, involving combinations of pertur-
several different modifications: setting the embedding of one tower to bations and high-dimensional readouts, are allowing for new, scal-
zero and hence effectively removing all information from this tower, able ways to query molecular networks in a context-specific fashion.
or applying Gaussian noise with magnitude σ = 1.0, σ = 2.0 and σ = 5.0. Combined with scalable causal learning tools, these approaches have
Causal relations are in general directed and asymmetric, so it is the potential to strongly impact disease biology by allowing global
interesting to explore model behaviour with respect to causal direction. networks, spanning thousands or tens of thousands of variables, to
Given an image representation, the CNN tower is designed to extract be investigated across many different contexts.
feature maps that are unique for ordered node pairs, that is, such that Networks learned in this way could, in the future, be leveraged to
in general features differ depending on edge direction. To empirically allow for the prediction of disease phenotypes or drug response under
study learning of causal direction, we constructed additional test data novel perturbations (this is a different task from standard supervised
as follows: for each truly causal edge k → l in the test set, we also included learning, because the test case involves an unseen perturbation to the
the reverse direction l → k. This means that any learner estimating system). Furthermore, in the area of personalized medicine, such an
approach could even allow for rational optimization over potential input. In terms of linear indexing, these can be viewed as available
therapeutic strategies, because the latter are often interventions tar- ‘labels’ of causal status for the pairs 𝒯𝒯(Π) ⊂ 𝒦𝒦 . No specific assumption
geted at molecular nodes. is made on the data X, but, in line with our focus on generalizing to
Our model leverages deep learning tools to learn causal relation- unseen causal relationships, it is assumed that it does not contain
ships between variables at large scale. However, and in contrast to interventional data corresponding to the pairs in 𝒰𝒰 . Furthermore, in
well-established approaches based on causal graphical models, it all experiments, not only are the sets 𝒯𝒯 and 𝒰𝒰 disjoint, but we enforce
provides only a structural output rather than a probability model the stronger requirement that u ∈ 𝒰𝒰 𝒰 𝒰j ∶ k(i(u), j) ∈ 𝒯𝒯 . This means
of the underlying system. It is also interesting to contrast D2CL with that all interventions on which models are tested are entirely novel,
the recently proposed CSIvA24. Both approaches pursue, in a sense, a that is, unrepresented in the inputs to the learner, either as data or prior
‘direct’ mapping of data inputs to graph outputs, with a key difference input. This also means that the learner has no access whatsoever to
being that CSIvA uses meta-learning and seeks to generalize across samples from the test interventional distributions, and all experiments
systems, whereas D2CL uses supervised learning to generalize to new are out-of-distribution in this sense.
interventions on a given system (for example, a biological system of The learning task can thus be formulated as follows: given inputs
interest). An interesting direction for future work may be to combine (I1) and (I2), the goal is to estimate, for each ordered pair of variables
both approaches, for example by using CSIvA to provide the initial (Xi, Xj) with unknown causal relationship, whether or not Xi has a causal
input to D2CL; this would combine general, simulation-based learning influence on Xj.
and data-efficient, system-specific training.
At present, rigorous theory and an understanding of the theoreti- Summary of the learning scheme
cal properties of the kind of approach studied here remain lacking. A With the notation above, our goal is to learn a graph whose nodes cor-
key direction for future theoretical work will be to understand the respond to the variables X1, …, Xp and whose edges represent causal
precise conditions for the underlying system that are needed to ensure relationships. To this end, we train a parameterized network Fθ, that is,
that direct mapping approaches can guarantee the recovery of specific a nonlinear function F with a set of unknown, trainable parameters θ.
causal structures. An interesting observation is that the proposed This is possible, because we know for each pair k ∈ 𝒯𝒯 the causal status
approach may benefit from a ‘blessing of dimensionality’, because the G∗ij based on input information Π. The architecture we use as Fθ is
learning problem will typically enjoy a larger number of examples as detailed below, but for now assume this has been specified. Then, given
dimension p grows. Conversely, and in contrast to established statisti- the data X and the training labels Yk = G∗i(k), j(k) for all pairs k ∈ 𝒯𝒯(Π), we
cal causal models, our approach (at the current stage) cannot be used ̂
train the set of parameters θ(X, Π) under a loss that is supervised by
in the small-p regime, because the number of examples will be too the (causal) labels Yk.
small for deep learning. At this stage, the trained network Fθ(X, ̂ Π) allows assignment of
causal status to any pair, because it gives an estimate of the entire graph
Methods including those pairs whose causal status was unknown. The output is
In this section, we provide information on the causal interpretation given by
of our learning scheme, as well as a more detailed presentation of the
Fθ(X,
̂ Π) (i, j; X ) if k(i, j) ∉ 𝒯𝒯𝒯Π)
architecture and associated implementation. Giĵ (X, Π) = { (1)
Yk(i, j) (Π) otherwise
Notation
Observed variables with index set V = {1, …, p} are denoted X1, …, Xp. where (i, j) are ordered variable pairs. Note that the overall estimate
The variables will be identified with vertices in a directed graph G whose depends solely on the data X and causal information Π. By default, no
vertex and edge sets are denoted V(G) and E(G), respectively. We occa- change is made for pairs 𝒯𝒯 whose status was known at the outset.
sionally overload G to refer also to the corresponding binary adjacency Reference 23 studied causal notions of risk based on loss functions of
matrix, using Gij to refer to the entry (i, j) of the adjacency matrix, as the form that compare a graph estimate Ĝ with ground truth G*.
will be clear from context. We use linear indexing of variable pairs to In our setting, we consider a classification-type loss on the variable
aid formulation as a machine learning problem. Specifically, an ordered pairs k, where the causal status of known pairs 𝒯𝒯(Π) provides the
pair (i, j) ∈ V × V has an associated linear index k ∈ 𝒦𝒦 = 𝒦1, … , K }, where training ‘labels’. We therefore use the corresponding binary
K is the total number of variable pairs of interest. Where useful, we make cross-entropy loss, augmented by additional terms that, for example,
the mapping explicit, denoting the linear index corresponding to a pair prevent exploding weights.
(i, j) as k(i, j) and the variable pair corresponding to a linear index k as
(i(k), j(k)). The linear indices of pairs whose causal relationships are Causal interpretation of the learning scheme
unknown and of interest are 𝒰𝒰 ⊂ 𝒦𝒦 , and those pairs known in advance D2CL outputs a directed graph. The discriminative nature of D2CL
via input knowledge Π are 𝒯𝒯(Π) ⊂ 𝒦𝒦. In all experiments, 𝒯𝒯(Π) and 𝒰𝒰 are means that the notion of causal influence encoded by the edges is
disjoint; that is, no prior causal information is available on the pairs 𝒰𝒰 rooted in the application setting and input information Π, because
of interest. causal semantics are inherited via the problem setting rather than
specified by a generative model (see ref. 10 for related discussions).
Problem statement Indeed, in the experiments we showed that D2CL could be used to suc-
We focus on the setting in which the available inputs are cessfully learn either direct or indirect/ancestral causal relationships.
Here we provide some intuition as to why discriminative learning
• (I1) Empirical data: an n × p data matrix X whose columns corre- can be effective in this setting. However, we note that the following
spond to variables X1, …, Xp. arguments are not intended to constitute a rigorous theory at this stage,
• (I2) Causal background knowledge Π providing information on a but rather to help gain an understanding of the conditions under which
subset 𝒯𝒯(Π) ⊂ 𝒦𝒦 of causal relationships. discriminative causal structure learning may be expected to be effective.
We start with a general causal framework and then introduce
For (I2), we assume that information is available concerning the assumptions for D2CL (the meta-generator assumption (MGA) and
causal status of a subset of variable pairs. That is, for some variable the dominant cause under single intervention (DCSI), described in the
pairs (Xi, Xj) the correct binary indicator G∗ij , representing the presence/ following sections). Following refs. 1,33, we assume decomposition of
absence of an edge in the target graphical object, is provided as an the underlying system into modular and independent mechanisms:
Independent causal mechanisms (ICMs). The causal generative described above, rigorous theory and the theoretical properties of the
process of a system’s variables is composed of autonomous modules kind of approach studied here remain to be understood, in particular
that do not inform or influence each other. the precise conditions for the underlying system needed to ensure that
For variables Xi assume a structural causal model with equations the classification-type approach can guarantee recovery of specific
Xi = fi (PaG∗ (Xi ), UXi ), i = 1, … , p, where PaG∗ (Xi ) denotes the set of parents causal structures. We emphasize also that in contrast to classical causal
in the ground-truth graph G* for node i, and fi is a node-specific func- learning schemes, for example, based on causal DAGs, we cannot at this
tion. Exogeneous noise terms UXi are assumed jointly independent and stage make theoretical statements concerning underlying multivari-
distributed as UXi ∼ pi , where pi is a node-specific density. ate distributions and their link to edges estimated by our models. Our
Our approach treats the fi and pi as unknown, but assumes they are goal is good performance in an edge-wise sense (as detailed above),
related at a higher level. This can be formalized as a meta-generator and the core assumptions (formalized above) concern a limited notion
assumption as follows. of classifiability. We note also that our models at present learn edges
separately and do not impose any particular wider/global constraints
Meta-generator assumption (MGA). For a specific system W, the func- (such as acyclicity or path constraints), although this could in principle
tions fi and noise distributions pi are (independently) generated as fi ∼ ℱW be done within the causal risk framework.
and pi ∼ 𝒫𝒫W , where ℱW denotes a function generator and 𝒫𝒫W a stochastic
generator, that are specific to the applied problem setting W. Architecture details
MGA is motivated by the notion that in any particular real-world CNN tower. To capture distributional information from empirical data
system, underlying (biological, physical, social and so on) processes X, a preprocessing step is required. In principle, this could be done via
tend to share some functional and stochastic aspects, which impart a variety of multidimensional transformations of X. We consider the
some higher-level regularity. That is, MGA states that, in a given applied simplest possible case, namely for a pair (i, j) to consider only the cor-
context, functions fi and (independent causal mechanism-consistent) responding columns i and j in the data matrix X. Specifically, we use
noise terms UXi , while unknown, varied and potentially complex, the n × 2 submatrix X(⋅, [ij]) to form a bivariate kernel density estimate
are nonetheless related at a ‘meta’-level. The generators ℱW , 𝒫𝒫W are fij = KDE(X(⋅, [ij])). Note that this is, in general, asymmetric in the sense
random processes, representing, respectively, a ‘distribution over that fij ≠ fji, which is important as we want to learn ordered/directed
functions’ and a ‘distribution over distributions’, whose role here is to relationships. In other words, this ensures that, in general, the CNN
capture the notion of relatedness among fi functions (respectively pi) tower can output different probabilities for edges A → B and B → A
in a given setting W. Note that ℱW , 𝒫𝒫W are treated as unknown and never (for any two nodes A and B). Evaluations of the KDE at equally spaced
directly estimated. grid points on the plane (that is, numerical values from the induced
As mentioned in the problem statement, we focus on the causal sta- density function) are treated as the input to the CNN. The KDE itself is
tus of variable pairs (Xi, Xj) (rather than general tuples), which denotes a standard bivariate approach using automated bandwidth selection
the simplest possible case under MGA. Furthermore, in both our work following refs. 34,35. This provides an ‘image’ of the data and allows us
and the majority of interventional studies in applications such as biol- to leverage existing image analysis ideas. Furthermore, we concatenate
ogy, single interventions (rather than joint interventions on multiple channelwise the numerical KDE values on the regularly spaced grid with
nodes) are the norm. This requires the additional assumption, DCSI. a positional encoding of the grid points.
The concrete network architecture of our CNN tower is inspired
Dominant cause under single interventions (DCSI). A sufficiently by a ResNet-54 architecture36. From a high-level perspective, it consists
large change in one of potentially multiple causes leads to a change of a stem, five stages with [3, 4, 6, 3, 3] ResNet blocks and multiple fully
with respect to the effect. Therefore, single interventions are sufficient connected layers that transform the high-level feature maps into a
to drive variation in the child distribution. latent space that is merged with the output of the GNN tower. The first
ResNet block at each stage downsamples the spatial dimensions of the
From MGA and DCSI to discriminative causal structure learning. output of the previous stage by a factor of two. To enhance the compu-
Consider an applied problem W with underlying causal graph G∗W , tational efficiency of the bottleneck layers in each ResBlock, channel
treated as fixed but unknown. The associated functions and noise terms down- and upsampling exploiting 1 × 1 convolutions is performed
are also unknown but assumed to follow MGA. Then, under DCSI, we before and after each feature-extraction CNN layer37. We replaced
have that all pairs of the form (Xi, Xj) have underlying relationships of ReLU activations by the parametric counterpart PReLU38, allowing us
the form Xj = fj (Xi , UXj ) with components following the MGA (that is, to learn the slope of the negative part at negligible additional compu-
drawn from generators ℱW , 𝒫𝒫W ). This in turn suggests that within the tational costs, which addresses the problem of dying neurons. Follow-
setting W, identification of causal pairs can be treated as a classification ing ref. 39, we chose a full pre-activation of the convolutional layers,
problem, as all pairs share the same generators. In other words, MGA normalization–activation–convolution.
restricts the distribution over relations of variables and noise terms to
system-specific generators. GNN tower. Our GNN tower builds on the SEAL architecture of ref. 40
Note that no particular assumption is made on the individual and the resulting graph convolutional neural network (GCNN) for link
functions fj, only that they are mutually related on a higher level. Fur- prediction. The underlying notion is that a heuristic function predicts
thermore, the generators themselves need not be known nor directly scores for the existence of a link. However, instead of employing pre-
estimated; rather, it is only important that they are shared across the defined heuristics (such as the Katz coefficient or PageRank), an adap-
applied setting W. Note that a model learned for setting W will not in tive function is learned in an end-to-end fashion, which is formulated
general be able to classify pairs in an entirely different applied setting as a graph classification problem on enclosing subgraphs. Reference 40
W′ (because the generators may then differ strongly); that is, we do not showed that a γ-decaying heuristic can be approximated by an h-hop
seek to learn ‘universal’ patterns that apply to all causal relations in any neighbourhood while the approximation error is at least decreasing
system whatsoever. The classification task of D2CL aims to tell apart exponentially. These findings suggest that it is possible to learn
causal relationships (assumed drawn from the system-specific genera- high-order graph structure features from local enclosing subgraphs
tors) from non-causal ones. We note that, in real systems, fi functions instead of the entire graph, which can be exploited for link prediction.
may be coupled via constraints on global functionality, and are thus Consider the pair of nodes of interest (i, j); the GNN tower is intended
non-independent; however, the good performance seen in the results to infer causally interesting node features and state embeddings based
empirically justifies the approach. Despite the initial theoretical ideas on a local 1-hop enclosing subgraph extracted from the approximated
input graph G0̂ . For node pair (i, j), we first extract a set of nodes 𝒩𝒩 5. Shimizu, S., Hoyer, P. O., Hyvärinen, A. & Kerminen, A. A linear
with all nodes that are connected to either node i or node j based on non-Gaussian acyclic model for causal discovery. J. Mach. Learn.
the adjacency matrix of the approximated input graph G0̂ . The edge Res. 7, 2003–2030 (2006).
structure within the subgraph Gi, j is then reconstructed by pulling out 6. Maathuis, M. H., Kalisch, M. & Bühlmann, P. Estimating
all edges from G0̂ for which the parent and child node are in 𝒩𝒩 . The high-dimensional intervention effects from observational data.
order of the nodes is shuffled for each subgraph. The node features in Ann. Stat. 37, 3133–3164 (2009).
every input subgraph consist of structural node labels that are 7. Hauser, A. & Bühlmann, P. Characterization and greedy learning
assigned by a double-radius node labelling (DRNL) heuristic40 and the of interventional Markov equivalence classes of directed acyclic
individual data features. In a first step, the distances between node i graphs. J. Mach. Learn. Res. 13, 2409–2464 (2012).
and all other nodes of the local subgraph except node j are computed. 8. Colombo, D., Maathuis, M. H., Kalisch, M. & Richardson, T. S.
The same is repeated for node j. A hashing function then transforms Learning high-dimensional directed acyclic graphs with latent
the two distance labels into a DRNL label that assigns the same label and selection variables. Ann. Stat. 40, 294–321 (2012).
to nodes that are on the same ‘orbit’ around the centre nodes i and j. 9. Peters, J., Bühlmann, P. & Meinshausen, N. Causal inference using
During the training process, the DRNL label is transformed into a invariant prediction: identification and confidence intervals. J. R.
one-hot encoded vector and passed to the first graph convolutional Stat. Soc. 78, 947–1012 (2016).
layer. In contrast to traditional CNNs, GCNNs do not benefit strongly 10. Hill, S. M., Oates, C. J., Blythe, D. A. & Mukherjee, S. Causal
from very deep architecture design41,42. Therefore, our GNN tower learning via manifold regularization. J. Mach. Learn. Res. 20,
consists only of four sequentially stacked graph convolutional layers. 127 (2019).
The activation function is defined as the hyperbolic tangent. Because 11. Zheng, X., Aragam, B., Ravikumar, P. K. & Xing, E. P. DAGs with
the number of nodes in the enclosing subgraph for each pair of no tears: continuous optimization for structure learning. In
variables (i, j) is different, a SortPooling layer43 is applied to select the Proc. Advance in Neural Information Processing Systems Vol. 31,
top k nodes according to their structural role within the graph. After- 9472–9483, (eds Bengio, S. et al.) (Curran Associates, 2018).
wards, one-dimensional convolutions extract features from the 12. Ke, N. R. et al. Learning neural causal models from unknown
selected state embeddings. interventions. Preprint at https://arxiv.org/abs/1910.01075 (2019).
13. Brouillard, P., Lachapelle, S., Lacoste, A., Lacoste-Julien, S. &
Embedding fusion. Each tower outputs a high-dimensional embedding Drouin, A. Differentiable causal discovery from interventional
of the individual features found. These embeddings are concatenated data. Adv. Neural Inf. Process. Syst. 33, 21865–21877 (2020).
and further processed by multiple fully connected layers. Finally, the 14. Lopez, R., Hütter, J.-C., Pritchard, J. & Regev, A. Large-scale
last layers output the log-likelihood of a directed edge from node i differentiable causal discovery of factor graphs. Adv. Neural Inf.
to node j. Process. Syst. 35, 19290–19303 (2022).
15. Lippe, P., Cohen, T. & Gavves, E. Efficient neural causal discovery
Implementation details. All network architectures are implemented without acyclicity constraints. In International Conference on
in the open-source framework PyTorch44. The GNN is coded based on Learning Representations (2022).
the Deep Graph Library45. All modules are initialized from scratch using 16. Ideker, T. & Krogan, N. J. Differential network biology. Mol. Syst.
random weights. During training, we apply an Adam-Optimizer46 start- Biol. 8, 565 (2012).
ing at an initial learning rate of ϵ0 = 0.0001. The learning rate is reduced 17. Hill, S. M. et al. Inferring causal molecular networks: empirical
by a factor of five once the evaluation metrics stop improving for 15 assessment through a community-based effort. Nat. Methods 13,
consecutive epochs. The minimum learning rate is set to ϵmin = 10−8. 310–318 (2016).
The training predictions are supervised on the binary cross-entropy 18. Hill, S. M. et al. Context specificity in causal signaling networks
loss between estimated and ground-truth edge labels. The evaluation revealed by phosphoprotein profiling. Cell Syst. 4, 73–83 (2017).
metric is the (held-out) area under the ROC curve. Every network archi- 19. Kuenzi, B. M. & Ideker, T. A census of pathway maps in cancer
tecture is trained for 100 epochs. All computations are run on multiple systems biology. Nat. Rev. Cancer 20, 233–246 (2020).
graphics processing unit (GPU) nodes simultaneously, each equipped 20. Lopez-Paz, D., Muandet, K., Schölkopf, B. & Tolstikhin, I. Towards
with eight Nvidia Tesla V100 GPUs. a learning theory of cause-effect inference. In Proc. 32nd
International Conference on Machine Learning Vol. 37, 1452–1461
Data availability (eds Bach, F. et al.) (PMLR, 2015).
Data files are publicly available as follows. Yeast gene deletion data are 21. Mooij, J. M., Peters, J., Janzing, D., Zscheischler, J. & Schölkopf,
from ref. 25. CRISPR perturbation data are from ref. 32. The pseudocode B. Distinguishing cause from effect using observational data:
for data simulation is provided in Supplementary section 5. methods and benchmarks. J. Mach. Learn. Res. 17, 1–102 (2016).
22. Noè, U., Taschler, B., Täger, J., Heutink, P. & Mukherjee, S.
Code availability Ancestral causal learning in high dimensions with a human
A Code Ocean compute capsule, which contains a pre-built compute genome-wide application. Preprint at https://arxiv.org/
environment and the source code of D2CL, is available at https://code- abs/1905.11506 (2019).
ocean.com/capsule/4465854/tree/v1 ref. 47. 23. Eigenmann, M., Mukherjee, S. & Maathuis, M. Evaluation of causal
structure learning algorithms via risk estimation. In Proc. 36th
References Conference of Uncertainty in Artificial Intelligence 2020, UAI 2020
1. Peters, J., Janzing, D. & Schölkopf, B. Elements of Causal Inference: Vol. 124, 151–160 (eds Peters, J. et al.) (PMLR, 2020).
Foundations and Learning Algorithms (MIT Press, 2017). 24. Ke, N. R. et al. Learning to induce causal structure. Preprint at
2. Arjovsky, M., Bottou, L., Gulrajani, I. & Lopez-Paz, D. Invariant https://arxiv.org/abs/2204.04875 (2022).
risk minimization. Preprint at https://arxiv.org/abs/1907.02893 25. Kemmeren, P. et al. Large-scale genetic perturbations reveal
(2019). regulatory networks and an abundance of gene-specific
3. Heinze-Deml, C., Maathuis, M. H. & Meinshausen, N. Causal repressors. Cell 157, 740–752 (2014).
structure learning. Annu. Rev. Stat. Appl. 5, 371–391 (2018). 26. Meinshausen, N. et al. Methods for causal inference from gene
4. Spirtes, P., Glymour, C. & Scheines, R. Causation, Prediction and perturbation experiments and validation. Proc. Natl Acad. Sci.
Search (MIT Press, 2000). USA 113, 7361–7368 (2016).
27. Zhang, J. Causal reasoning with ancestral graphs. J. Mach. Learn. 47. Lagemann, K., Lagemann, C., Taschler, B. & Mukherjee, S. Deep
Res. 9, 1437–1474 (2008). learning of causal structures in high dimensions under data
28. Alon, U. An Introduction to Systems Biology: Design Principles of limitations https://codeocean.com/capsule/4465854/tree/
Biological Circuits (CRC Press, 2019). v1CodeOcean (2023).
29. Hyttinen, A., Eberhardt, F. & Hoyer, P. O. Learning linear cyclic
causal models with latent variables. J. Mach. Learn. Res. 13, Acknowledgements
3387–3439 (2012). This work was partly supported by the German Federal Ministry
30. Eberhardt, F. & Scheines, R. Interventions and causal inference. of Education and Research (BMBF) project ‘LODE’, the UK Medical
Philos. Sci. 74, 981–995 (2007). Research Council (MC-UU-00002/17) and the National Institute for
31. Kocaoglu, M., Shanmugam, K. & Bareinboim, E. Experimental Health Research (Cambridge Biomedical Research Centre at the
design for learning causal graphs with latent variables. In Proc. Cambridge University Hospitals NHS Foundation Trust).
Advance in Neural Information Processing Systems Vol. 30,
7018–7028, (eds Guyon, I. et al.) (Curran Associates, 2017). Author contributions
32. Replogle, J. M. et al. Mapping information-rich Methods were developed by K.L. and S.M. Implementation and
genotype-phenotype landscapes with genome-scale Perturb-seq. experiments were performed by K.L., supported by C.L. B.T.
Cell 185, 2559–2575 (2022). contributed to the design and implementation of experiments using
33. Schölkopf, B. et al. On causal and anticausal learning. In Proc. the baseline algorithms. The manuscript was written by K.L. and S.M.,
29th International Conference on Machine Learning, ICML 2012 with input from C.L. and B.T. The research was supervised by S.M.
459–466 (eds Langford, J. et al.) (icml.cc/Omnipress, 2012).
34. Silverman, B. W. Density Estimation for Statistics and Data Analysis Funding
(Chapman & Hall, 1986). Open access funding provided by Deutsches Zentrum
35. Turlach, B. Bandwidth selection in kernel density estimation: a für Neurodegenerative Erkrankungen e.V. (DZNE) in der
review. Technical Report (1999). Helmholtz-Gemeinschaft.
36. He, K., Zhang, X., Ren, S. & Sun, J. Deep residual learning for
image recognition. In Proc. 2016 IEEE Conference on Computer Competing interests
Vision and Pattern Recognition (CVPR) 770–778 (IEEE, 2016). The authors declare no competing interests.
37. Szegedy, C. et al. Going deeper with convolutions. In Proc. 2015
IEEE Conference on Computer Vision and Pattern Recognition Additional information
(CVPR) 1–9 (IEEE, 2015). Supplementary information The online version
38. He, K., Zhang, X., Ren, S. & Sun, J. Delving deep into rectifiers: contains supplementary material available at
surpassing human-level performance on ImageNet classification. https://doi.org/10.1038/s42256-023-00744-z.
In Proc. 2015 IEEE International Conference on Computer Vision
(ICCV) 1026–1034 (IEEE, 2015). Correspondence and requests for materials should be addressed to
39. Xie, S., Girshick, R., Dollár, P., Tu, Z. & He, K. Aggregated Kai Lagemann or Sach Mukherjee.
residual transformations for deep neural networks. In 2017 IEEE
Conference on Computer Vision and Pattern Recognition (CVPR) Peer review information Nature Machine Intelligence thanks the
5998–5995 (IEEE, 2017). anonymous reviewers for their contribution to the peer review of this
40. Zhang, M. & Chen, Y. Link prediction based on graph neural work. Primary Handling Editor: Liesbeth Venema, in collaboration with
networks. In Proc. Advances in Neural Information Processing the Nature Machine Intelligence team.
Systems 2018 Vol. 31, 5165–5175 (eds Bengio, S. et al.) (Curran
Associates, 2018). Reprints and permissions information is available at
41. Chen, D. et al. Measuring and relieving the over-smoothing www.nature.com/reprints.
problem for graph neural networks from the topological view.
Computing Research Repository (CoRR) https://doi.org/10.1609/ Publisher’s note Springer Nature remains neutral with
aaai.v34i04.5747 (2019). regard to jurisdictional claims in published maps and
42. Li, Q., Han, Z. & Wu, X.-M. Deeper insights into graph institutional affiliations.
convolutional networks for semi-supervised learning. In Proc.
32nd AAAI Conference on Artificial Intelligence 3538–3545 (eds Open Access This article is licensed under a Creative Commons
McIlraith, S. et al.) (AAAI, 2018). Attribution 4.0 International License, which permits use, sharing,
43. Zhang, M., Cui, Z., Neumann, M. & Chen, Y. An end-to-end deep adaptation, distribution and reproduction in any medium or format,
learning architecture for graph classification. In Proc. 32nd AAAI as long as you give appropriate credit to the original author(s) and the
Conference on Artificial Intelligence 4438–4445 (eds McIlraith, S. source, provide a link to the Creative Commons license, and indicate
et al.) (AAAI, 2018). if changes were made. The images or other third party material in this
44. Paszke, A. et al. PyTorch: an imperative style, high-performance article are included in the article’s Creative Commons license, unless
deep learning library. In Proc. Advances in Neural Information indicated otherwise in a credit line to the material. If material is not
Processing Systems Vol. 32, 8026–8037 (eds Wallach, H. et al.) included in the article’s Creative Commons license and your intended
(Curran Associates, 2019). use is not permitted by statutory regulation or exceeds the permitted
45. Wang, M. et al. Deep Graph Library: a graph-centric, use, you will need to obtain permission directly from the copyright
highly-performant package for graph neural networks. Preprint at holder. To view a copy of this license, visit http://creativecommons.
https://arxiv.org/abs/1909.01315 (2019). org/licenses/by/4.0/.
46. Kingma, D. P. & Ba, J. Adam: a method for stochastic optimization.
In 3rd International Conference on Learning Representations (2015). © The Author(s) 2023