ray.train.xgboost.XGBoostTrainer#
- class ray.train.xgboost.XGBoostTrainer(*args, **kwargs)[源代码]#
基类:
XGBoostTrainer用于数据并行XGBoost训练的训练器。
这个训练器使用多个 Ray 角色以分布式方式运行 XGBoost 训练循环。
备注
XGBoostTrainer不会修改或改变 XGBoost 分布式训练算法的运行方式。Ray 仅提供编排、数据摄取和容错功能。有关 XGBoost 分布式训练的更多信息,请参阅 XGBoost 文档。示例
import ray from ray.train.xgboost import XGBoostTrainer from ray.train import ScalingConfig train_dataset = ray.data.from_items( [{"x": x, "y": x + 1} for x in range(32)]) trainer = XGBoostTrainer( label_column="y", params={"objective": "reg:squarederror"}, scaling_config=ScalingConfig(num_workers=3), datasets={"train": train_dataset}, ) result = trainer.fit()
- 参数:
datasets – 用于训练和验证的 Ray 数据集。必须包含一个表示训练数据集的“train”键。所有非训练数据集将作为单独的验证集使用,每个验证集报告一个单独的指标。
label_column – 标签列的名称。训练数据集中必须存在具有此名称的列。
params – XGBoost 训练参数。有关可能参数的列表,请参阅 XGBoost 文档。
num_boost_round – 目标提升迭代次数(模型中的树)。请注意,与
xgboost.train不同,这是目标树的数量,这意味着如果您设置num_boost_round=10并传递一个已经训练了5次迭代的模型,它将再训练5次迭代,而不是再训练10次。scaling_config – 数据并行训练的配置方式。
run_config – 训练运行的配置。
dataset_config – 用于摄取输入
datasets的配置。默认情况下,所有 Ray Datasets 在各个工作节点之间平均分配。更多详情请参见DataConfig。resume_from_checkpoint – 用于从中恢复训练的检查点。
metadata – 应在此 Trainer 保存的检查点中通过
checkpoint.get_metadata()提供的字典。必须是 JSON 可序列化的。**train_kwargs – 传递给
xgboost.train()函数的额外关键字参数。
PublicAPI (测试版): 此API目前处于测试阶段,在成为稳定版本之前可能会发生变化。
方法
将自身转换为
tune.Trainable类。检查给定目录是否包含一个可恢复的 Train 实验。
运行训练。
返回此训练器的最终数据集配置的副本。
检索存储在此检查点中的 XGBoost 模型。
已弃用。
从之前中断/失败的运行中恢复一个 DataParallelTrainer。
在调用 fit() 时执行初始设置的 Trainer。