Gradnorm: Gradient Normalization For Adaptive Loss Balancing in Deep Multitask Networks

Download as pdf or txt
Download as pdf or txt
You are on page 1of 12

GradNorm: Gradient Normalization for Adaptive

Loss Balancing in Deep Multitask Networks

Zhao Chen 1 Vijay Badrinarayanan 1 Chen-Yu Lee 1 Andrew Rabinovich 1

Abstract such as smartphones, wearable devices, and robots/drones.


Deep multitask networks, in which one neural net- Such a system can be enabled by multitask learning, where
arXiv:1711.02257v4 [cs.CV] 12 Jun 2018

work produces multiple predictive outputs, can one model shares weights across multiple tasks and makes
offer better speed and performance than their multiple inferences in one forward pass. Such networks
single-task counterparts but are challenging to are not only scalable, but the shared features within these
train properly. We present a gradient normaliza- networks can induce more robust regularization and boost
tion (GradNorm) algorithm that automatically bal- performance as a result. In the ideal limit, we can thus
ances training in deep multitask models by dynam- have the best of both worlds with multitask networks: more
ically tuning gradient magnitudes. We show that efficiency and higher performance.
for various network architectures, for both regres- In general, multitask networks are difficult to train; different
sion and classification tasks, and on both synthetic tasks need to be properly balanced so network parameters
and real datasets, GradNorm improves accuracy converge to robust shared features that are useful across all
and reduces overfitting across multiple tasks when tasks. Methods in multitask learning thus far have largely
compared to single-task networks, static baselines, tried to find this balance by manipulating the forward pass
and other adaptive multitask loss balancing tech- of the network (e.g. through constructing explicit statisti-
niques. GradNorm also matches or surpasses the cal relationships between features (Long & Wang, 2015)
performance of exhaustive grid search methods, or optimizing multitask network architectures (Misra et al.,
despite only involving a single asymmetry hy- 2016), etc.), but such methods ignore a key insight: task
perparameter α. Thus, what was once a tedious imbalances impede proper training because they manifest
search process that incurred exponentially more as imbalances between backpropagated gradients. A task
compute for each task added can now be accom- that is too dominant during training, for example, will neces-
plished within a few training runs, irrespective of sarily express that dominance by inducing gradients which
the number of tasks. Ultimately, we will demon- have relatively large magnitudes. We aim to mitigate such is-
strate that gradient manipulation affords us great sues at their root by directly modifying gradient magnitudes
control over the training dynamics of multitask through tuning of the multitask loss function.
networks and may be one of the keys to unlocking
the potential of multitask learning. In practice, the multitask loss function is often
P assumed to
be linear in the single task losses Li , L = i wi Li , where
the sum runs over all T tasks. In our case, we propose an
adaptive method, and so wi can vary at each training step
1. Introduction
t: wi = wi (t). This linear form of the loss function is
Single-task learning in computer vision has enjoyed much convenient for implementing gradient balancing, as wi very
success in deep learning, with many single-task models now directly and linearly couples to the backpropagated gradient
performing at or beyond human accuracies for a wide array magnitudes from each task. The challenge is then to find the
of tasks. However, an ultimate visual system for full scene best value for each wi at each training step t that balances
understanding must be able to perform many diverse percep- the contribution of each task for optimal model training.
tual tasks simultaneously and efficiently, especially within To optimize the weights wi (t) for gradient balancing, we
the limited compute environments of embedded systems propose a simple algorithm that penalizes the network when
1
backpropagated gradients from any task are too large or too
Magic Leap, Inc. Correspondence to: Zhao Chen small. The correct balance is struck when tasks are train-
<zchen@magicleap.com>.
ing at similar rates; if task i is training relatively quickly,
Proceedings of the 35 th International Conference on Machine then its weight wi (t) should decrease relative to other task
Learning, Stockholm, Sweden, PMLR 80, 2018. Copyright 2018 weights wj (t)|j6=i to allow other tasks more influence on
by the author(s).
GradNorm: Gradient Normalization for Adaptive Loss Balancing in Deep Multitask Networks

training. Our algorithm is similar to batch normalization give deep networks the capacity to search for meaningful
(Ioffe & Szegedy, 2015) with two main differences: (1) we relationships between tasks and to learn which features to
normalize across tasks instead of across data batches, and share between them. Work in (Warde-Farley et al., 2014)
(2) we use rate balancing as a desired objective to inform and (Lu et al., 2016) use groupings amongst labels to search
our normalization. We will show that such gradient normal- through possible architectures for learning. Perhaps the
ization (hereafter referred to as GradNorm) boosts network most relevant to the current work, (Kendall et al., 2017) uses
performance while significantly curtailing overfitting. a joint likelihood formulation to derive task weights based
on the intrinsic uncertainty in each task.
Our main contributions to multitask learning are as follows:

1. An efficient algorithm for multitask loss balancing 3. The GradNorm Algorithm


