损失
不同类型的蒸馏损失。
类
输出logits上的KL散度损失。 |
|
PyTorch 版本的掩码生成蒸馏。 |
- class LogitsDistillationLoss
基础类:
_Loss输出logits上的KL散度损失。
此函数实现了论文中的蒸馏损失:https://arxiv.org/abs/1503.02531。
- __init__(temperature=1.0, reduction='batchmean')
构造函数。
- Parameters:
temperature (float) – 用于在计算损失之前软化logits_t和logits_s的值。
reduction (str) – 在返回之前如何减少最终的点损失。传递
"none"以 之后使用您自己的减少函数,例如使用损失掩码。
- forward(logits_s, logits_t)
在学生和教师的logits上计算KD损失。
- Parameters:
logits_s (Tensor) – 学生的logits,被视为预测。
logits_t (Tensor) – 教师的logits,被视为标签。
- Return type:
张量
注意
假设类别logits维度是最后一个。
- class MGDLoss
基础类:
_LossPyTorch 版本的掩码生成蒸馏。
此函数实现了论文中的蒸馏损失:https://arxiv.org/abs/2205.01529。
- __init__(num_student_channels, num_teacher_channels, alpha_mgd=1.0, lambda_mgd=0.65)
构造函数。
- Parameters:
num_student_channels (int) – 学生特征图中的通道数。
num_teacher_channels (int) – 教师特征图中的通道数。
alpha_mgd (float) – 标量,最终损失乘以该值。默认为1.0。
lambda_mgd (float) – 掩码比率。默认为0.65。
- forward(out_s, out_t)
前向函数。
- Parameters:
out_s (Tensor) – 学生的特征图(形状为 BxCxHxW)。
out_t (Tensor) – 教师的特征图(形状为 BxCxHxW)。