An Interactive Introduction to Model-Agnostic Meta-Learning

Exploring the world of model-agnostic meta-learning and its variants.

This page is part of a multi-part series on Model-Agnostic Meta-Learning. If you are already familiar with the topic, use the menu on the right side to jump straight to the part that interests you. Otherwise, we suggest you start at the beginning.

How MAML works

Model-agnostic meta-learning (MAML) is a meta-learning approach to solve different tasks from simple regression to reinforcement learning but also few-shot learning. . To learn more about it, let us build an example from the ground up and then try to apply MAML. We will do this by alternating mathematical walk-throughs and interactive, as well as coding examples.

If you have applied machine learning before, you have probably already solved or attempted to solve a problem like the following: Training a model to solve one specific task, for example, to classify cats from dogs or to teach an agent to find its way through a maze. In these settings, if we are able to define a loss \(\mathcal{L}_\tau:\ \Phi \equiv \mathbb R^d \rightarrow \mathbb R\) for our task \(\tau\), which depends on the parameters \(\phi \in \Phi\) of a model, we can express our learning objective as \[ \underset{\phi}{\text{min}} \, \mathcal{L}_\tau (\phi) .\] We usually find the optimal \(\phi\) by progressively walking in the opposite direction of the gradient of \(\mathcal{L}_\tau\) with respect to \(\phi\), i.e., \[ \phi \leftarrow \phi - \alpha \nabla_\phi \mathcal{L}_\tau (\phi) ,\] also known as gradient descent. \(\mathcal{L}_\tau\) usually also depends on some data, and \(\alpha\) is the learning rate, controlling the size of the steps we want to take.

Unfortunately, applying this to regression or a few-shot task (i.e., with a very small dataset), the above method is known to perform poorly on, e.g., neural networks, since there is simply too little data for too many parameters, leading to overfitting. The key idea of MAML is to mitigate this problem by learning not only from the data regarding exactly our task but also from data of similar tasks. To incorporate this, we make an additional assumption, namely that \(\tau\) comes from some distribution of tasks \(p(\tau)\) and that we can sample freely from this distribution. Eventually, we want to use the data available from the other tasks in the distribution to be able to converge to a specific task \(\tau_i \sim p(\tau)\), which we can express in terms of an expectation over the distribution. \(\tau\) is now a random variable and \(\phi_\tau\) is a set of parameters for task \(\tau\). We may use different parameters for each task, use the same parameters for every task, or do something in between.

Additionally, we will not simply use the data from other tasks to find parameters that are optimal for all tasks, but keep the option to fine-tune our model, i.e., take additional optimizer steps on data from the new task \(\tau_i\). Afterward, we want to have converged to \(\tau_i\) and reuse the pre-fine-tune-version of the model for each new task. Thus, we can express our optimization objective as \[ \underset{\theta}{\text{min}} \, \mathbb{E}_\tau [ \mathcal{L}_\tau (U_\tau(\theta)) ] ,\] where \(U_\tau: \Phi \rightarrow \Phi \) is an optimization algorithm that maps \(\theta\) to a new parameter vector \( U_\tau(\theta) \), being the result of fine-tuning \(\theta\) on data from task \(\tau\), using optimizer \(U_\tau\). For the rest of this article, we assume \(U_\tau\) corresponds to performing gradient descent with a variable number of steps but don't let this limit your imagination of what algorithm \(U_\tau\) could be.

A word on terminology: In conventional machine learning settings, we consider trainable parameters that are tied to our task. However, the \(\theta\) in the above objective is learned with respect to a variety of tasks. This, together with the fact that it can further be regarded as the initialization of the optimizer \(U_\tau\), lets us interpret \(\theta\) to be above task-level and thus acquires the status of a meta-parameter. Consequently, optimizing such a meta-parameter corresponds to meta-learning.

Having set the above objective, we are already halfway there. The only thing that is left is to find a feasible optimizer for \(\theta\). Before we jump into how MAML solves this problem, we are going to take a look at a simple baseline, which will help us to digest the setting a bit better and which leads us directly to MAML.

Part 1: A simple baseline

Recalling our optimization objective \[ \underset{\theta}{\text{min}} \, \mathbb{E}_\tau [ \mathcal{L}_\tau (U_\tau(\theta)) ] ,\] the following approach mitigates dealing with \(U_\tau\), mostly by ignoring it exists, which would make the objective collapse to \[ \underset{\theta}{\text{min}} \, \mathbb{E}_\tau [ \mathcal{L}_\tau (\theta) ] ,\] i.e., the standard machine learning setting that we have already talked about. However, we are not operating on a few samples of one task this time but have a whole distribution of tasks at our disposal. Hence, we can reliably solve the simplified objective with gradient descent.

Omitting the update procedure \(U_\tau\), we expect the final \(\theta\) to be chosen such that fine-tuning the model on only a few samples of some task \(\tau_i\) from the distribution makes the model parameters close to optimal. This hope might seem naive, considering that we did not reason about why \(U_\tau\) might be disregarded, but simply disregarded it. But on the other hand, this part is not called "a simple baseline" for no reason.

