Shortcuts

torch.utils.checkpoint 的源代码

```html
import contextlib
import platform
import uuid
import warnings
import weakref
from collections import defaultdict
from itertools import count
from typing import (
    Any,
    Callable,
    ContextManager,
    DefaultDict,
    Dict,
    Iterable,
    List,
    Optional,
    Tuple,
)
from weakref import ReferenceType

import torch
import torch.fx.traceback as fx_traceback
from torch._functorch._aot_autograd.functional_utils import is_fun
from torch.utils._pytree import tree_map
from torch.testing._internal.logging_tensor import capture_logs, LoggingTensorMode
from torch.utils._python_dispatch import TorchDispatchMode

__all__ = [
    "checkpoint",
    "checkpoint_sequential",
    "CheckpointError",
    "CheckpointFunction",
    "check_backward_validity",
    "detach_variable",
    "get_device_states",
    "set_device_states",
    "noop_context_fn",
    "set_checkpoint_early_stop",
    "DefaultDeviceType",
    "set_checkpoint_debug_enabled",
]

_DEFAULT_DETERMINISM_MODE = "default"

_checkpoint_debug_enabled: Optional[bool] = None


[docs]@contextlib.contextmanager def set_checkpoint_debug_enabled(enabled: Optional[bool]): """ 上下文管理器,设置检查点是否应在运行时打印额外的调试信息。 有关更多信息,请参阅 :func:`~torch.utils.checkpoint.checkpoint` 的 ``debug`` 标志。 请注意,当设置时,此上下文管理器会覆盖传递给检查点的 ``debug`` 值。 要推迟到本地设置,请将 ``None`` 传递给此上下文。 参数: enabled (bool): 检查点是否应打印调试信息。 默认值为 'None'。 """ global _checkpoint_debug_enabled try: prev = _checkpoint_debug_enabled _checkpoint_debug_enabled = enabled yield finally: _checkpoint_debug_enabled = prev
def detach_variable(inputs: Tuple[Any, ...]) -> Tuple[torch.Tensor, ...]: if isinstance(inputs, tuple): out = [] for inp in inputs: if not isinstance(inp, torch.Tensor): out.append(inp) continue x = inp.detach() x.requires_grad = inp.requires_grad out.append(x) return tuple(out) else: raise RuntimeError( "仅支持元组张量。不支持的输入类型: ", type(inputs).__name__, ) def check_backward_validity(inputs: Iterable[Any]) -> None: if not any(inp.requires_grad for inp in inputs if isinstance(inp, torch.Tensor)): warnings.warn( "没有任何输入的 requires_grad=True。梯度将为 None" ) def _get_device_module(device="cuda"): device_module = getattr(torch, device) return device_module class DefaultDeviceType: r""" 一个管理检查点默认设备类型的类。 如果没有非CPU张量存在,将使用默认设备类型。默认值为 'cuda'。 设备类型在检查点过程中用于确定保存和恢复哪些设备状态以进行重新计算。 """ _default_device_type = "cuda" @staticmethod def set_device_type(device: str = "cuda"): """ 设置检查点的默认设备类型。 参数: device (str): 要设置为默认的设备类型。默认值为 'cuda'。 """ DefaultDeviceType._default_device_type = device @staticmethod def get_device_type() -> str: """ 获取当前的默认设备类型。 返回: str: 当前的默认设备类型。 """ return DefaultDeviceType._default_device_type def _infer_device_type(*args): device_types = list( { arg.device.type for arg in args if isinstance(arg, torch.Tensor) and not arg<span