开始使用 XGBoost4J

本教程介绍 XGBoost 的 Java API。

数据接口

与 XGBoost python 模块类似,XGBoost4J 使用 DMatrix 来处理数据。支持 LIBSVM txt 格式文件、CSR/CSC 格式的稀疏矩阵以及密集矩阵。

  • 第一步是导入 DMatrix:

    import ml.dmlc.xgboost4j.java.DMatrix;
    
  • 使用 DMatrix 构造函数从 libsvm 文本格式文件加载数据:

    DMatrix dmat = new DMatrix("train.svm.txt");
    
  • 将数组传递给 DMatrix 构造函数以从稀疏矩阵加载。

    假设我们有一个稀疏矩阵

    1 0 2 0
    4 0 0 3
    3 1 2 0
    

    我们可以用 Compressed Sparse Row (CSR) 格式表示稀疏矩阵:

    long[] rowHeaders = new long[] {0,2,4,7};
    float[] data = new float[] {1f,2f,4f,3f,3f,1f,2f};
    int[] colIndex = new int[] {0,2,0,3,0,1,2};
    int numColumn = 4;
    DMatrix dmat = new DMatrix(rowHeaders, colIndex, data, DMatrix.SparseType.CSR, numColumn);
    

    … 或在 压缩稀疏列 (CSC) 格式中:

    long[] colHeaders = new long[] {0,3,4,6,7};
    float[] data = new float[] {1f,4f,3f,1f,2f,2f,3f};
    int[] rowIndex = new int[] {0,1,2,2,0,2,1};
    int numRow = 3;
    DMatrix dmat = new DMatrix(colHeaders, rowIndex, data, DMatrix.SparseType.CSC, numRow);
    
  • 您也可以从密集矩阵加载数据。假设我们有一个矩阵形式为

    1    2
    3    4
    5    6
    

    使用 行优先布局,我们指定密集矩阵如下:

    float[] data = new float[] {1f,2f,3f,4f,5f,6f};
    int nrow = 3;
    int ncol = 2;
    float missing = 0.0f;
    DMatrix dmat = new DMatrix(data, nrow, ncol, missing);
    
  • 设置权重:

    float[] weights = new float[] {1f,2f,1f};
    dmat.setWeight(weights);
    

设置参数

要设置参数,参数被指定为一个映射:

Map<String, Object> params = new HashMap<String, Object>() {
  {
    put("eta", 1.0);
    put("max_depth", 2);
    put("objective", "binary:logistic");
    put("eval_metric", "logloss");
  }
};

训练模型

通过参数和数据,您能够训练一个增强模型。

  • 导入 Booster 和 XGBoost:

    import ml.dmlc.xgboost4j.java.Booster;
    import ml.dmlc.xgboost4j.java.XGBoost;
    
  • 训练

    DMatrix trainMat = new DMatrix("train.svm.txt");
    DMatrix validMat = new DMatrix("valid.svm.txt");
    // Specify a watch list to see model accuracy on data sets
    Map<String, DMatrix> watches = new HashMap<String, DMatrix>() {
      {
        put("train", trainMat);
        put("test", testMat);
      }
    };
    int nround = 2;
    Booster booster = XGBoost.train(trainMat, params, nround, watches, null, null);
    
  • 保存模型

    训练后,您可以保存模型并将其导出。

    booster.saveModel("model.bin");
    
  • 使用特征图生成模型转储

    // dump without feature map
    String[] model_dump = booster.getModelDump(null, false);
    // dump with feature map
    String[] model_dump_with_feature_map = booster.getModelDump("featureMap.txt", false);
    
  • 加载模型

    Booster booster = XGBoost.loadModel("model.bin");
    

预测

在训练并加载模型后,您可以使用它对其他数据进行预测。结果将是一个二维浮点数组 (nsample, nclass);对于 predictLeaf(),结果将是形状为 (nsample, nclass*ntrees) 的数组。

DMatrix dtest = new DMatrix("test.svm.txt");
// predict
float[][] predicts = booster.predict(dtest);
// predict leaf
float[][] leafPredicts = booster.predictLeaf(dtest, 0);