开始使用 XGBoost4J
本教程介绍 XGBoost 的 Java API。
数据接口
与 XGBoost python 模块类似,XGBoost4J 使用 DMatrix 来处理数据。支持 LIBSVM txt 格式文件、CSR/CSC 格式的稀疏矩阵以及密集矩阵。
第一步是导入 DMatrix:
import ml.dmlc.xgboost4j.java.DMatrix;
使用 DMatrix 构造函数从 libsvm 文本格式文件加载数据:
DMatrix dmat = new DMatrix("train.svm.txt");
将数组传递给 DMatrix 构造函数以从稀疏矩阵加载。
假设我们有一个稀疏矩阵
1 0 2 0 4 0 0 3 3 1 2 0
我们可以用 Compressed Sparse Row (CSR) 格式表示稀疏矩阵:
long[] rowHeaders = new long[] {0,2,4,7}; float[] data = new float[] {1f,2f,4f,3f,3f,1f,2f}; int[] colIndex = new int[] {0,2,0,3,0,1,2}; int numColumn = 4; DMatrix dmat = new DMatrix(rowHeaders, colIndex, data, DMatrix.SparseType.CSR, numColumn);
… 或在 压缩稀疏列 (CSC) 格式中:
long[] colHeaders = new long[] {0,3,4,6,7}; float[] data = new float[] {1f,4f,3f,1f,2f,2f,3f}; int[] rowIndex = new int[] {0,1,2,2,0,2,1}; int numRow = 3; DMatrix dmat = new DMatrix(colHeaders, rowIndex, data, DMatrix.SparseType.CSC, numRow);
您也可以从密集矩阵加载数据。假设我们有一个矩阵形式为
1 2 3 4 5 6
使用 行优先布局,我们指定密集矩阵如下:
float[] data = new float[] {1f,2f,3f,4f,5f,6f}; int nrow = 3; int ncol = 2; float missing = 0.0f; DMatrix dmat = new DMatrix(data, nrow, ncol, missing);
设置权重:
float[] weights = new float[] {1f,2f,1f}; dmat.setWeight(weights);
设置参数
要设置参数,参数被指定为一个映射:
Map<String, Object> params = new HashMap<String, Object>() {
{
put("eta", 1.0);
put("max_depth", 2);
put("objective", "binary:logistic");
put("eval_metric", "logloss");
}
};
训练模型
通过参数和数据,您能够训练一个增强模型。
导入 Booster 和 XGBoost:
import ml.dmlc.xgboost4j.java.Booster; import ml.dmlc.xgboost4j.java.XGBoost;
训练
DMatrix trainMat = new DMatrix("train.svm.txt"); DMatrix validMat = new DMatrix("valid.svm.txt"); // Specify a watch list to see model accuracy on data sets Map<String, DMatrix> watches = new HashMap<String, DMatrix>() { { put("train", trainMat); put("test", testMat); } }; int nround = 2; Booster booster = XGBoost.train(trainMat, params, nround, watches, null, null);
保存模型
训练后,您可以保存模型并将其导出。
booster.saveModel("model.bin");
使用特征图生成模型转储
// dump without feature map String[] model_dump = booster.getModelDump(null, false); // dump with feature map String[] model_dump_with_feature_map = booster.getModelDump("featureMap.txt", false);
加载模型
Booster booster = XGBoost.loadModel("model.bin");
预测
在训练并加载模型后,您可以使用它对其他数据进行预测。结果将是一个二维浮点数组 (nsample, nclass);对于 predictLeaf(),结果将是形状为 (nsample, nclass*ntrees) 的数组。
DMatrix dtest = new DMatrix("test.svm.txt");
// predict
float[][] predicts = booster.predict(dtest);
// predict leaf
float[][] leafPredicts = booster.predictLeaf(dtest, 0);