损失

不同类型的蒸馏损失。

LogitsDistillationLoss

输出logits上的KL散度损失。

MGDLoss

PyTorch 版本的掩码生成蒸馏。

class LogitsDistillationLoss

基础类:_Loss

输出logits上的KL散度损失。

此函数实现了论文中的蒸馏损失:https://arxiv.org/abs/1503.02531

__init__(temperature=1.0, reduction='batchmean')

构造函数。

Parameters:
  • temperature (float) – 用于在计算损失之前软化logits_t和logits_s的值。

  • reduction (str) – 在返回之前如何减少最终的点损失。传递 "none" 以 之后使用您自己的减少函数,例如使用损失掩码。

forward(logits_s, logits_t)

在学生和教师的logits上计算KD损失。

Parameters:
  • logits_s (Tensor) – 学生的logits,被视为预测。

  • logits_t (Tensor) – 教师的logits,被视为标签。

Return type:

张量

注意

假设类别logits维度是最后一个。

class MGDLoss

基础类:_Loss

PyTorch 版本的掩码生成蒸馏。

此函数实现了论文中的蒸馏损失:https://arxiv.org/abs/2205.01529

__init__(num_student_channels, num_teacher_channels, alpha_mgd=1.0, lambda_mgd=0.65)

构造函数。

Parameters:
  • num_student_channels (int) – 学生特征图中的通道数。

  • num_teacher_channels (int) – 教师特征图中的通道数。

  • alpha_mgd (float) – 标量,最终损失乘以该值。默认为1.0。

  • lambda_mgd (float) – 掩码比率。默认为0.65。

forward(out_s, out_t)

前向函数。

Parameters:
  • out_s (Tensor) – 学生的特征图(形状为 BxCxHxW)。

  • out_t (Tensor) – 教师的特征图(形状为 BxCxHxW)。