Shortcuts

torch.autograd.grad_mode 的源代码

from typing import Any

import torch

from torch.utils._contextlib import (
    _DecoratorContextManager,
    _NoParamDecoratorContextManager,
    F,
)

__all__ = [
    "no_grad",
    "enable_grad",
    "set_grad_enabled",
    "inference_mode",
    "set_multithreading_enabled",
]


[docs]class no_grad(_NoParamDecoratorContextManager): r"""上下文管理器,禁用梯度计算。 禁用梯度计算在推理时非常有用,当你确定不会调用 :meth:`Tensor.backward()` 时。它可以减少内存消耗,否则这些计算将具有 `requires_grad=True`。 在这种模式下,每个计算结果都将具有 `requires_grad=False`,即使输入具有 `requires_grad=True`。 有一个例外!所有工厂函数或创建新 Tensor 并接受 requires_grad 关键字参数的函数都不会受到此模式的影响。 此上下文管理器是线程局部的;它不会影响其他线程中的计算。 也可以作为装饰器使用。 .. 注意:: no-grad 是可以在本地启用或禁用梯度的几种机制之一,请参阅 :ref:`locally-disable-grad-doc` 以了解它们如何比较。 .. 注意:: 此 API 不适用于 :ref:`forward-mode AD `。 如果你想禁用前向 AD 进行计算,可以解包你的双张量。 示例:: >>> # xdoctest: +SKIP >>> x = torch.tensor([1.], requires_grad=True) >>> with torch.no_grad(): ... y = x * 2 >>> y.requires_grad False >>> @torch.no_grad() ... def doubler(x): ... return x * 2 >>> z = doubler(x) >>> z.requires_grad False >>> @torch.no_grad ... def tripler(x): ... return x * 3 >>> z = tripler(x) >>> z.requires_grad False >>> # 工厂函数例外 >>> with torch.no_grad(): ... a = torch.nn.Parameter(torch.rand(10)) >>> a.requires_grad True """ def __init__(self) -> None: if not torch._jit_internal.is_scripting(): super().__init__() self.prev = False def __enter__(self) -> None: self.prev = torch.is_grad_enabled() torch.set_grad_enabled(False) def __exit__(self, exc_type: Any, exc_value: Any, traceback: Any) -> None: torch.set_grad_enabled(self.prev)
[docs]class enable_grad(_NoParamDecoratorContextManager): r"""上下文管理器,启用梯度计算。 如果梯度计算已通过 :class:`~no_grad` 或 :class:`~set_grad_enabled` 禁用,则启用梯度计算。 此上下文管理器是线程局部的;它不会影响其他线程中的计算。 也可以作为装饰器使用。 .. 注意:: enable_grad 是可以在本地启用或禁用梯度的几种机制之一,请参阅 :ref:`locally-disable-grad-doc` 以了解它们如何比较。 .. 注意:: 此 API 不适用于 :ref:`forward-mode AD `。 示例:: >>> # xdoctest: +SKIP >>> x = torch.tensor([1.], requires_grad=True) >>> with torch.no_grad(): ... with torch.enable_grad(): ... y = x * 2 >>> y.requires_grad True >>> y.backward() >>> x.grad tensor([2.]) >>> @torch.enable_grad() ... def doubler(x): ... return x * 2 >>> with torch.no_grad(): ... z = doubler(x) >>> z.requires_grad True >>> @torch.enable_grad ... def tripler(x): ... return x * 3 >>> with torch.no_grad(): ... z = tripler(x) >>> z.requires_grad True """ def __enter__(self) -> None: self.prev = torch.is_grad_enabled() torch._C._set_grad_enabled(True) def __exit__(self, exc_type: Any, exc_value: Any, traceback: Any) -> None: torch._C._set_grad_enabled(self.prev)
[docs]class set_grad_enabled(_DecoratorContextManager): r"""上下文管理器,设置梯度计算的开关。 ``set_grad_enabled`` 将根据其参数 :attr:`mode` 启用或禁用梯度。 它可以作为上下文管理器或函数使用。 此上下文管理器是线程局部的;它不会影响其他线程中的计算。 参数: mode (bool): 标志是否启用梯度 (``True``) 或禁用 (``False``)。这可以用于条件性地启用梯度。 .. 注意:: set_grad_enabled 是可以在本地启用或禁用梯度的几种机制之一,请参阅 :ref:`locally-disable-grad-doc` 以了解它们如何比较。 .. 注意:: 此 API 不适用于 :ref:`forward-mode AD `。 示例:: >>> # xdoctest: +SKIP >>> x = torch.tensor([1.], requires_grad=True) >>> is_train = False >>> with torch.set_grad_enabled(is_train): ... y = x * 2 >>> y.requires_grad False >>> _ = torch.set_grad_enabled(True) >>> y = x * 2 >>> y.requires_grad True
优云智算