网络
PyTorch模型的实用函数。
函数
比较两个字典并返回具有不匹配值的键。 |
|
获取PyTorch模型的关键属性。 |
|
获取PyTorch模块的设备。 |
|
获取给定内核大小的相同填充。 |
|
从类似模型的对象初始化模型。 |
|
检查模型是否使用通道最后的内存格式。 |
|
检查PyTorch模型是否已并行化。 |
|
函数取自原始的 tf 仓库。 |
|
将模型转换为与目标模型相同的设备、数据类型和内存布局。 |
|
获取PyTorch模型的参数数量。 |
|
从正向传递中获取PyTorch模型的参数数量。 |
|
移除网络中的所有批量归一化层。 |
|
根据提供的数据加载器,使用模型运行多次前向传递。 |
|
与nn.Module.get_submodule()互补的set函数。 |
|
根据 torch.onnx.export 标准化模型参数。 |
|
标准化一个类似模型的元组。 |
|
根据torch.onnx.export标准化模型参数并为其命名。 |
|
标准化一个类似构造函数的元组。 |
|
解包由支持的包装模块包装的模型或返回原始模型。 |
|
将模型参数中的任何梯度设置为None。 |
|
创建一个钩子以清除参数的梯度。 |
|
获取清理后的模块名称(即在使用分片模块包装之前的名称)。 |
- compare_dict(dict1, dict2)
比较两个字典并返回具有不匹配值的键。
- Parameters:
dict1 (Dict[str, Any]) –
dict2 (Dict[str, Any]) –
- Return type:
元组[字符串, …]
- create_param_grad_clear_hook(param)
创建一个钩子以清除参数的梯度。
钩子将在参数的梯度累积后被触发。 重要提示:为了使此功能正常工作,
accum_grad应在此实用程序需要时保持活动状态。
- get_model_attributes(model)
获取PyTorch模型的关键属性。
- Parameters:
模型 (模块) –
- Return type:
Dict[str, Any]
- get_module_device(module)
获取PyTorch模块的设备。
- Parameters:
模块 (模块) –
- Return type:
设备
- get_same_padding(kernel_size)
获取给定内核大小的相同填充。
- Parameters:
kernel_size (int | Tuple[int, int]) –
- Return type:
int | tuple
- get_unwrapped_name(name)
获取清理后的模块名称(即在使用分片模块包装之前的名称)。
- Parameters:
name (str) –
- Return type:
str
- init_model_from_model_like(model)
从类似模型的对象初始化模型。
- Parameters:
model (Module | Type[Module] | Tuple | Callable) – 一个类似模型的对象。可以是一个 nn.Module(直接返回),一个模型类或可调用对象,或者一个元组。 如果是元组,它必须是 (model_cls_or_callable,) 或 (model_cls_or_callable, args) 或 (model_cls_or_callable, args, kwargs) 的形式。 模型将被初始化为
model_cls_or_callable(*args, **kwargs)。- Return type:
模块
- is_channels_last(model)
检查模型是否使用通道最后的内存格式。
- Parameters:
模型 (模块) –
- is_parallel(model)
检查PyTorch模型是否已并行化。
- Parameters:
模型 (模块) –
- Return type:
bool
- make_divisible(v, divisor, min_val=None)
函数取自原始的 tf 仓库。
它确保所有层的通道数都能被8整除 可以在这里看到: https://github.com/tensorflow/models/blob/master/research/slim/nets/mobilenet/mobilenet.py
- Parameters:
v (int | float) –
divisor (int | None) –
- Return type:
int | float
- model_to(model, target_model)
将模型转换为与目标模型相同的设备、数据类型和内存布局。
- Parameters:
模型 (模块) –
target_model (Module) –
- param_num(network, trainable_only=False, unit=1000000.0)
获取PyTorch模型的参数数量。
- Parameters:
network (Module) – PyTorch 模型。
trainable_only (bool) – 是否仅计算可训练参数。默认为 False。
unit – 返回参数数量的单位。默认值为1e6(百万)。
- Returns:
给定单位中模型中的参数数量。
- Return type:
浮点数
- param_num_from_forward(model, trainable_only=False, args=None, unit=1000000.0)
从正向传递中获取PyTorch模型的参数数量。
- Parameters:
network – PyTorch模型。
trainable_only (bool) – 是否仅计算可训练参数。默认值为 False。
单位 (浮点数) – 返回参数数量的单位。默认值为1e6(百万)。
模型 (模块) –
args (Tensor | Tuple | None) –
- Returns:
给定单元中模型前向传递的参数数量。
这对于动态模块可能很有帮助,其中状态字典可能包含模型中未主动使用的额外参数,例如,由于在前向传递中停用了DynamicModule。我们通过仅计算出现在前向传递中的模块的参数来规避此问题。
- remove_bn(model)
移除网络中的所有批量归一化层。
- Parameters:
模型 (模块) –
- run_forward_loop(model, data_loader, max_iters=None, collect_func=None, progress_bar=None, post_process=None)
根据提供的数据加载器,使用模型运行多次前向传递。
- Parameters:
model – 用于运行前向传播的模型。
data_loader (Iterable) – 一个包含数据样本的迭代器。
max_iters (int | None) – 运行的批次数;默认情况下是无限的,或者直到
data_loader耗尽为止。collect_func (Callable[[Any], Any | Tuple] | None) –
A
Callablethat takes a batch of data from thedata_loaderas input and returns the input tomodel.forward()such that the return value (input) is either:a single argument (
type(input) != tuple) corresponding tomodel.forward(input)
a tuple of arguments corresponding to
model.forward(*input)
a tuple of arguments such that
type(input[-1]) == dictcorresponding tomodel.forward(*input[:-1], **input[-1])
Note
In order to pass a dict as last non-keyword argument, you need to use a tuple as
inputand add an empty dict as the last element, e.g.,input = (x, {"y": y, "z": z}, {})
The empty dict at the end will then be interpreted as the keyword args.
See the
argsargument of torch.onnx.export for more info on the format of the return value ofcollect_func(input).The default
collect_funcassumes that the data loader returns a tuple, e.g.,(images, labels, ...), and returns the first element of the tuple.progress_bar (str | None) – 设置为描述字符串以查看进度条。
post_process (Callable | None) – 一个可调用对象,它接收模型输出和数据作为输入,并可用于运行任何后处理或操作,例如反向传播。
- set_submodule(model, target, target_submodule)
与nn.Module.get_submodule()互补的set函数。
- Parameters:
模型 (模块) –
target (str) –
target_submodule (模块) –
- standardize_constructor_args(constructor_args)
标准化一个类似构造函数的元组。
- Parameters:
constructor_args (Callable | Tuple) –
- Return type:
元组[可调用, 元组, 字典]
- standardize_model_args(model_or_fw_or_sig, args, use_kwargs=False)
根据 torch.onnx.export 标准化模型参数。
- Parameters:
model_or_fw_or_sig (Module | Callable | Signature) – 一个 nn.Module,它的前向方法,或者它的前向方法的签名。
args (Any | Tuple) – 请参考
dummy_input参数在mtn.profile()中的描述。use_kwargs – 影响返回值,见下文。对于
use_kwargs==False,返回的 参数也与torch.onnx.export兼容。
- Returns:
可以在
model.forward()中以相同标准化方式使用的标准化模型参数,无论它们是如何提供的,更多信息请参见下文。- Return type:
元组
如果
use_kwargs == False,返回的 args 可以用作args = standardize_model_args(model, args, use_kwargs=False) model(*args)
如果
use_kwargs == True,返回的 args 可以用作args = standardize_model_args(model, args, use_kwargs=True) model.forward(*args[:-1], **args[-1])
警告
如果
use_kwargs == False,模型的forward()方法 不能 包含仅关键字参数(例如forward(..., *, kw_only_args))而没有默认值,并且你不得在args中提供它们。警告
如果
use_kwargs == False,你不得在args中提供通过模型forward()方法中的可变关键字参数处理的可变关键字参数 (例如forward(..., **kwargs))。
- standardize_model_like_tuple(model)
标准化一个类似模型的元组。
- Parameters:
model (Module | Type[Module] | Tuple | Callable) –
- Return type:
元组[类型[模块], 元组, 字典]
- standardize_named_model_args(model_or_fw_or_sig, args)
根据torch.onnx.export标准化模型参数并为其命名。
- Parameters:
model_or_fw_or_sig (Module | Callable | Signature) – 一个 nn.Module,它的前向方法,或者它的前向方法的签名。
args (Any | Tuple) – 一个包含args/kwargs或torch.Tensor的元组,输入到模型的
forward()方法中。
- Return type:
元组[字典[字符串, 任意类型], 集合[字符串]]
- Returns: A tuple (args_normalized, args_with_default) where
- args_normalized is a dictionary of ordered model args where the key represents a unique
基于函数签名中参数名称的序列化字符串,值包含实际参数,
- args_with_default is a set indicating whether the argument was retrieved from the default
模型
forward()方法函数签名中的值,或者参数是否完全对应于默认值。
注意
更多信息请参见
standardize_model_args()。
- unwrap_model(model, warn=False, raise_error=False, msg='', force_unwrap=False)
解包由支持的包装模块包装的模型或返回原始模型。
- Parameters:
模型 (模块) –
警告 (布尔值) –
raise_error (bool) –
msg (str) –
force_unwrap (bool) –
- Return type:
模块
- zero_grad(model)
将模型参数中的任何梯度设置为None。
- Parameters:
模型 (模块) –
- Return type:
无