• Docs >
  • Overloading Torch-TensorRT Converters with Custom Converters
Shortcuts

使用自定义转换器重载 Torch-TensorRT 转换器

如果出于某种原因,您希望更改特定PyTorch操作到TensorRT的转换行为,您可以通过编写自定义转换器并重载Torch-TensorRT的转换器来实现。 这可能是因为您希望使用自定义内核而不是TensorRT的内核,或者因为您希望在TensorRT中使用与Torch-TensorRT通常使用的不同的层实现。

在本教程中,我们将演示如何通过使用不同实现的GeLU层的自定义转换器来重载Torch-TensorRT对torch.nn.functional.gelu操作到TensorRT的转换。

import logging
import sys

import torch
import torch_tensorrt

GeLU 在 PyTorch 中有两种模式,一种使用 erf 函数,另一种使用 tanh 近似。 TensorRT 原生支持这两种实现作为激活层,但假设我们只想在 TensorRT 中使用自定义的 GeLU 实现,仅用于 tanh 模式。

class GeLU(torch.nn.Module):
    def __init__(self, mode="tanh"):
        super().__init__()
        self.mode = mode

    def forward(self, x):
        return torch.nn.functional.gelu(x, approximate=self.mode)


my_mod = GeLU(mode="tanh")
ex_input = torch.randn(2, 5).to("cuda")

作为基线,我们可以使用标准的 Torch-TensorRT GeLU 转换器(在 tanh 近似模式下)与我们的模块。

my_standard_gelu = torch_tensorrt.compile(
    my_mod, arg_inputs=(ex_input,), min_block_size=1
)
print(my_standard_gelu.graph)
print(my_standard_gelu(ex_input))

编写自定义转换器

转换器是将PyTorch图中的特定PyTorch操作实例转换为正在构建的TensorRT图中的等效TensorRT操作集的函数。 它们使用@torch_tensorrt.dynamo.conversion.dynamo_tensorrt_converter装饰器注册到Torch-TensorRT中。 在代码层面,转换器接收当前的转换状态(ConversionCtx)、图中要转换的下一个操作符以及该节点的参数, 并返回该操作的占位符输出,同时作为副作用将必要的TensorRT层插入到TensorRT网络中。

from typing import Dict, Sequence, Tuple, Union

from torch.fx.node import Argument, Node, Target
from torch_tensorrt.dynamo import CompilationSettings
from torch_tensorrt.dynamo.conversion import ConversionContext

import tensorrt as trt

转换器元数据

@torch_tensorrt.dynamo.conversion.dynamo_tensorrt_converter(
    # The PyTorch operation to convert, when this operation is encountered, this converter will be called
    torch.ops.aten.gelu.default,
    # Validators are functions that determine that given a specific node, if it can be converted by the converter
    capability_validator=lambda node, settings: (
        "approximate" in node.kwargs and node.kwargs["approximate"] == "tanh"
    ),
    # Can this converter be used in cases where the input shapes are dynamic
    supports_dynamic_shapes=True,
    # Set the priority of the converter to supersede the default one
    priority=torch_tensorrt.dynamo.conversion.ConverterPriority.HIGH,
)

对于定义转换器的装饰器,有一个必需的参数和几个可选的参数。 所有转换器都需要一个它们将运行的目标操作符,其思想是当图中存在torch.ops.aten.gelu.default的实例时,将调用此转换器。

在目标操作符之后,您可以提供额外的元数据,这些元数据定义了转换器的功能以及该转换器与其他可能针对该目标的转换器的优先级。

定义转换器功能的主要工具是capability_validator参数, 这是一个lambda函数,它接受图中的特定节点以及用户编译设置,并返回一个布尔值,指示转换器是否可以用于该节点。 此验证器函数在图分区阶段之前针对转换器目标操作的每个实例运行。在此阶段没有通过验证器的转换器的节点将在运行时在PyTorch中执行。 这对于您只想在特定情况下使用自定义转换器的情况非常有用,例如在我们的情况下,我们只想在approximate == "tanh"时使用我们的转换器。

与验证器不同的是supports_dynamic_shapes参数,这是一个布尔值,指示转换器是否可以在输入形状动态的情况下使用。 如果设置为False,在用户提供的输入是动态的情况下,此转换器将被禁用。如果没有支持动态形状的替代方案,操作将在PyTorch中运行。

