2019 - End-To-End Multi-Task Learning With Attention - Liu Et Al
2019 - End-To-End Multi-Task Learning With Attention - Liu Et Al
Abstract
Shared Features
We propose a novel multi-task learning architecture,
which allows learning of task-specific feature-level at-
tention. Our design, the Multi-Task Attention Network Task-Specific
Attention Modules
(MTAN), consists of a single shared network containing a
global feature pool, together with a soft-attention module
for each task. These modules allow for learning of task- Task-Specific
specific features from the global features, whilst simulta- Attention Modules
neously allowing for features to be shared across different
tasks. The architecture can be trained end-to-end and can Figure 1: Overview of our proposal MTAN. The shared net-
be built upon any feed-forward neural network, is simple work takes input data and learns task-shared features, whilst
to implement, and is parameter efficient. We evaluate our each attention network learns task-specific features, by ap-
approach on a variety of datasets, across both image-to- plying attention modules to the shared network.
image predictions and image classification tasks. We show
that our architecture is state-of-the-art in multi-task learn- i) Network Architecture (how to share): A multi-task
ing compared to existing methods, and is also less sen- learning architecture should express both task-shared
sitive to various weighting schemes in the multi-task loss and task-specific features. In this way, the network is en-
function. Code is available at https://github.com/ couraged to learn a generalisable representation (to avoid
lorenmt/mtan. over-fitting), whilst also providing the ability to learn
features tailored to each task (to avoid under-fitting).
11871
rather than learning directly from the shared feature pool, tecture [30], learns a regularised combination of features
a soft attention mask is applied at each convolution block from different layers of a single shared network. UberNet
in the shared network. In this way, each attention mask [16] proposes an image pyramid approach to process im-
automatically determines the importance of the shared fea- ages across multiple resolutions, where for each resolution,
tures for the respective task, allowing learning of both task- additional task-specific layers are formed top of the shared
shared and task-specific features in a self-supervised, end- VGG-Net [27]. The Progressive Networks [26] uses a se-
to-end manner. This flexibility enables much more expres- quence of incrementally-trained networks to transfer knowl-
sive combinations of features to be learned for generali- edge between tasks. However, architectures such as Cross-
sation across tasks, whilst still allowing for discriminative Stitch Networks and Progressive Networks require a large
features to be tailored for each individual task. Further- number of network parameters, and scale linearly with the
more, automatically choosing which features to share and number of tasks. In contrast, our model requires only a
which to be task specific allows for a highly efficient archi- rough 10% increase in parameters for per learning task.
tecture with far fewer parameters than multi-task architec- On the balancing of feature sharing in multi-task learn-
tures which have explicit separation of tasks [26, 20]. ing, there is extensive experimental analysis in [20, 14],
MTAN can be built on any feed-forward neural network with both papers arguing that different amounts of sharing
depending on the type of tasks. We first evaluate MTAN and weighting tend to work best for different tasks. One
with SegNet [1], an encoder-decoder network on the tasks example of weighting tasks appropriately is with the use
of semantic segmentation and depth estimation on the out- of weight uncertainty [14], which modifies the loss func-
door CityScapes dataset [4], and then with an additional tions in multi-task learning using task uncertainty. Another
task of surface normal prediction on the more challenging method is that of GradNorm [3], which manipulates gradi-
indoor dataset NYUv2 [21]. We also test our approach with ent norms over time to control the training dynamics. As an
a different backbone architecture, Wide Residual Network alternative to using task losses to determine task difficulties,
[31], on the recently proposed Visual Decathlon Challenge Dynamic Task Prioritisation [10] encourages prioritisation
[23], to solve 10 individual image classification tasks. Re- of difficult tasks directly using performance metrics such as
sults show that MTAN outperforms several baselines and is accuracy and precision.
competitive with the state-of-the-art for multi-task learning,
whilst being more parameter efficient and therefore scaling 3. Multi-Task Attention Network
more gracefully with the number of tasks. Furthermore, our
We now introduce our novel multi-task learning archi-
method shows greater robustness to the choice of weight-
tecture, the Multi-Task Attention Network (MTAN). Whilst
ing scheme in the loss function compared to baselines. As
the architecture can be incorporated into any feed-forward
part of our evaluation of this robustness, we also propose a
network, in the following section we demonstrate how to
novel weighting scheme, Dynamic Weight Average (DWA),
build MTAN upon an encoder-decoder network, SegNet
which adapts the task weighting over time by considering
[1]. This example configuration allows for image-to-image
the rate of change of the loss for each task.
dense pixel-level prediction, such as semantic segmenta-
tion, depth estimation, and surface normal prediction.
2. Related Work
3.1. Architecture Design
The term Multi-Task Learning (MTL) has been broadly
used in machine learning [2, 8, 6, 17], with similarities to MTAN consists of two components: a single shared net-
transfer learning [22, 18] and continual learning [29]. In work, and K task-specific attention networks. The shared
computer vision, multi-task learning has been used to for network can be designed based on the particular task, whilst
learning similar tasks such as image classification in mul- each task-specific network consists of a set of attention
tiple domains [23], pose estimation and action recognition modules, which link with the shared network. Each at-
[9], and dense prediction of depth, surface normals, and se- tention module applies a soft attention mask to a particular
mantic classes [20, 7]. In this paper, we consider two impor- layer of the shared network, to learn task-specific features.
tant aspects of multi-task learning: how can a good multi- As such, the attention masks can be considered as feature
task network architecture be designed, and how to balance selectors from the shared network, which are automatically
feature sharing in multi-task learning across all tasks? learned in an end-to-end manner, whilst the shared network
Most multi-task learning network architectures for com- learns a compact global feature pool across all tasks.
puter vision are designed based on existing CNN architec- Figure 2 shows a detailed visualisation of our network
tures. For example, Cross-Stitch Networks [20] contain one based on VGG-16 [27], illustrating the encoder half of Seg-
standard feed-forward network per task, with cross-stitch Net. The decoder half of SegNet is then symmetric to VGG-
units to allow features to be shared across tasks. The self- 16. As shown, each attention module learns a soft attention
supervised approach of [6], based on the ResNet101 archi- mask, which itself is dependent on the features in the shared
1872
image conv conv pool conv conv pool conv conv conv pool conv conv conv pool conv conv conv pool
Task 1
Attention
Attention Attention Attention Attention
Module
Module Module conv pool Module Module
Task 2
Attention
Attention Attention
Attention Attention Attention
Attention
Module
Module Module
Module conv pool Module Module
Module
u u
Merge Merge
g h p f f g h p
Attention Module for Encoder pool samp Attention Module for Decoder
Figure 2: Visualisation of MTAN based on VGG-16, showing the encoder half of SegNet (with the decoder half being
symmetrical to the encoder). Task one (green) and task two (blue) have their own set of attention modules, which link with
the shared network (grey). The middle attention module has its structure exposed for visualisation, which is further expanded
in the bottom section of the figure, showing both the encoder and decoder versions of the module. All attention modules have
the same design, although their weights are individually learned.
(j) (j)
network at the corresponding layer. Therefore, the features task i as ai . The task-specific features âi in this layer,
in the shared network, and the soft attention masks, can be are then computed by element-wise multiplication of the at-
learned jointly to maximise the generalisation of the shared tention masks with the shared features:
features across multiple tasks, whilst simultaneously max-
(j) (j)
imising the task-specific performance due to the attention âi = ai ⊙ p(j) , (1)
masks.
where ⊙ denotes element-wise multiplication.
3.2. Task Specific Attention Module As shown in Figure 2, the first attention module in the
encoder takes as input only features in the shared network.
The attention module is designed to allow the task- But for subsequent attention modules in block j, the input is
specific network to learn task-related features, by applying formed by a concatenation of the shared features u(j) , and
a soft attention mask to the features in the shared network, the task-specific features from the previous layer âi
(j−1)
:
with one attention mask per task per feature channel. We
denote the shared features in the j th block of the shared net-
h i
(j) (j) (j) (j−1)
a i = hi gi u(j) ; f (j) âi , j ≥ 2 (2)
work as p(j) , and the learned attention mask in this layer for
1873
(j) (j)
Here, f (j) , gi , hi are convolutional layers with batch • For surface normals (only available in NYUv2), we ap-
(j) ply an element-wise dot product at each normalised pixel
normalisation, following a non-linear activation. Both gi
(j)
and hi are composed of [1 × 1] kernels presenting the ith with the ground-truth map:
task-specific attention mask in block j. f (j) is composed 1 X
of [3 × 3] kernels representing a shared feature extractor for L3 (X, Y3 ) = − Y3 (p, q) · Ŷ3 (p, q). (6)
pq p,q
passing to another attention module, following by a pooling
or sampling layer to match the corresponding resolution.
The attention mask, following a sigmoid activation to en- For image classification tasks, we consider each dataset
(j) as one task for which each dataset represents each individ-
sure ai ∈ [0, 1], is learned in a self-supervised fashion
(j) ual classification task for one domain. We apply standard
with back-propagation. If ai → 1 such that the mask be- cross-entropy loss for all classification tasks.
comes an identity map, the attended feature maps are equiv-
alent to global feature maps and the tasks share all the fea- 4. Experiments
tures. Therefore, we expect the performance to be no worse
than that of a shared multi-task network, which splits into In this section, we evaluate our proposed method on two
individual tasks only at the end of the network, and we show types of tasks: one-to-many predictions for image-to-image
results demonstrating this in Section 4. regression tasks in Section 4.1 and many-to-many predic-
tions for image classification tasks (Visual Decathlon Chal-
3.3. The Model Objective lenge) in Section 4.2.
In general multi-task learning with K tasks, input X and 4.1. Image-to-Image Prediction (One-to-Many)
task-specific labels Yi , i = 1, 2, · · · , K, the loss function
is defined as, In this section, we evaluate MTAN built upon SegNet
[1] on image-to-image prediction tasks. We first introduce
K
X the datasets used for validation in Section 4.1.1, and sev-
Ltot (X, Y1:K ) = λi Li (X, Yi ). (3) eral baselines for comparison in Section 4.1.2. In Section
i=1 4.1.3, we introduce a novel adaptive weighting method, and
in Section 4.1.4 we show the effectiveness of MTAN with
This is the linear combination of task-specific losses Li with various weighting methods compared with single and multi-
task weightings λi . In our experiments, we study the effect task baseline methods. We explore how the performance
of different weighting schemes on various multi-task learn- of our method scales with task complexity in Section 4.1.5
ing approaches. and we show visualisations of the learned attention masks
For image-to-image prediction tasks, we consider each in Section 4.1.6.
mapping from input data X to a set of labels Yi as one task
with total three tasks for evaluation. In each loss function, 4.1.1 Datasets
Ŷ represents the network’s prediction, and Y represents the
ground-truth label. CityScapes. The CityScapes dataset [4] consists of high
resolution street-view images. We use this dataset for two
• For semantic segmentation, we apply a pixel-wise cross- tasks: semantic segmentation and depth estimation. To
entropy loss for each predicted class label from a depth- speed up training, all training and validation images were
softmax classifier. resized to [128 × 256]. The dataset contains 19 classes for
pixel-wise semantic segmentation, together with ground-
1 X truth inverse depth labels. We pair the depth estimation
L1 (X, Y1 ) = − Y1 (p, q) log Ŷ1 (p, q). (4)
pq p,q task with three levels of semantic segmentation using 2, 7
or 19 classes (excluding the void group in 7 and 19 classes).
Labels for the 19 classes and the coarser 7 categories are
• For depth estimation, we apply an L1 norm comparing
defined as in the original CityScapes dataset. We then fur-
the predicted and ground-truth depth. We use true depth
ther create a 2-class dataset with only background and fore-
for the NYUv2 indoor scene dataset, and inverse depth in
ground objects. The details of these segmentation classes
CityScapes outdoor scene dataset as standard, which can
are presented in Table 1. We perform multi-task learning
more easily represent points at infinite distances, such as
for 7-class CityScapes dataset in Section 4.1.4. We compare
the sky:
the 2/7/19-class results in Section 4.1.5, with visualisation
1 X of these attention maps in Section 4.1.6.
L2 (X, Y2 ) = |Y2 (p, q) − Ŷ2 (p, q)|. (5) NYUv2. The NYUv2 dataset [21] is consisted with
pq p,q
RGB-D indoor scene images. We evaluate performances
1874
on three learning tasks: 13-class semantic segmentation de- • Multi-Task, Dense: A shared network together with
fined in [5], true depth data which is recorded by depth cam- task-specific networks, where each task-specific network
eras from Microsoft Kinect, and surface normals which are receives all features from the shared network, without
provided in [7]. To speed up training, all training and vali- any attention modules.
dation images were resized to [288 × 384] resolution. • Multi-Task, Cross-Stitch: The Cross-Stitch Network
Compared to CityScapes, NYUv2 contains images of in- [20], a previously proposed adaptive multi-task learning
door scenes, which are much more complex since the view- approach, which we implemented on SegNet.
points can vary significantly, changable lighting conditions
are present, and the appearance for each object class shifts Note that all the baselines were designed to have at least
widely in texture and shape. We evaluate performance on as many parameters than our proposed MTAN, and were
different datasets, together with different numbers of tasks, tested to validate that our proposed method’s better perfor-
and further with different class complexities, in order to at- mance is due to the attention modules, rather than simply
tain a comprehensive understanding on how our proposed due to the increase in network parameters.
method behaves and scales under a range of scenarios.
4.1.3 Dynamic Weight Average
2-class 7-class 19-class
void void For most multi-task learning networks, training multiple
flat road, sidewalk tasks is difficult without finding the correct balance between
construction building, wall, fence those tasks, and recent approaches have attempted to ad-
background dress this issue [3, 14]. To test our method across a range
object pole, traffic light, traffic sign
nature vegetation, terrain
of weighting schemes, we propose a simple yet effective
sky sky
adaptive weighting method, named Dynamic Weight Aver-
age (DWA). Inspired by GradNorm [3], this learns to av-
human person, rider
foreground erage task weighting over time by considering the rate of
vehicle carm truck, bus, caravan, trailer, train, motorcycle
change of loss for each task. But whilst GradNorm requires
access to the network’s internal gradients, our DWA pro-
Table 1: Three levels of semantic classes for the CityScapes
posal only requires the numerical task loss, and therefore
data used in our experiments.
its implementation is far simpler.
With DWA, we define the weighting λk for task k as:
4.1.2 Baselines K exp(wk (t − 1)/T ) Lk (t − 1)
λk (t) := P , wk (t − 1) = ,
Most image-to-image multi-task learning architectures are i exp(w i (t − 1)/T ) L k (t − 2)
designed based on specific feed-forward neural networks, (7)
or implemented on varying network architectures, and thus
they are typically not directly comparable based on pub- Here, wk (·) calculates the relative descending rate in the
lished results. Our method is general and can be applied to range (0, +∞), t is an iteration index, and T represents a
any feed-forward neural network, and so for a fair compar- temperature which controls the softness of task weighting,
ison, we implemented 5 different network architectures (2 similar to [12]. A large T results in a more even distri-
single-task + 3 multi-task) based on SegNet [1], which we bution between different tasks. If T is large enough, we
consider as baselines: have λi ≈ 1, and tasks are weighted equally. Finally, the
softmax operator, which is multiplied by K, ensures that
P
• Single-Task, One Task: The vanilla SegNet for single i λi (t) = K.
task learning. In our implementation, the loss value Lk (t) is calculated
as the average loss in each epoch over several iterations.
• Single-Task, STAN: A Single-Task Attention Network,
Doing so reduces the uncertainty from stochastic gradient
where we directly apply our proposed MTAN whilst only
descent and random training data selection. For t = 1, 2,
performing a single task.
we initialise wk (t) = 1, but any non-balanced initialisation
• Multi-Task, Split (Wide, Deep): The standard multi- based on prior knowledge could also be introduced.
task learning, which splits at the last layer for the final
prediction for each specific task. We introduce two ve-
4.1.4 Results on Image-to-Image Predictions
rions of Split: Wide, where we adjusted the number of
convolutional filters, and Deep, where we adjusted the We now evaluate the performance of our proposed MTAN
number of convolutional layers, until Split had at least method in image-to-image multi-task learning, based on
as many parameters as MTAN. the SegNet architecture. Using the 7-class version of the
1875
CityScapes dataset and 13-class version of NYUv2 dataset, icantly fewer parameters in some cases.
we compare all the baselines introduced in Section 4.1.2. Second, our method maintains high performance across
Training. For each network architecture, we ran experi- different loss function weighting schemes, and is more ro-
ments with three types of weighting methods: equal weight- bust to the choice of weighting scheme than other meth-
ing, weight uncertainty [14], and our proposed DWA (with ods, avoiding the need for cumbersome tweaking of loss
hyper-parameter temperature T = 2, found empirically to weights. We illustrate the robustness of our method to the
be optimum across all architectures). We did not include weighting schemes with a comparison to the Cross-Stitch
GradNorm [3] because it requires a manual choice of sub- Network [20], by plotting learning curves in Figure 3 with
set network weights across all baselines, based on their spe- respect to the performance of three learning tasks in NYUv2
cific architectures, which distracts from a fair evaluation of dataset. We can clearly see that our network follows simi-
the architectures themselves. We trained all the models with lar learning trends across various weighting schemes, com-
ADAM optimiser [15] using a learning rate of 10−4 , with pared to the Cross-Stitch Network which produces notably
a batch size of 2 for NYUv2 dataset and 8 for CityScapes different behaviour across the different schemes.
dataset. During training, we halve the learning rate at 40k
Equal Weights DWA Weight Uncertainty
iterations, for a total of 80k iterations.
Normal (1 + Cos.)
Depth (Abs. Err.)
Cross-Stitch
55 0.8
Network
CityScales and NYUv2 datasets across all architectures, and 0.24
50 0.7
across all loss function weighting schemes. Results also 0.22
show the number of network parameters for each architec- 45
0.6
0.20
ture. Our MTAN method performs similarly to our base- 0 200 0 200 0 200
Epoch Epoch Epoch
Attention Network
line Dense in the CityScapes dataset, whilst only having Semantic (Pix. Acc.)
0.26
Normal (1 + Cos.)
Depth (Abs. Err.)
Multi-Task
55 0.8
less than half the number of parameters, and outperforms all 0.24
other baselines. For the more challenging NYUv2 dataset, 50
0.7
0.22
our method outperforms all baselines across all weighting 0.6
methods and all learning tasks. 0 200 0 200
0.20
0 200
Epoch Epoch Epoch
Segmentation Depth
#P. Architecture Weighting (Higher Better) (Lower Better) Figure 3: Validation performance curves on the NYUv2
mIoU Pix Acc Abs Err Rel Err dataset, across all three tasks (semantics, depth, normals,
2 One Task n.a. 51.09 90.69 0.0158 34.17 from left to right), showing robustness to loss function
3.04 STAN n.a. 51.90 90.87 0.0145 27.46
weighting schemes on the Cross-Stitch Network [20] (top)
Equal Weights 50.17 90.63 0.0167 44.73
1.75 Split, Wide Uncert. Weights [14] 51.21 90.72 0.0158 44.01 and our Multi-task Attention Network (bottom).
DWA, T = 2 50.39 90.45 0.0164 43.93
Equal Weights 49.85 88.69 0.0180 43.86
2 Split, Deep Uncert. Weights [14] 48.12 88.68 0.0169 39.73 Figure 4 then shows qualitative results on the CityScapes
DWA, T = 2 49.67 88.81 0.0182 46.63 validation dataset. We can see the advantage of our multi-
Equal Weights 51.91 90.89 0.0138 27.21 task learning approach over vanilla single-task learning,
3.63 Dense Uncert. Weights [14] 51.89 91.22 0.0134 25.36
DWA, T = 2 51.78 90.88 0.0137 26.67 where the edges of objects are clearly more pronounced.
Equal Weights 50.08 90.33 0.0154 34.49
≈2 Cross-Stitch [20] Uncert. Weights [14] 50.31 90.43 0.0152 31.36
DWA, T = 2 50.33 90.55 0.0153 33.37 4.1.5 Effect of Task Complexity
Equal Weights 53.04 91.11 0.0144 33.63
1.65 MTAN (Ours) Uncert. Weights [14] 53.86 91.10 0.0144 35.72 For further introspection into the benefits of multi-task
DWA, T = 2 53.29 91.09 0.0144 34.14 learning, we evaluated our implementations on CityScapes
across different numbers of semantic classes, with the depth
Table 2: 7-class semantic segmentation and depth estima- labels the same across all experiments. We trained the net-
tion results on CityScapes validation dataset. #P shows works with the same settings as in Section 4.1.4, with an
the number of network parameters, and the best perform- additional multi-task baseline Split (the standard version),
ing combination of multi-task architecture and weighting is which we found to perform better than the other modified
highlighted in bold. The top validation scores for each task versions. All networks are trained with equal weighting.
are annotated with boxes. Table 4 (left) shows the validation performance improve-
In particular, our method has two key advantages. First, ment across all multi-task implementations and the single-
due to the efficiency of having a single shared feature pool task STAN implementation, plotted relative to the perfor-
with attention masks automatically learning which features mance of the vanilla single-task learning on the CityScapes
to share, our method outperforms other methods without re- dataset. Interestingly, for only a 2-class setup, the single-
quiring extra parameters (column #P), and even with signif- task attention network (STAN) performs better than all
1876
Segmentation Depth Surface Normal
Table 3: 13-class semantic segmentation, depth estimation, and surface normal prediction results on the NYUv2 validation
dataset. #P shows the number of network parameters, and the best performing combination of multi-task architecture and
weighting is highlighted in bold. The top validation scores for each task are annotated with boxes.
Input Image
Grouth Truth
(Semantic)
Vanilla
Single-Task
Learning
Multi-Task
Attention
Network
Grouth Truth
(Depth)
Vanilla
Single-Task
Learning
Multi-Task
Attention
Network
Figure 4: CityScapes validation results on 7-class semantic labelling and depth estimation, trained with equal weighting. The
original images are cropped to avoid invalid points for better visualisation. The red boxes are regions of interest, showing the
effectiveness of the results provided from our method and single task method.
multi-task methods since it is able to fully utilise network available network parameters, which then leads to better re-
parameters in a simple manner for the simple task. How- sults. We also observe that, whilst the relative performance
ever, for greater task complexity, the multi-task methods gain increases for all implementations as the task complex-
encourage the sharing of features for a more efficient use of ity increases, our MTAN method increases at a greater rate.
1877
Single-Task, STAN Method #P. ImNet. Airc. C100 DPed DTD GTSR Flwr Oglt SVHN UCF Mean Score
15 Multi-Task, Split
Performance Gain (mIoU)
Multi-Task, Dense Scratch [23] 10 59.87 57.10 75.73 91.20 37.77 96.55 56.3 88.74 96.63 43.27 70.32 1625
Multi-Task, Cross-Stitch
Finetune [23] 10 59.87 60.34 82.12 92.82 55.53 97.53 81.41 87.69 96.55 51.20 76.51 2500
10 Multi-Task, MTAN (ours)
Feature [23] 1 59.67 23.31 63.11 80.33 45.37 68.16 73.69 58.79 43.54 26.8 54.28 544
Res. Adapt.[23] 2 59.67 56.68 81.20 93.88 50.85 97.05 66.24 89.62 96.13 47.45 73.88 2118
5 DAN [25] 2.17 57.74 64.12 80.07 91.30 56.54 98.46 86.05 89.67 96.77 49.38 77.01 2851
Piggyback [19] 1.28 57.69 65.29 79.87 96.99 57.45 97.27 79.09 87.63 97.24 47.48 76.60 2838
0 Parallel SVD [24] 1.5 60.32 66.04 81.86 94.23 57.82 99.24 85.74 89.25 96.62 52.50 78.36 3398
2-class 7-class 19-class MTAN (Ours) 1.74 63.90 61.81 81.59 91.63 56.44 98.80 81.04 89.83 96.88 50.63 77.25 2941
Table 4: Left: CityScapes performance gain in percentage for all implementations compared with the vanilla single-task
method. Right: Top-1 classification accuracy on the Visual Decathlon Challenge online test set. #P is the number of parame-
ters as a factor of a single-task implementation. The upper part of table presents results from single task learning baselines;
lower part of table presents results from multi-task learning baselines.
4.1.6 Attention Masks as Feature Selectors curacies, and assigns a cumulative score with a maximum
value of 10,000 (1,000 per task) based on these accuracies.
To understand the role of the proposed attention modules, in
The complete details about the challenge settings, evalua-
Figure 5 we visualise the first layer attention masks learned
tion, and datasets used, can be found at http://www.
with our network based on CityScapes dataset. We can see
robots.ox.ac.uk/˜vgg/decathlon/.
a clear difference in attention masks between the two tasks,
Table 4 (right) shows results for the online test set of the
with each mask working as a feature selector to mask out
challenge. As consistent with the prior works, we apply
uninformative parts of the shared features, and focus on
MTAN built on Wide Residual Network [31] with a depth
parts which are useful for each task. Notably, the depth
of 28, widening factor of 4, and a stride of 2 in the first
masks have a much higher contrast than the semantic masks,
convolutional layer of each block. We train our model us-
suggesting that whilst all shared features are generally use-
ing a batch size of 100, learning rate of 0.1 with SGD, and
ful for the semantic task, the depth task benefits more from
weight decay of 5 · 10−5 for all 10 classification tasks. We
extraction of task-specific features.
halve the learning rate every 50 epochs for a total of 300
Input Image Semantic Mask Semantic Features epochs. Then, we fine-tune 9 classification tasks (all ex-
cept ImageNet) with a learning rate 0.01 until convergence.
The results show that our approach surpasses most of the
baselines and is competitive with the current state-of-the-
Shared Features Depth Mask Depth Features
art, without the need for complicated regularisation strate-
gies such as applying DropOut [28], regrouping datasets by
size, or adaptive weight decay for each dataset, as required.
Input Image Semantic Mask Semantic Features
5. Conclusions
In this work, we have presented a new method for multi-
Shared Features Depth Mask Depth Features task learning, the Multi-Task Attention Network (MTAN).
The network architecture consists of a global feature pool,
together with task-specific attention modules for each task,
which allows for automatic learning of both task-shared
Figure 5: Visualisation of the first layer of 7-class semantic and task-specific features in an end-to-end manner. Exper-
and depth attention features of our proposed network. The iments on the NYUv2 and CityScapes datasets with mul-
colours for each image are rescaled to fit the data. tiple dense-prediction tasks, and on the Visual Decathlon
Challenge with multiple image classification tasks, show
that our method outperforms or is competitive with other
4.2. Visual Decathlon Challenge (Many-to-Many)
methods, whilst also showing robustness to the particular
Finally, we evaluate our approach on the recently in- task weighting schemes used in the loss function. Due
troduced Visual Decathlon Challenge, consisting of 10 in- to our method’s ability to share weights through atten-
dividual image classification tasks (many-to-many predic- tion masks, our method achieves this state-of-the-art per-
tions). Evaluation on this challenge reports per-task ac- formance whilst also being highly parameter efficient.
1878
References [15] Diederik P Kingma and Jimmy Ba. Adam: A method for
stochastic optimization. arXiv preprint arXiv:1412.6980,
[1] Vijay Badrinarayanan, Alex Kendall, and Roberto Cipolla. 2014.
Segnet: A deep convolutional encoder-decoder architecture
[16] Iasonas Kokkinos. Ubernet: Training a universal convolu-
for image segmentation. IEEE transactions on pattern anal-
tional neural network for low-, mid-, and high-level vision
ysis and machine intelligence, 39(12):2481–2495, 2017.
using diverse datasets and limited memory. In The IEEE
[2] Rich Caruana. Multitask learning. In Learning to learn, Conference on Computer Vision and Pattern Recognition
pages 95–133. Springer, 1998. (CVPR), July 2017.
[3] Zhao Chen, Vijay Badrinarayanan, Chen-Yu Lee, and An-
[17] Abhishek Kumar and Hal Daumé III. Learning task grouping
drew Rabinovich. Gradnorm: Gradient normalization for
and overlap in multi-task learning. In Proceedings of the
adaptive loss balancing in deep multitask networks. In Inter-
29th International Coference on International Conference on
national Conference on Machine Learning, pages 793–802,
Machine Learning, pages 1723–1730. Omnipress, 2012.
2018.
[18] Mingsheng Long, Jianmin Wang, Guiguang Ding, Jiaguang
[4] Marius Cordts, Mohamed Omran, Sebastian Ramos, Timo
Sun, and Philip S Yu. Transfer feature learning with joint
Rehfeld, Markus Enzweiler, Rodrigo Benenson, Uwe
distribution adaptation. In Proceedings of the IEEE inter-
Franke, Stefan Roth, and Bernt Schiele. The cityscapes
national conference on computer vision, pages 2200–2207,
dataset for semantic urban scene understanding. In Proceed-
2013.
ings of the IEEE Conference on Computer Vision and Pattern
[19] Arun Mallya, Dillon Davis, and Svetlana Lazebnik. Piggy-
Recognition, pages 3213–3223, 2016.
back: Adapting a single network to multiple tasks by learn-
[5] Camille Couprie, Clément Farabet, Laurent Najman, and
ing to mask weights. In Proceedings of the European Con-
Yann Lecun. Indoor semantic segmentation using depth in-
ference on Computer Vision (ECCV), pages 67–82, 2018.
formation. In International Conference on Learning Repre-
sentations (ICLR2013), April 2013, 2013. [20] Ishan Misra, Abhinav Shrivastava, Abhinav Gupta, and Mar-
tial Hebert. Cross-stitch networks for multi-task learning.
[6] Carl Doersch and Andrew Zisserman. Multi-task self-
In Proceedings of the IEEE Conference on Computer Vision
supervised visual learning. In The IEEE International Con-
and Pattern Recognition, pages 3994–4003, 2016.
ference on Computer Vision (ICCV), Oct 2017.
[21] Pushmeet Kohli Nathan Silberman, Derek Hoiem and Rob
[7] David Eigen and Rob Fergus. Predicting depth, surface nor-
Fergus. Indoor segmentation and support inference from
mals and semantic labels with a common multi-scale con-
rgbd images. In ECCV, 2012.
volutional architecture. In Proceedings of the IEEE Inter-
national Conference on Computer Vision, pages 2650–2658, [22] Sinno Jialin Pan and Qiang Yang. A survey on transfer learn-
2015. ing. IEEE Transactions on knowledge and data engineering,
[8] Theodoros Evgeniou and Massimiliano Pontil. Regular- 22(10):1345–1359, 2010.
ized multi–task learning. In Proceedings of the tenth ACM [23] Sylvestre-Alvise Rebuffi, Hakan Bilen, and Andrea Vedaldi.
SIGKDD international conference on Knowledge discovery Learning multiple visual domains with residual adapters. In
and data mining, pages 109–117. ACM, 2004. Advances in Neural Information Processing Systems, pages
[9] Georgia Gkioxari, Bharath Hariharan, Ross Girshick, and Ji- 506–516, 2017.
tendra Malik. R-cnns for pose estimation and action detec- [24] Sylvestre-Alvise Rebuffi, Hakan Bilen, and Andrea Vedaldi.
tion. arXiv preprint arXiv:1406.5212, 2014. Efficient parametrization of multi-domain deep neural net-
[10] Michelle Guo, Albert Haque, De-An Huang, Serena Ye- works. In Proceedings of the IEEE Conference on Computer
ung, and Li Fei-Fei. Dynamic task prioritization for multi- Vision and Pattern Recognition, pages 8119–8127, 2018.
task learning. In European Conference on Computer Vision, [25] Amir Rosenfeld and John K Tsotsos. Incremental learning
pages 282–299. Springer, 2018. through deep adaptation. IEEE transactions on pattern anal-
[11] Kaiming He, Xiangyu Zhang, Shaoqing Ren, and Jian Sun. ysis and machine intelligence, 2018.
Deep residual learning for image recognition. In Proceed- [26] Andrei A Rusu, Neil C Rabinowitz, Guillaume Desjardins,
ings of the IEEE conference on computer vision and pattern Hubert Soyer, James Kirkpatrick, Koray Kavukcuoglu, Raz-
recognition, pages 770–778, 2016. van Pascanu, and Raia Hadsell. Progressive neural networks.
[12] Geoffrey Hinton, Oriol Vinyals, and Jeff Dean. Distill- arXiv preprint arXiv:1606.04671, 2016.
ing the knowledge in a neural network. arXiv preprint [27] Karen Simonyan and Andrew Zisserman. Very deep convo-
arXiv:1503.02531, 2015. lutional networks for large-scale image recognition. arXiv
[13] Justin Johnson, Alexandre Alahi, and Li Fei-Fei. Perceptual preprint arXiv:1409.1556, 2014.
losses for real-time style transfer and super-resolution. In [28] Nitish Srivastava, Geoffrey Hinton, Alex Krizhevsky, Ilya
European Conference on Computer Vision, pages 694–711. Sutskever, and Ruslan Salakhutdinov. Dropout: a simple way
Springer, 2016. to prevent neural networks from overfitting. The Journal of
[14] Alex Kendall, Yarin Gal, and Roberto Cipolla. Multi-task Machine Learning Research, 15(1):1929–1958, 2014.
learning using uncertainty to weigh losses for scene geome- [29] Sebastian Thrun and Lorien Pratt. Learning to learn.
try and semantics. In Proceedings of the IEEE Conference Springer Science & Business Media, 2012.
on Computer Vision and Pattern Recognition, pages 7482– [30] Fei Wang, Mengqing Jiang, Chen Qian, Shuo Yang, Cheng
7491, 2018. Li, Honggang Zhang, Xiaogang Wang, and Xiaoou Tang.
1879
Residual attention network for image classification. In Pro-
ceedings of the IEEE Conference on Computer Vision and
Pattern Recognition, pages 3156–3164, 2017.
[31] Sergey Zagoruyko and Nikos Komodakis. Wide residual net-
works. In Edwin R. Hancock Richard C. Wilson and William
A. P. Smith, editors, Proceedings of the British Machine Vi-
sion Conference (BMVC), pages 87.1–87.12. BMVA Press,
September 2016.
1880