模型IO简介
自 2.1.0 版本以来,XGBoost 的默认模型格式是 UBJSON 格式,此选项已启用,用于将模型序列化到文件、将模型序列化到缓冲区,以及用于内存快照(pickle 等类似方式)。
在 XGBoost 1.0.0 版本中,我们引入了使用 JSON 格式保存/加载 XGBoost 模型和相关的训练超参数的支持,旨在用一种易于重用的开放格式取代旧的内部二进制格式。随后在 XGBoost 1.6.0 版本中,添加了对 通用二进制 JSON (Universal Binary JSON) 的额外支持,作为一种优化手段,以提高模型IO效率,此格式在 2.1 版本中被设置为默认格式。
JSON 和 UBJSON 具有相同的文档结构,但表示形式不同,我们将它们统称为 JSON 格式。本教程旨在分享一些关于 XGBoost 中使用的 JSON 序列化方法的基本见解。除非另有明确说明,以下部分均假定您正在使用这两种输出格式之一,您可以在保存/加载模型时通过提供以 .json
(或二进制 JSON 的 .ubj
)为文件扩展名的文件名来启用此格式:booster.save_model('model.json')
。更多详情如下。
在我们开始之前,XGBoost 是一个专注于树模型的梯度提升库,这意味着在 XGBoost 内部,有两个不同的部分:
由树组成的模型,以及
用于构建模型的超参数和配置。
如果您来自深度学习社区,那么您应该清楚,由具有固定张量运算的权重组成的神经网络结构与用于训练它们的优化器(例如 RMSprop)之间存在差异。
因此,当调用 booster.save_model
(R 中为 xgb.save
)时,XGBoost 会保存树、一些模型参数(例如训练好的树中的输入列数)以及目标函数,这些共同构成了 XGBoost 中的“模型”概念。至于为何我们将目标函数作为模型的一部分保存,那是因为目标函数控制着全局偏差(在 XGBoost 中称为 base_score
)的转换以及任务特定的信息。用户可以将此模型分享给其他人进行预测、评估,或者使用不同的超参数集继续训练等。
然而,这并非全部。在某些情况下,我们需要保存的不仅仅是模型本身。例如,在分布式训练中,XGBoost 会执行检查点操作。或者出于某种原因,您喜欢的分布式计算框架决定将模型从一个工作节点复制到另一个节点并在那里继续训练。在这种情况下,序列化输出需要包含足够的信息以继续先前的训练,而无需用户再次提供任何参数。我们将这种情况视为内存快照(或基于内存的序列化方法),并将其与普通模型IO操作区分开来。目前,内存快照用于以下几个方面:
Python 包:当使用内置的
pickle
模块对Booster
对象进行 pickle 操作时。R 包:当使用内置函数
saveRDS
或save
持久化xgb.Booster
对象时。JVM 包:当使用内置函数
saveModel
序列化Booster
对象时。
其他语言绑定仍在开发中。
注意
旧的二进制格式没有区分模型和原始内存序列化格式,它是所有内容的混合体,这也是我们希望用更稳健的序列化方法取代它的部分原因。JVM 包有自己的基于内存的序列化方法。
要启用模型 IO(仅保存树和目标函数)的 JSON 格式支持,请提供文件名时带上 .json
或 .ubj
作为文件扩展名,后者是 通用二进制 JSON (Universal Binary JSON) 的扩展名。
bst.save_model('model_file_name.json')
xgb.save(bst, 'model_file_name.json')
val format = "json" // or val format = "ubj"
model.write.option("format", format).save("model_directory_path")
注意
仅加载由 XGBoost 生成的 JSON 文件中的模型。尝试加载由外部源生成的 JSON 文件可能导致未定义的行为和崩溃。
而对于内存快照,从 xgboost 1.6 开始,UBJSON 是默认格式。当重新加载模型时,XGBoost 可以识别文件扩展名 .json
和 .ubj
,并进行相应的分派。如果未指定扩展名,XGBoost 会尝试猜测正确的扩展名。
关于模型和内存快照向后兼容性的说明
我们保证模型的向后兼容性,但不保证内存快照的向后兼容性。
模型(树和目标函数)使用稳定的表示形式,因此在早期版本的 XGBoost 中生成的模型可以在后续版本的 XGBoost 中访问。如果您想长期存储或归档您的模型,请使用 save_model
(Python) 和 xgb.save
(R)。
另一方面,内存快照(序列化)捕获了 XGBoost 内部的许多内容,其格式不稳定且经常变化。因此,内存快照仅适用于检查点,您可以持久化训练配置的完整快照,以便在可能的故障发生时能够稳健地恢复并恢复训练过程。加载由早期版本的 XGBoost 生成的内存快照可能会导致错误或未定义的行为。如果模型使用 pickle.dump
(Python) 或 saveRDS
(R) 持久化,则该模型可能无法在后续版本的 XGBoost 中访问。
自定义目标和指标
XGBoost 接受用户提供的目标函数和指标函数作为扩展。这些函数不保存在模型文件中,因为它们是依赖于语言的特性。在 Python 中,用户可以对模型进行 pickle 操作,以便将这些函数包含在保存的二进制文件中。一个缺点是,pickle 的输出不是稳定的序列化格式,并且在不同的 Python 版本或 XGBoost 版本之间不兼容,更不用说不同的语言环境了。解决此限制的另一种方法是在加载模型后再次提供这些函数。如果自定义函数很有用,请考虑提交 PR 以在 XGBoost 内部实现它,这样您的函数就可以与不同的语言绑定一起工作了。
从不同版本的 XGBoost 加载 pickle 文件
如前所述,pickled 模型既不可移植也不稳定,但在某些情况下,pickled 模型很有价值。将来恢复它的一种方法是使用特定版本的 Python 和 XGBoost 加载它,然后通过调用 save_model 导出模型。
类似的过程也可用于恢复旧 RDS 文件中持久化的模型。在 R 中,您可以使用 remotes
包安装旧版本的 XGBoost
library(remotes)
remotes::install_version("xgboost", "0.90.0.1") # Install version 0.90.0.1
安装所需版本后,您可以使用 readRDS
加载 RDS 文件并恢复 xgb.Booster
对象。然后调用 xgb.save
使用稳定的表示形式导出模型。现在您应该能够在最新版本的 XGBoost 中使用该模型了。
保存和加载内部参数配置
XGBoost 的 C API
、Python API
和 R API
支持直接将内部配置保存和加载为 JSON 字符串。在 Python 包中
bst = xgboost.train(...)
config = bst.save_config()
print(config)
或在 R 中
config <- xgb.config(bst)
print(config)
将打印出类似以下内容(非实际输出,因其过长无法演示)
{
"Learner": {
"generic_parameter": {
"device": "cuda:0",
"gpu_page_size": "0",
"n_jobs": "0",
"random_state": "0",
"seed": "0",
"seed_per_iteration": "0"
},
"gradient_booster": {
"gbtree_train_param": {
"num_parallel_tree": "1",
"process_type": "default",
"tree_method": "hist",
"updater": "grow_gpu_hist",
"updater_seq": "grow_gpu_hist"
},
"name": "gbtree",
"updater": {
"grow_gpu_hist": {
"gpu_hist_train_param": {
"debug_synchronize": "0",
},
"train_param": {
"alpha": "0",
"cache_opt": "1",
"colsample_bylevel": "1",
"colsample_bynode": "1",
"colsample_bytree": "1",
"default_direction": "learn",
...
"subsample": "1"
}
}
}
},
"learner_train_param": {
"booster": "gbtree",
"disable_default_eval_metric": "0",
"objective": "reg:squarederror"
},
"metrics": [],
"objective": {
"name": "reg:squarederror",
"reg_loss_param": {
"scale_pos_weight": "1"
}
}
},
"version": [1, 0, 0]
}
您可以通过以下方式将其加载回由相同版本 XGBoost 生成的模型中
bst.load_config(config)
这样用户可以更仔细地研究内部表示形式。请注意,一些 JSON 生成器使用依赖于区域设置的浮点序列化方法,这不受 XGBoost 支持。
保存模型和转储模型的区别
XGBoost 在 Booster 对象中有一个名为 dump_model
的函数,它允许您以可读格式导出模型,例如 text
、json
或 dot
(graphviz)。它的主要用途是用于模型解释或可视化,不应将其加载回 XGBoost。JSON 版本有一个模式。更多信息请参见下一节。
JSON 模式
JSON 格式的另一个重要特性是其文档化的模式,基于此模式可以轻松重用 XGBoost 的输出模型。此处是输出模型的 JSON 模式(非序列化,如前所述,序列化不稳定)。有关解析 XGBoost 树模型的示例,请参阅 /demo/json-model
。请注意在“dart”booster 中使用的“weight_drop”字段。XGBoost 不直接缩放树叶节点值,而是将权重保存为一个单独的数组。
{
"$schema": "https://json-schema.fullstack.org.cn/draft-07/schema#",
"definitions": {
"gbtree": {
"type": "object",
"properties": {
"name": {
"const": "gbtree"
},
"model": {
"type": "object",
"properties": {
"gbtree_model_param": {
"$ref": "#/definitions/gbtree_model_param"
},
"trees": {
"type": "array",
"items": {
"type": "object",
"properties": {
"tree_param": {
"$ref": "#/definitions/tree_param"
},
"id": {
"type": "integer"
},
"loss_changes": {
"type": "array",
"items": {
"type": "number"
}
},
"sum_hessian": {
"type": "array",
"items": {
"type": "number"
}
},
"base_weights": {
"type": "array",
"items": {
"type": "number"
}
},
"left_children": {
"type": "array",
"items": {
"type": "integer"
}
},
"right_children": {
"type": "array",
"items": {
"type": "integer"
}
},
"parents": {
"type": "array",
"items": {
"type": "integer"
}
},
"split_indices": {
"type": "array",
"items": {
"type": "integer"
}
},
"split_conditions": {
"type": "array",
"items": {
"type": "number"
}
},
"split_type": {
"type": "array",
"items": {
"type": "integer"
}
},
"default_left": {
"type": "array",
"items": {
"type": "integer"
}
},
"categories": {
"type": "array",
"items": {
"type": "integer"
}
},
"categories_nodes": {
"type": "array",
"items": {
"type": "integer"
}
},
"categories_segments": {
"type": "array",
"items": {
"type": "integer"
}
},
"categories_sizes": {
"type": "array",
"items": {
"type": "integer"
}
}
},
"required": [
"tree_param",
"loss_changes",
"sum_hessian",
"base_weights",
"left_children",
"right_children",
"parents",
"split_indices",
"split_conditions",
"default_left",
"categories",
"categories_nodes",
"categories_segments",
"categories_sizes"
]
}
},
"tree_info": {
"type": "array",
"items": {
"type": "integer"
}
}
},
"required": [
"gbtree_model_param",
"trees",
"tree_info"
]
}
},
"required": [
"name",
"model"
]
},
"gbtree_model_param": {
"type": "object",
"properties": {
"num_trees": {
"type": "string"
},
"num_parallel_tree": {
"type": "string"
}
},
"required": [
"num_trees",
"num_parallel_tree"
]
},
"tree_param": {
"type": "object",
"properties": {
"num_nodes": {
"type": "string"
},
"size_leaf_vector": {
"type": "string"
},
"num_feature": {
"type": "string"
}
},
"required": [
"num_nodes",
"num_feature",
"size_leaf_vector"
]
},
"reg_loss_param": {
"type": "object",
"properties": {
"scale_pos_weight": {
"type": "string"
}
}
},
"pseudo_huber_param": {
"type": "object",
"properties": {
"huber_slope": {
"type": "string"
}
}
},
"aft_loss_param": {
"type": "object",
"properties": {
"aft_loss_distribution": {
"type": "string"
},
"aft_loss_distribution_scale": {
"type": "string"
}
}
},
"softmax_multiclass_param": {
"type": "object",
"properties": {
"num_class": { "type": "string" }
}
},
"lambda_rank_param": {
"type": "object",
"properties": {
"num_pairsample": { "type": "string" },
"fix_list_weight": { "type": "string" }
}
},
"lambdarank_param": {
"type": "object",
"properties": {
"lambdarank_num_pair_per_sample": { "type": "string" },
"lambdarank_pair_method": { "type": "string" },
"lambdarank_unbiased": {"type": "string" },
"lambdarank_bias_norm": {"type": "string" },
"ndcg_exp_gain": {"type": "string"}
}
}
},
"type": "object",
"properties": {
"version": {
"type": "array",
"items": [
{
"type": "number",
"minimum": 1
},
{
"type": "number",
"minimum": 0
},
{
"type": "number",
"minimum": 0
}
],
"minItems": 3,
"maxItems": 3
},
"learner": {
"type": "object",
"properties": {
"feature_names": {
"type": "array",
"items": {
"type": "string"
}
},
"feature_types": {
"type": "array",
"items": {
"type": "string"
}
},
"gradient_booster": {
"oneOf": [
{
"$ref": "#/definitions/gbtree"
},
{
"type": "object",
"properties": {
"name": { "const": "gblinear" },
"model": {
"type": "object",
"properties": {
"weights": {
"type": "array",
"items": {
"type": "number"
}
}
}
}
}
},
{
"type": "object",
"properties": {
"name": { "const": "dart" },
"gbtree": {
"$ref": "#/definitions/gbtree"
},
"weight_drop": {
"type": "array",
"items": {
"type": "number"
}
}
},
"required": [
"name",
"gbtree",
"weight_drop"
]
}
]
},
"objective": {
"oneOf": [
{
"type": "object",
"properties": {
"name": { "const": "reg:squarederror" },
"reg_loss_param": { "$ref": "#/definitions/reg_loss_param"}
},
"required": [
"name",
"reg_loss_param"
]
},
{
"type": "object",
"properties": {
"name": { "const": "reg:pseudohubererror" },
"reg_loss_param": { "$ref": "#/definitions/reg_loss_param"}
},
"required": [
"name",
"reg_loss_param"
]
},
{
"type": "object",
"properties": {
"name": { "const": "reg:squaredlogerror" },
"reg_loss_param": { "$ref": "#/definitions/reg_loss_param"}
},
"required": [
"name",
"reg_loss_param"
]
},
{
"type": "object",
"properties": {
"name": { "const": "reg:linear" },
"reg_loss_param": { "$ref": "#/definitions/reg_loss_param"}
},
"required": [
"name",
"reg_loss_param"
]
},
{
"type": "object",
"properties": {
"name": { "const": "reg:logistic" },
"reg_loss_param": { "$ref": "#/definitions/reg_loss_param"}
},
"required": [
"name",
"reg_loss_param"
]
},
{
"type": "object",
"properties": {
"name": { "const": "binary:logistic" },
"reg_loss_param": { "$ref": "#/definitions/reg_loss_param"}
},
"required": [
"name",
"reg_loss_param"
]
},
{
"type": "object",
"properties": {
"name": { "const": "binary:logitraw" },
"reg_loss_param": { "$ref": "#/definitions/reg_loss_param"}
},
"required": [
"name",
"reg_loss_param"
]
},
{
"type": "object",
"properties": {
"name": { "const": "count:poisson" },
"poisson_regression_param": {
"type": "object",
"properties": {
"max_delta_step": { "type": "string" }
}
}
},
"required": [
"name",
"poisson_regression_param"
]
},
{
"type": "object",
"properties": {
"name": { "const": "reg:tweedie" },
"tweedie_regression_param": {
"type": "object",
"properties": {
"tweedie_variance_power": { "type": "string" }
}
}
},
"required": [
"name",
"tweedie_regression_param"
]
},
{
"properties": {
"name": {
"const": "reg:absoluteerror"
}
},
"type": "object"
},
{
"properties": {
"name": {
"const": "reg:quantileerror"
},
"quantile_loss_param": {
"type": "object",
"properties": {
"quantle_alpha": {"type": "array"}
}
}
},
"type": "object"
},
{
"type": "object",
"properties": {
"name": { "const": "survival:cox" }
},
"required": [ "name" ]
},
{
"type": "object",
"properties": {
"name": { "const": "reg:gamma" }
},
"required": [ "name" ]
},
{
"type": "object",
"properties": {
"name": { "const": "multi:softprob" },
"softmax_multiclass_param": { "$ref": "#/definitions/softmax_multiclass_param"}
},
"required": [
"name",
"softmax_multiclass_param"
]
},
{
"type": "object",
"properties": {
"name": { "const": "multi:softmax" },
"softmax_multiclass_param": { "$ref": "#/definitions/softmax_multiclass_param"}
},
"required": [
"name",
"softmax_multiclass_param"
]
},
{
"type": "object",
"properties": {
"name": { "const": "rank:pairwise" },
"lambda_rank_param": { "$ref": "#/definitions/lambdarank_param"}
},
"required": [
"name",
"lambdarank_param"
]
},
{
"type": "object",
"properties": {
"name": { "const": "rank:ndcg" },
"lambda_rank_param": { "$ref": "#/definitions/lambdarank_param"}
},
"required": [
"name",
"lambdarank_param"
]
},
{
"type": "object",
"properties": {
"name": { "const": "rank:map" },
"lambda_rank_param": { "$ref": "#/definitions/lambda_rank_param"}
},
"required": [
"name",
"lambda_rank_param"
]
},
{
"type": "object",
"properties": {
"name": {"const": "survival:aft"},
"aft_loss_param": { "$ref": "#/definitions/aft_loss_param"}
}
},
{
"type": "object",
"properties": {
"name": {"const": "binary:hinge"}
}
}
]
},
"learner_model_param": {
"type": "object",
"properties": {
"base_score": { "type": "string" },
"num_class": { "type": "string" },
"num_feature": { "type": "string" },
"num_target": { "type": "string" }
}
}
},
"required": [
"gradient_booster",
"objective"
]
}
},
"required": [
"version",
"learner"
]
}