跳到内容

对于 XGBoost 模型的序列化,可以使用 R 的序列化工具,例如 save()saveRDS() 来序列化 XGBoost 模型对象,但 XGBoost 也提供了自己的序列化工具,具有更好的兼容性保证,允许在 XGBoost 的其他语言绑定中加载这些模型。

请注意,一个 xgb.Booster 对象(xgb.train() 生成,有关 xgboost() 生成的对象的详细信息请参阅文档其余部分),除了其核心组件外,可能还包含

  • 额外的模型配置(可通过 xgb.config() 访问),其中包括模型拟合参数,如 max_depth,以及运行时参数,如 nthread。这些对于预测/重要性/绘图来说不一定有用。

  • 额外的 R 特有属性 - 例如回调结果,如评估日志,它们以 data.table 对象形式保存,如果存在,可通过 attributes(model)$evaluation_log 访问。

第一种(配置)与模型本身不具有相同的兼容性保证,包括通过 xgb.attributes() 设置和访问的属性 - 也就是说,无论使用何种序列化工具,在不同版本的 XGBoost 中加载增强器后,此类配置可能会丢失。使用 saveRDS() 时会保存这些配置,但如果加载到不兼容的 XGBoost 版本中,它们将被丢弃。在使用 XGBoost 公共接口中的序列化工具(包括 xgb.save()xgb.save.raw())时,不会保存这些配置。

第二种(R 属性)不是标准 XGBoost 模型结构的一部分,因此在使用 XGBoost 自己的序列化工具时不会保存。这些属性仅用于提供信息,例如跟踪模型拟合时的评估指标,或保存生成模型的 R 调用,但除此之外,它们不用于预测/重要性/绘图等。这些 R 属性仅在使用 R 的序列化工具时才会保留。

除了 xgb.train() 生成的常规 xgb.Booster 对象外,函数 xgboost() 生成的对象具有不同的子类 xgboost(继承自 xgb.Booster),该子类将其他额外元数据作为 R 属性保存,例如分类问题中的类名,并且具有一个专用的 predict 方法,该方法使用不同的默认值并接受不同的参数名称。XGBoost 自己的序列化工具可以处理这个 xgboost 类,但由于它们不保留 R 属性,因此反序列化时,结果对象会向下转型为常规的 xgb.Booster 类(即它会丢失元数据,并且结果对象将使用 predict.xgb.Booster() 而非 predict.xgboost())- 对于这些 xgboost 对象,如果需要额外的功能,使用 saveRDS 可能是一个更好的选择。

请注意,从版本 2.1.0 及 onwards 的 XGBoost R 模型与版本 2.1.0 之前的 XGBoost 模型具有非常不同的 R 对象结构,并且相互不兼容。因此,在版本 2.1.0 之前使用 R 序列化工具(如 saveRDS()save())保存的模型将无法在后续的 xgboost 版本中使用,反之亦然。请注意,理论上 R 模型对象的结构将来可能再次改变,因此对于长期存储,应优先使用 XGBoost 的序列化工具。

此外,请注意,XGBoost 的模型对象可能无法使用第三方 R 包(如 qsqs2)进行序列化。

详细信息

使用 xgb.save() 将 XGBoost 模型保存为独立文件。可以通过指定 JSON 扩展名来选择 JSON 格式。要读回模型,请使用 xgb.load()

使用 xgb.save.raw() 以未来兼容的方式将 XGBoost 模型保存为原始字节序列(向量)。XGBoost 的未来版本将能够读取原始字节并重新构建相应的模型。要读回模型,请使用 xgb.load.raw()。如果您希望将 XGBoost 模型作为另一个 R 对象的一部分进行持久化,则 xgb.save.raw() 函数很有用。

如果您需要增强器可能拥有的 R 特有属性,例如评估日志或模型类为 xgboost 而非 xgb.Booster,则使用 saveRDS(),但请注意,此类对象的未来兼容性不受 XGBoost 控制,因为它依赖于 R 的序列化格式(例如,请参阅 base R 中 serializesave() 的详细信息部分)。

有关模型持久化和存档的更多详细信息和解释,请查阅页面 https://docs.xgboost.com.cn/en/release_3.0.0/tutorials/saving_model.html

示例

data(agaricus.train, package = "xgboost")

bst <- xgb.train(
  data = xgb.DMatrix(agaricus.train$data, label = agaricus.train$label, nthread = 1),
  nrounds = 2,
  params = xgb.params(
    max_depth = 2,
    nthread = 2,
    objective = "binary:logistic"
  )
)

# Save as a stand-alone file; load it with xgb.load()
fname <- file.path(tempdir(), "xgb_model.ubj")
xgb.save(bst, fname)
bst2 <- xgb.load(fname)

# Save as a stand-alone file (JSON); load it with xgb.load()
fname <- file.path(tempdir(), "xgb_model.json")
xgb.save(bst, fname)
bst2 <- xgb.load(fname)

# Save as a raw byte vector; load it with xgb.load.raw()
xgb_bytes <- xgb.save.raw(bst)
bst2 <- xgb.load.raw(xgb_bytes)

# Persist XGBoost model as part of another R object
obj <- list(xgb_model_bytes = xgb.save.raw(bst), description = "My first XGBoost model")
# Persist the R object. Here, saveRDS() is okay, since it doesn't persist
# xgb.Booster directly. What's being persisted is the future-proof byte representation
# as given by xgb.save.raw().
fname <- file.path(tempdir(), "my_object.Rds")
saveRDS(obj, fname)
# Read back the R object
obj2 <- readRDS(fname)
# Re-construct xgb.Booster object from the bytes
bst2 <- xgb.load.raw(obj2$xgb_model_bytes)