Shortcuts

torch.func API 参考

函数转换

vmap

vmap 是向量化映射;vmap(func) 返回一个新函数,该函数将 func 映射到输入的某些维度上。

grad

grad 运算符帮助计算 func 相对于由 argnums 指定的输入的梯度。

grad_and_value

返回一个函数来计算梯度和原始或前向计算的元组。

vjp

代表向量-雅可比积,返回一个包含应用于primalsfunc结果的元组,以及一个函数,当给定cotangents时,计算关于primalsfunc的反向模式雅可比矩阵乘以cotangents

jvp

代表雅可比向量积,返回一个包含 func(*primals) 输出和“funcprimals 处的雅可比矩阵”乘以 tangents 的结果的元组。

linearize

返回funcprimals处的值以及在primals处的线性近似值。

jacrev

计算func相对于索引argnum处的参数的雅可比矩阵,使用反向模式自动微分

jacfwd

计算func相对于索引argnum处的参数的前向模式自动微分的雅可比矩阵

hessian

计算func相对于索引argnum处的参数的Hessian,采用前向-反向策略。

functionalize

functionalize 是一种转换,可用于从函数中移除(中间)突变和别名,同时保留函数的语义。

用于处理 torch.nn.Modules 的工具

通常情况下,您可以对调用 torch.nn.Module 的函数进行变换。 例如,以下是一个计算函数雅可比矩阵的示例, 该函数接受三个值并返回三个值:

model = torch.nn.Linear(3, 3)

def f(x):
    return model(x)

x = torch.randn(3)
jacobian = jacrev(f)(x)
assert jacobian.shape == (3, 3)

然而,如果你想做一些类似计算模型参数的雅可比矩阵的事情,那么就需要有一种方法来构建一个函数,其中参数是该函数的输入。 这就是functional_call()的作用: 它接受一个nn.Module,转换后的parameters,以及模块前向传播的输入。它返回使用替换参数运行模块前向传播的值。

以下是我们如何计算参数上的雅可比矩阵

model = torch.nn.Linear(3, 3)

def f(params, x):
    return torch.func.functional_call(model, params, x)

x = torch.randn(3)
jacobian = jacrev(f)(dict(model.named_parameters()), x)

functional_call

通过替换模块的参数和缓冲区为提供的参数和缓冲区,执行模块上的功能调用。

stack_module_state

准备一个 torch.nn.Modules 列表以与 vmap() 进行集成。

replace_all_batch_norm_modules_

通过将 running_meanrunning_var 设置为 None,并将 track_running_stats 设置为 False,对 root 中的任何 nn.BatchNorm 模块进行原地更新。

如果您正在寻找有关修复Batch Norm模块的信息,请遵循这里的指导

优云智算