模型无关的元学习

元强化学习在各种应用中取得了显著的成功。 模型无关元学习 (MAML) [FAL17] 是其中的先驱。 在本教程中,我们将逐步展示如何使用 TorchOpt 在少样本 Omniglot 分类上训练 MAML。 完整脚本位于 examples/few-shot/maml_omniglot.py

与现有的可微分优化器库(如higher)不同,这些库遵循PyTorch的设计,导致API不够灵活,TorchOpt通过代码级别提供了一种简单的构建方式。

概述

完成MAML训练流程有六个步骤:

  1. 加载数据集:加载Omniglot数据集;

  2. 构建网络:构建模型的神经网络架构;

  3. 训练:元训练;

  4. 测试: meta-test;

  5. 绘图:绘制结果;

  6. 管道:将步骤3-5合并在一起;

在接下来的部分中,我们将设置加载数据集、构建神经网络、训练测试和绘图,以成功运行MAML训练和评估流程。 以下是整体流程:

加载数据集

在你的Python代码中,只需导入torch并加载数据集,完整脚本位于examples/few-shot/support/omniglot_loaders.py

from .support.omniglot_loaders import OmniglotNShot
import torch

device = torch.device('cuda:0')
db = OmniglotNShot(
    '/tmp/omniglot-data',
    batchsz=args.task_num,
    n_way=args.n_way,
    k_shot=args.k_spt,
    k_query=args.k_qry,
    imgsz=28,
    rng=rng,
    device=device,
)

目标是训练一个用于少样本Omniglot分类的模型。

构建网络

TorchOpt 支持任何用户定义的 PyTorch 网络。以下是一个示例:

import torch, numpy as np
from torch import nn
import torch.optim as optim

net = nn.Sequential(
    nn.Conv2d(1, 64, 3),
    nn.BatchNorm2d(64, momentum=1.0, affine=True),
    nn.ReLU(inplace=False),
    nn.MaxPool2d(2, 2),
    nn.Conv2d(64, 64, 3),
    nn.BatchNorm2d(64, momentum=1.0, affine=True),
    nn.ReLU(inplace=False),
    nn.MaxPool2d(2, 2),
    nn.Conv2d(64, 64, 3),
    nn.BatchNorm2d(64, momentum=1.0, affine=True),
    nn.ReLU(inplace=False),
    nn.MaxPool2d(2, 2),
    nn.Flatten(),
    nn.Linear(64, args.n_way),
).to(device)

# We will use Adam to (meta-)optimize the initial parameters
# to be adapted.
meta_opt = optim.Adam(net.parameters(), lr=1e-3)

训练

定义train函数:

def train(db, net, meta_opt, epoch, log):
    net.train()
    n_train_iter = db.x_train.shape[0] // db.batchsz
    inner_opt = torchopt.MetaSGD(net, lr=1e-1)

    for batch_idx in range(n_train_iter):
        start_time = time.time()
        # Sample a batch of support and query images and labels.
        x_spt, y_spt, x_qry, y_qry = db.next()

        task_num = x_spt.size(0)

        # TODO: Maybe pull this out into a separate module so it
        # doesn't have to be duplicated between `train` and `test`?

        # Initialize the inner optimizer to adapt the parameters to
        # the support set.
        n_inner_iter = 5

        qry_losses = []
        qry_accs = []
        meta_opt.zero_grad()

        net_state_dict = torchopt.extract_state_dict(net)
        optim_state_dict = torchopt.extract_state_dict(inner_opt)
        for i in range(task_num):
            # Optimize the likelihood of the support set by taking
            # gradient steps w.r.t. the model's parameters.
            # This adapts the model's meta-parameters to the task.
            for _ in range(n_inner_iter):
                spt_logits = net(x_spt[i])
                spt_loss = F.cross_entropy(spt_logits, y_spt[i])
                inner_opt.step(spt_loss)

            # The final set of adapted parameters will induce some
            # final loss and accuracy on the query dataset.
            # These will be used to update the model's meta-parameters.
            qry_logits = net(x_qry[i])
            qry_loss = F.cross_entropy(qry_logits, y_qry[i])
            qry_acc = (qry_logits.argmax(dim=1) == y_qry[i]).float().mean()
            qry_losses.append(qry_loss)
            qry_accs.append(qry_acc.item())

            torchopt.recover_state_dict(net, net_state_dict)
            torchopt.recover_state_dict(inner_opt, optim_state_dict)

        qry_losses = torch.mean(torch.stack(qry_losses))
        qry_losses.backward()
        meta_opt.step()
        qry_losses = qry_losses.item()
        qry_accs = 100.0 * np.mean(qry_accs)
        i = epoch + float(batch_idx) / n_train_iter
        iter_time = time.time() - start_time

        print(
            f'[Epoch {i:.2f}] Train Loss: {qry_losses:.2f} | Acc: {qry_accs:.2f} | Time: {iter_time:.2f}'
        )
        log.append(
            {
                'epoch': i,
                'loss': qry_losses,
                'acc': qry_accs,
                'mode': 'train',
                'time': time.time(),
            }
        )

