零阶梯度微分
进化策略
当内循环过程不可微分或希望消除前两种模式中由Hessian带来的沉重计算负担时,可以选择零阶微分。 零阶微分通常基于零阶估计(如有限差分)或进化策略(ES)来获取梯度。 ES-MAML和NAC基于ES成功解决了不可微分的优化问题。
TorchOpt 提供了基于ES的微分API。 与优化目标 \(f (\boldsymbol{\theta}): \mathbb{R}^n \to \mathbb{R}\) 不同,ES 优化了一个高斯平滑目标,定义为 \(\tilde{f}_{\sigma} (\boldsymbol{\theta}) = \mathbb{E}_{\boldsymbol{z} \sim \mathcal{N}( 0, {I}_d )} [ f (\boldsymbol{\theta} + \sigma \, \boldsymbol{z}) ]\),其中 \(\sigma\) 表示精度。 该目标的梯度为 \(\nabla_{\boldsymbol{\theta}} \tilde{f}_{\sigma} (\boldsymbol{\theta}) = \frac{1}{\sigma} \mathbb{E}_{\boldsymbol{z} \sim \mathcal{N}( 0, {I}_d )} [ f (\boldsymbol{\theta} + \sigma \, \boldsymbol{z}) \cdot \boldsymbol{z} ]\)。 基于这种技术,可以将双层过程视为一个整体,基于纯前向过程计算元梯度。 更多解释请参考 ES-MAML。
装饰器
|
返回一个用于应用零阶微分的装饰器。 |
类似于隐式梯度,我们也为ES方法使用装饰器。
功能API
基本的功能API是torchopt.diff.zero_order.zero_order(),它被用作前向过程零阶梯度程序的装饰器。
用户需要实现噪声采样函数,该函数将作为zero_order装饰器的输入。
这里我们展示了装饰器中使用的每个参数的具体含义。
distribution用于噪声采样分布。分布 \(\lambda\) 应该是球对称的,并且每个元素的方差恒定为 \(1\)。即:球对称:\(\mathbb{E}_{\boldsymbol{z} \sim \lambda} [ \boldsymbol{z} ] = \boldsymbol{0}\).
每个元素的恒定方差为\(1\):\(\mathbb{E}_{\boldsymbol{z} \sim \lambda} [ {\lvert z_i \rvert}^2 ] = 1\)。
例如,标准的多维正态分布 \(\mathcal{N} (\boldsymbol{0}, \boldsymbol{1})\)。
method用于不同类型的算法,我们支持'naive'(ES RL),'forward'(Forward-FD), 和'antithetic'(antithetic).\[\begin{align*} \text{naive} \qquad & \nabla_{\boldsymbol{\theta}} \tilde{f}_{\sigma} (\boldsymbol{\theta}) = \frac{1}{\sigma} \mathbb{E}_{\boldsymbol{z} \sim \lambda} [ f (\boldsymbol{\theta} + \sigma \, \boldsymbol{z}) \cdot \boldsymbol{z} ] \\ \text{forward} \qquad & \nabla_{\boldsymbol{\theta}} \tilde{f}_{\sigma} (\boldsymbol{\theta}) = \frac{1}{\sigma} \mathbb{E}_{\boldsymbol{z} \sim \lambda} [ ( f (\boldsymbol{\theta} + \sigma \, \boldsymbol{z}) - f (\boldsymbol{\theta}) ) \cdot \boldsymbol{z} ] \\ \text{antithetic} \qquad & \nabla_{\boldsymbol{\theta}} \tilde{f}_{\sigma} (\boldsymbol{\theta}) = \frac{1}{2 \sigma} \mathbb{E}_{\boldsymbol{z} \sim \lambda} [ (f (\boldsymbol{\theta} + \sigma \, \boldsymbol{z}) - f (\boldsymbol{\theta} + \sigma \, \boldsymbol{z}) ) \cdot \boldsymbol{z} ] \end{align*}\]argnums指定我们想要追踪元梯度的参数。num_samples指定了我们想要进行采样的次数。sigma用于精度。这是采样分布的缩放因子。
我们在以下部分展示伪代码。
# Functional API for zero-order differentiation
# 1. Customize the noise distribution via a distribution class
class Distribution:
def sample(self, sample_shape=torch.Size()):
# Sampling function for noise
# NOTE: The distribution should be spherical symmetric and with a constant variance of 1.
...
return noise_batch
distribution = Distribution()
# 2. Customize the noise distribution via a sampling function
def distribution(sample_shape=torch.Size()):
# Sampling function for noise
# NOTE: The distribution should be spherical symmetric and with a constant variance of 1.
...
return noise_batch
# 3. Distribution can also be an instance of `torch.distributions.Distribution`, e.g., `torch.distributions.Normal(...)`
distribution = torch.distributions.Normal(loc=0, scale=1)
# Decorator that wraps the function
@torchopt.diff.zero_order(distribution=distribution, method='naive', argnums=0, num_samples=100, sigma=0.01)
def forward(params, data):
# Forward optimization process for params
...
return objective # the returned tensor should be a scalar tensor
# Define params and get data
params, data = ..., ...
# Forward pass
loss = forward(params, data)
# Backward pass using zero-order differentiation
grads = torch.autograd.grad(loss, params)
面向对象编程API
|
零阶梯度模型的基类。 |
结合PyTorch的torch.nn.Module,我们还为ES设计了OOP API nn.ZeroOrderGradientModule。
nn.ZeroOrderGradientModule的核心思想是使梯度流向前传递到self.parameters()(在计算元梯度时可以是元参数)。
用户需要定义前向过程的零阶梯度程序forward()和一个噪声采样函数sample()。
from torchopt.nn import ZeroOrderGradientModule
# Inherited from the class ZeroOrderGradientModule
# Optionally specify the `method` and/or `num_samples` and/or `sigma` used for sampling
class Net(ZeroOrderGradientModule, method='naive', num_samples=100, sigma=0.01):
def __init__(self, ...):
...
def forward(self, batch):
# Forward process
...
return objective # the returned tensor should be a scalar tensor
def sample(self, sample_shape=torch.Size()):
# Generate a batch of noise samples
# NOTE: The distribution should be spherical symmetric and with a constant variance of 1.
...
return noise_batch
# Get model and data
net = Net(...)
data = ...
# Forward pass
loss = Net(data)
# Backward pass using zero-order differentiation
grads = torch.autograd.grad(loss, net.parameters())
笔记本教程
更多详情,请查看zero-order的笔记本教程。