数值精度

本文档描述了在TensorRT-LLM中实现的不同量化方法,并包含不同模型的支持矩阵。

FP32、FP16 和 BF16

在TensorRT-LLM中实现的不同模型使用32位IEEE浮点数(FP32)。当检查点可用时,模型还支持16位IEEE浮点数(FP16)和16位Bfloat16(BF16),如这里所述。

量化和反量化 (Q/DQ)

给定一个浮点数 x 和一个浮点数缩放因子 s, TensorRT-LLM 实现 INT8 量化的方式如下:

q = int8.satfinite(x * s)

给定一个INT8数字 q 和一个浮点缩放因子 s,TensorRT-LLM 将INT8反量化为浮点(FP)类型的实现如下:

x = static_cast<FP>(q) * s

给定一个形状为M x N的矩阵(2D张量)(M行和N列),其中 M是标记的数量,N是通道的数量。TensorRT-LLM有以下三种模式来量化和反量化张量的元素:

  • 每张量:它为所有元素使用单一的缩放因子,

  • 每个令牌:它为每个令牌使用不同的缩放因子。在这种情况下,有M个缩放因子,

  • 每通道:它为每个通道使用不同的缩放因子。在这种情况下,有N个缩放因子。

请注意,每个令牌和每个通道的缩放模式可以一起使用(即它们不是互斥的)。

在伪代码中,量化可以针对三种不同的模式实现如下:

# Per-tensor scaling.
for mi in range(M):
    for ni in range(N):
        q[mi][ni] = int8.satfinite(x[mi][ni] * s)

# Per-token scaling.
for mi in range(M):
    for ni in range(N):
        q[mi][ni] = int8.satfinite(x[mi][ni] * s[mi])

# Per-channel scaling.
for mi in range(M):
    for ni in range(N):
        q[mi][ni] = int8.satfinite(x[mi][ni] * s[ni])

INT8 平滑量化 (W8A8)

SmoothQuant技术是在 https://arxiv.org/abs/2211.10438中引入的。它是一种 在保持网络(在下游任务上)准确性的同时,使用INT8进行激活和权重推理的方法。

正如研究论文中所解释的,必须对模型的权重进行预处理。TensorRT-LLM 包含了使用 SmoothQuant 方法准备模型运行的脚本。

如何为GPT、GPT-J和LLaMA启用SmoothQuant的示例可以在该版本的examples/quantization文件夹中找到。

INT4 和 INT8 仅权重(W4A16 和 W8A16)

INT4 和 INT8 仅权重技术包括量化模型的权重,并在线性层(矩阵乘法)中即时反量化这些权重。激活值使用浮点值(FP16 或 BF16)进行编码。

要使用INT4/INT8仅权重方法,用户必须确定用于量化和反量化模型权重的缩放因子。

此版本包括GPTLLaMA的示例。

GPTQ 和 AWQ (W4A16)

GPTQ和AWQ技术在 https://arxiv.org/abs/2210.17323https://arxiv.org/abs/2306.00978 中分别进行了介绍。TensorRT-LLM支持在线性层中使用每组的缩放因子和零偏移来实现GPTQ和AWQ方法。详情请参见 WeightOnlyGroupwiseQuantMatmulPlugin 插件和相应的 weight_only_groupwise_quant_matmul Python函数。

此版本包括将GPTQ应用于GPT-NeoXLLaMA-v2的示例,以及使用AWQ与GPT-J的示例。这些示例是实验性的实现,并可能在未来的版本中有所发展。

FP8 (Hopper)

此版本的 TensorRT-LLM 包含了 GPT-NeMo、GPT-J 和 LLaMA 的 FP8 实现。这些示例可以在 examples/quantization 中找到。

支持矩阵

此版本的 TensorRT-LLM 包含以下示例:

模型

FP32

FP16

BF16

FP8

W8A8 SQ

W8A16

W4A16

W4A16 AWQ

W4A16 GPTQ

百川

Y

