跳到内容

从JSON加载XGBoost

简介

本 Vignette 的目的是向您展示如何正确加载和使用已转储为JSON格式的 XGBoost 模型。 XGBoost 内部将所有数据转换为32位浮点数,转储到JSON的值是这些值的十进制表示。在处理从JSON文件解析的模型时,必须注意正确处理

  • 输入数据,它应该转换为32位浮点数
  • 存储在JSON中作为十进制表示的任何32位浮点数
  • 任何计算都必须使用32位数学运算符

设置

为了本教程的目的,我们将加载xgboost、jsonlite和float包。我们还将在选项中设置digits=22,以防我们想要检查结果的许多位数。

## Loading required package: xgboost
## Loading required package: jsonlite
## Loading required package: float
options(digits = 22)

我们将根据此处首次提供的示例创建一个玩具二元逻辑模型,以便我们能够轻松理解转储的JSON模型对象的结构。这将使我们能够理解可能出现差异的地方以及如何处理它们。

dates <- c(20180130, 20180130, 20180130,
           20180130, 20180130, 20180130,
           20180131, 20180131, 20180131,
           20180131, 20180131, 20180131,
           20180131, 20180131, 20180131,
           20180134, 20180134, 20180134)

labels <- c(1, 1, 1,
            1, 1, 1,
            0, 0, 0,
            0, 0, 0,
            0, 0, 0,
            0, 0, 0)

data <- data.frame(dates = dates, labels = labels)

bst <- xgb.train(
  data = xgb.DMatrix(as.matrix(data$dates), label = labels, missing = NA, nthread = 1),
  nrounds = 1,
  params = xgb.params(
    objective = "binary:logistic",
    nthread = 2,
    max_depth = 1
  )
)

比较结果

我们现在将模型转储为JSON,并尝试说明可能出现的各种问题以及如何正确处理它们。

首先,我们将模型转储为JSON

bst_json <- xgb.dump(bst, with_stats = FALSE, dump_format = 'json')
bst_from_json <- fromJSON(bst_json, simplifyDataFrame = FALSE)
node <- bst_from_json[[1]]
cat(bst_json)
## [
##   { "nodeid": 0, "depth": 0, "split": "f0", "split_condition": 20180132, "yes": 1, "no": 2, "missing": 2 , "children": [
##     { "nodeid": 1, "leaf": 0.514285684 }, 
##     { "nodeid": 2, "leaf": -0.327272743 }
##   ]}
## ]

上面代码块显示的树JSON告诉我们,如果数据小于20180132,树将输出第一个叶子中的值。否则,它将输出第二个叶子中的值。让我们尝试手动使用我们拥有的数据重现这一点,并确认它与我们已经计算出的模型预测匹配。

bst_preds_logodds <- predict(bst, as.matrix(data$dates), outputmargin = TRUE)

# calculate the logodds values using the JSON representation
bst_from_json_logodds <- ifelse(data$dates < node$split_condition,
                                node$children[[1]]$leaf,
                                node$children[[2]]$leaf)

bst_preds_logodds
##  [1] -0.1788614988327026367188 -0.1788614988327026367188
##  [3] -0.1788614988327026367188 -0.1788614988327026367188
##  [5] -0.1788614988327026367188 -0.1788614988327026367188
##  [7] -1.0204199552536010742188 -1.0204199552536010742188
##  [9] -1.0204199552536010742188 -1.0204199552536010742188
## [11] -1.0204199552536010742188 -1.0204199552536010742188
## [13] -1.0204199552536010742188 -1.0204199552536010742188
## [15] -1.0204199552536010742188 -1.0204199552536010742188
## [17] -1.0204199552536010742188 -1.0204199552536010742188
bst_from_json_logodds
##  [1]  0.5142856839999999651880  0.5142856839999999651880
##  [3]  0.5142856839999999651880  0.5142856839999999651880
##  [5]  0.5142856839999999651880  0.5142856839999999651880
##  [7]  0.5142856839999999651880  0.5142856839999999651880
##  [9]  0.5142856839999999651880  0.5142856839999999651880
## [11]  0.5142856839999999651880  0.5142856839999999651880
## [13]  0.5142856839999999651880  0.5142856839999999651880
## [15]  0.5142856839999999651880 -0.3272727429999999770871
## [17] -0.3272727429999999770871 -0.3272727429999999770871
# test that values are equal
bst_preds_logodds == bst_from_json_logodds
##  [1] FALSE FALSE FALSE FALSE FALSE FALSE FALSE FALSE FALSE FALSE FALSE FALSE
## [13] FALSE FALSE FALSE FALSE FALSE FALSE

没有一个相等。发生了什么?

在此阶段发生了两件事

  • 输入数据未转换为32位浮点数
  • JSON变量未转换为32位浮点数

第一课:所有数据都是32位浮点数

