torch.nn.utils.prune.custom_from_mask¶
- torch.nn.utils.prune.custom_from_mask(module, name, mask)[源代码]¶
修剪与
module中名为name的参数对应的张量,通过应用预先计算的掩码mask。就地修改模块(并返回修改后的模块),通过:
添加一个名为
name+'_mask'的命名缓冲区,对应于修剪方法应用于参数name的二进制掩码。将参数
name替换为其修剪后的版本,而原始(未修剪)参数存储在一个名为name+'_orig'的新参数中。
- Parameters
- Returns
修改后的(即修剪过的)输入模块版本
- Return type
模块 (nn.Module)
示例
>>> from torch.nn.utils import prune >>> m = prune.custom_from_mask( ... nn.Linear(5, 3), name='bias', mask=torch.tensor([0, 1, 0]) ... ) >>> print(m.bias_mask) tensor([0., 1., 0.])