Shortcuts

torch.optim.sparse_adam 的源代码

```html
import torch
from . import _functional as F
from .optimizer import Optimizer, _maximize_doc

__all__ = ['SparseAdam']

[docs]class SparseAdam(Optimizer): def __init__(self, params, lr=1e-3, betas=(0.9, 0.999), eps=1e-8, maximize: bool = False): if not 0.0 < lr: raise ValueError(f"无效的学习率: {lr}") if not 0.0 < eps: raise ValueError(f"无效的epsilon值: {eps}") if not 0.0 <= betas[0] < 1.0: raise ValueError(f"无效的beta参数,索引0: {betas[0]}") if not 0.0 <= betas[1] < 1.0: raise ValueError(f"无效的beta参数,索引1: {betas[1]}") defaults = dict(lr=lr, betas=betas, eps=eps, maximize=maximize) super().__init__(params, defaults) sparse_params = [] complex_params = [] for index, param_group in enumerate(self.param_groups): assert isinstance(param_group, dict), f"param_groups必须是一个字典列表,但得到了{type(param_group)}" # 给定参数组,首先将给定的参数转换为列表,然后再进行迭代 for d_index, d_param in enumerate(param_group['params']): if d_param.is_sparse: sparse_params.append([index, d_index]) if d_param.is_complex(): complex_params.append([index, d_index]) if sparse_params: raise ValueError( f"在索引{sparse_params}处的稀疏参数: SparseAdam需要密集的参数张量" ) if complex_params: raise ValueError( f"在索引{complex_params}处的复数参数: SparseAdam不支持复数参数" )
[docs] @torch.no_grad() def step(self, closure=None): """执行单个优化步骤。 参数: closure (Callable, 可选): 一个重新评估模型并返回损失的闭包。 """ loss = None if closure is not None: with torch.enable_grad(): loss = closure() for group in self.param_groups: params_with_grad = [] grads = [] exp_avgs = [] exp_avg_sqs = [] state_steps = [] eps = group['eps'] lr = group['lr'] beta1, beta2 = group['betas'] maximize = group.get('maximize', False) for p in group['params']: if p.grad is not None: params_with_grad.append(p) if not p.grad.is_sparse: raise RuntimeError('SparseAdam不支持密集梯度,请考虑使用Adam') grads.append(p.grad) state = self.state[p] # 状态初始化 if len(state) == 0: state[<span