torch.autograd.grad_mode 的源代码
from typing import Any
import torch
from torch.utils._contextlib import (
_DecoratorContextManager,
_NoParamDecoratorContextManager,
F,
)
__all__ = [
"no_grad",
"enable_grad",
"set_grad_enabled",
"inference_mode",
"set_multithreading_enabled",
]
[docs]class no_grad(_NoParamDecoratorContextManager):
r"""上下文管理器,禁用梯度计算。
禁用梯度计算在推理时非常有用,当你确定不会调用 :meth:`Tensor.backward()` 时。它可以减少内存消耗,否则这些计算将具有 `requires_grad=True`。
在这种模式下,每个计算结果都将具有 `requires_grad=False`,即使输入具有 `requires_grad=True`。
有一个例外!所有工厂函数或创建新 Tensor 并接受 requires_grad 关键字参数的函数都不会受到此模式的影响。
此上下文管理器是线程局部的;它不会影响其他线程中的计算。
也可以作为装饰器使用。
.. 注意::
no-grad 是可以在本地启用或禁用梯度的几种机制之一,请参阅 :ref:`locally-disable-grad-doc` 以了解它们如何比较。
.. 注意::
此 API 不适用于 :ref:`forward-mode AD `。
如果你想禁用前向 AD 进行计算,可以解包你的双张量。
示例::
>>> # xdoctest: +SKIP
>>> x = torch.tensor([1.], requires_grad=True)
>>> with torch.no_grad():
... y = x * 2
>>> y.requires_grad
False
>>> @torch.no_grad()
... def doubler(x):
... return x * 2
>>> z = doubler(x)
>>> z.requires_grad
False
>>> @torch.no_grad
... def tripler(x):
... return x * 3
>>> z = tripler(x)
>>> z.requires_grad
False
>>> # 工厂函数例外
>>> with torch.no_grad():
... a = torch.nn.Parameter(torch.rand(10))
>>> a.requires_grad
True
"""
def __init__(self) -> None:
if not torch._jit_internal.is_scripting():
super().__init__()
self.prev = False
def __enter__(self) -> None:
self.prev = torch.is_grad_enabled()
torch.set_grad_enabled(False)
def __exit__(self, exc_type: Any, exc_value: Any, traceback: Any) -> None:
torch.set_grad_enabled(self.prev)
[docs]class enable_grad(_NoParamDecoratorContextManager):
r"""上下文管理器,启用梯度计算。
如果梯度计算已通过 :class:`~no_grad` 或 :class:`~set_grad_enabled` 禁用,则启用梯度计算。
此上下文管理器是线程局部的;它不会影响其他线程中的计算。
也可以作为装饰器使用。
.. 注意::
enable_grad 是可以在本地启用或禁用梯度的几种机制之一,请参阅 :ref:`locally-disable-grad-doc` 以了解它们如何比较。
.. 注意::
此 API 不适用于 :ref:`forward-mode AD `。
示例::
>>> # xdoctest: +SKIP
>>> x = torch.tensor([1.], requires_grad=True)
>>> with torch.no_grad():
... with torch.enable_grad():
... y = x * 2
>>> y.requires_grad
True
>>> y.backward()
>>> x.grad
tensor([2.])
>>> @torch.enable_grad()
... def doubler(x):
... return x * 2
>>> with torch.no_grad():
... z = doubler(x)
>>> z.requires_grad
True
>>> @torch.enable_grad
... def tripler(x):
... return x * 3
>>> with torch.no_grad():
... z = tripler(x)
>>> z.requires_grad
True
"""
def __enter__(self) -> None:
self.prev = torch.is_grad_enabled()
torch._C._set_grad_enabled(True)
def __exit__(self, exc_type: Any, exc_value: Any, traceback: Any) -> None:
torch._C._set_grad_enabled(self.prev)
[docs]class set_grad_enabled(_DecoratorContextManager):
r"""上下文管理器,设置梯度计算的开关。
``set_grad_enabled`` 将根据其参数 :attr:`mode` 启用或禁用梯度。
它可以作为上下文管理器或函数使用。
此上下文管理器是线程局部的;它不会影响其他线程中的计算。
参数:
mode (bool): 标志是否启用梯度 (``True``) 或禁用 (``False``)。这可以用于条件性地启用梯度。
.. 注意::
set_grad_enabled 是可以在本地启用或禁用梯度的几种机制之一,请参阅 :ref:`locally-disable-grad-doc` 以了解它们如何比较。
.. 注意::
此 API 不适用于 :ref:`forward-mode AD `。
示例::
>>> # xdoctest: +SKIP
>>> x = torch.tensor([1.], requires_grad=True)
>>> is_train = False
>>> with torch.set_grad_enabled(is_train):
... y = x * 2
>>> y.requires_grad
False
>>> _ = torch.set_grad_enabled(True)
>>> y = x * 2
>>> y.requires_grad
True