Shortcuts

模型类型

class torchtune.training.ModelType(value, names=None, *, module=None, qualname=None, type=None, start=1, boundary=None)[source]

ModelType 被检查点用于区分不同的模型架构。

如果您正在添加一个与仓库中已有模型格式不同的新模型,您可以添加一个新的ModelType来控制该模型特有的权重转换逻辑。

Variables:
  • GEMMA (str) – Gemma 模型系列。参见 gemma()

  • GEMMA2 (str) – Gemma 2 模型系列。参见 gemma2()

  • LLAMA2 (str) – Llama2 模型系列。参见 llama2()

  • LLAMA3 (str) – Llama3 模型系列。参见 llama3()

  • LLAMA3_2 (str) – Llama3.2 模型系列。参见 llama3_2()

  • LLAMA3_VISION (str) – LLama3 视觉模型系列。参见 llama3_2_vision_decoder()

  • MISTRAL (str) – Mistral 模型系列。参见 mistral()

  • PHI3_MINI (str) – Phi-3 模型系列。参见 phi3()

  • 奖励 (str) – 一个带有分类头的Llama2、Llama3或Mistral模型,用于奖励建模,投影到单一类别。 参见 mistral_reward_7b()llama2_reward_7b()

  • QWEN2 (str) – Qwen2 模型系列。参见 qwen2()

  • CLIP_TEXT (str) – CLIP 文本编码器。参见 clip_text_encoder_large()

示例

>>> # Usage in a checkpointer class
>>> def load_checkpoint(self, ...):
>>>     ...
>>>     if self._model_type == MY_NEW_MODEL:
>>>         state_dict = my_custom_state_dict_mapping(state_dict)