Shortcuts

torch.nn.utils.clip_grad 的源代码

```html
import warnings
import functools
from typing import Union, Iterable, List, Dict, Tuple, Optional, cast

import torch
from torch import Tensor
from torch.utils._foreach_utils import _group_tensors_by_device_and_dtype, _has_foreach_support, _device_has_foreach_support

_tensor_or_tensors = Union[torch.Tensor, Iterable[torch.Tensor]]

__all__ = ['clip_grad_norm_', 'clip_grad_norm', 'clip_grad_value_']

def _no_grad(func):
    """
    这个包装器是为了避免在使用@torch.no_grad装饰暴露的函数clip_grad_norm_和clip_grad_value_时出现循环导入
    """
    def _no_grad_wrapper(*args, **kwargs):
        with torch.no_grad():
            return func(*args, **kwargs)
    functools.update_wrapper(_no_grad_wrapper, func)
    return _no_grad_wrapper

[docs]@_no_grad def clip_grad_norm_( parameters: _tensor_or_tensors, max_norm: float, norm_type: float = 2.0, error_if_nonfinite: bool = False, foreach: Optional[bool] = None) -> torch.Tensor: r"""裁剪可迭代参数的梯度范数。 所有梯度一起计算范数,就好像它们被连接成一个单一的向量一样。梯度是就地修改的。 参数: parameters (Iterable[Tensor] 或 Tensor): 一个Tensor的可迭代对象或单个Tensor,其梯度将被归一化 max_norm (float): 梯度的最大范数 norm_type (float): 使用的p-范数类型。可以是``'inf'``表示无穷范数。 error_if_nonfinite (bool): 如果为True,如果来自:attr:`parameters`的梯度的总范数是``nan``,``inf``或``-inf``,则会抛出错误。默认: False (未来将切换为True) foreach (bool): 使用更快的foreach实现。如果为``None``,则对CUDA和CPU原生张量使用foreach实现,并静默回退到其他设备类型的慢速实现。默认: ``None`` 返回: 参数梯度的总范数(视为单个向量)。 """ if isinstance(parameters, torch.Tensor): parameters = [parameters] grads = [p.grad for p in parameters if p.grad is not None] max_norm = float(max_norm) norm_type = float(norm_type) if len(grads) == 0: return torch.tensor(0.) first_device = grads[0].device grouped_grads: Dict[Tuple[torch.device, torch.dtype], Tuple[List[List[Tensor]], List[int]]] \ = _group_tensors_by_device_and_dtype([grads]) # type: ignore[assignment] norms: List[Tensor] = [] for ((device, _), ([device_grads], _)) in grouped_grads.items(): # type: ignore[assignment] if ( (foreach is None and _has_foreach_support(device_grads, device)) or (foreach and _device_has_foreach_support(device)) ): norms.extend(torch._foreach_norm(device_grads, norm_type)) elif foreach: raise RuntimeError(f'foreach=True was passed, but can\'t use the foreach API on {device.type} tensors') else: norms.extend([torch.linalg.vector_norm(g, norm_type) for g in device_grads]) total_norm = torch.linalg.vector_norm(torch.stack([norm.to(first_device) for norm in norms]), norm_type) if error_if_nonfinite and torch.logical_or(total_norm.isnan(), total_norm.isinf()): raise RuntimeError<span
优云智算