隐式梯度微分
隐式微分
隐式微分是通过满足映射函数\(T\)的优化问题的解进行微分的任务,该映射函数捕捉了问题的最优条件。最简单的例子是通过最小化问题的解对其输入进行微分。即,给定
通过将解\(\boldsymbol{\theta}^{\prime}\)视为\(\boldsymbol{\phi}\)的隐函数,隐函数微分的想法是通过隐函数定理直接获得解析的最佳响应导数\(\nabla_{\boldsymbol{\phi}} \boldsymbol{\theta}^{\prime} (\boldsymbol{\phi})\)。
求根
这适用于当内部级别的最优性条件 \(T\) 由函数的根定义时的算法,例如:
在IMAML中,图中的函数\(F\)表示通过展开梯度更新获得内层最优解:
定点迭代
有时,内部层级的最优解也可以通过固定点实现,其中最优性\(T\)的形式为:
在DEQ中,图中的函数\(F\)表示通过固定点更新获得内层最优解:
这可以被视为通过定义最优性函数为\(T (\boldsymbol{\phi}, \boldsymbol{\theta}) = F (\boldsymbol{\phi}, \boldsymbol{\theta}) - \boldsymbol{\theta}\)的函数根的特定情况。 这可以通过以下方式实现:
def fixed_point_function(phi: TensorTree, theta: TensorTree) -> TensorTree:
...
return new_theta
# A root function can be derived from the fixed point function
def root_function(phi: TensorTree, theta: TensorTree) -> TensorTree:
new_theta = fixed_point_function(phi, theta)
return torchopt.pytree.tree_sub(new_theta, theta)
自定义求解器
|
返回一个装饰器,用于为根求解器添加隐式微分。 |
让 \(T (\boldsymbol{\phi}, \boldsymbol{\theta}): \mathbb{R}^n \times \mathbb{R}^d \to \mathbb{R}^d\) 成为一个用户提供的映射函数,它捕捉了一个问题的最优条件。 一个最优解,表示为 \(\boldsymbol{\theta}^{\prime} (\boldsymbol{\phi})\),应该是 \(T\) 的根:
我们可以将\(\boldsymbol{\theta}^{\prime} (\boldsymbol{\phi})\)视为\(\boldsymbol{\phi} \in \mathbb{R}^n\)的隐式定义函数,即\(\boldsymbol{\theta}^{\prime}: \mathbb{R}^n \rightarrow \mathbb{R}^d\)。 更准确地说,根据隐函数定理,我们知道对于满足\(T (\boldsymbol{\phi}_0, \boldsymbol{\theta}^{\prime}_0) = \boldsymbol{0}\)的\((\boldsymbol{\phi}_0, \boldsymbol{\theta}^{\prime}_0)\),如果\(T\)是连续可微的,并且在\((\boldsymbol{\phi}_0, \boldsymbol{\theta}^{\prime}_0)\)处评估的雅可比矩阵\(\nabla_{\boldsymbol{\theta}^{\prime}} T\)是一个可逆方阵,那么在\(\boldsymbol{\phi}_0\)的邻域内存在一个函数\(\boldsymbol{\theta}^{\prime} (\cdot)\),使得\(\boldsymbol{\theta}^{\prime} (\boldsymbol{\phi}_0) = \boldsymbol{\theta}^{\prime}_0\)。 此外,对于该邻域内的所有\(\boldsymbol{\phi}\),我们有\(T (\boldsymbol{\phi}_0, \boldsymbol{\theta}^{\prime}_0) = \boldsymbol{0}\)并且\(\nabla_{\boldsymbol{\phi}} \boldsymbol{\theta}^{\prime} (\boldsymbol{\phi})\)存在。使用链式法则,雅可比矩阵\(\nabla_{\boldsymbol{\phi}} \boldsymbol{\theta}^{\prime}(\boldsymbol{\phi})\)满足:
计算 \(\nabla_{\boldsymbol{\phi}} \boldsymbol{\theta}^{\prime}(\boldsymbol{\phi})\) 因此归结为线性方程组的求解
TorchOpt 提供了一个装饰器函数 custom_root(),用于轻松地在任何现有的内部优化求解器(也称为前向优化)之上添加隐式微分。custom_root() 装饰器要求用户定义问题解的稳态条件(例如,KKT 条件),并会自动计算梯度以进行反向梯度计算。
这里是一个custom_root()装饰器的示例,它也是隐式梯度的功能API。
# Functional API for implicit gradient
def stationary(params, meta_params, data):
# stationary condition construction
return stationary condition
# Decorator that wraps the function
# Optionally specify the linear solver (conjugate gradient or Neumann series)
@torchopt.diff.implicit.custom_root(stationary)
def solve(params, meta_params, data):
# Forward optimization process for params
return optimal_params
# Define params, meta_params and get data
params, meta_prams, data = ..., ..., ...
optimal_params = solve(params, meta_params, data)
loss = outer_loss(optimal_params)
meta_grads = torch.autograd.grad(loss, meta_params)
面向对象编程API
|
可微分隐式元梯度模型的基类。 |
结合PyTorch的torch.nn.Module,我们还设计了面向对象的API nn.ImplicitMetaGradientModule 用于隐式梯度。
nn.ImplicitMetaGradientModule 的核心思想是使梯度从 self.parameters()(通常是低层参数)流向 self.meta_parameters()(通常是高层参数)。
用户需要定义前向过程 forward()、一个稳态函数 optimality()(或 objective()),以及内循环优化 solve。
这是OOP API的一个示例。
from torchopt.nn import ImplicitMetaGradientModule
# Inherited from the class ImplicitMetaGradientModule
class InnerNet(ImplicitMetaGradientModule):
def __init__(self, meta_module):
...
def forward(self, batch):
# Forward process
...
def optimality(self, batch, labels):
# Stationary condition construction for calculating implicit gradient
# NOTE: If this method is not implemented, it will be automatically derived from the
# gradient of the `objective` function.
...
def objective(self, batch, labels):
# Define the inner-loop optimization objective
# NOTE: This method is optional if method `optimality` is implemented.
...
def solve(self, batch, labels):
# Conduct the inner-loop optimization
...
return self # optimized module
# Get meta_params and data
meta_params, data = ..., ...
inner_net = InnerNet()
# Solve for inner-loop process related to the meta-parameters
optimal_inner_net = inner_net.solve(meta_params, *data)
# Get outer-loss and solve for meta-gradient
loss = outer_loss(optimal_inner_net)
meta_grad = torch.autograd.grad(loss, meta_params)
如果优化目标是最小化/最大化一个目标函数,我们提供了一个objective方法接口来简化实现。
用户只需要定义objective方法,而TorchOpt会自动从KKT条件中分析其平稳(最优)条件。
注意
在__init__方法中,用户需要定义内部参数和元参数。
默认情况下,nn.ImplicitMetaGradientModule将方法输入中的所有张量和模块视为self.meta_parameters() / self.meta_modules()。
例如,语句self.yyy = xxx将xxx分配为名为'yyy'的元参数,如果xxx存在于方法输入中(例如,def __init__(self, xxx, ...): ...)。
在__init__中定义的所有张量和模块被视为self.parameters() / self.modules()。
用户还可以通过分别调用self.register_parameter()和self.register_meta_parameter()来注册参数和元参数。
线性系统求解器
|
返回一个求解器函数,用于使用共轭梯度法求解 |
|
返回一个求解器函数,用于使用矩阵逆解 |
|
返回一个求解器函数,用于使用共轭梯度法求解 |
通常,隐式梯度的计算涉及逆Hessian矩阵的计算。 然而,高维Hessian矩阵也使得直接计算变得不可行,这就是线性求解器发挥作用的地方。 通过迭代求解线性系统问题,我们可以计算逆Hessian矩阵到一定的精度。我们提供了基于共轭梯度的求解器和基于纽曼级数的求解器。
这是一个线性求解器的示例。
import torch
from torchopt import linear_solve
torch.manual_seed(42)
A = torch.randn(3, 3)
b = torch.randn(3)
def matvec(x):
return torch.matmul(A, x)
solve_fn = linear_solve.solve_normal_cg(atol=1e-5)
solution = solve_fn(matvec, b)
print(solution)
solve_fn = linear_solve.solve_cg(atol=1e-5)
solution = solve_fn(matvec, b)
print(solution)
用户也可以在功能和面向对象的API中选择相应的求解器。
# For functional API
@torchopt.diff.implicit.custom_root(
functorch.grad(objective_fn, argnums=0), # optimality function
argnums=1,
solve=torchopt.linear_solve.solve_normal_cg(maxiter=5, atol=0),
)
def solve_fn(...):
...
# For OOP API
class InnerNet(
torchopt.nn.ImplicitMetaGradientModule,
linear_solve=torchopt.linear_solve.solve_normal_cg(maxiter=5, atol=0),
):
...
笔记本教程
查看笔记本教程 Implicit Differentiation。