隐式梯度微分

隐式微分

../_images/implicit-gradient.png

隐式微分是通过满足映射函数\(T\)的优化问题的解进行微分的任务,该映射函数捕捉了问题的最优条件。最简单的例子是通过最小化问题的解对其输入进行微分。即,给定

\[\boldsymbol{\theta}^{\prime} (\boldsymbol{\phi}) \triangleq \underset{\boldsymbol{\theta}}{\mathop{\operatorname{argmin}}} ~ \mathcal{L}^{\text{in}} (\boldsymbol{\phi},\boldsymbol{\theta}).\]

通过将解\(\boldsymbol{\theta}^{\prime}\)视为\(\boldsymbol{\phi}\)的隐函数,隐函数微分的想法是通过隐函数定理直接获得解析的最佳响应导数\(\nabla_{\boldsymbol{\phi}} \boldsymbol{\theta}^{\prime} (\boldsymbol{\phi})\)

求根

这适用于当内部级别的最优性条件 \(T\) 由函数的根定义时的算法,例如:

\[T (\boldsymbol{\phi}, \boldsymbol{\theta}) = \frac{ \partial \mathcal{L}^{\text{in}} (\boldsymbol{\phi}, \boldsymbol{\theta})}{\partial \boldsymbol{\theta}}, \qquad T (\boldsymbol{\phi}, \boldsymbol{\theta}^{\prime} (\boldsymbol{\phi})) = \left. \frac{ \partial \mathcal{L}^{\text{in}} (\boldsymbol{\phi}, \boldsymbol{\theta})}{\partial \boldsymbol{\theta}} \right\rvert_{\boldsymbol{\theta} = \boldsymbol{\theta}^{\prime}} = \boldsymbol{0}.\]

IMAML中,图中的函数\(F\)表示通过展开梯度更新获得内层最优解:

\[\boldsymbol{\theta}_{k + 1} = F (\boldsymbol{\phi}, \boldsymbol{\theta}_k) = \boldsymbol{\theta}_k - \alpha \nabla_{\boldsymbol{\theta}_k} \mathcal{L}^{\text{in}} (\boldsymbol{\phi}, \boldsymbol{\theta}_k).\]

定点迭代

有时,内部层级的最优解也可以通过固定点实现,其中最优性\(T\)的形式为:

\[\boldsymbol{\theta}^{\prime} (\boldsymbol{\phi}) = F (\boldsymbol{\phi}, \boldsymbol{\theta}^{\prime} (\boldsymbol{\phi})) \quad \Longleftrightarrow \quad T (\boldsymbol{\phi}, \boldsymbol{\theta}) = F (\boldsymbol{\phi}, \boldsymbol{\theta}) - \boldsymbol{\theta}, \quad T (\boldsymbol{\phi}, \boldsymbol{\theta}^{\prime} (\boldsymbol{\phi})) = \boldsymbol{0}.\]

DEQ中,图中的函数\(F\)表示通过固定点更新获得内层最优解:

\[\boldsymbol{\theta}_{k + 1} = F (\boldsymbol{\phi}, \boldsymbol{\theta}_k).\]

这可以被视为通过定义最优性函数为\(T (\boldsymbol{\phi}, \boldsymbol{\theta}) = F (\boldsymbol{\phi}, \boldsymbol{\theta}) - \boldsymbol{\theta}\)的函数根的特定情况。 这可以通过以下方式实现:

def fixed_point_function(phi: TensorTree, theta: TensorTree) -> TensorTree:
    ...
    return new_theta

# A root function can be derived from the fixed point function
def root_function(phi: TensorTree, theta: TensorTree) -> TensorTree:
    new_theta = fixed_point_function(phi, theta)
    return torchopt.pytree.tree_sub(new_theta, theta)

自定义求解器

torchopt.diff.implicit.custom_root(...[, ...])

返回一个装饰器,用于为根求解器添加隐式微分。

\(T (\boldsymbol{\phi}, \boldsymbol{\theta}): \mathbb{R}^n \times \mathbb{R}^d \to \mathbb{R}^d\) 成为一个用户提供的映射函数,它捕捉了一个问题的最优条件。 一个最优解,表示为 \(\boldsymbol{\theta}^{\prime} (\boldsymbol{\phi})\),应该是 \(T\) 的根:

\[T (\boldsymbol{\phi}, \boldsymbol{\theta}^{\prime}(\boldsymbol{\phi})) = \boldsymbol{0}.\]