Expectations are commonly approximated by an empirical mean over samples from the respective distribution, also known as Expected Risk Minimization (ERM). If we apply this here, the resulting gradient is also the empirical mean of the task gradients (\(n\) is the number of sampled tasks): \[ \theta \leftarrow \theta - \frac{\alpha}{n} \nabla_\theta \sum_i \mathcal{L}_{\tau_i} (\theta) .\]

Finn et al., the authors of MAML, call this type of baseline the pretrained model: we can simply pretrain over all available data and defer the problem of dealing with \(U_\tau\). Now, we can make use of the pretrained model by simply fine-tuning the final \(\theta\), the result of our pretraining, on a new task - which is exactly what we will do a bit further down!

Moving on, we will take a little detour and talk about some implementational aspects of the pretrained baseline. It will also serve us as a starting point to later implement MAML. Afterward, you can watch the pretrained model fail in a small experiment we prepared.

Implementing the Pretrained Model

If the above got all too theoretical for you, take a look at the following gist. It contains a simplistic implementation of an update step for the pretrained model in TensorFlow , a popular machine learning library. The implementation tries to emphasize that even if we differentiate between tasks when sampling the batch, the actual optimizer treats each sample the same.

The implementation is agnostic to the choice of the optimizer. We use the Adam optimizer to be congruent with the original paper.

Pretrained Model on a Sinusoid Problem (Regression)

In the following figure, you can experiment with a pretrained model trained by a collection of sinusoid regression tasks. The task distribution works as follows: Each task is represented by an amplitude \(A\) and a phase \(\varphi\) and requires the prediction of sinusoid \(f\): \[ f(x) := A \sin(x + \varphi),\] where \(A, \varphi\) are sampled uniformly from some predefined range. Different parameters yield different functions, \(f_1\) and \(f_2\), with possibly completely different function values and gradients. Take, for example, the following two tasks: Tasks \(\tau_1, \tau_2\) are both regression tasks on sinusoids \(f_1(x) := \sin (x - \frac{\pi}{2})\) and \(f_2(x) := \sin (x + \frac{\pi}{2})\) respectively. These two tasks' function values give completely contradicting information, as \[ f_1(x) = - f_2(x). \]

Before fitting the model, what do you expect to happen based on the position and the number of samples provided? Feel free to also experiment with the different settings: distributing the samples equispaced or squeezing all of them to a small range of the x-axis.

Ouch! That does not seem to work that well. Maybe you have already guessed that it would have been too easy. Remember our interpretation of what happens when omitting \(U_\tau\)? We said that we expect the \(\theta\) that minimizes the simplified objective can be fine-tuned easily into any task from the distribution. But it seems, such a \(\theta\) is either impossible or at the very least incredibly difficult to find. Subsequently, we will try to gain some insight into the difficulties of the problem.

Let us recall the original optimization objective, i.e., \[ \underset{\theta}{\text{min}} \, \mathbb{E}_\tau [ \mathcal{L}_\tau (U_\tau(\theta)) ] .\] We can augment this notation by giving \(U_\tau\) a superscript, i.e., write \(U^{(m)}_\tau\), indicating that we perform \(m\) steps of gradient descent. Then we recover the simplified objective of the pretrained model by setting \(m = 0\), as \[ U^{(0)}_\tau(\theta) = \theta .\] It should be emphasized that \(m = 0\) is not simply one of many special cases of our few-shot learning objective but rather indicates no fine-tuning, which is clearly not what we intended. We have already seen that for \(m =0\) the loss space with respect to some task samples becomes \[ \sum_i \mathcal{L}_\tau (\theta) ,\] i.e., a simple sum of loss spaces. From now on, we have to carefully distinguish between the task loss spaces which are defined by the individual \(\mathcal{L}_{\tau_1}\), ..., \(\mathcal{L}_{\tau_n}\) for tasks \(\tau_1\), ..., \(\tau_n\) and the accumulated loss space, defined by \[ \sum_i \mathcal{L}_\tau (U^{(m)}_\tau(\theta)) .\]

The following figure explores this representation visually by letting you control \(m\) to see how the resulting loss space changes. As we increase \(m\), the loss space and with it the position of the minimum change, indicating that the simplified objective of the pretrained model can be a poor approximation to the few-shot-meta-learning objective we would like to solve.

Having established the problems arising from omitting the fine-tuning function \(U_\tau\) from the optimization objective, we will finally turn to MAML. MAML does not disregard \(U_\tau\) but rather optimizes through it. In the next part, we will see how that goes about.

Part 2: Model-Agnostic Meta-Learning

