torch.optim.rprop 的源代码
import torch
from torch import Tensor
from .optimizer import (Optimizer, _use_grad_for_differentiable, _default_to_fused_or_foreach,
_differentiable_doc, _foreach_doc, _maximize_doc, _view_as_real)
from typing import List, Optional
__all__ = ["Rprop", "rprop"]
[docs]class Rprop(Optimizer):
def __init__(
self,
params,
lr=1e-2,
etas=(0.5, 1.2),
step_sizes=(1e-6, 50),
*,
foreach: Optional[bool] = None,
maximize: bool = False,
differentiable: bool = False,
):
if not 0.0 <= lr:
raise ValueError(f"Invalid learning rate: {lr}")
if not 0.0 < etas[0] < 1.0 < etas[1]:
raise ValueError(f"Invalid eta values: {etas[0]}, {etas[1]}")
defaults = dict(
lr=lr,
etas=etas,
step_sizes=step_sizes,
foreach=foreach,
maximize=maximize,
differentiable=differentiable,
)
super().__init__(params, defaults)
def __setstate__(self, state):
super().__setstate__(state)
for group in self.param_groups:
group.setdefault("foreach", None)
group.setdefault("maximize", False)
group.setdefault("differentiable", False)
def _init_group(self, group, params, grads, prevs, step_sizes):
has_complex = False
for p in group["params"]:
if p.grad is None:
continue
has_complex |= torch.is_complex(p)
params.append(p)
grad = p.grad
if grad.is_sparse:
raise RuntimeError("Rprop does not support sparse gradients")
grads.append(grad)
state = self.state[p]
# State initialization
if len(state) == 0:
state["step"] = 0
state["prev"] = torch.zeros_like(
p, memory_format=torch.preserve_format
)
if p.dtype.is_complex:
# Complex Number should be as if they are two independent real numbers.
# Hence the step_size shouldn't be zero for imaginary part.
state["step_size"] = (
torch.full_like(grad, complex(group["lr"], group["lr"]))
)
else:
state["step_size"] = torch.full_like(grad, group["lr"])
prevs.append(state["prev"])
step_sizes.append(state["step_size"])
state["step"] += 1
return has_complex
[docs] @_use_grad_for_differentiable
def step(self, closure=None):
"""执行单个优化步骤。
参数:
closure (