0% found this document useful (0 votes)
16 views12 pages

Improving Performance of Deep Learning Models

Uploaded by

a1173492690
Copyright
© © All Rights Reserved
We take content rights seriously. If you suspect this is your content, claim it here.
Available Formats
Download as PDF, TXT or read online on Scribd
0% found this document useful (0 votes)
16 views12 pages

Improving Performance of Deep Learning Models

Uploaded by

a1173492690
Copyright
© © All Rights Reserved
We take content rights seriously. If you suspect this is your content, claim it here.
Available Formats
Download as PDF, TXT or read online on Scribd
You are on page 1/ 12

Articles

https://doi.org/10.1038/s42256-021-00343-w

Improving performance of deep learning models


with axiomatic attribution priors and expected
gradients
Gabriel Erion1,2,4, Joseph D. Janizek1,2,4, Pascal Sturmfels1,4, Scott M. Lundberg1,3 and Su-In Lee 1 ✉

Recent research has demonstrated that feature attribution methods for deep networks can themselves be incorporated into
training; these attribution priors optimize for a model whose attributions have certain desirable properties—most frequently,
that particular features are important or unimportant. These attribution priors are often based on attribution methods that
are not guaranteed to satisfy desirable interpretability axioms, such as completeness and implementation invariance. Here we
introduce attribution priors to optimize for higher-level properties of explanations, such as smoothness and sparsity, enabled
by a fast new attribution method formulation called expected gradients that satisfies many important interpretability axioms.
This improves model performance on many real-world tasks where previous attribution priors fail. Our experiments show that
the gains from combining higher-level attribution priors with expected gradients attributions are consistent across image, gene
expression and healthcare datasets. We believe that this work motivates and provides the necessary tools to support the wide-
spread adoption of axiomatic attribution priors in many areas of applied machine learning. The implementations and our results
have been made freely available to academic communities.

R
ecent work on interpreting machine learning (ML) models greatly expand both the number of ways that a human-in-the-loop
has focused on feature attribution methods. Given an input can influence deep models and the precision with which they can
datum, a model and a prediction, such methods assign a num- do so. However, two drawbacks limit this method’s applicability to
ber to each input feature that represents how important the feature real-world problems. First, gradients do not satisfy the same theo-
was for making the prediction. Current research also investigates retical guarantees as modern feature attribution methods. This
the axioms that attribution methods should satisfy1–4 and how they leads to well-known problems such as saturation: operations, such
provide insight into model behaviour5–8. Feature attribution meth- as rectified linear units (ReLUs) and sigmoids, which have large
ods often reveal problems in a model or dataset. For example, a flat ‘saturated’ regions, can lead to zero gradient attribution even
model may place too much importance on undesirable features, for important features2. Second, it can be difficult to specify which
rely on many features when sparsity is desired or be sensitive to features should be important in a binary manner.
high-frequency noise. In such cases, humans often have a prior Additional recent work discusses the need for priors that incor-
belief about how a model should treat input features but find it dif- porate human intuition to develop robust and interpretable mod-
ficult to mathematically encode this prior for neural networks in els11. Still, it remains challenging to encode priors such as ‘have
terms of the model parameters. smoother attributions across an image’ or ‘treat this group of fea-
One method to address such problems is what we call an attri- tures similarly’ by penalizing a model’s input gradients or param-
bution prior: if it is possible for explanations to reveal problems in eters. Some recent attribution priors have proposed regularizing
a model, then constraining the model’s explanations during train- integrated gradients (IG) attributions12,13. While promising, this
ing can help the model avoid such problems. It is worth noting work suffers from three major weaknesses: it does not clearly dem-
that the vast majority of feature attribution methods focus exclu- onstrate improvements over gradient-based attribution priors,
sively on explaining why a given prediction was made. Only a very it penalizes attribution deviation from a target value rather than
small number of papers have investigated incorporating attribu- encoding sophisticated priors such as those we mention above,
tions themselves into model training. The first such paper, by Ross and it imposes a large computational cost by training with tens to
et al.9, used a binary indicator of whether each feature should or hundreds of reference samples per batch. A contemporary method
should not be important for making predictions on each sample in called contextual decomposition explanation penalization (CDEP)
the dataset and penalized the gradients of unimportant features. A uses a framework similar to attribution priors and penalizes expla-
very recent publication successfully used the gradient-based prior nations generated by the contextual decomposition method14.
of Ross et al. as part of a human-in-the-loop strategy to improve Unlike all other interpretability methods discussed in this paper,
model generalization performance and user trust, as well as con- CDEP penalizes explanations for pre-specified groups of features,
tributing their own model-agnostic method for penalizing feature meaning it is best suited for a different set of problems than we con-
importances10. Such results create a clear synergy with our study, sider. More discussion of CDEP can be found in ‘Attribution pri-
which improves the quality of calculated feature importances and ors are a flexible framework for encoding domain knowledge’ and
develops new forms of attribution priors. This has the potential to Supplementary Sections A and B.

Paul G. Allen School of Computer Science and Engineering, University of Washington, Seattle, WA, USA. 2Medical Scientist Training Program,
1

University of Washington, Seattle, WA, USA. 3Microsoft Research, Seattle, WA, USA. 4These authors contributed equally: Gabriel Erion, Joseph D. Janizek,
Pascal Sturmfels. ✉e-mail: [email protected]

620 Nature Machine Intelligence | VOL 3 | July 2021 | 620–631 | www.nature.com/natmachintell


NaTuRe MacHine InTelligence Articles
The main contribution of this work is a broadened interpretation that calculating Φ with attribution methods that satisfy previously
of attribution priors that includes any case in which the training established interpretability axioms improves performance (see ‘EG
objective incorporates differentiable functions of a model’s feature outperforms other attribution methods’ and ‘Expected gradients’ in
attributions. This can be seen as a generalization of gradient-based Methods for further discussion of interpretability axioms). Second,
regularization9,15–18 and it can be used to encode meaningful domain rather than simply encouraging each feature’s attribution to be near
knowledge more effectively than existing methods. Whereas previ- a target value as in previous work, we enforce high-level priors over
ous attribution priors generally took the form of ‘encourage feature the relationships between features.
i’s attribution to be near a pre-determined target value’, the priors we In image data, we use a Laplace zero-mean prior on the differ-
present here consider relative importance among multiple features ence between attributions of adjacent pixels, which encourages a
and do not require pre-determined target values for any feature’s low total variation (high smoothness) of attributions:
attribution. Specifically, we introduce an image prior enforcing that ∑∑ ℓ
neighbouring pixels have similar attributions, a graph prior for bio- Ωpixel (Φ(θ, X)) = |ϕi+1,j − ϕℓi,j | + |ϕℓi,j+1 − ϕℓi,j |,
logical data enforcing that related genes have similar attributions, ℓ i,j
and a sparsity prior enforcing that a few features have large attribu-
tions while all others have near-zero attributions. where i, j indexes the pixels of an image by rows and columns,
We also introduce a new general-purpose feature attribution respectively and ℓ indexes each image.
method to enforce these priors, expected gradients (EG). As men- In gene expression data, we use a Gaussian zero-mean prior on
tioned above, virtually all attribution methods are designed to the difference between mean absolute attributions ϕ̄i of functionally
explain a model’s prediction to humans, not to be penalized during related genes, which encourages such similar genes to have similar
training. This means many such methods may be computationally attributions:
difficult to incorporate into the training process. EG is an attribu- ∑ 2 T
tion method explicitly designed for regularization as an attribution Ωgraph (Φ(θ, X)) = Wi,j (ϕ̄i − ϕ̄j ) = ϕ̄ LG ϕ̄,
prior (Fig. 1a); it can be efficiently regularized during training due i,j
to its formulation as an expectation, which naturally lends itself to
batched estimates of the attribution. It also eliminates a hyperpa- where T represents a vector transpose, Wi,j is the weight of connec-
rameter choice required by IG2. Since these attributions are used tion between two genes in a biological graph, and LG is the graph
not only to interpret trained models but also as part of the training Laplacian.
objective itself, it is essential to guarantee that the attributions will Finally, in health data where sparsity is desired, we use a prior
be of high quality. We therefore show that our attribution method on the Gini coefficient of the mean absolute attributions ϕ̄i, which
satisfies important interpretability axioms. encourages a small number of features to have a large percentage of
Across three different prediction tasks, we show that training the total attribution while others are near zero:
with EG outperforms training with previous, more limited versions
of attribution priors. On images, our image prior produces a model p ∑
∑ p
that is more interpretable and generalizes better to noisy data. On |ϕ̄i − ϕ̄j |
i = 1 j= 1
gene expression data, our graph prior reduces prediction error and Ωsparse (Φ(θ, X)) = − = −2G(ϕ̄),
p

