torch_frame.gbdt.CatBoost
- class CatBoost(task_type: TaskType, num_classes: int | None = None, metric: Metric | None = None)[来源]
基础类:
GBDT使用Optuna进行超参数调优的CatBoost模型实现。
此实现扩展了GBDT,旨在通过优化给定的目标函数来找到最佳超参数。
- objective(trial: Any, train_x: DataFrame, train_y: ndarray, val_x: DataFrame, val_y: ndarray, cat_features: ndarray, num_boost_round: int) float[source]
需要优化的目标函数。
- Parameters:
试验 (optuna.trial.Trial) – Optuna 试验对象。
train_x (DataFrame) – 训练数据。
train_y (numpy.ndarray) – 训练标签。
val_x (DataFrame) – 验证数据。
val_y (numpy.ndarray) – 验证标签。
cat_features (numpy.ndarray) – 包含分类特征索引的数组。
num_boost_round (int) – 提升轮数。
- Returns:
最佳目标值。回归任务的均方根误差和分类任务的准确率。
- Return type: