分位数预测

!pip install -Uqq nixtla
from nixtla.utils import in_colab
IN_COLAB = in_colab()
if not IN_COLAB:
    from nixtla.utils import colab_badge
    from dotenv import load_dotenv

在预测中,我们常常对预测的分布感兴趣,而不仅仅是一个点预测,因为我们希望对预测的周围不确定性有一个概念。

为此,我们可以创建_分位数预测_。

分位数预测具有直观的解释,因为它们表示预测分布的特定百分位。这使我们能够做出这样的陈述:“我们预计90%的航空乘客观察值会超过100。”这种方法对于在不确定性下进行规划非常有帮助,提供了一系列可能的未来值,并通过考虑潜在结果的全范围来帮助用户做出更明智的决策。

使用TimeGPT,我们可以创建一个预测的分布,并提取特定百分位的分位数预测。例如,第25和第75分位数分别提供了预期结果的下四分位数和上四分位数的洞察,而第50分位数或中位数则提供了一个中心估计。

TimeGPT利用符合预测来生成分位数。

if not IN_COLAB:
    load_dotenv()    
    colab_badge('docs/tutorials/10_uncertainty_quantification_with_quantile_forecasts')

1. 导入包

首先,我们导入所需的包并初始化Nixtla客户端

import pandas as pd
from nixtla import NixtlaClient

from IPython.display import display
nixtla_client = NixtlaClient(
    # defaults to os.environ.get("NIXTLA_API_KEY")
    api_key = 'my_api_key_provided_by_nixtla'
)

👍 使用 Azure AI 终端

要使用 Azure AI 终端,请设置 base_url 参数:

nixtla_client = NixtlaClient(base_url="你的 Azure AI 终端", api_key="你的 api_key")

if not IN_COLAB:
    nixtla_client = NixtlaClient()

2. 加载数据

df = pd.read_csv('https://raw.githubusercontent.com/Nixtla/transfer-learning-time-series/main/datasets/air_passengers.csv')
df.head()
timestamp value
0 1949-01-01 112
1 1949-02-01 118
2 1949-03-01 132
3 1949-04-01 129
4 1949-05-01 121

3. 使用分位数进行预测

在使用TimeGPT进行时间序列预测时,您可以设置想要预测的分位数。以下是您可以做到的方法:

quantiles = [0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9]
timegpt_quantile_fcst_df = nixtla_client.forecast(
    df=df, h=12, 
    quantiles=quantiles, 
    time_col='timestamp', target_col='value',
)
timegpt_quantile_fcst_df.head()
INFO:nixtla.nixtla_client:Validating inputs...
INFO:nixtla.nixtla_client:Preprocessing dataframes...
INFO:nixtla.nixtla_client:Inferred freq: MS
INFO:nixtla.nixtla_client:Restricting input...
INFO:nixtla.nixtla_client:Calling Forecast Endpoint...
timestamp TimeGPT TimeGPT-q-10 TimeGPT-q-20 TimeGPT-q-30 TimeGPT-q-40 TimeGPT-q-50 TimeGPT-q-60 TimeGPT-q-70 TimeGPT-q-80 TimeGPT-q-90
0 1961-01-01 437.837952 431.987091 435.043799 435.384363 436.402155 437.837952 439.273749 440.291541 440.632104 443.688812
1 1961-02-01 426.062744 412.704956 414.832837 416.042432 421.719196 426.062744 430.406293 436.083057 437.292651 439.420532
2 1961-03-01 463.116577 437.412564 444.234985 446.420233 450.705762 463.116577 475.527393 479.812921 481.998169 488.820590
3 1961-04-01 478.244507 448.726837 455.428375 465.570038 469.879114 478.244507 486.609900 490.918976 501.060638 507.762177
4 1961-05-01 505.646484 478.409872 493.154315 497.990848 499.138708 505.646484 512.154260 513.302121 518.138654 532.883096

📘 Azure AI 中可用的模型

如果您使用 Azure AI 端点,请确保设置 model="azureai"

nixtla_client.forecast(..., model="azureai")

对于公共 API,我们支持两种模型:timegpt-1timegpt-1-long-horizon

默认情况下,使用 timegpt-1。有关如何以及何时使用 timegpt-1-long-horizon 的详细信息,请参见 本教程

TimeGPT 将以 TimeGPT-q-{int(100 * q)} 的格式返回每个分位数 q 的预测。

nixtla_client.plot(
    df, timegpt_quantile_fcst_df, 
    time_col='timestamp', target_col='value',
)

需要注意的是,量化(或量化值)的选择取决于您的具体用例。对于高风险的预测,您可能会倾向于选择更保守的量化值,例如第10或第20百分位,以确保您为最坏情况做好准备。另一方面,如果您处于一个过度准备成本高的情况,您可能会选择接近中位数的量化值,如第50百分位,以在谨慎和效率之间找到平衡。

例如,如果您在大型促销活动期间为零售业务管理

历史预测

您还可以通过添加add_history=True参数来计算历史预测的分位数预测,如下所示:

