多损失#

class pytorch_forecasting.metrics.base_metrics.MultiLoss(metrics: List[Metric], weights: List[float] = None)[来源]#

基础: Metric

可以与多个指标一起使用的指标。

Parameters:
  • metrics (列表[LightningMetric], 可选) – 要组合的指标列表。

  • weights (List[float], 可选) – 权重 / 权重的倍数列表。默认对所有指标为 1.0。

方法

compute()

获取指标

forward(y_pred, y_actual, **kwargs)

计算综合指标

persistent([mode])

如果指标状态应该保存到其 state_dict,请更改 post-init。

reset()

重置指标状态变量为其默认值。

to_prediction(y_pred, **kwargs)

将网络预测转换为点预测。

to_quantiles(y_pred, **kwargs)

将网络预测转换为分位数预测。

update(y_pred, y_actual, **kwargs)

更新复合指标

compute() Tensor[来源]#

获取指标

Returns:

指标

Return type:

torch.Tensor

forward(y_pred: Tensor, y_actual: Tensor, **kwargs)[来源]#

计算综合指标

Parameters:
  • y_pred – 网络输出

  • y_actual – 实际值

  • **kwargs – 更新函数的参数

Returns:

可以应用反向传播的度量值

Return type:

torch.Tensor

persistent(mode: bool = False) None[来源]#

如果指标状态应该保存到其 state_dict,请更改 post-init。

reset() None[来源]#

重置指标状态变量为其默认值。

to_prediction(y_pred: Tensor, **kwargs) Tensor[来源]#

将网络预测转换为点预测。

将使用metrics属性中的第一个指标来计算结果。

Parameters:
  • y_pred – 网络的预测输出

  • **kwargs – 用于指标的参数

Returns:

点预测

Return type:

torch.Tensor

to_quantiles(y_pred: Tensor, **kwargs) Tensor[来源]#

将网络预测转换为分位数预测。

将使用metrics属性中的第一个指标来计算结果。

Parameters:
  • y_pred – 网络的预测输出

  • **kwargs – 每个指标的 to_quantiles() 方法的参数

Returns:

预测分位数

Return type:

torch.Tensor

update(y_pred: Tensor, y_actual: Tensor, **kwargs) None[来源]#

更新复合指标

Parameters:
  • y_pred – 网络输出

  • y_actual – 实际值

  • **kwargs – 更新函数的参数