torch.autograd.function 的源代码
import functools
import inspect
import itertools
import warnings
from collections import OrderedDict
from typing import Any, List, Optional, Tuple
import torch
import torch._C as _C
import torch._functorch as _functorch
import torch.utils.hooks as hooks
from torch._C import _functions
from torch._functorch.autograd_function import custom_function_call
__all__ = [
"FunctionCtx",
"BackwardCFunction",
"FunctionMeta",
"Function",
"once_differentiable",
"traceable",
"InplaceFunction",
"NestedIOFunction",
]
# 为每个继承自Function的类提供唯一的id
# 这在类定义期间在FunctionMeta中递增
AUTOGRAD_FUNCTION_COUNTER = itertools.count()
# 以前称为:_ContextMethodMixin
class FunctionCtx:
[docs] def save_for_backward(self, *tensors: torch.Tensor):
r"""保存给定的张量以供将来调用 :func:`~Function.backward`。
``save_for_backward`` 最多应调用一次,仅在 :func:`forward` 方法内部调用,并且仅使用张量。
所有打算在反向传播中使用的张量都应使用 ``save_for_backward`` 保存(而不是直接保存在 ``ctx`` 上),以防止梯度不正确和内存泄漏,并启用保存张量钩子的应用。请参阅 :class:`torch.autograd.graph.saved_tensors_hooks`。
请注意,如果保存了中间张量(既不是 :func:`forward` 的输入也不是输出的张量),您的自定义 Function 可能不支持双重反向传播。
不支持双重反向传播的自定义 Function 应在其 :func:`backward` 方法上使用 ``@once_differentiable`` 装饰器,以便在执行双重反向传播时引发错误。如果您希望支持双重反向传播,您可以根据反向传播期间的输入重新计算中间张量,或者将中间张量作为自定义 Function 的输出返回。有关更多详细信息,请参阅 `双重反向传播教程 `_。
在 :func:`backward` 中,保存的张量可以通过 :attr:`saved_tensors` 属性访问。在将它们返回给用户之前,会进行检查以确保它们没有在任何就地操作中被修改。
参数也可以是 ``None``。这是一个空操作。
有关如何使用此方法的更多详细信息,请参阅 :ref:`extending-autograd`。
示例::
>>> # xdoctest: +REQUIRES(env:TORCH_DOCTEST_AUTOGRAD)
>>> class Func(Function):
>>> @staticmethod
>>> def forward(ctx, x: torch.Tensor, y: torch.Tensor, z: int):
>>> w = x * z
>>> out = x * y + y * z + w * y
>>> ctx.save_for_backward(x, y, w, out)
>>> ctx.z = z # z 不是张量
>>> return out
>>>
>>> @staticmethod
>>> @once_differentiable
>>> def backward(ctx, grad_out):
>>> x, y, w, out = ctx.saved_tensors
>>> z = ctx.z
>>> gx = grad_out * (y + y * z)
>>> gy = grad_out * (x + z + w)
>>> gz = None
>>> return gx, gy, gz
>>>
>>> a = torch.tensor(1., requires_grad=True, dtype=torch.double)
>>> b = torch.tensor(2., requires_grad=True, dtype=torch.double)
>>> c = 4
>>> d = Func.apply(a, b, c)
"""
self.to_save = tensors
def save_for_forward(self, *tensors: torch.Tensor):
r"""保存给定的张量以供将来调用 :func:`~Function.jvp`。
``save_for_forward`` 应仅调用一次,从 :func:`forward` 方法内部调用,并且仅使用张量。
在 :func:`jvp` 中,保存的对象可以通过 :attr:`saved_tensors` 属性访问。
参数也可以是 ``None``。这是一个空操作。
有关如何使用此方法的更多详细信息,请参阅 :ref:`extending-autograd`。
示例::
>>> # xdoctest: +SKIP
>>> class Func(torch.autograd.Function):
>>> @staticmethod
>>> def forward(ctx, x: torch.Tensor, y: torch.Tensor, z: int):
>>> ctx.save_for_backward(x, y)
>>> ctx.save_for_forward(x, y)
>>> ctx.z = z
>>> return x * y * z
>>>
>>> @staticmethod
>>> def jvp(ctx, x_t, y_t, _):
>>> x, y = ctx.saved_tensors
>>> z = ctx.z
>>> return z * (y * x_t + x * y_t)
>>>
>>> @staticmethod
>>> def vjp(ctx, grad_out):
>>> x, y = ctx.saved_tensors
>>> z = ctx.z
>>> return z * grad_out * y, z * grad_out * x, None
>>>
>>> a = torch.tensor(1., requires_grad=True, dtype=torch.double)
>>> t = torch.tensor(1., dtype=torch.double)
>>> b = torch.tensor(2., requires_grad=True, dtype=torch.double)
>>> c = 4
>>>
>>> with fwAD.dual_level():
>>> a_dual = fwAD.make_dual(a, t)
>>> d = Func.apply(a_dual, b, c)
"""
for tensor in tensors:
assert isinstance(tensor, torch.Tensor) or tensor is None, (
"save_for_forward 期望所有参数都是张量;您应该将非张量保存为 ctx 的属性。"
)
self.saved_for_forward = tensors
[docs] def mark_dirty(self, *args: torch.Tensor):
r"""标记给定的张量在就地操作中被修改。
**这最多应调用一次,仅从 :func:`forward` 方法内部调用,并且所有参数都应是输入。**
在调用 :func:`forward` 时,所有在就地操作中被修改的张量都应传递给此函数,以确保我们的检查的正确性。调用此函数的时间(在修改之前或之后)无关紧要。
示例::
>>> # xdoctest: +REQUIRES(env:TORCH_DOCTEST_AUTOGRAD)
>>> class Inplace(Function):
>>> @static