which directly tunes gradient magnitudes. 3.1. Definitions and Preliminaries
2. A method which matches or surpasses the performance P
For a multitask loss function L(t) = wi (t)Li (t), we aim
of very expensive exhaustive grid search procedures, to learn the functions wi (t) with the following goals: (1)
but which only requires tuning a single hyperparameter. to place gradient norms for different tasks on a common
3. A demonstration that direct gradient interaction pro- scale through which we can reason about their relative mag-
vides a powerful way of controlling multitask learning. nitudes, and (2) to dynamically adjust gradient norms so
different tasks train at similar rates. To this end, we first de-
fine the relevant quantities, first with respect to the gradients
2. Related Work we will be manipulating.
Multitask learning was introduced well before the advent of
• W : The subset of the full network weights W ⊂ W
deep learning (Caruana, 1998; Bakker & Heskes, 2003), but
where we actually apply GradNorm. W is generally
the robust learned features within deep networks and their
chosen as the last shared layer of weights to save on
excellent single-task performance have spurned renewed
compute costs1 .
interest. Although our primary application area is computer
vision, multitask learning has applications in multiple other (i)
• GW (t) = ||∇W wi (t)Li (t)||2 : the L2 norm of the
fields, from natural language processing (Collobert & We- gradient of the weighted single-task loss wi (t)Li (t)
ston, 2008; Hashimoto et al., 2016; Søgaard & Goldberg, with respect to the chosen weights W .
2016) to speech synthesis (Seltzer & Droppo, 2013; Wu
(i)
et al., 2015), from very domain-specific applications such • GW (t) = Etask [GW (t)]: the average gradient norm
as traffic prediction (Huang et al., 2014) to very general across all tasks at training time t.
cross-domain work (Bilen & Vedaldi, 2017). Multitask
learning has also been explored in the context of curriculum We also define various training rates for each task i:
learning (Graves et al., 2017), where subsets of tasks are
subsequently trained based on local rewards; we here ex- • L̃i (t) = Li (t)/Li (0): the loss ratio for task i at time
plore the opposite approach, where tasks are jointly trained t. L̃i (t) is a measure of the inverse training rate of
based on global rewards such as total loss decrease. task i (i.e. lower values of L̃i (t) correspond to a faster
training rate for task i)2 .
Multitask learning is very well-suited to the field of com-
puter vision, where making multiple robust predictions is • ri (t) = L̃i (t)/Etask [L̃i (t)]: the relative inverse train-
crucial for complete scene understanding. Deep networks ing rate of task i.
have been used to solve various subsets of multiple vision
With the above definitions in place, we now complete our
tasks, from 3-task networks (Eigen & Fergus, 2015; Te-
description of the GradNorm algorithm.
ichmann et al., 2016) to much larger subsets as in Uber-
Net (Kokkinos, 2016). Often, single computer vision prob-
lems can even be framed as multitask problems, such as in 3.2. Balancing Gradients with GradNorm
Mask R-CNN for instance segmentation (He et al., 2017) or As stated in Section 3.1, GradNorm should establish a com-
YOLO-9000 for object detection (Redmon & Farhadi, 2016). mon scale for gradient magnitudes, and also should balance
Particularly of note is the rich and significant body of work
1
on finding explicit ways to exploit task relationships within In our experiments this choice of W causes GradNorm to
increase training time by only ∼ 5% on NYUv2.
a multitask model. Clustering methods have shown success 2
Networks in this paper all had stable initializations and Li (0)
beyond deep models (Jacob et al., 2009; Kang et al., 2011), could be used directly. When Li (0) is sharply dependent on ini-
while constructs such as deep relationship networks (Long tialization, we can use a theoretical initial loss instead. E.g. for Li
& Wang, 2015) and cross-stich networks (Misra et al., 2016) the CE loss across C classes, we can use Li (0) = log(C).
GradNorm: Gradient Normalization for Adaptive Loss Balancing in Deep Multitask Networks

Figure 1. Gradient Normalization. Imbalanced gradient norms across tasks (left) result in suboptimal training within a multitask network.
We implement GradNorm through computing a novel gradient loss Lgrad (right) which tunes the loss weights wi to fix such imbalances in
gradient norms. We illustrate here a simplified case where such balancing results in equalized gradient norms, but in general there may be
tasks that require relatively high or low gradient magnitudes for optimal training (discussed further in Section 3).

training rates of different tasks. The common scale for gra- norms towards this target for each task. GradNorm is then
dients is most naturally the average gradient norm, GW (t), implemented as an L1 loss function Lgrad between the actual
which establishes a baseline at each timestep t by which we and target gradient norms at each timestep for each task,
can determine relative gradient sizes. The relative inverse summed over all tasks:
training rate of task i, ri (t), can be used to rate balance X (i)
α

our gradients. Concretely, the higher the value of ri (t), the Lgrad (t; wi (t)) = GW (t) − GW (t) × [ri (t)] (2)

higher the gradient magnitudes should be for task i in order i 1

