TorchMetricWrapper#
- class pytorch_forecasting.metrics.base_metrics.TorchMetricWrapper(torchmetric: Metric, reduction: str = None, **kwargs)[来源]#
基础:
Metric
包装一个torchmetric以与PyTorch Forecasting一起使用。
不支持误差加权,仅支持点预测的指标。
- Parameters:
torchmetric (LightningMetric) – 用于封装的Torchmetric。
reduction (str, 可选) – 直接与torchmetric一起使用reduction。默认为None。
方法
compute
()计算指标的抽象方法
forward
(y_pred, target, **kwargs)直接聚合和评估批量输入。
persistent
([mode])如果指标状态应该保存到其 state_dict,请更改 post-init。
reset
()重置指标状态变量为其默认值。
update
(y_pred, target, **kwargs)重写此方法以更新您度量类的状态变量。
- compute()[来源]#
计算指标的抽象方法
应该在派生类中被重写
- Parameters:
y_pred – 网络输出
y_actual – 实际值
- Returns:
可以应用反向传播的度量值
- Return type:
torch.Tensor