工具
用于修剪相关和搜索空间相关任务的实用函数。
注意
通常,modelopt.torch.nas模块中的方法应直接使用这些实用函数,而不是访问SearchSpace类。这是为了确保正确执行可能需要的预处理和后处理操作。
类
上下文管理器用于启用 |
|
上下文管理器用于禁用对模型的 |
|
上下文管理器,用于设置补丁的开启或关闭。 |
函数
获取PyTorch模型的推理FLOPs。 |
|
打印搜索空间摘要。 |
|
返回所有超参数的配置字典。 |
|
使用提供的sample_func进行可搜索的hparams示例,并返回结果配置。 |
|
根据提供的配置字典选择子网。 |
|
检查是否为模型启用了modelopt补丁。 |
|
上下文管理器,用于临时替换模型基础类型的forward方法。 |
- class enable_modelopt_patches
基础类:
_DecoratorContextManager上下文管理器用于启用
modelopt补丁,例如用于autonas/fastnas的补丁。它也可以用作装饰器(确保使用括号实例化)。
例如:
modelopt_model.train() modelopt_model(inputs) # architecture changes with mtn.no_modelopt(): with mtn.enable_modelopt(): modelopt_model(inputs) # architecture changes @mtn.enable_modelopt() def forward(model, inputs): return model(inputs) with mtn.no_modelopt(): forward(modelopt_model, inputs) # architecture changes because of decorator on forward
- __init__()
构造函数。
- get_subnet_config(model, configurable=None)
返回所有超参数的配置字典。
- Parameters:
model (Module) – 包含DynamicModule的模型。
可配置 (bool | None) – None -> 所有超参数, True -> 可配置的超参数,无重复
- Returns:
一个
(parameter_name, choice)的字典,用于指定一个活动的子网。- Return type:
Dict[str, Any]
- inference_flops(network, dummy_input=None, data_shape=None, unit=1000000.0, return_str=False)
获取PyTorch模型的推理FLOPs。
- Parameters:
network (Module) – PyTorch 模型。
args – 在
mtn.convert()中定义的虚拟输入。data_shape (Tuple | None) – 如果虚拟输入是单个张量,则虚拟输入的形状。如果提供了,
args必须为None。单位 (浮点数) – 返回参数数量的单位。默认值为1e6(百万)。
return_str (bool) – 是否将FLOPs的数量作为字符串返回。
dummy_input (Any | Tuple | None) –
- Returns:
给定单位中的推理FLOPs数量,可以是字符串或浮点数。
- Return type:
float | str
- is_modelopt_patches_enabled()
检查是否为模型启用了modelopt补丁。
- Return type:
bool
- class no_modelopt_patches
基础类:
_DecoratorContextManager上下文管理器用于禁用对模型的
modelopt补丁。禁用
modelopt补丁在您希望使用模型的原始行为时非常有用。例如,您可以使用此功能来执行没有NAS操作的前向传递。它也可以用作装饰器(确保使用括号实例化)。
例如:
modelopt_model.train() modelopt_model(inputs) # architecture changes with mtn.no_modelopt(): modelopt_model(inputs) # architecture does not change @mtn.no_modelopt() def forward(model, inputs): return model(inputs) forward(modelopt_model, inputs) # architecture does not change
- __init__()
构造函数。
- print_search_space_summary(model, skipped_hparams=['kernel_size'])
打印搜索空间摘要。
- Parameters:
model (Module) – 包含DynamicModule的模型。
skipped_hparams (List[str]) –
- Return type:
无
- replace_forward(model, new_forward)
上下文管理器,用于临时替换模型基础类型的forward方法。
原始的前向函数暂时可以通过
model.forward_original访问。- Parameters:
model (Module) – 其类型的 forward 方法将被临时替换的模型。
new_forward (Callable) – 新的前向方法。前向方法应该是模型实例的绑定方法,或者将模型 (
self) 作为第一个参数。
- Return type:
迭代器[无]
例如:
fake_forward = lambda _: None with replace_forward(model, fake_forward): out = model(inputs) # this output is None out_original = model(inputs) # this output is the original output
- sample(model, sample_func=<function choice>)
使用提供的sample_func进行可搜索的hparams示例,并返回结果配置。
- Parameters:
model (Module) – 一个可搜索的模型,包含一个或多个DynamicModule。
sample_func (Callable[[Sequence[T]], T] | Dict[str, Callable[[Sequence[T]], T]]) – 用于超参数的采样函数。默认值:随机采样。
- Returns:
一个
(parameter_name, choice)的字典,用于指定一个活动的子网。- Return type:
Dict[str, Any]
- select(model, config, strict=True)
根据提供的配置字典选择子网。
- Parameters:
model (Module) – 包含DynamicModule的模型。
config (Dict[str, Any]) – 目标子网的配置,由
mtn.config()和mtn.search()返回。strict (bool) – 当配置不包含所有必要键时引发错误。
- Return type:
无
- class set_modelopt_patches_enabled
基础类:
_DecoratorContextManager上下文管理器,用于设置补丁的开启或关闭。
它可以用作上下文管理器或函数。如果用作函数,操作将在全局范围内(线程本地)被禁用。
- Parameters:
enabled – 是否启用(
True)或禁用(False)修补方法。
例如:
modelopt_model.train() modelopt_model(inputs) # architecture changes mtn.set_modelopt_enabled(False) modelopt_model(inputs) # architecture does not change with mtn.set_modelopt_enabled(True): modelopt_model(inputs) # architecture changes modelopt_model(inputs) # architecture does not change
- __init__(enabled)
构造函数。
- Parameters:
enabled (bool) –
- clone()
克隆上下文管理器。