!pip install -Uqq nixtla分位数预测
from nixtla.utils import in_colabIN_COLAB = in_colab()if not IN_COLAB:
from nixtla.utils import colab_badge
from dotenv import load_dotenv在预测中,我们常常对预测的分布感兴趣,而不仅仅是一个点预测,因为我们希望对预测的周围不确定性有一个概念。
为此,我们可以创建_分位数预测_。
分位数预测具有直观的解释,因为它们表示预测分布的特定百分位。这使我们能够做出这样的陈述:“我们预计90%的航空乘客观察值会超过100。”这种方法对于在不确定性下进行规划非常有帮助,提供了一系列可能的未来值,并通过考虑潜在结果的全范围来帮助用户做出更明智的决策。
使用TimeGPT,我们可以创建一个预测的分布,并提取特定百分位的分位数预测。例如,第25和第75分位数分别提供了预期结果的下四分位数和上四分位数的洞察,而第50分位数或中位数则提供了一个中心估计。
TimeGPT利用符合预测来生成分位数。
if not IN_COLAB:
load_dotenv()
colab_badge('docs/tutorials/10_uncertainty_quantification_with_quantile_forecasts')1. 导入包
首先,我们导入所需的包并初始化Nixtla客户端
import pandas as pd
from nixtla import NixtlaClient
from IPython.display import displaynixtla_client = NixtlaClient(
# defaults to os.environ.get("NIXTLA_API_KEY")
api_key = 'my_api_key_provided_by_nixtla'
)👍 使用 Azure AI 终端
要使用 Azure AI 终端,请设置
base_url参数:
nixtla_client = NixtlaClient(base_url="你的 Azure AI 终端", api_key="你的 api_key")
if not IN_COLAB:
nixtla_client = NixtlaClient()2. 加载数据
df = pd.read_csv('https://raw.githubusercontent.com/Nixtla/transfer-learning-time-series/main/datasets/air_passengers.csv')
df.head()| timestamp | value | |
|---|---|---|
| 0 | 1949-01-01 | 112 |
| 1 | 1949-02-01 | 118 |
| 2 | 1949-03-01 | 132 |
| 3 | 1949-04-01 | 129 |
| 4 | 1949-05-01 | 121 |
3. 使用分位数进行预测
在使用TimeGPT进行时间序列预测时,您可以设置想要预测的分位数。以下是您可以做到的方法:
quantiles = [0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9]
timegpt_quantile_fcst_df = nixtla_client.forecast(
df=df, h=12,
quantiles=quantiles,
time_col='timestamp', target_col='value',
)
timegpt_quantile_fcst_df.head()INFO:nixtla.nixtla_client:Validating inputs...
INFO:nixtla.nixtla_client:Preprocessing dataframes...
INFO:nixtla.nixtla_client:Inferred freq: MS
INFO:nixtla.nixtla_client:Restricting input...
INFO:nixtla.nixtla_client:Calling Forecast Endpoint...
| timestamp | TimeGPT | TimeGPT-q-10 | TimeGPT-q-20 | TimeGPT-q-30 | TimeGPT-q-40 | TimeGPT-q-50 | TimeGPT-q-60 | TimeGPT-q-70 | TimeGPT-q-80 | TimeGPT-q-90 | |
|---|---|---|---|---|---|---|---|---|---|---|---|
| 0 | 1961-01-01 | 437.837952 | 431.987091 | 435.043799 | 435.384363 | 436.402155 | 437.837952 | 439.273749 | 440.291541 | 440.632104 | 443.688812 |
| 1 | 1961-02-01 | 426.062744 | 412.704956 | 414.832837 | 416.042432 | 421.719196 | 426.062744 | 430.406293 | 436.083057 | 437.292651 | 439.420532 |
| 2 | 1961-03-01 | 463.116577 | 437.412564 | 444.234985 | 446.420233 | 450.705762 | 463.116577 | 475.527393 | 479.812921 | 481.998169 | 488.820590 |
| 3 | 1961-04-01 | 478.244507 | 448.726837 | 455.428375 | 465.570038 | 469.879114 | 478.244507 | 486.609900 | 490.918976 | 501.060638 | 507.762177 |
| 4 | 1961-05-01 | 505.646484 | 478.409872 | 493.154315 | 497.990848 | 499.138708 | 505.646484 | 512.154260 | 513.302121 | 518.138654 | 532.883096 |
📘 Azure AI 中可用的模型
如果您使用 Azure AI 端点,请确保设置
model="azureai":
nixtla_client.forecast(..., model="azureai")对于公共 API,我们支持两种模型:
timegpt-1和timegpt-1-long-horizon。默认情况下,使用
timegpt-1。有关如何以及何时使用timegpt-1-long-horizon的详细信息,请参见 本教程。
TimeGPT 将以 TimeGPT-q-{int(100 * q)} 的格式返回每个分位数 q 的预测。
nixtla_client.plot(
df, timegpt_quantile_fcst_df,
time_col='timestamp', target_col='value',
)
需要注意的是,量化(或量化值)的选择取决于您的具体用例。对于高风险的预测,您可能会倾向于选择更保守的量化值,例如第10或第20百分位,以确保您为最坏情况做好准备。另一方面,如果您处于一个过度准备成本高的情况,您可能会选择接近中位数的量化值,如第50百分位,以在谨慎和效率之间找到平衡。
例如,如果您在大型促销活动期间为零售业务管理
历史预测
您还可以通过添加add_history=True参数来计算历史预测的分位数预测,如下所示:
timegpt_quantile_fcst_df = nixtla_client.forecast(
df=df, h=12,
quantiles=quantiles,
time_col='timestamp', target_col='value',
add_history=True,
)
timegpt_quantile_fcst_df.head()INFO:nixtla.nixtla_client:Validating inputs...
INFO:nixtla.nixtla_client:Preprocessing dataframes...
INFO:nixtla.nixtla_client:Inferred freq: MS
INFO:nixtla.nixtla_client:Calling Forecast Endpoint...
INFO:nixtla.nixtla_client:Calling Historical Forecast Endpoint...
| timestamp | TimeGPT | TimeGPT-q-10 | TimeGPT-q-20 | TimeGPT-q-30 | TimeGPT-q-40 | TimeGPT-q-50 | TimeGPT-q-60 | TimeGPT-q-70 | TimeGPT-q-80 | TimeGPT-q-90 | |
|---|---|---|---|---|---|---|---|---|---|---|---|
| 0 | 1951-01-01 | 135.483673 | 111.937768 | 120.020593 | 125.848879 | 130.828935 | 135.483673 | 140.138411 | 145.118467 | 150.946753 | 159.029579 |
| 1 | 1951-02-01 | 144.442398 | 120.896493 | 128.979318 | 134.807604 | 139.787660 | 144.442398 | 149.097136 | 154.077192 | 159.905478 | 167.988304 |
| 2 | 1951-03-01 | 157.191910 | 133.646004 | 141.728830 | 147.557116 | 152.537172 | 157.191910 | 161.846648 | 166.826703 | 172.654990 | 180.737815 |
| 3 | 1951-04-01 | 148.769363 | 125.223458 | 133.306284 | 139.134570 | 144.114625 | 148.769363 | 153.424102 | 158.404157 | 164.232443 | 172.315269 |
| 4 | 1951-05-01 | 140.472946 | 116.927041 | 125.009866 | 130.838152 | 135.818208 | 140.472946 | 145.127684 | 150.107740 | 155.936026 | 164.018852 |
nixtla_client.plot(
df, timegpt_quantile_fcst_df,
time_col='timestamp', target_col='value',
)
交叉验证
quantiles 参数也可以包括在 cross_validation 方法中,从而允许在不同的窗口和不同的分位数之间比较 TimeGPT 的性能。
timegpt_cv_quantile_fcst_df = nixtla_client.cross_validation(
df=df,
h=12,
n_windows=5,
quantiles=quantiles,
time_col='timestamp',
target_col='value',
)
timegpt_quantile_fcst_df.head()INFO:nixtla.nixtla_client:Validating inputs...
INFO:nixtla.nixtla_client:Inferred freq: MS
INFO:nixtla.nixtla_client:Validating inputs...
INFO:nixtla.nixtla_client:Preprocessing dataframes...
INFO:nixtla.nixtla_client:Inferred freq: MS
INFO:nixtla.nixtla_client:Restricting input...
INFO:nixtla.nixtla_client:Calling Forecast Endpoint...
INFO:nixtla.nixtla_client:Validating inputs...
INFO:nixtla.nixtla_client:Validating inputs...
INFO:nixtla.nixtla_client:Preprocessing dataframes...
INFO:nixtla.nixtla_client:Inferred freq: MS
INFO:nixtla.nixtla_client:Restricting input...
INFO:nixtla.nixtla_client:Calling Forecast Endpoint...
INFO:nixtla.nixtla_client:Validating inputs...
INFO:nixtla.nixtla_client:Validating inputs...
INFO:nixtla.nixtla_client:Preprocessing dataframes...
INFO:nixtla.nixtla_client:Inferred freq: MS
INFO:nixtla.nixtla_client:Restricting input...
INFO:nixtla.nixtla_client:Calling Forecast Endpoint...
INFO:nixtla.nixtla_client:Validating inputs...
INFO:nixtla.nixtla_client:Validating inputs...
INFO:nixtla.nixtla_client:Preprocessing dataframes...
INFO:nixtla.nixtla_client:Inferred freq: MS
INFO:nixtla.nixtla_client:Restricting input...
INFO:nixtla.nixtla_client:Calling Forecast Endpoint...
INFO:nixtla.nixtla_client:Validating inputs...
INFO:nixtla.nixtla_client:Validating inputs...
INFO:nixtla.nixtla_client:Preprocessing dataframes...
INFO:nixtla.nixtla_client:Inferred freq: MS
INFO:nixtla.nixtla_client:Restricting input...
INFO:nixtla.nixtla_client:Calling Forecast Endpoint...
INFO:nixtla.nixtla_client:Validating inputs...
| timestamp | TimeGPT | TimeGPT-q-10 | TimeGPT-q-20 | TimeGPT-q-30 | TimeGPT-q-40 | TimeGPT-q-50 | TimeGPT-q-60 | TimeGPT-q-70 | TimeGPT-q-80 | TimeGPT-q-90 | |
|---|---|---|---|---|---|---|---|---|---|---|---|
| 0 | 1951-01-01 | 135.483673 | 111.937768 | 120.020593 | 125.848879 | 130.828935 | 135.483673 | 140.138411 | 145.118467 | 150.946753 | 159.029579 |
| 1 | 1951-02-01 | 144.442398 | 120.896493 | 128.979318 | 134.807604 | 139.787660 | 144.442398 | 149.097136 | 154.077192 | 159.905478 | 167.988304 |
| 2 | 1951-03-01 | 157.191910 | 133.646004 | 141.728830 | 147.557116 | 152.537172 | 157.191910 | 161.846648 | 166.826703 | 172.654990 | 180.737815 |
| 3 | 1951-04-01 | 148.769363 | 125.223458 | 133.306284 | 139.134570 | 144.114625 | 148.769363 | 153.424102 | 158.404157 | 164.232443 | 172.315269 |
| 4 | 1951-05-01 | 140.472946 | 116.927041 | 125.009866 | 130.838152 | 135.818208 | 140.472946 | 145.127684 | 150.107740 | 155.936026 | 164.018852 |
cutoffs = timegpt_cv_quantile_fcst_df['cutoff'].unique()
for cutoff in cutoffs:
fig = nixtla_client.plot(
df.tail(100),
timegpt_cv_quantile_fcst_df.query('cutoff == @cutoff').drop(columns=['cutoff', 'value']),
time_col='timestamp',
target_col='value'
)
display(fig)




Give us a ⭐ on Github