if not IN_COLAB: load_dotenv() colab_badge('docs/tutorials/06_finetuning')微调
::: {#02134a5e .cell 0=‘隐’ 1=‘藏’}
!pip install -Uqq nixtla:::
::: {#c6d8f223 .cell 0=‘隐’ 1=‘藏’}
from nixtla.utils import in_colab:::
::: {#3c6c0333 .cell 0=‘隐’ 1=‘藏’}
IN_COLAB = in_colab():::
::: {#ce98fab5 .cell 0=‘隐’ 1=‘藏’}
if not IN_COLAB: from nixtla.utils import colab_badge from dotenv import load_dotenv:::
微调是更加有效利用 TimeGPT 的一种强大过程。像 TimeGPT 这样的基础模型是在大量数据上预训练的,捕捉了广泛的特征和模式。这些模型可以针对特定的上下文或领域进行专业化。通过微调,模型的参数被优化以预测新的任务,使其能够将先前积累的丰富知识调整到新数据的需求上。因此,微调作为一个关键的桥梁,将 TimeGPT 的广泛能力与您的任务特性连接起来。具体来说,微调的过程包括在您的输入数据上进行一定次数的训练迭代,以最小化预测误差。然后将使用更新后的模型来生成预测结果。要控制迭代次数,请使用 forecast 方法的 finetune_steps 参数。
1. 导入包首先,我们导入所需的包并初始化Nixtla客户端,
import pandas as pdfrom nixtla import NixtlaClientfrom utilsforecast.losses import mae, msefrom utilsforecast.evaluation import evaluatenixtla_client = NixtlaClient( # defaults to os.environ.get("NIXTLA_API_KEY") api_key = 'my_api_key_provided_by_nixtla')👍 使用Azure AI端点> > 要使用Azure AI端点,请记得还要设置
base_url参数:> >nixtla_client = NixtlaClient(base_url="你的Azure AI端点", api_key="你的api_key")
::: {#5cd61549-0b00-4a42-a98e-239fa4fae5e5 .cell 0=‘隐’ 1=‘藏’}
if not IN_COLAB: nixtla_client = NixtlaClient():::
2. 加载数据
df = pd.read_csv('https://raw.githubusercontent.com/Nixtla/transfer-learning-time-series/main/datasets/air_passengers.csv')df.head()| timestamp | value | |
|---|---|---|
| 0 | 1949-01-01 | 112 |
| 1 | 1949-02-01 | 118 |
| 2 | 1949-03-01 | 132 |
| 3 | 1949-04-01 | 129 |
| 4 | 1949-05-01 | 121 |
3. 微调
在这里,finetune_steps=10 意味着模型将对您的时间序列数据进行10次训练迭代。
timegpt_fcst_finetune_df = nixtla_client.forecast( df=df, h=12, finetune_steps=10, time_col='timestamp', target_col='value',)INFO:nixtla.nixtla_client:Validating inputs...
INFO:nixtla.nixtla_client:Inferred freq: MS
INFO:nixtla.nixtla_client:Querying model metadata...
INFO:nixtla.nixtla_client:Preprocessing dataframes...
INFO:nixtla.nixtla_client:Calling Forecast Endpoint...
📘 Azure AI中的可用模型>> 如果您使用的是Azure AI端点,请确保设置
model="azureai":>>nixtla_client.forecast(..., model="azureai")>> 对于公共API,我们支持两个模型:timegpt-1和timegpt-1-long-horizon。> > 默认情况下,使用timegpt-1。请参阅此教程了解如何以及何时使用timegpt-1-long-horizon。
nixtla_client.plot( df, timegpt_fcst_finetune_df, time_col='timestamp', target_col='value',)
3.1 使用 finetune_depth 控制微调的深度还可以通过 finetune_depth 参数控制微调的深度。finetune_depth 的取值范围为 [1, 2, 3, 4, 5]。默认值为 1,这意味着只调整模型参数的一小部分,而值为 5 则可以微调最多的参数。增加 finetune_depth 还会增加生成预测所需的时间。
让我们进行一个小实验,看看 finetune_depth 如何影响性能。
train = df[:-24]test = df[-24:]depths = [1, 2, 3, 4, 5]test = test.copy()for depth in depths: preds_df = nixtla_client.forecast( df=train, h=24, finetune_steps=5, finetune_depth=depth, time_col='timestamp', target_col='value') preds = preds_df['TimeGPT'].values test.loc[:,f'TimeGPT_depth{depth}'] = predsINFO:nixtla.nixtla_client:Validating inputs...
INFO:nixtla.nixtla_client:Inferred freq: MS
WARNING:nixtla.nixtla_client:The specified horizon "h" exceeds the model horizon. This may lead to less accurate forecasts. Please consider using a smaller horizon.
INFO:nixtla.nixtla_client:Preprocessing dataframes...
INFO:nixtla.nixtla_client:Calling Forecast Endpoint...
INFO:nixtla.nixtla_client:Validating inputs...
INFO:nixtla.nixtla_client:Inferred freq: MS
WARNING:nixtla.nixtla_client:The specified horizon "h" exceeds the model horizon. This may lead to less accurate forecasts. Please consider using a smaller horizon.
INFO:nixtla.nixtla_client:Preprocessing dataframes...
INFO:nixtla.nixtla_client:Calling Forecast Endpoint...
INFO:nixtla.nixtla_client:Validating inputs...
INFO:nixtla.nixtla_client:Inferred freq: MS
WARNING:nixtla.nixtla_client:The specified horizon "h" exceeds the model horizon. This may lead to less accurate forecasts. Please consider using a smaller horizon.
INFO:nixtla.nixtla_client:Preprocessing dataframes...
INFO:nixtla.nixtla_client:Calling Forecast Endpoint...
INFO:nixtla.nixtla_client:Validating inputs...
INFO:nixtla.nixtla_client:Inferred freq: MS
WARNING:nixtla.nixtla_client:The specified horizon "h" exceeds the model horizon. This may lead to less accurate forecasts. Please consider using a smaller horizon.
INFO:nixtla.nixtla_client:Preprocessing dataframes...
INFO:nixtla.nixtla_client:Calling Forecast Endpoint...
INFO:nixtla.nixtla_client:Validating inputs...
INFO:nixtla.nixtla_client:Inferred freq: MS
WARNING:nixtla.nixtla_client:The specified horizon "h" exceeds the model horizon. This may lead to less accurate forecasts. Please consider using a smaller horizon.
INFO:nixtla.nixtla_client:Preprocessing dataframes...
INFO:nixtla.nixtla_client:Calling Forecast Endpoint...
test['unique_id'] = 0evaluation = evaluate(test, metrics=[mae, mse], time_col="timestamp", target_col="value")evaluation| unique_id | metric | TimeGPT_depth1 | TimeGPT_depth2 | TimeGPT_depth3 | TimeGPT_depth4 | TimeGPT_depth5 | |
|---|---|---|---|---|---|---|---|
| 0 | 0 | mae | 22.805146 | 17.929682 | 21.320125 | 24.944233 | 28.735563 |
| 1 | 0 | mse | 683.303778 | 462.133945 | 678.182747 | 1003.023709 | 1119.906759 |
正如您所看到的,增加微调的深度可以提高模型的性能,但也可能由于过拟合而使性能下降。因此,请记住,微调可能需要一些尝试和错误。您可能需要根据您的具体需求和数据的复杂性调整 finetune_steps 的数量和 finetune_depth 的水平。通常,对于大型数据集,更高的 finetune_depth 效果更好。在本教程中,由于我们是在一个非常短的数据集上预测单个序列,因此增加深度导致了过拟合。建议在微调过程中监控模型的性能并根据需要进行调整。请注意,更多的 finetune_steps 和更大的 finetune_depth 值可能会导致更长的训练时间,如果没有适当管理,可能会导致过拟合。请记住,微调是一个强大的功能,但应谨慎和细心地使用。
有关使用特定损失函数进行微调的详细指南,请查看使用特定损失函数进行微调教程。
Give us a ⭐ on Github