多损失#
- class pytorch_forecasting.metrics.base_metrics.MultiLoss(metrics: List[Metric], weights: List[float] = None)[来源]#
基础:
Metric
可以与多个指标一起使用的指标。
- Parameters:
metrics (列表[LightningMetric], 可选) – 要组合的指标列表。
weights (List[float], 可选) – 权重 / 权重的倍数列表。默认对所有指标为 1.0。
方法
compute
()获取指标
forward
(y_pred, y_actual, **kwargs)计算综合指标
persistent
([mode])如果指标状态应该保存到其 state_dict,请更改 post-init。
reset
()重置指标状态变量为其默认值。
to_prediction
(y_pred, **kwargs)将网络预测转换为点预测。
to_quantiles
(y_pred, **kwargs)将网络预测转换为分位数预测。
update
(y_pred, y_actual, **kwargs)更新复合指标
- forward(y_pred: Tensor, y_actual: Tensor, **kwargs)[来源]#
计算综合指标
- Parameters:
y_pred – 网络输出
y_actual – 实际值
**kwargs – 更新函数的参数
- Returns:
可以应用反向传播的度量值
- Return type:
torch.Tensor
- to_prediction(y_pred: Tensor, **kwargs) Tensor [来源]#
将网络预测转换为点预测。
将使用
metrics
属性中的第一个指标来计算结果。- Parameters:
y_pred – 网络的预测输出
**kwargs – 用于指标的参数
- Returns:
点预测
- Return type:
torch.Tensor