This page is part of a multi-part series on Model-Agnostic Meta-Learning.
If you are already familiar with the topic, use the menu on the right
side to jump straight to the part that interests you. Otherwise,
we suggest you start at the beginning.
How MAML works
Model-agnostic meta-learning (MAML) is a meta-learning approach to solve different tasks from simple
regression to reinforcement learning but also few-shot learning. .
To learn more about it, let us build an example from the ground up and then try to apply MAML.
We will do this by alternating mathematical walk-throughs and interactive, as well as coding examples.
If you have applied machine learning
before, you have probably already solved or attempted to solve a problem like the following:
Training a model to solve one specific task, for example, to classify cats from dogs or to
teach an agent
to find its way through a maze. In these settings, if we are able to define a loss
for our task , which depends on
the parameters
of a model, we can express our learning objective as
We usually find the optimal by progressively walking in the opposite direction of the gradient of
with respect to , i.e.,
also known as gradient descent. usually also depends on some data, and
is the learning rate,
controlling the size of the steps we want to take.
Unfortunately, applying this to regression or a few-shot task (i.e., with a very small dataset), the
above method is
known to perform poorly on, e.g., neural networks, since there is simply too little data for too many
parameters, leading to overfitting.
The key idea of MAML is to
mitigate this problem by learning not only from the data regarding exactly our task but also from
data of similar tasks.
To incorporate this, we make an additional assumption, namely that comes from some distribution of
tasks and that we
can sample freely from this distribution. Eventually, we want to use the data available from the other tasks
in
the distribution to be able to converge to a specific task , which we can express in
terms of an
expectation over the distribution.
is now a random variable and is a set of
parameters for task . We may use different parameters for each
task, use the same parameters for every task, or do something in between.
Additionally, we will not simply use the data from other tasks to find parameters that
are optimal for all
tasks, but keep the option to fine-tune our model, i.e., take additional optimizer steps on data from
the new task .
Afterward, we want to have converged to and reuse the pre-fine-tune-version of the model for each
new task. Thus, we
can express our optimization objective as
where is an optimization algorithm that maps to a new
parameter vector , being the result of fine-tuning on data from task
, using optimizer .
For the rest of this article, we assume corresponds to
performing gradient descent with a variable number of steps but don't let this limit your imagination of
what
algorithm could be.
A word on terminology: In conventional machine learning settings, we consider
trainable parameters
that are tied to our task. However, the in the above objective is learned with respect to a
variety of tasks. This, together with the fact that
it can further be regarded as the initialization of the optimizer , lets us interpret
to be above
task-level and thus acquires the status of a meta-parameter. Consequently, optimizing such a
meta-parameter corresponds to
meta-learning.
Having set the above objective, we are already halfway there. The only thing that is left is to find a
feasible
optimizer
for . Before we jump into how MAML solves this problem, we are going to take a look at a simple
baseline, which will help us
to digest the setting a bit better and which leads us directly to MAML.
Part 1: A simple baseline
Recalling our optimization objective
the following approach mitigates dealing with , mostly by ignoring it exists, which would make the
objective
collapse to
i.e., the standard machine learning setting that we have already talked about. However,
we are not operating on a few samples of one task this time but
have a whole distribution
of tasks at our disposal.
Hence, we can reliably solve the
simplified objective with gradient descent.
Omitting the update procedure ,
we expect the final to be chosen such that fine-tuning the model
on only a few samples of some task from the distribution
makes the model parameters close to optimal.
This hope might seem naive,
considering that we did not reason about why might be disregarded, but simply disregarded it. But
on the other hand, this part
is not called "a simple baseline" for no reason.
Expectations are commonly approximated by an empirical mean over samples from the respective distribution,
also known as Expected Risk Minimization (ERM).
If we apply this here, the resulting gradient is also the empirical mean of the
task gradients ( is the number of sampled tasks):
Finn et al., the authors of MAML, call this type of
baseline the pretrained model:
we can simply pretrain over all available data and defer the problem of dealing with . Now, we can
make use of the pretrained model
by simply fine-tuning the final , the result of our pretraining, on a new task - which is exactly
what we will do a bit further down!
Moving on, we will take a little detour and talk about some
implementational aspects of the pretrained baseline.
It will also serve us as a starting point to later implement MAML. Afterward,
you can watch the pretrained model fail in a small experiment we prepared.
Implementing the Pretrained Model
If the above got all too theoretical for you, take a look at the following gist. It contains a
simplistic
implementation of an update step
for the pretrained model in TensorFlow
, a popular machine
learning library. The implementation
tries to emphasize
that
even if we differentiate between tasks when sampling the batch, the actual optimizer treats each sample the
same.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
The implementation is agnostic to the choice of the optimizer. We use the Adam optimizer to be
congruent
with the original paper.
Pretrained Model on a Sinusoid Problem (Regression)
In the following figure, you can experiment with a pretrained model trained by a collection
of sinusoid regression tasks.
The task distribution works as follows: Each task is represented by an amplitude
and a
phase and requires the prediction of sinusoid :
where are sampled uniformly from some predefined range.
Different parameters yield different functions, and , with possibly completely
different function values and gradients.
Take, for example, the following two tasks:
Tasks are both regression tasks on sinusoids and
respectively. These two tasks' function values give completely
contradicting information, as
Before fitting the model, what do you expect to happen based on the position and the
number of samples provided? Feel free to also experiment with the different settings: distributing the
samples equispaced or squeezing all of them to a small range
of the x-axis.
Experiment with a pretrained model (blue) by (a) changing the task on which it is
evaluated
(red) by operating the sliders for amplitude and phase of the sinusoid and (b) either setting up to 5 samples manually for
prediction by clicking at the plot or sampling them uniformly
by clicking the 🎲.
Ouch! That does not seem to work that well. Maybe you have already guessed that it would have been too easy.
Remember our interpretation of what happens when omitting ? We said that we expect the
that minimizes
the simplified objective can be fine-tuned easily into any task from the distribution. But it seems, such a
is either impossible or at the very least incredibly difficult to find. Subsequently, we will try
to gain some insight into the difficulties of the problem.
Let us recall the original optimization objective, i.e.,
We can augment this notation by giving a superscript, i.e., write , indicating
that we
perform steps of gradient descent. Then we recover the simplified objective of the pretrained model by
setting , as
It should be emphasized that is not simply one of many special cases of our few-shot learning
objective but rather indicates no fine-tuning, which is clearly not what we intended.
We have already seen that for the loss space with respect to
some task samples becomes
i.e., a simple sum of loss spaces. From now on, we have to carefully distinguish
between the task loss spaces
which are defined by the individual , ..., for tasks
, ..., and
the accumulated loss space, defined by
The following figure explores this representation visually by letting you
control to see how the resulting loss space changes. As we increase , the loss space and with it
the position of the minimum
change, indicating that the simplified objective of the pretrained model can be a poor approximation to the
few-shot-meta-learning objective we
would like to solve.
Accumulated Meta-Loss
The task loss spaces displayed below stem from a curve fitting (toy) problem,
where
we are interested in parameters and ,
with the task requiring the prediction of
where is a non-linearity we applied to make the resulting loss spaces a bit more interesting.
The plot above displays the accumulated loss space of the two tasks, with respect to the number of fine-tuning steps
, which
represents the empirical version of our few-shot-learning optimization objective, i.e.
Task 1 (a = 0.3, b = 0.05)
Task 2 (a = 0.83, b = 0.75)
Notice how the optimum of the accumulated loss space changes as you increase . It should become more than
obvious that assuming
and training on the resulting simplified objective function, as the pretrained model does, might result in
a completely wrong meta-parameter.
Having established the problems arising from omitting the fine-tuning function from
the optimization objective, we will finally turn to MAML. MAML does not disregard but rather
optimizes
through it.
In the next part, we will see how that goes about.
Part 2: Model-Agnostic Meta-Learning
We will now study MAML in detail, trying to optimize the previously established few-shot learning objective
for . In short, MAML optimizes the same as the pretrained model, but in its optimization
strategy, it
acknowledges the effect of fine-tuning function on the accumulated loss space.
Outline of the Algorithm
Let us jump right in and take a look at the three main steps of the method,
given a (current)
meta-parameter :
1. Sample a number of tasks from .
2. For each task, obtain , by minimizing
on a few training samples.
3. Update by gradient descent such that it minimizes on a few
test samples.
Note, that and
are two instances of the same loss function
on the same task and correspond to training or test data from
this task (though changes while iterating over ).
The easiest way to obtain , is to do a single step of gradient
descent ( will not be optimal but most likely better then ):
Further, updating requires us to evaluate the gradient of the individual task losses on a set of
test data. We obtain
the gradient of the overall loss as follows:
Note that depends on , which means that we have to take a gradient
through the optimizer
. We can then update via gradient descent, using a new learning rate :
And that 🥁... is more or less everything that comprises the original MAML algorithm.
Implementing the Algorithm
However, a machine learning algorithm is not very useful unless we can execute it on a computer. While
implementing the pretrained
model was more or less straightforward, implementing MAML requires some more thought. Firstly, computing
is still straightforward; simply call
the optimization algorithm of your choice (as long as it is gradient-based).
However, how do we then compute the gradient through that optimization algorithm?
It is actually not that complicated. Almost every modern machine learning framework (e.g., TensorFlow
), can
differentiate through
nearly arbitrary python code. Hence, if we can express our optimizer in a python function, TensorFlow
can differentiate through it.
Below you find a gist that implements a simplistic version of the MAML update step. The
optimizer is encoded within the function
fastWeights. Still, the function also directly applies an input tensor to the optimized weights
The name fastWeights is adapted from the
original implementation of MAML, a suitable name given that we obtain the adapted (temporary) model parameters
as well as the corresponding test loss in one go..
We did
this mainly for simplicity, but if you
are interested in thorough reasoning about this design choice, you can read more about it in the comments under the
gist.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Before we study the MAML model on the sinusoid task distribution, let us spend some
time trying to see MAML in action. Consider the problem in the figure below.
As already established, our few-shot optimization objective was to find an optimal meta-parameter
, which we can easily fine-tune
on any task with only a few respective samples. The figure shows a task distribution of two
different tasks and
lets you move around to make sure, in the spirit of MAML, that a single-step-fine-tuning can
result in nearly optimal parameters
for each task, respectively. Optimizing not directly on the tasks, as the pretrained model would,
but respecting 's role as
an initialization to the fine-tuning algorithm is what makes MAML both elegant and effective.
The two loss spaces represent two non-linear regression tasks, comprising a task distribution
. A darker color corresponds to a
lower loss. Discounted by a learning rate, vectors represent
the gradient descent directions of tasks , when starting at and fine-tuning on each task,
respectively.
You can manipulate the slider for the learning rate, as well as move the around to see how the
descent directions
change in each plot. If MAML had control over moving the it would optimize on the sum of the task
losses
and hence, the closer and
point to the respective minimum of their
task loss space, the better through the eyes of MAML. Can you find a that MAML would
regard "optimal"?
Returning to Sinusoids
After studying the math behind the MAML objective, as well as its intuition and implementation, it is
time to evaluate it on the sinusoid example.
Hopefully, MAML will produce better results
than the pretrained model.
You will have the opportunity to repeat the above experiments on a model that has been trained with
MAML in this figure.
Try to compare the optimization behavior of both the pretrained model and MAML and evaluate for yourself
whether
you think the MAML-trained model has found a good meta-initialization parameter .
−5−4−3−2−1012345−5−4−3−2−1012345
Experiment with a MAML model (blue) by (a) changing the task on which it is
evaluated
(red) by operating the sliders for amplitude and phase of the sinusoid and (b) either setting up to 5 samples manually for
prediction by clicking at the plot or sampling them uniformly
by clicking the 🎲.
So as you were hopefully able to verify, MAML produces results that are way closer to the actual sinusoid,
despite being exposed to at most five samples.
The rest of this article is dedicated to introducing interesting variants of MAML. The next page
starts with a general discussion about the difficulty of obtaining the MAML-meta-gradient, which leads
directly to FOMAML, a simple
first-order version of MAML. A slightly different first-order approach, but still in the spirit of MAML, is
Reptile,
which obtains meta-knowledge without an explicit meta-gradient.
Lastly, iMAML approximates the meta-update by creating a dependency between
task-loss and meta-parameter and
thereby bypasses some of the computationally more expensive parts of the original MAML.
Footnotes
The name fastWeights is adapted from the
original implementation of MAML, a suitable name given that we obtain the adapted (temporary) model parameters
as well as the corresponding test loss in one go.[↩]
Author Contributions
Luis Müller implemented the visualization of MAML, FOMAML, Reptile and the Comparision. Max Ploner created the visualization of iMAML and the svelte elements and components. Both wrote the introduction together and contributed most of the text of the other parts. Thomas Goerttler came up with the idea and sketched out the project. He also wrote parts of the manuscript and helped with finalizing the document. Klaus Obermayer provided feedback on the project.