零阶梯度微分

进化策略

../_images/zero-order.png

当内循环过程不可微分或希望消除前两种模式中由Hessian带来的沉重计算负担时,可以选择零阶微分。 零阶微分通常基于零阶估计(如有限差分)或进化策略(ES)来获取梯度。 ES-MAMLNAC基于ES成功解决了不可微分的优化问题。

TorchOpt 提供了基于ES的微分API。 与优化目标 \(f (\boldsymbol{\theta}): \mathbb{R}^n \to \mathbb{R}\) 不同,ES 优化了一个高斯平滑目标,定义为 \(\tilde{f}_{\sigma} (\boldsymbol{\theta}) = \mathbb{E}_{\boldsymbol{z} \sim \mathcal{N}( 0, {I}_d )} [ f (\boldsymbol{\theta} + \sigma \, \boldsymbol{z}) ]\),其中 \(\sigma\) 表示精度。 该目标的梯度为 \(\nabla_{\boldsymbol{\theta}} \tilde{f}_{\sigma} (\boldsymbol{\theta}) = \frac{1}{\sigma} \mathbb{E}_{\boldsymbol{z} \sim \mathcal{N}( 0, {I}_d )} [ f (\boldsymbol{\theta} + \sigma \, \boldsymbol{z}) \cdot \boldsymbol{z} ]\)。 基于这种技术,可以将双层过程视为一个整体,基于纯前向过程计算元梯度。 更多解释请参考 ES-MAML

装饰器

torchopt.diff.zero_order.zero_order(distribution)

返回一个用于应用零阶微分的装饰器。

类似于隐式梯度,我们也为ES方法使用装饰器。

功能API

基本的功能API是torchopt.diff.zero_order.zero_order(),它被用作前向过程零阶梯度程序的装饰器。 用户需要实现噪声采样函数,该函数将作为zero_order装饰器的输入。 这里我们展示了装饰器中使用的每个参数的具体含义。

  • distribution 用于噪声采样分布。分布 \(\lambda\) 应该是球对称的,并且每个元素的方差恒定为 \(1\)。即:

    • 球对称:\(\mathbb{E}_{\boldsymbol{z} \sim \lambda} [ \boldsymbol{z} ] = \boldsymbol{0}\).

    • 每个元素的恒定方差为\(1\)\(\mathbb{E}_{\boldsymbol{z} \sim \lambda} [ {\lvert z_i \rvert}^2 ] = 1\)

    • 例如,标准的多维正态分布 \(\mathcal{N} (\boldsymbol{0}, \boldsymbol{1})\)

  • method 用于不同类型的算法,我们支持 'naive' (ES RL), 'forward' (Forward-FD), 和 'antithetic' (antithetic).

    \[\begin{align*} \text{naive} \qquad & \nabla_{\boldsymbol{\theta}} \tilde{f}_{\sigma} (\boldsymbol{\theta}) = \frac{1}{\sigma} \mathbb{E}_{\boldsymbol{z} \sim \lambda} [ f (\boldsymbol{\theta} + \sigma \, \boldsymbol{z}) \cdot \boldsymbol{z} ] \\ \text{forward} \qquad & \nabla_{\boldsymbol{\theta}} \tilde{f}_{\sigma} (\boldsymbol{\theta}) = \frac{1}{\sigma} \mathbb{E}_{\boldsymbol{z} \sim \lambda} [ ( f (\boldsymbol{\theta} + \sigma \, \boldsymbol{z}) - f (\boldsymbol{\theta}) ) \cdot \boldsymbol{z} ] \\ \text{antithetic} \qquad & \nabla_{\boldsymbol{\theta}} \tilde{f}_{\sigma} (\boldsymbol{\theta}) = \frac{1}{2 \sigma} \mathbb{E}_{\boldsymbol{z} \sim \lambda} [ (f (\boldsymbol{\theta} + \sigma \, \boldsymbol{z}) - f (\boldsymbol{\theta} + \sigma \, \boldsymbol{z}) ) \cdot \boldsymbol{z} ] \end{align*}\]
  • argnums 指定我们想要追踪元梯度的参数。

  • num_samples 指定了我们想要进行采样的次数。

  • sigma 用于精度。这是采样分布的缩放因子。

我们在以下部分展示伪代码。

# Functional API for zero-order differentiation
# 1. Customize the noise distribution via a distribution class
class Distribution:
    def sample(self, sample_shape=torch.Size()):
        # Sampling function for noise
        # NOTE: The distribution should be spherical symmetric and with a constant variance of 1.
        ...
        return noise_batch

distribution = Distribution()

# 2. Customize the noise distribution via a sampling function
def distribution(sample_shape=torch.Size()):
    # Sampling function for noise
    # NOTE: The distribution should be spherical symmetric and with a constant variance of 1.
    ...
    return noise_batch

# 3. Distribution can also be an instance of `torch.distributions.Distribution`, e.g., `torch.distributions.Normal(...)`
distribution = torch.distributions.Normal(loc=0, scale=1)

# Decorator that wraps the function
@torchopt.diff.zero_order(distribution=distribution, method='naive', argnums=0, num_samples=100, sigma=0.01)
def forward(params, data):
    # Forward optimization process for params
    ...
    return objective  # the returned tensor should be a scalar tensor

# Define params and get data
params, data = ..., ...

# Forward pass
loss = forward(params, data)
# Backward pass using zero-order differentiation
grads = torch.autograd.grad(loss, params)

面向对象编程API

torchopt.nn.ZeroOrderGradientModule(*args, ...)

零阶梯度模型的基类。

结合PyTorch的torch.nn.Module,我们还为ES设计了OOP API nn.ZeroOrderGradientModulenn.ZeroOrderGradientModule的核心思想是使梯度流向前传递到self.parameters()(在计算元梯度时可以是元参数)。 用户需要定义前向过程的零阶梯度程序forward()和一个噪声采样函数sample()

from torchopt.nn import ZeroOrderGradientModule

# Inherited from the class ZeroOrderGradientModule
# Optionally specify the `method` and/or `num_samples` and/or `sigma` used for sampling
class Net(ZeroOrderGradientModule, method='naive', num_samples=100, sigma=0.01):
    def __init__(self, ...):
        ...

    def forward(self, batch):
        # Forward process
        ...
        return objective  # the returned tensor should be a scalar tensor

    def sample(self, sample_shape=torch.Size()):
        # Generate a batch of noise samples
        # NOTE: The distribution should be spherical symmetric and with a constant variance of 1.
        ...
        return noise_batch

# Get model and data
net = Net(...)
data = ...

# Forward pass
loss = Net(data)
# Backward pass using zero-order differentiation
grads = torch.autograd.grad(loss, net.parameters())

笔记本教程

更多详情,请查看zero-order的笔记本教程。