优化大型语言模型
随着大语言模型(LLMs)成为多个领域的热门研究课题,在云端和边缘设备上部署这些模型已成为一项具有挑战性的任务。本教程将展示如何使用Apache TVM优化大语言模型。我们将采用Hugging Face提供的预训练TinyLlama模型,并将其部署到各类设备上。
整体流程回顾
整体流程包含以下步骤:
构建或导入模型: 构建神经网络模型或从其他框架(如PyTorch、ONNX)导入预训练模型,并创建TVM IRModule,其中包含编译所需的所有信息,包括用于计算图的高级Relax函数和用于张量程序的低级TensorIR函数。
执行可组合优化: 执行一系列优化转换,如图优化、张量程序优化和库调度。
构建与通用部署: 将优化后的模型构建为可部署模块,适配通用运行时环境,并可在不同设备上执行,例如CPU、GPU或其他加速器。
构建模型架构
我们将使用来自Hugging Face的预训练TinyLlama模型。但通常情况下,我们只会加载Hugging Face的预训练权重,而不会直接使用其模型架构。我们需要自行构建模型架构。Apache TVM提供了一套类似PyTorch的API来构建模型架构,我们可以通过这些API来完成模型架构的搭建。
import dataclasses
import enum
import os
from pathlib import Path
from pprint import pprint
from typing import List, Optional
import tvm
from tvm import dlight, relax, te, tir
from tvm.relax import register_pipeline
from tvm.relax.frontend import nn
from tvm.relax.frontend.nn import Tensor, op
from tvm.relax.frontend.nn.llm.kv_cache import PagedKVCache, TIRPagedKVCache
from tvm.runtime import ShapeTuple
首先,我们需要定义模型配置。该配置包含模型的关键参数,例如隐藏层大小、中间层大小等。为方便起见,我们专门为TinyLlama模型定义了一个常量配置。
@dataclasses.dataclass
class LlamaConfig:
hidden_size: int = 2048
intermediate_size: int = 5632
num_attention_heads: int = 32
num_hidden_layers: int = 22
rms_norm_eps: float = 1e-05
vocab_size: int = 32000
rope_theta: int = 10000
context_window_size: int = 2048
prefill_chunk_size: int = 2048
num_key_value_heads: int = 4
head_dim: int = 64 # hidden_size // num_attention_heads
dev = tvm.device("cuda", 0)
target = tvm.target.Target.from_device(dev)
接下来,我们定义分页KV缓存的RoPE模式。RoPE模式用于对查询和键张量应用相对位置编码(RoPE)。RoPE模式可设置为NONE、NORMAL或INLINE。如果RoPE模式为NONE,KV缓存将不会对查询和键张量应用RoPE。如果RoPE模式为NORMAL,在将键张量加入缓存前会先对其应用RoPE。如果RoPE模式为INLINE,则会在注意力内核中实时对查询和键张量应用RoPE。
class RopeMode(enum.IntEnum):
"""The RoPE mode of the Paged KV cache.
If it is none, the KV cache will not apply RoPE to q and k.
If it is normal, RoPE will be applied to k before adding k to cache.
Otherwise, RoPE will be applied to q/k in attention kernel on-the-fly.
"""
NONE = 0
NORMAL = 1
INLINE = 2
其次,我们定义模型架构。模型架构由三部分组成:
嵌入层:嵌入层将输入的标记ID转换为隐藏状态。
解码器层:解码器层是模型的核心。每个解码器层由一个自注意力层和一个前馈网络(FFN)层组成。
输出层:输出层将隐藏状态转换为逻辑值。
首先我们定义FFN层。请注意以下FFN层是经过优化的实现,其中我们将门控和向上投影融合到一个内核中。
FFN层的原始实现是:FFN(x) = down_proj(silu(gate(x)) * up(x))
我们可以将gate和up投影合并到一个内核中以获得更好的性能。
优化后的实现是:
concat_x = gate_up(x)
gate_x, up_x = split(concat_x, 2, axis=-1)
FFN(x) = down_proj(silu(gate_x) * up_x)
class LlamaFFN(nn.Module):
def __init__(self, config: LlamaConfig):
super().__init__()
self.gate_up_proj = nn.Linear(
in_features=config.hidden_size,
out_features=2 * config.intermediate_size,
bias=False,
)
self.down_proj = nn.Linear(config.intermediate_size, config.hidden_size, bias=False)
def forward(self, x: Tensor):
concat_x1_x2 = self.gate_up_proj(x)
x1, x2 = op.split(concat_x1_x2, 2, axis=-1)
return self.down_proj(op.silu(x1) * x2)
然后我们定义自注意力层。自注意力层由三部分组成:
QKV投影:QKV投影将输入的隐藏状态转换为查询(query)、键(key)和值(value)张量。
注意:注意力层负责计算注意力分数并应用softmax运算。
输出投影:输出投影将注意力输出转换为隐藏状态。
我们对自注意力层的不同部分进行优化:
QKV投影:我们利用QKV投影的水平融合技术,将它们融合到一个核心中。
注意:我们利用注意力机制的水平融合,将QKV投影和
class LlamaAttention(nn.Module): # pylint: disable=too-many-instance-attributes
def __init__(self, config: LlamaConfig):
self.head_dim = config.head_dim
self.num_q_heads = config.num_attention_heads
self.num_kv_heads = config.num_key_value_heads
# horizontal fusion on QKV projection
self.qkv_proj = nn.Linear(
in_features=config.hidden_size,
out_features=(self.num_q_heads + 2 * self.num_kv_heads) * self.head_dim,
bias=False,
)
self.o_proj = nn.Linear(self.num_q_heads * self.head_dim, config.hidden_size, bias=False)
def forward(self, hidden_states: Tensor, paged_kv_cache: PagedKVCache, layer_id: int):
d, h_q, h_kv = self.head_dim, self.num_q_heads, self.num_kv_heads
b, s, _ = hidden_states.shape
# QKV Projection
qkv = self.qkv_proj(hidden_states)
qkv = op.reshape(qkv, (b, s, h_q + h_kv + h_kv, d))
# Attention
output = op.reshape(
paged_kv_cache.attention_with_fused_qkv(
layer_id, qkv, self.num_q_heads, sm_scale=self.head_dim**-0.5
),
(b, s, h_q * d),
)
# Output Projection
return self.o_proj(output)
最后,我们定义了包含FFN和自注意力层的模型架构。
class LlamaDecoderLayer(nn.Module):
def __init__(self, config: LlamaConfig):
rms_norm_eps = config.rms_norm_eps
self.self_attn = LlamaAttention(config)
self.mlp = LlamaFFN(config)
self.input_layernorm = nn.RMSNorm(config.hidden_size, -1, rms_norm_eps, bias=False)
self.post_attention_layernorm = nn.RMSNorm(config.hidden_size, -1, rms_norm_eps, bias=False)
def forward(self, hidden_states: Tensor, paged_kv_cache: PagedKVCache, layer_id: int):
hidden_states += self.self_attn(
self.input_layernorm(hidden_states), paged_kv_cache, layer_id
)
hidden_states += self.mlp(self.post_attention_layernorm(hidden_states))
return hidden_states
class LlamaModel(nn.Module):
def __init__(self, config: LlamaConfig):
assert config.hidden_size % config.num_attention_heads == 0
self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size)
self.layers = nn.ModuleList(
[LlamaDecoderLayer(config) for _ in range(config.num_hidden_layers)]
)
self.norm = nn.RMSNorm(config.hidden_size, -1, config.rms_norm_eps, bias=False)
def forward(self, input_embed: Tensor, paged_kv_cache: PagedKVCache):
hidden_states = input_embed
for layer_id, layer in enumerate(self.layers):
hidden_states = layer(hidden_states, paged_kv_cache, layer_id)
hidden_states = self.norm(hidden_states)
return hidden_states
class LlamaForCasualLM(nn.Module):
def __init__(self, config: LlamaConfig):
self.model = LlamaModel(config)
self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
self.num_hidden_layers = config.num_hidden_layers
self.num_attention_heads = config.num_attention_heads
self.num_key_value_heads = config.num_key_value_heads
self.head_dim = config.head_dim
self.hidden_size = config.hidden_size
self.vocab_size = config.vocab_size
self.rope_theta = config.rope_theta
self.dtype = "float32"
def to(self, dtype: Optional[str] = None):
super().to(dtype=dtype)
if dtype is not None:
self.dtype = dtype
def embed(self, input_ids: Tensor):
return self.model.embed_tokens(input_ids)
def get_logits(self, hidden_states: Tensor):
logits = self.lm_head(hidden_states)
if logits.dtype != "float32":
logits = logits.astype("float32")
return logits
def prefill(self, input_embed: Tensor, paged_kv_cache: PagedKVCache):
def _index(x: te.Tensor): # x[:-1,:]
b, s, d = x.shape
return te.compute((b, 1, d), lambda i, _, k: x[i, s - 1, k], name="index")
hidden_states = self.model(input_embed, paged_kv_cache)
hidden_states = op.tensor_expr_op(_index, name_hint="index", args=[hidden_states])
logits = self.get_logits(hidden_states)
return logits, paged_kv_cache
def decode(self, input_embed: Tensor, paged_kv_cache: PagedKVCache):
hidden_states = self.model(input_embed, paged_kv_cache)
logits = self.get_logits(hidden_states)
return logits, paged_kv_cache
def create_tir_paged_kv_cache(
self,
max_batch_size: tir.Var,
max_total_seq_len: tir.Var,
prefill_chunk_size: tir.Var,
page_size: tir.Var,
) -> PagedKVCache:
return TIRPagedKVCache(
attn_kind="mha",
max_batch_size=max_batch_size,
max_total_seq_len=max_total_seq_len,
prefill_chunk_size=prefill_chunk_size,
page_size=page_size,
support_sliding_window=0,
layer_partition=relax.ShapeExpr([0, self.num_hidden_layers]),
num_hidden_layers=self.num_hidden_layers,
num_attention_heads=self.num_attention_heads,
num_key_value_heads=self.num_key_value_heads,
qk_head_dim=self.head_dim,
v_head_dim=self.head_dim,
mla_original_qk_head_dim=0,
mla_original_v_head_dim=0,
rope_mode=RopeMode.NORMAL,
rope_scale=1,
rope_theta=self.rope_theta,
rope_scaling={},
rope_ext_factors=relax.PrimValue(0),
rotary_dim=self.head_dim,
dtype=self.dtype,
target=target,
enable_disaggregation=False,
)
def get_default_spec(self):
mod_spec = {
"embed": {
"input_ids": nn.spec.Tensor(["seq_len"], "int32"),
"$": {
"param_mode": "packed",
"effect_mode": "none",
},
},
"prefill": {
"input_embed": nn.spec.Tensor([1, "seq_len", self.hidden_size], self.dtype),
"paged_kv_cache": nn.spec.Object(object_type=PagedKVCache),
"$": {
"param_mode": "packed",
"effect_mode": "none",
},
},
"decode": {
"input_embed": nn.spec.Tensor([1, 1, self.hidden_size], self.dtype),
"paged_kv_cache": nn.spec.Object(object_type=PagedKVCache),
"$": {
"param_mode": "packed",
"effect_mode": "none",
},
},
"create_tir_paged_kv_cache": {
"max_batch_size": int,
"max_total_seq_len": int,
"prefill_chunk_size": int,
"page_size": int,
"$": {
"param_mode": "none",
"effect_mode": "none",
},
},
}
return nn.spec.ModuleSpec.from_raw(mod_spec, self)
将模型导出为Relax IRModule
在定义模型架构后,我们可以将模型导出到Relax IRModule。 为便于演示,我们仅展示模型架构和参数的部分内容。
model_config = LlamaConfig()
model = LlamaForCasualLM(model_config)
model.to("float16")
mod, named_params = model.export_tvm(spec=model.get_default_spec())
prefill_str = mod["prefill"].script()
print(*prefill_str.split("\n")[3:20], sep="\n") # Only show the first 10 lines for demonstration
print(" ...")
print("\nParameters:")
pprint(named_params[:5]) # Only show the first 5 parameters for demonstration
@R.function
def prefill(input_embed: R.Tensor((1, "seq_len", 2048), dtype="float16"), paged_kv_cache: R.Object, packed_params: R.Tuple(R.Tensor((32000, 2048), dtype="float16"), R.Tensor((2560, 2048), dtype="float16"), R.Tensor((2048, 2048), dtype="float16"), R.Tensor((11264, 2048), dtype="float16"), R.Tensor((2048, 5632), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2560, 2048), dtype="float16"), R.Tensor((2048, 2048), dtype="float16"), R.Tensor((11264, 2048), dtype="float16"), R.Tensor((2048, 5632), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2560, 2048), dtype="float16"), R.Tensor((2048, 2048), dtype="float16"), R.Tensor((11264, 2048), dtype="float16"), R.Tensor((2048, 5632), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2560, 2048), dtype="float16"), R.Tensor((2048, 2048), dtype="float16"), R.Tensor((11264, 2048), dtype="float16"), R.Tensor((2048, 5632), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2560, 2048), dtype="float16"), R.Tensor((2048, 2048), dtype="float16"), R.Tensor((11264, 2048), dtype="float16"), R.Tensor((2048, 5632), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2560, 2048), dtype="float16"), R.Tensor((2048, 2048), dtype="float16"), R.Tensor((11264, 2048), dtype="float16"), R.Tensor((2048, 5632), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2560, 2048), dtype="float16"), R.Tensor((2048, 2048), dtype="float16"), R.Tensor((11264, 2048), dtype="float16"), R.Tensor((2048, 5632), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2560, 2048), dtype="float16"), R.Tensor((2048, 2048), dtype="float16"), R.Tensor((11264, 2048), dtype="float16"), R.Tensor((2048, 5632), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2560, 2048), dtype="float16"), R.Tensor((2048, 2048), dtype="float16"), R.Tensor((11264, 2048), dtype="float16"), R.Tensor((2048, 5632), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2560, 2048), dtype="float16"), R.Tensor((2048, 2048), dtype="float16"), R.Tensor((11264, 2048), dtype="float16"), R.Tensor((2048, 5632), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2560, 2048), dtype="float16"), R.Tensor((2048, 2048), dtype="float16"), R.Tensor((11264, 2048), dtype="float16"), R.Tensor((2048, 5632), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2560, 2048), dtype="float16"), R.Tensor((2048, 2048), dtype="float16"), R.Tensor((11264, 2048), dtype="float16"), R.Tensor((2048, 5632), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2560, 2048), dtype="float16"), R.Tensor((2048, 2048), dtype="float16"), R.Tensor((11264, 2048), dtype="float16"), R.Tensor((2048, 5632), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2560, 2048), dtype="float16"), R.Tensor((2048, 2048), dtype="float16"), R.Tensor((11264, 2048), dtype="float16"), R.Tensor((2048, 5632), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2560, 2048), dtype="float16"), R.Tensor((2048, 2048), dtype="float16"), R.Tensor((11264, 2048), dtype="float16"), R.Tensor((2048, 5632), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2560, 2048), dtype="float16"), R.Tensor((2048, 2048), dtype="float16"), R.Tensor((11264, 2048), dtype="float16"), R.Tensor((2048, 5632), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2560, 2048), dtype="float16"), R.Tensor((2048, 2048), dtype="float16"), R.Tensor((11264, 2048), dtype="float16"), R.Tensor((2048, 5632), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2560, 2048), dtype="float16"), R.Tensor((2048, 2048), dtype="float16"), R.Tensor((11264, 2048), dtype="float16"), R.Tensor((2048, 5632), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2560, 2048), dtype="float16"), R.Tensor((2048, 2048), dtype="float16"), R.Tensor((11264, 2048), dtype="float16"), R.Tensor((2048, 5632), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2560, 2048), dtype="float16"), R.Tensor((2048, 2048), dtype="float16"), R.Tensor((11264, 2048), dtype="float16"), R.Tensor((2048, 5632), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2560, 2048), dtype="float16"), R.Tensor((2048, 2048), dtype="float16"), R.Tensor((11264, 2048), dtype="float16"), R.Tensor((2048, 5632), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2560, 2048), dtype="float16"), R.Tensor((2048, 2048), dtype="float16"), R.Tensor((11264, 2048), dtype="float16"), R.Tensor((2048, 5632), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((32000, 2048), dtype="float16"))) -> R.Tuple(R.Tensor((1, 1, 32000), dtype="float32"), R.Object):
seq_len = T.int64()
R.func_attr({"num_input": 2})
with R.dataflow():
model_embed_tokens_weight1: R.Tensor((32000, 2048), dtype="float16") = packed_params[0]
model_layers_0_self_attn_qkv_proj_weight1: R.Tensor((2560, 2048), dtype="float16") = packed_params[1]
model_layers_0_self_attn_o_proj_weight1: R.Tensor((2048, 2048), dtype="float16") = packed_params[2]
model_layers_0_mlp_gate_up_proj_weight1: R.Tensor((11264, 2048), dtype="float16") = packed_params[3]
model_layers_0_mlp_down_proj_weight1: R.Tensor((2048, 5632), dtype="float16") = packed_params[4]
model_layers_0_input_layernorm_weight1: R.Tensor((2048,), dtype="float16") = packed_params[5]
model_layers_0_post_attention_layernorm_weight1: R.Tensor((2048,), dtype="float16") = packed_params[6]
model_layers_1_self_attn_qkv_proj_weight1: R.Tensor((2560, 2048), dtype="float16") = packed_params[7]
model_layers_1_self_attn_o_proj_weight1: R.Tensor((2048, 2048), dtype="float16") = packed_params[8]
model_layers_1_mlp_gate_up_proj_weight1: R.Tensor((11264, 2048), dtype="float16") = packed_params[9]
model_layers_1_mlp_down_proj_weight1: R.Tensor((2048, 5632), dtype="float16") = packed_params[10]
model_layers_1_input_layernorm_weight1: R.Tensor((2048,), dtype="float16") = packed_params[11]
...
Parameters:
[('model.embed_tokens.weight', Tensor([32000, 2048], "float16")),
('model.layers.0.self_attn.qkv_proj.weight', Tensor([2560, 2048], "float16")),
('model.layers.0.self_attn.o_proj.weight', Tensor([2048, 2048], "float16")),
('model.layers.0.mlp.gate_up_proj.weight', Tensor([11264, 2048], "float16")),
('model.layers.0.mlp.down_proj.weight', Tensor([2048, 5632], "float16"))]
定义优化流程
我们定义了一系列优化步骤来优化模型。该优化流程是专门为LLMs设计的。
@register_pipeline("opt_llm")
def _pipeline( # pylint: disable=too-many-arguments
ext_mods: List[nn.ExternModule] = None,
):
ext_mods = ext_mods or []
@tvm.transform.module_pass(opt_level=0)
def _pipeline(mod: tvm.ir.IRModule, _ctx: tvm.transform.PassContext) -> tvm.ir.IRModule:
seq = tvm.transform.Sequential(
[
# Phase 1. Passes on high-level operator graph
# We can enable cublas for further optimization
relax.transform.FuseTransposeMatmul(),
# Phase 2. Lowering to TIR, inherited TVM Relax's official "zero" pipeline
relax.transform.LegalizeOps(),
relax.transform.AnnotateTIROpPattern(),
relax.transform.FoldConstant(),
relax.transform.FuseOps(),
relax.transform.FuseTIR(),
# Phase 3. Passes on TIR
relax.transform.DeadCodeElimination(),
# Phase 4. Low-level Optimizations
dlight.ApplyDefaultSchedule(
dlight.gpu.Matmul(),
dlight.gpu.GEMV(),
dlight.gpu.Reduction(),
dlight.gpu.GeneralReduction(),
dlight.gpu.Fallback(),
),
# Phase 5. Lowering to VM bytecode
relax.transform.RewriteDataflowReshape(),
relax.transform.ToNonDataflow(),
relax.transform.RemovePurityChecking(),
relax.transform.CallTIRRewrite(),
relax.transform.StaticPlanBlockMemory(),
relax.transform.RewriteCUDAGraph(),
relax.transform.LowerAllocTensor(),
relax.transform.KillAfterLastUse(),
relax.transform.LowerRuntimeBuiltin(),
relax.transform.VMShapeLower(),
relax.transform.AttachGlobalSymbol(),
relax.transform.AttachExternModules(ext_mods),
]
)
mod = seq(mod)
return mod
return _pipeline
with target:
ex = tvm.compile(mod, target, relax_pipeline=relax.get_pipeline("opt_llm"))
vm = relax.VirtualMachine(ex, dev)
准备模型权重
我们从Hugging Face加载预训练权重并准备模型权重。 预训练权重以Hugging Face格式存储。我们需要加载这些权重 并准备模型参数。
注意
请注意,本教程中不会执行以下代码,因为在CI环境中没有预训练权重可用。
IS_IN_CI = os.getenv("CI", "") == "true"
HF_WEIGHT_PATH = None
# HF_WEIGHT_PATH = Path("/path/to/TinyLlama-1.1B-Chat-v1.0/")
if not IS_IN_CI:
import numpy as np
import safetensors.torch
import torch
if HF_WEIGHT_PATH is None or not HF_WEIGHT_PATH.exists():
raise ValueError("Please set the HF_WEIGHT_PATH to the path of the pre-trained weights.")
# Torch format weights
param_dict = safetensors.torch.load_file(HF_WEIGHT_PATH / "model.safetensors", device="cpu")
# Numpy format weights
param_dict = {
k: v.half().numpy() if v.dtype == torch.bfloat16 else v.numpy()
for k, v in param_dict.items()
}
named_params = dict(named_params)
for i in range(model_config.num_hidden_layers):
# Add QKV in self attention
attn = f"model.layers.{i}.self_attn"
param_dict[f"{attn}.qkv_proj.weight"] = np.concatenate(
[
param_dict.pop(f"{attn}.q_proj.weight"), # Pop the old parameters to save memory
param_dict.pop(f"{attn}.k_proj.weight"),
param_dict.pop(f"{attn}.v_proj.weight"),
],
axis=0,
)
# Add gates in MLP
mlp = f"model.layers.{i}.mlp"
param_dict[f"{mlp}.gate_up_proj.weight"] = np.concatenate(
[
param_dict.pop(f"{mlp}.gate_proj.weight"),
param_dict.pop(f"{mlp}.up_proj.weight"),
],
axis=0,
)
# Convert params into ndarray
params = [
tvm.nd.array(param_dict[k].astype("float16"), device=dev) for k in named_params.keys()
]
部署已编译的模型
在模型和权重准备就绪后,我们可以将编译好的模型部署到目标设备上。 语言模型推理包括两个步骤:预填充(prefill)和解码(decode)。预填充步骤用于 处理输入标记并存储KVCache。解码步骤用于生成标记,直到生成结束标记。
分词处理
第一步是将输入提示进行标记化,并将这些标记嵌入到隐藏状态中。标记化和嵌入过程与原模型相同。我们使用HF标记器对输入提示进行标记化,并将标记嵌入到隐藏状态中。需要注意的是,不同模型需要不同的标记化和提示格式,请参考模型文档以获取正确的标记化和提示格式。
if not IS_IN_CI:
from transformers import AutoTokenizer
tokenizer = AutoTokenizer.from_pretrained(HF_WEIGHT_PATH)
messages = [
{"role": "user", "content": "What's your name?"},
]
prompt = tokenizer.apply_chat_template(messages)
input_len = len(prompt)
# Load prompt tokens into TVM ndarray on the target device
tokens = tvm.nd.array(np.array(prompt).astype("int32"), device=dev)
创建KVCache
在开始推理之前,我们需要创建KVCache。KVCache用于存储注意力层的键和值张量。Apache TVM提供了PagedKVCache来存储这些键和值张量。我们使用指定的参数创建PagedKVCache。
嵌入
下一步是将令牌嵌入到隐藏状态中。我们使用Relax IRModule中编译的embed函数将令牌嵌入到隐藏状态中。
nd_view_func = tvm.get_global_func("vm.builtin.reshape")
def embed(tokens, params):
_embed = vm["embed"](tokens, params)
# Reshape hidden from [seq_len, hidden_size] to [1, seq_len, hidden_size]
_embed = nd_view_func(_embed, ShapeTuple([1, _embed.shape[0], _embed.shape[1]]))
return _embed
预填充
在执行前向传递之前,我们首先获取一些辅助函数进行准备工作。
add_sequence_func = tvm.get_global_func("vm.builtin.kv_state_add_sequence")
begin_forward_func = tvm.get_global_func("vm.builtin.kv_state_begin_forward")
end_forward_func = tvm.get_global_func("vm.builtin.kv_state_end_forward")
当我们创建一个新序列时,需要调用add_sequence_func来初始化请求。此外,还需要调用begin_forward_func来开始前向传递,以及调用end_forward_func来结束前向传递。
现在我们获得了预填充步骤输出的logits。这些logits将通过采样用于生成token。让我们从这些logits中采样token。
在本教程中,我们简化了采样过程,直接选择概率最高的token。实际应用中,应该根据概率分布进行采样。此外,为了使教程简洁明了,我们在CPU上执行采样过程。
def sample_token(logits):
logits_np = logits.numpy()
return np.argmax(logits_np)
if not IS_IN_CI:
last_token = sample_token(logits)
output_tokens = [last_token]
解码
在预填充步骤之后,我们可以开始解码步骤。解码步骤用于持续生成标记,直到生成结束标记为止。我们使用Relax IRModule中编译的decode函数来生成这些标记。
if not IS_IN_CI:
print("The generated token:")
while last_token != tokenizer.eos_token_id:
tokens = tvm.nd.array(np.array([last_token]).astype("int32"), device=dev)
hidden_states = embed(tokens, params)
begin_forward_func(kv_cache, ShapeTuple([seq_id]), ShapeTuple([1]))
logits, kv_cache = vm["decode"](hidden_states, kv_cache, params)
end_forward_func(kv_cache)
last_token = sample_token(logits)
output_tokens.append(last_token)
print(tokenizer.decode(output_tokens))
脚本总运行时间: (1 分钟 2.556 秒)