Shortcuts

torch.optim.adamw 的源代码

import torch
from torch import Tensor
from .optimizer import (Optimizer, _use_grad_for_differentiable, _get_value, _dispatch_sqrt,
                        _stack_if_compiling, _get_scalar_dtype, _capturable_doc, _differentiable_doc,
                        _foreach_doc, _fused_doc, _maximize_doc, _default_to_fused_or_foreach,
                        ParamsT, _view_as_real)
from typing import List, Optional, Tuple, Union
from torch.utils._foreach_utils import _get_fused_kernels_supported_devices

__all__ = ["AdamW", "adamw"]


[docs]class AdamW(Optimizer): def __init__( self, params: ParamsT, lr: Union[float, Tensor] = 1e-3, betas: Tuple[float, float] = (0.9, 0.999), eps: float = 1e-8, weight_decay: float = 1e-2, amsgrad: bool = False, *, maximize: bool = False, foreach: Optional[bool] = None, capturable: bool = False, differentiable: bool = False, fused: Optional[bool] = None, ): if not 0.0 <= lr: raise ValueError(f"Invalid learning rate: {lr}") if isinstance(lr, Tensor) and foreach and not capturable: raise ValueError("lr as a Tensor is not supported for capturable=False and foreach=True") if not 0.0 <= eps: raise ValueError(f"Invalid epsilon value: {eps}") if not 0.0 <= betas[0] < 1.0: raise ValueError(f"Invalid beta parameter at index 0: {betas[0]}") if not 0.0 <= betas[1] < 1.0: raise ValueError(f"Invalid beta parameter at index 1: {betas[1]}") if not 0.0 <= weight_decay: raise ValueError(f"Invalid weight_decay value: {weight_decay}") defaults = dict( lr=lr, betas=betas, eps=eps, weight_decay=weight_decay, amsgrad=amsgrad, foreach=foreach, maximize=maximize, capturable=capturable, differentiable=differentiable, fused=fused, ) super().__init__(params, defaults) if fused: if differentiable: raise RuntimeError("`fused` does not support `differentiable`") self._step_supports_amp_scaling = True # TODO(crcrpar): [low prec params & their higher prec copy] # Suppor AMP with FP16/BF16 model params which would need # higher prec copy of params to do update math in higher prec to # alleviate the loss of information. fused_supported_devices = _get_fused_kernels_supported_devices() if not all( p.device.type in fused_supported_devices and torch.is_floating_point(p) for pg in self.param_groups for p in pg['params'] ): raise RuntimeError("`fused=True` requires all the params to be floating point Tensors of " f"supported devices: {fused_supported_devices}.") if foreach: raise RuntimeError("`fused` and `foreach` cannot be `True` together.") def __setstate__(self, state): <
优云智算