torch.optim.optimizer 的源代码
import math
import functools
import warnings
from collections import OrderedDict, defaultdict
from copy import deepcopy
from itertools import chain
from typing import (
Any,
Callable,
DefaultDict,
Dict,
Hashable,
Iterable,
List,
Optional,
Set,
Tuple,
TypeVar,
Union,
cast,
overload,
)
from typing_extensions import ParamSpec, Self, TypeAlias
import torch
import torch.utils.hooks as hooks
from torch.utils.hooks import RemovableHandle
from torch.utils._foreach_utils import (
Indices,
TensorListList,
_get_foreach_kernels_supported_devices,
_get_fused_kernels_supported_devices,
)
from torch._utils import is_compiling
from torch.utils._foreach_utils import _group_tensors_by_device_and_dtype
Args: TypeAlias = Tuple[Any, ...]
Kwargs: TypeAlias = Dict[str, Any]
StateDict: TypeAlias = Dict[str, Any]
GlobalOptimizerPreHook: TypeAlias = Callable[["Optimizer", Args, Kwargs], Optional[Tuple[Args, Kwargs]]]
GlobalOptimizerPostHook: TypeAlias = Callable[["Optimizer", Args, Kwargs], None]
__all__ = ['Optimizer', 'register_optimizer_step_pre_hook', 'register_optimizer_step_post_hook']
_global_optimizer_pre_hooks: Dict[int, GlobalOptimizerPreHook] = OrderedDict()
_global_optimizer_post_hooks: Dict[int, GlobalOptimizerPostHook] = OrderedDict()
_foreach_supported_types = [torch.Tensor, torch.nn.parameter.Parameter]
class _RequiredParameter:
"""Singleton class representing a required parameter for an Optimizer."""
def __repr__(self) -> str:
return ""
required = _RequiredParameter()
def _use_grad_for_differentiable(func):
def _use_grad(self, *args, **kwargs):
import torch._dynamo
prev_grad = torch.is_grad_enabled()
try:
# Note on graph break below:
# we need to graph break to ensure that aot respects the no_grad annotation.
# This is important for perf because without this, functionalization will generate an epilogue
# which updates the mutated parameters of the optimizer which is *not* visible to inductor, as a result,
# inductor will allocate for every parameter in the model, which is horrible.
# With this, aot correctly sees that this is an inference graph, and functionalization will generate
# an epilogue which is appended to the graph, which *is* visible to inductor, as a result, inductor sees that
# step is in place and is able to avoid the extra allocation.
# In the future, we will either 1) continue to graph break on backward, so this graph break does not matter
# or 2) have a fully fused forward and backward graph, which will have no_grad by default, and we can remove this
# graph break to allow the fully fused fwd-bwd-optimizer graph to be compiled.
# see https://github.com/pytorch/pytorch/issues/104053
torch.set_grad_enabled(self.defaults['differentiable'])
torch._dynamo.graph_break()
ret = func(self, *args, **kwargs)
finally:
torch._dynamo.graph_break()
torch.set_grad_enabled(prev_grad)
return ret
functools.update_wrapper(_use_grad, func)
return _use_grad
def _get_value(x):
# item is significantly faster than a cpu tensor in eager mode
if not torch.jit.is_scripting() and is_compiling():
return x
else:
return x.item()
def _stack_if_compiling(x):
if not torch.jit.is_scripting() and is_compiling():
return torch.stack(x)
else:
return x
def _dispatch_sqrt(x: float): # float annotation is needed because of torchscript type inference
if not torch.jit.is_scripting() and isinstance(x, <span class="