工具

用于修剪相关和搜索空间相关任务的实用函数。

注意

通常,modelopt.torch.nas模块中的方法应直接使用这些实用函数,而不是访问SearchSpace类。这是为了确保正确执行可能需要的预处理和后处理操作。

enable_modelopt_patches

上下文管理器用于启用modelopt补丁,例如用于autonas/fastnas的补丁。

no_modelopt_patches

上下文管理器用于禁用对模型的modelopt补丁。

set_modelopt_patches_enabled

上下文管理器,用于设置补丁的开启或关闭。

函数

inference_flops

获取PyTorch模型的推理FLOPs。

print_search_space_summary

打印搜索空间摘要。

get_subnet_config

返回所有超参数的配置字典。

sample

使用提供的sample_func进行可搜索的hparams示例,并返回结果配置。

select

根据提供的配置字典选择子网。

is_modelopt_patches_enabled

检查是否为模型启用了modelopt补丁。

replace_forward

上下文管理器,用于临时替换模型基础类型的forward方法。

class enable_modelopt_patches

基础类:_DecoratorContextManager

上下文管理器用于启用modelopt补丁,例如用于autonas/fastnas的补丁。

它也可以用作装饰器(确保使用括号实例化)。

例如:

modelopt_model.train()
modelopt_model(inputs)  # architecture changes

with mtn.no_modelopt():
    with mtn.enable_modelopt():
        modelopt_model(inputs)  # architecture changes


@mtn.enable_modelopt()
def forward(model, inputs):
    return model(inputs)


with mtn.no_modelopt():
    forward(modelopt_model, inputs)  # architecture changes because of decorator on forward
__init__()

构造函数。

get_subnet_config(model, configurable=None)

返回所有超参数的配置字典。

Parameters:
  • model (Module) – 包含DynamicModule的模型。

  • 可配置 (bool | None) – None -> 所有超参数, True -> 可配置的超参数,无重复

Returns:

一个(parameter_name, choice)的字典,用于指定一个活动的子网。

Return type:

Dict[str, Any]

inference_flops(network, dummy_input=None, data_shape=None, unit=1000000.0, return_str=False)

获取PyTorch模型的推理FLOPs。

Parameters:
  • network (Module) – PyTorch 模型。

  • args – 在mtn.convert()中定义的虚拟输入。

  • data_shape (Tuple | None) – 如果虚拟输入是单个张量,则虚拟输入的形状。如果提供了, args 必须为 None

  • 单位 (浮点数) – 返回参数数量的单位。默认值为1e6(百万)。

  • return_str (bool) – 是否将FLOPs的数量作为字符串返回。

  • dummy_input (Any | Tuple | None) –

Returns:

给定单位中的推理FLOPs数量,可以是字符串或浮点数。

Return type:

float | str

is_modelopt_patches_enabled()

检查是否为模型启用了modelopt补丁。

Return type:

bool

class no_modelopt_patches

基础类:_DecoratorContextManager

上下文管理器用于禁用对模型的modelopt补丁。

禁用 modelopt 补丁在您希望使用模型的原始行为时非常有用。例如,您可以使用此功能来执行没有NAS操作的前向传递。

它也可以用作装饰器(确保使用括号实例化)。

例如:

modelopt_model.train()
modelopt_model(inputs)  # architecture changes

with mtn.no_modelopt():
    modelopt_model(inputs)  # architecture does not change


@mtn.no_modelopt()
def forward(model, inputs):
    return model(inputs)


forward(modelopt_model, inputs)  # architecture does not change
__init__()

构造函数。

print_search_space_summary(model, skipped_hparams=['kernel_size'])

打印搜索空间摘要。

Parameters:
  • model (Module) – 包含DynamicModule的模型。

  • skipped_hparams (List[str]) –

Return type:

replace_forward(model, new_forward)

上下文管理器,用于临时替换模型基础类型的forward方法。

原始的前向函数暂时可以通过 model.forward_original 访问。

Parameters:
  • model (Module) – 其类型的 forward 方法将被临时替换的模型。

  • new_forward (Callable) – 新的前向方法。前向方法应该是模型实例的绑定方法,或者将模型 (self) 作为第一个参数。

Return type:

迭代器[]

例如:

fake_forward = lambda _: None

with replace_forward(model, fake_forward):
    out = model(inputs)  # this output is None

out_original = model(inputs)  # this output is the original output
sample(model, sample_func=<function choice>)

使用提供的sample_func进行可搜索的hparams示例,并返回结果配置。

Parameters:
  • model (Module) – 一个可搜索的模型,包含一个或多个DynamicModule。

  • sample_func (Callable[[Sequence[T]], T] | Dict[str, Callable[[Sequence[T]], T]]) – 用于超参数的采样函数。默认值:随机采样。

Returns:

一个(parameter_name, choice)的字典,用于指定一个活动的子网。

Return type:

Dict[str, Any]

select(model, config, strict=True)

根据提供的配置字典选择子网。

Parameters:
  • model (Module) – 包含DynamicModule的模型。

  • config (Dict[str, Any]) – 目标子网的配置,由 mtn.config()mtn.search() 返回。

  • strict (bool) – 当配置不包含所有必要键时引发错误。

Return type:

class set_modelopt_patches_enabled

基础类:_DecoratorContextManager

上下文管理器,用于设置补丁的开启或关闭。

它可以用作上下文管理器或函数。如果用作函数,操作将在全局范围内(线程本地)被禁用。

Parameters:

enabled – 是否启用(True)或禁用(False)修补方法。

例如:

modelopt_model.train()
modelopt_model(inputs)  # architecture changes

mtn.set_modelopt_enabled(False)
modelopt_model(inputs)  # architecture does not change

with mtn.set_modelopt_enabled(True):
    modelopt_model(inputs)  # architecture changes

modelopt_model(inputs)  # architecture does not change
__init__(enabled)

构造函数。

Parameters:

enabled (bool) –

clone()

克隆上下文管理器。