蒸馏模型
元模型包装器,用于支持知识蒸馏学习。
类
将多个教师和学生模型封装为一个单一模型的类。 |
- class DistillationModel
-
将多个教师和学生模型封装为单个模型的类。
- compute_kd_loss(student_loss=None, loss_reduction_fn=None, skip_balancer=False)
计算蒸馏反向传播的总损失。
- Parameters:
student_loss (Tensor | None) – 从学生输出计算出的原始损失。
loss_reduction_fn (Callable) – 在每个损失张量平衡之前调用的可调用对象。在每次迭代中可调用对象改变参数的情况下,这对于损失屏蔽情况非常有用。
skip_balancer (bool) – 是否使用损失平衡器将损失字典减少为标量。
- Returns:
如果 reduce 为 True,则在
student_loss和蒸馏损失之间加权的标量总损失。 如果 reduce 为 False,则返回学生模型输出损失和逐层蒸馏损失的字典。- Return type:
张量 | 字典[字符串, 张量]
- forward(*args, **kwargs)
实现前向传播。
- Parameters:
*args – 学生和教师模型的位置输入。
**kwargs – 学生和教师模型的命名输入。
- Returns:
学生模型的输出。
- Return type:
任何
- hide_loss_modules(enable=True)
上下文管理器,用于暂时从模型中隐藏教师模型。
- hide_teacher_model(enable=True)
上下文管理器,用于暂时从模型中隐藏教师模型。
- load_state_dict(state_dict, *args, **kwargs)
重写以可能加载没有教师或损失模块的状态。
- Return type:
任何
- property loss_balancer: DistillationLossBalancer | None
获取损失平衡器(如果有的话)。
- property loss_modules: ModuleList
获取损失模块列表。
- modify(teacher_model, criterion, loss_balancer=None, expose_minimal_state_dict=True)
构造函数。
- Parameters:
teacher_model (Module) – 一个教师模型,该类将封装此模型。
criterion (Dict[Tuple[str, str], _Loss]) – 一个字典,将学生和教师模型层名称的元组映射到应用于该层对的损失函数。
loss_balancer (DistillationLossBalancer | None) –
DistillationLossBalancer的实例,它使用某种加权方案将蒸馏和非蒸馏损失减少为单个值。expose_minimal_state_dict (bool) – 如果为True,当在此类上调用
state_dict时,将隐藏教师的状态字典。这可以避免在检查点期间不必要地保存教师状态。 .. 注意: 如果使用FSDP,请设置为False
- only_student_forward(enable=True)
上下文管理器,用于暂时禁用学生模型的前向传递。
- only_teacher_forward(enable=True)
上下文管理器,用于暂时禁用学生模型的前向传递。
- state_dict(*args, **kwargs)
重写以可能返回没有教师的状态。
- Return type:
Dict[str, Any]
- property teacher_model: ModuleList
获取教师模型。
- train(mode=True)
重写以防止在未来的前向传播中存储中间输出的警告。
- Parameters:
mode (bool) –