NBeats#
- class pytorch_forecasting.models.nbeats._nbeats.NBeats(stack_types: List[str] | None = None, num_blocks: List[int] | None = None, num_block_layers: List[int] | None = None, widths: List[int] | None = None, sharing: List[bool] | None = None, expansion_coefficient_lengths: List[int] | None = None, prediction_length: int = 1, context_length: int = 1, dropout: float = 0.1, learning_rate: float = 0.01, log_interval: int = -1, log_gradient_flow: bool = False, log_val_interval: int = None, weight_decay: float = 0.001, loss: 多视角指标 = None, reduce_on_plateau_patience: int = 1000, backcast_loss_ratio: float = 0.0, logging_metrics: ModuleList = None, **kwargs)[来源]#
基础:
BaseModel初始化 NBeats 模型 - 如果可能,请使用它的
from_dataset()方法。基于文章 N-BEATS: Neural basis expansion analysis for interpretable time series forecasting。该网络(如果作为集成使用)在M4竞赛中表现超过了所有其他方法,包括传统统计方法的集成。M4竞赛可以说是单变量时间序列预测最重要的基准。
近年来,
NHiTS网络的表现一直优于 N-BEATS。- Parameters:
stack_types – 以下值之一:“generic”、“seasonality”或“trend”。 一个长度为1或‘num_stacks’的字符串列表。 通用模式的默认和推荐值:[“generic”] 可解释模式的推荐值:[“trend”,“seasonality”]
num_blocks – 每个堆栈中的块数。长度为1或“num_stacks”的整数列表。通用模式的默认和推荐值:[1] 可解释模式的推荐值:[3]
num_block_layers – 每个块的带有ReLu激活的全连接层数。长度为1或‘num_stacks’的整数列表。 通用模式的默认和推荐值:[4] 可解释模式的推荐值:[4]
width – 块中具有ReLu激活的全连接层的宽度。 一个长度为1或“num_stacks”的整数列表。 通用模式的默认和推荐值:[512] 可解释模式的推荐值:[256, 2048]
共享 – 权重是否与每个堆栈的其他块共享。 长度为 1 或 ‘num_stacks’ 的整数列表。 通用模式的默认和推荐值: [False] 可解释模式的推荐值: [True]
expansion_coefficient_length – 如果类型是 “G” (通用),那么扩展系数的长度。 如果类型是 “T” (趋势),那么它对应多项式的度。 如果类型是 “S” (季节性),那么这是允许的最小周期,例如每个时间步变化的最小周期为2。 整数列表长度为1或‘num_stacks’。 通用模式的默认值:[32] 可解释模式的推荐值:[3]
prediction_length – 预测长度。也称为“视野”。
context_length – 条件预测的时间单位数。也称为“回溯期”。应该在预测长度的1-10倍之间。
backcast_loss_ratio – 在计算损失时,回溯与预测的权重比较。权重为1.0意味着回溯和预测损失的权重相同(与回溯和预测的长度无关)。默认为0.0,即无权重。
loss – 要优化的损失。默认为 MASE()。
log_gradient_flow – 如果记录梯度流,这会耗费时间,仅应在诊断训练故障时进行
reduce_on_plateau_patience (int) – 学习率减少10倍后的耐心时间
logging_metrics (nn.ModuleList[MultiHorizonMetric]) – 在训练期间记录的指标列表。默认为 nn.ModuleList([SMAPE(), MAE(), RMSE(), MAPE(), MASE()])
**kwargs – 额外的参数用于
BaseModel。
方法
forward(x)网络的前向传递。
from_dataset(数据集, **kwargs)方便的函数,用于从 :py:class`~pytorch_forecasting.data.timeseries.TimeSeriesDataSet` 创建网络。
log_interpretation(x, out, batch_idx)在tensorboard中对网络预测的日志进行解释。
plot_interpretation(x, output, idx[, ax, ...])图表解释。
step(x, y, batch_idx)进行训练/验证步骤。
- forward(x: Dict[str, Tensor]) Dict[str, Tensor][来源]#
网络的前向传递。
- Parameters:
x (字典[字符串, torch.Tensor]) – 来自于
TimeSeriesDataSet的数据加载器的输入。- Returns:
模型的输出
- Return type:
字典[str, torch.Tensor]
- classmethod from_dataset(dataset: 时间序列数据集, **kwargs)[来源]#
方便的函数,用于从 :py:class`~pytorch_forecasting.data.timeseries.TimeSeriesDataSet` 创建网络。
- Parameters:
数据集 (TimeSeriesDataSet) – 唯一预测变量是目标的 数据集。
**kwargs – 传递给
__init__方法的额外参数。
- Returns:
NBeats
- plot_interpretation(x: Dict[str, Tensor], output: Dict[str, Tensor], idx: int, ax=None, plot_seasonality_and_generic_on_secondary_axis: bool = False)[来源]#
图表解释。
绘制两个面板:预测和回溯与实际值的对比,以及将预测分解为趋势、季节性和一般预测。
- Parameters:
x (字典[字符串, torch.Tensor]) – 网络输入
输出 (字典[字符串, torch.Tensor]) – 网络输出
idx (int) – 要绘制解释的样本索引。
ax (列表[matplotlib坐标轴], 可选) – 要绘制解释的两个matplotlib坐标轴的列表。默认为None。
plot_seasonality_and_generic_on_secondary_axis (bool, optional) – 是否在第二个面板的次要轴上绘制季节性和通用预测。默认为 False。
- Returns:
matplotlib 图形
- Return type:
plt.Figure