better captures biological signal. Finally, on a patient mortality pre-
n ϕ̄i
diction task, our sparsity prior yields a sparser model and improves i=1
performance when learning from limited training data.
where G is the Gini coefficient.
Results None of these priors requires specifying target values for features,
Attribution priors are a flexible framework for encoding domain and all improve performance over simpler baselines. For more details
knowledge. Let X ∈ Rn×p denote a dataset with labels y ∈ Rn×o, on our priors, see ‘Specific priors’ in Methods, and for more details
where n is the number of samples, p is the number of features and o on previous attribution priors, see ‘Previous attribution priors’ in
is the number of outputs. In standard deep learning, we find optimal Methods. We also note that these priors involve the relationships
parameters θ by minimizing loss L, with a regularization term Ω′(θ) between the attributions for all features in the dataset. Gradients, IG
weighted by λ′ on the parameters: and our method (EG) discussed below are all designed for calculat-
ing such attributions. The CDEP method discussed above differs in
θ = argminθ L(θ;X, y) + λ′ Ω′ (θ ). that it penalizes the attributions of a single pre-specified group of
features14; while CDEP has reported better performance with cer-
Attribution priors involve a model’s attributions, represented by tain types of prior than EG and gradients, we believe this is due to
the matrix Φ(θ, X), where each entry ϕℓi is the importance of fea- the fact that the methods are inherently best suited to different types
ture i in the model’s output for sample ℓ. The attribution prior is a of prior. Using CDEP with the specific priors proposed in this work
scalar-valued penalty function of the feature attributions Ω(Φ(θ, X)), would require several orders of magnitude more backward passes of
which represents a log-transformed prior probability distribution the model during training than our approach. CDEP also uses addi-
over possible attributions (λ is the regularization strength). The tional preprocessing steps that are not necessary in our approach,
attribution prior is modular and agnostic to the particular attribu- which further distinguishes the scenarios in which each method is
tion method. This results in the optimization: most applicable. For further discussion of related work, including
a discussion of specific cases for which our method and CDEP are
θ = argminθ L(θ;X, y) + λΩ(Φ(θ, X)), best suited, see Supplementary Sections A and B.

where the standard regularization term has simply been replaced EG outperforms other attribution methods. Attribution priors
with an arbitrary, differentiable penalty function on the feature involve using feature attributions not just as a post-hoc analy-
attributions. sis method but also as a key part of the training objective. Thus,
While feature attributions have previously been used in train- it is essential to guarantee as much as possible that the attribution
ing (more details in ‘Previous attribution priors’ in Methods)9,12, method used will produce high-quality attributions and run fast
our approach offers two novel components. First, we demonstrate enough to be calculated for each training batch. We propose an

Nature Machine Intelligence | VOL 3 | July 2021 | 620–631 | www.nature.com/natmachintell 621


Articles NaTuRe MacHine InTelligence

a Post-hoc explanation During training

x1

x2

N training samples
Single Fixed Input
reference interpolation sample
x3
IG

xn

Random Random Input


references interpolation sample x1

x2

N training samples
x3
x
EG

xn

b Trombone Gradients IG EG

c Digits Gradients IG EG

Fig. 1 | EG is a feature attribution method designed to be regularized during training. a, A comparison of our method, EG, to IG as both a post-hoc explanation
method (left), and as a differentiable feature attribution to be penalized during training to enforce attribution priors (right). b, Comparison of saliency maps
generated by three different attribution methods on an image from the ImageNet dataset. The saliency maps demonstrate how the IG attribution method fails
to highlight black pixels as important when black is used as a baseline input, while EG is capable of highlighting the black pixels in these images as important.
c, Comparison of saliency maps for the same three attribution methods for two MNIST digits. Again, IG fails to highlight potentially relevant image regions
(like the empty middle of the 0 or the empty region at the top of the 4, which might make the digit resemble a 9 if it were filled in).