to encourage the task to train more quickly. Therefore, our where the summation runs through all T tasks. When dif-
desired gradient norm for each task i is simply: ferentiating this loss Lgrad , we treat the target gradient norm
GW (t) × [ri (t)]α as a fixed constant to prevent loss weights
(i)
GW (t) 7→ GW (t) × [ri (t)]α , (1) wi (t) from spuriously drifting towards zero. Lgrad is then
differentiated only with respect to the wi , as the wi (t) di-
rectly control gradient magnitudes per task. The computed
where α is an additional hyperparameter. α sets the strength
gradients ∇wi Lgrad are then applied via standard update
of the restoring force which pulls tasks back to a common
rules to update each wi (as shown in Figure 1). The full
training rate. In cases where tasks are very different in
GradNorm algorithm is summarized in Algorithm 1. Note
their complexity, leading to dramatically different learning
that after everyPupdate step, we also renormalize the weights
dynamics between tasks, a higher value of α should be used
wi (t) so that i wi (t) = T in order to decouple gradient
to enforce stronger training rate balancing. When tasks are
normalization from the global learning rate.
more symmetric (e.g. the synthetic examples in Section 4),
a lower value of α is appropriate. Note that α = 0 will
always try to pin the norms of backpropagated gradients 4. A Toy Example
from each task to be equal at W . See Section 5.4 for more
To illustrate GradNorm on a simple, interpretable system,
details on the effects of tuning α.
we construct a common scenario for multitask networks:
Equation 1 gives a target for each task i’s gradient norms, training tasks which have similar loss functions but different
and we update our loss weights wi (t) to move gradient loss scales. In such situations, if we naı̈vely pick wi (t) = 1
GradNorm: Gradient Normalization for Adaptive Loss Balancing in Deep Multitask Networks

Algorithm 1 Training with GradNorm In these toy problems, we measure the task-normalized test-
Initialize wi (0) = 1 ∀i time loss to judge test-time performance,
P which is the sum
Initialize network weights W of the test loss ratios for each task, i Li (t)/Li (0). We do
Pick value for α > 0 and pick the weights W (usually the this because a simple sum of losses is an inadequate per-
final layer of weights which are shared between tasks) formance metric for multitask networks when different loss
for t = 0 to max train steps do scales exist; higher loss scale tasks will factor dispropor-
Input batchP xi to compute Li (t) ∀i and tionately highly in the loss. There unfortunately exists no
L(t) = i wi (t)Li (t) [standard forward pass] general single scalar which gives a meaningful measure of
(i) multitask performance in all scenarios, but our toy problem
Compute GW (t) and ri (t) ∀i
(i) was specifically designed with tasks which are statistically
Compute GW (t) by averaging the GW (t) identical except for their loss scales σi . There is therefore
P (i)
Compute Lgrad = i |GW (t) − GW (t) × [ri (t)]α |1 a clear measure of overall network performance, which is
Compute GradNorm gradients ∇wi Lgrad , keeping the sum of losses normalized by each task’s variance σi2 -
targets GW (t) × [ri (t)]α constant equivalent (up to a scaling factor) to the sum of loss ratios.
Compute standard gradients ∇W L(t)
Update wi (t) 7→ wi (t + 1) using ∇wi Lgrad For T = 2, we choose the values (σ0 , σ1 ) = (1.0, 100.0)
Update W(t) 7→ W(t + 1) using ∇W L(t) [standard and show the results of training in the top panels of Figure 2.
backward pass] If we train with equal weights wi = 1, task 1 suppresses task
P
Renormalize wi (t + 1) so that i wi (t + 1) = T 0 from learning due to task 1’s higher loss scale. However,
end for gradient normalization increases w0 (t) to counteract the
larger gradients coming from T1 , and the improved task
balance results in better test-time performance.
for all loss weights wi (t), the network training will be dom- The possible benefits of gradient normalization become even
inated by tasks with larger loss scales that backpropagate clearer when the number of tasks increases. For T = 10,
larger gradients. We will demonstrate that GradNorm over- we sample the σi from a wide normal distribution and plot
comes this issue. the results in the bottom panels of Figure 2. GradNorm
Consider T regression tasks trained using standard squared significantly improves test time performance over naı̈vely
loss onto the functions weighting each task the same. Similarly to the T = 2 case,
for T = 10 the wi (t) grow larger for smaller σi tasks.
fi (x) = σi tanh((B + i )x), (3)
For both T = 2 and T = 10, GradNorm is more stable
where tanh(·) acts element-wise. Inputs are dimension 250 and outperforms the uncertainty weighting proposed by
and outputs dimension 100, while B and i are constant (Kendall et al., 2017). Uncertainty weighting, which en-
matrices with their elements generated IID from N (0, 10) forces that wi (t) ∼ 1/Li (t), tends to grow the weights
and N (0, 3.5), respectively. Each task therefore shares in- wi (t) too large and too quickly as the loss for each task
formation in B but also contains task-specific information drops. Although such networks train quickly at the onset,
i . The σi are the key parameters within this problem; the training soon deteriorates. This issue is largely caused by
they are fixed scalars which set the scales of the outputs the fact that uncertainty weighting allows wi (t) to change
fi . A higher scale for fi induces a higher expected value without
P constraint (compared to GradNorm which ensures
of squared loss for that task. Such tasks are harder to learn wi (t) = T always), which pushes the global learning
due to the higher variances in their response values, but they rate up rapidly as the network trains.
also backpropagate larger gradients. This scenario generally The traces for each wi (t) during a single GradNorm run are
leads to suboptimal training dynamics when the higher σi observed to be stable and convergent. In Section 5.3 we will
tasks dominate the training across all tasks. see how the time-averaged weights Et [wi (t)] lie close to the
To train our toy models, we use a 4-layer fully-connected optimal static weights, suggesting GradNorm can greatly
ReLU-activated network with 100 neurons per layer as a simplify the tedious grid search procedure.
common trunk. A final affine transformation layer gives T
final predictions (corresponding to T different tasks). To 5. Application to a Large Real-World Dataset
ensure valid analysis, we only compare models initialized
to the same random values and fed data generated from the We use two variants of NYUv2 (Nathan Silberman & Fer-
same fixed random seed. The asymmetry α is set low to 0.12 gus, 2012) as our main datasets. Please refer to the Supple-
for these experiments, as the output functions fi are all of mentary Materials for additional results on a 9-task facial
the same functional form and thus we expect the asymmetry landmark dataset found in (Zhang et al., 2014). The standard
between tasks to be minimal. NYUv2 dataset carries depth, surface normals, and semantic
GradNorm: Gradient Normalization for Adaptive Loss Balancing in Deep Multitask Networks

