微调

::: {#02134a5e .cell 0=‘隐’ 1=‘藏’}

!pip install -Uqq nixtla

:::

::: {#c6d8f223 .cell 0=‘隐’ 1=‘藏’}

from nixtla.utils import in_colab

:::

::: {#3c6c0333 .cell 0=‘隐’ 1=‘藏’}

IN_COLAB = in_colab()

:::

::: {#ce98fab5 .cell 0=‘隐’ 1=‘藏’}

if not IN_COLAB:    from nixtla.utils import colab_badge    from dotenv import load_dotenv

:::

微调是更加有效利用 TimeGPT 的一种强大过程。像 TimeGPT 这样的基础模型是在大量数据上预训练的,捕捉了广泛的特征和模式。这些模型可以针对特定的上下文或领域进行专业化。通过微调,模型的参数被优化以预测新的任务,使其能够将先前积累的丰富知识调整到新数据的需求上。因此,微调作为一个关键的桥梁,将 TimeGPT 的广泛能力与您的任务特性连接起来。具体来说,微调的过程包括在您的输入数据上进行一定次数的训练迭代,以最小化预测误差。然后将使用更新后的模型来生成预测结果。要控制迭代次数,请使用 forecast 方法的 finetune_steps 参数。

if not IN_COLAB:    load_dotenv()        colab_badge('docs/tutorials/06_finetuning')

1. 导入包首先,我们导入所需的包并初始化Nixtla客户端,

import pandas as pdfrom nixtla import NixtlaClientfrom utilsforecast.losses import mae, msefrom utilsforecast.evaluation import evaluate
nixtla_client = NixtlaClient(    # defaults to os.environ.get("NIXTLA_API_KEY")    api_key = 'my_api_key_provided_by_nixtla')

👍 使用Azure AI端点> > 要使用Azure AI端点,请记得还要设置base_url参数:> > nixtla_client = NixtlaClient(base_url="你的Azure AI端点", api_key="你的api_key")

::: {#5cd61549-0b00-4a42-a98e-239fa4fae5e5 .cell 0=‘隐’ 1=‘藏’}

if not IN_COLAB:    nixtla_client = NixtlaClient()

:::

2. 加载数据

df = pd.read_csv('https://raw.githubusercontent.com/Nixtla/transfer-learning-time-series/main/datasets/air_passengers.csv')df.head()
timestamp value
0 1949-01-01 112
1 1949-02-01 118
2 1949-03-01 132
3 1949-04-01 129
4 1949-05-01 121

3. 微调

在这里,finetune_steps=10 意味着模型将对您的时间序列数据进行10次训练迭代。

timegpt_fcst_finetune_df = nixtla_client.forecast(    df=df, h=12, finetune_steps=10,    time_col='timestamp', target_col='value',)
INFO:nixtla.nixtla_client:Validating inputs...
INFO:nixtla.nixtla_client:Inferred freq: MS
INFO:nixtla.nixtla_client:Querying model metadata...
INFO:nixtla.nixtla_client:Preprocessing dataframes...
INFO:nixtla.nixtla_client:Calling Forecast Endpoint...

📘 Azure AI中的可用模型>> 如果您使用的是Azure AI端点,请确保设置model="azureai":>> nixtla_client.forecast(..., model="azureai")>> 对于公共API,我们支持两个模型:timegpt-1timegpt-1-long-horizon。> > 默认情况下,使用timegpt-1。请参阅此教程了解如何以及何时使用timegpt-1-long-horizon

nixtla_client.plot(    df, timegpt_fcst_finetune_df,     time_col='timestamp', target_col='value',)

3.1 使用 finetune_depth 控制微调的深度还可以通过 finetune_depth 参数控制微调的深度。finetune_depth 的取值范围为 [1, 2, 3, 4, 5]。默认值为 1,这意味着只调整模型参数的一小部分,而值为 5 则可以微调最多的参数。增加 finetune_depth 还会增加生成预测所需的时间。

让我们进行一个小实验,看看 finetune_depth 如何影响性能。

train = df[:-24]test = df[-24:]
depths = [1, 2, 3, 4, 5]test = test.copy()for depth in depths:    preds_df = nixtla_client.forecast(    df=train,     h=24,     finetune_steps=5,    finetune_depth=depth,    time_col='timestamp',     target_col='value')    preds = preds_df['TimeGPT'].values    test.loc[:,f'TimeGPT_depth{depth}'] = preds
INFO:nixtla.nixtla_client:Validating inputs...
INFO:nixtla.nixtla_client:Inferred freq: MS
WARNING:nixtla.nixtla_client:The specified horizon "h" exceeds the model horizon. This may lead to less accurate forecasts. Please consider using a smaller horizon.
INFO:nixtla.nixtla_client:Preprocessing dataframes...
INFO:nixtla.nixtla_client:Calling Forecast Endpoint...
INFO:nixtla.nixtla_client:Validating inputs...
INFO:nixtla.nixtla_client:Inferred freq: MS
WARNING:nixtla.nixtla_client:The specified horizon "h" exceeds the model horizon. This may lead to less accurate forecasts. Please consider using a smaller horizon.
INFO:nixtla.nixtla_client:Preprocessing dataframes...
INFO:nixtla.nixtla_client:Calling Forecast Endpoint...
INFO:nixtla.nixtla_client:Validating inputs...
INFO:nixtla.nixtla_client:Inferred freq: MS
WARNING:nixtla.nixtla_client:The specified horizon "h" exceeds the model horizon. This may lead to less accurate forecasts. Please consider using a smaller horizon.
INFO:nixtla.nixtla_client:Preprocessing dataframes...
INFO:nixtla.nixtla_client:Calling Forecast Endpoint...
INFO:nixtla.nixtla_client:Validating inputs...
INFO:nixtla.nixtla_client:Inferred freq: MS
WARNING:nixtla.nixtla_client:The specified horizon "h" exceeds the model horizon. This may lead to less accurate forecasts. Please consider using a smaller horizon.
INFO:nixtla.nixtla_client:Preprocessing dataframes...
INFO:nixtla.nixtla_client:Calling Forecast Endpoint...
INFO:nixtla.nixtla_client:Validating inputs...
INFO:nixtla.nixtla_client:Inferred freq: MS
WARNING:nixtla.nixtla_client:The specified horizon "h" exceeds the model horizon. This may lead to less accurate forecasts. Please consider using a smaller horizon.
INFO:nixtla.nixtla_client:Preprocessing dataframes...
INFO:nixtla.nixtla_client:Calling Forecast Endpoint...
test['unique_id'] = 0evaluation = evaluate(test, metrics=[mae, mse], time_col="timestamp", target_col="value")evaluation
unique_id metric TimeGPT_depth1 TimeGPT_depth2 TimeGPT_depth3 TimeGPT_depth4 TimeGPT_depth5
0 0 mae 22.805146 17.929682 21.320125 24.944233 28.735563
1 0 mse 683.303778 462.133945 678.182747 1003.023709 1119.906759

正如您所看到的,增加微调的深度可以提高模型的性能,但也可能由于过拟合而使性能下降。因此,请记住,微调可能需要一些尝试和错误。您可能需要根据您的具体需求和数据的复杂性调整 finetune_steps 的数量和 finetune_depth 的水平。通常,对于大型数据集,更高的 finetune_depth 效果更好。在本教程中,由于我们是在一个非常短的数据集上预测单个序列,因此增加深度导致了过拟合。建议在微调过程中监控模型的性能并根据需要进行调整。请注意,更多的 finetune_steps 和更大的 finetune_depth 值可能会导致更长的训练时间,如果没有适当管理,可能会导致过拟合。请记住,微调是一个强大的功能,但应谨慎和细心地使用。

有关使用特定损失函数进行微调的详细指南,请查看使用特定损失函数进行微调教程。

Give us a ⭐ on Github