网络

PyTorch模型的实用函数。

函数

compare_dict

比较两个字典并返回具有不匹配值的键。

get_model_attributes

获取PyTorch模型的关键属性。

get_module_device

获取PyTorch模块的设备。

get_same_padding

获取给定内核大小的相同填充。

init_model_from_model_like

从类似模型的对象初始化模型。

is_channels_last

检查模型是否使用通道最后的内存格式。

is_parallel

检查PyTorch模型是否已并行化。

make_divisible

函数取自原始的 tf 仓库。

model_to

将模型转换为与目标模型相同的设备、数据类型和内存布局。

param_num

获取PyTorch模型的参数数量。

param_num_from_forward

从正向传递中获取PyTorch模型的参数数量。

remove_bn

移除网络中的所有批量归一化层。

run_forward_loop

根据提供的数据加载器,使用模型运行多次前向传递。

set_submodule

与nn.Module.get_submodule()互补的set函数。

standardize_model_args

根据 torch.onnx.export 标准化模型参数。

standardize_model_like_tuple

标准化一个类似模型的元组。

standardize_named_model_args

根据torch.onnx.export标准化模型参数并为其命名。

standardize_constructor_args

标准化一个类似构造函数的元组。

unwrap_model

解包由支持的包装模块包装的模型或返回原始模型。

zero_grad

将模型参数中的任何梯度设置为None。

create_param_grad_clear_hook

创建一个钩子以清除参数的梯度。

get_unwrapped_name

获取清理后的模块名称(即在使用分片模块包装之前的名称)。

compare_dict(dict1, dict2)

比较两个字典并返回具有不匹配值的键。

Parameters:
  • dict1 (Dict[str, Any]) –

  • dict2 (Dict[str, Any]) –

Return type:

元组[字符串, …]

create_param_grad_clear_hook(param)

创建一个钩子以清除参数的梯度。

钩子将在参数的梯度累积后被触发。 重要提示:为了使此功能正常工作,accum_grad 应在此实用程序需要时保持活动状态。

get_model_attributes(model)

获取PyTorch模型的关键属性。

Parameters:

模型 (模块) –

Return type:

Dict[str, Any]

get_module_device(module)

获取PyTorch模块的设备。

Parameters:

模块 (模块) –

Return type:

设备

get_same_padding(kernel_size)

获取给定内核大小的相同填充。

Parameters:

kernel_size (int | Tuple[int, int]) –

Return type:

int | tuple

get_unwrapped_name(name)

获取清理后的模块名称(即在使用分片模块包装之前的名称)。

Parameters:

name (str) –

Return type:

str

init_model_from_model_like(model)

从类似模型的对象初始化模型。

Parameters:

model (Module | Type[Module] | Tuple | Callable) – 一个类似模型的对象。可以是一个 nn.Module(直接返回),一个模型类或可调用对象,或者一个元组。 如果是元组,它必须是 (model_cls_or_callable,) 或 (model_cls_or_callable, args) 或 (model_cls_or_callable, args, kwargs) 的形式。 模型将被初始化为 model_cls_or_callable(*args, **kwargs)

Return type:

模块

is_channels_last(model)

检查模型是否使用通道最后的内存格式。

Parameters:

模型 (模块) –

is_parallel(model)

检查PyTorch模型是否已并行化。

Parameters:

模型 (模块) –

Return type:

bool

make_divisible(v, divisor, min_val=None)

函数取自原始的 tf 仓库。

它确保所有层的通道数都能被8整除 可以在这里看到: https://github.com/tensorflow/models/blob/master/research/slim/nets/mobilenet/mobilenet.py

Parameters:
  • v (int | float) –

  • divisor (int | None) –

Return type:

int | float

model_to(model, target_model)

将模型转换为与目标模型相同的设备、数据类型和内存布局。

Parameters:
  • 模型 (模块) –

  • target_model (Module) –

param_num(network, trainable_only=False, unit=1000000.0)

获取PyTorch模型的参数数量。

Parameters:
  • network (Module) – PyTorch 模型。

  • trainable_only (bool) – 是否仅计算可训练参数。默认为 False。

  • unit – 返回参数数量的单位。默认值为1e6(百万)。

Returns:

给定单位中模型中的参数数量。

Return type:

浮点数

param_num_from_forward(model, trainable_only=False, args=None, unit=1000000.0)

从正向传递中获取PyTorch模型的参数数量。

Parameters:
  • network – PyTorch模型。

  • trainable_only (bool) – 是否仅计算可训练参数。默认值为 False。

  • 单位 (浮点数) – 返回参数数量的单位。默认值为1e6(百万)。

  • 模型 (模块) –

  • args (Tensor | Tuple | None) –

Returns:

给定单元中模型前向传递的参数数量。

这对于动态模块可能很有帮助,其中状态字典可能包含模型中未主动使用的额外参数,例如,由于在前向传递中停用了DynamicModule。我们通过仅计算出现在前向传递中的模块的参数来规避此问题。

remove_bn(model)

移除网络中的所有批量归一化层。

Parameters:

模型 (模块) –

run_forward_loop(model, data_loader, max_iters=None, collect_func=None, progress_bar=None, post_process=None)

根据提供的数据加载器,使用模型运行多次前向传递。

Parameters:
  • model – 用于运行前向传播的模型。

  • data_loader (Iterable) – 一个包含数据样本的迭代器。

  • max_iters (int | None) – 运行的批次数;默认情况下是无限的,或者直到 data_loader 耗尽为止。

  • collect_func (Callable[[Any], Any | Tuple] | None) –

    A Callable that takes a batch of data from the data_loader as input and returns the input to model.forward() such that the return value (input) is either:

    1. a single argument (type(input) != tuple) corresponding to

      model.forward(input)
      
    2. a tuple of arguments corresponding to

      model.forward(*input)
      
    3. a tuple of arguments such that type(input[-1]) == dict corresponding to

      model.forward(*input[:-1], **input[-1])
      

    Note

    In order to pass a dict as last non-keyword argument, you need to use a tuple as input and add an empty dict as the last element, e.g.,

    input = (x, {"y": y, "z": z}, {})
    

    The empty dict at the end will then be interpreted as the keyword args.

    See the args argument of torch.onnx.export for more info on the format of the return value of collect_func (input).

    The default collect_func assumes that the data loader returns a tuple, e.g., (images, labels, ...), and returns the first element of the tuple.

  • progress_bar (str | None) – 设置为描述字符串以查看进度条。

  • post_process (Callable | None) – 一个可调用对象,它接收模型输出和数据作为输入,并可用于运行任何后处理或操作,例如反向传播。

set_submodule(model, target, target_submodule)

与nn.Module.get_submodule()互补的set函数。

Parameters:
  • 模型 (模块) –

  • target (str) –

  • target_submodule (模块) –

standardize_constructor_args(constructor_args)

标准化一个类似构造函数的元组。

Parameters:

constructor_args (Callable | Tuple) –

Return type:

元组[可调用, 元组, 字典]

standardize_model_args(model_or_fw_or_sig, args, use_kwargs=False)

根据 torch.onnx.export 标准化模型参数。

Parameters:
  • model_or_fw_or_sig (Module | Callable | Signature) – 一个 nn.Module,它的前向方法,或者它的前向方法的签名。

  • args (Any | Tuple) – 请参考 dummy_input 参数在 mtn.profile() 中的描述。

  • use_kwargs – 影响返回值,见下文。对于 use_kwargs==False,返回的 参数也与 torch.onnx.export 兼容。

Returns:

可以在model.forward()中以相同标准化方式使用的标准化模型参数,无论它们是如何提供的,更多信息请参见下文。

Return type:

元组

  • 如果 use_kwargs == False,返回的 args 可以用作

    args = standardize_model_args(model, args, use_kwargs=False)
    model(*args)
    
  • 如果 use_kwargs == True,返回的 args 可以用作

    args = standardize_model_args(model, args, use_kwargs=True)
    model.forward(*args[:-1], **args[-1])
    

警告

如果 use_kwargs == False,模型的 forward() 方法 不能 包含仅关键字参数(例如 forward(..., *, kw_only_args))而没有默认值,并且你不得在 args 中提供它们。

警告

如果 use_kwargs == False,你不得在 args 中提供通过模型 forward() 方法中的可变关键字参数处理的可变关键字参数 (例如 forward(..., **kwargs))。

standardize_model_like_tuple(model)

标准化一个类似模型的元组。

Parameters:

model (Module | Type[Module] | Tuple | Callable) –

Return type:

元组[类型[模块], 元组, 字典]

standardize_named_model_args(model_or_fw_or_sig, args)

根据torch.onnx.export标准化模型参数并为其命名。

Parameters:
  • model_or_fw_or_sig (Module | Callable | Signature) – 一个 nn.Module,它的前向方法,或者它的前向方法的签名。

  • args (Any | Tuple) – 一个包含args/kwargs或torch.Tensor的元组,输入到模型的forward()方法中。

Return type:

元组[字典[字符串, 任意类型], 集合[字符串]]

Returns: A tuple (args_normalized, args_with_default) where
args_normalized is a dictionary of ordered model args where the key represents a unique

基于函数签名中参数名称的序列化字符串,值包含实际参数,

args_with_default is a set indicating whether the argument was retrieved from the default

模型forward()方法函数签名中的值,或者参数是否完全对应于默认值。

注意

更多信息请参见 standardize_model_args()

unwrap_model(model, warn=False, raise_error=False, msg='', force_unwrap=False)

解包由支持的包装模块包装的模型或返回原始模型。

Parameters:
  • 模型 (模块) –

  • 警告 (布尔值) –

  • raise_error (bool) –

  • msg (str) –

  • force_unwrap (bool) –

Return type:

模块

zero_grad(model)

将模型参数中的任何梯度设置为None。

Parameters:

模型 (模块) –

Return type: