Shortcuts

torch.func.functional_call

torch.func.functional_call(module, parameter_and_buffer_dicts, args, kwargs=None, *, tie_weights=True, strict=False)

通过替换模块参数和缓冲区为提供的参数和缓冲区,执行模块上的功能调用。

注意

如果模块有活动的参数化设置,在 parameter_and_buffer_dicts 参数中传递一个值,并将名称设置为常规参数名称,将完全禁用参数化。 如果你想将参数化函数应用于传递的值,请将键设置为 {submodule_name}.parametrizations.{parameter_name}.original

注意

如果模块对参数/缓冲区执行就地操作,这些操作将会反映在输入的 parameter_and_buffer_dicts 中。

示例:

>>> a = {'foo': torch.zeros(())}
>>> mod = Foo()  # 将 self.foo 设置为 self.foo + 1
>>> print(mod.foo)  # tensor(0.)
>>> functional_call(mod, a, torch.ones(()))
>>> print(mod.foo)  # tensor(0.)
>>> print(a['foo'])  # tensor(1.)

注意

如果模块有绑定的权重,是否功能调用尊重绑定由tie_weights标志决定。

示例:

>>> a = {'foo': torch.zeros(())}
>>> mod = Foo()  # 同时具有 self.foo 和 self.foo_tied,它们是绑定的。返回 x + self.foo + self.foo_tied
>>> print(mod.foo)  # tensor(1.)
>>> mod(torch.zeros(()))  # tensor(2.)
>>> functional_call(mod, a, torch.zeros(()))  # tensor(0.) 因为它也会改变 self.foo_tied
>>> functional_call(mod, a, torch.zeros(()), tie_weights=False)  # tensor(1.) -- self.foo_tied 没有更新
>>> new_a = {'foo': torch.zeros(()), 'foo_tied': torch.zeros(())}
>>> functional_call(mod, new_a, torch.zeros()) # tensor(0.)

传递多个字典的示例

a = ({'weight': torch.ones(1, 1)}, {'buffer': torch.zeros(1)})  # 两个独立的字典
mod = nn.Bar(1, 1)  # 返回 self.weight @ x + self.buffer
print(mod.weight)  # 张量(...)
print(mod.buffer)  # 张量(...)
x = torch.randn((1, 1))
print(x)
functional_call(mod, a, x)  # 与 x 相同
print(mod.weight)  # 与 functional_call 之前相同

这里是一个对模型参数应用grad变换的示例。

import torch
import torch.nn as nn
from torch.func import functional_call, grad

x = torch.randn(4, 3)
t = torch.randn(4, 3)
model = nn.Linear(3, 3)

def compute_loss(params, x, t):
    y = functional_call(model, params, x)
    return nn.functional.mse_loss(y, t)

grad_weights = grad(compute_loss)(dict(model.named_parameters()), x, t)

注意

如果用户不需要在梯度变换之外进行梯度跟踪,他们可以分离所有参数以获得更好的性能和内存使用

示例:

>>> detached_params = {k: v.detach() for k, v in model.named_parameters()}
>>> grad_weights = grad(compute_loss)(detached_params, x, t)
>>> grad_weights.grad_fn  # 无--它不在grad之外跟踪梯度

这意味着用户不能调用 grad_weight.backward()。然而,如果他们不需要在变换之外进行自动梯度跟踪,这将导致更少的内存使用和更快的速度。

Parameters
  • 模块 (torch.nn.Module) – 要调用的模块

  • parameters_and_buffer_dicts (Dict[str, Tensor] 或 tupleDict[str, Tensor]) – 将在模块调用中使用的参数。如果给定一个字典元组,它们必须具有不同的键,以便所有字典可以一起使用

  • args (任意元组) – 传递给模块调用的参数。如果不是元组,则视为单个参数。

  • kwargs (字典) – 传递给模块调用的关键字参数

  • tie_weights (bool, 可选) – 如果为 True,则原始模型中绑定的参数和缓冲区在重新参数化的版本中将被视为绑定。因此,如果为 True 并且为绑定的参数和缓冲区传递了不同的值,将会报错。如果为 False,则不会尊重原始绑定的参数和缓冲区,除非为两个权重传递的值相同。默认值:True。

  • 严格布尔值可选)– 如果为True,则传入的参数和缓冲区必须与原始模块中的参数和缓冲区匹配。因此,如果为True并且存在任何缺失或意外的键,将会报错。默认值:False。

Returns

调用 module 的结果。

Return type

任意

优云智算