Figure 2. Gradient Normalization on a toy 2-task (top) and 10-task (bottom) system. Diagrams of the network structure with loss
scales are on the left, traces of wi (t) during training in the middle, and task-normalized test loss curves on the right. α = 0.12 for all runs.

segmentation labels (clustered into 13 distinct classes) for a and (2) an FCN (Long et al., 2015) network with a modified
variety of indoor scenes in different room types (bathrooms, ResNet-50 (He et al., 2016) encoder and shallow ResNet de-
living rooms, studies, etc.). NYUv2 is relatively small (795 coder. The VGG SegNet reuses maxpool indices to perform
training, 654 test images), but contains both regression and upsampling, while the ResNet FCN learns all upsampling
classification labels, making it a good choice to test the filters. The ResNet architecture is further thinned (both in
robustness of GradNorm across various tasks. its filters and activations) to contrast with the heavier, more
complex VGG SegNet: stride-2 layers are moved earlier
We augment the standard NYUv2 depth dataset with flips
and all 2048-filter layers are replaced by 1024-filter layers.
and additional frames from each video, resulting in 90,000
Ultimately, the VGG SegNet has 29M parameters versus
images complete with pixel-wise depth, surface normals,
15M for the thin ResNet. All model parameters are shared
and room keypoint labels (segmentation labels are, unfortu-
amongst all tasks until the final layer. Although we will
nately, not available for these additional frames). Keypoint
focus on the VGG SegNet in our more in-depth analysis,
labels are professionally annotated by humans, while sur-
by designing and testing on two extremely different net-
face normals are generated algorithmically. The full dataset
work topologies we will further demonstrate GradNorm’s
is then split by scene for a 90/10 train/test split. See Figure
robustness to the choice of base architecture.
6 for examples. We will generally refer to these two datasets
as NYUv2+seg and NYUv2+kpts, respectively. We use standard pixel-wise loss functions for each task:
cross entropy for segmentation, squared loss for depth, and
All inputs are downsampled to 320 x 320 pixels and outputs
cosine similarity for normals. As in (Lee et al., 2017), for
to 80 x 80 pixels. We use these resolutions following (Lee
room layout we generate Gaussian heatmaps for each of
et al., 2017), which represents the state-of-the-art in room
48 room keypoint types and predict these heatmaps with
keypoint prediction and from which we also derive our
a pixel-wise squared loss. Note that all regression tasks
VGG-style model architecture. These resolutions also allow
are quadratic losses (our surface normal prediction uses a
us to keep models relatively slim while not compromising
cosine loss which is quadratic to leading order), allowing us
semantic complexity in the ground truth output maps.
to use ri (t) for each task i as a direct proxy for each task’s
relative inverse training rate.
5.1. Model and General Training Characteristics
All runs are trained at a batch size of 24 across 4 Titan
We try two different models: (1) a SegNet (Badrinarayanan X GTX 12GB GPUs and run at 30fps on a single GPU at
et al., 2015; Lee et al., 2017) network with a symmetric inference. All NYUv2 runs begin with a learning rate of 2e-
VGG16 (Simonyan & Zisserman, 2014) encoder/decoder, 5. NYUv2+kpts runs last 80000 steps with a learning rate
GradNorm: Gradient Normalization for Adaptive Loss Balancing in Deep Multitask Networks

Table 1. Test error, NYUv2+seg for GradNorm and various base-


lines. Lower values are better. Best performance for each task is
bolded, with second-best underlined.
Model and Depth Seg. Normals
Weighting RMS Err. Err. Err.
Method (m) (100-IoU) (1-|cos|)
VGG Backbone
Depth Only 1.038 - -
Seg. Only - 70.0 -
Normals Only - - 0.169
Equal Weights 0.944 70.1 0.192
GradNorm Static 0.939 67.5 0.171
GradNorm α = 1.5 0.925 67.8 0.174

decay of 0.2 every 25000 steps. NYUv2+seg runs last 20000


steps with a learning rate decay of 0.2 every 6000 steps.
Updating wi (t) is performed at a learning rate of 0.025 for
both GradNorm and the uncertainty weighting ((Kendall
et al., 2017)) baseline. All optimizers are Adam, although
we find that GradNorm is insensitive to the optimizer chosen.
We implement GradNorm using TensorFlow v1.2.1.

