分位数损失#

class pytorch_forecasting.metrics.quantile.QuantileLoss(quantiles: List[float] | None = None, **kwargs)[来源]#

基类: MultiHorizonMetric

分位损失,即q=0.5的分位数将给出一半的平均绝对误差,因为它是按照以下方式计算的

定义为 max(q * (y-y_pred), (1-q) * (y_pred-y))

分位数损失

Parameters:

分位数 – 指标的分位数

方法

loss(y_pred, target)

计算未减少的损失。

to_prediction(y_pred)

将网络预测转换为点预测。

to_quantiles(y_pred)

将网络预测转换为分位数预测。

loss(y_pred: Tensor, target: Tensor) Tensor[来源]#

计算不进行减小的损失。在派生类中覆盖

Parameters:
  • y_pred – 网络输出

  • y_actual – 实际值

Returns:

作为反向传播的单个数字的损失/指标

Return type:

torch.Tensor

to_prediction(y_pred: Tensor) Tensor[来源]#

将网络预测转换为点预测。

Parameters:

y_pred – 网络的预测输出

Returns:

点预测

Return type:

torch.Tensor

to_quantiles(y_pred: Tensor) Tensor[来源]#

将网络预测转换为分位数预测。

Parameters:

y_pred – 网络的预测输出

Returns:

预测分位数

Return type:

torch.Tensor