添加模型
本文档描述了如何在TensorRT-LLM中添加一个典型的仅解码器模型。
步骤 1. 编写建模部分
TensorRT-LLM 提供了不同级别的 API:
低级函数,例如,
concat
,add
,和sum
。基础层,例如,
Linear
和LayerNorm
。高级层,例如,
MLP
和Attention
。典型仅解码器模型的基类,例如,
DecoderModelForCausalLM
。
在
tensorrt_llm/models
中创建一个模型目录,例如my_model
。使用TensorRT-LLM的API编写一个
model.py
class MyDecoderLayer(Module):
def __init__(self, config: PretrainedConfig, layer_idx: int):
self.layer_idx = layer_idx
self.config = config
self.input_layernorm = LayerNorm(...)
self.attention = Attention(...)
self.post_layernorm = LayerNorm(...)
self.mlp = MLP(...)
def forward(self, hidden_states, ...):
# decoder layer forward
return hidden_states
class MyModel(Module):
def __init__(self, config: PretrainedConfig):
self.config = config
self.vocab_embedding = Embedding(...)
self.layers = DecoderLayerList(MyDecoderLayer, config)
self.ln_f = LayerNorm(...)
def forward(self, input_ids, ...):
# model forward
return hidden_states
class MyModelForCausalLM(DecoderModelForCausalLM):
def __init__(self, config: PretrainedConfig):
transformer = MyModel(config)
lm_head = ColumnLinear(...)
super().__init__(config, transformer, lm_head)
步骤2. 实现重量转换
源框架的权重需要转换并绑定到新添加的TensorRT-LLM模型。以下是一个转换HuggingFace权重的示例:
class MyModelForCausalLM(DecoderModelForCausalLM):
@classmethod
def from_hugging_face(
cls,
hf_model_dir,
dtype='float16',
mapping: Optional[Mapping] = None) -> MyModelForCausalLM
# create a TensorRT-LLM MyModelForCausalLM model object
# convert HuggingFace checkpoint to TensorRT-LLM expected weights dict
# load the weights to MyModelForCausalLM object
为了方便离线权重转换,在examples/my_model/
目录中开发一个convert_checkpoint.py
脚本是可选的。
步骤 3. 注册新模型
请在tensorrt_llm/models/__init__.py
中注册新的模型类MyModelForCausalLM
。
步骤 4. 验证新模型
最后,让我们验证新模型。典型的命令如下:
cd examples/my_model/
python convert_checkpoint.py --model_dir hf_model_dir --output_dir tllm_ckpt_dir
trtllm-build --checkpoint_dir tllm_ckpt_dir --output_dir tllm_engine_dir
# try the model with a single prompt
python ../run.py --engine_dir tllm_engine_dir --tokenizer_dir hf_model_dir --input_text "Born in north-east France, Soyer trained as a"
# run summarization task
python ../summarize.py --engine_dir tllm_engine_dir --hf_model_dir hf_model_dir --test_trt_llm
参考
建议阅读工作流程[./workflow.md]和检查点[./checkpoint.md]文档以获取更多详细信息。