Shortcuts

FullModelTorchTuneCheckpointer

class torchtune.training.FullModelTorchTuneCheckpointer(checkpoint_dir: str, checkpoint_files: List[str], model_type: str, output_dir: str, adapter_checkpoint: Optional[str] = None, recipe_checkpoint: Optional[str] = None, resume_from_checkpoint: bool = False)[source]

检查点读取器,以与torchtune兼容的格式读写检查点。无需转换权重。

目前这仅支持读取单个检查点文件。随着我们增加对更大模型的支持,这可能会有所改变。

Parameters:
  • checkpoint_dir (str) – 包含检查点文件的目录

  • checkpoint_files (List[str]) – 要加载的检查点文件列表。由于检查点处理器会负责按文件ID排序,因此此列表中的顺序无关紧要

  • model_type (str) – 正在加载检查点的模型的模型类型,例如 LLAMA3。

  • output_dir (str) – 保存检查点文件的目录

  • adapter_checkpoint (Optional[str]) – 适配器权重的路径。如果为None, 并且resume_from_checkpoint=True,则在output_dir/epoch_{largest_epoch}中查找adapter_model.pt。 默认值为None。

  • recipe_checkpoint (可选[str]) – 配方状态检查点文件的路径。如果为 None, 并且 resume_from_checkpoint=True,则在 output_dir/RECIPE_STATE_DIRNAME 中查找 recipe_state.pt。 默认值为 None。

  • resume_from_checkpoint (bool) – 如果为True,检查点将加载额外的检查点文件以从先前的运行中恢复训练。默认值为False

Raises:

ValueError – 如果提供了多个检查点文件

load_checkpoint(weights_only: bool = True) Dict[str, Any][source]

从文件加载torchtune检查点。目前仅支持从单个文件加载。

输出的state_dict具有以下格式,只有在resume_from_checkpoint为True时,才会出现“model”以外的键:

>>>     {
>>>         "model": {
>>>             "key_1": weight
>>>             ...
>>>         },
>>>         "optimizer": {...},
>>>         ...
>>>     }
Parameters:

weights_only (bool) – 传递给torch.load的标志。我们暴露这个标志,因为量化模型无法在weights_only=True的情况下加载

Returns:

来自输入检查点的state_dict

Return type:

字典[str, 任意]

save_checkpoint(state_dict: Dict[str, Any], epoch: int, intermediate_checkpoint: bool = False, adapter_only: bool = False) None[source]

将torchtune检查点保存到文件。如果intermediate_checkpoint为True,则会在_output_dir/RECIPE_STATE_DIRNAME中创建一个额外的检查点文件recipe_state.pt,其中包含配方状态。输出的状态字典具有以下格式:

>>> # Model
>>> {
>>>     "key_1": weight
>>>     ...
>>> }
>>>
>>> # Recipe state
>>> {
>>>     "optimizer": ...,
>>>     "epoch": ...,
>>>     ...
>>> }
Parameters:
  • state_dict (Dict[str, Any]) – 包含模型和(可选)配方状态的状态字典

  • epoch (int) – 当前周期数。这被添加到检查点文件名中,以确保我们不会覆盖中间的检查点文件。

  • intermediate_checkpoint (bool) – 如果为True,保存一个额外的检查点文件,包含配方状态

  • adapter_only (bool) – 如果为True,仅保存适配器权重。默认为False

Raises:

ValueError – 如果 adapter_only 为 True 并且在 state_dict 中找不到适配器检查点。