5.2. Main Results on NYUv2


In Table 1 we display the performance of GradNorm on
the NYUv2+seg dataset. We see that GradNorm α = 1.5
improves the performance of all three tasks with respect
to the equal-weights baseline (where wi (t) = 1 for all t,i),
and either surpasses or matches (within statistical noise)
the best performance of single networks for each task.
The GradNorm Static network uses static weights derived
from a GradNorm network by calculating the time-averaged
weights Et [wi (t)] for each task during a GradNorm training
run, and retraining a network with weights fixed to those
values. GradNorm thus can also be used to extract good Figure 3. Test and training loss curves for all tasks in
values for static weights. We pursue this idea further in NYUv2+kpts, VGG16 backbone. GradNorm versus an equal
Section 5.3 and show that these weights lie very close to the weights baseline and uncertainty weighting (Kendall et al., 2017).
optimal weights extracted from exhaustive grid search.
To show how GradNorm can perform in the presence of a trend exists for keypoint regression, and is a clear signal of
larger dataset, we also perform extensive experiments on network regularization. In contrast, uncertainty weighting
the NYUv2+kpts dataset, which is augmented to a factor (Kendall et al., 2017) always moves test and training error in
of 50x more data. The results are shown in Table 2. As the same direction, and thus is not a good regularizer. Only
with the NYUv2+seg runs, GradNorm networks outperform results for the VGG SegNet are shown here, but the Thin
other multitask methods, and either matches (within noise) ResNet FCN produces consistent results.
or surpasses the performance of single-task networks.
5.3. Gradient Normalization Finds Optimal
Figure 3 shows test and training loss curves for GradNorm
Grid-Search Weights in One Pass
(α = 1.5) and baselines on the larger NYUv2+kpts dataset
for our VGG SegNet models. GradNorm improves test-time For our VGG SegNet, we train 100 networks from scratch
depth error by ∼ 5%, despite converging to a much higher with random task weights on NYUv2+kpts. Weights are
training loss. GradNorm achieves this by aggressively rate sampled from a uniform distribution and renormalized to
balancing the network (enforced by a high asymmetry α = sum to T = 3. For computational efficiency, we only train
1.5), and ultimately suppresses the depth weight wdepth (t) to for 15000 iterations out of the normal 80000, and then
lower than 0.10 (see Section 5.4 for more details). The same compare the performance of that network to our GradNorm
GradNorm: Gradient Normalization for Adaptive Loss Balancing in Deep Multitask Networks

Table 2. Test error, NYUv2+kpts for GradNorm and various base-


lines. Lower values are better. Best performance for each task is
bolded, with second-best underlined.
Model and Depth Kpt. Normals
Weighting RMS Err. Err. Err.
Method (m) (%) (1-|cos|)
ResNet Backbone
Depth Only 0.725 - -
Kpt Only - 7.90 -
Normals Only - - 0.155
Equal Weights 0.697 7.80 0.172
(Kendall et al., 2017) 0.702 7.96 0.182
GradNorm Static 0.695 7.63 0.156
GradNorm α = 1.5 0.663 7.32 0.155
VGG Backbone
Depth Only 0.689 - -
Keypoint Only - 8.39 -
Normals Only - - 0.142
Equal Weights 0.658 8.39 0.155
(Kendall et al., 2017) 0.649 8.00 0.158
GradNorm Static 0.638 7.69 0.137
GradNorm α = 1.5 0.629 7.73 0.139

Figure 5. Weights wi (t) during training, NYUv2+kpts. Traces


of how the task weights wi (t) change during training for two
different values of α. A larger value of α pushes weights farther
apart, leading to less symmetry between tasks.

5.4. Effects of tuning the asymmetry α


The only hyperparameter in our algorithm is the asymmetry
α. The optimal value of α for NYUv2 lies near α = 1.5,
Figure 4. Gridsearch performance for random task weights while in the highly symmetric toy example in Section 4 we
vs GradNorm, NYUv2+kpts. Average change in performance used α = 0.12. This observation reinforces our characteri-
across three tasks for a static multitask network with weights wistatic , zation of α as an asymmetry parameter.
plotted against the L2 distance between wistatic and a set of static
weights derived from a GradNorm network, Et [wi (t)]. A refer- Tuning α leads to performance gains, but we found that
ence line at zero performance change is provided for convenience. for NYUv2, almost any value of 0 < α < 3 will improve
All comparisons are made at 15000 steps of training. network performance over an equal weights baseline (see
Supplementary for details). Figure 5 shows that higher val-
ues of α tend to push the weights wi (t) further apart, which
α = 1.5 VGG SegNet network at the same 15000 steps. more aggressively reduces the influence of tasks which over-
The results are shown in Figure 4. fit or learn too quickly (in our case, depth). Remarkably, at
α = 1.75 (not shown) wdepth (t) is suppressed to below 0.02
Even after 100 networks trained, grid search still falls short at no detriment to network performance on the depth task.
of our GradNorm network. Even more remarkably, there is
a strong, negative correlation between network performance
5.5. Qualitative Results
and task weight distance to our time-averaged GradNorm
weights Et [wi (t)]. At an L2 distance of ∼ 3, grid search Figure 6 shows visualizations of the VGG SegNet outputs
networks on average have almost double the errors per task on test set images along with the ground truth, for both the
compared to our GradNorm network. GradNorm has there- NYUv2+seg and NYUv2+kpts datasets. Ground truth labels
fore found the optimal grid search weights in one single are juxtaposed with outputs from the equal weights network,
training run. 3 single networks, and our best GradNorm network. Some
GradNorm: Gradient Normalization for Adaptive Loss Balancing in Deep Multitask Networks

