torch.nn.utils.prune 的源代码
r"""剪枝方法。"""
import numbers
from abc import ABC, abstractmethod
from collections.abc import Iterable
from typing import Tuple
import torch
[docs]class BasePruningMethod(ABC):
r"""用于创建新剪枝技术的抽象基类。"""
提供了需要重写的方法框架,例如 :meth:`compute_mask` 和 :meth:`apply`。"""
_tensor_name: str
def __call__(self, module, inputs):
r"""将掩码乘以原始张量并存储结果。"""
将掩码(存储在 ``module[name + '_mask']`` 中)乘以原始张量(存储在 ``module[name + '_orig']`` 中),并将结果存储在 ``module[name]`` 中,通过 :meth:`apply_mask` 实现。"""
参数:
module (nn.Module): 包含要剪枝的张量的模块
inputs: 未使用。"""
setattr(module, self._tensor_name, self.apply_mask(module))
[docs] @abstractmethod
def compute_mask(self, t, default_mask):
r"""计算并返回输入张量 ``t`` 的掩码。"""
从基础 ``default_mask``(如果张量尚未剪枝,则应为全1的掩码)开始,生成一个随机掩码,根据特定的剪枝方法配方将其应用于 ``default_mask`` 之上。"""
参数:
t (torch.Tensor): 表示要剪枝的参数的重要性分数的张量。
default_mask (torch.Tensor): 先前剪枝迭代的基掩码,需要在应用新掩码后保留。与 ``t`` 具有相同的维度。"""
返回:
mask (torch.Tensor): 应用于 ``t`` 的掩码,与 ``t`` 具有相同的维度。"""
pass
[docs] def apply_mask(self, module):
r"""简单处理被剪枝参数与生成的掩码的乘法。"""
从模块中获取掩码和原始张量,并返回剪枝后的张量版本。"""
参数:
module (nn.Module): 包含要剪枝的张量的模块"""
返回:
pruned_tensor (torch.Tensor): 输入张量的剪枝版本"""
# 为了进行乘法,掩码需要已经计算出来,因此剪枝方法必须知道它正在操作的张量
assert self._tensor_name is not None, f"Module {module} 必须被剪枝" # 这会在 apply() 中设置
mask = getattr(module, self._tensor_name + "_mask")
orig = getattr(module, self._tensor_name + "_orig")
pruned_tensor = mask.to(dtype=orig.dtype) * orig
return pruned_tensor
[docs] @classmethod
def apply(cls, module, name, *args, importance_scores=None, **kwargs):
r"""动态添加剪枝并重新参数化张量。"""
添加前向预钩子,以实现动态剪枝和基于原始张量和剪枝掩码的张量重新参数化。"""
参数:
module (nn.Module): 包含要剪枝的张量的模块
name (str): 模块中剪枝将作用的参数名称。
args: 传递给 :class:`BasePruningMethod` 子类的参数
importance_scores (torch.Tensor): 用于计算剪枝掩码的重要性分数张量(与模块参数形状相同)。
如果未指定或为 None,则将使用参数本身。
kwargs: 传递给 :class:`BasePruningMethod` 子类的关键字参数"""
def _get_composite_method(cls, module, name, *args, **kwargs):
# 检查是否已经对 `module[name]` 应用了剪枝方法。如果是,则将其存储在 `old_method` 中。
old_method = None
found = 0
# 技术上应该只有一个钩子,hook.name == name
# 使用 `found` 断言这一点
hooks_to_remove = []
for k, hook in module._forward_pre_hooks.items():
# 如果存在,则获取现有内容,删除钩子,然后继续正常操作
if isinstance(hook, BasePruningMethod) and hook._tensor_name == name:
old_method = hook
hooks_to_remove.append(k)
found += 1
assert (
found <= 1
), f"避免向同一个张量 {name} 添加多个剪枝钩子。使用 PruningContainer。"
for k in hooks_to_remove:
del module._forward_pre_hooks[k]
# 应用新的剪枝方法,无论是从头开始还是在之前的方法之上。
method = cls(*args, **kwargs) # 新的剪枝
# 让剪枝方法记住它被应用到的张量名称
method._tensor_name = name
# 如果 `old_method` 存在,则将其与 `method` 结合
if old_method is not None: # 意味着有钩子
# 如果钩子已经是一个剪枝容器,只需将新的剪枝方法添加到容器中
if isinstance(old_method, PruningContainer):
old_method.add_pruning_method(method)
method <span class="o