显式梯度微分
显式梯度
显式梯度的思想是将梯度步骤视为可微函数,并尝试通过展开的优化路径进行反向传播。 即,给定
我们想要计算梯度 \(\nabla_{\boldsymbol{\phi}} \boldsymbol{\theta}^{\prime} (\boldsymbol{\phi})\)。 这通常是通过AutoDiff通过内部优化的展开迭代来完成的。
可微分函数优化器
通过将参数inplace作为False传递给update函数,我们可以使优化可微分。
这里是一个使torchopt.adam()可微分的示例。
opt = torchopt.adam()
# Define meta and inner parameters
meta_params = ...
fmodel, params = make_functional(model)
# Initialize optimizer state
state = opt.init(params)
for iter in range(iter_times):
loss = inner_loss(fmodel, params, meta_params)
grads = torch.autograd.grad(loss, params)
# Apply non-inplace parameter update
updates, state = opt.update(grads, state, inplace=False)
params = torchopt.apply_updates(params, updates)
loss = outer_loss(fmodel, params, meta_params)
meta_grads = torch.autograd.grad(loss, meta_params)
可微分的面向对象编程元优化器
对于类似PyTorch的API(例如,step()),我们设计了一个基类torchopt.MetaOptimizer,将我们的函数优化器包装成可微分的OOP元优化器。
|
高级可微分优化器的基类。 |
|
可微分的AdaDelta优化器。 |
|
|
|
可微分的AdaGrad优化器。 |
|
|
|
可微分的Adam优化器。 |
|
可微分的AdamW优化器。 |
|
可微分的AdaMax优化器。 |
|
|
|
可微分的RAdam优化器。 |
|
可微分的RMSProp优化器。 |
|
可微分的随机梯度下降优化器。 |
通过将低级API torchopt.MetaOptimizer 与之前的功能优化器结合,我们可以实现高级API:
# Low-level API
optim = torchopt.MetaOptimizer(net, torchopt.sgd(lr=1.0))
# High-level API
optim = torchopt.MetaSGD(net, lr=1.0)
这里是一个使用OOP API torchopt.MetaAdam 进行元梯度计算的示例。
# Define meta and inner parameters
meta_params = ...
model = ...
# Define differentiable optimizer
opt = torchopt.MetaAdam(model)
for iter in range(iter_times):
# Perform the inner update
loss = inner_loss(model, meta_params)
opt.step(loss)
loss = outer_loss(model, meta_params)
loss.backward()
CPU/GPU加速优化器
TorchOpt 通过使用 C++ OpenMP(CPU)和 CUDA(GPU)手动编写前向和反向函数来执行符号缩减,这大大提高了元梯度的计算效率。
用户可以通过将 use_accelerated_op 设置为 True 来使用加速优化器。
TorchOpt 将自动检测设备并分配相应的加速优化器。
# Check whether the `accelerated_op` is available:
torchopt.accelerated_op_available(torch.device('cpu'))
torchopt.accelerated_op_available(torch.device('cuda'))
net = Net(1).cuda()
optim = torchopt.Adam(net.parameters(), lr=1.0, use_accelerated_op=True)
通用工具
我们提供了torchopt.extract_state_dict()和torchopt.recover_state_dict()函数来提取和恢复网络和优化器的状态。
默认情况下,提取的状态字典是一个引用(此设计用于累积多任务批量训练的梯度,例如MAML)。
你也可以设置by='copy'来提取状态字典的副本,或者设置by='deepcopy'来获得一个独立的副本。
|
提取目标状态。 |
|
恢复状态。 |
|
停止输入对象的梯度计算。 |
这是一个使用示例。
net = Net()
x = nn.Parameter(torch.tensor(2.0), requires_grad=True)
optim = torchopt.MetaAdam(net, lr=1.0)
# Get the reference of state dictionary
init_net_state = torchopt.extract_state_dict(net, by='reference')
init_optim_state = torchopt.extract_state_dict(optim, by='reference')
# If set `detach_buffers=True`, the parameters are referenced as references while buffers are detached copies
init_net_state = torchopt.extract_state_dict(net, by='reference', detach_buffers=True)
# Set `copy` to get the copy of the state dictionary
init_net_state_copy = torchopt.extract_state_dict(net, by='copy')
init_optim_state_copy = torchopt.extract_state_dict(optim, by='copy')
# Set `deepcopy` to get the detached copy of state dictionary
init_net_state_deepcopy = torchopt.extract_state_dict(net, by='deepcopy')
init_optim_state_deepcopy = torchopt.extract_state_dict(optim, by='deepcopy')
# Conduct 2 inner-loop optimization
for i in range(2):
inner_loss = net(x)
optim.step(inner_loss)
print(f'a = {net.a!r}')
# Recover and reconduct 2 inner-loop optimization
torchopt.recover_state_dict(net, init_net_state)
torchopt.recover_state_dict(optim, init_optim_state)
for i in range(2):
inner_loss = net(x)
optim.step(inner_loss)
print(f'a = {net.a!r}') # the same result
笔记本教程
查看笔记本教程在 Meta Optimizer 和 Stop Gradient。