Figure 6. Visualizations at inference time. NYUv2+kpts outputs are shown on the left, while NYUv2+seg outputs are shown on the
right. Visualizations shown were generated from random test set images. Some improvements are incremental, but red frames are drawn
around predictions that are visually more clearly improved by GradNorm. For NYUv2+kpts outputs GradNorm shows improvement
over the equal weights network in normals prediction and over single networks in keypoint prediction. For NYUv2+seg there is an
improvement over single networks in depth and segmentation accuracy. These are consistent with the numbers reported in Tables 1 and 2.

improvements are incremental, but GradNorm produces perior performance over state-of-the-art multitask adaptive
superior visual results in tasks for which there are significant weighting methods and can match or surpass the perfor-
quantitative improvements in Tables 1 and 2. mance of exhaustive grid search while being significantly
less time-intensive.
6. Conclusions Looking ahead, algorithms such as GradNorm may have
applications beyond multitask learning. We hope to extend
We introduced GradNorm, an efficient algorithm for tun-
the GradNorm approach to work with class-balancing and
ing loss weights in a multi-task learning setting based on
sequence-to-sequence models, all situations where problems
balancing the training rates of different tasks. We demon-
with conflicting gradient signals can degrade model perfor-
strated on both synthetic and real datasets that GradNorm
mance. We thus believe that our work not only provides a
improves multitask test-time performance in a variety of
robust new algorithm for multitask learning, but also rein-
scenarios, and can accommodate various levels of asymme-
forces the powerful idea that gradient tuning is fundamental
try amongst the different tasks through the hyperparameter
for training large, effective models on complex tasks.
α. Our empirical results indicate that GradNorm offers su-
GradNorm: Gradient Normalization for Adaptive Loss Balancing in Deep Multitask Networks

References Kang, Z., Grauman, K., and Sha, F. Learning with whom
to share in multi-task feature learning. In Proceedings of
Badrinarayanan, V., Kendall, A., and Cipolla, R. Segnet:
the 28th International Conference on Machine Learning
A deep convolutional encoder-decoder architecture for
(ICML-11), pp. 521–528, 2011.
image segmentation. arXiv preprint arXiv:1511.00561,
2015. Kendall, A., Gal, Y., and Cipolla, R. Multi-task learning
Bakker, B. and Heskes, T. Task clustering and gating for using uncertainty to weigh losses for scene geometry and
bayesian multitask learning. Journal of Machine Learn- semantics. arXiv preprint arXiv:1705.07115, 2017.
ing Research, 4(May):83–99, 2003. Kokkinos, I. Ubernet: Training a universal convolutional
Bilen, H. and Vedaldi, A. Universal representations: The neural network for low-, mid-, and high-level vision using
missing link between faces, text, planktons, and cat diverse datasets and limited memory. arXiv preprint
breeds. arXiv preprint arXiv:1701.07275, 2017. arXiv:1609.02132, 2016.

Caruana, R. Multitask learning. In Learning to learn, pp. Lee, C.-Y., Badrinarayanan, V., Malisiewicz, T., and Rabi-
95–133. Springer, 1998. novich, A. Roomnet: End-to-end room layout estimation.
arXiv preprint arXiv:1703.06241, 2017.
Collobert, R. and Weston, J. A unified architecture for natu-
ral language processing: Deep neural networks with mul- Long, J., Shelhamer, E., and Darrell, T. Fully convolutional
titask learning. In Proceedings of the 25th international networks for semantic segmentation. In Proceedings of
conference on Machine learning, pp. 160–167. ACM, the IEEE Conference on Computer Vision and Pattern
2008. Recognition, pp. 3431–3440, 2015.
Eigen, D. and Fergus, R. Predicting depth, surface normals Long, M. and Wang, J. Learning multiple tasks with deep
and semantic labels with a common multi-scale convolu- relationship networks. arXiv preprint arXiv:1506.02117,
tional architecture. In Proceedings of the IEEE Interna- 2015.
tional Conference on Computer Vision, pp. 2650–2658,
2015. Lu, Y., Kumar, A., Zhai, S., Cheng, Y., Javidi, T., and Feris,
R. Fully-adaptive feature sharing in multi-task networks
Graves, A., Bellemare, M. G., Menick, J., Munos, R., and with applications in person attribute classification. arXiv
Kavukcuoglu, K. Automated curriculum learning for preprint arXiv:1611.05377, 2016.
neural networks. arXiv preprint arXiv:1704.03003, 2017.
Misra, I., Shrivastava, A., Gupta, A., and Hebert, M. Cross-
Hashimoto, K., Xiong, C., Tsuruoka, Y., and Socher, R. A stitch networks for multi-task learning. In Proceedings
joint many-task model: Growing a neural network for of the IEEE Conference on Computer Vision and Pattern
multiple nlp tasks. arXiv preprint arXiv:1611.01587, Recognition, pp. 3994–4003, 2016.
2016.
Nathan Silberman, Derek Hoiem, P. K. and Fergus, R. In-
He, K., Zhang, X., Ren, S., and Sun, J. Deep residual learn-
door segmentation and support inference from rgbd im-
ing for image recognition. In Proceedings of the IEEE
ages. In ECCV, 2012.
conference on computer vision and pattern recognition,
pp. 770–778, 2016. Redmon, J. and Farhadi, A. Yolo9000: better, faster,
He, K., Gkioxari, G., Dollár, P., and Girshick, R. Mask stronger. arXiv preprint arXiv:1612.08242, 2016.
r-cnn. arXiv preprint arXiv:1703.06870, 2017. Seltzer, M. L. and Droppo, J. Multi-task learning in deep
Huang, W., Song, G., Hong, H., and Xie, K. Deep archi- neural networks for improved phoneme recognition. In
tecture for traffic flow prediction: deep belief networks Acoustics, Speech and Signal Processing (ICASSP), 2013
with multitask learning. IEEE Transactions on Intelligent IEEE International Conference on, pp. 6965–6969. IEEE,
Transportation Systems, 15(5):2191–2201, 2014. 2013.

