Shortcuts

torch.optim.rprop 的源代码

import torch
from torch import Tensor
from .optimizer import (Optimizer, _use_grad_for_differentiable, _default_to_fused_or_foreach,
                        _differentiable_doc, _foreach_doc, _maximize_doc, _view_as_real)
from typing import List, Optional

__all__ = ["Rprop", "rprop"]


[docs]class Rprop(Optimizer): def __init__( self, params, lr=1e-2, etas=(0.5, 1.2), step_sizes=(1e-6, 50), *, foreach: Optional[bool] = None, maximize: bool = False, differentiable: bool = False, ): if not 0.0 <= lr: raise ValueError(f"Invalid learning rate: {lr}") if not 0.0 < etas[0] < 1.0 < etas[1]: raise ValueError(f"Invalid eta values: {etas[0]}, {etas[1]}") defaults = dict( lr=lr, etas=etas, step_sizes=step_sizes, foreach=foreach, maximize=maximize, differentiable=differentiable, ) super().__init__(params, defaults) def __setstate__(self, state): super().__setstate__(state) for group in self.param_groups: group.setdefault("foreach", None) group.setdefault("maximize", False) group.setdefault("differentiable", False) def _init_group(self, group, params, grads, prevs, step_sizes): has_complex = False for p in group["params"]: if p.grad is None: continue has_complex |= torch.is_complex(p) params.append(p) grad = p.grad if grad.is_sparse: raise RuntimeError("Rprop does not support sparse gradients") grads.append(grad) state = self.state[p] # State initialization if len(state) == 0: state["step"] = 0 state["prev"] = torch.zeros_like( p, memory_format=torch.preserve_format ) if p.dtype.is_complex: # Complex Number should be as if they are two independent real numbers. # Hence the step_size shouldn't be zero for imaginary part. state["step_size"] = ( torch.full_like(grad, complex(group["lr"], group["lr"])) ) else: state["step_size"] = torch.full_like(grad, group["lr"]) prevs.append(state["prev"]) step_sizes.append(state["step_size"]) state["step"] += 1 return has_complex
[docs] @_use_grad_for_differentiable def step(self, closure=None): """执行单个优化步骤。 参数: closure (
优云智算