torch_frame.gbdt.GBDT
- class GBDT(task_type: TaskType, num_classes: int | None = None, metric: Metric | None = None)[source]
基础类:
objectGBDT(梯度提升决策树)模型的基类,用作强基线。
- Parameters:
- tune(tf_train: TensorFrame, tf_val: TensorFrame, num_trials: int, *args, **kwargs)[source]
通过使用Optuna进行超参数调优来拟合模型。试验次数由num_trials指定。
- Parameters:
tf_train (TensorFrame) – 训练数据在
TensorFrame中。tf_val (TensorFrame) – 验证数据在
TensorFrame中。num_trials (int) – 执行超参数搜索的试验次数。
*args – 可变长度参数列表。
**kwargs – 任意关键字参数。
- predict(tf_test: TensorFrame) Tensor[来源]
在拟合模型上预测测试数据的标签/值,并返回其预测结果。
TaskType.REGRESSION: 返回原始数值。TaskType.BINARY_CLASSIFICATION: 返回为正类的概率。TaskType.MULTICLASS_CLASSIFICATION: 返回类标签预测。