Improving Performance of Deep Learning Models
Improving Performance of Deep Learning Models
https://doi.org/10.1038/s42256-021-00343-w
Recent research has demonstrated that feature attribution methods for deep networks can themselves be incorporated into
training; these attribution priors optimize for a model whose attributions have certain desirable properties—most frequently,
that particular features are important or unimportant. These attribution priors are often based on attribution methods that
are not guaranteed to satisfy desirable interpretability axioms, such as completeness and implementation invariance. Here we
introduce attribution priors to optimize for higher-level properties of explanations, such as smoothness and sparsity, enabled
by a fast new attribution method formulation called expected gradients that satisfies many important interpretability axioms.
This improves model performance on many real-world tasks where previous attribution priors fail. Our experiments show that
the gains from combining higher-level attribution priors with expected gradients attributions are consistent across image, gene
expression and healthcare datasets. We believe that this work motivates and provides the necessary tools to support the wide-
spread adoption of axiomatic attribution priors in many areas of applied machine learning. The implementations and our results
have been made freely available to academic communities.
R
ecent work on interpreting machine learning (ML) models greatly expand both the number of ways that a human-in-the-loop
has focused on feature attribution methods. Given an input can influence deep models and the precision with which they can
datum, a model and a prediction, such methods assign a num- do so. However, two drawbacks limit this method’s applicability to
ber to each input feature that represents how important the feature real-world problems. First, gradients do not satisfy the same theo-
was for making the prediction. Current research also investigates retical guarantees as modern feature attribution methods. This
the axioms that attribution methods should satisfy1–4 and how they leads to well-known problems such as saturation: operations, such
provide insight into model behaviour5–8. Feature attribution meth- as rectified linear units (ReLUs) and sigmoids, which have large
ods often reveal problems in a model or dataset. For example, a flat ‘saturated’ regions, can lead to zero gradient attribution even
model may place too much importance on undesirable features, for important features2. Second, it can be difficult to specify which
rely on many features when sparsity is desired or be sensitive to features should be important in a binary manner.
high-frequency noise. In such cases, humans often have a prior Additional recent work discusses the need for priors that incor-
belief about how a model should treat input features but find it dif- porate human intuition to develop robust and interpretable mod-
ficult to mathematically encode this prior for neural networks in els11. Still, it remains challenging to encode priors such as ‘have
terms of the model parameters. smoother attributions across an image’ or ‘treat this group of fea-
One method to address such problems is what we call an attri- tures similarly’ by penalizing a model’s input gradients or param-
bution prior: if it is possible for explanations to reveal problems in eters. Some recent attribution priors have proposed regularizing
a model, then constraining the model’s explanations during train- integrated gradients (IG) attributions12,13. While promising, this
ing can help the model avoid such problems. It is worth noting work suffers from three major weaknesses: it does not clearly dem-
that the vast majority of feature attribution methods focus exclu- onstrate improvements over gradient-based attribution priors,
sively on explaining why a given prediction was made. Only a very it penalizes attribution deviation from a target value rather than
small number of papers have investigated incorporating attribu- encoding sophisticated priors such as those we mention above,
tions themselves into model training. The first such paper, by Ross and it imposes a large computational cost by training with tens to
et al.9, used a binary indicator of whether each feature should or hundreds of reference samples per batch. A contemporary method
should not be important for making predictions on each sample in called contextual decomposition explanation penalization (CDEP)
the dataset and penalized the gradients of unimportant features. A uses a framework similar to attribution priors and penalizes expla-
very recent publication successfully used the gradient-based prior nations generated by the contextual decomposition method14.
of Ross et al. as part of a human-in-the-loop strategy to improve Unlike all other interpretability methods discussed in this paper,
model generalization performance and user trust, as well as con- CDEP penalizes explanations for pre-specified groups of features,
tributing their own model-agnostic method for penalizing feature meaning it is best suited for a different set of problems than we con-
importances10. Such results create a clear synergy with our study, sider. More discussion of CDEP can be found in ‘Attribution pri-
which improves the quality of calculated feature importances and ors are a flexible framework for encoding domain knowledge’ and
develops new forms of attribution priors. This has the potential to Supplementary Sections A and B.
Paul G. Allen School of Computer Science and Engineering, University of Washington, Seattle, WA, USA. 2Medical Scientist Training Program,
1
University of Washington, Seattle, WA, USA. 3Microsoft Research, Seattle, WA, USA. 4These authors contributed equally: Gabriel Erion, Joseph D. Janizek,
Pascal Sturmfels. ✉e-mail: [email protected]
where the standard regularization term has simply been replaced EG outperforms other attribution methods. Attribution priors
with an arbitrary, differentiable penalty function on the feature involve using feature attributions not just as a post-hoc analy-
attributions. sis method but also as a key part of the training objective. Thus,
While feature attributions have previously been used in train- it is essential to guarantee as much as possible that the attribution
ing (more details in ‘Previous attribution priors’ in Methods)9,12, method used will produce high-quality attributions and run fast
our approach offers two novel components. First, we demonstrate enough to be calculated for each training batch. We propose an
x1
x2
N training samples
Single Fixed Input
reference interpolation sample
x3
IG
xn
x2
N training samples
x3
x
EG
xn
b Trombone Gradients IG EG
c Digits Gradients IG EG
Fig. 1 | EG is a feature attribution method designed to be regularized during training. a, A comparison of our method, EG, to IG as both a post-hoc explanation
method (left), and as a differentiable feature attribution to be penalized during training to enforce attribution priors (right). b, Comparison of saliency maps
generated by three different attribution methods on an image from the ImageNet dataset. The saliency maps demonstrate how the IG attribution method fails
to highlight black pixels as important when black is used as a baseline input, while EG is capable of highlighting the black pixels in these images as important.
c, Comparison of saliency maps for the same three attribution methods for two MNIST digits. Again, IG fails to highlight potentially relevant image regions
(like the empty middle of the 0 or the empty region at the top of the 4, which might make the digit resemble a 9 if it were filled in).
axiomatic feature attribution method called expected gradients output for a given sample) and implementation invariance (the
(EG), which avoids problems with existing methods and is natu- attributions are identical for any of the infinite possible implemen-
rally suited to being incorporated into training. EG extends the IG tations of the same function). Because these methods satisfy com-
method2, and like IG, satisfies a variety of desirable interpretability pleteness, they are not subject to the problems with input saturation
axioms such as completeness (the feature attributions sum to the that affect gradient attributions. Because these methods satisfy
a Pixel
b
1.0
Original attribution Method
image Baseline prior 0.9 Pixel attribution prior
(total variation of EG)
0.8 Total variation of gradients
Baseline (no prior)
0.7
Test accuracy
0.6
0.5
0.4
0.3
0.2
0.1
0
0 0.2 0.4 0.6 0.8 1.0 1.2 1.4 1.6 1.8 2.0
Fig. 2 | Pixel attribution prior improves saliency map smoothness and increases robustness of MNIST classifier to noise. a, EG attributions (from 100
samples) on MNIST for both an unregularized model and a model trained with an attribution prior regularized using EG. The latter achieves visually
smoother attributions, and it better highlights how the network classifies digits (for example, the top part of the 4 being very important). Unlike previous
methods that take additional steps to smooth saliency maps after training21,22, these are unmodified saliency maps directly from the learned model.
b, Training with an attribution prior on total variance of EG attributions induces robustness to Gaussian noise without specifically training for robustness.
This robustness greatly exceeds that provided by an attribution prior on the total variance of model gradients. Shaded bars around each line indicate
standard deviation of the accuracy results; however, the bars are small enough to be indistinguishable in this plot.
only post-process attribution maps but do not change model method, we trained five models with different random initializa-
behaviour19,21,22. Such techniques may not be faithful to the original tions. In Figs. 2 and 3, we plot the mean and standard deviation of
model11. In this section, we describe how we applied our framework test accuracy on MNIST and CIFAR-10, respectively, as a function
to train image models with naturally smoother attributions. of standard deviation of added Gaussian noise. The figures show
To regularize pixel-level attributions, we used the following intu- that our regularized model is more robust to noise than both the
ition: neighbouring pixels should have a similar impact on an image baseline and gradient-based models.
model’s output. To encode this intuition, we chose a total variation Both the robustness and more intuitive saliency maps our
loss on pixel-level attributions (see ‘Specific priors’ in Methods for method provides come at the cost of reduced test set accuracy
more detail). We applied this pixel smoothness attribution prior (0.93 ± 0.002 for the baseline versus 0.85 ± 0.003 for pixel attribu-
to the Modified National Institute of Standards and Technology tion prior model on CIFAR-10). Mathematically, adding a penalty
(MNIST) database, containing handwritten digits classified from 0 term to the optimization objective should only ever reduce train-
to 9, and the Canadian Institute for Advanced Research (CIFAR)-10 ing set performance; it is reasonable that in many cases this can
dataset, containing colour images classified into 10 categories such lead to a reduction in test-set performance as well. However, test
as cats, dogs and cars15,23. On MNIST we trained a two-layer convo- accuracy is not the only metric of interest for image classifiers.
lutional neural network; for CIFAR-10 we trained a VGG16 network The trade-off between robustness and accuracy that we observe
from scratch (see ‘Image model experimental settings’ in Methods is consistent with previous work that suggests image classifi-
for more details)24. In both cases, we optimized hyperparameters for ers trained solely to maximize test accuracy rely on features that
the baseline model without an attribution prior. To choose λ, we are brittle and difficult to interpret11,26,27. Despite this trade-off,
searched over values in [10−20, 10−1] and chose the λ that minimized we find that at a stricter hyperparameter cutoff for λ on CIFAR-
the attribution prior penalty and achieved a test accuracy within 1% 10—within 1% test accuracy of the baseline, rather than 10%—our
of the baseline model for MNIST and 10% for CIFAR-10. Figures 2 methods still achieve modest but significant robustness relative
and 3 show EG attribution maps for both the baseline and the model to the baseline. We also evaluated our method against several
regularized with an attribution prior on five randomly selected test other attribution priors including IG and, for ablation purposes,
images on MNIST and CIFAR-10, respectively. In all examples, the single-reference EG (Supplementary Figs. 10 and 11). We found
attribution prior yields a model with visually smoother attributions. that the pixel attribution prior outperformed standard IG and
Remarkably, in many instances, smoother attributions better high- that most of this additional performance was due to our random
light the target object’s structure. interpolation. Both the pixel attribution prior and single-reference
Recent work has suggested that image classifiers are brittle EG were much more robust than all other methods; however, only
to small domain shifts: small changes in the underlying distribu- the pixel attribution prior, which used multiple references, could
tion of the training and test set can lead to large reductions in test highlight important foreground and background regions in addi-
accuracy25. To simulate a domain shift, we applied Gaussian noise tion to providing robustness and smoothness. For details of the EG
to images in the test set and re-evaluated the performance of the versus IG comparison, results at different hyperparameter thresh-
regularized and baseline models. As an adaptation of ref. 9, we also olds, more details on our training procedure and additional experi-
compared the attribution prior model with regularizing the total ments on MNIST, CIFAR-10 and ImageNet, see Supplementary
variation of gradients with the same criteria for choosing λ. For each Sections E–H.
Test accuracy
0.6
Cat 0.5
0.4
0.3
Frog
0.2
0.1
Deer 0
0 0.1 0.2 0.3 0.4 0.5 0.6 0.7 0.8 0.9 1.0
Ship
Fig. 3 | Pixel attribution prior improves saliency map smoothness and increases robustness of CIFAR-10 classifier to noise. a, EG attributions (from 100
samples) on CIFAR-10 for both the baseline model and the model trained with an attribution prior for five randomly selected images classified correctly by
both models. Training with an attribution prior generates visually smoother attribution maps in all cases. Notably, these smoothed attributions also appear
more localized towards the object of interest. b, Training with an attribution prior on total variance of EG attributions induces robustness to Gaussian noise,
achieving more than double the accuracy of the baseline at high noise levels. This robustness is not achievable by choosing total variation of gradients as
the attribution function. Shaded bars around each line indicate standard deviation of the accuracy results.
A graph attribution prior improves anticancer drug response model’s gradients, rather than a penalty on the axiomatically correct
prediction. In the image domain, our attribution prior took the expected gradient feature attribution, does not perform significantly
form of a penalty encouraging smoothness over adjacent pixels. In better than a baseline neural network. We also observe substantially
other domains, there may be prior information about specific rela- improved test performance when using the prior graph information
tionships between features that can be encoded as a graph (such as to regularize a linear LASSO model. Finally, we note that our graph
social networks, knowledge graphs or protein–protein interactions). attribution prior neural network significantly outperforms graph
For example, previous work in bioinformatics has shown that pro- convolutional neural networks, a recent method for utilizing graph
tein–protein interaction networks contain valuable information for information in deep neural networks31.
improving performance on biological prediction tasks28. Therefore, To find out whether our model’s attributions match biological
in this domain, we regularized attributions to be smooth over the domain knowledge, we first compared the list of top genes gener-
protein–protein feature graph analogously to the regular graph of ated by our network trained with a graph attribution prior (ranked
pixels in the image. by mean absolute feature attribution) to a ‘ground truth’ list of
Incorporating the Ωgraph attribution prior not only led to a model AML-relevant genes found by querying the GeneCards database
with more reasonable attributions but also improved predictive per- (Fig. 4b). When we count the number of AML-relevant genes at
formance by letting us incorporate prior biological knowledge into each position in our network’s top gene list and compare this to the
the training process. We downloaded publicly available gene expres- number of AML-relevant genes at each position in a standard neu-
sion and drug response data for patients with acute myeloid leu- ral network’s top gene list, we see that the graph attribution prior
kaemia (AML, a type of blood cancer) and tried to predict patients’ network captures significantly more biologically relevant genes.
drug response from their gene expression29. For this regression In addition, to check for biological pathway-level enrich-
task, an input sample was a patient’s gene expression profile plus ments, we conducted gene set enrichment analysis (a modified
a one-hot encoded vector indicating which drug was tested in that Kolmogorov–Smirnov test). We measured whether our top genes,
patient, while the label we tried to predict was drug response (mea- ranked by mean absolute feature attribution, were enriched for
sured by IC50, a continuous value representing the concentration of membership in any pathways (see ‘Biological experiments’ in
the drug required to kill half of the patient’s tumour cells). To define Methods and Supplementary Section I for more detail, including
the graph used by our prior, we downloaded the tissue-specific the top pathways for each model)32. We find that the neural net-
gene-interaction graph for the tissue most closely related to AML in work with the tissue-specific graph attribution prior captures far
the HumanBase database30. more biologically relevant pathways (increased number of signifi-
A two-layer neural network trained with our graph attribu- cant pathways after false discovery rate correction) than a neural
tion prior (Ωgraph) significantly outperforms all other methods in network without attribution priors33. Furthermore, the pathways
terms of test set performance as measured by R2, which indicates our model uses more closely match biological expert knowledge,
the fraction of the variance in the output explained by the model that is, they included prognostically useful AML gene expression
(Fig. 4, see ‘Biological experiments’ in Methods for significance profiles as well as important AML-related transcription factors
testing). Unsurprisingly, when we replace the biological graph (Supplementary Section I)34,35. These results are expected, given that
from HumanBase with a randomized graph, we find that the test neural networks trained without priors can learn a relatively sparse
performance is no better than the performance of a neural network basis of genes that will not enrich for specific pathways (for exam-
trained without any attribution prior. Extending the method pro- ple, a single gene from each correlated pathway), while those trained
posed in ref. 9 by applying our new graph prior as a penalty on the with our graph prior will spread credit among functionally related
a b
350
0.40
300
0.39
250
With graph
AML-related genes
0.38 attribution prior
200
R2
0.37
150
Standard
neural network
0.36 100
0.35 50
0.34 0
Graph Graph prior Random Graph Neural Graph LASSO 200 400 600 800 1,000 1,200 1,400 1,600 1,800
attribution with attribution convolution network LASSO
Model top genes
prior gradients prior
Existing methods
Fig. 4 | Graph attribution prior improves test accuracy and biological relevance of anticancer drug response prediction model. a, A neural network
trained with our graph attribution prior (bold) attains the best test performance, while one trained with the same graph penalty on the gradients (italics,
adapted from ref. 9) does not perform significantly better than a standard neural network (error bars indicate the extent of the bootstrapped 95%
confidence interval of the mean test set R2 value, over ten retrainings of the model on random re-splits of the data). b, A neural network trained with our
graph attribution prior gives more weight to AML-relevant genes than a standard neural network trained without the graph attribution prior (solid line
indicates average over ten random re-splits of the data and retrainings of the model, error bands indicate the extent of the bootstrapped 95%
confidence interval).
genes. This demonstrates the graph prior’s value as an accurate and be trained with very few labelled patient samples or reduce cost by
efficient way to encourage neural networks to treat functionally accurately risk-stratifying patients using few lab tests. We randomly
related genes similarly. sampled training and validation sets of only 100 patients each, plac-
ing all other patients in the test set, and ran each experiment 200
A sparsity prior improves performance with limited training times with a new random sample to average out variance. We built
data. Feature selection and sparsity are popular ways to alleviate three-layer binary classifier neural networks regularized using L1,
the curse of dimensionality, facilitate interpretability and improve SGL and sparse attribution prior penalties to predict patient sur-
generalization by building models that use a small number of input vival, as well as an L1 penalty on gradients adapted for global spar-
features. A straightforward way to build a sparse deep model is to sity from refs. 9,38. The regularization strength was tuned from 10−7
apply an L1 penalty to the first layer (and possibly subsequent lay- to 105 using the validation set for all methods (see ‘Sparsity experi-
ers) of the network. Similarly, the sparse group lasso (SGL) method ments’ in Methods and Supplementary Section J.2).
penalizes all weights connected to a given feature36,37, while a simple The sparse attribution prior enables more accurate test predic-
existing attribution prior approach38 penalizes the gradients of each tions (Fig. 5a) and sparser models (Fig. 5c) when limited training
feature in the model. data is available, with P < 10−4 and t ≥ 4.314 by paired-samples t-test
These approaches suffer from two problems. First, a feature for all comparisons. We also plot the average cumulative impor-
with small gradients or first-layer weights may still strongly tance of sorted features and find that the sparse attribution prior
affect the model’s output39. A feature whose attribution value more effectively concentrates importance in the top few features
(for example, IG or EG) is zero is much less likely to have any (Fig. 5d). In particular, we observe that L1 penalizing the model’s
effect on predictions. Second, successfully minimizing penalties gradients as in ref. 38 rather than its EG attributions performs poorly
such as L1—regardless of attribution type—is not necessarily the in terms of both sparsity and performance. A Gini penalty on gra-
best way to create a sparse model. A model that puts weight w on dients improves sparsity but does not outperform other baselines
1 feature is penalized more than one that puts weight w/2p on each such as SGL and L1 in area under a receiver operating characteristic
of p features. Previous work on sparse linear regression has shown curve (ROC AUC). Finally, we plot the average sparsity of the mod-
that the Gini coefficient G of the weights, proportional to 0.5 minus els (Gini coefficient) against their validation ROC AUC across the
the area under the cumulative distribution function of sorted val- full range of regularization strengths. The sparse attribution prior
ues, avoids such problems and corresponds more directly to a sparse exhibits higher sparsity than other models and a smooth trade-off
model40,41. We extend this analysis to deep models by noting that the between sparsity and ROC AUC (Fig. 5b). Details and results for
Gini coefficient can be written differentiably and used as an attribu- other penalties, including L2, dropout and other attribution priors,
tion prior. are in Supplementary Section J.
Here we show that the Ωsparse attribution prior can build sparser
models that perform better in settings with limited training data. Discussion
We use a publicly available healthcare mortality prediction dataset The immense popularity of deep learning has driven its application
of 13,000 patients42, whose 35 features (118 after one-hot encod- in many areas with diverse, complicated domain knowledge. While
ing) represent medical data such as a patient’s age, vital signs and it is in principle possible to hand-design network architectures to
laboratory measurements. The binary outcome is survival after ten encode this knowledge, a more practical approach involves the
years. Sparse models in this setting may enable accurate models to use of attribution priors, which penalize the importance a model
0.750 0.75
0.725 0.70
0.700 Smooth sparsity–
0.65 AUC trade-off
0.675
0.60
0.650
0.625 0.55
Sparse L1: SGL: SGL: Gini: L1: grad Unreg 0.5 0.6 0.7 0.8 0.9
attribution all 1st all grad (ref. 9)
Sparsity
prior
c d
0.85
1.0
Sparsity (Gini coefficient,
Cumulative fraction of
0.80
feature importance
0.8 Gini: grad L1: grad
0.75
SGL: all Unreg
average)
0.70 0.6
L1: all
0.65 0.4
0.60
0.2
0.55
0
0.50
Sparse L1: SGL: SGL: Gini: L1: grad Unreg 0 20 40 60 80 100 120
attribution all 1st all grad
Features by increasing importance
prior
Fig. 5 | Sparse attribution prior builds sparser and more accurate healthcare mortality models. a,c, A sparse attribution prior enables more accurate test
predictions (a) and sparser models (c) across 200 small subsampled datasets (100 training and 100 validation samples, all other samples used for test
set) than other penalties, including gradients. b, Across the full range of tuned parameters, the sparse attribution prior achieves the greatest sparsity and
a smooth sparsity–validation performance trade-off. d, A sparse attribution prior concentrates a larger fraction of global feature importance in the top few
features. ‘Gini’, ‘L1’ and ‘SGL’ indicate the Gini, L1 and SGL penalties, respectively, ‘grad’ indicates a penalty on the gradients, ‘all’ indicates a penalty on all
weights in the model and ‘1st’ indicates a penalty on only the first weight layer. ‘Unreg’ indicates an unregularized model.
places on each of its input features when making predictions. attribution priors: as tools to achieve domain-specific goals without
Unfortunately, previous attribution priors have been limited, both sacrificing efficiency.
theoretically and computationally. Binary penalties only specify
whether features should or should not be important and fail to Methods
capture relationships among features. Approaches that focus only Previous attribution priors. The first instance of what we now call an attribution
on a model’s input gradients change the local decision boundary prior was proposed in ref. 9, where the regularization term was modified to place a
constant penalty on the gradients of undesirable features:
but often fail to impact a model’s underlying decision-making.
Attribution priors on more complicated attributions, such as IG, θ = argminθ L(θ;X, y) + λ′′ ||A ⊙
∂L 2
|| .
have proven computationally difficult. ∂X F
Our work advances previous work both by introducing novel, Here the attribution method is the gradients of the model, represented by the
flexible attribution priors for multiple domains and by enabling matrix ∂∂X
L
whose ℓ, ith entry is the gradient of the loss at the ℓth sample with
the training of such priors with a newly defined feature attribution respect to the ith feature. A is a binary matrix indicating that features should be
method. Our priors lead to smoother and more interpretable image penalized in which samples, and F is the Frobenius norm.
A more general interpretation of attribution priors is that any function of any
models, biological predictive models that incorporate graph-based
feature attribution method could be used to penalize a loss function, thus encoding
prior knowledge and sparser healthcare models that perform better prior knowledge about what properties the attributions of a model should have.
in data-scarce scenarios. Our attribution method not only enables For some model parameters θ, let Φ(θ, X) be a feature attribution method, which
the training of said priors but also outperforms its predecessor— is a function of θ and the data X. Let ϕℓi be the feature importance of feature i
IG—in terms of reliably identifying the features models use to make in sample ℓ. We formally define an attribution prior as a scalar-valued penalty
function of the feature attributions Ω(Φ(θ, X)), which represents a log-transformed
predictions. prior probability distribution over possible attributions:
There remain many avenues for future work in this area. We
chose to base our prior on an improved version of IG because it is θ = argminθ L(θ;X, y) + λΩ(Φ(θ, X)),
the most prominent differentiable feature attribution method we are
where λ is the regularization strength. Note that the attribution prior function Ω is
aware of, but a wide array of other attribution methods exist. Our agnostic to the attribution method Φ.
framework makes it straightforward to substitute any other attribu- Previous attribution priors9,12 required specifying an exact target value for
tion method as long as it is differentiable, and studying the effec- the model’s attributions, but often we do not know in advance which features are
tiveness of other attribution methods as priors would be valuable. important in advance. In general, there is no requirement that Φ(θ, X) constrain
In addition, while we develop new, more sophisticated attribution attributions to particular values. The ‘Results’ section presented three newly
developed attribution priors for different tasks that improve performance without
priors and show their value, there is ample room to improve on our requiring pre-specified attribution targets for any particular feature.
priors and evaluate entirely new ones for other tasks. Determining
the best attribution priors for particular tasks opens a further ave- Expected gradients. EG is an extension of IG2 with fewer hyperparameter
nue of research. We believe that surveys of domain experts to estab- choices. Like several other attribution methods, IG aims to explain the difference
lish model desiderata for particular applications will help to develop between a model’s current prediction and the prediction that the model would
make when given a baseline input. This baseline input is meant to represent
the best priors for any given situation while offering a valuable some uninformative reference input that represents not knowing the value of the
opportunity to put humans in the loop. Overall, the dual advances input features. Although choosing such an input is necessary for several feature
of sophisticated attribution priors and EG enable a broader view of attribution methods2,39,43, the choice is often made arbitrarily. For example, for
where x is the target input and x′ is baseline input. To avoid specifying x′, we define Pixel attribution prior. Our pixel attribution prior is based on the anisotropic total
the EG value for feature i as: variation loss and is given as follows:
∫ ∫ 1 ∑∑ ℓ
∂f(x′ + α(x − x′ )) Ωpixel (Φ(θ, X)) = |ϕi+1,j − ϕi,j | + |ϕi,j+1 − ϕi,j |,
ℓ ℓ ℓ
EGi (x) := ( (xi − x ) ×
′
dα ) pD (x′ )dx′ ,
x ′ α =0 ∂xi ℓ i,j
where D is the underlying data distribution. Since EG is also a diagonal path where ϕi,j is the attribution for the i, jth pixel in the ℓ-th training image.
ℓ
method, it satisfies the same axioms as IG44. Directly integrating over the training Research shows46 that this penalty is equivalent to placing zero-mean, i.i.d.,
distribution is intractable; therefore, we instead reformulate the integrals as Laplace-distributed priors on the differences between adjacent pixel values, that is,
expectations: ϕℓi+1,j − ϕℓi,j ≈ Laplace (0, λ−1 ) and ϕℓi,j+1 − ϕℓi,j ≈ Laplace (0, λ−1 ). Reference 46
∂f(x′ + α × (x − x′ )) does not call our penalty ‘total variation’, but it is in fact the widely used anisotropic
EGi (x) := Ex′ ∼D,α∼U(0,1) [ (xi − x′ ) × ]. version of total variation and is directly implemented in Tensorflow47–49.
∂xi
This expectation-based formulation lends itself to a natural, sampling based Graph attribution prior. For our graph attribution prior, we used a protein–
approximation method: (1) draw samples of x′ from the training dataset and α protein or gene–gene interaction network and represented these networks as
from U(0, 1), (2) compute the value inside the expectation for each sample and a weighted, undirected graph. Formally, assume we have a weighted adjacency
(3) average over samples. For a pseudocode description of EG, see Supplementary matrix W ∈ Rp+×p for an undirected graph, where the entries encode our prior
Section C. belief about the pairwise similarity of the importances between two features. For
EG also satisfies a set of important interpretability axioms: implementation a biological network, Wi,j encodes either the probability or strength of interaction
invariance, sensitivity, completeness, linearity and symmetry preserving. between the ith and jth genes (or proteins). We encouraged similarity along graph
• Implementation invariance states that two networks with outputs that are edges by penalizing the squared Euclidean distance between each pair of feature
equal over all inputs should have equivalent attributions. Any attribution attributions in proportion to how similar we believe them to be. Using the graph
method based on the gradients of a network will satisfy this axiom2, meaning Laplacian (LG = D − W), where D is the diagonal degree matrix of the weighted
that IG, EG and gradients will all be implementation invariant. graph, this becomes:
• Sensitivity (sometimes called dummy) states that when a model does not ∑
Ωgraph (Φ(θ, X)) = Wi,j (ϕ̄i − ϕ̄j )2 = ϕ̄T LG ϕ̄.
depend on a feature at all, it receives zero importance. IG, EG and gradients all
i,j
satisfy sensitivity because the gradient with respect to an irrelevant feature will
be zero everywhere. In this case, we choose to penalize global rather than local feature attributions.
• Completeness states that the attributions should sum to the difference between We define ϕ̄i to be the importance of feature i across all samples in our dataset,
the output of a function at the input to be explained and the output of that where this global attribution is calculated
∑ as the average magnitude of the feature
function at a baseline. Gradients do not satisfy completeness due to saturation attribution across all samples: ϕ̄i = n1 nℓ=1 |ϕℓi |. Just as the image penalty is
at the inputs; elements such as ReLUs may cause gradients to be zero, making equivalent to placing a Laplace prior on adjacent pixels in a regular graph, the
completeness impossible2. IG and EG both satisfy completeness due to the graph penalty Ωgraph is equivalent to placing a Gaussian prior on adjacent features
gradient theorem (fundamental theorem of calculus for line integrals)2. For in an arbitrary graph with Laplacian LG (ref. 46).
EG, the function being integrated is the expectation of the model’s output, so
completeness means that the attributions sum to the difference between the Sparse attribution prior. Our sparsity prior uses the Gini coefficient G as a penalty,
model’s output for the input and the model’s output averaged over all possible which is written:
baselines.
p ∑
∑ p
• Linearity states that for a model that is a linear combination of two submodels |ϕ̄i − ϕ̄j |
f(x) = af1(x) + bf2(x), where a and b are arbitrary scalars, the attributions are a i= 1 j = 1
linear combination of the submodels’ attributions ϕ(x) = aϕ1(x) + bϕ2(x). This Ωsparse (Φ(θ, X)) = − p
= −2G(ϕ̄),
∑
will hold for IG, EG and gradients because gradients are linear. n ϕ̄i
• Symmetry preserving states that symmetric variables with identical values i=1
should achieve identical attributions. IG is symmetry preserving since it is a By taking exponentials of this function, we find that minimizing the sparsity
straight line path method, and EG will also be symmetry preserving, as a sym- regularizer is equivalent to maximizing likelihood under a prior proportional to the
metric function of symmetric functions will itself be symmetrical2. following:
Unlike previous attribution methods, EG is explicitly designed for natural
batched training. This enables an order of magnitude increase in computational p p 1
∏ ∏
efficiency relative to previous approaches for training with attribution priors. We exp p |ϕ̄i − ϕ̄j | ,
further improve performance by reducing the need for additional data reading. ∑
Specifically, for each input in a batch of inputs, we need k additional inputs to
i=1 j=1 ϕ̄i
i= 1
calculate EG attributions for that input batch. As long as k is smaller than the batch
size, we can avoid any additional data reading by re-using the same batch of input To our knowledge, this prior does not directly correspond to a named distribution.
data as a reference batch, as in ref. 45. We accomplish this by shifting the batch of However, we observe that its maximum value occurs when one ϕ̄i is 1 and all
input k times, such that each input in the batch uses k other inputs from the batch others are 0, and that its minimum occurs when all ϕ̄i are equal. This is similar
as its reference values. to the total variation penalty Ωimage, but it is normalized and has a flipped sign to
encourage differences. The corresponding attribution prior is maximized when
Specific priors. Here we elaborate on the explicit form of the attribution priors global attributions are zero for all but one feature and minimized when attributions
we used in this paper. In general, minimizing the error of a model corresponds to are uniform across features.
maximizing the likelihood of the data under a generative model consisting of the
learned model plus parametric noise. For example, minimizing mean squared error Image model experimental settings. We trained a VGG16 model from scratch
in a regression task corresponds to maximizing the likelihood of the data under the modified for the CIFAR-10 dataset, containing 60,000 coloured 32 × 32-pixel
learned model, assuming Gaussian-distributed errors: images divided into 10 categories, as in ref. 50. To train this network, we used
stochastic gradient descent with an initial learning rate of 0.1 and an exponential
argminθ ||fθ (X) − y||22 = argmaxθ exp (−||fθ (X) − y||22 ) = θ MLE , decay of 0.5 applied every 20 epochs. Additionally, we used a momentum level of
0.9. For augmentation, we shifted each image horizontally and vertically by a pixel
where θMLE is the maximum-likelihood estimate (MLE) of θ under the model shift uniformly drawn from the range [−3, 3], and we randomly rotated each image
Y = fθ (X) + N (0, σ ). by an angle uniformly drawn from the range [−15, 15]. We used a batch size of 128.
be applied to R2 as R2 is a linear transformation of mean squared error, which equal in magnitude to the least squares and L1 loss terms. We found that five extra
satisfies normality assumptions by the central limit theorem. When we compare epochs of tuning were optimal by validation set error. We drew k = 10 background
the R2 attained from ten independent retrainings of the neural network to the R2 samples for our attributions. To test our attribution prior using gradients as the
attained from ten independent retrainings of the attribution prior model, we find feature attribution method (rather than expected gradients), we followed the exact
that predictive performance is significantly higher for the model with the graph same procedure, only we replaced ϕ̄ with the average magnitude of the gradients
attribution prior (t statistic = 3.59, P = 2.06 × 10−3). rather than the EG.
To ensure that the increased performance in the attribution prior model was
due to real biological information, we replaced the gene-interaction graph with a Graph convolutional networks. We followed the implementation of graph
randomized graph (symmetric matrix with identical number of non-zero entries convolution described in ref. 31. The architectures were searched as follows: in
to the real graph, but entries placed in random positions). We then compared the every network, we first had a single graph convolutional layer (we were limited to
R2 attained from ten independent retrainings of a neural network with no graph one graph convolution layer due to memory constraints on each Nvidia GTX 1080
attribution prior to ten independent retrainings of an neural network regularized Ti GPU that we used), followed by two fully connected layers of sizes (512, 256),
with the random graph and found that test error was not significantly different (512, 128) or (256, 128). We tuned over a wide range of hyperparameters, including
between these two models (t statistic = 1.25, P = 0.23). We also compared to graph L2 penalties on the weights ranging from 10−5 to 10−2, L1 penalties on the weights
convolutional neural networks, and found that our network with a graph attribution ranging from 10−5 to 10−2, learning rates of 10−5 to 10−3 and dropout rates ranging
prior outperformed the graph convolutional neural network (t statistic = 3.30, from 0.2 to 0.8. We found the optimal hyperparameters based on validation set
P = 4.0 × 10−3). Finally, we compared to an L2 penalty applied uniformly across error were two hidden layers of size 512 and size 256, an L2 penalty on the weights
all attributions, and found that this attribution prior did not significantly increase of 10−5, a learning rate of 10−5 and a dropout rate of 0.6. We again used an early
performance from baseline (t statistic = 1.7, P = 0.12, see Supplementary Fig. 15). stopping parameter and found that 47 epochs was the optimal number.
Train/validation/test-set allocation. To increase the number of samples in our Sparsity experiments. Data description and processing. Our sparsity experiments
dataset, we used as a feature the identity of the drug being tested, rather than one used data from the National Health and Nutrition Examination I Survey
of a number of possible output tasks in a multi-task prediction. This follows from (NHANES I)42 and contained 35 variables (expanded to 118 features by
previous literature on training neural networks to predict drug response54. This one-hot encoding of categorical variables) gathered from 13,000 patients. The
yielded 30,816 samples (covering 218 patients and 145 anticancer drugs). Defining measurements included demographic information such as age, sex and BMI as well
a sample as a drug and a patient, however, meant we had to choose carefully how as physiological measurements such as blood, urine and vital sign measurements.
to stratify samples into our train, validation and test sets. While it is perfectly The prediction task was a binary classification of whether the patient was still alive
legitimate in general to randomly stratify samples into these sets, we wanted to (1) or not (0) ten years after data were gathered.
specifically focus on how well our model could learn trends from gene expression Data were mean-imputed and standardized so that each feature had zero
data that would generalize to new patients. Therefore, we stratified samples at a mean and unit variance. For each of the 200 experimental replicates, 100 train and
patient level rather than at the level of individual samples (for example, no samples 100 validation points were sampled uniformly at random; all other points were
from any patient in the test set ever appeared in the training set). We split 20% of allocated to the test set.
the total patients into a test set (6,155 samples) and then split 20% of the training
data into a validation set for hyperparameter selection (4,709 samples). Model. We trained a range of neural networks to predict survival in the NHANES
data. The architecture, nonlinearities and training rounds were all held constant
Model class implementations and hyperparameters tested. LASSO. We used the at values that performed well on an unregularized network, and the type and
scikit-learn implementation of the LASSO55,56. We tested a range of α parameters degree of regularization were varied. All models used ReLU activations and a
from 10−9 to 1, and we found that the optimal value for α was 10−2 by mean squared single output with binary cross-entropy loss; in addition, all models ran for 100
error on the validation set. epochs with a stochastic gradient descent optimizer with learning rate 0.001 on the
size-100 training data. The entire 100-sample training set fit in one batch. Because
Graph LASSO. For our graph LASSO, we used the Adam optimizer in the training set was so small, all of its 100 samples were used for EG attributions
TensorFlow47, with a learning rate of 10−5 to optimize the following loss function: during training and evaluation, yielding k = 100. Each model was trained on a
single GPU on a desktop workstation with 4 Nvidia 1080 Ti GPUs.
2 T
L(w;X, y) =∥ Xw − y∥2 + λ ∥ w∥1 + ν w LG w, (1)
′ ′
Architecture. We considered a range of architectures, including
single-hidden-layer 32-node, 128-node and 512-node networks, two-layer
where w ∈ Rd is the weights vector of our linear model and LG is the graph [128, 32]-node and [512, 128]-node networks, and a three-layer [512, 128, 32]-node
Laplacian of our HumanBase network30. In particular, we downloaded the ‘Top network; we fixed the [512, 128, 32] architecture for future experiments.
Edges’ version of the haematopoietic stem cell network, which was thresholded to Regularizers. We tested a large array of regularizers in addition to those
only have non-zero values for pairwise interactions that had a posterior probability considered in the main text. For details, see Supplementary Section J.1.
greater than 0.1. We used the value of λ′ selected as optimal in the regular LASSO
model (10−2, which corresponds to the α parameter in scikit-learn) and then tuned Hyperparameter tuning. We selected the hyperparameters for our models based on
over ν′ values ranging from 10−3 to 100. We found that a value of 10 was optimal validation performance. We searched all L1, L2, SGL and attribution prior penalties
according to MSE on the validation set. with 121 points sampled on a log scale over [10−7, 105] (Supplementary Fig. 18).