XGBoost 从 JSON
简介
此教程的目的是向您展示如何正确加载和使用已转储为 JSON 格式的 XGBoost 模型。XGBoost 内部将所有数据转换为32 位浮点数,并且转储到 JSON 的值是这些值的十进制表示。当使用从 JSON 文件解析的模型时,必须注意正确处理
输入数据,应将其转换为 32 位浮点数
以十进制表示形式存储在 JSON 中的任何 32 位浮点数
任何计算都必须使用 32 位数学运算符完成
设置
为了本教程的目的,我们将加载 xgboost、jsonlite 和 float 包。我们还将选项中的 digits=22
,以防我们想检查结果的许多位数。
require(xgboost)
## Loading required package: xgboost
require(jsonlite)
## Loading required package: jsonlite
require(float)
## Loading required package: float
options(digits = 22)
我们将基于此处首次提供的示例创建一个玩具二元逻辑模型,以便我们可以轻松理解转储的 JSON 模型对象的结构。这将使我们能够理解差异可能发生的位置以及应如何处理它们。
dates <- c(20180130, 20180130, 20180130,
20180130, 20180130, 20180130,
20180131, 20180131, 20180131,
20180131, 20180131, 20180131,
20180131, 20180131, 20180131,
20180134, 20180134, 20180134)
labels <- c(1, 1, 1,
1, 1, 1,
0, 0, 0,
0, 0, 0,
0, 0, 0,
0, 0, 0)
data <- data.frame(dates = dates, labels = labels)
bst <- xgb.train(
data = xgb.DMatrix(as.matrix(data$dates), label = labels, missing = NA, nthread = 1),
nrounds = 1,
params = xgb.params(
objective = "binary:logistic",
nthread = 2,
max_depth = 1
)
)
比较结果
我们现在将模型转储到 JSON 并尝试说明可能出现的各种问题,以及如何正确处理它们。
首先,我们将模型转储到 JSON
bst_json <- xgb.dump(bst, with_stats = FALSE, dump_format = 'json')
bst_from_json <- fromJSON(bst_json, simplifyDataFrame = FALSE)
node <- bst_from_json[[1]]
cat(bst_json)
## [
## { "nodeid": 0, "depth": 0, "split": "f0", "split_condition": 20180132, "yes": 1, "no": 2, "missing": 2 , "children": [
## { "nodeid": 1, "leaf": 0.514285684 },
## { "nodeid": 2, "leaf": -0.327272743 }
## ]}
## ]
上述代码块显示的树 JSON 告诉我们,如果数据小于 20180132,则树将输出第一个叶子中的值。否则,它将输出第二个叶子中的值。让我们尝试使用我们拥有的数据手动重现这一点,并确认它与我们已经计算出的模型预测相匹配。
bst_preds_logodds <- predict(bst, as.matrix(data$dates), outputmargin = TRUE)
# calculate the logodds values using the JSON representation
bst_from_json_logodds <- ifelse(data$dates < node$split_condition,
node$children[[1]]$leaf,
node$children[[2]]$leaf)
bst_preds_logodds
## [1] -0.1788614988327026367188 -0.1788614988327026367188
## [3] -0.1788614988327026367188 -0.1788614988327026367188
## [5] -0.1788614988327026367188 -0.1788614988327026367188
## [7] -1.0204199552536010742188 -1.0204199552536010742188
## [9] -1.0204199552536010742188 -1.0204199552536010742188
## [11] -1.0204199552536010742188 -1.0204199552536010742188
## [13] -1.0204199552536010742188 -1.0204199552536010742188
## [15] -1.0204199552536010742188 -1.0204199552536010742188
## [17] -1.0204199552536010742188 -1.0204199552536010742188
bst_from_json_logodds
## [1] 0.5142856839999999651880 0.5142856839999999651880
## [3] 0.5142856839999999651880 0.5142856839999999651880
## [5] 0.5142856839999999651880 0.5142856839999999651880
## [7] 0.5142856839999999651880 0.5142856839999999651880
## [9] 0.5142856839999999651880 0.5142856839999999651880
## [11] 0.5142856839999999651880 0.5142856839999999651880
## [13] 0.5142856839999999651880 0.5142856839999999651880
## [15] 0.5142856839999999651880 -0.3272727429999999770871
## [17] -0.3272727429999999770871 -0.3272727429999999770871
# test that values are equal
bst_preds_logodds == bst_from_json_logodds
## [1] FALSE FALSE FALSE FALSE FALSE FALSE FALSE FALSE FALSE FALSE FALSE FALSE
## [13] FALSE FALSE FALSE FALSE FALSE FALSE
都不相等。发生了什么?
在此阶段发生了两件事
输入数据未转换为 32 位浮点数
JSON 变量未转换为 32 位浮点数
教训 1:所有数据都是 32 位浮点数
处理导入的 JSON 时,所有数据都必须转换为 32 位浮点数
为了解释这一点,让我们重复比较并四舍五入到两位小数
round(bst_preds_logodds, 2) == round(bst_from_json_logodds, 2)
## [1] FALSE FALSE FALSE FALSE FALSE FALSE FALSE FALSE FALSE FALSE FALSE FALSE
## [13] FALSE FALSE FALSE FALSE FALSE FALSE
如果我们四舍五入到两位小数,我们会看到只有与数据值 20180131
相关的元素不一致。如果我们把数据转换为浮点数,它们就会一致
# now convert the dates to floats first
bst_from_json_logodds <- ifelse(fl(data$dates) < node$split_condition,
node$children[[1]]$leaf,
node$children[[2]]$leaf)
# test that values are equal
round(bst_preds_logodds, 2) == round(bst_from_json_logodds, 2)
## [1] FALSE FALSE FALSE FALSE FALSE FALSE FALSE FALSE FALSE FALSE FALSE FALSE
## [13] FALSE FALSE FALSE FALSE FALSE FALSE
教训是什么?如果我们要使用导入的 JSON 模型,任何数据都必须首先转换为浮点数。在这种情况下,由于“20180131”不能表示为 32 位浮点数,它被向上舍入到 20180132,如这里所示
fl(20180131)
## # A float32 vector: 1
## [1] 20180132
教训 2:JSON 参数是 32 位浮点数
所有存储为浮点数的 JSON 参数都必须转换为浮点数。
现在让我们假设我们确实关心小数点后两位之后的数字。
# test that values are equal
bst_preds_logodds == bst_from_json_logodds
## [1] FALSE FALSE FALSE FALSE FALSE FALSE FALSE FALSE FALSE FALSE FALSE FALSE
## [13] FALSE FALSE FALSE FALSE FALSE FALSE
都不完全相等。发生了什么?尽管我们已经将数据转换为 32 位浮点数,但我们还需要将 JSON 参数转换为 32 位浮点数。我们来做这个
# now convert the dates to floats first
bst_from_json_logodds <- ifelse(fl(data$dates) < fl(node$split_condition),
as.numeric(fl(node$children[[1]]$leaf)),
as.numeric(fl(node$children[[2]]$leaf)))
# test that values are equal
bst_preds_logodds == bst_from_json_logodds
## [1] FALSE FALSE FALSE FALSE FALSE FALSE FALSE FALSE FALSE FALSE FALSE FALSE
## [13] FALSE FALSE FALSE FALSE FALSE FALSE
全部相等。教训是什么?如果我们要使用导入的 JSON 模型,任何存储为浮点数的 JSON 参数也必须首先转换为浮点数。
教训 3:使用 32 位数学运算
始终使用 32 位数字和运算符
我们能够使对数几率一致,所以现在让我们手动计算对数几率的 sigmoid。这应该与 xgboost 预测一致。
bst_preds <- predict(bst, as.matrix(data$dates))
# calculate the predictions casting doubles to floats
bst_from_json_preds <- ifelse(
fl(data$dates) < fl(node$split_condition)
, as.numeric(1 / (1 + exp(-1 * fl(node$children[[1]]$leaf))))
, as.numeric(1 / (1 + exp(-1 * fl(node$children[[2]]$leaf))))
)
# test that values are equal
bst_preds == bst_from_json_preds
## [1] FALSE FALSE FALSE FALSE FALSE FALSE FALSE FALSE FALSE FALSE FALSE FALSE
## [13] FALSE FALSE FALSE FALSE FALSE FALSE
又都不完全相等。这里发生了什么?嗯,由于我们在计算中使用了值 1
,我们引入了一个双精度数到计算中。因此,所有浮点值都被提升为 64 位双精度数,并且也使用了指数运算符 exp
的 64 位版本。另一方面,xgboost 在其 sigmoid 函数中使用了指数运算符的 32 位版本。
我们如何解决这个问题?我们必须确保在所有地方都使用正确的数据类型和正确的运算符。如果我们只使用浮点数,我们加载的浮点库将确保应用 32 位浮点指数运算符。
# calculate the predictions casting doubles to floats
bst_from_json_preds <- ifelse(
fl(data$dates) < fl(node$split_condition)
, as.numeric(fl(1) / (fl(1) + exp(fl(-1) * fl(node$children[[1]]$leaf))))
, as.numeric(fl(1) / (fl(1) + exp(fl(-1) * fl(node$children[[2]]$leaf))))
)
# test that values are equal
bst_preds == bst_from_json_preds
## [1] FALSE FALSE FALSE FALSE FALSE FALSE FALSE FALSE FALSE FALSE FALSE FALSE
## [13] FALSE FALSE FALSE FALSE FALSE FALSE
全部相等。教训是什么?如果我们想要重现我们用 xgboost 看到的结果,我们必须确保所有计算都使用 32 位浮点运算符完成。