优化超参数#

pytorch_forecasting.models.temporal_fusion_transformer.tuning.optimize_hyperparameters(train_dataloaders: DataLoader, val_dataloaders: DataLoader, model_path: str, max_epochs: int = 20, n_trials: int = 100, timeout: float = 28800.0, gradient_clip_val_range: Tuple[float, float] = (0.01, 100.0), hidden_size_range: Tuple[int, int] = (16, 265), hidden_continuous_size_range: Tuple[int, int] = (8, 64), attention_head_size_range: Tuple[int, int] = (1, 4), dropout_range: Tuple[float, float] = (0.1, 0.3), learning_rate_range: Tuple[float, float] = (1e-05, 1.0), use_learning_rate_finder: bool = True, trainer_kwargs: Dict[str, Any] = {}, log_dir: str = 'lightning_logs', study=None, verbose: int | bool = None, pruner=None, **kwargs)[来源]#

优化时间融合变压器超参数。

运行超参数优化。学习率是通过PyTorch Lightning学习率查找器确定的。

Parameters:
  • train_dataloaders (DataLoader) – 用于训练模型的数据加载器

  • val_dataloaders (DataLoader) – 用于验证模型的dataloader

  • model_path (str) – 模型检查点保存的文件夹

  • max_epochs (int, 可选) – 训练的最大轮数。默认为20。

  • n_trials (int, 可选) – 运行的超参数试验次数。默认为100。

  • timeout (float, 可选) – 训练停止的时间(以秒为单位),无论经过多少个epoch或验证指标,都会在此时间后停止。默认值为 3600*8.0。

  • hidden_size_range (元组[整数, 整数], 可选) – hidden_size 超参数的最小值和最大值。默认值为 (16, 265)。

  • hidden_continuous_size_range (元组[int, int], 可选) – hidden_continuous_size 超参数的最小值和最大值。默认为 (8, 64)。

  • attention_head_size_range (Tuple[int, int], optional) – attention_head_size 超参数的最小值和最大值。默认为 (1, 4)。

  • dropout_range (元组[浮点数, 浮点数], 可选) – dropout 超参数的最小值和最大值。默认值为 (0.1, 0.3)。

  • learning_rate_range (元组[浮点数, 浮点数], 可选) – 学习率范围。默认为 (1e-5, 1.0)。

  • use_learning_rate_finder (bool) – 是否使用学习率查找器或作为超参数的一部分进行优化。默认为 True。

  • trainer_kwargs (字典[字符串, 任意], 可选) – 额外的参数给 PyTorch Lightning trainer,比如 limit_train_batches。默认为 {}。

  • log_dir (str, 可选) – 用于记录tensorboard结果的文件夹。默认为“lightning_logs”。

  • 研究 (optuna.Study, 可选) – 要恢复的研究。默认情况下将创建新的研究。

  • verbose (Union[int, bool]) – 详细信息级别。 * None: 详细信息级别没有变化(相当于默认的 verbose=1)。 * 0 或 False: 仅记录警告。 * 1 或 True: 记录修剪事件。 * 2: optuna 日志记录级别为调试级别。 默认为 None。

  • pruner (optuna.pruners.BasePruner, 可选) – 要使用的optuna剪枝器。默认为optuna.pruners.SuccessiveHalvingPruner()。

  • **kwargsTemporalFusionTransformer 的其他参数。

Returns:

optuna 研究结果

Return type:

optuna.Study