转换
用于处理优化方法的模型转换和恢复的模块。
在应用模型优化算法时,我们通常需要在算法的每一步(模式)中修改模型。该模块提供了状态管理器,这是一个标准化的接口(类),用于记录和存储模型中的状态信息。
在状态管理器的基础上,该模块提供了保存这些修改历史(“modelopt 状态字典”)的工具,并将未修改的模型恢复到状态字典中指示的状态。
类
一个类,用于处理为每个任务/模式对应的模式存储的modelopt状态。 |
函数
将提供的模式应用于模型,记录更改,并返回模型。 |
|
返回描述模型修改的modelopt状态字典。 |
|
保存模型的状态字典以及模型优化器的状态字典,以便恢复其架构。 |
|
根据用户提供的模型,从modelopt状态字典中恢复模型架构。 |
|
加载检查点,恢复模型优化模型的修改,并加载模型的权重。 |
- class ModeloptStateManager
基础类:
object一个类,用于处理为每个任务/模式对应的模式存储的modelopt状态。
- __init__(model=None, init_state=False)
初始化状态管理器。
- Parameters:
model (Module | None) – 存储了modelopt_state的模块。如果为None,则会创建一个假模块来存储可能通过管理器添加的任何状态。
init_state (bool) – 如果模型不存在,是否初始化模型的modelopt状态。
- Return type:
无
- add_mode(mode, config, metadata)
添加模式并就地更新状态。
请注意,self._state 是一个列表(保留了键的插入顺序),因此我们可以回忆起模式的顺序!
- Parameters:
mode (_ModeDescriptor | str) –
config (ModeloptBaseConfig) –
metadata (Dict[str, Any]) –
- Return type:
无
- check_mode(mode)
检查提议的模式是否与当前状态兼容。
- Parameters:
mode (_ModeDescriptor | str) –
- Return type:
无
- static get_config_class(mode, config)
将提供的配置标准化为相应的配置类。
- Parameters:
mode (_ModeDescriptor | str) –
config (Dict[str, Any]) –
- Return type:
- property has_state: bool
返回模型是否具有非平凡模型优化状态。
- classmethod is_converted(model, is_root=False)
检查模型是否已转换。
- Parameters:
model (Module) – 一个模型,用于检查从转换过程中获取的状态/元数据。
is_root (bool) – 额外检查具有状态的模块是否是根模块。
- Returns:
如果模型包含指示其已被转换的modelopt状态,则为True。
- Return type:
bool
当检测到多个modelopt_states或当is_root设置为True但具有状态的模块不是根模块时,此方法会引发断言。
- property last_mode: _ModeDescriptor | None
返回应用于模型的最后模式(最后存储的模式)。
- load_state_dict(state_dict, version)
将提供的
state_dict加载到模型opt_state中。- Parameters:
state_dict (List[Tuple[str, Dict[str, Dict[str, Any]]]]) –
版本 (str) –
- Return type:
无
- modes_with_states()
从状态中生成模式以及完整的配置和元数据。
- Return type:
迭代器[元组[_ModeDescriptor, ModeloptBaseConfig, 字典[字符串, 任意类型]]]
- state_dict()
返回模型的元数据。
- Return type:
列表[元组[字符串, 字典[字符串, 字典[字符串, 任意类型]]]]
- property state_version: str
返回状态管理器中包含的modelopt状态的版本。
- classmethod transfer_state_dict(model_from, model_to)
将状态(相同实例)从一个模型转移到另一个模型。
- Parameters:
model_from (模块) –
model_to (模块) –
- Return type:
无
- update_last_state_before_new_mode(model)
更新应用于模型的最后一个模式的元数据和配置。
- Parameters:
模型 (模块) –
- Return type:
无
- update_last_state_before_save(model)
更新应用于模型的最后一个模式的元数据和配置。
- Parameters:
模型 (模块) –
- Return type:
无
- apply_mode(model, mode, registry=None, init_state=None)
将提供的模式应用于模型,记录更改,并返回模型。
- Parameters:
model (Module | Type[Module] | Tuple | Callable) – 一个类似模型的对象。可以是 nn.Module、模型类类型或元组。 元组必须是
(model_cls,)或(model_cls, args)或(model_cls, args, kwargs)的形式。模型将被初始化为model_cls(*args, **kwargs)。mode (_ModeDescriptor | str | List[_ModeDescriptor | str] | List[Tuple[str, Dict[str, Any]]]) – 一个模式、模式列表或包含模式及其配置的元组列表。模式可以指定为字符串或实际的
_ModeDescriptor类,例如QuantizeModeDescriptor类。registry (_ModeRegistryCls | None) – 一个可选的模式注册表,用于从中检索模式。如果未提供,将搜索所有注册表。
init_state (bool | None) – 标志,指示我们是否应该为模型初始化状态管理器。如果未提供,将从模型中推断。此标志可用于强制执行某种行为。例如,对于
init_state=True,如果模型已经包含状态,状态管理器将引发错误。
- Returns:
应用所需模式后的转换模型。
- Return type:
模块
- modelopt_state(model)
返回描述模型修改的modelopt状态字典。
请注意,返回的
modelopt_state不包含模型参数,如权重和偏差。modelopt_state对于保存和加载各种模型优化状态(与模型参数分开)非常有用。例如:import modelopt.torch.opt as mto # Save the modelopt state and model weights separately torch.save(mto.modelopt_state(model), "modelopt_state.pt") # Save the modelopt state torch.save(model.state_dict(), "model_weights.pt") # Save the model weights
如果你想同时保存模型权重和模型优化状态,请使用
mto.save()。- Parameters:
model (Module) – 经过modelopt修改的模型。
- Returns:
描述模型修改的modelopt状态字典。
- Return type:
Dict[str, Any]
- restore(model, f, **kwargs)
加载检查点,恢复模型优化模型的修改,并加载模型的权重。
- Parameters:
模型 (模块 | 类型[模块] | 元组 | 可调用对象) – 一个类似模型的对象。可以是一个nn.Module、一个模型类类型,或者一个元组。 元组必须是以下形式之一:
(model_cls,)或(model_cls, args)或(model_cls, args, kwargs)。 模型将被初始化为model_cls(*args, **kwargs)。f (str | PathLike | BinaryIO) – 由
mto.save()生成的目标文件位置。**kwargs –
torch.load()的额外参数。
- Returns:
具有原始权重和存储架构的模型。
- Return type:
模块
注意
请注意,在恢复过程中不支持诸如DistributedDataParallel之类的包装器。请在恢复过程之后包装模型。
- restore_from_modelopt_state(model, modelopt_state)
根据用户提供的模型,从modelopt状态字典中恢复模型架构。
此方法不会恢复模型参数,如权重、偏置和量化比例。 请在恢复模型优化状态后,使用原始检查点加载方法加载权重和偏置 使用restore_from_modelopt_state。例如:
import modelopt.torch.opt as mto model = ... # Create the model-like object # Restore the previously saved modelopt state followed by model weights mto.restore_from_modelopt_state( model, torch.load("modelopt_state.pt") ) # Restore modelopt state model.load_state_dict(torch.load("model_weights.pt"), ...) # Load the model weights
如果你想恢复模型权重和模型优化状态以及保存的比例,请使用
mto.restore()。- Parameters:
model (Module | Type[Module] | Tuple | Callable) – 一个类似模型的对象。可以是 nn.Module、模型类类型或元组。 元组必须是
(model_cls,)或(model_cls, args)或(model_cls, args, kwargs)的形式。模型将被初始化为model_cls(*args, **kwargs)。modelopt_state (Dict[str, Any]) – 描述模型修改的modelopt状态字典。可以通过
mto.modelopt_state()生成modelopt_state。
- Returns:
基于恢复的修改的模型架构,使用提供的
model参数中存储的未修改权重。- Return type:
模块
注意
请注意,在恢复过程中不支持诸如DistributedDataParallel之类的包装器。请在恢复过程之后包装模型。
- save(model, f, **kwargs)
保存模型的状态字典以及模型优化状态字典,以恢复其架构。
- Parameters:
模型 (模块) – 任何模型。
f (str | PathLike | BinaryIO) – 目标文件位置。
**kwargs –
torch.save()的额外参数。
- Return type:
无
注意
如果模型是一个包装器,例如 DistributedDataParallel,它将被解包以进行保存。