蒸馏

用于将模型转换为modelopt.torch.distill.DistillationModel的API,可直接用于训练。

函数

convert

主要转换函数,用于将学生模型转换为适合蒸馏的模型。

export

将蒸馏元模型导出到原始学生模型。

convert(model, mode)

主要转换函数,用于将学生模型转换为适合蒸馏的模型。

Parameters:
  • model (Module) – 用作学生的基础模型。

  • mode (_ModeDescriptor | str | List[_ModeDescriptor | str] | List[Tuple[str, Dict[str, Any]]]) –

    一个(或多个)字符串或模式,或包含模式和其配置的元组列表,用于指示转换过程中所需的模式(和配置)。模式为模型优化设置不同的算法。以下模式可用:

    • "kd_loss": model 将被转换为封装了教师和学生的元模型。该模式的配置在 KDLossConfig 中描述。

    如果模式参数被指定为字典,则键应指示模式,值应指定每个模式的配置。

Returns:

DistillationModel 的一个实例。

Return type:

模块

export(model)

将蒸馏元模型导出到原始学生模型。

Parameters:

model (Module) – 从蒸馏模式中导出仅用于学生的模型。

Returns:

内部学生模型。

Return type:

模块