Ioffe, S. and Szegedy, C. Batch normalization: Accelerating Simonyan, K. and Zisserman, A. Very deep convolu-
deep network training by reducing internal covariate shift. tional networks for large-scale image recognition. arXiv
In International Conference on Machine Learning, pp. preprint arXiv:1409.1556, 2014.
448–456, 2015.
Søgaard, A. and Goldberg, Y. Deep multi-task learning with
Jacob, L., Vert, J.-p., and Bach, F. R. Clustered multi-task low level tasks supervised at lower layers. In Proceed-
learning: A convex formulation. In Advances in neural ings of the 54th Annual Meeting of the Association for
information processing systems, pp. 745–752, 2009. Computational Linguistics, volume 2, pp. 231–235, 2016.
GradNorm: Gradient Normalization for Adaptive Loss Balancing in Deep Multitask Networks

Teichmann, M., Weber, M., Zoellner, M., Cipolla, R.,


and Urtasun, R. Multinet: Real-time joint seman-
tic reasoning for autonomous driving. arXiv preprint
arXiv:1612.07695, 2016.
Warde-Farley, D., Rabinovich, A., and Anguelov, D. Self-
informed neural network structure learning. arXiv
preprint arXiv:1412.6563, 2014.

Wu, Z., Valentini-Botinhao, C., Watts, O., and King, S.


Deep neural networks employing multi-task learning
and stacked bottleneck features for speech synthesis. In
Acoustics, Speech and Signal Processing (ICASSP), 2015
IEEE International Conference on, pp. 4460–4464. IEEE,
2015.
Zhang, Z., Luo, P., Loy, C. C., and Tang, X. Facial land-
mark detection by deep multi-task learning. In European
Conference on Computer Vision, pp. 94–108. Springer,
2014.
GradNorm: Gradient Normalization for Adaptive Loss Balancing in Deep Multitask Networks

7. GradNorm: Gradient Normalization for ResNet architecture seems more robust to α than the VGG
Adaptive Loss Balancing in Deep Multitask architecture, although both architectures offer a similar level
Networks: Supplementary Materials of gains with the proper setting of α. Most importantly, the
consistently positive performance gains across all values of
7.1. Performance Gains Versus α α suggest that any kind of gradient balancing (even in sub-
optimal regimes) is healthy for multitask network training.
The α asymmetry hyperparameter, we argued, allows us to
accommodate for various different priors on the symmetry
between tasks. A low value of α results in gradient norms 7.2. Additional Experiments on a Multitask Facial
which are of similar magnitude across tasks, ensuring that Landmark Dataset
each task has approximately equal impact on the training dy- We perform additional experiments on the Multitask Facial
namics throughout training. A high value of α will penalize Landmark (MTFL) dataset (Zhang et al., 2014). This dataset
tasks whose losses drop too quickly, instead placing more contains approximately 13k images of faces, split into a
weight on tasks whose losses are dropping more slowly. training set of 10k and a test set of 3k. Images are each
For our NYUv2 experiments, we chose α = 1.5 as our labeled with (x, y) coordinates of five facial landmarks (left
optimal value for α, and in Section 5.4 we touched upon eye, right eye, nose, left lip, and right lip), along with four
how increasing α pushes the task weights wi (t) farther apart. class labels (gender, smiling, glasses, and pose). Examples
It is interesting to note, however, that we achieve overall of images and labels from the dataset are given in Figure 8.
gains in performance for almost all positive values of α for
which GradNorm is numerically stable3 . These results are
summarized in Figure 7.

Figure 8. Examples from the Multi-Task Facial Landmark


(MTFL) dataset.

The MTFL dataset provides a good opportunity to test


