优化超参数#
- pytorch_forecasting.models.temporal_fusion_transformer.tuning.optimize_hyperparameters(train_dataloaders: DataLoader, val_dataloaders: DataLoader, model_path: str, max_epochs: int = 20, n_trials: int = 100, timeout: float = 28800.0, gradient_clip_val_range: Tuple[float, float] = (0.01, 100.0), hidden_size_range: Tuple[int, int] = (16, 265), hidden_continuous_size_range: Tuple[int, int] = (8, 64), attention_head_size_range: Tuple[int, int] = (1, 4), dropout_range: Tuple[float, float] = (0.1, 0.3), learning_rate_range: Tuple[float, float] = (1e-05, 1.0), use_learning_rate_finder: bool = True, trainer_kwargs: Dict[str, Any] = {}, log_dir: str = 'lightning_logs', study=None, verbose: int | bool = None, pruner=None, **kwargs)[来源]#
优化时间融合变压器超参数。
运行超参数优化。学习率是通过PyTorch Lightning学习率查找器确定的。
- Parameters:
train_dataloaders (DataLoader) – 用于训练模型的数据加载器
val_dataloaders (DataLoader) – 用于验证模型的dataloader
model_path (str) – 模型检查点保存的文件夹
max_epochs (int, 可选) – 训练的最大轮数。默认为20。
n_trials (int, 可选) – 运行的超参数试验次数。默认为100。
timeout (float, 可选) – 训练停止的时间(以秒为单位),无论经过多少个epoch或验证指标,都会在此时间后停止。默认值为 3600*8.0。
hidden_size_range (元组[整数, 整数], 可选) –
hidden_size
超参数的最小值和最大值。默认值为 (16, 265)。hidden_continuous_size_range (元组[int, int], 可选) –
hidden_continuous_size
超参数的最小值和最大值。默认为 (8, 64)。attention_head_size_range (Tuple[int, int], optional) –
attention_head_size
超参数的最小值和最大值。默认为 (1, 4)。dropout_range (元组[浮点数, 浮点数], 可选) –
dropout
超参数的最小值和最大值。默认值为 (0.1, 0.3)。learning_rate_range (元组[浮点数, 浮点数], 可选) – 学习率范围。默认为 (1e-5, 1.0)。
use_learning_rate_finder (bool) – 是否使用学习率查找器或作为超参数的一部分进行优化。默认为 True。
trainer_kwargs (字典[字符串, 任意], 可选) – 额外的参数给 PyTorch Lightning trainer,比如
limit_train_batches
。默认为 {}。log_dir (str, 可选) – 用于记录tensorboard结果的文件夹。默认为“lightning_logs”。
研究 (optuna.Study, 可选) – 要恢复的研究。默认情况下将创建新的研究。
verbose (Union[int, bool]) – 详细信息级别。 * None: 详细信息级别没有变化(相当于默认的 verbose=1)。 * 0 或 False: 仅记录警告。 * 1 或 True: 记录修剪事件。 * 2: optuna 日志记录级别为调试级别。 默认为 None。
pruner (optuna.pruners.BasePruner, 可选) – 要使用的optuna剪枝器。默认为optuna.pruners.SuccessiveHalvingPruner()。
**kwargs –
TemporalFusionTransformer
的其他参数。
- Returns:
optuna 研究结果
- Return type:
optuna.Study