Shortcuts

torch.nn.utils.prune.custom_from_mask

torch.nn.utils.prune.custom_from_mask(module, name, mask)[源代码]

修剪与module中名为name的参数对应的张量,通过应用预先计算的掩码mask

就地修改模块(并返回修改后的模块),通过:

  1. 添加一个名为 name+'_mask' 的命名缓冲区,对应于修剪方法应用于参数 name 的二进制掩码。

  2. 将参数 name 替换为其修剪后的版本,而原始(未修剪)参数存储在一个名为 name+'_orig' 的新参数中。

Parameters
  • 模块 (nn.Module) – 包含要剪枝的张量的模块

  • 名称 (字符串) – 在 模块 中进行剪枝操作的参数名称。

  • mask (张量) – 应用于参数的二进制掩码。

Returns

修改后的(即修剪过的)输入模块版本

Return type

模块 (nn.Module)

示例

>>> from torch.nn.utils import prune
>>> m = prune.custom_from_mask(
...     nn.Linear(5, 3), name='bias', mask=torch.tensor([0, 1, 0])
... )
>>> print(m.bias_mask)
tensor([0., 1., 0.])
优云智算