Shortcuts

torch.func.linearize

torch.func.linearize(func, *primals)

返回funcprimals处的值和线性近似值。

Parameters
  • func (可调用对象) – 一个接受一个或多个参数的Python函数。

  • primals (Tensors) – 传递给 func 的位置参数,这些参数必须都是 Tensors。这些是函数进行线性近似的值。

Returns

返回一个 (output, jvp_fn) 元组,包含将 func 应用于 primals 的输出,以及一个计算在 primals 处评估的 func 的 jvp 的函数。

Return type

元组[任意, 可调用]

如果需要在 primals 处多次计算 jvp,linearize 会很有用。然而,为了实现这一点,linearize 会保存中间计算结果,并且比直接应用 jvp 具有更高的内存需求。因此,如果所有 tangents 都已知,计算 vmap(jvp) 可能比使用 linearize 更高效。

注意

linearize 对 func 进行两次评估。请提交一个问题以实现仅进行一次评估的版本。

Example::
>>> import torch
>>> from torch.func import linearize
>>> def fn(x):
...     return x.sin()
...
>>> output, jvp_fn = linearize(fn, torch.zeros(3, 3))
>>> jvp_fn(torch.ones(3, 3))
tensor([[1., 1., 1.],
        [1., 1., 1.],
        [1., 1., 1.]])
>>>
优云智算