# 元学习算法MAML简介及代码分析
论文: Model-Agnostic Meta-Learning for Fast Adaptation of Deep Networks Chelsea (opens new window)
代码: https://github.com/cbfinn/maml (opens new window)
ICML2017的一篇论文,作者Chelsea Finn
是斯坦福的老师,一不小心去作者主页 (opens new window)看了下,MIT和伯克利的学生,真强。^_^
元学习MAML论文介绍
模型无关元学习算法,即Model-Agnostic Meta-Learning Algorithm
(MAML)。
# 1.元学习(meta learning)
元学习即学会学习,区别与普通的深度学习过程。普通的深度学习具体到某一任务,如图像分类,即训练一个模型实现一个数据集内的图像分类,这种方法有一定的局限性,即模型只能在当前任务(task)上工作,不能应用到其他任务。譬如基于手写字识别数据集训练的分类模型不能用来实现猫和狗的分类。有没有一种方法,可以学会完成分类这一任务,不针对具体是实现哪些对象的分类,学会分类任务后再基于少量的具体数据训练学会是具体给猫狗分类还是给手写字分类。相当于说一个模型实现了原来多个模型的功能。
元学习训练模型是为了获得一个可以快速应用到小样本数据的新任务上的模型,元学习通过初步训练获得模型比较好的初值,再基于初值对具体任务在小样本训练数据上少量更新权重即可取得好的效果。
元学习还可以理解成是寻找一组具有较高敏感度的参数,基于找到的参数,只需要进行少量的迭代即可在新的任务上取得理想的结果。
元学习可应用于训练数据有限的Few-Shot Learning
任务。
# 2.模型无关元学习
# 2.1 元学习问题建模
元学习是在一系列任务上学习,目标是学习得到一个比较敏感的模型,使该模型能够基于小样本数据简单训练快速应用到新任务上。也就是说,元学习将一系列学习任务当作训练样本。
譬如,识别一个动物是不是狗是任务是否是飞机
的训练数据,即可快速学会判断天空中的一个物体是否是飞机。
使用数学公式描述:
单个任务表示为:
是输入 是输出 是损失函数 是初始输入变量的概率分布 是输入变量的状态转移分布 输入变量序列的长度,对于监督学习问题,其值为1
,应用在强化学习等中。 是针对具体任务的损失函数,如回归问题通常是均方误差(Mean Square Error, MSE),分类问题通常是交叉商(Cross Entropy, CE)。
在元学习(meta-learning)中,考虑多个任务
上图中
# 2.2 MAML算法
算法中参数更新分成两步,一次是更新
第一步,针对任务
表示元学习模型
第二步,针对元学习模型的优化为:
# 3.将MAML应用到回归分类任务上的算法流程
方程2和方程3分别是均方误差和交叉熵。
# 4.代码解读
MAML
原作者的代码是基于tensorflow 1.x
版本实现的,结构比较清晰。
模型封装了一个MAML
类 (opens new window),数据的加载在类DataGenerator
(opens new window)中。
main.py
的train函数
(opens new window)中定义了metatrain
的过程:
# metatrain_iterations是元学习模型训练的迭代此数
for itr in range(resume_itr, FLAGS.pretrain_iterations + FLAGS.metatrain_iterations):
feed_dict = {}
# not for omniglot
if 'generate' in dir(data_generator):
batch_x, batch_y, amp, phase = data_generator.generate()
if FLAGS.baseline == 'oracle':
batch_x = np.concatenate([batch_x, np.zeros([batch_x.shape[0], batch_x.shape[1], 2])], 2)
for i in range(FLAGS.meta_batch_size):
batch_x[i, :, 1] = amp[i]
batch_x[i, :, 2] = phase[i]
"""
# a: training data for inner gradient,
# b: test data for meta gradient
这里 数据被分成两部分`inputa`和`inputb`
`inputa`用来训练针对具体任务的模型,更新其权重
`inputb`用来测试基于`inputa`训练的模型,并计算对具体任务的模型在`intputb`的`losses`
`inputb`上的测试`loss`用来更新元模型,具体实现见`maml.py`中`task_metalearn`函数
"""
inputa = batch_x[:, :num_classes*FLAGS.update_batch_size, :]
labela = batch_y[:, :num_classes*FLAGS.update_batch_size, :]
inputb = batch_x[:, num_classes*FLAGS.update_batch_size:, :] # b used for testing
labelb = batch_y[:, num_classes*FLAGS.update_batch_size:, :]
feed_dict = {model.inputa: inputa, model.inputb: inputb, model.labela: labela, model.labelb: labelb}
if itr < FLAGS.pretrain_iterations:
# 前n步,预训练时只使用`loassa`更新元学习模型
input_tensors = [model.pretrain_op]
else:
input_tensors = [model.metatrain_op]
...
result = sess.run(input_tensors, feed_dict)
在 MAML
类construct_model
函数中定义有task_metalearn
函数,在这个函数中有使用num_updates
参数,num_updates
参数表示train
函数中的每个元模型训练迭代中针对某个任务的模型迭代次数,针对某个任务的模型每更新一次,在测试数据inputb
上计算1次losses
,更新某个任务的模型num_updates次
后,得到长度为num_updates
的list lossesb
,再用lossesb
来更新元模型。
def task_metalearn(inp, reuse=True):
""" Perform gradient descent for one task in the meta-batch. """
inputa, inputb, labela, labelb = inp
task_outputbs, task_lossesb = [], []
if self.classification:
task_accuraciesb = []
task_outputa = self.forward(inputa, weights, reuse=reuse) # only reuse on the first iter
task_lossa = self.loss_func(task_outputa, labela)
grads = tf.gradients(task_lossa, list(weights.values()))
if FLAGS.stop_grad:
grads = [tf.stop_gradient(grad) for grad in grads]
gradients = dict(zip(weights.keys(), grads))
fast_weights = dict(zip(weights.keys(), [weights[key] - self.update_lr*gradients[key] for key in weights.keys()]))
output = self.forward(inputb, fast_weights, reuse=True)
task_outputbs.append(output)
task_lossesb.append(self.loss_func(output, labelb))
for j in range(num_updates - 1):
loss = self.loss_func(self.forward(inputa, fast_weights, reuse=True), labela)
grads = tf.gradients(loss, list(fast_weights.values()))
if FLAGS.stop_grad:
grads = [tf.stop_gradient(grad) for grad in grads]
gradients = dict(zip(fast_weights.keys(), grads))
fast_weights = dict(zip(fast_weights.keys(), [fast_weights[key] - self.update_lr*gradients[key] for key in fast_weights.keys()]))
output = self.forward(inputb, fast_weights, reuse=True)
task_outputbs.append(output)
task_lossesb.append(self.loss_func(output, labelb))
task_output = [task_outputa, task_outputbs, task_lossa, task_lossesb]
训练结束得到元模型后,要将元模型应用到具体任务时,要先根据提供的样本数据(x,y)
对元模型进行微调test_num_updates
后,再使用微调后的模型在测试数据上输出测试结果,其过程参照task_metalearn
。这也就能解释测试时所用的类在训练时是没有的,为什么测试时模型可以输出测试的类别。正因为模型在测试时有个在少量测试数据上的微调的过程,可以理解成元学习模型先训练得到一个预训练权重,然后再在少量新的其他任务的训练数据上少里训练,然后在新任务的测试数据上验证。
flowchart LR
A(类别为a,b的训练数据) -->B[训练]--> C(元学习模型) -->E
D(类别为c,d的测试数据<少量>) -->E[微调fast_learning]-->F(类别为c,d的测试数据<大量>) -->G[测试]