torch.optim.adam 的源代码
from typing import List, Optional, Union, Tuple
import torch
from torch import Tensor
from .optimizer import (Optimizer, ParamsT, _use_grad_for_differentiable, _get_value,
_stack_if_compiling, _dispatch_sqrt, _default_to_fused_or_foreach,
_get_scalar_dtype, _capturable_doc, _differentiable_doc, _foreach_doc,
_fused_doc, _maximize_doc, _view_as_real)
from torch.utils._foreach_utils import _get_fused_kernels_supported_devices
__all__ = ['Adam', 'adam']
[docs]class Adam(Optimizer):
def __init__(self,
params: ParamsT,
lr: Union[float, Tensor] = 1e-3,
betas: Tuple[float, float] = (0.9, 0.999),
eps: float = 1e-8,
weight_decay: float = 0,
amsgrad: bool = False,
*,
foreach: Optional[bool] = None,
maximize: bool = False,
capturable: bool = False,
differentiable: bool = False,
fused: Optional[bool] = None):
if not 0.0 <= lr:
raise ValueError(f"Invalid learning rate: {lr}")
if isinstance(lr, Tensor) and foreach and not capturable:
raise ValueError("lr as a Tensor is not supported for capturable=False and foreach=True")
if not 0.0 <= eps:
raise ValueError(f"Invalid epsilon value: {eps}")
if not 0.0 <= betas[0] < 1.0:
raise ValueError(f"Invalid beta parameter at index 0: {betas[0]}")
if not 0.0 <= betas[1] < 1.0:
raise ValueError(f"Invalid beta parameter at index 1: {betas[1]}")
if not 0.0 <= weight_decay:
raise ValueError(f"Invalid weight_decay value: {weight_decay}")
defaults = dict(lr=lr, betas=betas, eps=eps,
weight_decay=weight_decay, amsgrad=amsgrad,
maximize=maximize, foreach=foreach, capturable=capturable,
differentiable=differentiable, fused=fused)
super().__init__(params, defaults)
if fused:
if differentiable:
raise RuntimeError("`fused` does not support `differentiable`")
self._step_supports_amp_scaling = True
# TODO(crcrpar): [low prec params & their higher prec copy]
# Support AMP with FP16/BF16 model params which would need
# higher prec copy of params to do update math in higher prec to
# alleviate the loss of information.
fused_supported_devices = _get_fused_kernels_supported_devices()
if not all(
p.device.type in fused_supported_devices and
torch.is_floating_point(p) for pg in self.param_groups for p in pg['params']
):
raise RuntimeError("`fused=True` requires all the params to be floating point Tensors of "
f"supported devices: {fused_supported_devices}.")
if foreach:
raise RuntimeError("`fused` and `foreach` cannot be `True` together.")
def __setstate__(self, state):
super().__setstate__(state