Shortcuts

torch.cuda.amp.autocast_mode 的源代码

```html
import collections
import functools

import torch

try:
    import numpy as np

    HAS_NUMPY = True
except ModuleNotFoundError:
    np = None  # type: ignore[assignment]
from typing import Any

__all__ = ["autocast", "custom_fwd", "custom_bwd"]


[docs]class autocast(torch.amp.autocast_mode.autocast): r"""参见 :class:`torch.autocast`。 ``torch.cuda.amp.autocast(args...)`` 等同于 ``torch.autocast("cuda", args...)`` """ def __init__( self, enabled: bool = True, dtype: torch.dtype = torch.float16, cache_enabled: bool = True, ): if torch._jit_internal.is_scripting(): self._enabled = enabled self.device = "cuda" self.fast_dtype = dtype return super().__init__( "cuda", enabled=enabled, dtype=dtype, cache_enabled=cache_enabled ) def __enter__(self): if torch._jit_internal.is_scripting(): return self return super().__enter__() # TODO: 讨论一个统一的 TorchScript 友好 API 用于 autocast def __exit__(self, exc_type: Any, exc_val: Any, exc_tb: Any): # type: ignore[override] if torch._jit_internal.is_scripting(): return return super().__exit__(exc_type, exc_val, exc_tb) def __call__(self, func): if torch._jit_internal.is_scripting(): return func return super().__call__(func)
# 转换张量和包含张量的容器。特别处理字符串和 np.ndarray,它们可能被错误地检测为“可迭代对象”。 def _cast(value, dtype): if isinstance(value, torch.Tensor): is_eligible = ( value.is_floating_point() and value.is_cuda and (value.dtype is not torch.float64) ) return value.to(dtype) if is_eligible else value elif isinstance(value, (str, bytes)): return value elif HAS_NUMPY and isinstance(value, np.ndarray): return value elif isinstance(value, collections.abc.Mapping): return {_cast(k, dtype): _cast(v, dtype) for k, v in value.items()} elif isinstance(value, collections.abc.Iterable): iterable = (_cast(v, dtype) for v in value) if isinstance(value, (list, tuple)): return type(value)(iterable) else: return iterable else: return value # custom_fwd 是一个装饰器,可能带参数也可能不带参数,遵循 # https://github.com/dabeaz/python-cookbook/tree/master/src/9/defining_a_decorator_that_takes_an_optional_argument。 # 这样是可行的: # @custom_fwd # def forward(...): # 这样也是可行的: # @custom_fwd(cast_inputs=torch.float) # def forward(...):
[docs]def custom_fwd(fwd=None, *, cast_inputs=None): """ 为自定义 autograd 函数的 ``forward`` 方法创建一个辅助装饰器。 Autograd 函数是 :class:`torch.autograd.Function` 的子类。 有关更多详细信息,请参阅 :ref:`示例页面`。 参数:</span
优云智算