torch.cuda.amp.grad_scaler 的源代码
import torch
from torch.amp.grad_scaler import OptState
__all__ = ["GradScaler", "OptState"]
[docs]class GradScaler(torch.amp.GradScaler):
r"""
请参阅 :class:`torch.amp.GradScaler`。
``torch.cuda.amp.GradScaler(args...)`` 等同于 ``torch.amp.GradScaler("cuda", args...)``
"""
def __init__(
self,
init_scale: float = 2.0**16,
growth_factor: float = 2.0,
backoff_factor: float = 0.5,
growth_interval: int = 2000,
enabled: bool = True,
) -> None:
super().__init__(
"cuda",
init_scale=init_scale,
growth_factor=growth_factor,
backoff_factor=backoff_factor,
growth_interval=growth_interval,
enabled=enabled,
)