Shortcuts

torch.autograd.graph 的源代码

```html
import abc
import collections
import contextlib
import functools
import logging
import threading
import weakref
from collections import defaultdict, namedtuple
from typing import (
    Any,
    Callable,
    cast,
    Deque,
    Dict,
    List,
    Optional,
    Sequence,
    Set,
    Tuple,
    Union,
)

import torch
from torch.autograd.variable import Variable
from torch.utils._python_dispatch import TorchDispatchMode
from torch.utils.hooks import RemovableHandle

log = logging.getLogger(__name__)


__all__ = [
    "saved_tensors_hooks",
    "save_on_cpu",
    "disable_saved_tensors_hooks",
    "register_multi_grad_hook",
    "allow_mutation_on_saved_tensors",
    "Node",
    "GradientEdge",
    "get_gradient_edge",
    "increment_version",
]


class Node(abc.ABC):
[docs] @abc.abstractmethod def name(self) -> str: r"""返回名称。 示例:: >>> import torch >>> a = torch.tensor([0., 0., 0.], requires_grad=True) >>> b = a.clone() >>> assert isinstance(b.grad_fn, torch.autograd.graph.Node) >>> print(b.grad_fn.name()) CloneBackward0 """ ...
@property @abc.abstractmethod def next_functions(self) -> Tuple[Tuple[Optional["Node"], int], ...]: ...
[docs] @abc.abstractmethod def metadata(self) -> dict: r"""返回元数据。""" ...
@abc.abstractmethod def _register_hook_dict(self, tensor: torch.Tensor) -> None: ...
[docs] @abc.abstractmethod def register_hook(self, fn: Callable[..., Any]) -> RemovableHandle: r"""注册一个反向钩子。 每次计算相对于节点的梯度时,都会调用该钩子。钩子应具有以下签名:: hook(grad_inputs: Tuple[Tensor], grad_outputs: Tuple[Tensor]) -> Tuple[Tensor] or None 钩子不应修改其参数,但可以选择返回一个新的梯度,该梯度将代替 :attr:`grad_inputs` 使用。 此函数返回一个带有 ``handle.remove()`` 方法的句柄,该方法从模块中删除钩子。 .. 注意:: 有关此钩子何时执行以及其执行顺序相对于其他钩子的更多信息,请参阅 :ref:`backward-hooks-execution`。 示例:: >>> import torch >>> a = torch.tensor([0., 0., 0.], requires_grad=True) >>> b = a.clone() >>> assert isinstance(b.grad_fn, torch.autograd.graph.Node) >>> handle = b.grad_fn.register_hook(lambda gI, gO: (gO[0] * 2,)) >>> b.sum().backward(retain_graph=True) >>> print(a.grad) tensor([2., 2., 2.]) >>> handle.remove() # 删除钩子 >>> a.grad = None >>> b.sum().backward(retain_graph=True) >>> print(a.grad) tensor([1., 1., 1.]) """ ...
[docs] @abc.abstractmethod def register_prehook(self, fn: Callable[..., Any]) -> RemovableHandle: r"""注册一个反向预钩子。 每次计算相对于节点的梯度时,都会调用该钩子。钩子应具有以下签名:: hook(grad_outputs: Tuple[Tensor]) -> Tuple[Tensor] or None 钩子不应修改其参数,但可以选择返回一个新的梯度,该梯度将代替 :attr:`grad_outputs` 使用。 此函数返回一个带有 ``handle.remove()`` 方法的句柄,该方法从模块中删除钩子。 .. 注意:: 有关此钩子何时执行以及其执行顺序相对于其他钩子的更多信息,请参阅 :ref:`backward-hooks-execution`。 示例:: >>> a = torch.tensor([0., 0., 0.], requires_grad=True) >>> b = a.clone() >>> assert isinstance(b.grad_fn, torch.autograd.graph.Node) >>> handle = b.grad_fn.register_prehook(lambda gI: (gI[0] * 2,)) >>> b.sum().backward(retain_graph=True) >>> print(a.grad) tensor([2., 2., 2.]) >>> handle.remove() >>> a.grad = None >>> b.sum().backward(retain_graph=True) >>> print(a.grad) tensor([1., 1., 1.]) """ ...
@classmethod def __subclasshook__(cls, C): if cls is Node: if ( C is not None and C is getattr(torch._C._functions, C.__name__,</
优云智算