我们可以将\(\boldsymbol{\theta}^{\prime} (\boldsymbol{\phi})\)视为\(\boldsymbol{\phi} \in \mathbb{R}^n\)的隐式定义函数,即\(\boldsymbol{\theta}^{\prime}: \mathbb{R}^n \rightarrow \mathbb{R}^d\)。 更准确地说,根据隐函数定理,我们知道对于满足\(T (\boldsymbol{\phi}_0, \boldsymbol{\theta}^{\prime}_0) = \boldsymbol{0}\)\((\boldsymbol{\phi}_0, \boldsymbol{\theta}^{\prime}_0)\),如果\(T\)是连续可微的,并且在\((\boldsymbol{\phi}_0, \boldsymbol{\theta}^{\prime}_0)\)处评估的雅可比矩阵\(\nabla_{\boldsymbol{\theta}^{\prime}} T\)是一个可逆方阵,那么在\(\boldsymbol{\phi}_0\)的邻域内存在一个函数\(\boldsymbol{\theta}^{\prime} (\cdot)\),使得\(\boldsymbol{\theta}^{\prime} (\boldsymbol{\phi}_0) = \boldsymbol{\theta}^{\prime}_0\)。 此外,对于该邻域内的所有\(\boldsymbol{\phi}\),我们有\(T (\boldsymbol{\phi}_0, \boldsymbol{\theta}^{\prime}_0) = \boldsymbol{0}\)并且\(\nabla_{\boldsymbol{\phi}} \boldsymbol{\theta}^{\prime} (\boldsymbol{\phi})\)存在。使用链式法则,雅可比矩阵\(\nabla_{\boldsymbol{\phi}} \boldsymbol{\theta}^{\prime}(\boldsymbol{\phi})\)满足:

\[\frac{d T}{d \boldsymbol{\phi}} = \underbrace{\nabla_{\boldsymbol{\theta}^{\prime}} T (\boldsymbol{\phi}, \boldsymbol{\theta}^{\prime}(\boldsymbol{\phi}))}_{\frac{\partial T}{\partial \boldsymbol{\theta}^{\prime}}} \underbrace{\nabla_{\boldsymbol{\phi}} \boldsymbol{\theta}^{\prime} (\boldsymbol{\phi})}_{\frac{d \boldsymbol{\theta}^{\prime}}{d \boldsymbol{\phi}}} + \underbrace{\nabla_{\boldsymbol{\phi}} T (\boldsymbol{\phi}, \boldsymbol{\theta}^{\prime} (\boldsymbol{\phi}))}_{\frac{\partial T}{\partial \boldsymbol{\phi}}} = \boldsymbol{0}. \qquad ( T (\boldsymbol{\phi}, \boldsymbol{\theta}^{\prime}) = \boldsymbol{0} = \text{const})\]

计算 \(\nabla_{\boldsymbol{\phi}} \boldsymbol{\theta}^{\prime}(\boldsymbol{\phi})\) 因此归结为线性方程组的求解

\[\underbrace{\nabla_{\boldsymbol{\theta}^{\prime}} T (\boldsymbol{\phi}, \boldsymbol{\theta}^{\prime}(\boldsymbol{\phi}))}_{A \in \mathbb{R}^{d \times d}} \underbrace{\nabla_{\boldsymbol{\phi}} \boldsymbol{\theta}^{\prime} (\boldsymbol{\phi})}_{J \in \mathbb{R}^{d \times n}} = \underbrace{- \nabla_{\boldsymbol{\phi}} T (\boldsymbol{\phi}, \boldsymbol{\theta}^{\prime} (\boldsymbol{\phi}))}_{B \in \mathbb{R}^{d \times n}}.\]

TorchOpt 提供了一个装饰器函数 custom_root(),用于轻松地在任何现有的内部优化求解器(也称为前向优化)之上添加隐式微分。custom_root() 装饰器要求用户定义问题解的稳态条件(例如,KKT 条件),并会自动计算梯度以进行反向梯度计算。

这里是一个custom_root()装饰器的示例,它也是隐式梯度的功能API

# Functional API for implicit gradient
def stationary(params, meta_params, data):
    # stationary condition construction
    return stationary condition

# Decorator that wraps the function
# Optionally specify the linear solver (conjugate gradient or Neumann series)
@torchopt.diff.implicit.custom_root(stationary)
def solve(params, meta_params, data):
    # Forward optimization process for params
    return optimal_params

# Define params, meta_params and get data
params, meta_prams, data = ..., ..., ...
optimal_params = solve(params, meta_params, data)
loss = outer_loss(optimal_params)