Y

Y

Y

Y

Y

Y

Y

Y

BERT

Y

Y

Y

.

.

.

.

.

.

BLIP-2

Y

Y

Y

.

.

.

.

.

.

BLOOM

Y

Y

Y

Y

Y

Y

Y

.

.

ChatGLM

Y

Y

Y

.

.

.

.

.

.

ChatGLM-v2

Y

Y

Y

.

.

.

.

.

.

ChatGLM-v3

Y

Y

Y

.

.

.

.

.

.

DBRX

Y

Y

Y

.

.

Y

Y

.

.

猎鹰

Y

Y

Y

Y

.

Y

Y

Y

.

Flan-T5

Y

Y

Y

.

.

.

.

.

.

Gemma

Y

Y

Y

Y

Y

Y

Y

Y

.

GPT

Y

Y

Y

Y

Y

Y

Y

.

.

GPT-J

Y

Y

Y

Y

Y

Y

Y

Y

.

GPT-NeMo

Y

Y

Y

.

.

.

.

.

.

GPT-NeoX

Y

Y

Y

.

.

.

.

.

Y

InternLM

Y

Y

Y

.

Y

Y

Y

.

.

InternLM2

Y

Y

Y

.

.

.

.

.

.

LLaMA

Y

Y

Y

Y

Y

Y

Y

Y

Y

LLaMA-v2

Y

Y

Y

Y

Y

Y

Y

Y

Y

曼巴

Y

Y

Y

.

.

.

.

.

.

Mistral

Y

Y

Y

Y

Y

Y

Y

Y

.

Mixtral

Y

Y

Y

Y

.

Y

Y

.

.

MPT

Y

Y

Y

Y

Y

Y

Y

Y

.

OPT

Y

Y

Y

.

.

.

.

.

.

Phi

Y

Y

Y

.

.

.

.

.

.

Qwen

Y

Y

Y

.

Y

Y

Y

Y

Y

循环宝石

Y

Y

Y

Y

Y

.

.

Y

.

Replit 代码

Y

Y

Y

.

.

.

.

.

.

圣诞编码器

Y

Y

Y

.

.

Y

Y

.

.

天工

Y

Y

Y

.

.

.

.

.

.

StarCoder1

Y

Y

Y

.

.

Y

Y

.

.

StarCoder2

Y

Y

Y

Y

.

Y

Y

.

.

T5

Y

Y

Y

.

.

.

.

.

.

耳语

Y

Y

Y

.

.

Y

Y

.

.

BLIP2-OPT

Y

Y

Y

.

.

.

.

.

.

BLIP2-T5

Y

Y

Y

.

.

.

.

.

.

LLaVA

Y

Y

Y

Y

Y

Y

Y

Y

Y

VILA

Y

Y

Y

Y

Y

Y

Y

Y

Y

牛轧糖

Y

Y

Y

.

.

.

.

.

.

注意:多模态模型(BLIP2-OPT/BLIP2-T5/LLaVA/VILA/Nougat)的视觉组件默认使用FP16。 语言组件决定了给定多模态模型支持哪些量化方法。

技术细节:QuantMode 标志

量化方法由 QuantMode 标志控制。不同的字段 是:

  • INT4_WEIGHTS,权重被量化为4位(W4A*),

  • INT8_WEIGHTS,权重被量化为8位(W8A*),

  • ACTIVATIONS,激活值被量化为8位(W*A8),

  • PER_CHANNEL,缩放因子是按通道定义的,

  • PER_TOKEN,缩放因子是按每个令牌定义的,

  • PER_GROUP,缩放因子是按组定义的。

有三个额外的标志用于控制 TensorRT-LLM:

  • INT8_KV_CACHE,K/V缓存使用8位整数存储K和V,

  • FP8_KV_CACHE,K/V缓存使用8位浮点数存储K和V,

  • FP8_QDQ, TensorRT-LLM 依赖于 TensorRT 中 Q/DQ 节点的自动融合。