torch.func API 参考¶
函数转换¶
vmap |
vmap 是向量化映射; |
grad |
|
grad_and_value |
返回一个函数来计算梯度和原始或前向计算的元组。 |
vjp |
代表向量-雅可比积,返回一个包含应用于 |
jvp |
代表雅可比向量积,返回一个包含 func(*primals) 输出和“ |
linearize |
返回 |
jacrev |
计算 |
jacfwd |
计算 |
hessian |
计算 |
functionalize |
functionalize 是一种转换,可用于从函数中移除(中间)突变和别名,同时保留函数的语义。 |
用于处理 torch.nn.Modules 的工具¶
通常情况下,您可以对调用 torch.nn.Module 的函数进行变换。
例如,以下是一个计算函数雅可比矩阵的示例,
该函数接受三个值并返回三个值:
model = torch.nn.Linear(3, 3)
def f(x):
return model(x)
x = torch.randn(3)
jacobian = jacrev(f)(x)
assert jacobian.shape == (3, 3)
然而,如果你想做一些类似计算模型参数的雅可比矩阵的事情,那么就需要有一种方法来构建一个函数,其中参数是该函数的输入。
这就是functional_call()的作用:
它接受一个nn.Module,转换后的parameters,以及模块前向传播的输入。它返回使用替换参数运行模块前向传播的值。
以下是我们如何计算参数上的雅可比矩阵
model = torch.nn.Linear(3, 3)
def f(params, x):
return torch.func.functional_call(model, params, x)
x = torch.randn(3)
jacobian = jacrev(f)(dict(model.named_parameters()), x)
functional_call |
通过替换模块的参数和缓冲区为提供的参数和缓冲区,执行模块上的功能调用。 |
stack_module_state |
准备一个 torch.nn.Modules 列表以与 |
replace_all_batch_norm_modules_ |
通过将 |
如果您正在寻找有关修复Batch Norm模块的信息,请遵循这里的指导