torch.optim.sgd 的源代码
import torch
from torch import Tensor
from .optimizer import (Optimizer, _use_grad_for_differentiable, _default_to_fused_or_foreach,
_differentiable_doc, _foreach_doc, _maximize_doc, _fused_doc)
from typing import List, Optional
__all__ = ['SGD', 'sgd']
[docs]class SGD(Optimizer):
def __init__(self, params, lr=1e-3, momentum=0, dampening=0,
weight_decay=0, nesterov=False, *, maximize: bool = False, foreach: Optional[bool] = None,
differentiable: bool = False, fused: Optional[bool] = None):
if lr < 0.0:
raise ValueError(f"Invalid learning rate: {lr}")
if momentum < 0.0:
raise ValueError(f"Invalid momentum value: {momentum}")
if weight_decay < 0.0:
raise ValueError(f"Invalid weight_decay value: {weight_decay}")
defaults = dict(lr=lr, momentum=momentum, dampening=dampening,
weight_decay=weight_decay, nesterov=nesterov,
maximize=maximize, foreach=foreach,
differentiable=differentiable, fused=fused)
if nesterov and (momentum <= 0 or dampening != 0):
raise ValueError("Nesterov momentum requires a momentum and zero dampening")
super().__init__(params, defaults)
if fused:
self._step_supports_amp_scaling = True
if differentiable:
raise RuntimeError("`fused` does not support `differentiable`")
if foreach:
raise RuntimeError("`fused` and `foreach` cannot be `True` together.")
def __setstate__(self, state):
super().__setstate__(state)
for group in self.param_groups:
group.setdefault('nesterov', False)
group.setdefault('maximize', False)
group.setdefault('foreach', None)
group.setdefault('differentiable', False)
group.setdefault('fused', False)
def _init_group(self, group, params_with_grad, d_p_list, momentum_buffer_list):
has_sparse_grad = False
for p in group['params']:
if p.grad is not None:
params_with_grad.append(p)
d_p_list.append(p.grad)
if p.grad.is_sparse:
has_sparse_grad = True
state = self.state[p]
momentum_buffer_list.append(state.get('momentum_buffer'))
return has_sparse_grad
[docs] @_use_grad_for_differentiable
def step(self, closure=None):
"""执行单个优化步骤。
参数:
closure (Callable, 可选): 一个重新评估模型并返回损失的闭包。
"""
loss = None
if closure is not None:
with torch.enable_grad():
loss = closure()
for group in self.param_groups:
params_with_grad = []