One-Shot Learning With Memory-Augmented Neural Networks
One-Shot Learning With Memory-Augmented Neural Networks
One-Shot Learning With Memory-Augmented Neural Networks
Adam Santoro
Google DeepMind
Sergey Bartunov
Google DeepMind, National Research University Higher School of Economics (HSE)
Matthew Botvinick
Daan Wierstra
Timothy Lillicrap
Google DeepMind
SBOS @ SBOS . IN
Abstract
Despite recent breakthroughs in the applications
of deep neural networks, one setting that presents
a persistent challenge is that of one-shot learning. Traditional gradient-based networks require
a lot of data to learn, often through extensive iterative training. When new data is encountered,
the models must inefficiently relearn their parameters to adequately incorporate the new information without catastrophic interference. Architectures with augmented memory capacities, such as
Neural Turing Machines (NTMs), offer the ability to quickly encode and retrieve new information, and hence can potentially obviate the downsides of conventional models. Here, we demonstrate the ability of a memory-augmented neural network to rapidly assimilate new data, and
leverage this data to make accurate predictions
after only a few samples. We also introduce a
new method for accessing an external memory
that focuses on memory content, unlike previous
methods that additionally use memory locationbased focusing mechanisms.
1. Introduction
The current success of deep learning hinges on the ability to apply gradient-based optimization to high-capacity
models. This approach has achieved impressive results on
many large-scale supervised tasks with raw sensory input,
such as image classification (He et al., 2015), speech recognition (Yu & Deng, 2012), and games (Mnih et al., 2015;
Silver et al., 2016). Notably, performance in such tasks is
typically evaluated after extensive, incremental training on
large data sets. In contrast, many problems of interest re-
It has been proposed that neural networks with memory capacities could prove quite capable of meta-learning
(Hochreiter et al., 2001). These networks shift their bias
through weight updates, but also modulate their output by
learning to rapidly cache representations in memory stores
(Hochreiter & Schmidhuber, 1997). For example, LSTMs
trained to meta-learn can quickly learn never-before-seen
quadratic functions with a low number of data samples
(Hochreiter et al., 2001).
Usually, we try to choose parameters to minimize a learning cost L across some dataset D. However, for metalearning, we choose parameters to reduce the expected
learning cost across a distribution of datasets p(D):
Neural networks with a memory capacity provide a promising approach to meta-learning in deep networks. However,
the specific strategy of using the memory inherent in unstructured recurrent architectures is unlikely to extend to
settings where each new task requires significant amounts
of new information to be rapidly encoded. A scalable solution has a few necessary requirements: First, information
must be stored in memory in a representation that is both
stable (so that it can be reliably accessed when needed) and
element-wise addressable (so that relevant pieces of information can be accessed selectively). Second, the number
of parameters should not be tied to the size of the memory. These two characteristics do not arise naturally within
standard memory architectures, such as LSTMs. However, recent architectures, such as Neural Turing Machines
(NTMs) (Graves et al., 2014) and memory networks (Weston et al., 2014), meet the requisite criteria. And so, in this
paper we revisit the meta-learning problem and setup from
the perspective of a highly capable memory-augmented
neural network (MANN) (note: here on, the term MANN
will refer to the class of external-memory equipped networks, and not other internal memory-based architectures, such as LSTMs).
We demonstrate that MANNs are capable of meta-learning
in tasks that carry significant short- and long-term memory demands. This manifests as successful classification
of never-before-seen Omniglot classes at human-like accuracy after only a few presentations, and principled function
estimation based on a small number of samples. Additionally, we outline a memory access module that emphasizes
memory access by content, and not additionally on memory location, as in original implementations of the NTM
(Graves et al., 2014). Our approach combines the best of
two worlds: the ability to slowly learn an abstract method
for obtaining useful representations of raw data, via gradient descent, and the ability to rapidly bind never-beforeseen information after a single presentation, via an external
memory module. The combination supports robust metalearning, extending the range of problems to which deep
learning can be effectively applied.
(1)
To accomplish this, proper task setup is critical (Hochreiter et al., 2001). In our setup, a task, or episode, involves the presentation of some dataset D = {dt }Tt=1 =
{(xt , yt )}Tt=1 . For classification, yt is the class label for
an image xt , and for regression, yt is the value of a hidden function for a vector with real-valued elements xt , or
simply a real-valued number xt (here on, for consistency,
xt will be used). In this setup, yt is both a target, and
is presented as input along with xt , in a temporally offset manner; that is, the network sees the input sequence
(x1 , null), (x2 , y1 ), . . . , (xT , yT 1 ). And so, at time t the
correct label for the previous data sample (yt1 ) is provided as input along with a new query xt (see Figure 1 (a)).
The network is tasked to output the appropriate label for
xt (i.e., yt ) at the given timestep. Importantly, labels are
shuffled from dataset-to-dataset. This prevents the network
from slowly learning sample-class bindings in its weights.
Instead, it must learn to hold data samples in memory until the appropriate labels are presented at the next timestep, after which sample-class information can be bound
and stored for later use (see Figure 1 (b)). Thus, for a given
episode, ideal performance involves a random guess for the
first presentation of a class (since the appropriate label can
not be inferred from previous episodes, due to label shuffling), and the use of memory to achieve perfect accuracy
thereafter. Ultimately, the system aims at modelling the
predictive distribution p(yt |xt , D1:t1 ; ), inducing a corresponding loss at each time step.
This task structure incorporates exploitable metaknowledge: a model that meta-learns would learn to bind
data representations to their appropriate labels regardless
of the actual content of the data representation or label,
and would employ a general scheme to map these bound
representations to appropriate classes or function values
for prediction.
3. Memory-Augmented Model
3.1. Neural Turing Machines
The Neural Turing Machine is a fully differentiable implementation of a MANN. It consists of a controller, such as
a feed-forward network or LSTM, which interacts with an
external memory module using a number of read and write
heads (Graves et al., 2014). Memory encoding and retrieval
in a NTM external memory module is rapid, with vector
Figure 1. Task structure. (a) Omniglot images (or x-values for regression), xt , are presented with time-offset labels (or function values),
yt1 , to prevent the network from simply mapping the class labels to the output. From episode to episode, the classes to be presented
in the episode, their associated labels, and the specific samples are all shuffled. (b) A successful strategy would involve the use of an
external memory to store bound sample representation-class label information, which can then be retrieved at a later point for successful
classification when a sample from an already-seen class is presented. Specifically, sample data xt from a particular time step should be
bound to the appropriate class label yt , which is presented in the subsequent time step. Later, when a sample from this same class is
seen, it should retrieve this bound information from the external memory to make a prediction. Backpropagated error signals from this
prediction step will then shape the weight updates from the earlier steps in order to promote this binding strategy.
kt Mt (i)
,
k kt kk Mt (i) k
(2)
(4)
This memory is used by the controller as the input to a classifier, such as a softmax output layer, and as an additional
input for the next controller state.
3.2. Least Recently Used Access
In previous instantiations of the NTM (Graves et al., 2014),
memories were addressed by both content and location.
Location-based addressing was used to promote iterative
steps, akin to running along a tape, as well as long-distance
jumps across memory. This method was advantageous for
sequence-based prediction tasks. However, this type of access is not optimal for tasks that emphasize a conjunctive
coding of information independent of sequence. As such,
writing to memory in our model involves the use of a newly
designed access module called the Least Recently Used
Access (LRUA) module.
The LRUA module is a pure content-based memory writer
that writes memories to either the least used memory location or the most recently used memory location. This
module emphasizes accurate encoding of relevant (i.e., recent) information, and pure content-based retrieval. New
information is written into rarely-used locations, preserving recently encoded information, or it is written to the
last used location, which can function as an update of the
memory with newer, possibly more relevant information.
The distinction between these two options is accomplished
with an interpolation between the previous read weights
and weights scaled according to usage weights wtu . These
usage weights are updated at each time-step by decaying
the previous usage weights and adding the current read and
write weights:
wtu
u
wt1
wtr
wtw .
(5)
(7)
(8)
4. Experimental Results
4.1. Data
Two sources of data were used: Omniglot, for classification, and sampled functions from a Gaussian process (GP)
with fixed hyperparameters, for regression. The Omniglot
dataset consists of over 1600 separate classes with only a
few examples per class, aptly lending to it being called the
transpose of MNIST (Lake et al., 2015). To reduce the
risk of overfitting, we performed data augmentation by randomly translating and rotating character images. We also
created new classes through 90 , 180 and 270 rotations
of existing data. The training of all models was performed
on the data of 1200 original classes (plus augmentations),
with the rest of the 423 classes (plus augmentations) being
used for test experiments. In order to reduce the computational time of our experiments we downscaled the images
to 20 20.
4.2. Omniglot Classification
We performed a number of iterations of the basic task described in Section 2. First, our MANN was trained using
Figure 2. Omniglot classification. The network was given either five (a-b) or up to fifteen (c-d) random classes per episode, which were
of length 50 or 100 respectively. Labels were one-hot vectors in (a-b), and five-character strings in (c-d). In (b), first instance accuracy is
above chance, indicating that the MANN is performing educated guesses for new classes based on the classes it has already seen and
stored in memory. In (c-d), first instance accuracy is poor, as is expected, since it must make a guess from 3125 random strings. Second
instance accuracy, however, approaches 80% during training for the MANN (d). At the 100,000 episode mark the network was tested,
without further learning, on distinct classes withheld from the training set, and exhibited comparable performance.
M ODEL
1 ST
I NSTANCE (% C ORRECT )
2 ND
3 RD
4 TH
5 TH
10 TH
H UMAN
F EEDFORWARD
LSTM
MANN
34.5
24.4
24.4
36.4
57.3
19.6
49.5
82.8
92.4
19.5
62.5
98.1
70.1
21.1
55.3
91.0
71.8
19.9
61.0
92.6
81.4
22.8
63.6
94.9
M ODEL
C ONTROLLER
# OF C LASSES
1 ST
2 ND
K NN ( RAW PIXELS )
K NN ( DEEP FEATURES )
F EEDFORWARD
F EEDFORWARD
LSTM
5
5
5
5
5
5
4.0
4.0
0.0
0.0
0.0
0.0
36.7
51.9
0.2
9.0
8.0
69.5
41.9
61.0
0.0
14.2
16.2
80.4
45.7
66.3
0.2
16.9
25.2
87.9
48.1
69.3
0.0
21.8
30.9
88.4
57.0
77.5
0.0
25.5
46.8
93.1
F EEDFORWARD
LSTM
LSTM
15
15
15
15
15
15
15
0.5
0.4
0.0
0.0
0.1
0.1
0.0
18.7
32.7
0.1
2.2
12.8
62.6
35.4
23.3
41.2
0.0
2.9
22.3
79.3
61.2
26.5
47.1
0.0
4.3
28.8
86.6
71.7
29.1
50.6
0.0
5.6
32.2
88.7
77.7
37.0
60.0
0.0
12.7
43.4
95.3
88.4
LSTM
MANN
MANN
K NN ( RAW PIXELS )
K NN ( DEEP FEATURES )
F EEDFORWARD
LSTM
MANN (LRUA)
MANN (LRUA)
MANN (NTM)
10 TH
6. Acknowledgements
The authors would like to thank Ivo Danihelka and Greg
Wayne for helpful discussions and prior work on the NTM
and LRU Access architectures, as well as Yori Zwols,
and many others at Google DeepMind for reviewing the
manuscript.
References
Braun, Daniel A, Aertsen, Ad, Wolpert, Daniel M, and
Mehring, Carsten. Motor task variation induces structural learning. Current Biology, 19(4):352357, 2009.
Brazdil, Pavel B, Soares, Carlos, and Da Costa,
Joaquim Pinto. Ranking learning algorithms: Using ibl
and meta-learning on accuracy and time results. Machine
Learning, 50(3):251277, 2003.
Caruana, Rich. Multitask learning. Machine learning, 28
(1):4175, 1997.
Cowan, Nelson. The magical mystery four how is working
memory capacity limited, and why? Current Directions
in Psychological Science, 19(1):5157, 2010.
Giraud-Carrier, Christophe, Vilalta, Ricardo, and Brazdil,
Pavel. Introduction to the special issue on meta-learning.
Machine learning, 54(3):187193, 2004.
Graves, Alex, Wayne, Greg, and Danihelka, Ivo. Neural
turing machines. arXiv preprint arXiv:1410.5401, 2014.
He, Kaiming, Zhang, Xiangyu, Ren, Shaoqing, and Sun,
Jian. Delving deep into rectifiers: Surpassing humanlevel performance on imagenet classification. arXiv
preprint arXiv:1502.01852, 2015.
Hochreiter, Sepp and Schmidhuber, Jurgen. Long shortterm memory. Neural computation, 9(8):17351780,
1997.
Hochreiter, Sepp, Younger, A Steven, and Conwell, Peter R. Learning to learn using gradient descent. In Artificial Neural NetworksICANN 2001, pp. 8794. Springer,
2001.
Jankowski, Norbert, Duch, Wodzisaw, and Grabczewski,
Krzysztof. Meta-learning in computational intelligence,
volume 358. Springer Science & Business Media, 2011.
Lake, Brenden M, Salakhutdinov, Ruslan, and Tenenbaum,
Joshua B. Human-level concept learning through probabilistic program induction. Science, 350(6266):1332
1338, 2015.
Mnih, Volodymyr, Kavukcuoglu, Koray, Silver, David,
Rusu, Andrei A, Veness, Joel, Bellemare, Marc G,
Graves, Alex, Riedmiller, Martin, Fidjeland, Andreas K,
Ostrovski, Georg, et al. Human-level control through
deep reinforcement learning. Nature, 518(7540):529
533, 2015.
Rendell, Larry A, Sheshu, Raj, and Tcheng, David K.
Layered concept-learning and dynamically variable bias
management. In IJCAI, pp. 308314. Citeseer, 1987.
Supplementary Information
kt Mt (i)
,
k kt kk Mt (i) k
(17)
Next, these similarity measures are used to produce a readweight vector wtr , with elements computed according to a
softmax:
exp K kt , Mt (i)
r
.
(18)
wt (i) P
j exp K kt , Mt (j)
A memory, rt , is then retrieved using these read-weights:
X
rt
wtr (i)Mt (i).
(19)
i
(a)
Figure 7. MANN Architecture.
The controllers in our experiments are feed-forward networks or Long Short-Term Memories (LSTMs). For the
best performing networks, the controller is a LSTM with
200 hidden units. The controller receives some concatenated input (xt , yt1 ) (see section 7 for details) and updates its state according to:
g
f , g
i , g
o , u
= Wxh (xt , yt1 ) + Whh ht1 + bh , (9)
gf = (
gf ),
i
(10)
g = (
g ),
(11)
go = (
go ),
(12)
u = tanh(
u),
f
(13)
i
ct = g ct1 + g u,
o
(14)
ht = g tanh(ct ),
(15)
ot = (ht , rt )
(16)
where g
f , g
o , and g
i are the forget gates, output gates,
and input gates, respectively, bh are the hidden state biases, ct is the cell state, ht is the hidden state, rt is the vector read from memory, ot is the concatenated output of the
controller, represents element-wise multiplication, and
(, ) represents vector concatenation. Wxh are the weights
from the input (xt , yt1 ) to the hidden state, and Whh
are the weights between hidden states connected through
time. The read vector rt is computed using content-based
(20)
(22)
(23)
(24)
L() =
XX
t
(26)
8. Task
Either 5, 10, or 15 unique classes are chosen per episode.
Episode lengths are ten times the number of unique classes
(i.e., 50, 100, or 150 respectively), unless explicitly mentioned otherwise. Training occurs for 100 000 episodes.
At the 100 000 episode mark, the task continues; however,
data are pulled from a disjoint test set (i.e., samples from
classes 1201-1623 in the omniglot dataset), and weight updates are ceased. This is deemed the test phase.
For curriculum training, the maximum number of unique
classes per episode increments by 1 every 10 000 training
episodes. Accordingly, the episode length increases to 10
times this new maximum.
9. Parameters
9.0.1. O PTIMIZATION
Rmsprop was used with a learning rate of 1e4 and max
learning rate of 5e1 , decay of 0.95 and momentum 0.9.
9.0.2. F REE PARAMETER GRID SEARCH
A grid search was performed over number of parameters,
with the values used shown in parentheses: memory slots
(128), memory size (40), controller size (200 hidden units