Shortcuts

torch.nn.utils.prune.identity

torch.nn.utils.prune.identity(module, name)[源代码]

应用剪枝重参数化而不剪枝任何单元。

对名为 name 的参数在 module 中对应的张量应用修剪重参数化,但实际上并不修剪任何单元。通过以下方式就地修改模块(并返回修改后的模块):

  1. 添加一个名为 name+'_mask' 的命名缓冲区,对应于修剪方法应用于参数 name 的二进制掩码。

  2. 将参数 name 替换为其修剪后的版本,而原始(未修剪)参数存储在一个名为 name+'_orig' 的新参数中。

注意

掩码是一个由1组成的张量。

Parameters
  • 模块 (nn.Module) – 包含要剪枝的张量的模块。

  • 名称 (字符串) – 在 模块 中进行剪枝操作的参数名称。

Returns

修改后的(即修剪过的)输入模块版本

Return type

模块 (nn.Module)

示例

>>> m = prune.identity(nn.Linear(2, 3), 'bias')
>>> print(m.bias_mask)
张量([1., 1., 1.])
优云智算