蒸馏
用于将模型转换为modelopt.torch.distill.DistillationModel的API,可直接用于训练。
函数
主要转换函数,用于将学生模型转换为适合蒸馏的模型。 |
|
将蒸馏元模型导出到原始学生模型。 |
- convert(model, mode)
主要转换函数,用于将学生模型转换为适合蒸馏的模型。
- Parameters:
model (Module) – 用作学生的基础模型。
mode (_ModeDescriptor | str | List[_ModeDescriptor | str] | List[Tuple[str, Dict[str, Any]]]) –
一个(或多个)字符串或模式,或包含模式和其配置的元组列表,用于指示转换过程中所需的模式(和配置)。模式为模型优化设置不同的算法。以下模式可用:
"kd_loss":model将被转换为封装了教师和学生的元模型。该模式的配置在KDLossConfig中描述。
如果模式参数被指定为字典,则键应指示模式,值应指定每个模式的配置。
- Returns:
DistillationModel的一个实例。- Return type:
模块
- export(model)
将蒸馏元模型导出到原始学生模型。
- Parameters:
model (Module) – 从蒸馏模式中导出仅用于学生的模型。
- Returns:
内部学生模型。
- Return type:
模块