FullModelHFCheckpointer¶
- class torchtune.training.FullModelHFCheckpointer(checkpoint_dir: str, checkpoint_files: Union[List[str], Dict[str, str]], model_type: str, output_dir: str, adapter_checkpoint: Optional[str] = None, recipe_checkpoint: Optional[str] = None, resume_from_checkpoint: bool = False, safe_serialization: bool = True)[source]¶
读取和写入HF格式检查点的检查点器。对于LoRA模型,这包括以可以通过例如
from_pretrained加载到PEFT的格式保存检查点。示例包括来自meta-llama仓库的Llama-2-7b-hf模型(https://huggingface.co/meta-llama/Llama-2-7b-hf)。注意
HF检查点名称通常按ID排序(例如:0001_of_0003, 0002_of_0003等)。为了确保我们按正确的顺序读取文件,我们在读取之前对检查点文件名进行排序。
注意
将检查点转换为HF格式或从HF格式转换需要访问模型参数,这些参数直接从
config.json文件中读取。这有助于确保我们正确加载权重,或在HF检查点文件与torchtune的模型实现之间存在差异时出错。- Parameters:
checkpoint_dir (str) – 包含检查点文件的目录
checkpoint_files (Union[List[str], Dict[str, str]]) – 要加载的检查点文件列表或包含键 ["filename_format", "max_filename"] 的字典。由于检查点器会按文件ID排序,因此此列表中的顺序无关紧要。
model_type (str) – 正在加载检查点的模型的模型类型,例如 LLAMA3。
output_dir (str) – 保存检查点文件的目录
adapter_checkpoint (Optional[str]) – 适配器权重的路径。如果为None, 并且resume_from_checkpoint=True,则在output_dir/epoch_{largest_epoch}中查找adapter_model.pt。 默认值为None。
recipe_checkpoint (可选[str]) – 配方状态检查点文件的路径。如果为 None, 并且 resume_from_checkpoint=True,则在 output_dir/RECIPE_STATE_DIRNAME 中查找 recipe_state.pt。 默认值为 None。
resume_from_checkpoint (bool) – 如果为True,检查点将加载额外的检查点文件以从先前的运行中恢复训练。默认值为False
safe_serialization (bool) – 如果为True,检查点器将使用safetensors保存检查点文件。 默认值为True。
- load_checkpoint() Dict[str, Any][source]¶
从文件加载HF检查点。
所有检查点文件中的键和权重被合并到一个单一的state_dict中。 我们在weight_map中保留了“state_dict键” <-> “检查点文件”的映射,以便我们可以在
save_checkpoint中正确地写入状态字典。在返回之前,模型状态字典使用适当的convert_weights函数转换为torchtune兼容的格式(取决于
self._model_type)。- Returns:
torchtune 检查点状态字典
- Return type:
state_dict (字典[str, 任意])
- Raises:
ValueError – 如果输入状态字典中的值不是张量
- save_checkpoint(state_dict: Dict[str, Any], epoch: int, intermediate_checkpoint: bool = False, adapter_only: bool = False) None[source]¶
将HF检查点保存到文件。如果
intermediate_checkpoint为True,则会在_output_dir/RECIPE_STATE_DIRNAME中创建一个额外的检查点文件recipe_state.pt,其中包含配方状态。state_dict 首先被转换回 HF 格式,然后根据
_weight_map分割成单独的检查点文件。- Parameters:
- Raises:
ValueError – 如果
adapter_only为 True 并且在 state_dict 中找不到适配器检查点。