Shortcuts

torch.optim.sgd 的源代码

import torch
from torch import Tensor
from .optimizer import (Optimizer, _use_grad_for_differentiable, _default_to_fused_or_foreach,
                        _differentiable_doc, _foreach_doc, _maximize_doc, _fused_doc)
from typing import List, Optional

__all__ = ['SGD', 'sgd']


[docs]class SGD(Optimizer): def __init__(self, params, lr=1e-3, momentum=0, dampening=0, weight_decay=0, nesterov=False, *, maximize: bool = False, foreach: Optional[bool] = None, differentiable: bool = False, fused: Optional[bool] = None): if lr < 0.0: raise ValueError(f"Invalid learning rate: {lr}") if momentum < 0.0: raise ValueError(f"Invalid momentum value: {momentum}") if weight_decay < 0.0: raise ValueError(f"Invalid weight_decay value: {weight_decay}") defaults = dict(lr=lr, momentum=momentum, dampening=dampening, weight_decay=weight_decay, nesterov=nesterov, maximize=maximize, foreach=foreach, differentiable=differentiable, fused=fused) if nesterov and (momentum <= 0 or dampening != 0): raise ValueError("Nesterov momentum requires a momentum and zero dampening") super().__init__(params, defaults) if fused: self._step_supports_amp_scaling = True if differentiable: raise RuntimeError("`fused` does not support `differentiable`") if foreach: raise RuntimeError("`fused` and `foreach` cannot be `True` together.") def __setstate__(self, state): super().__setstate__(state) for group in self.param_groups: group.setdefault('nesterov', False) group.setdefault('maximize', False) group.setdefault('foreach', None) group.setdefault('differentiable', False) group.setdefault('fused', False) def _init_group(self, group, params_with_grad, d_p_list, momentum_buffer_list): has_sparse_grad = False for p in group['params']: if p.grad is not None: params_with_grad.append(p) d_p_list.append(p.grad) if p.grad.is_sparse: has_sparse_grad = True state = self.state[p] momentum_buffer_list.append(state.get('momentum_buffer')) return has_sparse_grad
[docs] @_use_grad_for_differentiable def step(self, closure=None): """执行单个优化步骤。 参数: closure (Callable, 可选): 一个重新评估模型并返回损失的闭包。 """ loss = None if closure is not None: with torch.enable_grad(): loss = closure() for group in self.param_groups: params_with_grad = []
优云智算