meta_grads = torch.autograd.grad(loss, meta_params)

面向对象编程API

torchopt.nn.ImplicitMetaGradientModule(...)

可微分隐式元梯度模型的基类。

结合PyTorch的torch.nn.Module,我们还设计了面向对象的API nn.ImplicitMetaGradientModule 用于隐式梯度。 nn.ImplicitMetaGradientModule 的核心思想是使梯度从 self.parameters()(通常是低层参数)流向 self.meta_parameters()(通常是高层参数)。 用户需要定义前向过程 forward()、一个稳态函数 optimality()(或 objective()),以及内循环优化 solve

这是OOP API的一个示例。

from torchopt.nn import ImplicitMetaGradientModule

# Inherited from the class ImplicitMetaGradientModule
class InnerNet(ImplicitMetaGradientModule):
    def __init__(self, meta_module):
        ...

    def forward(self, batch):
        # Forward process
        ...

    def optimality(self, batch, labels):
        # Stationary condition construction for calculating implicit gradient
        # NOTE: If this method is not implemented, it will be automatically derived from the
        # gradient of the `objective` function.
        ...

    def objective(self, batch, labels):
        # Define the inner-loop optimization objective
        # NOTE: This method is optional if method `optimality` is implemented.
        ...

    def solve(self, batch, labels):
        # Conduct the inner-loop optimization
        ...
        return self  # optimized module

# Get meta_params and data
meta_params, data = ..., ...
inner_net = InnerNet()

# Solve for inner-loop process related to the meta-parameters
optimal_inner_net = inner_net.solve(meta_params, *data)

# Get outer-loss and solve for meta-gradient
loss = outer_loss(optimal_inner_net)
meta_grad = torch.autograd.grad(loss, meta_params)

如果优化目标是最小化/最大化一个目标函数,我们提供了一个objective方法接口来简化实现。 用户只需要定义objective方法,而TorchOpt会自动从KKT条件中分析其平稳(最优)条件。

注意

__init__方法中,用户需要定义内部参数和元参数。 默认情况下,nn.ImplicitMetaGradientModule将方法输入中的所有张量和模块视为self.meta_parameters() / self.meta_modules()。 例如,语句self.yyy = xxxxxx分配为名为'yyy'的元参数,如果xxx存在于方法输入中(例如,def __init__(self, xxx, ...): ...)。 在__init__中定义的所有张量和模块被视为self.parameters() / self.modules()。 用户还可以通过分别调用self.register_parameter()self.register_meta_parameter()来注册参数和元参数。

线性系统求解器

torchopt.linear_solve.solve_cg(**kwargs)

返回一个求解器函数,用于使用共轭梯度法求解 A x = b

torchopt.linear_solve.solve_inv(**kwargs)

返回一个求解器函数,用于使用矩阵逆解A x = b

torchopt.linear_solve.solve_normal_cg(**kwargs)

返回一个求解器函数,用于使用共轭梯度法求解 A^T A x = A^T b

通常,隐式梯度的计算涉及逆Hessian矩阵的计算。 然而,高维Hessian矩阵也使得直接计算变得不可行,这就是线性求解器发挥作用的地方。 通过迭代求解线性系统问题,我们可以计算逆Hessian矩阵到一定的精度。我们提供了基于共轭梯度的求解器和基于纽曼级数的求解器。

这是一个线性求解器的示例。

import torch
from torchopt import linear_solve

torch.manual_seed(42)
A = torch.randn(3, 3)
b = torch.randn(3)

def matvec(x):
    return  torch.matmul(A, x)

solve_fn = linear_solve.solve_normal_cg(atol=1e-5)
solution = solve_fn(matvec, b)
print(solution)

solve_fn = linear_solve.solve_cg(atol=1e-5)
solution = solve_fn(matvec, b)
print(solution)

用户也可以在功能和面向对象的API中选择相应的求解器。

# For functional API
@torchopt.diff.implicit.custom_root(
    functorch.grad(objective_fn, argnums=0),  # optimality function
    argnums=1,
    solve=torchopt.linear_solve.solve_normal_cg(maxiter=5, atol=0),
)
def solve_fn(...):
    ...

# For OOP API
class InnerNet(
    torchopt.nn.ImplicitMetaGradientModule,
    linear_solve=torchopt.linear_solve.solve_normal_cg(maxiter=5, atol=0),
):
    ...

笔记本教程

查看笔记本教程 Implicit Differentiation