最后是priority参数,它是来自torch_tensorrt.dynamo.conversion.ConverterPriority类的枚举,用于定义转换器的优先级。两个选项是HIGHSTANDARD。 使用STANDARD注册的转换器将被附加到给定操作的转换器列表中,而使用HIGH注册的转换器将被前置到列表中。 候选转换器将按照此优先级顺序评估其适用性,第一个通过验证器的转换器将被使用。

转换器实现

转换器函数本身接受以下参数:当前的转换上下文、目标操作符、目标操作符的参数、目标操作符的关键字参数以及目标操作符的名称。 参数可以是任何Python原语、torch.Tensornp.ArraysITensor对象。 转换器函数应主要返回目标操作符的输出,以TensorRT的ITensor为主。这些输入和输出应与目标PyTorch操作符的模式相对应,可以在这里找到https://pytorch.org/docs/main/torch.compiler_ir.html

由于Torch-TensorRT涵盖了核心的ATen操作集,它已经将许多常见的低级操作抽象为辅助函数,这些函数可用于构建TensorRT网络。这使得开发者可以避免直接创建TensorRT层的样板代码,而是专注于转换的高级逻辑。辅助函数位于torch_tensorrt.dynamo.conversion.impl模块中,并且设计为可组合的,并且可以与原始的TensorRT实现互操作。在这种情况下,我们将使用Torch-TensorRT的muladdtanh函数从impl中实现我们的替代GeLU层。

def aten_ops_gelu(
    ctx: ConversionContext,
    target: Target,
    args: Tuple[Argument, ...],
    kwargs: Dict[str, Argument],
    name: str,
) -> Union[trt.ITensor, Sequence[trt.ITensor]]:
    # The schema for torch.ops.aten.gelu.default is gelu(Tensor self, *, str approximate=’none’) -> Tensor

    from torch_tensorrt.dynamo import SourceIR
    from torch_tensorrt.dynamo.conversion import impl

    # Cheap way to allow layer names to be unqiue
    op_count = 0

    def get_op_count():
        nonlocal op_count
        op_count += 1
        return op_count

    mul = lambda x, y: impl.elementwise.mul(
        ctx,
        target,
        name=f"mul_{get_op_count()}",
        source_ir=SourceIR.ATEN,
        lhs_val=x,
        rhs_val=y,
    )
    add = lambda x, y: impl.elementwise.add(
        ctx,
        target,
        name=f"add_{get_op_count()}",
        source_ir=SourceIR.ATEN,
        lhs_val=x,
        rhs_val=y,
    )
    tanh = lambda x: impl.activation.tanh(
        ctx, target, name=f"tanh_{get_op_count()}", source_ir=SourceIR.ATEN, input_val=x
    )

    # So we know that our custom converter is being run instead of the standard one
    print("\n\n---------------------------")
    print("Using custom GeLU converter")
    print("---------------------------\n\n")

    x_7 = mul(args[0], 0.5)
    x_8 = mul(args[0], 0.79788456080000003)
    x_9 = mul(args[0], 0.044714999999999998)
    x_10 = mul(x_9, args[0])
    x_11 = add(x_10, 1.0)
    x_12 = mul(x_8, x_11)
    x_13 = tanh(x_12)
    x_14 = add(x_13, 1.0)
    x_15 = mul(x_7, x_14)

    return x_15

使用我们的自定义转换器

我们现在可以重新编译并看到我们的自定义转换器被调用来将GeLU转换为TensorRT。

my_custom_gelu = torch_tensorrt.compile(
    my_mod, arg_inputs=(ex_input,), min_block_size=1
)

print(my_custom_gelu.graph)
print(my_custom_gelu(ex_input))

我们可以验证我们的实现与TensorRT实现对于tanh近似的匹配情况。

print(
    f"tanh approximations are close: {torch.allclose(my_standard_gelu(ex_input), my_custom_gelu(ex_input))}"
)

最后,我们想验证在approximate参数未设置为tanh的情况下,我们的自定义转换器未被使用。

my_mod_erf = GeLU(mode="none")
my_gelu_erf = torch_tensorrt.compile(
    my_mod_erf, arg_inputs=(ex_input,), min_block_size=1
)

注意,我们没有看到自定义转换器的打印语句,这表明它没有被使用。然而,查看图表,我们仍然可以看到创建了一个TensorRT引擎来运行GeLU操作。 在这种情况下,我们的自定义转换器的验证器返回了False,因此转换系统继续使用列表中的下一个转换器,即标准的GeLU转换器,并使用它来转换操作。

print(my_gelu_erf.graph)
print(my_gelu_erf(ex_input))

脚本总运行时间: ( 0 分钟 0.000 秒)

Gallery generated by Sphinx-Gallery