Shortcuts

torch.autograd.function 的源代码

import functools
import inspect
import itertools
import warnings
from collections import OrderedDict
from typing import Any, List, Optional, Tuple

import torch
import torch._C as _C
import torch._functorch as _functorch
import torch.utils.hooks as hooks
from torch._C import _functions
from torch._functorch.autograd_function import custom_function_call

__all__ = [
    "FunctionCtx",
    "BackwardCFunction",
    "FunctionMeta",
    "Function",
    "once_differentiable",
    "traceable",
    "InplaceFunction",
    "NestedIOFunction",
]

# 为每个继承自Function的类提供唯一的id
# 这在类定义期间在FunctionMeta中递增
AUTOGRAD_FUNCTION_COUNTER = itertools.count()


# 以前称为:_ContextMethodMixin
class FunctionCtx:
[docs] def save_for_backward(self, *tensors: torch.Tensor): r"""保存给定的张量以供将来调用 :func:`~Function.backward`。 ``save_for_backward`` 最多应调用一次,仅在 :func:`forward` 方法内部调用,并且仅使用张量。 所有打算在反向传播中使用的张量都应使用 ``save_for_backward`` 保存(而不是直接保存在 ``ctx`` 上),以防止梯度不正确和内存泄漏,并启用保存张量钩子的应用。请参阅 :class:`torch.autograd.graph.saved_tensors_hooks`。 请注意,如果保存了中间张量(既不是 :func:`forward` 的输入也不是输出的张量),您的自定义 Function 可能不支持双重反向传播。 不支持双重反向传播的自定义 Function 应在其 :func:`backward` 方法上使用 ``@once_differentiable`` 装饰器,以便在执行双重反向传播时引发错误。如果您希望支持双重反向传播,您可以根据反向传播期间的输入重新计算中间张量,或者将中间张量作为自定义 Function 的输出返回。有关更多详细信息,请参阅 `双重反向传播教程 `_。 在 :func:`backward` 中,保存的张量可以通过 :attr:`saved_tensors` 属性访问。在将它们返回给用户之前,会进行检查以确保它们没有在任何就地操作中被修改。 参数也可以是 ``None``。这是一个空操作。 有关如何使用此方法的更多详细信息,请参阅 :ref:`extending-autograd`。 示例:: >>> # xdoctest: +REQUIRES(env:TORCH_DOCTEST_AUTOGRAD) >>> class Func(Function): >>> @staticmethod >>> def forward(ctx, x: torch.Tensor, y: torch.Tensor, z: int): >>> w = x * z >>> out = x * y + y * z + w * y >>> ctx.save_for_backward(x, y, w, out) >>> ctx.z = z # z 不是张量 >>> return out >>> >>> @staticmethod >>> @once_differentiable >>> def backward(ctx, grad_out): >>> x, y, w, out = ctx.saved_tensors >>> z = ctx.z >>> gx = grad_out * (y + y * z) >>> gy = grad_out * (x + z + w) >>> gz = None >>> return gx, gy, gz >>> >>> a = torch.tensor(1., requires_grad=True, dtype=torch.double) >>> b = torch.tensor(2., requires_grad=True, dtype=torch.double) >>> c = 4 >>> d = Func.apply(a, b, c) """ self.to_save = tensors
def save_for_forward(self, *tensors: torch.Tensor): r"""保存给定的张量以供将来调用 :func:`~Function.jvp`。 ``save_for_forward`` 应仅调用一次,从 :func:`forward` 方法内部调用,并且仅使用张量。 在 :func:`jvp` 中,保存的对象可以通过 :attr:`saved_tensors` 属性访问。 参数也可以是 ``None``。这是一个空操作。 有关如何使用此方法的更多详细信息,请参阅 :ref:`extending-autograd`。 示例:: >>> # xdoctest: +SKIP >>> class Func(torch.autograd.Function): >>> @staticmethod >>> def forward(ctx, x: torch.Tensor, y: torch.Tensor, z: int): >>> ctx.save_for_backward(x, y) >>> ctx.save_for_forward(x, y) >>> ctx.z = z >>> return x * y * z >>> >>> @staticmethod >>> def jvp(ctx, x_t, y_t, _): >>> x, y = ctx.saved_tensors >>> z = ctx.z >>> return z * (y * x_t + x * y_t) >>> >>> @staticmethod >>> def vjp(ctx, grad_out): >>> x, y = ctx.saved_tensors >>> z = ctx.z >>> return z * grad_out * y, z * grad_out * x, None >>> >>> a = torch.tensor(1., requires_grad=True, dtype=torch.double) >>> t = torch.tensor(1., dtype=torch.double) >>> b = torch.tensor(2., requires_grad=True, dtype=torch.double) >>> c = 4 >>> >>> with fwAD.dual_level(): >>> a_dual = fwAD.make_dual(a, t) >>> d = Func.apply(a_dual, b, c) """ for tensor in tensors: assert isinstance(tensor, torch.Tensor) or tensor is None, ( "save_for_forward 期望所有参数都是张量;您应该将非张量保存为 ctx 的属性。" ) self.saved_for_forward = tensors
[docs] def mark_dirty(self, *args: torch.Tensor): r"""标记给定的张量在就地操作中被修改。 **这最多应调用一次,仅从 :func:`forward` 方法内部调用,并且所有参数都应是输入。** 在调用 :func:`forward` 时,所有在就地操作中被修改的张量都应传递给此函数,以确保我们的检查的正确性。调用此函数的时间(在修改之前或之后)无关紧要。 示例:: >>> # xdoctest: +REQUIRES(env:TORCH_DOCTEST_AUTOGRAD) >>> class Inplace(Function): >>> @static