torch.func.linearize¶
- torch.func.linearize(func, *primals)¶
返回
func在primals处的值和线性近似值。- Parameters
func (可调用对象) – 一个接受一个或多个参数的Python函数。
primals (Tensors) – 传递给
func的位置参数,这些参数必须都是 Tensors。这些是函数进行线性近似的值。
- Returns
返回一个
(output, jvp_fn)元组,包含将func应用于primals的输出,以及一个计算在primals处评估的func的 jvp 的函数。- Return type
如果需要在
primals处多次计算 jvp,linearize 会很有用。然而,为了实现这一点,linearize 会保存中间计算结果,并且比直接应用 jvp 具有更高的内存需求。因此,如果所有tangents都已知,计算 vmap(jvp) 可能比使用 linearize 更高效。注意
linearize 对
func进行两次评估。请提交一个问题以实现仅进行一次评估的版本。- Example::
>>> import torch >>> from torch.func import linearize >>> def fn(x): ... return x.sin() ... >>> output, jvp_fn = linearize(fn, torch.zeros(3, 3)) >>> jvp_fn(torch.ones(3, 3)) tensor([[1., 1., 1.], [1., 1., 1.], [1., 1., 1.]]) >>>