GradNorm, as it is a rich mixture of classification and
regression tasks. We perform experiments at two different
input resolutions: 40x40 and 160x160. For our 40x40
experiments we use the same architecture as in (Zhang et al.,
Figure 7. Performance gains on NYUv2+kpts for various set- 2014) to ensure a fair comparison, while for our 160x160
tings of α. For various values of α, we plot the average perfor- experiments we use a deeper version of the architecture
mance gain (defined as the mean of the percent change in the test in (Zhang et al., 2014); the deeper model layer stack is
loss compared to the equal weights baseline across all tasks) on
[CONV-5-16][POOL-2][CONV-3-32]2 [POOL-2][CONV-
NYUv2+kpts. We show results for both the VGG16 backbone
3-64]2 [POOL-2][[CONV-3-128]2 [POOL-2]]2 [CONV-
(solid line) and the ResNet50 backbone (dotted line). We show
performance gains at all values of α tested, although gains appear 3-128]2 [FC-100][FC-18], where CONV-X-F denotes a
to peak around α = 1.5. No points past α > 2 are shown for convolution with filter size X and F output filters, POOL-2
the VGG16 backbone as GradNorm weights are unstable past this denotes a 2x2 pooling layer with stride 2, and FC-X is
point for this particular architectural backbone. a dense layer with X outputs. All networks output 18
values: 10 coordinates for facial landmarks, and 4 pairs of 2
We see from Figure 7 that we achieve performance gains softmax scores for each classifier.
at almost all values of α. However, for NYUv2+kpts in The results on the MTFL dataset are shown in Table 3. Key-
particular, these performance gains seem to be peaked at point error is a mean over L2 distance errors for all five
α = 1.5 for both backbone architectures. Moreover, the facial landmarks, normalized to the inter-ocular distance,
3
At large positive values of α, which in the NYUv2 case cor- while failure rate is the percent of images for which key-
responded to α ≥ 3, some weights were pushed too close to zero point error is over 10%. For both resolutions, GradNorm
and GradNorm updates became unstable. outperforms other methods on all tasks (save for glasses
GradNorm: Gradient Normalization for Adaptive Loss Balancing in Deep Multitask Networks

Table 3. Test error on the Multi-Task Facial Landmark (MTFL) dataset for GradNorm and various baselines. Lower values are
better and best performance for each task is bolded. Experiments are performed for two different input resolutions, 40x40 and 160x160. In
all cases, GradNorm shows superior performance, especially on gender and smiles classification. GradNorm also matches the performance
of (Zhang et al., 2014) on keypoint prediction at 40x40 resolution, even though the latter only tries to optimize keypoint accuracy
(sacrificing classification accuracy in the process).
Input Keypoint Failure. Gender Smiles Glasses Pose
Method Resolution Err. (%) Rate. (%) Err. (%) Err. (%) Err. (%) Err. (%)
Equal Weights 40x40 8.3 27.4 20.3 19.2 8.1 38.9
(Zhang et al., 2014) 40x40 8.2 25.0 - - - -
(Kendall et al., 2017) 40x40 8.3 27.2 20.7 18.5 8.1 38.9
GradNorm α = 0.3 40x40 8.0 25.0 17.3 16.9 8.1 38.9
Equal Weights 160x160 6.8 15.2 18.6 17.4 8.1 38.9
(Kendall et al., 2017) 160x160 7.2 18.3 38.1 18.4 8.1 38.9
GradNorm α = 0.2 160x160 6.5 14.3 14.4 15.4 8.1 38.9

and pose prediction, both of which always quickly converge placing more relative focus on keypoint regression, and of-
to the majority classifier and refuse to train further). Grad- ten performs quite poorly on classification (especially for
Norm also matches the performance of (Zhang et al., 2014) higher resolution inputs). These experiments thus highlight
on keypoints, even though the latter did not try to optimize GradNorm’s ability to identify and benefit tasks which re-
for classifier performance and only stressed keypoint accu- quire more attention during training.
racy. It should be noted that the keypoint prediction and
failure rate improvements are likely within error bars; a 1%
absolute improvement in keypoint error represents a very
fine sub-pixel improvement, and thus may not represent a
statistically significant gain. Ultimately, we interpret these
results as showing that GradNorm significantly improves
classification accuracy on gender and smiles, while at least
matching all other methods on all other tasks.
We reiterate that both glasses and pose classification always
converge to the majority classifier. Such tasks which be-
come “stuck” during training pose a problem for GradNorm,
as the GradNorm algorithm would tend to continuously in-
crease the loss weights for these tasks. For future work,
we are looking into ways to alleviate this issue, by detect-
ing pathological tasks online and removing them from the
GradNorm update equation.
Despite such obstacles, GradNorm still provides superior
performance on this dataset and it is instructive to examine
why. After all loss weights are initialized to wi (0) = 1,
we find that (Kendall et al., 2017) tends to increase the
loss weight for keypoints relative to that of the classifier
losses, while GradNorm aggressively decreases the relative
keypoint loss weights. For GradNorm training runs, we
often find that wkpt (t) converges to a value ≤ 0.01, showing
that even with gradients that are smaller by two orders of
magnitude compared to (Kendall et al., 2017) or the equal
weights method, the keypoint task trains properly with no
attenuation of accuracy.
To summarize, GradNorm is the only method that correctly
identifies that the classification tasks in the MTFL dataset
are relatively undertrained and need to be boosted. In con-
trast, (Kendall et al., 2017) makes the inverse decision by

You might also like