Meta-Learning in Neural Networks

By Timothy Hospedales Principal Researcher, SAIC-Cambridge

Introduction

AI methods are advancing across a range of applications from computer vision and natural language processing to autonomous control. There are many facets to AI’s capabilities that determine how useful it is in our lives. Besides the obvious metrics of peak accuracy or efficacy of an AI system at its task, other facets include: How effectively can it learn a new task from a small amount of data or experience? Can it perform, or even learn, within the limited hardware and battery power available on a handheld device? How reliably does it generalize to new situations that are out-of-distribution with respect to its training data? The conventional paradigm to improving AI technology relies on the ingenuity of human researchers and engineers to design learning algorithms and architectures that perform better across all these axes and more. In this post and related paper [1], we introduce a sub-discipline of AI called meta-learning that aims to accelerate AI progress by shifting the paradigm from handdesigning AI learners to one where machine learning is also applied to improve the learning algorithms themselves.

What is meta-learning?

Learning and Meta-Learning. Machine learning is the process of a program improving its performance at performing a task as it obtains increasing amounts of data or experience. Meta-learning, also known as learning-to-learn, is the process of a program repeatedly attempting to learn tasks and improving its learning strategy over several learning `episodes’. It falls under the broad umbrella of AutoML, which addresses reducing the labor intensity of developing and deploying machine learning by automating processes that are typically manual.

A Schematic. As illustrated in this schematic example, the learner trains a model such as a neural network to solve a task such as visual object recognition. The meta-learner then wraps the learning process and trains the learner to better solve learning tasks. The learner typically minimizes a loss function that measures the difference between the true label of an input and the label predicted by a neural network model. The meta-learner minimizes a meta-loss function such as the generalization error of the trained neural network on novel data.

Context. The idea of meta-learning has been around for several decades. However, it has recently surged in popularity due to some notable successes and its potential to help alleviate some bottlenecks in contemporary deep learning such as model data & compute efficiency of learning and model robustness.

Prior to the widespread use of deep learning, researchers hand designed data representations (such as SIFT in computer vision, or TFIDF in language processing) for use with machine learning. Deep learning popularized the joint learning of representations and predictive models, showing that performance could be significantly improved compared to hand designed representations. Nevertheless, the architecture and/or algorithm that defined the learner isstill usually hand designed. Deep meta-learning aims to further improve performance by learning the learning algorithm along with the model and representation.

A More Formal Definition

Conventional machine learning aims to produce a model \[f_θ\] that can be used to make predictions \[y ̂= f_θ (x)\] for data inputs \[x\]. For example, recognizing objects in images, or transcribing speech in audio. This is achieved by applying a learning algorithm \[\mathcal{A}\] that inputs a training dataset \[D^{tr}= \left\{(x_i,y_i)\right\} \] and outputs the trained model \[θ\]*.

The algorithm is usually hand-designed and typically performs a minimization with respect to some loss function ℒ, for example by iterative gradient descent with respect to \[θ\], \[θ^t=θ^{t-1}-α∇_θ L(D^{tr},θ)\] . The performance of the resulting model\[f_{θ^{*}}\] depends on many factors that go into the design of the learning algorithm, such as neural architecture and optimization strategy. These are made explicit above in the quantity \[ϕ\] that parameterizes learning algorithm \[\mathcal{A}_\phi\].

Meta-learning aims to derive new learners that lead to better learning performance when used for training. To achieve this, we may exploit a whole set of learning tasks (e.g., prior objects to recognize in computer vision), each of each of which is divided into a training and validation set \[D=\left\{D_{j}^{tr},D_{j}^{val}\right\}_{j=1}^{M}\]. In this case the meta-training procedure should optimize the learner to solve this whole set of datasets.

Meta-learning therefore involves two layers learning optimization: an inner optimization performed by the base learner \[\mathcal{A}_\phi\] updates the model \[θ\] to solve \[\mathcal{L}\], and an outer optimization performed by the meta-learner \[\mathcal{B}\] updates the base learner \[ϕ\] to solve \[\mathcal{L}^{meta}\]. Importantly, the outer optimization can search for the learning algorithm that leads to best validation performance after learning, rather than merely the best training performance as in conventional learning. In this sense meta-learning can be seen as related to the classic process of complexity control and hyper-parameter optimization by cross-validation. However, as we will see it can be significantly more powerful and flexible.

Meta-Learner Design

There are three key design axes that go into defining a particular instantiation of a meta-learning algorithm.

Meta-Representation. The meta-representation specifies the search space in which the meta-learner will search for an improved learner \[ϕ∈Φ\]. For example, neural architecture search (NAS) [2] searches the space of neural architectures; MAML [3] searches the space of initial conditions \[θ^0\] for the iterative optimization conducted by learner \[\mathcal{A}\], and amortized meta-learners such as ProtoNets [4] search the space of feed-forward models that map a training set into a predictive model. While meta-learning can encompass conventional hyper-parameter optimization (HPO), e.g., of regularizer strength or learning rate, a key distinction is that efficient gradient-based meta-optimization strategies can lead to substantially higher dimensional representations \[ϕ\] being optimized – scaling to millions of parameters [7] compared to dozens in classic HPO.

Meta-Optimizer. The meta-optimizer defines the search strategy used by the meta-learner \[\mathcal{B}\] to search for an improved learner. Common approaches include gradient-descent (e.g., used in [3,4]), reinforcement learning (RL) (e.g., used in [6]) and evolutionary optimization (e.g., used in [2]). Gradient-descent strategies that explicitly compute gradients \[dL^{meta}\diagup d\phi\] for meta-optimization tend to be fast and scale well to high dimensional representations but require that \[\mathcal{L}^{meta}\] and \[\mathcal{A}\] should be differentiable and representation \[\phi\] should be continuous. RL and evolutionary approaches impose fewer assumptions on differentiability but tend to be slower and scale poorly to high dimensional search spaces \[\phi\].

Meta-Objective. The meta-objective defines the goal of meta-learning and will be customized for each use-case. It is primarily determined by the design of the loss and associated inputs for the outer optimization \[\mathcal{L}^{meta}\]. For example, in meta-learning for few-shot learning, we use small datasets \[D^{tr}\] in order to search for a learning algorithm that is data efficient [3,4,6]. Alternatively, to find a learning algorithm that is fast in terms of clock time or iterations, one could define a meta-objective rewards faster progress in improving learning performance [5].

Meta-Learning Use Cases

The framework for meta-learning outlined above is quite general, and can be applied in many different ways to enhance AI pipelines, as reviewed comprehensively in [1].

Established Applications

Some famous examples include NAS, MAML, ProtoNets and Learning to Optimize (L2O), which are summarized in the schematic below in terms of the taxonomy above.

Emerging Applications

Beyond these most established directions for meta-learning, there are a wide variety of emerging application directions [1]. To name just a few:

Learnable Data Augmentation. Data augmentation improves generalization performance of models by augmenting the training set with new instances that are class-preserving transformations of existing instances, and is a crucial tool required to reach state of the art performance. For example, rotating and cropping images in computer vision object recognition. To further advance state of the art, learnable data augmentation [8] automates the design process to produce augmentation engines superior to the best hand-designed alternatives. One searches for the data augmentation scheme which, when applied to learning the base model, leads to high validation performance.

Dataset Distillation. Data Distillation aims to deal with slow neural network training on big datasets by discovering a small “distilled” training set that can be used to train similarly good models much more quickly than using the full training set [9,7]. This is achieved by allowing the meta-parameter \[\phi\] to define a synthetic training set, where the synthetic training set size is much smaller than a typical real dataset. The synthetic training set is used to train the model in the inner loop, and the outer loop optimizes for the synthetic set that leads to maximum validation performance. In future, such distilled datasets may accelerate other tasks such as NAS by replacing the real dataset in the inner loop with a compact synthetic set.

Meta-Learning for Domain Generalisation. Distribution-shift between training and testing distributions reduces model performance. This is an ubiquitous issue that is infamous for reducing model performance in practice to below the level expected by developers. For example, by changes in imaging conditions in computer vision (E.g., from sunlit to cloudy), or acoustic conditions in audio understanding (E.g., from reverberant to direct). Domain generalization techniques aim to train models with increased robustness to this practical challenge. Where multiple domains/datasets are available during training, meta-learning can optimize for cross-domain robustness by training a model on one dataset, and validating it on a disjoint dataset. Any relevant base algorithm hyper-parameters such as regularizers [10] are then chosen to optimize performance for the reliability of model generalisation under such domain shift.

Some Applications in Samsung Research

We have worked on a variety of meta-learning algorithms and applications in Samsung Research. Two examples from Samsung AI Center Cambridge include few-shot object detection, and domain adaptation.

Few-Shot Object Detection

Object detection in computer vision is the task of returning a list of known objects and associated bounding boxes within an image. Training object detectors requires extensive annotation, as per other tasks in computer vision. Given the growing number of object categories of interest, as well as the greater manual effort required to annotate bounding boxes for object detection compared to object recognition, there is extensive interest in few-shot object detector learning.

In the CVPR20 paper “Incremental Few-Shot Object Detection” [11], we introduce ONCE, a feed-forward (amortized) few-shot object detector. Here the meta-representation \[\phi\] is a neural network with two input branches that accept the training and testing data respectively. The training branch enrolls new objects to detect via a few training examples. The testing branch then inputs new images or video in which to search for further examples of the enrolled objects. Under the hood, the training branch synthesizes parameters required for the testing branch to do its job. Meta-learning in this architecture teaches the network’s training branch how to synthesize parameters required for the testing branch to detect examples of the few-shot enrolled objects. The parameter-synthesis process performed by the training branch can now replace the iterative process of backpropagation that is conventionally used to fit neural network parameters. This feed-forward parameter synthesis is not only more accurate than the standard approach, but also much more efficient, which makes it suitable for running on resource constrained devices such as phones. Finally, a key unique property of ONCE is the ability to efficiently enroll new categories incrementally during operation, without needing to update old categories, retrain, or revisit existing category data.

Domain Adaptation

Domain adaptation is the challenge of taking a model trained on data from a reference source distribution and adapting it to another target distribution, usually without assuming access to labels in the target domain. This is a widely relevant problem as in real applications the data distribution during deployment is often different from the lab distribution under which models are deployed by users. Standard benchmarks for domain adaptive object recognition consider for example adapting a model trained on artistic depictions of objects to recognizing real-world photos.

Domain adaptation is widely studied with hundreds of domain adaptive learning algorithms available. Any given domain adaptive algorithm corresponds to a base learning algorithm \[\mathcal{A}\] in the meta-learning context. In the ECCV’20 paper “Online Meta-Learning for Multi-Source and Semi-Supervised Domain Adaptation” [12], we observed that the accuracy of a neural model for object recognition in the target domain after domain-adaptive learning depends on the initialization chosen prior to domain adaptive learning, with some initializations leading to much better solutions than others.

We therefore explored optimizing the initial condition of a domain adaptive learner, by treating initialization as the meta-representation \[\phi\] to meta-learn. We search the space of initial conditions, performing domain adaptation (e.g., from clipart to photos) from each initial condition, and score each initial condition by the resulting performance after adaptation. This initial condition search is driven by the post-adaptation performance in the target domain, and thus ultimately leads to the best performing model in the target domain of interest. This MetaDA framework is complementary to the wide array of existing base learning DA algorithms, which each define different strategies of how-to-adapt given a particular initial condition for a recognition model and unlabeled data from a target domain of interest. Indeed, the results showed a consistent performance improvement across a range of base DA algorithms (DANN, MCD, JiGen), CNN architectures (AlexNet, ResNet-18, ResNet-34, and datasets (DomainNet, Office-Home, PACS).

As the first approach to meta-learning for domain adaptation, our MetaDA explored optimizing domain adaptation with respect to initial conditions. However, this is only scratching the surface of the potential impact of meta-learning on domain adaptation, as many other aspects of the base learner can be improved. The impact of meta-learning can therefore be expected to grow as other meta-parameters of the base learning algorithms used in DA are optimized.

Open Research Challenges

There are several underpinning methodological challenges that remain to be solved for meta-learning to achieve its full potential.

Computation and Memory Cost. In the most common gradient-based meta-learning frameworks, a challenge is that computing the meta-gradient (i.e., gradient of the meta-loss with respect to meta-parameters) derivatives \[dL^{meta}\diagup d\phi\] required by the outer optimization is expensive in computation and/or memory. Thus, numerous algorithms (such as reverse-mode, forward-mode, and implicit function-based differentiation) have been developed that accelerate this computation [1,7]. Nevertheless, there is more work to be done in developing efficient and scalable algorithms for optimizing learners.

Generalisation Within and Across Task Distributions. For meta-learning methods concerned with task families – such as few-shot learning -- one aims to find learners that can solve any task drawn from a distribution over tasks \[D\sim p(D)\] (e.g., object categories to recognize, environments for autonomous agents to navigate). This can be particularly challenging in practical applications where \[p(D)\] in broad – such as solving diverse tasks with robots [13]; or where there is a task-level domain-shift between the training and testing task distributions \[P_{tr}\left(D\right)\neq P_{te}(D)\] -- such as deploying few-shot learners trained on general benchmarks such as ImageNet to specialist tasks such as medical image analysis [14]. Practical successes will increasingly depend on future meta-learners being up to meeting these robustness challenges.

Conclusion

We have introduced the world of neural-network meta-learning from high-level concepts to design considerations, contemporary applications, and outstanding challenges. For more details, we recommend the introduction paper “Meta-learning in neural networks: A Survey”. Only time will tell if meta-learning will drive AI’s next decade as deep learning did in the last? But our money is on yes.

Reference

[1] T. M. Hospedales, A. Antoniou, P. Micaelli, and A. J. Storkey, “Meta- learning in neural networks: A survey,” IEEE Transactions on Pattern Analysis and Machine Intelligence, 2021.

[2] E. Real, A. Aggarwal, Y. Huang, and Q. V. Le, “Regularized evolution for image classifier architecture search,” in AAAI, 2019.

[3] C. Finn, P. Abbeel, and S. Levine, “Model-agnostic meta-learning for fast adaptation of deep networks,” in ICML, 2017.

[4] J. Snell, K. Swersky, and R. S. Zemel, “Prototypical networks for few shot learning,” in NIPS, 2017.

[5] M. Andrychowicz, M. Denil, S. G. Colmenarejo, M. W. Hoffman, D. Pfau, T. Schaul, and N. de Freitas, “Learning to learn by gradient descent by gradient descent,” in NIPS, 2016.

[6] N. Mishra, M. Rohaninejad, X. Chen, and P. Abbeel, “A simple neural attentive meta-learner,” in ICLR, 2018.

[7] J. Lorraine, P. Vicol, and D. Duvenaud, “Optimizing millions of hyperparameters by implicit differentiation,” in AISTATS, 2020.

[8] E. D. Cubuk, B. Zoph, D. Mané, V. Vasudevan, and Q. V. Le, “Autoaugment: Learning augmentation policies from data,” in CVPR, 2019.

[9] B. Zhao, K. R. Mopuri, and H. Bilen, “Dataset condensation with gradient matching,” in ICLR, 2021.

[10] Y. Li, Y. Yang, W. Zhou, and T. M. Hospedales, “Feature-critic networks for heterogeneous domain generalization,” in ICML, 2019.

[11] J.-M. Perez-Rua, X. Zhu, T. Hospedales, and T. Xiang, “Incremental few- shot object detection,” in CVPR, 2020.

[12] D. Li and T. Hospedales, “Online meta-learning for multi-source and semi- supervised domain adaptation,” in ECCV, 2020.

[13] T. Yu, D. Quillen, Z. He, R. Julian, K. Hausman, C. Finn, and S. Levine, “Meta-world: A benchmark and evaluation for multi-task and meta rein- forcement learning,” in CORL, 2019.

[14] E. Triantafillou, T. Zhu, V. Dumoulin, P. Lamblin, U. Evci, K. Xu, R. Goroshin, C. Gelada, K. Swersky, P.-A. Manzagol, and H. Larochelle, “Meta-dataset: A dataset of datasets for learning to learn from few examples,” in ICLR, 2020