torch.optim.asgd 的源代码
```html
import torch from torch import Tensor from .optimizer import (Optimizer, _use_grad_for_differentiable, _get_value, _default_to_fused_or_foreach, _get_scalar_dtype, _view_as_real, _differentiable_doc, _foreach_doc, _maximize_doc, _capturable_doc) from typing import List, Optional __all__ = ["ASGD", "asgd"] def _to_tensor(x, device=None): if not isinstance(x, torch.Tensor): return torch.tensor(x, device=device) return x[docs]class ASGD(Optimizer): def __init__( self, params, lr=1e-2, lambd=1e-4, alpha=0.75, t0=1e6, weight_decay=0, foreach: Optional[bool] = None, maximize: bool = False, differentiable: bool = False, capturable: bool = False, ): if not 0.0 <= lr: raise ValueError(f"Invalid learning rate: {lr}") if not 0.0 <= weight_decay: raise ValueError(f"Invalid weight_decay value: {weight_decay}") defaults = dict( lr=lr, lambd=lambd, alpha=alpha, t0=t0, weight_decay=weight_decay, foreach=foreach, maximize=maximize, differentiable=differentiable, capturable=capturable, ) super().__init__(params, defaults) def __setstate__(self, state): super().__setstate__(state) for group in self.param_groups: group.setdefault("foreach", None) group.setdefault("maximize", False) group.setdefault("differentiable", False) group.setdefault("capturable", False) for p in group["params"]: p_state = self.state.get(p, []) if len(p_state) != 0: if not torch.is_tensor(p_state['step']): step_val = float(p_state["step"]) p_state["step"] = torch.tensor(step_val, dtype=_get_scalar_dtype(), device=p.device) if not torch.is_tensor(p_state["eta"]): p_state["eta"] = torch.tensor(p_state["eta"], dtype=_get_scalar_dtype(), device=p.device) if not torch.is_tensor(p_state["mu"]): p_state["mu"] = torch.tensor(p_state["mu"], dtype=_get_scalar_dtype(), device=p.device) def _init_group(self, group, params_with_grad, grads, mus, axs, etas, state_steps): has_complex = False for p in group["params"]: