Shortcuts

扩展 PyTorch

在本笔记中,我们将介绍扩展torch.nntorch.autogradtorch以及编写自定义C++扩展的方法。

扩展 torch.autograd

autograd 添加操作需要为每个操作实现一个新的 Function 子类。回想一下,Functions 是 autograd 用于编码操作历史和计算梯度的工具。

本文档的第一部分主要关注反向模式自动微分(AD),因为这是最广泛使用的功能。最后一部分讨论了正向模式自动微分的扩展。

何时使用

通常情况下,如果您希望在模型中执行不可微分或依赖于非PyTorch库(例如,NumPy)的计算,但仍希望您的操作能够与其他操作链接并使用自动求导引擎,则可以实现一个自定义函数。

在某些情况下,自定义函数也可以用于提高性能和内存使用:如果你使用C++扩展实现了前向和反向传播,你可以将它们封装在Function中以与自动求导引擎接口。如果你想减少为反向传播保存的缓冲区数量,可以使用自定义函数将操作组合在一起。

何时不使用

如果你已经能够用 PyTorch 的内置操作来编写你的函数,它的反向传播图(很可能)已经能够被 autograd 记录。在这种情况下,你不需要自己实现反向传播函数。考虑使用一个普通的 Python 函数。

如果你需要维护状态,即训练参数,你应该(也)使用一个自定义模块。有关扩展 torch.nn 的更多信息,请参见下面的章节。

如果您想在反向传播过程中改变梯度或执行副作用,请考虑注册一个 tensorModule 钩子。

如何使用

按照以下步骤操作: 1. 子类化 Function 并实现 forward() 方法, (可选)setup_context()backward() 方法。 2. 在 ctx 参数上调用适当的方法。 3. 声明您的函数是否支持 双重反向传播。 4. 使用 gradcheck 验证您的梯度是否正确。

步骤1:在子类化 Function 之后,你需要定义3个方法:

  • forward() 是执行操作的代码。它可以接受任意数量的参数,其中一些参数是可选的,如果你指定了默认值。这里接受所有类型的Python对象。Tensor 参数如果跟踪历史记录(即,requires_grad=True),将在调用前转换为不跟踪历史的张量,并且它们的使用将被注册在图中。请注意,此逻辑不会遍历列表/字典/任何其他数据结构,只会考虑作为直接参数传递给调用的张量。你可以返回单个 Tensor 输出,或者如果存在多个输出,则返回一个 tuple 的张量。此外,请参阅 Function 的文档,以找到只能在 forward() 中调用的有用方法的描述。

  • setup_context()(可选)。可以编写一个“组合”的forward(),该方法接受一个ctx对象,或者(从PyTorch 2.0开始)编写一个不接受ctx的单独forward()和一个setup_context()方法,其中ctx的修改发生。 forward()应该负责计算,而setup_context()应该只负责ctx的修改(而不进行任何计算)。 通常,单独的forward()setup_context()更接近PyTorch原生操作的工作方式,因此与各种PyTorch子系统更具组合性。 有关更多详细信息,请参阅组合或单独的forward()和setup_context()

  • backward()(或 vjp())定义了梯度公式。 它将被赋予与输出数量相同的 Tensor 参数,每个参数表示相对于该输出的梯度。重要的是永远不要就地修改这些参数。它应该返回与输入数量相同的张量,每个张量包含相对于其相应输入的梯度。如果你的输入不需要梯度(needs_input_grad 是一个布尔元组,指示每个输入是否需要梯度计算),或者是非 Tensor 对象,你可以返回 python:None。此外,如果你有可选的参数传递给 forward(),你可以返回比输入更多的梯度,只要它们都是 None

步骤2:您有责任正确使用ctx中的函数,以确保新的Function能够正确地与自动求导引擎配合工作。

  • save_for_backward() 必须用于保存任何在反向传播中使用的张量。非张量应直接存储在 ctx 上。如果保存的张量既不是输入也不是输出,您的 Function 可能不支持双重反向传播(见步骤3)。

  • mark_dirty() 必须用于标记前向函数中任何就地修改的输入。

  • mark_non_differentiable() 必须用于告知引擎某个输出是否不可微。默认情况下,所有可微类型的输出张量都将被设置为需要梯度。不可微类型的张量(即整数类型)永远不会被标记为需要梯度。

  • set_materialize_grads() 可以用于告诉自动求导引擎在输出不依赖于输入的情况下优化梯度计算,即不将传递给反向函数的梯度张量具体化。也就是说,如果设置为False,Python中的None对象或C++中的“未定义张量”(即x.defined()为False的张量x)在调用反向传播之前不会被转换为填充零的张量,因此您的代码需要将这些对象视为填充零的张量来处理。此设置的默认值为True。

步骤 3: 如果你的 Function 不支持双重反向传播,你应该通过使用 once_differentiable() 装饰器来明确声明这一点。使用此装饰器,尝试通过你的函数进行双重反向传播将会产生错误。有关双重反向传播的更多信息,请参阅我们的双重反向传播教程。

步骤 4:建议您使用 torch.autograd.gradcheck() 来检查您的反向函数是否正确计算了前向的梯度,方法是使用您的反向函数计算雅可比矩阵,并与使用有限差分法计算的数值雅可比进行逐元素比较。

示例

下面你可以找到一个线性函数的代码,并附有额外的注释:

# 继承自 Function
class LinearFunction(Function):

    # 注意 forward, setup_context, 和 backward 是 @staticmethods
    @staticmethod
    def forward(input, weight, bias):
        output = input.mm(weight.t())
        if bias is not None:
            output += bias.unsqueeze(0).expand_as(output)
        return output

    @staticmethod
    # inputs 是一个包含传递给 forward 的所有输入的元组。
    # output 是 forward() 的输出。
    def setup_context(ctx, inputs, output):
        input, weight, bias = inputs
        ctx.save_for_backward(input, weight, bias)

    # 这个函数只有一个输出,所以它只得到一个梯度
    @staticmethod
    def backward(ctx, grad_output):
        # 这是一个非常方便的模式 - 在 backward 的顶部
        # 解包 saved_tensors 并初始化所有输入的梯度为
        # None。由于额外的尾部 None 被忽略,
        # 即使函数有可选输入,返回语句也很简单。
        input, weight, bias = ctx.saved_tensors
        grad_input = grad_weight = grad_bias = None

        # 这些 needs_input_grad 检查是可选的,仅用于
        # 提高效率。如果你想简化代码,可以
        # 跳过它们。返回不需要的输入的梯度
        # 不是错误。
        if ctx.needs_input_grad[0]:
            grad_input = grad_output.mm(weight)
        if ctx.needs_input_grad[1]:
            grad_weight = grad_output.t().mm(input)
        if bias is not None and ctx.needs_input_grad[2]:
            grad_bias = grad_output.sum(0)

        return grad_input, grad_weight, grad_bias

现在,为了更方便地使用这些自定义操作,我们建议要么为它们设置别名,要么将它们封装在一个函数中。将它们封装在函数中可以让我们支持默认参数和关键字参数:

# 选项1: 别名
linear = LinearFunction.apply

# 选项2: 封装在一个函数中,以支持默认参数和关键字参数。
def linear(input, weight, bias=None):
    return LinearFunction.apply(input, weight, bias)

在这里,我们给出了一个函数的附加示例,该函数由非Tensor参数参数化:

class MulConstant(Function):
    @staticmethod
    def forward(tensor, constant):
        return tensor * constant

    @staticmethod
    def setup_context(ctx, inputs, output):
        # ctx 是一个上下文对象,可以用来存储信息
        # 用于反向传播计算
        tensor, constant = inputs
        ctx.constant = constant

    @staticmethod
    def backward(ctx, grad_output):
        # 我们返回与参数数量相同的输入梯度。
        # 非 Tensor 参数的梯度必须为 None。
        return grad_output * ctx.constant, None

在这里,我们通过调用 set_materialize_grads(False) 来优化上述示例:

class MulConstant(Function):
    @staticmethod
    def forward(tensor, constant):
        return tensor * constant

    @staticmethod
    def setup_context(ctx, inputs, output):
        tensor, constant = inputs
        ctx.set_materialize_grads(False)
        ctx.constant = constant

    @staticmethod
    def backward(ctx, grad_output):
        # 这里我们必须处理None的grad_output张量。在这种情况下,我们可以跳过不必要的计算并直接返回None。
        if grad_output is None:
            return None, None

        # 我们返回与参数数量相同的输入梯度。
        # 非Tensor参数的梯度必须为None。
        return grad_output * ctx.constant, None

如果你需要在 forward() 中计算的任何“中间”张量被保存, 要么它们必须作为输出返回,要么结合 forwardsetup_context() (参见 Combined or separate forward() and setup_context())。 请注意,这意味着如果你希望梯度通过这些中间值流动,你需要为它们定义梯度公式(另请参阅 双重反向传播教程 ):

class MyCube(torch.autograd.Function):
    @staticmethod
    def forward(x):
        # 我们希望保存dx以用于反向传播。为此,它必须作为输出返回。
        dx = 3 * x ** 2
        result = x ** 3
        return result, dx

    @staticmethod
    def setup_context(ctx, inputs, output):
        x, = inputs
        result, dx = output
        ctx.save_for_backward(x, dx)

    @staticmethod
    def backward(ctx, grad_output, grad_dx):
        x, dx = ctx.saved_tensors
        # 为了使autograd.Function能够处理高阶梯度,我们必须添加`dx`的梯度贡献,
        # 即 grad_dx * 6 * x。
        result = grad_output * dx + grad_dx * 6 * x
        return result

# 将MyCube包装在一个函数中,以便更清楚地了解输出是什么
def my_cube(x):
    result, dx = MyCube.apply(x)
    return result

注意

输入到 backward 的内容,即 grad_output,也可以是跟踪历史的张量。因此,如果 backward 是用可微分的操作实现的(例如,调用另一个自定义的 Function),高阶导数将会起作用。在这种情况下,使用 save_for_backward 保存的张量也可以在反向传播中使用,并且会有梯度回流,但保存在 ctx 中的张量不会有梯度回流。如果你需要对保存在 ctx 中的张量进行梯度回流,你应该将其作为自定义 Function 的输出,并使用 save_for_backward 保存它。

您可能希望检查您实现的反向传播方法是否确实计算了函数的导数。通过与使用小有限差分的数值近似进行比较,可以实现这一点:

from torch.autograd import gradcheck

# gradcheck 接受一个张量元组作为输入,检查使用这些张量计算的梯度是否
# 足够接近数值近似值,如果所有张量都满足此条件,则返回 True。
input = (torch.randn(20,20,dtype=torch.double,requires_grad=True), torch.randn(30,20,dtype=torch.double,requires_grad=True))
test = gradcheck(linear, input, eps=1e-6, atol=1e-4)
print(test)

有关有限差分梯度比较的更多详细信息,请参阅数值梯度检查。 如果你的函数用于高阶导数(反向传播的微分),你可以使用同一包中的gradgradcheck函数来检查高阶导数。

组合或分离 forward()setup_context()

有两种主要方式来定义Function。要么:

  • 定义一个forward(),它将前向计算逻辑与setup_context()结合在一起

  • (截至PyTorch 2.0)定义一个单独的forward()setup_context()

我们推荐第二种选择(单独的 forward()setup_context()),因为这更接近 PyTorch 原生操作的实现方式,并且与 torch.func 转换组合使用。然而,我们计划在未来支持这两种方法;将 forward()setup_context() 结合使用:可以带来更大的灵活性,因为您能够在不将它们作为输出返回的情况下保存中间结果。

请参阅上一节,了解如何使用单独的Function定义 forward()setup_context()

这里是一个如何定义一个结合了Functionforward()setup_context()的示例:

class LinearFunction(Function):
    @staticmethod
    # ctx 是 forward 的第一个参数
    def forward(ctx, input, weight, bias=None):
        # 前向传播可以使用 ctx
        ctx.save_for_backward(input, weight, bias)
        output = input.mm(weight.t())
        if bias is not None:
            output += bias.unsqueeze(0).expand_as(output)
        return output

    @staticmethod
    def backward(ctx, grad_output):
        input, weight, bias = ctx.saved_tensors
        grad_input = grad_weight = grad_bias = None

        if ctx.needs_input_grad[0]:
            grad_input = grad_output.mm(weight)
        if ctx.needs_input_grad[1]:
            grad_weight = grad_output.t().mm(input)
        if bias is not None and ctx.needs_input_grad[2]:
            grad_bias = grad_output.sum(0)

        return grad_input, grad_weight, grad_bias

前向模式自动微分

覆盖前向模式自动微分公式的API与一些不同的细微差别非常相似。 您可以实现jvp()函数。

它将被赋予与输入数量相同的 Tensor 参数,每个参数代表相对于该输入的梯度。它应该返回与输出数量相同的张量,每个张量包含相对于其相应输出的梯度。jvp() 方法将在 forward() 方法之后立即被调用,在 apply() 返回之前。

jvp()backward() 函数有一些细微的差异:

  • 您可以使用ctx将任何数据从forward()传递到jvp()函数。 如果该状态不需要用于backward(), 您可以通过在jvp()函数的末尾执行del ctx.foo来显式释放它。

  • 实现 jvp() 必须是反向可微的,或者显式检查给定的前向模式梯度中没有任何一个设置了 requires_grad

  • jvp() 函数必须与 forward() 的视图/就地行为相匹配。 例如,如果第 i 个输入被就地修改,那么第 i 个梯度必须在就地更新。 同样,如果第 j 个输出是第 k 个输入的视图,那么返回的第 j 个输出梯度必须是给定的第 k 个输入梯度的视图。

  • 因为用户无法指定需要计算哪个梯度,所以jvp()函数应该始终计算所有输出的梯度。

  • 前向模式梯度确实遵循由set_materialize_grads()设置的标志,当此功能被禁用时,您可以获得None输入梯度。

扩展 torch.nn

nn 导出两种接口 - 模块及其功能版本。你可以通过两种方式扩展它,但我们建议对所有包含参数或缓冲区的层使用模块,并建议对无参数操作(如激活函数、池化等)使用功能形式。

添加操作的功能版本已经在上述部分中完全涵盖。

添加一个模块

由于 nn 大量使用了 autograd,添加一个新的 Module 需要实现一个 Function 来执行操作并能够计算梯度。从现在开始,我们假设我们想要实现一个 Linear 模块,并且我们已经按照上面的清单实现了该函数。添加这个模块所需的代码非常少。现在,有两个函数需要实现:

  • __init__(可选)- 接受诸如内核大小、特征数量等参数,并初始化参数和缓冲区。

  • forward() - 实例化一个Function并使用它来执行操作。它与上面显示的功能包装器非常相似。

这是一个如何实现线性模块的示例:

class Linear(nn.Module):
    def __init__(self, input_features, output_features, bias=True):
        super().__init__()
        self.input_features = input_features
        self.output_features = output_features

        # nn.Parameter 是一种特殊的 Tensor,一旦被赋值为 Module 的属性,它就会被自动注册为 Module 的参数。
        # 参数和缓冲区需要被注册,否则它们不会出现在 .parameters() 中(不适用于缓冲区),并且在调用 .cuda() 等方法时不会被转换。
        # 你可以使用 .register_buffer() 来注册缓冲区。
        # nn.Parameters 默认需要梯度。
        self.weight = nn.Parameter(torch.empty(output_features, input_features))
        if bias:
            self.bias = nn.Parameter(torch.empty(output_features))
        else:
            # 你应该总是注册所有可能的参数,但可选的参数可以是 None。
            self.register_parameter('bias', None)

        # 这不是一种非常聪明的初始化权重的方法
        nn.init.uniform_(self.weight, -0.1, 0.1)
        if self.bias is not None:
            nn.init.uniform_(self.bias, -0.1, 0.1)

    def forward(self, input):
        # 请参阅 autograd 部分以了解这里发生的情况。
        return LinearFunction.apply(input, self.weight, self.bias)

    def extra_repr(self):
        # (可选)设置有关此模块的额外信息。你可以通过打印此类的一个对象来测试它。
        return 'input_features={}, output_features={}, bias={}'.format(
            self.input_features, self.output_features, self.bias is not None
        )

扩展 torch Python API

你可以通过定义一个具有与Tensor匹配方法的自定义类来创建模拟Tensor的自定义类型。但是,如果你想将这些类型传递给顶级torch命名空间中接受Tensor操作数的函数,例如torch.add(),该怎么办?

如果你的自定义 Python 类型定义了一个名为 __torch_function__ 的方法,PyTorch 将在你的自定义类实例传递给 torch 命名空间中的函数时调用你的 __torch_function__ 实现。这使得你可以为 torch 命名空间中的任何函数定义自定义实现,你的 __torch_function__ 实现可以调用这些函数,允许你的用户在使用他们已经为 Tensor 编写的现有 PyTorch 工作流时使用你的自定义类型。这也适用于与 Tensor 无关的“鸭子”类型以及用户定义的 Tensor 子类。

使用类似Tensor的类型扩展torch

注意

此功能受NumPy __array_function__ 协议的启发。请参阅NumPy文档NEP-0018以获取更多详细信息。

为了具体说明,让我们从一个简单的例子开始,这个例子说明了API调度机制。我们将创建一个自定义类型,表示一个2D标量张量,由阶数N和对角线上的值value参数化:

class ScalarTensor(object):
   def __init__(self, N, value):
       self._N = N
       self._value = value

   def __repr__(self):
       return "ScalarTensor(N={}, value={})".format(self._N, self._value)

   def tensor(self):
       return self._value * torch.eye(self._N)

设计的第一次迭代并不是非常有用。ScalarTensor的主要功能是提供比基本张量类更紧凑的标量张量字符串表示:

>>> d = ScalarTensor(5, 2)
>>> d
ScalarTensor(N=5, value=2)
>>> d.tensor()
tensor([[2., 0., 0., 0., 0.],
        [0., 2., 0., 0., 0.],
        [0., 0., 2., 0., 0.],
        [0., 0., 0., 2., 0.],
        [0., 0., 0., 0., 2.]])

如果我们尝试将这个对象与torch API一起使用,我们将会遇到问题:

>>> import torch
>>> torch.mean(d)
TypeError: mean(): 参数 'input' (位置 1) 必须是 Tensor,而不是 ScalarTensor

__torch_function__ 实现添加到 ScalarTensor 中,可以使上述操作成功。让我们重新实现一次,这次添加一个 __torch_function__ 实现:

HANDLED_FUNCTIONS = {}
class ScalarTensor(object):
    def __init__(self, N, value):
        self._N = N
        self._value = value

    def __repr__(self):
        return "ScalarTensor(N={}, value={})".format(self._N, self._value)

    def tensor(self):
        return self._value * torch.eye(self._N)

    @classmethod
    def __torch_function__(cls, func, types, args=(), kwargs=None):
        if kwargs is None:
            kwargs = {}
        if func not in HANDLED_FUNCTIONS or not all(
            issubclass(t, (torch.Tensor, ScalarTensor))
            for t in types
        ):
            return NotImplemented
        return HANDLED_FUNCTIONS[func](*args, **kwargs)

__torch_function__ 方法接受四个参数:func,一个对被覆盖的 torch API 函数的引用,types,实现 __torch_function__ 的 Tensor-like 类型的列表,args,传递给函数的参数元组,以及 kwargs,传递给函数的键值字典。它使用一个名为 HANDLED_FUNCTIONS 的全局调度表来存储自定义实现。该字典的键是 torch 命名空间中的函数,值是 ScalarTensor 的实现。

注意

使用全局调度表并不是__torch_function__ API的强制要求,它只是一个用于构建您的覆盖实现的有用设计模式。

这个类定义还不够让 torch.mean 在我们传递一个 ScalarTensor 时做正确的事情——我们还需要为 torch.mean 定义一个针对 ScalarTensor 操作数的实现,并将该实现添加到 HANDLED_FUNCTIONS 调度表字典中。一种实现方式是定义一个装饰器:

import functools
def implements(torch_function):
    """为 ScalarTensor 注册一个 torch 函数覆盖"""
    def decorator(func):
        functools.update_wrapper(func, torch_function)
        HANDLED_FUNCTIONS[torch_function] = func
        return func
    return decorator

这可以应用于我们的重写实现:

@implements(torch.mean)
def mean(input):
    return float(input._value) / input._N

通过这一更改,我们现在可以使用 torch.meanScalarTensor

>>> d = ScalarTensor(5, 2)
>>> torch.mean(d)
0.4

当然,torch.mean 是一个最简单的函数覆盖示例,因为它只接受一个操作数。我们可以使用相同的机制来覆盖一个接受多个操作数的函数,其中任何一个操作数可能是定义了 __torch_function__ 的张量或类张量对象,例如对于 torch.add()

def ensure_tensor(data):
    if isinstance(data, ScalarTensor):
        return data.tensor()
    return torch.as_tensor(data)

@implements(torch.add)
def add(input, other):
   try:
       if input._N == other._N:
           return ScalarTensor(input._N, input._value + other._value)
       else:
           raise ValueError("形状不匹配!")
   except AttributeError:
       return torch.add(ensure_tensor(input), ensure_tensor(other))

此版本在两个操作数都是ScalarTensor实例时有一个快速路径,并且还有一个较慢的路径,当任一操作数不是ScalarTensor时,会降级为将数据转换为张量。这使得重写函数在任一操作数是ScalarTensor或常规Tensor时都能正确工作:

>>> s = ScalarTensor(2, 2)
>>> torch.add(s, s)
ScalarTensor(N=2, value=4)
>>> t = torch.tensor([[1, 1,], [1, 1]])
>>> torch.add(s, t)
tensor([[3., 1.],
        [1., 3.]])

请注意,我们的 add 实现并不像 torch.add() 那样将 alphaout 作为关键字参数:

>>> torch.add(s, s, alpha=2)
TypeError: add() 得到了一个意外的关键字参数 'alpha'

为了速度和灵活性,__torch_function__ 调度机制不会检查覆盖函数的签名是否与 torch API 中被覆盖的函数的签名匹配。对于某些应用程序来说,忽略可选参数是可以的,但要确保与 Tensor 完全兼容,用户实现的 torch API 函数应确保完全模拟被覆盖函数的 API。

torch API 中没有显式重写的函数将从 __torch_function__ 返回 NotImplemented。如果所有具有 __torch_function__ 定义的操作数都返回 NotImplemented,PyTorch 将引发 TypeError。这意味着大多数情况下,当传递此类类型的实例时,没有显式重写的操作将引发 TypeError

>>> torch.mul(s, 3)
TypeError: 没有为 'torch.mul' 找到实现,适用于实现 __torch_function__ 的类型: [ScalarTensor]

实际上,这意味着如果您希望使用__torch_function__实现来实现您的重写,您需要显式实现完整的torch API或您关心的用于您用例的API的整个子集。这可能是一个很高的要求,因为完整的torch API非常广泛。

另一个选项是不返回NotImplemented对于未处理的操作,而是将Tensor传递给原始的torch函数,当没有重写时。例如,如果我们更改ScalarTensor__torch_function__实现为以下内容:

@classmethod
def __torch_function__(cls, func, types, args=(), kwargs=None):
    if kwargs is None:
        kwargs = {}
    if func not in HANDLED_FUNCTIONS or not all(
            issubclass(t, (torch.Tensor, ScalarTensor))
            for t in types
        ):
        args = [a.tensor() if hasattr(a, 'tensor') else a for a in args]
        return func(*args, **kwargs)
    return HANDLED_FUNCTIONS[func](*args, **kwargs)

然后 torch.mul() 将正确工作,尽管返回类型将始终是一个 Tensor 而不是一个 ScalarTensor,即使两个操作数都是 ScalarTensor 实例:

>>> s = ScalarTensor(2, 2)
>>> torch.mul(s, s)
tensor([[4., 0.],
        [0., 4.]])

另请参见下面的 MetadataTensor 示例,这是该模式的另一种变体,但始终返回一个 MetadataTensor 以在 torch API 中的操作中传播元数据。

__torch_function__ 协议旨在全面覆盖API,部分覆盖可能会导致不良结果,特别是某些函数引发 TypeError。对于子类来说尤其如此,其中 torch.addtorch.Tensor.__add__torch.Tensor.add 都必须被覆盖,即使它们返回完全相同的结果。未能做到这一点也可能导致无限递归。如果需要从 torch.Tensor 子类实现某个函数,他们必须在其实现中使用 super().__torch_function__

子类化 torch.Tensor

自版本1.7.0起,应用于torch.Tensor子类的torch.Tensor方法和公共torch.*命名空间中的函数将返回子类实例,而不是torch.Tensor实例:

>>> class SubTensor(torch.Tensor):
...     pass
>>> type(torch.add(SubTensor([0]), SubTensor([1]))).__name__
'SubTensor'
>>> type(torch.add(SubTensor([0]), torch.tensor([1]))).__name__
'SubTensor'

如果存在多个子类,默认情况下将选择层次结构中最底层的那个。如果没有唯一的方法来确定这种情况,则会引发TypeError

>>> type(torch.add(SubTensor2([0]), SubTensor([1]))).__name__
'SubTensor2'
>>> type(torch.add(SubTensor2([0]), torch.tensor([1]))).__name__
'SubTensor2'
>>> torch.add(SubTensor([0]), OtherSubTensor([1]))
Traceback (most recent call last):
  File "", line 1, in 
TypeError: 没有找到 'torch.add' 在实现 __torch_function__ 的类型上的实现: [SubTensor, OtherSubTensor]

如果希望对所有张量方法进行全局覆盖,可以使用 __torch_function__。以下是一个记录所有函数/方法调用的示例:

class LoggingTensor(torch.Tensor):
    @classmethod
    def __torch_function__(cls, func, types, args=(), kwargs=None):
        # 注意:日志调用 Tensor.__repr__,所以我们不能记录 __repr__ 而不导致无限递归
        if func is not torch.Tensor.__repr__:
            logging.info(f"func: {func.__name__}, args: {args!r}, kwargs: {kwargs!r}")
        if kwargs is None:
            kwargs = {}
        return super().__torch_function__(func, types, args, kwargs)

然而,如果有人希望在Tensor子类上覆盖一个方法,可以通过直接覆盖该方法(为子类定义它),或者使用__torch_function__并匹配func来实现。

在子类的__torch_function__中应小心,始终调用super().__torch_function__(func, ...)而不是直接调用func,如1.7.0版本之前的情况。未能这样做可能会导致func递归回到__torch_function__,从而导致无限递归。

使用Tensor包装类型扩展torch

另一个有用的案例是包装一个Tensor的类型,无论是作为属性还是通过子类化。下面我们实现了一种特殊情况,即一个MetadataTensor,它将元数据字典附加到一个Tensor上,并通过torch操作传播。由于这是一种对整个torch API的通用包装,我们不需要单独实现每个重写,因此我们可以使__torch_function__的实现对允许的操作更加宽容:

```python class MetadataTensor(object): def __init__(self, data, metadata=None, **kwargs): self._t = torch.as_tensor(data, **kwargs) self._metadata = metadata def __repr__(self): return "元数据:\n{}\n\n数据:\n{}".format(self._metadata, self._t) @classmethod def __torch_function__(cls, func, types, args=(), kwargs=None): if kwargs is None: kwargs = {} metadatas = tuple(a._metadata for a in args if hasattr(a, '_metadata')) args = [getattr(a, '_t', a) for a in args] assert len(metadatas) > 0 ret = func(*args, **kwargs) return MetadataTensor(ret, metadata=metadatas[0]) ```

这个简单的实现不一定适用于torch API中的每个函数,但它足以捕捉大多数常见操作:

>>> metadata = {'owner': 'Ministry of Silly Walks'}
>>> m = MetadataTensor([[1, 2], [3, 4]], metadata=metadata)
>>> t = torch.tensor([[1, 2], [1, 2]])
>>> torch.add(t, m)
元数据:
{'owner': 'Ministry of Silly Walks'}

数据:
tensor([[2, 4],
        [4, 6]])
>>> torch.mul(t, m)
元数据:
{'owner': 'Ministry of Silly Walks'}

数据:
tensor([[1, 4],
        [3, 8]])

对定义了__torch_function__的多种类型的操作

可以使用 torch API 与多种不同类型一起使用,每种类型都有 __torch_function__ 实现,但需要特别注意。在这种情况下,规则如下:

  • 分发操作收集每个操作数的所有不同 __torch_function__ 实现,并按顺序调用它们:子类在父类之前,否则按操作符表达式中的从左到右顺序调用。

  • 如果返回的值不是 NotImplemented,则该值将作为结果返回。实现可以通过返回 NotImplemented 来注册它们不实现某个操作。

  • 如果所有的 __torch_function__ 实现都返回 NotImplemented,PyTorch 会引发一个 TypeError

PyTorch API 重写的测试覆盖率

实现__torch_function__的一个麻烦之处在于,如果某些操作有重写而其他操作没有重写,用户在最好的情况下会看到不一致的体验,在最坏的情况下会在使用没有重写的函数时遇到运行时错误。为了简化这个过程,PyTorch提供了一个面向开发者的API,用于确保对__torch_function__重写的全面支持。这个API是私有的,并且可能在将来未经警告的情况下发生变化。

首先,要获取所有可覆盖函数的列表,请使用 torch.overrides._get_overridable_functions。这将返回一个字典,其键是PyTorch Python API中的命名空间,其值是该命名空间中可以被覆盖的函数的列表。例如,让我们打印torch.nn.functional中前5个可以被覆盖的函数的名称:

>>> from torch.overrides import get_overridable_functions
>>> func_dict = get_overridable_functions()
>>> nn_funcs = func_dict[torch.nn.functional]
>>> print([f.__name__ for f in nn_funcs[:5])
['adaptive_avg_pool1d', 'adaptive_avg_pool2d', 'adaptive_avg_pool3d',
 'adaptive_max_pool1d', 'adaptive_max_pool1d_with_indices']

这个函数列表使得可以迭代所有可覆盖的函数,然而在实践中,这并不足以在不费力地手动复制每个函数的签名的情况下为所有这些函数编写测试。为了简化这个过程,torch.overrides._get_testing_overrides 函数返回一个字典,该字典将 PyTorch API 中的可覆盖函数映射到具有与原始函数相同签名的虚拟 lambda 函数,但这些函数无条件地返回 -1。这些函数在与 inspect 一起使用时最有用,可以分析原始 PyTorch 函数的签名:

>>> import inspect
>>> from torch.overrides import get_testing_overrides
>>> override_dict = get_testing_overrides()
>>> dummy_add = override_dict[torch.add]
>>> inspect.signature(dummy_add)

最后,torch.overrides.get_ignored_functions 返回一个元组,其中包含无法被 __torch_function__ 覆盖的函数。这个列表可以用来确认那些在 get_overridable_functions 返回的字典中不存在的函数无法被覆盖。

扩展 torch 原生 API

虽然 __torch_function__ 允许有效地扩展 PyTorch 的纯 Python 组件的行为,但它不允许扩展 PyTorch 中用 C++ 实现的部分。为此,Tensor 子类还可以定义 __torch_dispatch__,这将能够在 C++ 级别覆盖行为。

要有效地使用此功能,了解PyTorch的本地部分是如何实现的是非常重要的。其中最重要的组件是我们称之为“调度器”的部分(最好的描述可以在这篇博客文章中找到,尽管它有些过时)。正如其名所示,它负责为特定函数的调用调用正确的后端函数。例如,当调用torch.add(a, b)时,调度器会检查这两个参数,确定应为此特定调用使用的“功能”(如autograd、autocast、functionalization等)和“后端”(如CPU、CUDA、MPS等),并最终调用所有正确的内核。 一个非常常见的操作是由内核进行的“重新调度”。例如,当在GPU上使用autocast运行神经网络时,第一次调用将是autocast内核,它将处理任何潜在的autocast逻辑并向下重新调度。接下来是autograd,它将正确创建autograd图,然后向下重新调度。最后,我们到达CUDA的后端内核,它将启动正确的CUDA内核并返回最终结果。在返回的过程中,autograd会将图附加到输出上,最后,autocast将有机会在退出时进行任何需要的更新。

调度器的一个配置是所有这些特征和后端键被调用的顺序。最新的列表及其顺序可以在 DispatchKey.h 文件中的 DispatchKey 枚举中找到。为了扩展 torch,本次讨论中重要的顺序子集是:

vmap -> 自动转换 -> 自动梯度 -> 零张量 -> 负/共轭 -> 功能化 -> Python -> 后端

对于本次讨论的目的而言,最重要的关键点是 Python,因为每个定义了 __torch_dispatch__ 方法的 Tensor 子类都会调用此功能。正是在这里,用户定义的方法被调用,并且可以任意覆盖行为。从那里,再次调用提供的 func 将执行“重新分派”。

这一实现的一些重要含义包括:

  • 此代码运行“在所有功能之下”。因此,它仅负责生成每个张量的输出值(并且可以,也应该忽略所有高级功能,如自动求导、自动转换等),就像常规的后端一样。

  • 如果任何高级功能在没有重新分派的情况下实现了给定函数,它将永远不会到达 Python 键,因此 __torch_dispatch__ 回调将永远不会被触发。这在 CompositeImplicitAutograd 函数中尤其如此,这些函数在 Autograd 级别进行评估而无需重新分派。这是因为 CompositeImplicitAutograd 函数通过隐式调用其他原生操作来指定其自动求导公式,因此在 Autograd 级别,该函数被分解为其原生操作,并由这些操作进行评估。

  • 当回调到Python并包装结果时,使用的转换与常规的PyTorch Python/C++绑定相同。特别是,某些对象无法在Python中表示,需要特殊处理(例如,未定义的张量变为None)。

  • 我们的本地函数作为可调用的Python对象以torch.ops.{namespace}.{func_name}.{overload_name}的形式惰性填充,以便从Python轻松地与它们交互。传递给__torch_dispatch__func对象始终是此命名空间中的一个条目。此命名空间可用于直接调用本地操作,绕过通常的Python API和绑定代码。

__torch_function__能够介入所有torch的Python API和Tensor方法的方式类似,__torch_dispatch__能够拦截所有对aten原生API的调用。请注意,Tensor上的所有方法在进入调度器之前都被转换为函数调用,因此在这里它们将显示为函数调用:torch.add(a, 2)a + 2 将导致完全相同的aten调用。 这些函数中的大多数定义在native_functions.yaml中,该文件指定了这些函数的属性和它们的底层实现。它们的实现与指定的特性一起通过代码生成自动注册。 一些更奇特的函数或特性也在C++代码库的其他地方或用户定义的C++扩展中注册。

也可以使用torch.library添加的原生函数。这个Python功能允许定义和/或向原生函数添加新的实现。这可以用于添加缺失的内核、替换现有的内核或定义全新的原生函数。

你可以在 subclass zoo 仓库中找到许多基于 __torch_dispatch__ 的子类示例。

扩展所有 torch API 与模式

不幸的是,有些函数不接受Tensor输入。这意味着上述的子类方法不能用于覆盖PyTorch所有函数的行为。此外,如果使用场景需要拦截每个函数调用,将每个Tensor更改为子类可能会过于侵入性。

为了解决这个用例,我们引入了“模式”的概念。这些模式用于__torch_function____torch_dispatch__的重写,分别通过子类化torch.overrides.TorchFunctionModetorch.utils._python_dispatch.TorchDispatchMode来创建,并作为上下文管理器使用。

为了简化它与子类和其他模式的交互描述,每当进入某个模式的上下文管理器时,每个函数的行为就好像在参数列表的开头有一个额外的Tensor参数,该参数是该模式的子类。 这意味着特别是所有模式处理程序将在任何子类处理程序之前被调用,并且与内部上下文管理器对应的模式将始终首先运行。

同样重要的是要注意,在给定的模式处理器中,此特定模式被禁用,并且可以通过执行with self:手动重新启用。

以下是一个展示每种类型日志模式的示例:

import torch
from torch.overrides import TorchFunctionMode, resolve_name
from torch.utils._python_dispatch import TorchDispatchMode

class FunctionLog(TorchFunctionMode):
    def __torch_function__(self, func, types, args, kwargs=None):
        print(f"函数日志: {resolve_name(func)}(*{args}, **{kwargs})")
        return func(*args, **(kwargs or {}))

class DispatchLog(TorchDispatchMode):
    def __torch_dispatch__(self, func, types, args, kwargs=None):
        print(f"调度日志: {func}(*{args}, **{kwargs})")
        return func(*args, **(kwargs or {}))

def f():
    a = torch.rand(10, requires_grad=True)
    b = a * 2
    b.sum().backward()

print("TorchFunctionMode 日志记录:")
with FunctionLog():
    f()

print("TorchDispatchMode 日志记录:")
with DispatchLog():
    f()

打印出以下内容,并附带额外注释:

TorchFunctionMode logging:
Function Log: torch.rand(*(10,), **{'requires_grad': True})
Function Log: torch.Tensor.mul(*(tensor([0.7164, 0.9897, 0.1745, 0.9336, 0.4287, 0.7989, 0.2169, 0.7474, 0.5624,
        0.5970], requires_grad=True), 2), **None)
Function Log: torch.Tensor.sum(*(tensor([1.4328, 1.9794, 0.3490, 1.8671, 0.8573, 1.5977, 0.4338, 1.4948, 1.1249,
        1.1939], grad_fn=<MulBackward0>),), **None)
# 注意在python级别,我们只能看到对backward的调用,而看不到autograd引擎中发生的事情。
Function Log: torch.Tensor.backward(*(tensor(12.3307, grad_fn=<SumBackward0>),), **{'gradient': None, 'retain_graph': None, 'create_graph': False, 'inputs': None})

TorchDispatchMode logging:
# 这里autograd的requires_grad标志被移除,同时默认参数被填充。
Dispatch Log: aten.rand.default(*([10],), **{'device': device(type='cpu'), 'pin_memory': False})
Dispatch Log: aten.mul.Tensor(*(tensor([0.2151, 0.6018, 0.8415, 0.9060, 0.2974, 0.7708, 0.6668, 0.0352, 0.7948,
        0.6023], requires_grad=True), 2), **{})
Dispatch Log: aten.sum.default(*(tensor([0.4303, 1.2036, 1.6831, 1.8120, 0.5949, 1.5416, 1.3335, 0.0705, 1.5897,
        1.2046], grad_fn=<MulBackward0>),), **{})
# 这里我们看不到对backward本身的调用,但可以看到它的组成部分。从这里开始,使用工厂函数创建初始梯度。
Dispatch Log: aten.ones_like.default(*(tensor(11.4637, grad_fn=<SumBackward0>),), **{'pin_memory': False, 'memory_format': torch.preserve_format})
# 这是sum的反向传播
Dispatch Log: aten.expand.default(*(tensor(1.), [10]), **{})
Dispatch Log: aten.mul.Tensor(*(tensor([1., 1., 1., 1., 1., 1., 1., 1., 1., 1.]), 2), **{})
Dispatch Log: aten.detach.default(*(tensor([2., 2., 2., 2., 2., 2., 2., 2., 2., 2.]),), **{})
Dispatch Log: aten.detach.default(*(tensor([2., 2., 2., 2., 2., 2., 2., 2., 2., 2.]),), **{})

编写自定义C++扩展

请参阅此 PyTorch教程 以获取详细说明和示例。

文档可在 torch.utils.cpp_extension 获取。