TorchMetricWrapper#

class pytorch_forecasting.metrics.base_metrics.TorchMetricWrapper(torchmetric: Metric, reduction: str = None, **kwargs)[来源]#

基础: Metric

包装一个torchmetric以与PyTorch Forecasting一起使用。

不支持误差加权,仅支持点预测的指标。

Parameters:
  • torchmetric (LightningMetric) – 用于封装的Torchmetric。

  • reduction (str, 可选) – 直接与torchmetric一起使用reduction。默认为None。

方法

compute()

计算指标的抽象方法

forward(y_pred, target, **kwargs)

直接聚合和评估批量输入。

persistent([mode])

如果指标状态应该保存到其 state_dict,请更改 post-init。

reset()

重置指标状态变量为其默认值。

update(y_pred, target, **kwargs)

重写此方法以更新您度量类的状态变量。

compute()[来源]#

计算指标的抽象方法

应该在派生类中被重写

Parameters:
  • y_pred – 网络输出

  • y_actual – 实际值

Returns:

可以应用反向传播的度量值

Return type:

torch.Tensor

forward(y_pred, target, **kwargs)[来源]#

直接聚合和评估批量输入。

服务于两个目的:既计算当前批次输入的度量,又将批次统计信息添加到整体累积度量状态中。输入参数与相应的 update 方法完全相同。返回的输出与 compute 的输出完全相同。

Parameters:
  • args – 任何指标 update 方法所需的参数。

  • kwargs – 该指标update方法所需的任何关键字参数。

Returns:

当前批次上评估的compute方法的输出。

Raises:

TorchMetricsUserError – 如果指标已经同步并且再次调用forward

persistent(mode: bool = False) None[来源]#

如果指标状态应该保存到其 state_dict,请更改 post-init。

reset() None[来源]#

重置指标状态变量为其默认值。

update(y_pred: Tensor, target: Tensor, **kwargs) Tensor[来源]#

重写此方法以更新您的指标类的状态变量。