We will now study MAML in detail, trying to optimize the previously established few-shot learning objective \[ \underset{\theta}{\text{min}} \, \mathbb{E}_\tau [ \mathcal{L}_\tau (U^{(m)}_\tau(\theta)) ] ,\] for \(m > 0\). In short, MAML optimizes the same \(\theta\) as the pretrained model, but in its optimization strategy, it acknowledges the effect of fine-tuning function \(U_\tau\) on the accumulated loss space.

Outline of the Algorithm

Let us jump right in and take a look at the three main steps of the method, given a (current) meta-parameter \(\theta\) :

Note, that \(\mathcal{L}_{\tau_i, \text{train}}\) and \(\mathcal{L}_{\tau_i, \text{test}}\) are two instances of the same loss function on the same task \( \tau_i \) and correspond to training or test data from this task (though \( \tau_i \) changes while iterating over \(i \) ). The easiest way to obtain \(\phi_i\), is to do a single step of gradient descent (\( \phi_i \) will not be optimal but most likely better then \( \theta \) ): \[ \phi_i = \theta - \alpha \nabla_\theta \mathcal{L}_{\tau_i, \text{train}}(\theta).\] Further, updating \(\theta\) requires us to evaluate the gradient of the individual task losses on a set of test data. We obtain the gradient of the overall loss as follows: \[ \nabla_\theta \mathcal{L}(\theta) = \sum_{i} \nabla_\theta \mathcal{L}_{\tau_i, \text{test}}(\phi_i) .\] Note that \(\phi_i = U_{\tau_i}(\theta)\) depends on \(\theta\), which means that we have to take a gradient through the optimizer \(U\). We can then update \(\theta\) via gradient descent, using a new learning rate \(\beta\): \[ \theta \leftarrow \theta - \beta \nabla_\theta \mathcal{L}(\theta).\] And that 🥁... is more or less everything that comprises the original MAML algorithm.

Implementing the Algorithm

However, a machine learning algorithm is not very useful unless we can execute it on a computer. While implementing the pretrained model was more or less straightforward, implementing MAML requires some more thought. Firstly, computing \(\phi_i\) is still straightforward; simply call the optimization algorithm of your choice (as long as it is gradient-based). However, how do we then compute the gradient through that optimization algorithm? It is actually not that complicated. Almost every modern machine learning framework (e.g., TensorFlow ), can differentiate through nearly arbitrary python code. Hence, if we can express our optimizer in a python function, TensorFlow can differentiate through it.

Below you find a gist that implements a simplistic version of the MAML update step. The optimizer is encoded within the function fastWeights. Still, the function also directly applies an input tensor to the optimized weights The name fastWeights is adapted from the original implementation of MAML, a suitable name given that we obtain the adapted (temporary) model parameters \(\phi_i\) as well as the corresponding test loss in one go.. We did this mainly for simplicity, but if you are interested in thorough reasoning about this design choice, you can read more about it in the comments under the gist.

Will it really work?

Before we study the MAML model on the sinusoid task distribution, let us spend some time trying to see MAML in action. Consider the problem in the figure below. As already established, our few-shot optimization objective was to find an optimal meta-parameter \(\theta\), which we can easily fine-tune on any task with only a few respective samples. The figure shows a task distribution of two different tasks and lets you move \(\theta\) around to make sure, in the spirit of MAML, that a single-step-fine-tuning can result in nearly optimal parameters for each task, respectively. Optimizing \(\theta\) not directly on the tasks, as the pretrained model would, but respecting \(\theta\)'s role as an initialization to the fine-tuning algorithm is what makes MAML both elegant and effective.

Returning to Sinusoids

After studying the math behind the MAML objective, as well as its intuition and implementation, it is time to evaluate it on the sinusoid example. Hopefully, MAML will produce better results than the pretrained model. You will have the opportunity to repeat the above experiments on a model that has been trained with MAML in this figure. Try to compare the optimization behavior of both the pretrained model and MAML and evaluate for yourself whether you think the MAML-trained model has found a good meta-initialization parameter \(\theta\).

So as you were hopefully able to verify, MAML produces results that are way closer to the actual sinusoid, despite being exposed to at most five samples.

The rest of this article is dedicated to introducing interesting variants of MAML. The next page starts with a general discussion about the difficulty of obtaining the MAML-meta-gradient, which leads directly to FOMAML, a simple first-order version of MAML. A slightly different first-order approach, but still in the spirit of MAML, is Reptile, which obtains meta-knowledge without an explicit meta-gradient. Lastly, iMAML approximates the meta-update by creating a dependency between task-loss and meta-parameter \(\theta\) and thereby bypasses some of the computationally more expensive parts of the original MAML.

Author Contributions

Luis Müller fabricated the visualization of MAML, FOMAML, Reptile and the Comparision. Max Ploner created the visualization of iMAML and the svelte elements and components. Both wrote the introduction together and contributed most of the text of the other parts. Thomas Goerttler came up with the idea and sketched out the project. He also wrote parts of the manuscript and helped with finalizing the document. Klaus Obermayer provided feedback on the project.

† equal contributors