torch.autograd.functional 的源代码
from typing import List, Tuple
import torch
from torch._vmap_internals import _vmap
from . import forward_ad as fwAD
__all__ = ["vjp", "jvp", "jacobian", "hessian", "hvp", "vhp"]
# 实用函数
def _as_tuple_nocheck(x):
if isinstance(x, tuple):
return x
elif isinstance(x, list):
return tuple(x)
else:
return (x,)
def _as_tuple(inp, arg_name=None, fn_name=None):
# 确保inp是一个Tensor元组
# 返回原始inp是否为元组以及输入的元组版本
if arg_name is None and fn_name is None:
return _as_tuple_nocheck(inp)
is_inp_tuple = True
if not isinstance(inp, tuple):
inp = (inp,)
is_inp_tuple = False
for i, el in enumerate(inp):
if not isinstance(el, torch.Tensor):
if is_inp_tuple:
raise TypeError(
f"The {arg_name} given to {fn_name} must be either a Tensor or a tuple of Tensors but the"
f" value at index {i} has type {type(el)}."
)
else:
raise TypeError(
f"The {arg_name} given to {fn_name} must be either a Tensor or a tuple of Tensors but the"
f" given {arg_name} has type {type(el)}."
)
return is_inp_tuple, inp
def _tuple_postprocess(res, to_unpack):
# 解包潜在嵌套的Tensor元组
# to_unpack 应该是一个单一的布尔值或两个布尔值的元组。
# 它用于:
# - 当res应该匹配传递给_as_tuple的inp时,反转_as_tuple
# - 可选地去除由多次调用_as_tuple创建的两个元组的嵌套
if isinstance(to_unpack, tuple):
assert len(to_unpack) == 2
if not to_unpack[1]:
res = tuple(el[0] for el in res)
if not to_unpack[0]:
res = res[0]
else:
if not to_unpack:
res = res[0]
return res
def _grad_preprocess(inputs, create_graph, need_graph):
# 预处理输入以确保它们需要梯度
# inputs 是一个要预处理的Tensor元组
# create_graph 指定用户是否希望梯度回流到inputs中的Tensor
# need_graph 指定我们内部是否希望梯度回流到res中的Tensor
# 注意,我们*总是*创建一个新的Tensor对象,以便能够区分作为参数给出的输入和用户函数自动捕获的相同Tensor。
# 有关如何发生这种情况的更多详细信息,请参阅此问题:https://github.com/pytorch/pytorch/issues/32576
res = []
for inp in inputs:
if create_graph and inp.requires_grad:
# 以可微分的方式创建至少一个新的Tensor对象
if not inp.is_sparse:
# 使用.view_as()获取浅拷贝
res.append(inp.view_as(inp))
else:
# 我们不能对稀疏Tensor使用view,因此我们克隆
res.append(inp.clone())
else:
res.append(inp.detach().requires_grad_(need_graph))
return tuple(res)
def _grad_postprocess(inputs, create_graph):
# 后处理生成的Tensor,以避免在用户未请求时返回具有历史的Tensor。
if isinstance(inputs[0], torch.Tensor):
if not create_graph:
return tuple(inp.detach() for inp in inputs)
else:
return inputs
else:
return tuple(_