图重写模块

TensorRT-LLM 使用声明式方法来定义神经网络,并包含优化底层图的技术。它提供了一个类似于 PyTorch 的 Module 的包装器。当用户调用 forward 方法时,层会被降级为 TensorRT 的 ILayer,并成为 INetworkDefinition 的一部分。图重写(GW)模块可用于在 ILayer/INetworkDefinition 级别操作网络。

何时使用图重写?

对于网络操作,TensorRT-LLM 提供了两种选择:

  1. 模块重写: 此方法在触发forward方法(即创建TensorRT图)之前修改Module实例的成员。它在网络表示的最高级别上工作,并有助于修改操作序列(如修改SmoothQuant的GEMM +激活),

  2. 图重写: 图重写在触发forward方法后操作TensorRT的INetworkDefinition。它在更细粒度的ILayer级别上操作,并且可以跨多个模块实例改变结构。它通常用于层融合。

图重写(GW)在以下条件下理想使用:

  1. 当只有ILayer/INetworkDefinition可用时,

  2. 当模块重写会导致嵌套控制流或功能分散时。

图重写API

提供了几个核心API用于图重写:

FLayerInfo 用于检索功能的高级信息

对于所有位于functional.py中的层,一旦降低到INetworkDefinition,原始输入信息就会丢失,特别是对于在Python世界中不透明的TensorRT插件。FLayerInfo保存了它们的原始信息,作为一个包含Tensor、Python属性等输入的高级签名。有一个网络范围的单例称为FLayerInfoMemo,用于将每个ILayer映射到其对应的FLayerInfo

对于 FLayerInfo:

  • FLayerInfo.replace_input_with: 用另一个张量替换某些输入张量,

  • FLayerInfo.replace_output_uses_with: 将原始输出张量的使用重定向到一组新的张量。

对于 FLayerInfoMemo:

  • FLayerInfoMemo.instance(): 获取单例实例,

  • FLayerInfoMemo.get: 获取与ILayer对应的FLayerInfo

FLayerInfo 在GW期间与实际 ILayer 保持一致,因此可以安全使用。

模式和模式管理器

有两种模式:

  • PatternRewriter: 用于定义重写模式,实际上会改变网络。

    • match: 匹配模式;如果匹配到图层,则返回true,

    • rewrite: 操作一个层,

    • match_and_rewrite: 结合了matchrewrite,用于需要从match传递到rewrite的复杂状态。

  • PatternAnalyzer: 用于定义分析模式,该模式从网络中收集信息。

    • match: 匹配模式,

    • analyze: 对图层列表执行分析。

有两个管理器用于管理多个PatternRewriterPatternAnalyzer

  • RewritePatternManager:

    • add: 添加一个模式及其标签和收益;收益指定其特权,

    • get: 通过标签获取模式,

    • rewrite: 应用包含的重写模式到网络。

  • AnalysisPatternManager:

    • add: 添加一个模式及其标签和收益;收益指定了其特权,

    • get: 通过标签获取模式,

    • analyze: 将包含的分析模式应用于网络。

@record_signature 用于装饰需要 FLayerInfo 的功能

@record_signature 装饰器用于记录功能层的 FLayerInfo。虽然 FLayerInfo 在 GW 分析或重写某些功能层时至关重要,但它是以“按需添加”的方式使用的。如果您正在添加 GW 模式,请确保该功能层需要 @record_signature 装饰器。

经典工作流程

有特定的例程用于定义GW模式。让我们从一个简单的例子开始:将求和层替换为减法层,这也可以在test_graph_rewriting.py文件中找到。

class NaivePatternRewriter_ReplaceAddWithSub(PatternRewriter):

    def __init__(self):
        super().__init__('replace_add_with_sub',
                         root_layer={trt.LayerType.ELEMENTWISE},
                         separate_match_rewrite=True)

    def match(self, layer: Layer):
        # The rewriter will stop at the first matched layer, and then the Rewriter will enter the rewrite() to do the rewriting.
        return layer.as_layer().op == trt.ElementWiseOperation.SUM

    def rewrite(self, layer: Layer) -> None:
        # The layer here should be an Elementwise_SUM layer.
        with net_guard(layer.network):
            # There are several stages to replace some subgraph with another subgraph:

            # Stage 1: Get the input tensors and output tensors of the subgraph to replace.
            # - For Elementwise_SUM, there are two inputs and one output.
            a, b = layer.get_inputs(0, 1)
            o = layer.get_outputs(0)[0]

            # Stage 2: Create a new subgraph that takes the old one's inputs.
            # - Here we insert an Elementwise_SUB layer, and 'c' is the output.
            c = a - b

            # Stage 3: Redirect all the layers depending on the outputs of the old subgraph to the new subgraph's.
            # - After this, the SUM becomes dangling and will be pruned by TensorRT when building the engine.
            # - Note that there is no API in TensorRT python to remove a layer explicitly; `replace_all_uses_with` is the only way to "remove" a layer.
            o.replace_all_uses_with(c)

            # Stage 4: Mark all the layers in the old subgraph as removed.
            # - This helps the PatternRewriter to skip the removed layers.
            layer.mark_as_removed()

在这个例子中,我们处理的是ILayer而不是插件,因此FLayerInfo是不必要的。如rewrite方法所示,几乎所有重写模式都共享四个阶段。

请注意,在GW中,我们从不直接重写一个层。相反,我们分两步进行:首先,创建另一个具有相同输入的层,并剥夺所有用户对原始输出的访问,将他们重定向到新层的输出。这样,旧层将在引擎构建阶段自动悬空并被TensorRT修剪。这是TensorRT的一个限制,因为在Python中没有类似删除层的API。

在第二阶段,我们依赖于网络构建阶段常用的操作符和层。理想情况下,您可以在GW期间用任何网络结构替换它们。

关于FLayerInfo的使用,让我们重写gpt_attention以启用remove-padding功能。gpt_attention实际上是

一个TensorRT插件,因此我们需要FLayerInfo来保存原始的Tensor-wise输入,以帮助创建新的gpt_attention层。

class GPTAttentionPluginRemovePaddingRewritePass(PatternRewriter):

    def __init__(self):
        super().__init__('gpt_attention_plugin_remove_padding',
                         root_layer={trt.LayerType.PLUGIN_V2})

    def match_and_rewrite(self, layer: Layer) -> bool:
        if layer.as_layer().type != trt.LayerType.PLUGIN_V2 or \
                layer.as_layer().plugin.plugin_namespace != 'tensorrt_llm' or \
                layer.as_layer().plugin.plugin_type != 'GPTAttention':
            return False

        # Retrieve the FLayerInfo
        flayer = FLayerInfoMemo.instance().get(layer.name)
        assert flayer
        # Although the layer is a plugin, which is a black box, we get some high-level input information from the FLayerInfo.
        tensor_input: Tensor = flayer.get_input('qkv')
        if tensor_input.shape[0] == 1:  # Already in remove-padding mode
            return False

        # Some information could be passed in from external
        assert self.args is not None, "args should be passed in from RewritePatternManager.rewrite()"
        batch_size, in_len, hidden_size = self.args['batch_size'], self.args['in_len'], self.args['hidden_size']

        with net_guard(layer.network):
            new_inputs = flayer.clone_inputs()

            # Step 1: Create new inputs and replace the original arglist.
            input = Tensor(
                name='qkv',
                dtype=trt.float16,
                shape=(1, batch_size * in_len, hidden_size),
            )
            new_inputs['qkv'] = input

            # Step 2: Create a new plugin instance.
            new_outs = gpt_attention(**new_inputs)

            # Step 3: Deprive all the users of the old plugin instance.
            flayer.replace_outputs_uses_with(layer.network, new_outs)

            # Step 4: Remove the old plugin instance.
            layer.mark_as_removed()

        return True

这与第一个示例非常相似,重点是FLayerInfo部分。通过下面的代码,我们可以获取该层的原始输入,从而能够更改与移除填充相关的输入,并创建一个新层来替换它。

flayer = FLayerInfoMemo.instance().get(layer.name)
assert flayer
new_inputs = flayer.clone_inputs()

# Step 1: Create new inputs and replace the original arglist.
input = Tensor(
    name='tensor',
    dtype=trt.float16,
    shape=(1, batch_size * in_len, hidden_size),
)
new_inputs['tensor'] = input

# Step 2: Create a new plugin instance.
new_outs = gpt_attention(**new_inputs)

有关实际示例,请参考graph_rewriting.py中的FuseAttentionWithBiasPass