测试

定义test函数:

def test(db, net, epoch, log):
    # Crucially in our testing procedure here, we do *not* fine-tune
    # the model during testing for simplicity.
    # Most research papers using MAML for this task do an extra
    # stage of fine-tuning here that should be added if you are
    # adapting this code for research.
    net.train()
    n_test_iter = db.x_test.shape[0] // db.batchsz
    inner_opt = torchopt.MetaSGD(net, lr=1e-1)

    qry_losses = []
    qry_accs = []

    for batch_idx in range(n_test_iter):
        x_spt, y_spt, x_qry, y_qry = db.next('test')

        task_num = x_spt.size(0)

        # TODO: Maybe pull this out into a separate module so it
        # doesn't have to be duplicated between `train` and `test`?
        n_inner_iter = 5

        net_state_dict = torchopt.extract_state_dict(net)
        optim_state_dict = torchopt.extract_state_dict(inner_opt)
        for i in range(task_num):
            # Optimize the likelihood of the support set by taking
            # gradient steps w.r.t. the model's parameters.
            # This adapts the model's meta-parameters to the task.
            for _ in range(n_inner_iter):
                spt_logits = net(x_spt[i])
                spt_loss = F.cross_entropy(spt_logits, y_spt[i])
            inner_opt.step(spt_loss)

            # The query loss and acc induced by these parameters.
            qry_logits = net(x_qry[i]).detach()
            qry_loss = F.cross_entropy(qry_logits, y_qry[i])
            qry_acc = (qry_logits.argmax(dim=1) == y_qry[i]).float().mean()
            qry_losses.append(qry_loss.item())
            qry_accs.append(qry_acc.item())

            torchopt.recover_state_dict(net, net_state_dict)
            torchopt.recover_state_dict(inner_opt, optim_state_dict)

    qry_losses = np.mean(qry_losses)
    qry_accs = 100.0 * np.mean(qry_accs)

    print(f'[Epoch {epoch+1:.2f}] Test Loss: {qry_losses:.2f} | Acc: {qry_accs:.2f}')
    log.append(
        {
            'epoch': epoch + 1,
            'loss': qry_losses,
            'acc': qry_accs,
            'mode': 'test',
            'time': time.time(),
        }
    )

绘图

TorchOpt 支持任何用户定义的 PyTorch 网络和优化器。当然,输入和输出必须符合 TorchOpt 的 API。以下是一个示例:

def plot(log):
    # Generally you should pull your plotting code out of your training
    # script but we are doing it here for brevity.
    df = pd.DataFrame(log)

    fig, ax = plt.subplots(figsize=(6, 4))
    train_df = df[df['mode'] == 'train']
    test_df = df[df['mode'] == 'test']
    ax.plot(train_df['epoch'], train_df['acc'], label='Train')
    ax.plot(test_df['epoch'], test_df['acc'], label='Test')
    ax.set_xlabel('Epoch')
    ax.set_ylabel('Accuracy')
    ax.set_ylim(70, 100)
    fig.legend(ncol=2, loc='lower right')
    fig.tight_layout()
    fname = 'maml-accs.png'
    print(f'--- Plotting accuracy to {fname}')
    fig.savefig(fname)
    plt.close(fig)

管道

我们现在可以将所有组件组合在一起,并绘制结果。

log = []
for epoch in range(10):
    train(db, net, meta_opt, epoch, log)
    test(db, net, epoch, log)
    plot(log)
../_images/maml-accs.png

参考文献

[FAL17]

Chelsea Finn, Pieter Abbeel, 和 Sergey Levine。用于深度网络快速适应的模型无关元学习。在 Doina Precup 和 Yee Whye Teh 编辑的第34届国际机器学习会议论文集, ICML 2017, 澳大利亚新南威尔士州悉尼, 2017年8月6-11日,第70卷机器学习研究论文集,1126–1135页。PMLR, 2017。网址: http://proceedings.mlr.press/v70/finn17a.html