交叉熵#
- class pytorch_forecasting.metrics.point.CrossEntropy(reduction: str = 'mean', **kwargs)[来源]#
-
分类的交叉熵损失。
初始化指标
- Parameters:
name (str) – 指标名称。默认为类名称。
分位数 (列表[浮点数], 可选) – 概率范围的分位数。默认为 None。
reduction (str, 可选) – 减少方式,“none”,“mean”或“sqrt-mean”。默认为“mean”。
方法
loss
(y_pred, target)计算未减少的损失。
to_prediction
(y_pred)将网络预测转换为点预测。
to_quantiles
(y_pred[, quantiles])将网络预测转换为分位数预测。
- loss(y_pred, target)[来源]#
计算不进行减小的损失。在派生类中覆盖
- Parameters:
y_pred – 网络输出
y_actual – 实际值
- Returns:
作为反向传播的单个数字的损失/指标
- Return type:
torch.Tensor