XGBoost4J 入门

本教程介绍了 XGBoost 的 Java API。

数据接口

与 XGBoost Python 模块类似,XGBoost4J 使用 DMatrix 处理数据。支持 LIBSVM txt 格式文件、CSR/CSC 格式的稀疏矩阵以及密集矩阵。

  • 第一步是导入 DMatrix

    import ml.dmlc.xgboost4j.java.DMatrix;
    
  • 使用 DMatrix 构造函数从 libsvm 文本格式文件加载数据

    DMatrix dmat = new DMatrix("train.svm.txt");
    
  • 将数组传递给 DMatrix 构造函数以从稀疏矩阵加载。

    假设我们有一个稀疏矩阵

    1 0 2 0
    4 0 0 3
    3 1 2 0
    

    我们可以用 压缩稀疏行 (CSR) 格式表示稀疏矩阵

    long[] rowHeaders = new long[] {0,2,4,7};
    float[] data = new float[] {1f,2f,4f,3f,3f,1f,2f};
    int[] colIndex = new int[] {0,2,0,3,0,1,2};
    int numColumn = 4;
    DMatrix dmat = new DMatrix(rowHeaders, colIndex, data, DMatrix.SparseType.CSR, numColumn);
    

    ... 或用 压缩稀疏列 (CSC) 格式表示

    long[] colHeaders = new long[] {0,3,4,6,7};
    float[] data = new float[] {1f,4f,3f,1f,2f,2f,3f};
    int[] rowIndex = new int[] {0,1,2,2,0,2,1};
    int numRow = 3;
    DMatrix dmat = new DMatrix(colHeaders, rowIndex, data, DMatrix.SparseType.CSC, numRow);
    
  • 您也可以从密集矩阵加载数据。假设我们有如下形式的矩阵

    1    2
    3    4
    5    6
    

    使用 行主序布局,我们指定密集矩阵如下

    float[] data = new float[] {1f,2f,3f,4f,5f,6f};
    int nrow = 3;
    int ncol = 2;
    float missing = 0.0f;
    DMatrix dmat = new DMatrix(data, nrow, ncol, missing);
    
  • 设置权重

    float[] weights = new float[] {1f,2f,1f};
    dmat.setWeight(weights);
    

设置参数

设置参数时,参数以 Map 形式指定

Map<String, Object> params = new HashMap<String, Object>() {
  {
    put("eta", 1.0);
    put("max_depth", 2);
    put("objective", "binary:logistic");
    put("eval_metric", "logloss");
  }
};

训练模型

有了参数和数据,您就可以训练一个 booster 模型。

  • 导入 Booster 和 XGBoost

    import ml.dmlc.xgboost4j.java.Booster;
    import ml.dmlc.xgboost4j.java.XGBoost;
    
  • 训练

    DMatrix trainMat = new DMatrix("train.svm.txt");
    DMatrix validMat = new DMatrix("valid.svm.txt");
    // Specify a watch list to see model accuracy on data sets
    Map<String, DMatrix> watches = new HashMap<String, DMatrix>() {
      {
        put("train", trainMat);
        put("test", testMat);
      }
    };
    int nround = 2;
    Booster booster = XGBoost.train(trainMat, params, nround, watches, null, null);
    
  • 保存模型

    训练完成后,您可以保存模型并将其导出。

    booster.saveModel("model.json");
    
  • 生成带有特征映射的模型导出

    // dump without feature map
    String[] model_dump = booster.getModelDump(null, false);
    // dump with feature map
    String[] model_dump_with_feature_map = booster.getModelDump("featureMap.txt", false);
    
  • 加载模型

    Booster booster = XGBoost.loadModel("model.json");
    

预测

训练和加载模型后,您可以使用它对其他数据进行预测。结果将是一个二维浮点数组 (nsample, nclass);对于 predictLeaf(),结果的形状将是 (nsample, nclass*ntrees)

DMatrix dtest = new DMatrix("test.svm.txt");
// predict
float[][] predicts = booster.predict(dtest);
// predict leaf
float[][] leafPredicts = booster.predictLeaf(dtest, 0);