torch.optim.nadam 的源代码
```html
import torch from torch import Tensor from .optimizer import (Optimizer, _use_grad_for_differentiable, _get_value, _dispatch_sqrt, _stack_if_compiling, _get_scalar_dtype, _default_to_fused_or_foreach, _view_as_real, _capturable_doc, _differentiable_doc, _foreach_doc,) from typing import List, Optional __all__ = ['NAdam', 'nadam'][docs]class NAdam(Optimizer): def __init__(self, params, lr=2e-3, betas=(0.9, 0.999), eps=1e-8, weight_decay=0, momentum_decay=4e-3, decoupled_weight_decay: bool = False, *, foreach: Optional[bool] = None, capturable: bool = False, differentiable: bool = False): if not 0.0 <= lr: raise ValueError(f"Invalid learning rate: {lr}") if not 0.0 <= eps: raise ValueError(f"Invalid epsilon value: {eps}") if not 0.0 <= betas[0] < 1.0: raise ValueError(f"Invalid beta parameter at index 0: {betas[0]}") if not 0.0 <= betas[1] < 1.0: raise ValueError(f"Invalid beta parameter at index 1: {betas[1]}") if not 0.0 <= weight_decay: raise ValueError(f"Invalid weight_decay value: {weight_decay}") if not 0.0 <= momentum_decay: raise ValueError(f"Invalid momentum_decay value: {momentum_decay}") defaults = dict(lr=lr, betas=betas, eps=eps, weight_decay=weight_decay, momentum_decay=momentum_decay, decoupled_weight_decay=decoupled_weight_decay, foreach=foreach, capturable=capturable, differentiable=differentiable) super().__init__(params, defaults) def __setstate__(self, state): super().__setstate__(state) for group in self.param_groups: group.setdefault('foreach', None) group.setdefault('capturable', False) group.setdefault('differentiable', False) group.setdefault('decoupled_weight_decay', False) for p in group["params"]: p_state = self.state.get(p, []) if len(p_state) != 0: if not torch.is_tensor(p_state['step']): step_val = float(p_state["step"]) p_state["step"] = (torch.tensor(step_val, dtype=_get_scalar_dtype(), device=p.device) if group['capturable'] else torch.tensor(step_val, dtype=_get_scalar_dtype())) if not torch.is_tensor(p_state['mu_product']): mu_prod_val = p_state["mu_product"] p_state["mu_product"]</