timegpt_quantile_fcst_df = nixtla_client.forecast(
    df=df, h=12, 
    quantiles=quantiles, 
    time_col='timestamp', target_col='value',
    add_history=True,
)
timegpt_quantile_fcst_df.head()
INFO:nixtla.nixtla_client:Validating inputs...
INFO:nixtla.nixtla_client:Preprocessing dataframes...
INFO:nixtla.nixtla_client:Inferred freq: MS
INFO:nixtla.nixtla_client:Calling Forecast Endpoint...
INFO:nixtla.nixtla_client:Calling Historical Forecast Endpoint...
timestamp TimeGPT TimeGPT-q-10 TimeGPT-q-20 TimeGPT-q-30 TimeGPT-q-40 TimeGPT-q-50 TimeGPT-q-60 TimeGPT-q-70 TimeGPT-q-80 TimeGPT-q-90
0 1951-01-01 135.483673 111.937768 120.020593 125.848879 130.828935 135.483673 140.138411 145.118467 150.946753 159.029579
1 1951-02-01 144.442398 120.896493 128.979318 134.807604 139.787660 144.442398 149.097136 154.077192 159.905478 167.988304
2 1951-03-01 157.191910 133.646004 141.728830 147.557116 152.537172 157.191910 161.846648 166.826703 172.654990 180.737815
3 1951-04-01 148.769363 125.223458 133.306284 139.134570 144.114625 148.769363 153.424102 158.404157 164.232443 172.315269
4 1951-05-01 140.472946 116.927041 125.009866 130.838152 135.818208 140.472946 145.127684 150.107740 155.936026 164.018852
nixtla_client.plot(
    df, timegpt_quantile_fcst_df, 
    time_col='timestamp', target_col='value',
)

交叉验证

quantiles 参数也可以包括在 cross_validation 方法中,从而允许在不同的窗口和不同的分位数之间比较 TimeGPT 的性能。

timegpt_cv_quantile_fcst_df = nixtla_client.cross_validation(
    df=df, 
    h=12, 
    n_windows=5,
    quantiles=quantiles, 
    time_col='timestamp', 
    target_col='value',
)
timegpt_quantile_fcst_df.head()
INFO:nixtla.nixtla_client:Validating inputs...
INFO:nixtla.nixtla_client:Inferred freq: MS
INFO:nixtla.nixtla_client:Validating inputs...
INFO:nixtla.nixtla_client:Preprocessing dataframes...
INFO:nixtla.nixtla_client:Inferred freq: MS
INFO:nixtla.nixtla_client:Restricting input...
INFO:nixtla.nixtla_client:Calling Forecast Endpoint...
INFO:nixtla.nixtla_client:Validating inputs...
INFO:nixtla.nixtla_client:Validating inputs...
INFO:nixtla.nixtla_client:Preprocessing dataframes...
INFO:nixtla.nixtla_client:Inferred freq: MS
INFO:nixtla.nixtla_client:Restricting input...
INFO:nixtla.nixtla_client:Calling Forecast Endpoint...
INFO:nixtla.nixtla_client:Validating inputs...
INFO:nixtla.nixtla_client:Validating inputs...
INFO:nixtla.nixtla_client:Preprocessing dataframes...
INFO:nixtla.nixtla_client:Inferred freq: MS
INFO:nixtla.nixtla_client:Restricting input...
INFO:nixtla.nixtla_client:Calling Forecast Endpoint...
INFO:nixtla.nixtla_client:Validating inputs...
INFO:nixtla.nixtla_client:Validating inputs...
INFO:nixtla.nixtla_client:Preprocessing dataframes...
INFO:nixtla.nixtla_client:Inferred freq: MS
INFO:nixtla.nixtla_client:Restricting input...
INFO:nixtla.nixtla_client:Calling Forecast Endpoint...
INFO:nixtla.nixtla_client:Validating inputs...
INFO:nixtla.nixtla_client:Validating inputs...
INFO:nixtla.nixtla_client:Preprocessing dataframes...
INFO:nixtla.nixtla_client:Inferred freq: MS
INFO:nixtla.nixtla_client:Restricting input...
INFO:nixtla.nixtla_client:Calling Forecast Endpoint...
INFO:nixtla.nixtla_client:Validating inputs...
timestamp TimeGPT TimeGPT-q-10 TimeGPT-q-20 TimeGPT-q-30 TimeGPT-q-40 TimeGPT-q-50 TimeGPT-q-60 TimeGPT-q-70 TimeGPT-q-80 TimeGPT-q-90
0 1951-01-01 135.483673 111.937768 120.020593 125.848879 130.828935 135.483673 140.138411 145.118467 150.946753 159.029579
1 1951-02-01 144.442398 120.896493 128.979318 134.807604 139.787660 144.442398 149.097136 154.077192 159.905478 167.988304
2 1951-03-01 157.191910 133.646004 141.728830 147.557116 152.537172 157.191910 161.846648 166.826703 172.654990 180.737815
3 1951-04-01 148.769363 125.223458 133.306284 139.134570 144.114625 148.769363 153.424102 158.404157 164.232443 172.315269
4 1951-05-01 140.472946 116.927041 125.009866 130.838152 135.818208 140.472946 145.127684 150.107740 155.936026 164.018852
cutoffs = timegpt_cv_quantile_fcst_df['cutoff'].unique()
for cutoff in cutoffs:
    fig = nixtla_client.plot(
        df.tail(100), 
        timegpt_cv_quantile_fcst_df.query('cutoff == @cutoff').drop(columns=['cutoff', 'value']),
        time_col='timestamp', 
        target_col='value'
    )
    display(fig)

Give us a ⭐ on Github