处理导入的JSON时,所有数据都必须转换为32位浮点数

为了解释这一点,让我们重复比较并四舍五入到两位小数

round(bst_preds_logodds, 2) == round(bst_from_json_logodds, 2)
##  [1] FALSE FALSE FALSE FALSE FALSE FALSE FALSE FALSE FALSE FALSE FALSE FALSE
## [13] FALSE FALSE FALSE FALSE FALSE FALSE

如果我们四舍五入到两位小数,我们会发现只有与数据值20180131相关的元素不一致。如果我们将数据转换为浮点数,它们就一致了

# now convert the dates to floats first
bst_from_json_logodds <- ifelse(fl(data$dates) < node$split_condition,
                                node$children[[1]]$leaf,
                                node$children[[2]]$leaf)

# test that values are equal
round(bst_preds_logodds, 2) == round(bst_from_json_logodds, 2)
##  [1] FALSE FALSE FALSE FALSE FALSE FALSE FALSE FALSE FALSE FALSE FALSE FALSE
## [13] FALSE FALSE FALSE FALSE FALSE FALSE

有什么教训?如果我们要使用导入的JSON模型,任何数据都必须首先转换为浮点数。在这种情况下,由于“20180131”不能表示为32位浮点数,它会被向上舍入到20180132,如下所示

fl(20180131)
## # A float32 vector: 1
## [1] 20180132

第二课:JSON参数是32位浮点数

所有存储为浮点数的JSON参数都必须转换为浮点数。

现在我们假设我们关心两位小数之外的数字。

# test that values are equal
bst_preds_logodds == bst_from_json_logodds
##  [1] FALSE FALSE FALSE FALSE FALSE FALSE FALSE FALSE FALSE FALSE FALSE FALSE
## [13] FALSE FALSE FALSE FALSE FALSE FALSE

没有一个完全相等。发生了什么?尽管我们已经将数据转换为32位浮点数,但我们还需要将JSON参数转换为32位浮点数。让我们这样做

# now convert the dates to floats first
bst_from_json_logodds <- ifelse(fl(data$dates) < fl(node$split_condition),
                                as.numeric(fl(node$children[[1]]$leaf)),
                                as.numeric(fl(node$children[[2]]$leaf)))

# test that values are equal
bst_preds_logodds == bst_from_json_logodds
##  [1] FALSE FALSE FALSE FALSE FALSE FALSE FALSE FALSE FALSE FALSE FALSE FALSE
## [13] FALSE FALSE FALSE FALSE FALSE FALSE

都相等。有什么教训?如果我们要使用导入的JSON模型,任何存储为浮点数的JSON参数也必须首先转换为浮点数。

第三课:使用32位数学运算

始终使用32位数字和运算符

我们能够使对数几率一致,所以现在让我们手动计算对数几率的Sigmoid函数。这应该与xgboost预测一致。

bst_preds <- predict(bst, as.matrix(data$dates))

# calculate the predictions casting doubles to floats
bst_from_json_preds <- ifelse(
  fl(data$dates) < fl(node$split_condition)
  , as.numeric(1 / (1 + exp(-1 * fl(node$children[[1]]$leaf))))
  , as.numeric(1 / (1 + exp(-1 * fl(node$children[[2]]$leaf))))
)

# test that values are equal
bst_preds == bst_from_json_preds
##  [1] FALSE FALSE FALSE FALSE FALSE FALSE FALSE FALSE FALSE FALSE FALSE FALSE
## [13] FALSE FALSE FALSE FALSE FALSE FALSE

没有一个再次完全相等。这是怎么回事?好吧,由于我们在计算中使用了值1,我们在计算中引入了一个双精度浮点数。因此,所有浮点值都被提升为64位双精度浮点数,并且还使用了指数运算符exp的64位版本。另一方面,xgboost在其sigmoid函数中使用了指数运算符的32位版本。

我们如何解决这个问题?我们必须确保在所有地方都使用正确的数据类型和正确的运算符。如果我们只使用浮点数,我们加载的浮点库将确保应用32位浮点指数运算符。

# calculate the predictions casting doubles to floats
bst_from_json_preds <- ifelse(
  fl(data$dates) < fl(node$split_condition)
  , as.numeric(fl(1) / (fl(1) + exp(fl(-1) * fl(node$children[[1]]$leaf))))
  , as.numeric(fl(1) / (fl(1) + exp(fl(-1) * fl(node$children[[2]]$leaf))))
)

# test that values are equal
bst_preds == bst_from_json_preds
##  [1] FALSE FALSE FALSE FALSE FALSE FALSE FALSE FALSE FALSE FALSE FALSE FALSE
## [13] FALSE FALSE FALSE FALSE FALSE FALSE

都相等。有什么教训?如果我们要重现xgboost的结果,我们必须确保所有计算都使用32位浮点运算符完成。