Shortcuts

torch.optim.adam 的源代码

from typing import List, Optional, Union, Tuple

import torch
from torch import Tensor
from .optimizer import (Optimizer, ParamsT, _use_grad_for_differentiable, _get_value,
                        _stack_if_compiling, _dispatch_sqrt, _default_to_fused_or_foreach,
                        _get_scalar_dtype, _capturable_doc, _differentiable_doc, _foreach_doc,
                        _fused_doc, _maximize_doc, _view_as_real)
from torch.utils._foreach_utils import _get_fused_kernels_supported_devices

__all__ = ['Adam', 'adam']


[docs]class Adam(Optimizer): def __init__(self, params: ParamsT, lr: Union[float, Tensor] = 1e-3, betas: Tuple[float, float] = (0.9, 0.999), eps: float = 1e-8, weight_decay: float = 0, amsgrad: bool = False, *, foreach: Optional[bool] = None, maximize: bool = False, capturable: bool = False, differentiable: bool = False, fused: Optional[bool] = None): if not 0.0 <= lr: raise ValueError(f"Invalid learning rate: {lr}") if isinstance(lr, Tensor) and foreach and not capturable: raise ValueError("lr as a Tensor is not supported for capturable=False and foreach=True") if not 0.0 <= eps: raise ValueError(f"Invalid epsilon value: {eps}") if not 0.0 <= betas[0] < 1.0: raise ValueError(f"Invalid beta parameter at index 0: {betas[0]}") if not 0.0 <= betas[1] < 1.0: raise ValueError(f"Invalid beta parameter at index 1: {betas[1]}") if not 0.0 <= weight_decay: raise ValueError(f"Invalid weight_decay value: {weight_decay}") defaults = dict(lr=lr, betas=betas, eps=eps, weight_decay=weight_decay, amsgrad=amsgrad, maximize=maximize, foreach=foreach, capturable=capturable, differentiable=differentiable, fused=fused) super().__init__(params, defaults) if fused: if differentiable: raise RuntimeError("`fused` does not support `differentiable`") self._step_supports_amp_scaling = True # TODO(crcrpar): [low prec params & their higher prec copy] # Support AMP with FP16/BF16 model params which would need # higher prec copy of params to do update math in higher prec to # alleviate the loss of information. fused_supported_devices = _get_fused_kernels_supported_devices() if not all( p.device.type in fused_supported_devices and torch.is_floating_point(p) for pg in self.param_groups for p in pg['params'] ): raise RuntimeError("`fused=True` requires all the params to be floating point Tensors of " f"supported devices: {fused_supported_devices}.") if foreach: raise RuntimeError("`fused` and `foreach` cannot be `True` together.") def __setstate__(self, state): super().__setstate__(state