跳到目录

保存从 xgboost()xgb.train() 创建的 XGBoost 模型。调用 xgb.load.raw() 将模型从原始向量加载回来。

用法

xgb.save.raw(model, raw_format = "ubj")

参数

model

模型对象。

raw_format

编码 booster 的格式

  • "json":将 booster 编码为 JSON 文本文档。

  • "ubj":将 booster 编码为通用二进制 JSON。

  • "deprecated":将 booster 编码为旧的自定义二进制格式。

示例

DONTSHOW({RhpcBLASctl::omp_set_num_threads(1)})
data(agaricus.train, package = "xgboost")
data(agaricus.test, package = "xgboost")

## Keep the number of threads to 1 for examples
nthread <- 1
data.table::setDTthreads(nthread)

train <- agaricus.train
test <- agaricus.test

bst <- xgb.train(
  data = xgb.DMatrix(train$data, label = train$label, nthread = 1),
  nrounds = 2,
  params = xgb.params(
    max_depth = 2,
    nthread = nthread,
    objective = "binary:logistic"
  )
)

raw <- xgb.save.raw(bst)
bst <- xgb.load.raw(raw)