Shortcuts

torch.nn.utils.prune 的源代码

r"""剪枝方法。"""
import numbers
from abc import ABC, abstractmethod
from collections.abc import Iterable
from typing import Tuple

import torch


[docs]class BasePruningMethod(ABC): r"""用于创建新剪枝技术的抽象基类。""" 提供了需要重写的方法框架,例如 :meth:`compute_mask` 和 :meth:`apply`。""" _tensor_name: str def __call__(self, module, inputs): r"""将掩码乘以原始张量并存储结果。""" 将掩码(存储在 ``module[name + '_mask']`` 中)乘以原始张量(存储在 ``module[name + '_orig']`` 中),并将结果存储在 ``module[name]`` 中,通过 :meth:`apply_mask` 实现。""" 参数: module (nn.Module): 包含要剪枝的张量的模块 inputs: 未使用。""" setattr(module, self._tensor_name, self.apply_mask(module))
[docs] @abstractmethod def compute_mask(self, t, default_mask): r"""计算并返回输入张量 ``t`` 的掩码。""" 从基础 ``default_mask``(如果张量尚未剪枝,则应为全1的掩码)开始,生成一个随机掩码,根据特定的剪枝方法配方将其应用于 ``default_mask`` 之上。""" 参数: t (torch.Tensor): 表示要剪枝的参数的重要性分数的张量。 default_mask (torch.Tensor): 先前剪枝迭代的基掩码,需要在应用新掩码后保留。与 ``t`` 具有相同的维度。""" 返回: mask (torch.Tensor): 应用于 ``t`` 的掩码,与 ``t`` 具有相同的维度。""" pass
[docs] def apply_mask(self, module): r"""简单处理被剪枝参数与生成的掩码的乘法。""" 从模块中获取掩码和原始张量,并返回剪枝后的张量版本。""" 参数: module (nn.Module): 包含要剪枝的张量的模块""" 返回: pruned_tensor (torch.Tensor): 输入张量的剪枝版本""" # 为了进行乘法,掩码需要已经计算出来,因此剪枝方法必须知道它正在操作的张量 assert self._tensor_name is not None, f"Module {module} 必须被剪枝" # 这会在 apply() 中设置 mask = getattr(module, self._tensor_name + "_mask") orig = getattr(module, self._tensor_name + "_orig") pruned_tensor = mask.to(dtype=orig.dtype) * orig return pruned_tensor
[docs] @classmethod def apply(cls, module, name, *args, importance_scores=None, **kwargs): r"""动态添加剪枝并重新参数化张量。""" 添加前向预钩子,以实现动态剪枝和基于原始张量和剪枝掩码的张量重新参数化。""" 参数: module (nn.Module): 包含要剪枝的张量的模块 name (str): 模块中剪枝将作用的参数名称。 args: 传递给 :class:`BasePruningMethod` 子类的参数 importance_scores (torch.Tensor): 用于计算剪枝掩码的重要性分数张量(与模块参数形状相同)。 如果未指定或为 None,则将使用参数本身。 kwargs: 传递给 :class:`BasePruningMethod` 子类的关键字参数""" def _get_composite_method(cls, module, name, *args, **kwargs): # 检查是否已经对 `module[name]` 应用了剪枝方法。如果是,则将其存储在 `old_method` 中。 old_method = None found = 0 # 技术上应该只有一个钩子,hook.name == name # 使用 `found` 断言这一点 hooks_to_remove = [] for k, hook in module._forward_pre_hooks.items(): # 如果存在,则获取现有内容,删除钩子,然后继续正常操作 if isinstance(hook, BasePruningMethod) and hook._tensor_name == name: old_method = hook hooks_to_remove.append(k) found += 1 assert ( found <= 1 ), f"避免向同一个张量 {name} 添加多个剪枝钩子。使用 PruningContainer。" for k in hooks_to_remove: del module._forward_pre_hooks[k] # 应用新的剪枝方法,无论是从头开始还是在之前的方法之上。 method = cls(*args, **kwargs) # 新的剪枝 # 让剪枝方法记住它被应用到的张量名称 method._tensor_name = name # 如果 `old_method` 存在,则将其与 `method` 结合 if old_method is not None: # 意味着有钩子 # 如果钩子已经是一个剪枝容器,只需将新的剪枝方法添加到容器中 if isinstance(old_method, PruningContainer): old_method.add_pruning_method(method) method <span class="o
优云智算