OptimizerInBackwardWrapper¶
- class torchtune.training.OptimizerInBackwardWrapper(optim_map: Dict[str, Optimizer])[source]¶
一个用于在反向运行时优化器的检查点保存和加载的简单类。使用仅限于以下情况:
注意
此包装器仅适用于单设备使用场景。 不支持需要专门优化器状态检查点的分布式使用场景,如FSDP。
- Parameters:
optim_map (Dict[str, torch.optim.Optimizer]) – 从参数名称到优化器的映射。
示例
>>> optim_dict = { >>> p: config.instantiate(cfg_optimizer, [p]) >>> for p in self._model.parameters() >>> } >>> >>> # Save checkpoint >>> ckpt = OptimizerInBackwardWrapper(optim_dict).state_dict() >>> torch.save("/tmp/optim_ckpt", ckpt) >>> >>> # Load checkpoint >>> placeholder_optim_dict = { >>> p: config.instantiate(cfg_optimizer, [p]) >>> for p in self._model.parameters() >>> } >>> >>> wrapper = OptimInBackwardWrapper(placeholder_optim_dict) >>> >>> # load_state_dict expects a dict produced by this class's >>> # state_dict method. >>> wrapper.load_state_dict(torch.load("/tmp/optim_ckpt")) >>> # placeholder_optim_dict now has updated optimizer states.
- get_last_lr() float[source]¶
如果存在,从调度器中获取最后的学习率。
- Returns:
最后的学习率。
- Return type:
- Raises:
RuntimeError – 如果LR调度器尚未设置。
- get_optim_key(key: str) Any[source]¶
返回在反向传播中运行的任意优化器的键值。请注意,这假设所有反向传播中的优化器对于该键具有相同的值,即,使用相同的超参数进行初始化。
- load_state_dict(optim_ckpt_map: Dict[str, Any])[source]¶
从由该类的state_dict方法生成的状态字典中加载优化器状态。
- Parameters:
optim_ckpt_map (Dict[str, Any]) – 状态字典,将参数名称映射到优化器状态。
- Raises:
RuntimeError – 如果优化器状态字典不包含所有预期的参数。
- set_lr_scheduler(lr_scheduler: LRScheduler) None[source]¶
设置学习率调度器并修改其步骤方法以更新所有优化器。
- Parameters:
lr_scheduler (LRScheduler) – 使用的学习率调度器。
- state_dict() Dict[str, Any][source]¶
返回一个将参数名称映射到优化器状态的状态字典。这个状态字典只能由相同的类加载。
- Returns:
状态字典将参数名称映射到优化器状态。
- Return type:
字典[str, 任意]
- step_lr_scheduler(epoch: int = None)[source]¶
如果存在学习率调度器,则执行其步骤。
- Parameters:
epoch (int, 可选) – 当前的epoch编号。默认为None。
- Raises:
RuntimeError – 如果LR调度器尚未设置。