Shortcuts

torch.autograd.functional 的源代码

from typing import List, Tuple

import torch
from torch._vmap_internals import _vmap
from . import forward_ad as fwAD

__all__ = ["vjp", "jvp", "jacobian", "hessian", "hvp", "vhp"]

# 实用函数


def _as_tuple_nocheck(x):
    if isinstance(x, tuple):
        return x
    elif isinstance(x, list):
        return tuple(x)
    else:
        return (x,)


def _as_tuple(inp, arg_name=None, fn_name=None):
    # 确保inp是一个Tensor元组
    # 返回原始inp是否为元组以及输入的元组版本
    if arg_name is None and fn_name is None:
        return _as_tuple_nocheck(inp)

    is_inp_tuple = True
    if not isinstance(inp, tuple):
        inp = (inp,)
        is_inp_tuple = False

    for i, el in enumerate(inp):
        if not isinstance(el, torch.Tensor):
            if is_inp_tuple:
                raise TypeError(
                    f"The {arg_name} given to {fn_name} must be either a Tensor or a tuple of Tensors but the"
                    f" value at index {i} has type {type(el)}."
                )
            else:
                raise TypeError(
                    f"The {arg_name} given to {fn_name} must be either a Tensor or a tuple of Tensors but the"
                    f" given {arg_name} has type {type(el)}."
                )

    return is_inp_tuple, inp


def _tuple_postprocess(res, to_unpack):
    # 解包潜在嵌套的Tensor元组
    # to_unpack 应该是一个单一的布尔值或两个布尔值的元组。
    # 它用于:
    # - 当res应该匹配传递给_as_tuple的inp时,反转_as_tuple
    # - 可选地去除由多次调用_as_tuple创建的两个元组的嵌套
    if isinstance(to_unpack, tuple):
        assert len(to_unpack) == 2
        if not to_unpack[1]:
            res = tuple(el[0] for el in res)
        if not to_unpack[0]:
            res = res[0]
    else:
        if not to_unpack:
            res = res[0]
    return res


def _grad_preprocess(inputs, create_graph, need_graph):
    # 预处理输入以确保它们需要梯度
    # inputs 是一个要预处理的Tensor元组
    # create_graph 指定用户是否希望梯度回流到inputs中的Tensor
    # need_graph 指定我们内部是否希望梯度回流到res中的Tensor
    # 注意,我们*总是*创建一个新的Tensor对象,以便能够区分作为参数给出的输入和用户函数自动捕获的相同Tensor。
    # 有关如何发生这种情况的更多详细信息,请参阅此问题:https://github.com/pytorch/pytorch/issues/32576
    res = []
    for inp in inputs:
        if create_graph and inp.requires_grad:
            # 以可微分的方式创建至少一个新的Tensor对象
            if not inp.is_sparse:
                # 使用.view_as()获取浅拷贝
                res.append(inp.view_as(inp))
            else:
                # 我们不能对稀疏Tensor使用view,因此我们克隆
                res.append(inp.clone())
        else:
            res.append(inp.detach().requires_grad_(need_graph))
    return tuple(res)


def _grad_postprocess(inputs, create_graph):
    # 后处理生成的Tensor,以避免在用户未请求时返回具有历史的Tensor。
    if isinstance(inputs[0], torch.Tensor):
        if not create_graph:
            return tuple(inp.detach() for inp in inputs)
        else:
            return inputs
    else:
        return tuple(_
优云智算