图重写模块
TensorRT-LLM 使用声明式方法来定义神经网络,并包含优化底层图的技术。它提供了一个类似于 PyTorch 的 Module 的包装器。当用户调用 forward 方法时,层会被降级为 TensorRT 的 ILayer,并成为 INetworkDefinition 的一部分。图重写(GW)模块可用于在 ILayer/INetworkDefinition 级别操作网络。
何时使用图重写?
对于网络操作,TensorRT-LLM 提供了两种选择:
模块重写: 此方法在触发
forward方法(即创建TensorRT图)之前修改Module实例的成员。它在网络表示的最高级别上工作,并有助于修改操作序列(如修改SmoothQuant的GEMM +激活),图重写: 图重写在触发
forward方法后操作TensorRT的INetworkDefinition。它在更细粒度的ILayer级别上操作,并且可以跨多个模块实例改变结构。它通常用于层融合。
图重写(GW)在以下条件下理想使用:
当只有
ILayer/INetworkDefinition可用时,当模块重写会导致嵌套控制流或功能分散时。
图重写API
提供了几个核心API用于图重写:
FLayerInfo 用于检索功能的高级信息
对于所有位于functional.py中的层,一旦降低到INetworkDefinition,原始输入信息就会丢失,特别是对于在Python世界中不透明的TensorRT插件。FLayerInfo保存了它们的原始信息,作为一个包含Tensor、Python属性等输入的高级签名。有一个网络范围的单例称为FLayerInfoMemo,用于将每个ILayer映射到其对应的FLayerInfo。
对于 FLayerInfo:
FLayerInfo.replace_input_with: 用另一个张量替换某些输入张量,FLayerInfo.replace_output_uses_with: 将原始输出张量的使用重定向到一组新的张量。
对于 FLayerInfoMemo:
FLayerInfoMemo.instance(): 获取单例实例,FLayerInfoMemo.get: 获取与ILayer对应的FLayerInfo。
FLayerInfo 在GW期间与实际 ILayer 保持一致,因此可以安全使用。
模式和模式管理器
有两种模式:
PatternRewriter: 用于定义重写模式,实际上会改变网络。match: 匹配模式;如果匹配到图层,则返回true,rewrite: 操作一个层,match_and_rewrite: 结合了match和rewrite,用于需要从match传递到rewrite的复杂状态。
PatternAnalyzer: 用于定义分析模式,该模式从网络中收集信息。match: 匹配模式,analyze: 对图层列表执行分析。
有两个管理器用于管理多个PatternRewriter或PatternAnalyzer:
RewritePatternManager:add: 添加一个模式及其标签和收益;收益指定其特权,get: 通过标签获取模式,rewrite: 应用包含的重写模式到网络。
AnalysisPatternManager:add: 添加一个模式及其标签和收益;收益指定了其特权,get: 通过标签获取模式,analyze: 将包含的分析模式应用于网络。
@record_signature 用于装饰需要 FLayerInfo 的功能
@record_signature 装饰器用于记录功能层的 FLayerInfo。虽然 FLayerInfo 在 GW 分析或重写某些功能层时至关重要,但它是以“按需添加”的方式使用的。如果您正在添加 GW 模式,请确保该功能层需要 @record_signature 装饰器。
经典工作流程
有特定的例程用于定义GW模式。让我们从一个简单的例子开始:将求和层替换为减法层,这也可以在test_graph_rewriting.py文件中找到。
class NaivePatternRewriter_ReplaceAddWithSub(PatternRewriter):
def __init__(self):
super().__init__('replace_add_with_sub',
root_layer={trt.LayerType.ELEMENTWISE},
separate_match_rewrite=True)
def match(self, layer: Layer):
# The rewriter will stop at the first matched layer, and then the Rewriter will enter the rewrite() to do the rewriting.
return layer.as_layer().op == trt.ElementWiseOperation.SUM
def rewrite(self, layer: Layer) -> None:
# The layer here should be an Elementwise_SUM layer.
with net_guard(layer.network):
# There are several stages to replace some subgraph with another subgraph:
# Stage 1: Get the input tensors and output tensors of the subgraph to replace.
# - For Elementwise_SUM, there are two inputs and one output.
a, b = layer.get_inputs(0, 1)
o = layer.get_outputs(0)[0]
# Stage 2: Create a new subgraph that takes the old one's inputs.
# - Here we insert an Elementwise_SUB layer, and 'c' is the output.
c = a - b
# Stage 3: Redirect all the layers depending on the outputs of the old subgraph to the new subgraph's.
# - After this, the SUM becomes dangling and will be pruned by TensorRT when building the engine.
# - Note that there is no API in TensorRT python to remove a layer explicitly; `replace_all_uses_with` is the only way to "remove" a layer.
o.replace_all_uses_with(c)
# Stage 4: Mark all the layers in the old subgraph as removed.
# - This helps the PatternRewriter to skip the removed layers.
layer.mark_as_removed()
在这个例子中,我们处理的是ILayer而不是插件,因此FLayerInfo是不必要的。如rewrite方法所示,几乎所有重写模式都共享四个阶段。
请注意,在GW中,我们从不直接重写一个层。相反,我们分两步进行:首先,创建另一个具有相同输入的层,并剥夺所有用户对原始输出的访问,将他们重定向到新层的输出。这样,旧层将在引擎构建阶段自动悬空并被TensorRT修剪。这是TensorRT的一个限制,因为在Python中没有类似删除层的API。
在第二阶段,我们依赖于网络构建阶段常用的操作符和层。理想情况下,您可以在GW期间用任何网络结构替换它们。
关于FLayerInfo的使用,让我们重写gpt_attention以启用remove-padding功能。gpt_attention实际上是
一个TensorRT插件,因此我们需要FLayerInfo来保存原始的Tensor-wise输入,以帮助创建新的gpt_attention层。
class GPTAttentionPluginRemovePaddingRewritePass(PatternRewriter):
def __init__(self):
super().__init__('gpt_attention_plugin_remove_padding',
root_layer={trt.LayerType.PLUGIN_V2})
def match_and_rewrite(self, layer: Layer) -> bool:
if layer.as_layer().type != trt.LayerType.PLUGIN_V2 or \
layer.as_layer().plugin.plugin_namespace != 'tensorrt_llm' or \
layer.as_layer().plugin.plugin_type != 'GPTAttention':
return False
# Retrieve the FLayerInfo
flayer = FLayerInfoMemo.instance().get(layer.name)
assert flayer
# Although the layer is a plugin, which is a black box, we get some high-level input information from the FLayerInfo.
tensor_input: Tensor = flayer.get_input('qkv')
if tensor_input.shape[0] == 1: # Already in remove-padding mode
return False
# Some information could be passed in from external
assert self.args is not None, "args should be passed in from RewritePatternManager.rewrite()"
batch_size, in_len, hidden_size = self.args['batch_size'], self.args['in_len'], self.args['hidden_size']
with net_guard(layer.network):
new_inputs = flayer.clone_inputs()
# Step 1: Create new inputs and replace the original arglist.
input = Tensor(
name='qkv',
dtype=trt.float16,
shape=(1, batch_size * in_len, hidden_size),
)
new_inputs['qkv'] = input
# Step 2: Create a new plugin instance.
new_outs = gpt_attention(**new_inputs)
# Step 3: Deprive all the users of the old plugin instance.
flayer.replace_outputs_uses_with(layer.network, new_outs)
# Step 4: Remove the old plugin instance.
layer.mark_as_removed()
return True
这与第一个示例非常相似,重点是FLayerInfo部分。通过下面的代码,我们可以获取该层的原始输入,从而能够更改与移除填充相关的输入,并创建一个新层来替换它。
flayer = FLayerInfoMemo.instance().get(layer.name)
assert flayer
new_inputs = flayer.clone_inputs()
# Step 1: Create new inputs and replace the original arglist.
input = Tensor(
name='tensor',
dtype=trt.float16,
shape=(1, batch_size * in_len, hidden_size),
)
new_inputs['tensor'] = input
# Step 2: Create a new plugin instance.
new_outs = gpt_attention(**new_inputs)
有关实际示例,请参考graph_rewriting.py中的FuseAttentionWithBiasPass。