显式梯度微分

显式梯度

../_images/explicit-gradient.png

显式梯度的思想是将梯度步骤视为可微函数,并尝试通过展开的优化路径进行反向传播。 即,给定

\[\boldsymbol{\theta}^{\prime} (\boldsymbol{\phi}) \triangleq \boldsymbol{\theta}_0 - \alpha \sum_{i=0}^{K-1} \nabla_{\boldsymbol{\theta}_i} \mathcal{L}^{\text{in}} (\boldsymbol{\phi},\boldsymbol{\theta}_i),\]

我们想要计算梯度 \(\nabla_{\boldsymbol{\phi}} \boldsymbol{\theta}^{\prime} (\boldsymbol{\phi})\)。 这通常是通过AutoDiff通过内部优化的展开迭代来完成的。

可微分函数优化器

通过将参数inplace作为False传递给update函数,我们可以使优化可微分。 这里是一个使torchopt.adam()可微分的示例。

opt = torchopt.adam()
# Define meta and inner parameters
meta_params = ...
fmodel, params = make_functional(model)
# Initialize optimizer state
state = opt.init(params)

for iter in range(iter_times):
    loss = inner_loss(fmodel, params, meta_params)
    grads = torch.autograd.grad(loss, params)
    # Apply non-inplace parameter update
    updates, state = opt.update(grads, state, inplace=False)
    params = torchopt.apply_updates(params, updates)

loss = outer_loss(fmodel, params, meta_params)
meta_grads = torch.autograd.grad(loss, meta_params)

可微分的面向对象编程元优化器

对于类似PyTorch的API(例如,step()),我们设计了一个基类torchopt.MetaOptimizer,将我们的函数优化器包装成可微分的OOP元优化器。

torchopt.MetaOptimizer(module, impl)

高级可微分优化器的基类。

torchopt.MetaAdaDelta(module[, lr, rho, ...])

可微分的AdaDelta优化器。

torchopt.MetaAdadelta

MetaAdaDelta 的别名

torchopt.MetaAdaGrad(module[, lr, lr_decay, ...])

可微分的AdaGrad优化器。

torchopt.MetaAdagrad

MetaAdaGrad 的别名

torchopt.MetaAdam(module[, lr, betas, eps, ...])

可微分的Adam优化器。

torchopt.MetaAdamW(module[, lr, betas, eps, ...])

可微分的AdamW优化器。

torchopt.MetaAdaMax(module[, lr, betas, ...])

可微分的AdaMax优化器。

torchopt.MetaAdamax

MetaAdaMax 的别名

torchopt.MetaRAdam(module[, lr, betas, eps, ...])

可微分的RAdam优化器。

torchopt.MetaRMSProp(module[, lr, alpha, ...])

可微分的RMSProp优化器。

torchopt.MetaSGD(module, lr[, momentum, ...])

可微分的随机梯度下降优化器。

通过将低级API torchopt.MetaOptimizer 与之前的功能优化器结合,我们可以实现高级API:

# Low-level API
optim = torchopt.MetaOptimizer(net, torchopt.sgd(lr=1.0))

# High-level API
optim = torchopt.MetaSGD(net, lr=1.0)

这里是一个使用OOP API torchopt.MetaAdam 进行元梯度计算的示例。

# Define meta and inner parameters
meta_params = ...
model = ...
# Define differentiable optimizer
opt = torchopt.MetaAdam(model)

for iter in range(iter_times):
    # Perform the inner update
    loss = inner_loss(model, meta_params)
    opt.step(loss)

loss = outer_loss(model, meta_params)
loss.backward()

CPU/GPU加速优化器

TorchOpt 通过使用 C++ OpenMP(CPU)和 CUDA(GPU)手动编写前向和反向函数来执行符号缩减,这大大提高了元梯度的计算效率。 用户可以通过将 use_accelerated_op 设置为 True 来使用加速优化器。 TorchOpt 将自动检测设备并分配相应的加速优化器。

# Check whether the `accelerated_op` is available:
torchopt.accelerated_op_available(torch.device('cpu'))

torchopt.accelerated_op_available(torch.device('cuda'))

net = Net(1).cuda()
optim = torchopt.Adam(net.parameters(), lr=1.0, use_accelerated_op=True)

通用工具

我们提供了torchopt.extract_state_dict()torchopt.recover_state_dict()函数来提取和恢复网络和优化器的状态。 默认情况下,提取的状态字典是一个引用(此设计用于累积多任务批量训练的梯度,例如MAML)。 你也可以设置by='copy'来提取状态字典的副本,或者设置by='deepcopy'来获得一个独立的副本。

torchopt.extract_state_dict(target, *[, by, ...])

提取目标状态。

torchopt.recover_state_dict(target, state)

恢复状态。

torchopt.stop_gradient(target)

停止输入对象的梯度计算。

这是一个使用示例。

net = Net()
x = nn.Parameter(torch.tensor(2.0), requires_grad=True)

optim = torchopt.MetaAdam(net, lr=1.0)

# Get the reference of state dictionary
init_net_state = torchopt.extract_state_dict(net, by='reference')
init_optim_state = torchopt.extract_state_dict(optim, by='reference')
# If set `detach_buffers=True`, the parameters are referenced as references while buffers are detached copies
init_net_state = torchopt.extract_state_dict(net, by='reference', detach_buffers=True)

# Set `copy` to get the copy of the state dictionary
init_net_state_copy = torchopt.extract_state_dict(net, by='copy')
init_optim_state_copy = torchopt.extract_state_dict(optim, by='copy')

# Set `deepcopy` to get the detached copy of state dictionary
init_net_state_deepcopy = torchopt.extract_state_dict(net, by='deepcopy')
init_optim_state_deepcopy = torchopt.extract_state_dict(optim, by='deepcopy')

# Conduct 2 inner-loop optimization
for i in range(2):
    inner_loss = net(x)
    optim.step(inner_loss)

print(f'a = {net.a!r}')

# Recover and reconduct 2 inner-loop optimization
torchopt.recover_state_dict(net, init_net_state)
torchopt.recover_state_dict(optim, init_optim_state)

for i in range(2):
    inner_loss = net(x)
    optim.step(inner_loss)

print(f'a = {net.a!r}')  # the same result

笔记本教程

查看笔记本教程在 Meta OptimizerStop Gradient