Shortcuts

torch.optim.radam 的源代码

```html
from typing import List, Optional

import torch
from torch import Tensor

from .optimizer import (
    Optimizer,
    _default_to_fused_or_foreach,
    _differentiable_doc,
    _capturable_doc,
    _dispatch_sqrt,
    _foreach_doc,
    _get_scalar_dtype,
    _get_value,
    _use_grad_for_differentiable,
    _view_as_real,
)

__all__ = ["RAdam", "radam"]


[docs]class RAdam(Optimizer): def __init__( self, params, lr=1e-3, betas=(0.9, 0.999), eps=1e-8, weight_decay=0, decoupled_weight_decay: bool = False, *, foreach: Optional[bool] = None, capturable: bool = False, differentiable: bool = False, ): if not 0.0 <= lr: raise ValueError(f"Invalid learning rate: {lr}") if not 0.0 <= eps: raise ValueError(f"Invalid epsilon value: {eps}") if not 0.0 <= betas[0] < 1.0: raise ValueError(f"Invalid beta parameter at index 0: {betas[0]}") if not 0.0 <= betas[1] < 1.0: raise ValueError(f"Invalid beta parameter at index 1: {betas[1]}") if not 0.0 <= weight_decay: raise ValueError(f"Invalid weight_decay value: {weight_decay}") defaults = dict( lr=lr, betas=betas, eps=eps, weight_decay=weight_decay, foreach=foreach, capturable=capturable, decoupled_weight_decay=decoupled_weight_decay, differentiable=differentiable, ) super().__init__(params, defaults) def __setstate__(self, state): super().__setstate__(state) for group in self.param_groups: group.setdefault("foreach", None) group.setdefault("differentiable", False) group.setdefault("decoupled_weight_decay", False) group.setdefault("capturable", False) for p in group["params"]: p_state = self.state.get(p, []) if len(p_state) != 0 and not torch.is_tensor(p_state['step']): step_val = float(p_state["step"]) p_state["step"] = (torch.tensor(step_val, dtype=_get_scalar_dtype(), device=p.device) if group['capturable'] else torch.tensor(step_val, dtype=_get_scalar_dtype())) def _init_group(self, group, params_with_grad, grads, exp_avgs, exp_avg_sqs, state_steps): has_complex = False for p in group["params"]: if p.grad is not None: has_complex |= torch.is_
优云智算