optuna.integration
The integration 模块包含用于将 Optuna 与外部机器学习框架集成的类。
对于Optuna支持的大多数机器学习框架,相应的Optuna集成类仅用于实现一个回调对象和函数,这些对象和函数符合框架特定的回调API,并在模型训练的每个中间步骤中被调用。这些回调在不同机器学习框架中实现的功能包括:
使用
optuna.trial.Trial.report()将中间模型分数报告回 Optuna 试验。根据
optuna.trial.Trial.should_prune()的结果,通过引发optuna.TrialPruned()来修剪当前模型,并且将Optuna的中间数据(如当前试验编号)报告回框架,如在
MLflowCallback中所做的那样。
对于 scikit-learn,提供了一个集成的 OptunaSearchCV 估计器,它结合了 scikit-learn BaseEstimator 功能,并可以访问类级别的 Study 对象。
每个集成的依赖项
我们总结了每个集成的必要依赖项。
集成 |
依赖项 |
|---|---|
allennlp, torch, psutil, jsonnet |
|
botorch, gpytorch, torch |
|
catboost |
|
chainermn |
|
chainer |
|
cma |
|
分布式 |
|
fastai |
|
keras |
|
lightgbm, scikit-learn |
|
lightgbm |
|
mlflow |
|
mxnet |
|
PyTorch Distributed |
torch |
PyTorch (Ignite) |
pytorch-ignite |
PyTorch (Lightning) |
pytorch-lightning |
scikit-learn, shap |
|
pandas, scipy, scikit-learn |
|
skorch |
|
tensorboard, tensorflow |
|
tensorflow, tensorflow-estimator |
|
tensorflow |
|
wandb |
|
xgboost |