XGBoost4J 入门
本教程介绍了 XGBoost 的 Java API。
数据接口
与 XGBoost Python 模块类似,XGBoost4J 使用 DMatrix 处理数据。支持 LIBSVM txt 格式文件、CSR/CSC 格式的稀疏矩阵以及密集矩阵。
第一步是导入 DMatrix
import ml.dmlc.xgboost4j.java.DMatrix;
使用 DMatrix 构造函数从 libsvm 文本格式文件加载数据
DMatrix dmat = new DMatrix("train.svm.txt");
将数组传递给 DMatrix 构造函数以从稀疏矩阵加载。
假设我们有一个稀疏矩阵
1 0 2 0 4 0 0 3 3 1 2 0
我们可以用 压缩稀疏行 (CSR) 格式表示稀疏矩阵
long[] rowHeaders = new long[] {0,2,4,7}; float[] data = new float[] {1f,2f,4f,3f,3f,1f,2f}; int[] colIndex = new int[] {0,2,0,3,0,1,2}; int numColumn = 4; DMatrix dmat = new DMatrix(rowHeaders, colIndex, data, DMatrix.SparseType.CSR, numColumn);
... 或用 压缩稀疏列 (CSC) 格式表示
long[] colHeaders = new long[] {0,3,4,6,7}; float[] data = new float[] {1f,4f,3f,1f,2f,2f,3f}; int[] rowIndex = new int[] {0,1,2,2,0,2,1}; int numRow = 3; DMatrix dmat = new DMatrix(colHeaders, rowIndex, data, DMatrix.SparseType.CSC, numRow);
您也可以从密集矩阵加载数据。假设我们有如下形式的矩阵
1 2 3 4 5 6
使用 行主序布局,我们指定密集矩阵如下
float[] data = new float[] {1f,2f,3f,4f,5f,6f}; int nrow = 3; int ncol = 2; float missing = 0.0f; DMatrix dmat = new DMatrix(data, nrow, ncol, missing);
设置权重
float[] weights = new float[] {1f,2f,1f}; dmat.setWeight(weights);
设置参数
设置参数时,参数以 Map 形式指定
Map<String, Object> params = new HashMap<String, Object>() {
{
put("eta", 1.0);
put("max_depth", 2);
put("objective", "binary:logistic");
put("eval_metric", "logloss");
}
};
训练模型
有了参数和数据,您就可以训练一个 booster 模型。
导入 Booster 和 XGBoost
import ml.dmlc.xgboost4j.java.Booster; import ml.dmlc.xgboost4j.java.XGBoost;
训练
DMatrix trainMat = new DMatrix("train.svm.txt"); DMatrix validMat = new DMatrix("valid.svm.txt"); // Specify a watch list to see model accuracy on data sets Map<String, DMatrix> watches = new HashMap<String, DMatrix>() { { put("train", trainMat); put("test", testMat); } }; int nround = 2; Booster booster = XGBoost.train(trainMat, params, nround, watches, null, null);
保存模型
训练完成后,您可以保存模型并将其导出。
booster.saveModel("model.json");
生成带有特征映射的模型导出
// dump without feature map String[] model_dump = booster.getModelDump(null, false); // dump with feature map String[] model_dump_with_feature_map = booster.getModelDump("featureMap.txt", false);
加载模型
Booster booster = XGBoost.loadModel("model.json");
预测
训练和加载模型后,您可以使用它对其他数据进行预测。结果将是一个二维浮点数组 (nsample, nclass)
;对于 predictLeaf()
,结果的形状将是 (nsample, nclass*ntrees)
。
DMatrix dtest = new DMatrix("test.svm.txt");
// predict
float[][] predicts = booster.predict(dtest);
// predict leaf
float[][] leafPredicts = booster.predictLeaf(dtest, 0);