指标#
- class pytorch_forecasting.metrics.base_metrics.Metric(name: str = None, quantiles: List[float] = None, reduction='mean', **kwargs)[来源]#
基础:
Metric
基本指标类,具有处理预测分位数和在对数空间中操作的基本功能。有关如何实现新指标的细节,请参见Lightning documentation
其他指标应继承自这个基类
初始化指标
- Parameters:
name (str) – 指标名称。默认为类名称。
分位数 (列表[浮点数], 可选) – 概率范围的分位数。默认为 None。
reduction (str, 可选) – 减少方式,“none”,“mean”或“sqrt-mean”。默认为“mean”。
方法
compute
()计算指标的抽象方法
返回模块的额外表示。
rescale_parameters
(parameters, target_scale, ...)将归一化参数缩放到输出所需的尺度。
to_prediction
(y_pred)将网络预测转换为点预测。
to_quantiles
(y_pred[, quantiles])将网络预测转换为分位数预测。
update
(y_pred, y_actual)重写此方法以更新您的指标类的状态变量。
- compute() Tensor [来源]#
计算指标的抽象方法
应该在派生类中被重写
- Parameters:
y_pred – 网络输出
y_actual – 实际值
- Returns:
可以应用反向传播的度量值
- Return type:
torch.Tensor
- rescale_parameters(parameters: Tensor, target_scale: Tensor, encoder: 基础估计器) Tensor [来源]#
将归一化参数缩放到输出所需的尺度。
- Parameters:
参数 (torch.Tensor) – 规范化参数(按最后一个维度索引)
target_scale (torch.Tensor) – 参数的规模 (n_batch_samples x (中心, 缩放))
encoder (BaseEstimator) – 原始编码器,最初对目标进行了标准化
- Returns:
真实/未标准化空间中的参数
- Return type:
torch.Tensor
- to_prediction(y_pred: Tensor) Tensor [来源]#
将网络预测转换为点预测。
- Parameters:
y_pred – 网络的预测输出
- Returns:
点预测
- Return type:
torch.Tensor