torch.optim.adamax 的源代码
```html
import torch from torch import Tensor from .optimizer import (Optimizer, _use_grad_for_differentiable, _get_value, _default_to_fused_or_foreach, _get_scalar_dtype, _differentiable_doc, _maximize_doc, _foreach_doc, _view_as_real, _capturable_doc) from typing import List, Optional __all__ = ["Adamax", "adamax"][docs]class Adamax(Optimizer): def __init__( self, params, lr=2e-3, betas=(0.9, 0.999), eps=1e-8, weight_decay=0, foreach: Optional[bool] = None, *, maximize: bool = False, differentiable: bool = False, capturable: bool = False, ): if not 0.0 <= lr: raise ValueError(f"Invalid learning rate: {lr}") if not 0.0 <= eps: raise ValueError(f"Invalid epsilon value: {eps}") if not 0.0 <= betas[0] < 1.0: raise ValueError(f"Invalid beta parameter at index 0: {betas[0]}") if not 0.0 <= betas[1] < 1.0: raise ValueError(f"Invalid beta parameter at index 1: {betas[1]}") if not 0.0 <= weight_decay: raise ValueError(f"Invalid weight_decay value: {weight_decay}") defaults = dict( lr=lr, betas=betas, eps=eps, weight_decay=weight_decay, foreach=foreach, maximize=maximize, differentiable=differentiable, capturable=capturable, ) super().__init__(params, defaults) def __setstate__(self, state): super().__setstate__(state) for group in self.param_groups: group.setdefault("foreach", None) group.setdefault("maximize", False) group.setdefault("differentiable", False) group.setdefault("capturable", False) for p in group["params"]: p_state = self.state.get(p, []) if len(p_state) != 0 and not torch.is_tensor(p_state['step']): step_val = float(p_state["step"]) p_state["step"] = (torch.tensor(step_val, dtype=_get_scalar_dtype(), device=p.device) if group['capturable'] else torch.tensor(step_val, dtype=_get_scalar_dtype())) def _init_group(self, group, params_with_grad, grads, exp_avgs, exp_infs, state_steps): has_complex = False for p in group["params"]: if p.grad is None: continue has_complex |= torch.is_complex(p) params_with_grad<span class="