axiomatic feature attribution method called expected gradients output for a given sample) and implementation invariance (the
(EG), which avoids problems with existing methods and is natu- attributions are identical for any of the infinite possible implemen-
rally suited to being incorporated into training. EG extends the IG tations of the same function). Because these methods satisfy com-
method2, and like IG, satisfies a variety of desirable interpretability pleteness, they are not subject to the problems with input saturation
axioms such as completeness (the feature attributions sum to the that affect gradient attributions. Because these methods satisfy

622 Nature Machine Intelligence | VOL 3 | July 2021 | 620–631 | www.nature.com/natmachintell


NaTuRe MacHine InTelligence Articles
of the required background reference x′. For example, in image
Table 1 | Synthetic data benchmark results for attribution tasks, the image of all zeros is often chosen as a baseline, but doing
methods so implies that black pixels will not be highlighted as important
(Fig. 1b,c). This problem can be solved by integrating gradients over
Method Remove Remove Remove Convergence
multiple references. However, calculating multiple Riemann inte-
positive negative absolute time
grals is expensive in terms of time and memory, probably prohibi-
EG 3.612 3.759 0.897 0.150 tively so if calculated during every batch of training (Fig. 1a, right).
IG 3.539 3.687 0.872 0.989 EG naturally accommodates multiple references by performing the
Gradients 0.035 0.110 0.729 0.250 Monte Carlo integral with samples from multiple references and
interpolation points (here, x is the sample, x′ is a reference and D is
Random −0.053 0.034 0.400 –
the reference distribution):
Larger numbers mean a better feature attribution method for all metrics other than convergence
time, for which a smaller number indicates faster convergence. The first three metrics measure
∂f(x′ + α × (x − x′ ))
the quality of the method for correctly identifying important features, whereas convergence time EGi (x) = Ex′ ∼D,α∼U(0,1) [ (xi − x′ ) × ]
indicates how effectively the method is regularized during training as an attribution prior. The ∂xi
‘remove positive’ metric measures the average magnitude change in model output when the
features identified as having the largest positive impact by each method are masked by the feature In principle, any distribution D over reference samples could be
mean, whereas ‘remove negative’ measures the average magnitude change in model output when
the features identified as having the largest negative impact by each method are masked by the used to calculate EG attributions; choosing which distribution to
feature mean. The ‘remove absolute’ metric measures the average increase in model loss when the use depends on the nature of the attribution problem. For exam-
features identified as having the largest magnitude impact on the model are masked by the feature ple, setting D to be a single sample recovers single-reference EG:
mean. Each model is trained on 900 samples and tested using 100 samples. EG attains the best
benchmark scores of all of the tested attribution methods (P = 7.2 × 10−5, one-tailed binomial test,
the same reference setup as IG but with the Monte Carlo speedup
tested across all 18 attribution performance metrics, see Supplementary Section D for details on of EG (Supplementary Section D.1). By default, we do not choose
exact calculation of these metrics and exhaustive list of metrics considered). D to be a single sample but rather a uniform distribution over the
entire training set. This tells us which features cause x’s output to be
different from the output at all other points in the dataset, on aver-
implementation invariance, they are straightforward to practically age. In certain cases, we may want to use a different distribution D.
apply to any differentiable model, regardless of specific network For example, we might want to distinguish between subgroups and
architectures (see ‘Expected gradients’ in Methods for an extended understand why a digit is classified as a ‘seven’ rather than a ‘one’
discussion of the interpretability axioms satisfied by EG). by choosing references only from the ‘one’-labelled training sam-
IG generates feature attributions by integrating the gradients of ples. We could also account for baseline subgroup characteristics by
the model’s output f between the sample of interest and a reference explaining, for example, an 80-year-old patient’s mortality risk rela-
sample x′ (Fig. 1a, left). tive to other 80 year olds; this could prevent age and age-correlated
∫ 1 features from being trivially listed as the most important. While our
∂f(x′ + α(x − x′ )) formulation and implementation of EG support any choice of distri-
IGi (x) := dα
α =0 ∂xi bution D, the examples in this paper do not focus on subgroup anal-
ysis, so we set D to be a uniform distribution over the training set
where ∂ represents a partial derivative and α represents progress (see ‘Expected gradients’ in Methods and Supplementary Section C
along the integration path. If the attribution function Φ in our for implementation details and pseudocode).
attribution prior Ω(Φ(θ, X)) is IG, regularizing Φ would require In a simple experiment using synthetic data to assess the impact
hundreds of extra gradient calls every training step (the original of k on the convergence time of model training (rather than the
IG paper2 recommends 20 to 300 gradient calls to compute attribu- convergence of a single explanation), we found that regularizing
tions). This makes training with IG prohibitively slow—in fact, ref. 12 EG with k = 1 was more effective at removing a model’s depen-
finds that using IG can take up to 30 times longer than standard dency on one of two correlated features than gradients or even IG
training even when only back-propagating gradients through part of with more than k samples (Table 1, Supplementary Section D and
the network. However, most deep learning models today are trained Supplementary Fig. 4). The k = 1 setting also appeared optimal for
using some variant of batch gradient descent, where the gradient EG; setting k > 1 required more total gradient calls for convergence
of a loss function is approximated over many training steps using (Supplementary Section D.1 and Supplementary Fig. 3). We also
mini-batches of data. We can dramatically improve speed over an compare EG to other feature attribution methods using synthetic
IG attribution prior by using a similar idea and formulating the IG data benchmarks introduced in ref. 5 (Table 1), which are available
integral as an expectation over integration path steps α drawn from as part of the SHAP software package. These benchmark metrics
a uniform distribution U (see Table 1 and Supplementary Section evaluate whether each attribution method finds the most important
D.1 for more details on convergence time benchmark). This Monte features for a given dataset and model. EG significantly outperforms
Carlo estimate of the integral is the core of our EG method, defined the next best feature attribution method (P = 7.2 × 10−5, one-tailed
below for a single reference x′: binomial test). We believe this demonstrates another benefit of
EG; by averaging attributions over multiple reference samples,
∂f(x′ + α × (x − x′ )) it becomes more robust to the wide array of patterns of missing-
SingleRefEGi (x) = Eα∼U(0,1) [ (xi − x′ ) × ]
∂xi ness and re-imputation tested in the benchmark. We provide more
details and additional benchmarks in Supplementary Section D.
Just like the gradient of the loss, EG attributions can be calcu-
lated in a batched manner during training (Fig. 1a, right). We let k A pixel attribution prior improves robustness to image noise.
be the number of samples we draw for this Monte Carlo integral at Previous work on interpreting image models has focused on creating
each mini-batch. Remarkably, because the variance in each batched pixel attribution maps, which assign a value to each pixel indicating
EG attribution will be smoothed over thousands of batches dur- how important that pixel was for a model’s prediction2,19. Attribution
ing training, we find that as small as k = 1 suffices to regularize the maps can be noisy and difficult to understand due to their tendency
explanations. to highlight seemingly unimportant background pixels, indicating
This expectation formulation also enables us to solve a long- the model may be vulnerable to adversarial attacks20. Although we
standing problem with IG as an attribution method—the choice may prefer a model with smoother attributions, existing methods

Nature Machine Intelligence | VOL 3 | July 2021 | 620–631 | www.nature.com/natmachintell 623


Articles NaTuRe MacHine InTelligence

a Pixel
b
1.0
Original attribution Method
image Baseline prior 0.9 Pixel attribution prior
(total variation of EG)
0.8 Total variation of gradients
Baseline (no prior)
0.7

Test accuracy
0.6

0.5

0.4

0.3

0.2

0.1

0
0 0.2 0.4 0.6 0.8 1.0 1.2 1.4 1.6 1.8 2.0

Standard deviation of Gaussian noise

Fig. 2 | Pixel attribution prior improves saliency map smoothness and increases robustness of MNIST classifier to noise. a, EG attributions (from 100
samples) on MNIST for both an unregularized model and a model trained with an attribution prior regularized using EG. The latter achieves visually
smoother attributions, and it better highlights how the network classifies digits (for example, the top part of the 4 being very important). Unlike previous
methods that take additional steps to smooth saliency maps after training21,22, these are unmodified saliency maps directly from the learned model.
b, Training with an attribution prior on total variance of EG attributions induces robustness to Gaussian noise without specifically training for robustness.
This robustness greatly exceeds that provided by an attribution prior on the total variance of model gradients. Shaded bars around each line indicate
standard deviation of the accuracy results; however, the bars are small enough to be indistinguishable in this plot.

only post-process attribution maps but do not change model method, we trained five models with different random initializa-
behaviour19,21,22. Such techniques may not be faithful to the original tions. In Figs. 2 and 3, we plot the mean and standard deviation of
model11. In this section, we describe how we applied our framework test accuracy on MNIST and CIFAR-10, respectively, as a function
to train image models with naturally smoother attributions. of standard deviation of added Gaussian noise. The figures show
To regularize pixel-level attributions, we used the following intu- that our regularized model is more robust to noise than both the
ition: neighbouring pixels should have a similar impact on an image baseline and gradient-based models.
model’s output. To encode this intuition, we chose a total variation Both the robustness and more intuitive saliency maps our
loss on pixel-level attributions (see ‘Specific priors’ in Methods for method provides come at the cost of reduced test set accuracy
more detail). We applied this pixel smoothness attribution prior (0.93 ± 0.002 for the baseline versus 0.85 ± 0.003 for pixel attribu-
to the Modified National Institute of Standards and Technology tion prior model on CIFAR-10). Mathematically, adding a penalty
(MNIST) database, containing handwritten digits classified from 0 term to the optimization objective should only ever reduce train-
to 9, and the Canadian Institute for Advanced Research (CIFAR)-10 ing set performance; it is reasonable that in many cases this can
dataset, containing colour images classified into 10 categories such lead to a reduction in test-set performance as well. However, test
as cats, dogs and cars15,23. On MNIST we trained a two-layer convo- accuracy is not the only metric of interest for image classifiers.
lutional neural network; for CIFAR-10 we trained a VGG16 network The trade-off between robustness and accuracy that we observe
from scratch (see ‘Image model experimental settings’ in Methods is consistent with previous work that suggests image classifi-
for more details)24. In both cases, we optimized hyperparameters for ers trained solely to maximize test accuracy rely on features that
the baseline model without an attribution prior. To choose λ, we are brittle and difficult to interpret11,26,27. Despite this trade-off,
searched over values in [10−20, 10−1] and chose the λ that minimized we find that at a stricter hyperparameter cutoff for λ on CIFAR-
the attribution prior penalty and achieved a test accuracy within 1% 10—within 1% test accuracy of the baseline, rather than 10%—our
of the baseline model for MNIST and 10% for CIFAR-10. Figures 2 methods still achieve modest but significant robustness relative
and 3 show EG attribution maps for both the baseline and the model to the baseline. We also evaluated our method against several
regularized with an attribution prior on five randomly selected test other attribution priors including IG and, for ablation purposes,
images on MNIST and CIFAR-10, respectively. In all examples, the single-reference EG (Supplementary Figs. 10 and 11). We found
attribution prior yields a model with visually smoother attributions. that the pixel attribution prior outperformed standard IG and
Remarkably, in many instances, smoother attributions better high- that most of this additional performance was due to our random
light the target object’s structure. interpolation. Both the pixel attribution prior and single-reference
Recent work has suggested that image classifiers are brittle EG were much more robust than all other methods; however, only
to small domain shifts: small changes in the underlying distribu- the pixel attribution prior, which used multiple references, could
tion of the training and test set can lead to large reductions in test highlight important foreground and background regions in addi-
accuracy25. To simulate a domain shift, we applied Gaussian noise tion to providing robustness and smoothness. For details of the EG
to images in the test set and re-evaluated the performance of the versus IG comparison, results at different hyperparameter thresh-
regularized and baseline models. As an adaptation of ref. 9, we also olds, more details on our training procedure and additional experi-
compared the attribution prior model with regularizing the total ments on MNIST, CIFAR-10 and ImageNet, see Supplementary
variation of gradients with the same criteria for choosing λ. For each Sections E–H.

624 Nature Machine Intelligence | VOL 3 | July 2021 | 620–631 | www.nature.com/natmachintell


NaTuRe MacHine InTelligence Articles
a Pixel
b 1.0
Original attribution Method
image Baseline prior 0.9 Pixel attribution prior
(total variation of EG)
0.8 Total variation of gradients
Dog Baseline (no prior)
0.7

Test accuracy
0.6

Cat 0.5

0.4

0.3
Frog
0.2

0.1

Deer 0
0 0.1 0.2 0.3 0.4 0.5 0.6 0.7 0.8 0.9 1.0

Ship

Standard deviation of Gaussian noise

Fig. 3 | Pixel attribution prior improves saliency map smoothness and increases robustness of CIFAR-10 classifier to noise. a, EG attributions (from 100
samples) on CIFAR-10 for both the baseline model and the model trained with an attribution prior for five randomly selected images classified correctly by
both models. Training with an attribution prior generates visually smoother attribution maps in all cases. Notably, these smoothed attributions also appear
more localized towards the object of interest. b, Training with an attribution prior on total variance of EG attributions induces robustness to Gaussian noise,
achieving more than double the accuracy of the baseline at high noise levels. This robustness is not achievable by choosing total variation of gradients as
the attribution function. Shaded bars around each line indicate standard deviation of the accuracy results.

A graph attribution prior improves anticancer drug response model’s gradients, rather than a penalty on the axiomatically correct
prediction. In the image domain, our attribution prior took the expected gradient feature attribution, does not perform significantly
form of a penalty encouraging smoothness over adjacent pixels. In better than a baseline neural network. We also observe substantially
other domains, there may be prior information about specific rela- improved test performance when using the prior graph information
tionships between features that can be encoded as a graph (such as to regularize a linear LASSO model. Finally, we note that our graph
social networks, knowledge graphs or protein–protein interactions). attribution prior neural network significantly outperforms graph
For example, previous work in bioinformatics has shown that pro- convolutional neural networks, a recent method for utilizing graph
tein–protein interaction networks contain valuable information for information in deep neural networks31.
improving performance on biological prediction tasks28. Therefore, To find out whether our model’s attributions match biological
in this domain, we regularized attributions to be smooth over the domain knowledge, we first compared the list of top genes gener-
protein–protein feature graph analogously to the regular graph of ated by our network trained with a graph attribution prior (ranked
pixels in the image. by mean absolute feature attribution) to a ‘ground truth’ list of
Incorporating the Ωgraph attribution prior not only led to a model AML-relevant genes found by querying the GeneCards database
with more reasonable attributions but also improved predictive per- (Fig. 4b). When we count the number of AML-relevant genes at
formance by letting us incorporate prior biological knowledge into each position in our network’s top gene list and compare this to the
the training process. We downloaded publicly available gene expres- number of AML-relevant genes at each position in a standard neu-
sion and drug response data for patients with acute myeloid leu- ral network’s top gene list, we see that the graph attribution prior
kaemia (AML, a type of blood cancer) and tried to predict patients’ network captures significantly more biologically relevant genes.
drug response from their gene expression29. For this regression In addition, to check for biological pathway-level enrich-
task, an input sample was a patient’s gene expression profile plus ments, we conducted gene set enrichment analysis (a modified
a one-hot encoded vector indicating which drug was tested in that Kolmogorov–Smirnov test). We measured whether our top genes,
patient, while the label we tried to predict was drug response (mea- ranked by mean absolute feature attribution, were enriched for
sured by IC50, a continuous value representing the concentration of membership in any pathways (see ‘Biological experiments’ in
the drug required to kill half of the patient’s tumour cells). To define Methods and Supplementary Section I for more detail, including
the graph used by our prior, we downloaded the tissue-specific the top pathways for each model)32. We find that the neural net-
gene-interaction graph for the tissue most closely related to AML in work with the tissue-specific graph attribution prior captures far
the HumanBase database30. more biologically relevant pathways (increased number of signifi-
A two-layer neural network trained with our graph attribu- cant pathways after false discovery rate correction) than a neural
tion prior (Ωgraph) significantly outperforms all other methods in network without attribution priors33. Furthermore, the pathways
terms of test set performance as measured by R2, which indicates our model uses more closely match biological expert knowledge,
the fraction of the variance in the output explained by the model that is, they included prognostically useful AML gene expression
(Fig. 4, see ‘Biological experiments’ in Methods for significance profiles as well as important AML-related transcription factors
testing). Unsurprisingly, when we replace the biological graph (Supplementary Section I)34,35. These results are expected, given that
from HumanBase with a randomized graph, we find that the test neural networks trained without priors can learn a relatively sparse
performance is no better than the performance of a neural network basis of genes that will not enrich for specific pathways (for exam-
trained without any attribution prior. Extending the method pro- ple, a single gene from each correlated pathway), while those trained
posed in ref. 9 by applying our new graph prior as a penalty on the with our graph prior will spread credit among functionally related

Nature Machine Intelligence | VOL 3 | July 2021 | 620–631 | www.nature.com/natmachintell 625


Articles NaTuRe MacHine InTelligence

a b
350
0.40
300

0.39
250
With graph

AML-related genes
0.38 attribution prior
200
R2

0.37
150
Standard
neural network
0.36 100

0.35 50

0.34 0
Graph Graph prior Random Graph Neural Graph LASSO 200 400 600 800 1,000 1,200 1,400 1,600 1,800
attribution with attribution convolution network LASSO
Model top genes
prior gradients prior
Existing methods

Fig. 4 | Graph attribution prior improves test accuracy and biological relevance of anticancer drug response prediction model. a, A neural network
trained with our graph attribution prior (bold) attains the best test performance, while one trained with the same graph penalty on the gradients (italics,
adapted from ref. 9) does not perform significantly better than a standard neural network (error bars indicate the extent of the bootstrapped 95%
confidence interval of the mean test set R2 value, over ten retrainings of the model on random re-splits of the data). b, A neural network trained with our
graph attribution prior gives more weight to AML-relevant genes than a standard neural network trained without the graph attribution prior (solid line
indicates average over ten random re-splits of the data and retrainings of the model, error bands indicate the extent of the bootstrapped 95%
confidence interval).

genes. This demonstrates the graph prior’s value as an accurate and be trained with very few labelled patient samples or reduce cost by
efficient way to encourage neural networks to treat functionally accurately risk-stratifying patients using few lab tests. We randomly
related genes similarly. sampled training and validation sets of only 100 patients each, plac-
ing all other patients in the test set, and ran each experiment 200
A sparsity prior improves performance with limited training times with a new random sample to average out variance. We built
data. Feature selection and sparsity are popular ways to alleviate three-layer binary classifier neural networks regularized using L1,
the curse of dimensionality, facilitate interpretability and improve SGL and sparse attribution prior penalties to predict patient sur-
generalization by building models that use a small number of input vival, as well as an L1 penalty on gradients adapted for global spar-
features. A straightforward way to build a sparse deep model is to sity from refs. 9,38. The regularization strength was tuned from 10−7
apply an L1 penalty to the first layer (and possibly subsequent lay- to 105 using the validation set for all methods (see ‘Sparsity experi-
ers) of the network. Similarly, the sparse group lasso (SGL) method ments’ in Methods and Supplementary Section J.2).
penalizes all weights connected to a given feature36,37, while a simple The sparse attribution prior enables more accurate test predic-
existing attribution prior approach38 penalizes the gradients of each tions (Fig. 5a) and sparser models (Fig. 5c) when limited training
feature in the model. data is available, with P < 10−4 and t ≥ 4.314 by paired-samples t-test
These approaches suffer from two problems. First, a feature for all comparisons. We also plot the average cumulative impor-
with small gradients or first-layer weights may still strongly tance of sorted features and find that the sparse attribution prior
affect the model’s output39. A feature whose attribution value more effectively concentrates importance in the top few features
(for example, IG or EG) is zero is much less likely to have any (Fig. 5d). In particular, we observe that L1 penalizing the model’s
effect on predictions. Second, successfully minimizing penalties gradients as in ref. 38 rather than its EG attributions performs poorly
such as L1—regardless of attribution type—is not necessarily the in terms of both sparsity and performance. A Gini penalty on gra-
best way to create a sparse model. A model that puts weight w on dients improves sparsity but does not outperform other baselines
1 feature is penalized more than one that puts weight w/2p on each such as SGL and L1 in area under a receiver operating characteristic
of p features. Previous work on sparse linear regression has shown curve (ROC AUC). Finally, we plot the average sparsity of the mod-
that the Gini coefficient G of the weights, proportional to 0.5 minus els (Gini coefficient) against their validation ROC AUC across the
the area under the cumulative distribution function of sorted val- full range of regularization strengths. The sparse attribution prior
ues, avoids such problems and corresponds more directly to a sparse exhibits higher sparsity than other models and a smooth trade-off
model40,41. We extend this analysis to deep models by noting that the between sparsity and ROC AUC (Fig. 5b). Details and results for
Gini coefficient can be written differentiably and used as an attribu- other penalties, including L2, dropout and other attribution priors,
tion prior. are in Supplementary Section J.
Here we show that the Ωsparse attribution prior can build sparser
models that perform better in settings with limited training data. Discussion
We use a publicly available healthcare mortality prediction dataset The immense popularity of deep learning has driven its application
of 13,000 patients42, whose 35 features (118 after one-hot encod- in many areas with diverse, complicated domain knowledge. While
ing) represent medical data such as a patient’s age, vital signs and it is in principle possible to hand-design network architectures to
laboratory measurements. The binary outcome is survival after ten encode this knowledge, a more practical approach involves the
years. Sparse models in this setting may enable accurate models to use of attribution priors, which penalize the importance a model

626 Nature Machine Intelligence | VOL 3 | July 2021 | 620–631 | www.nature.com/natmachintell


NaTuRe MacHine InTelligence Articles
a b
0.775 0.80

Validation ROC AUC


ROC AUC (average)

0.750 0.75
0.725 0.70
0.700 Smooth sparsity–
0.65 AUC trade-off
0.675
0.60
0.650
0.625 0.55
Sparse L1: SGL: SGL: Gini: L1: grad Unreg 0.5 0.6 0.7 0.8 0.9
attribution all 1st all grad (ref. 9)
Sparsity
prior
c d
0.85
1.0
Sparsity (Gini coefficient,

Sparse attribution prior SGL: 1st

Cumulative fraction of
0.80

feature importance
0.8 Gini: grad L1: grad
0.75
SGL: all Unreg
average)

0.70 0.6
L1: all
0.65 0.4
0.60
0.2
0.55
0
0.50
Sparse L1: SGL: SGL: Gini: L1: grad Unreg 0 20 40 60 80 100 120
attribution all 1st all grad
Features by increasing importance
prior

Fig. 5 | Sparse attribution prior builds sparser and more accurate healthcare mortality models. a,c, A sparse attribution prior enables more accurate test
predictions (a) and sparser models (c) across 200 small subsampled datasets (100 training and 100 validation samples, all other samples used for test
set) than other penalties, including gradients. b, Across the full range of tuned parameters, the sparse attribution prior achieves the greatest sparsity and
a smooth sparsity–validation performance trade-off. d, A sparse attribution prior concentrates a larger fraction of global feature importance in the top few
features. ‘Gini’, ‘L1’ and ‘SGL’ indicate the Gini, L1 and SGL penalties, respectively, ‘grad’ indicates a penalty on the gradients, ‘all’ indicates a penalty on all
weights in the model and ‘1st’ indicates a penalty on only the first weight layer. ‘Unreg’ indicates an unregularized model.

places on each of its input features when making predictions. attribution priors: as tools to achieve domain-specific goals without
Unfortunately, previous attribution priors have been limited, both sacrificing efficiency.
theoretically and computationally. Binary penalties only specify
whether features should or should not be important and fail to Methods
capture relationships among features. Approaches that focus only Previous attribution priors. The first instance of what we now call an attribution
on a model’s input gradients change the local decision boundary prior was proposed in ref. 9, where the regularization term was modified to place a
constant penalty on the gradients of undesirable features:
but often fail to impact a model’s underlying decision-making.
Attribution priors on more complicated attributions, such as IG, θ = argminθ L(θ;X, y) + λ′′ ||A ⊙
∂L 2
|| .
have proven computationally difficult. ∂X F
Our work advances previous work both by introducing novel, Here the attribution method is the gradients of the model, represented by the
flexible attribution priors for multiple domains and by enabling matrix ∂∂X
L
whose ℓ, ith entry is the gradient of the loss at the ℓth sample with
the training of such priors with a newly defined feature attribution respect to the ith feature. A is a binary matrix indicating that features should be
method. Our priors lead to smoother and more interpretable image penalized in which samples, and F is the Frobenius norm.
A more general interpretation of attribution priors is that any function of any
models, biological predictive models that incorporate graph-based
feature attribution method could be used to penalize a loss function, thus encoding
prior knowledge and sparser healthcare models that perform better prior knowledge about what properties the attributions of a model should have.
in data-scarce scenarios. Our attribution method not only enables For some model parameters θ, let Φ(θ, X) be a feature attribution method, which
the training of said priors but also outperforms its predecessor— is a function of θ and the data X. Let ϕℓi be the feature importance of feature i
IG—in terms of reliably identifying the features models use to make in sample ℓ. We formally define an attribution prior as a scalar-valued penalty
function of the feature attributions Ω(Φ(θ, X)), which represents a log-transformed
predictions. prior probability distribution over possible attributions:
There remain many avenues for future work in this area. We
chose to base our prior on an improved version of IG because it is θ = argminθ L(θ;X, y) + λΩ(Φ(θ, X)),
the most prominent differentiable feature attribution method we are
where λ is the regularization strength. Note that the attribution prior function Ω is
aware of, but a wide array of other attribution methods exist. Our agnostic to the attribution method Φ.
framework makes it straightforward to substitute any other attribu- Previous attribution priors9,12 required specifying an exact target value for
tion method as long as it is differentiable, and studying the effec- the model’s attributions, but often we do not know in advance which features are
tiveness of other attribution methods as priors would be valuable. important in advance. In general, there is no requirement that Φ(θ, X) constrain
In addition, while we develop new, more sophisticated attribution attributions to particular values. The ‘Results’ section presented three newly
developed attribution priors for different tasks that improve performance without
priors and show their value, there is ample room to improve on our requiring pre-specified attribution targets for any particular feature.
priors and evaluate entirely new ones for other tasks. Determining
the best attribution priors for particular tasks opens a further ave- Expected gradients. EG is an extension of IG2 with fewer hyperparameter
nue of research. We believe that surveys of domain experts to estab- choices. Like several other attribution methods, IG aims to explain the difference
lish model desiderata for particular applications will help to develop between a model’s current prediction and the prediction that the model would
make when given a baseline input. This baseline input is meant to represent
the best priors for any given situation while offering a valuable some uninformative reference input that represents not knowing the value of the
opportunity to put humans in the loop. Overall, the dual advances input features. Although choosing such an input is necessary for several feature
of sophisticated attribution priors and EG enable a broader view of attribution methods2,39,43, the choice is often made arbitrarily. For example, for

Nature Machine Intelligence | VOL 3 | July 2021 | 620–631 | www.nature.com/natmachintell 627


Articles NaTuRe MacHine InTelligence
image tasks, the image of all zeros is often chosen as a baseline, but doing so An additive regularization term is equivalent to adding a multiplicative
implies that black pixels will not be highlighted as important by existing feature (independent) prior to yield a maximum a posteriori (MAP) estimate:
attribution methods. In many domains, it is not clear how to choose a baseline that
correctly represents a lack of information. argminθ ||fθ (X) − y||22 + λ||θ ||22
Our method avoids an arbitrary choice of baseline; it models not knowing
= argmaxθ exp (−||fθ (X) − y||22 ) exp (− λ||θ ||22 ) = θ MAP ,
the value of a feature by integrating over a dataset. For a model f, the IG value for
feature i is defined as:
Here, adding an L2 penalty is equivalent to MAP for Y = fθ (X) + N (0, σ ) with
∫ 1 a N (0, 1λ ) prior. We next discuss the functional form of the attribution priors
∂f(x′ + α(x − x′ ))
IGi (x, x′ ) := (xi − x′ ) × dα, enforced by our penalties.
α =0 ∂xi

where x is the target input and x′ is baseline input. To avoid specifying x′, we define Pixel attribution prior. Our pixel attribution prior is based on the anisotropic total
the EG value for feature i as: variation loss and is given as follows:
∫ ∫ 1 ∑∑ ℓ
∂f(x′ + α(x − x′ )) Ωpixel (Φ(θ, X)) = |ϕi+1,j − ϕi,j | + |ϕi,j+1 − ϕi,j |,
ℓ ℓ ℓ
EGi (x) := ( (xi − x ) ×

dα ) pD (x′ )dx′ ,
x ′ α =0 ∂xi ℓ i,j

where D is the underlying data distribution. Since EG is also a diagonal path where ϕi,j is the attribution for the i, jth pixel in the ℓ-th training image.

method, it satisfies the same axioms as IG44. Directly integrating over the training Research shows46 that this penalty is equivalent to placing zero-mean, i.i.d.,
distribution is intractable; therefore, we instead reformulate the integrals as Laplace-distributed priors on the differences between adjacent pixel values, that is,
expectations: ϕℓi+1,j − ϕℓi,j ≈ Laplace (0, λ−1 ) and ϕℓi,j+1 − ϕℓi,j ≈ Laplace (0, λ−1 ). Reference 46
∂f(x′ + α × (x − x′ )) does not call our penalty ‘total variation’, but it is in fact the widely used anisotropic
EGi (x) := Ex′ ∼D,α∼U(0,1) [ (xi − x′ ) × ]. version of total variation and is directly implemented in Tensorflow47–49.
∂xi
This expectation-based formulation lends itself to a natural, sampling based Graph attribution prior. For our graph attribution prior, we used a protein–
approximation method: (1) draw samples of x′ from the training dataset and α protein or gene–gene interaction network and represented these networks as
from U(0, 1), (2) compute the value inside the expectation for each sample and a weighted, undirected graph. Formally, assume we have a weighted adjacency
(3) average over samples. For a pseudocode description of EG, see Supplementary matrix W ∈ Rp+×p for an undirected graph, where the entries encode our prior
Section C. belief about the pairwise similarity of the importances between two features. For
EG also satisfies a set of important interpretability axioms: implementation a biological network, Wi,j encodes either the probability or strength of interaction
invariance, sensitivity, completeness, linearity and symmetry preserving. between the ith and jth genes (or proteins). We encouraged similarity along graph
• Implementation invariance states that two networks with outputs that are edges by penalizing the squared Euclidean distance between each pair of feature
equal over all inputs should have equivalent attributions. Any attribution attributions in proportion to how similar we believe them to be. Using the graph
method based on the gradients of a network will satisfy this axiom2, meaning Laplacian (LG = D − W), where D is the diagonal degree matrix of the weighted
that IG, EG and gradients will all be implementation invariant. graph, this becomes:
• Sensitivity (sometimes called dummy) states that when a model does not ∑
Ωgraph (Φ(θ, X)) = Wi,j (ϕ̄i − ϕ̄j )2 = ϕ̄T LG ϕ̄.
depend on a feature at all, it receives zero importance. IG, EG and gradients all
i,j
satisfy sensitivity because the gradient with respect to an irrelevant feature will
be zero everywhere. In this case, we choose to penalize global rather than local feature attributions.
• Completeness states that the attributions should sum to the difference between We define ϕ̄i to be the importance of feature i across all samples in our dataset,
the output of a function at the input to be explained and the output of that where this global attribution is calculated
∑ as the average magnitude of the feature
function at a baseline. Gradients do not satisfy completeness due to saturation attribution across all samples: ϕ̄i = n1 nℓ=1 |ϕℓi |. Just as the image penalty is
at the inputs; elements such as ReLUs may cause gradients to be zero, making equivalent to placing a Laplace prior on adjacent pixels in a regular graph, the
completeness impossible2. IG and EG both satisfy completeness due to the graph penalty Ωgraph is equivalent to placing a Gaussian prior on adjacent features
gradient theorem (fundamental theorem of calculus for line integrals)2. For in an arbitrary graph with Laplacian LG (ref. 46).
EG, the function being integrated is the expectation of the model’s output, so
completeness means that the attributions sum to the difference between the Sparse attribution prior. Our sparsity prior uses the Gini coefficient G as a penalty,
model’s output for the input and the model’s output averaged over all possible which is written:
baselines.
p ∑
∑ p
• Linearity states that for a model that is a linear combination of two submodels |ϕ̄i − ϕ̄j |
f(x) = af1(x) + bf2(x), where a and b are arbitrary scalars, the attributions are a i= 1 j = 1
linear combination of the submodels’ attributions ϕ(x) = aϕ1(x) + bϕ2(x). This Ωsparse (Φ(θ, X)) = − p
= −2G(ϕ̄),

will hold for IG, EG and gradients because gradients are linear. n ϕ̄i
• Symmetry preserving states that symmetric variables with identical values i=1

should achieve identical attributions. IG is symmetry preserving since it is a By taking exponentials of this function, we find that minimizing the sparsity
straight line path method, and EG will also be symmetry preserving, as a sym- regularizer is equivalent to maximizing likelihood under a prior proportional to the
metric function of symmetric functions will itself be symmetrical2. following:
Unlike previous attribution methods, EG is explicitly designed for natural  
batched training. This enables an order of magnitude increase in computational p p  1 
∏ ∏  
efficiency relative to previous approaches for training with attribution priors. We exp  p |ϕ̄i − ϕ̄j | ,
further improve performance by reducing the need for additional data reading. ∑ 
Specifically, for each input in a batch of inputs, we need k additional inputs to
i=1 j=1 ϕ̄i
i= 1
calculate EG attributions for that input batch. As long as k is smaller than the batch
size, we can avoid any additional data reading by re-using the same batch of input To our knowledge, this prior does not directly correspond to a named distribution.
data as a reference batch, as in ref. 45. We accomplish this by shifting the batch of However, we observe that its maximum value occurs when one ϕ̄i is 1 and all
input k times, such that each input in the batch uses k other inputs from the batch others are 0, and that its minimum occurs when all ϕ̄i are equal. This is similar
as its reference values. to the total variation penalty Ωimage, but it is normalized and has a flipped sign to
encourage differences. The corresponding attribution prior is maximized when
Specific priors. Here we elaborate on the explicit form of the attribution priors global attributions are zero for all but one feature and minimized when attributions
we used in this paper. In general, minimizing the error of a model corresponds to are uniform across features.
maximizing the likelihood of the data under a generative model consisting of the
learned model plus parametric noise. For example, minimizing mean squared error Image model experimental settings. We trained a VGG16 model from scratch
in a regression task corresponds to maximizing the likelihood of the data under the modified for the CIFAR-10 dataset, containing 60,000 coloured 32 × 32-pixel
learned model, assuming Gaussian-distributed errors: images divided into 10 categories, as in ref. 50. To train this network, we used
stochastic gradient descent with an initial learning rate of 0.1 and an exponential
argminθ ||fθ (X) − y||22 = argmaxθ exp (−||fθ (X) − y||22 ) = θ MLE , decay of 0.5 applied every 20 epochs. Additionally, we used a momentum level of
0.9. For augmentation, we shifted each image horizontally and vertically by a pixel
where θMLE is the maximum-likelihood estimate (MLE) of θ under the model shift uniformly drawn from the range [−3, 3], and we randomly rotated each image
Y = fθ (X) + N (0, σ ). by an angle uniformly drawn from the range [−15, 15]. We used a batch size of 128.

628 Nature Machine Intelligence | VOL 3 | July 2021 | 620–631 | www.nature.com/natmachintell


NaTuRe MacHine InTelligence Articles
Before training, we normalized the training dataset to have zero mean and unit Neural networks. We tested a variety of hyperparameter settings and network
variance, and standardized the test set with the mean and variance of the training architectures via validation set performance to choose our best neural networks,
set. We used k = 1 background reference samples for our attribution prior while including the following feed-forward network architectures (where each element
training. When training with attributions over images, we first normalized the in a list denotes the size of a hidden layer): [512, 256], [256, 128], [256, 256] and
per-pixel attribution maps by dividing by the standard deviation before computing [1,000, 100]. We tested a range of L1 penalties on all of the weights of the network,
the total variation; otherwise, the total variation can be made arbitrarily small from 10−7 to 10−2. All models attempted to optimize a least squares loss using the
without changing model predictions by scaling down the pixel attributions close to Adam optimizer, with learning rates again selected by hyperparameter tuning
zero. See Supplementary Section F for more details. ranging from 10−5 to 10−3. Finally, we implemented an early stopping parameter of
We repeated the same experiment as above on MNIST, which contains 20 rounds to select the number of epochs of training (training was stopped after
60,000 black-and-white 28 × 28-pixel images of handwritten digits. We trained a no improvement on validation error for 20 epochs, and the number of epochs
convolutional neural network with two convolutional layers and a single hidden was chosen based on optimal validation set error). We found that the optimal
layer. The convolutional layers each had 5 × 5 filters, a stride length of 1, and 32 architecture (chosen by lowest validation set error) had two hidden layers of size
and 64 filters total. Each convolutional layer was followed by a max pooling layer of 512 and 256, an L1 penalty on the weights of 10−3 and a learning rate of 10−5. We
size 2 with stride length 2. The hidden layer had 1,024 units and a dropout rate of additionally found that 120 was the optimal number of training epochs.
0.5 during training51. Dropout was turned off when calculating the gradients with
respect to the attributions. We trained with the Adam optimizer with the default Attribution prior neural networks. We next applied our attribution prior to the
parameters (learning rate α = 0.001, gradient average decay rate β1 = 0.9, squared neural networks. First, we tuned networks to the optimal conditions described
gradient average decay rate β2 = 0.999, and numerical stability constant ϵ = 10−8)52. above. We then added extra epochs of fine-tuning where we ran an alternating
We trained with an initial learning rate of 0.0001, with an exponential decay of 0.95 minimization of the following objectives:
for every epoch, for a total of 60 epochs. For all models, we trained with a batch 2
size of 50 images and used k = 1 background reference sample per attribution while L(θ;X, y) =∥ fθ (X) − y∥2 + λ ∥ θ ∥1 (2)
training. See Supplementary Section G for more details.
T
Biological experiments. Significance testing of results. To test the difference in L(θ;X) = Ωgraph (Φ(θ, X)) = νϕ̄ LG ϕ̄ (3)
R2 attained by each method, we used a t-test for the means of two independent
samples of scores (as implemented in SciPy)53. This is a two-sided test and can Following ref. , we selected ν to be 100 so that the Ωgraph term would initially be
9

be applied to R2 as R2 is a linear transformation of mean squared error, which equal in magnitude to the least squares and L1 loss terms. We found that five extra
satisfies normality assumptions by the central limit theorem. When we compare epochs of tuning were optimal by validation set error. We drew k = 10 background
the R2 attained from ten independent retrainings of the neural network to the R2 samples for our attributions. To test our attribution prior using gradients as the
attained from ten independent retrainings of the attribution prior model, we find feature attribution method (rather than expected gradients), we followed the exact
that predictive performance is significantly higher for the model with the graph same procedure, only we replaced ϕ̄ with the average magnitude of the gradients
attribution prior (t statistic = 3.59, P = 2.06 × 10−3). rather than the EG.
To ensure that the increased performance in the attribution prior model was
due to real biological information, we replaced the gene-interaction graph with a Graph convolutional networks. We followed the implementation of graph
randomized graph (symmetric matrix with identical number of non-zero entries convolution described in ref. 31. The architectures were searched as follows: in
to the real graph, but entries placed in random positions). We then compared the every network, we first had a single graph convolutional layer (we were limited to
R2 attained from ten independent retrainings of a neural network with no graph one graph convolution layer due to memory constraints on each Nvidia GTX 1080
attribution prior to ten independent retrainings of an neural network regularized Ti GPU that we used), followed by two fully connected layers of sizes (512, 256),
with the random graph and found that test error was not significantly different (512, 128) or (256, 128). We tuned over a wide range of hyperparameters, including
between these two models (t statistic = 1.25, P = 0.23). We also compared to graph L2 penalties on the weights ranging from 10−5 to 10−2, L1 penalties on the weights
convolutional neural networks, and found that our network with a graph attribution ranging from 10−5 to 10−2, learning rates of 10−5 to 10−3 and dropout rates ranging
prior outperformed the graph convolutional neural network (t statistic = 3.30, from 0.2 to 0.8. We found the optimal hyperparameters based on validation set
P = 4.0 × 10−3). Finally, we compared to an L2 penalty applied uniformly across error were two hidden layers of size 512 and size 256, an L2 penalty on the weights
all attributions, and found that this attribution prior did not significantly increase of 10−5, a learning rate of 10−5 and a dropout rate of 0.6. We again used an early
performance from baseline (t statistic = 1.7, P = 0.12, see Supplementary Fig. 15). stopping parameter and found that 47 epochs was the optimal number.

Train/validation/test-set allocation. To increase the number of samples in our Sparsity experiments. Data description and processing. Our sparsity experiments
dataset, we used as a feature the identity of the drug being tested, rather than one used data from the National Health and Nutrition Examination I Survey
of a number of possible output tasks in a multi-task prediction. This follows from (NHANES I)42 and contained 35 variables (expanded to 118 features by
previous literature on training neural networks to predict drug response54. This one-hot encoding of categorical variables) gathered from 13,000 patients. The
yielded 30,816 samples (covering 218 patients and 145 anticancer drugs). Defining measurements included demographic information such as age, sex and BMI as well
a sample as a drug and a patient, however, meant we had to choose carefully how as physiological measurements such as blood, urine and vital sign measurements.
to stratify samples into our train, validation and test sets. While it is perfectly The prediction task was a binary classification of whether the patient was still alive
legitimate in general to randomly stratify samples into these sets, we wanted to (1) or not (0) ten years after data were gathered.
specifically focus on how well our model could learn trends from gene expression Data were mean-imputed and standardized so that each feature had zero
data that would generalize to new patients. Therefore, we stratified samples at a mean and unit variance. For each of the 200 experimental replicates, 100 train and
patient level rather than at the level of individual samples (for example, no samples 100 validation points were sampled uniformly at random; all other points were
from any patient in the test set ever appeared in the training set). We split 20% of allocated to the test set.
the total patients into a test set (6,155 samples) and then split 20% of the training
data into a validation set for hyperparameter selection (4,709 samples). Model. We trained a range of neural networks to predict survival in the NHANES
data. The architecture, nonlinearities and training rounds were all held constant
Model class implementations and hyperparameters tested. LASSO. We used the at values that performed well on an unregularized network, and the type and
scikit-learn implementation of the LASSO55,56. We tested a range of α parameters degree of regularization were varied. All models used ReLU activations and a
from 10−9 to 1, and we found that the optimal value for α was 10−2 by mean squared single output with binary cross-entropy loss; in addition, all models ran for 100
error on the validation set. epochs with a stochastic gradient descent optimizer with learning rate 0.001 on the
size-100 training data. The entire 100-sample training set fit in one batch. Because
Graph LASSO. For our graph LASSO, we used the Adam optimizer in the training set was so small, all of its 100 samples were used for EG attributions
TensorFlow47, with a learning rate of 10−5 to optimize the following loss function: during training and evaluation, yielding k = 100. Each model was trained on a
single GPU on a desktop workstation with 4 Nvidia 1080 Ti GPUs.
2 T
L(w;X, y) =∥ Xw − y∥2 + λ ∥ w∥1 + ν w LG w, (1)
′ ′
Architecture. We considered a range of architectures, including
single-hidden-layer 32-node, 128-node and 512-node networks, two-layer
where w ∈ Rd is the weights vector of our linear model and LG is the graph [128, 32]-node and [512, 128]-node networks, and a three-layer [512, 128, 32]-node
Laplacian of our HumanBase network30. In particular, we downloaded the ‘Top network; we fixed the [512, 128, 32] architecture for future experiments.
Edges’ version of the haematopoietic stem cell network, which was thresholded to Regularizers. We tested a large array of regularizers in addition to those
only have non-zero values for pairwise interactions that had a posterior probability considered in the main text. For details, see Supplementary Section J.1.
greater than 0.1. We used the value of λ′ selected as optimal in the regular LASSO
model (10−2, which corresponds to the α parameter in scikit-learn) and then tuned Hyperparameter tuning. We selected the hyperparameters for our models based on
over ν′ values ranging from 10−3 to 100. We found that a value of 10 was optimal validation performance. We searched all L1, L2, SGL and attribution prior penalties
according to MSE on the validation set. with 121 points sampled on a log scale over [10−7, 105] (Supplementary Fig. 18).

Nature Machine Intelligence | VOL 3 | July 2021 | 620–631 | www.nature.com/natmachintell 629


Articles NaTuRe MacHine InTelligence
Other penalties, not displayed in the main text experiments, are discussed in 10. Schramowski, P. et al. Making deep neural networks right for the right
Supplementary Section J.2. scientific reasons by interacting with their explanations. Nat. Mach. Intell. 2,
476–486 (2020).
Main text methods. Performance and sparsity bar plots. The performance bar 11. Ilyas, A. et al. Adversarial examples are not bugs, they are features. In
graph (Fig. 5a) was generated by plotting mean test ROC AUC of the best model Advances in Neural Information Processing Systems Vol. 32 (NeurIPS, 2019).
of each type (chosen by validation ROC AUC) averaged over each of the 200 12. Liu, F. & Avci, B. Incorporating priors with feature attribution on text
subsampled datasets, with confidence intervals given by 2 times the standard error classification. In Proc. of the 57th Annual Meeting of the Association for
over the 200 replicates. The sparsity bar graph (Fig. 5c) was constructed using the Computational Linguistics (ACL) 6274–6283 (2019).
same process, but with Gini coefficients rather than ROC AUCs. 13. Chen, J., Wu, X., Rastogi, V., Liang, Y. & Jha, S. Robust attribution
Feature importance distribution plot. The distribution of feature importances regularization. In Advances in Neural Information Processing Systems Vol. 32
was plotted in the main text as a Lorenz curve (Fig. 5, bottom right): for each model, (NeurIPS, 2019).
the features were sorted by global attribution value ϕ̄i, and the cumulative normalized 14. Rieger, L., Singh, C., Murdoch, W. J. & Yu, B. Interpretations are useful:
value of the lowest q features was plotted, from 0 at q = 0 to 1 at q = p. A lower area penalizing explanations to align neural networks with prior knowledge. In
under the curve indicates more features had relatively small attribution values, Proc. 37th International Conference on Machine Learning (eds. Daumé III, H.
indicating that the model was sparser. Because 200 replicates were run on small & Singh, A.) 8116–8126 (ICML, 2020).
subsampled datasets, the Lorenz curve for each model was plotted using the averaged 15. LeCun, Y., Cortes, C. & Burges, C. MNIST Handwritten Digit Database
mean absolute sorted feature importances over all replicates. Thus, for a given model (AT&T Labs) http://yann.lecun.com/exdb/mnist (2010)
type, the q = 1 point represented the mean absolute feature importance of the least 16. Yu, F., Xu, Z., Wang, Y., Liu, C. & Chen, X. Towards robust training of neural
important feature averaged over each replicate, q = 2 added the mean importance for networks by regularizing adversarial gradients. Preprint at https://arxiv.org/
the second least important feature averaged over each replicate, and so on. abs/1805.09370 (2018).
Performance versus sparsity plot. Validation ROC AUC and model sparsity 17. Jakubovitz, D. & Giryes, R. Improving DNN robustness to adversarial attacks
were calculated for each of the 121 regularization strengths and averaged over each using Jacobian regularization. In Proc. European Conference on Computer
of the 200 replicates. These were plotted on a scatterplot to show the possible range Vision (ECCV) (eds. Ferrari, V., Hebert, M., Sminchisescu, C. & Weiss, Y.)
of model sparsities and ROC AUC performances (Fig. 5, top right) as well as the 514–529 (ECCV, 2018).
trade-off between sparsity and performance. 18. Roth, K., Lucchi, A., Nowozin, S. & Hofmann, T. Adversarially robust
Statistical significance. Statistical significance of the sparse attribution training through structured gradient regularization. Preprint at https://arxiv.
prior performance was assessed by comparing the test ROC AUCs of the sparse org/abs/1805.08736 (2018).
attribution prior models on each of the 200 subsampled datasets to those of the 19. Selvaraju, R. R. et al. Grad-CAM: visual explanations from deep networks via
other models (L1 gradients, L1 weights, SGL and unregularized). Significance was gradient-based localization. In Proc. IEEE International Conference on
assessed by two-sided paired-samples t-test, paired by subsampled dataset. The Computer Vision 618–626 (IEEE, 2017).
same process was used to calculate the significance of model sparsity as measured 20. Ross, A. S. & Doshi-Velez, F. Improving the adversarial robustness and
by the Gini coefficient. Detailed tables of the resulting P values and test statistics t interpretability of deep neural networks by regularizing their input gradients.
are shown in Supplementary Section J.3. In Thirty-second AAAI Conference on Artificial Intelligence Vol. 32 1 (AAAI,
2018).
Data availability 21. Smilkov, D., Thorat, N., Kim, B., Viégas, F. & Wattenberg, M. Smoothgrad:
The data for all experiments and figures in the paper are publicly available. A removing noise by adding noise. Preprint at https://arxiv.org/abs/1706.03825
downloadable version of the dataset used for the sparsity experiment, as well as (2017).
links to download the datasets used in the image and graph prior experiments, 22. Fong, R. C. & Vedaldi, A. Interpretable explanations of black boxes by
is available at https://github.com/suinleelab/attributionpriors. Data for the meaningful perturbation. In Proc. IEEE International Conference on Computer
benchmarks were published as part of ref. 57 and can be accessed at https://github. Vision 3429–3437 (IEEE, 2017).
com/suinleelab/treeexplainer-study/tree/master/benchmark. 23. Krizhevsky, A. et al. Learning Multiple Layers of Features from Tiny Images
Technical Report (Citeseer, 2009).
24. Simonyan, K. & Zisserman, A. Very deep convolutional networks for
Code availability large-scale image recognition. In 3rd International Conference on Learning
Implementations of attribution priors for Tensorflow and PyTorch are available at Representations (eds. Bengio, Y. & LeCun, Y.) (ICLR, 2015).
https://github.com/suinleelab/attributionpriors. This repository also contains code 25. Recht, B., Roelofs, R., Schmidt, L. & Shankar, V. Do ImageNet classifiers
reproducing main results from the paper. The specific version of code used in this generalize to ImageNet? Proc. of the 36th International Conference on Machine
paper is archived at ref. 58. Learning Vol. 97, 5389–5400 (PMLR, 2019).
26. Tsipras, D., Santurkar, S., Engstrom, L., Turner, A. & Madry, A. Robustness
Received: 23 August 2020; Accepted: 12 April 2021; may be at odds with accuracy. In 7th International Conference on Learning
Published online: 31 May 2021 Representations (ICLR, 2019).
27. Zhang, H. et al. Theoretically principled trade-off between robustness and
accuracy. In Proc. 36th International Conference on Machine Learning Vol. 97,
References 7472–7482 (PMLR, 2019).
1. Lundberg, S. M. & Lee, S.-I. A unified approach to interpreting model 28. Cheng, W., Zhang, X., Guo, Z., Shi, Y. & Wang, W. Graph-regularized dual
predictions. In Advances in Neural Information Processing Systems Vol. 30, Lasso for robust eQTL mapping. Bioinformatics 30, i139–i148 (2014).
4765–4774 (NeurIPS, 2017). 29. Tyner, J. W. et al. Functional genomic landscape of acute myeloid leukaemia.
2. Sundararajan, M., Taly, A. & Yan, Q. Axiomatic attribution for deep networks. Nature 562, 526–531 (2018).
In Proc. 34th International Conference on Machine Learning Vol. 70, 30. Greene, C. S. et al. Understanding multicellular function and disease with
3319–3328 (Journal of Machine Learning Research, 2017). human tissue-specific networks. Nat. Genet. 47, 569–576 (2015).
3. Štrumbelj, E. & Kononenko, I. Explaining prediction models and individual 31. Kipf, T. N. & Welling, M. Semi-supervised classification with graph
predictions with feature contributions. Knowl. Inf. Syst. 41, 647–665 (2014). convolutional networks. In 5th International Conference on Learning
4. Datta, A., Sen, S. & Zick, Y. Algorithmic transparency via quantitative input Representations (ICLR, 2017).
influence: theory and experiments with learning systems. In 2016 IEEE 32. Subramanian, A. et al. Gene set enrichment analysis: a knowledge-based
Symposium on Security and Privacy (SP) 598–617 (IEEE, 2016). approach for interpreting genome-wide expression profiles. Proc. Natl Acad.
5. Lundberg, S. M. et al. From local explanations to global understanding with Sci. USA 102, 15545–15550 (2005).
explainable AI for trees. Nat. Mach. Intell. 2, 56–67 (2020). 33. Benjamini, Y. & Hochberg, Y. Controlling the false discovery rate: a practical
6. Lundberg, S. M. et al. Explainable machine-learning predictions for the and powerful approach to multiple testing. J. R. Stat. Soc. B 57, 289–300
prevention of hypoxaemia during surgery. Nat. Biomed. Eng. 2, 749–760 (2018). (1995).
7. Sayres, R. et al. Using a deep learning algorithm and integrated gradients 34. Liu, J. et al. Meis1 is critical to the maintenance of human acute myeloid
explanation to assist grading for diabetic retinopathy. Ophthalmology 126, leukemia cells independent of MLL rearrangements. Ann. Hematol. 96,
552–564 (2019). 567–574 (2017).
8. Zech, J. R. et al. Variable generalization performance of a deep learning 35. Valk, P. J. M. et al. Prognostically useful gene-expression profiles in acute
model to detect pneumonia in chest radiographs: a cross-sectional study. myeloid leukemia. N. Engl. J. Med. 350, 1617–1628 (2004).
PLoS Med. 15, e1002683 (2018). 36. Feng, J. & Simon, N. Sparse-input neural networks for high-dimensional
9. Ross, A. S., Hughes, M. C. & Doshi-Velez, F. Right for the right reasons: nonparametric regression and classification. Preprint at https://arxiv.org/
training differentiable models by constraining their explanations. In Proc. abs/1711.07592 (2017).
26th International Joint Conference on Artificial Intelligence 2662–2670 37. Scardapane, S., Comminiello, D., Hussain, A. & Uncini, A. Group sparse
(IJCAI, 2017). regularization for deep neural networks. Neurocomputing 241, 81–89 (2017).

630 Nature Machine Intelligence | VOL 3 | July 2021 | 620–631 | www.nature.com/natmachintell


NaTuRe MacHine InTelligence Articles
38. Ross, A., Lage, I. & Doshi-Velez, F. The neural lasso: local linear sparsity for 54. Preuer, K. et al. DeepSynergy: predicting anti-cancer drug synergy with deep
interpretable explanations. In Workshop on Transparent and Interpretable learning. Bioinformatics 34, 1538–1546 (2018).
Machine Learning in Safety Critical Environments, 31st Conference on Neural 55. Tibshirani, R. Regression shrinkage and selection via the Lasso. J. R. Stat. Soc.
Information Processing Systems (2017). B 58, 267–288 (1996).
39. Shrikumar, A., Greenside, P. & Kundaje, A. Learning important features 56. Pedregosa, F. et al. Scikit-learn: machine learning in Python. J. Mach. Learn.
through propagating activation differences. In Pro. 34th International Res. 12, 2825–2830 (2011).
Conference on Machine Learning Vol. 70, 3145–3153 (Journal of Machine 57. Lundberg, S. M. et al. Explainable AI for trees:from local explanations to
Learning Research, 2017). global understanding. Preprint at https://arxiv.org/abs/1905.04610 (2019).
40. Hurley, N. & Rickard, S. Comparing measures of sparsity. IEEE Trans. Inf. 58. Sturmfels, P., Erion, G. & Janizek, J. D. suinleelab/attributionpriors: Nature
Theory 55, 4723–4741 (2009). Machine Intelligence code. Zenodo https://doi.org/10.5281/zenodo.4608599
41. Zonoobi, D., Kassim, A. A. & Venkatesh, Y. V. Gini index as sparsity measure (2021).
for signal reconstruction from compressive samples. IEEE J. Sel. Top. Signal
Process. 5, 927–932 (2011).
42. Miller, H. W. Plan and Operation of the Health and Nutrition Examination Acknowledgements
Survey, United States, 1971–1973 DHEW publication no. 79-55071 (PHS) The results published here are partially based on data generated by the Cancer Target
(Department of Health, Education, and Welfare, 1973). Discovery and Development (CTD2) Network (https://ocg.cancer.gov/programs/ctd2/
43. Binder, A., Montavon, G., Lapuschkin, S., Müller, K.-R. & Samek, W. data-portal) established by the National Cancer Institute’s Office of Cancer Genomics.
Layer-wise relevance propagation for neural networks with local The authors received funding from the National Science Foundation (DBI-1759487
renormalization layers. In International Conference on Artificial Neural (S.-I.L.), DBI-1552309 (J.D.J., G.E., S.-I.L.), DGE-1256082 (S.M.L.)); American Cancer
Networks (eds. Villa, A.E.P., Masulli, P. & Rivero, A.J.P.) 63–71 (Springer, Society (RSG-14-257-01-TBG (J.D.J., P.S., S.-I.L)); and National Institutes of Health
2016). (R01AG061132 (J.D.J, P.S., S.-I.L), R35GM128638 (G.E., S.-I.L), F30HL151074-01 (G.E.,
44. Friedman, E. J. Paths and consistency in additive cost sharing. Int. J. Game S.-I.L), 5T32GM007266-46 (J.D.J, G.E.)).
Theory 32, 501–518 (2004).
45. Zhang, H., Cisse, M., Dauphin, Y. N. & Lopez-Paz, D. mixup: beyond Author contributions
empirical risk minimization. In 6th International Conference on Learning G.E., J.D.J., P.S. and S.M.L. conceived the study. G.E., J.D.J. and P.S. designed algorithms
Representations (ICLR, 2018). and experiments. P.S. and J.D.J. implemented core libraries for the research. G.E., J.D.J.
46. Bardsley, J. M. Laplace-distributed increments, the Laplace prior, and and P.S. wrote code for and ran the experiments, plotted figures and contributed to
edge-preserving regularization. J. Inverse Ill Posed Probl. 20, 271–285 (2012). the writing. S.M.L. contributed to the writing. S.-I.L. supervised research and method
47. Abadi, M. et al. Tensorflow: a system for large-scale machine learning. In development, and contributed to the writing.
12th USENIX Symposium on Operating Systems Design and Implementation
(OSDI ’16) 265–283 (2016).
48. Lou, Y., Zeng, T., Osher, S. & Xin, J. A weighted difference of anisotropic and Competing interests
isotropic total variation model for image processing. SIAM J. Imaging Sci. 8, The authors declare no competing interests.
1798–1823 (2015).
49. Shi, Y. & Chang, Q. Efficient algorithm for isotropic and anisotropic Additional information
total variation deblurring and denoising. J. Appl. Math. 2013, Supplementary information The online version contains supplementary material
797239 (2013). available at https://doi.org/10.1038/s42256-021-00343-w.
50. Liu, S. & Deng, W. Very deep convolutional neural network based image
classification using small training sample size. In 2015 3rd IAPR Asian Correspondence and requests for materials should be addressed to S.-I.L.
Conference on Pattern Recognition (ACPR) 730–734 (IEEE, 2015). Peer review information Nature Machine Intelligence thanks Ronny Luss, Andrew Ross
51. Srivastava, N., Hinton, G., Krizhevsky, A., Sutskever, I. & Salakhutdinov, R. and the other, anonymous, reviewer(s) for their contribution to the peer review of this
Dropout: a simple way to prevent neural networks from overfitting. J. Mach. work.
Learn. Res. 15, 1929–1958 (2014). Reprints and permissions information is available at www.nature.com/reprints.
52. Kingma, D. P. & Ba, J. In 3rd International Conference on Learning
Representations (eds. Bengio, Y. & LeCun, Y.) (ICLR, 2015). Publisher’s note Springer Nature remains neutral with regard to jurisdictional claims in
53. Virtanen, P. et al. SciPy 1.0: fundamental algorithms for scientific computing published maps and institutional affiliations.
in Python. Nat. Methods 17, 261–272 (2020). © The Author(s), under exclusive licence to Springer Nature Limited 2021

Nature Machine Intelligence | VOL 3 | July 2021 | 620–631 | www.nature.com/natmachintell 631

You might also like