算法
高级搜索和模型设计算法,帮助您优化模型。
函数
转换模型或常规模型的搜索空间的配置文件统计信息。 |
|
搜索给定的可修剪模型以找到最佳子网并返回搜索模型。 |
- profile(model, dummy_input=None, constraints=None, deployment=None, strict=False, verbose=True, use_centroid=False)
转换模型或常规模型的搜索空间的配置文件统计信息。
- Parameters:
model (Module) – 要进行分析的模型。可以是已转换的或未转换的。
dummy_input (Any | Tuple | None) –
Arguments of
model.forward(). This is used for exporting and calculating inference-based metrics, such as latency/FLOPs. The format ofdummy_inputsfollows the convention of theargsargument in torch.onnx.export. Specifically,dummy_inputcan be:a single argument (
type(dummy_input) != tuple) corresponding tomodel.forward(dummy_input)
a tuple of arguments corresponding to
model.forward(*dummy_input)
a tuple of arguments such that
type(dummy_input[-1]) == dictcorresponding tomodel.forward(*dummy_input[:-1], **dummy_input[-1])
Warning
In this case the model’s
forward()method cannot contain keyword-only arguments (e.g.forward(..., *, kw_only_args)) or variable keyword arguments (e.g.forward(..., **kwargs)) since these cannot be sorted into positional arguments.
Note
In order to pass a dict as last non-keyword argument, you need to use a tuple as
dummy_inputand add an empty dict as the last element, e.g.,dummy_input = (x, {"y": y, "z": z}, {})
The empty dict at the end will then be interpreted as the keyword args.
See torch.onnx.export for more info.
Note that if you provide a
{arg_name}with batch sizeb, the results will be computed based on batch sizeb.约束条件 (Dict[str, str | float | Dict | None] | ConstraintsFunc | None) –
从约束名称到搜索模型必须满足的上限的字典。 目前,我们支持
flops和params作为约束条件。 约束字典通常采用以下形式:constraints = {"params": 5.0e6, "flops": 4.5e8}
注意
我们建议只提供最相关的约束条件,例如 flops:
constraints = {"flops": 4.5e8}
请注意,您也可以提供百分比值而不是绝对值,例如:
# 搜索一个 flops 小于等于原始模型 60% 的模型 constraints = {"flops": "60%"}
strict (bool) – 如果约束条件不可满足,则引发错误。
verbose (bool) – 打印详细的性能分析结果。
use_centroid (bool) – 默认情况下,配置文件报告从随机采样子网络评估结果的中位数(而不是确定性中心子网络的评估结果)。将 use_centroid 设置为 True 以使用确定性中心进行性能分析。
部署 (字典[字符串, 字符串] | 无) –
- Return type:
元组[布尔, 字典[字符串, 字典]]
- Returns: A tuple (is_all_sat, stats) where
is_all_sat 是一个布尔值,表示是否所有约束都可以被满足。 stats 是一个字典,包含模型转换后的搜索空间的统计信息, 例如,最小、中心、最大子网的FLOPs和参数及其最大/最小比率, 搜索空间的大小,可配置的超参数的数量。
- search(model, constraints, dummy_input, config=None)
搜索给定的可修剪模型以找到最佳子网并返回搜索模型。
最佳子网络在满足
constraints的同时,最大化由score_func给出的分数。- Parameters:
model (Module) – 要搜索的转换后的模型。
约束条件 (字典[字符串, 字符串 | 浮点数 | 字典 | 无]) –
从约束名称到搜索模型必须满足的上限的字典。 目前,我们支持
flops和params作为约束条件。 约束字典通常采用以下形式:constraints = {"params": 5.0e6, "flops": 4.5e8}
我们建议只提供最相关的约束条件,例如 flops:
constraints = {"flops": 4.5e8}
你也可以提供百分比值而不是绝对值,例如,
# 搜索一个 flops 小于等于原始模型 60% 的模型 constraints = {"flops": "60%"}
dummy_input (Any | Tuple) –
Arguments of
model.forward(). This is used for exporting and calculating inference-based metrics, such as latency/FLOPs. The format ofdummy_inputsfollows the convention of theargsargument in torch.onnx.export. Specifically,dummy_inputcan be:a single argument (
type(dummy_input) != tuple) corresponding tomodel.forward(dummy_input)
a tuple of arguments corresponding to
model.forward(*dummy_input)
a tuple of arguments such that
type(dummy_input[-1]) == dictcorresponding tomodel.forward(*dummy_input[:-1], **dummy_input[-1])
Warning
In this case the model’s
forward()method cannot contain keyword-only arguments (e.g.forward(..., *, kw_only_args)) or variable keyword arguments (e.g.forward(..., **kwargs)) since these cannot be sorted into positional arguments.
Note
In order to pass a dict as last non-keyword argument, you need to use a tuple as
dummy_inputand add an empty dict as the last element, e.g.,dummy_input = (x, {"y": y, "z": z}, {})
The empty dict at the end will then be interpreted as the keyword args.
See torch.onnx.export for more info.
Note that if you provide a
{arg_name}with batch sizeb, the results will be computed based on batch sizeb.config (Dict[str, Any] | None) –
Additional optional arguments to configure the search. Currently, we support:
checkpoint: Path to save/restore checkpoint with dictionary containing intermediate search state. If provided, the intermediate search state will be automatically restored before search (if exists) and stored/saved during search.verbose: Whether to print detailed search space profiling and search stats during search.forward_loop: ACallablethat takes a model as input and runs a forward loop on it. It is recommended to choose the data loader used inside the forward loop carefully to reduce the runtime. Cannot be provided at the same time asdata_loaderandcollect_func.data_loader: An iterator yielding batches of data for calibrating the normalization layers in the model or compute gradient scores. It is recommended to use the same data loader as for training but with significantly fewer iterations. Cannot be provided at the same time asforward_loop.collect_func: ACallablethat takes a batch of data from the data loader as input and returns the input tomodel.forward()as described inrun_forward_loop. Cannot be provided at the same time asforward_loop.max_iter_data_loader: Maximum number of iterations to run the data loader.score_func: A callable taking the model as input and returning a single accuracy/score metric (float). This metric will be maximized during search.Note
The
score_funcis required forautonasandfastnasmodes. It will be evaluated on models in eval mode (model.eval()).loss_func: ACallablewhich takes the model output (i.e output ofmodel.forward()) and the batch of data as its inputs and returns a scalar loss. This is a required argument if the model is converted viagradnasmode.It should be possible to run a backward pass on the loss value returned by this method.
collect_funcwill be used to gather the inputs tomodel.forward()from a batch of data yielded by``data_loader``.loss_funcshould support the following usage:for i, batch in enumerate(data_loader): if i >= max_iter_data_loader: break # Assuming collect_func returns a tuple of arguments output = model(*collect_func(batch)) loss = loss_func(output, batch) loss.backward()
Note
Additional configuration options may be added by individual algorithms. Please refer to the documentation of the individual algorithms for more information.
- Return type:
元组[模块, 字典[字符串, 任意类型]]
- Returns: A tuple (subnet, state_dict) where
subnet 是搜索到的子网 (nn.Module),可以用于后续任务,如微调,state_dict 包含搜索过程的历史和详细统计信息。
注意
给定的模型被就地修改(导出)以匹配搜索算法找到的最佳子网。因此,返回的子网是对与输入模型相同的模型实例的引用。