用于定义回调函数的结构,这些函数可以在模型训练的不同阶段执行(训练前/后,每次提升迭代前/后)。
用法
xgb.Callback(
cb_name = "custom_callback",
env = new.env(),
f_before_training = function(env, model, data, evals, begin_iteration, end_iteration)
NULL,
f_before_iter = function(env, model, data, evals, iteration) NULL,
f_after_iter = function(env, model, data, evals, iteration, iter_feval) NULL,
f_after_training = function(env, model, data, evals, iteration, final_feval,
prev_cb_res) NULL
)
参数
- cb_name
回调的名称。
如果回调产生非 NULL 结果(来自执行
f_after_training
下传递的函数),该结果将作为 R 属性添加到最终的 booster 中(或作为 CV 结果中的命名元素),属性名称由此指定。回调的名称必须是唯一的 - 即不能有两个名称相同的回调。
- env
一个环境对象,将传递给回调中的不同函数。请注意,此环境不会与其他回调共享。
- f_before_training
一个函数,将在训练开始前执行。
如果对此或对其他函数输入传递
NULL
,则不会执行任何函数。如果传递一个函数,它将使用作为非命名参数提供的参数调用,这些参数与每个函数参数默认值中所示的函数签名匹配。
- f_before_iter
一个函数,将在每次提升轮次前执行。
该函数可以通过输出评估为
TRUE
的值来指示训练是否应终止 - 即,如果在此提供的函数在给定轮次的输出为TRUE
,则训练将在当前迭代发生之前停止。返回值为
NULL
将被解释为FALSE
。- f_after_iter
一个函数,将在每次提升轮次后执行。
该函数可以通过输出评估为
TRUE
的值来指示训练是否应终止 - 即,如果在此提供的函数在给定轮次的输出为TRUE
,则训练将在该轮次停止。返回值为
NULL
将被解释为FALSE
。- f_after_training
一个函数,将在训练结束后执行。
该函数可以选择性地输出一些非 NULL 的内容,这些内容将成为 booster 的 R 属性的一部分(假设对
xgb.train()
传递了keep_extra_attributes=TRUE
),其名称为参数cb_name
在xgb.train()
情况下的提供值;或者成为xgb.cv()
结果中命名元素的一部分。
返回值
一个 xgb.Callback
对象,可以传递给 xgb.train()
或 xgb.cv()
。
详情
将传递给提供的函数的参数如下:
env 与参数
env
下传递的环境相同。它可以被函数修改,例如用于跟踪迭代过程中发生的事情或类似目的。
此环境仅由提供给回调的函数使用,并在模型拟合函数终止后不保留(参见参数
f_after_training
)。model 使用
xgb.train()
时的 booster 对象,或使用xgb.cv()
时的折叠 (folds)。对于
xgb.cv()
,折叠是一个列表,其结构如下:dtrain
: 该折叠的训练数据(作为xgb.DMatrix
对象)。bst
: 该折叠的xgb.Booster
对象。evals
: 一个包含两个 DMatrix 的列表,名称分别为train
和test
(test
是该折叠的保留数据)。index
: 该折叠保留数据的索引(base-1 索引),evals
中的test
条目就是从这些索引获取的。
这个对象不应以与训练冲突的方式原地修改(例如,以将轮次数重置为零来覆盖轮次的方式重置训练更新的参数)。
请注意,在回调函数期间分配给 booster 的任何 R 属性之后都不会保留,因为在训练期间 booster 对象变量未被重新赋值。然而,可以通过
xgb.attr()
或xgb.attributes()
设置 booster 的 C 级别属性,这些属性在其余迭代和训练完成后应保持可用。为了跨迭代保留变量,建议改用
env
。data 模型拟合的数据,作为
xgb.DMatrix
对象。请注意,对于
xgb.cv()
,这将是完整数据,而特定折叠的数据可以在model
对象中找到。evals 评估数据,作为参数
evals
传递给xgb.train()
。对于
xgb.cv()
,这将始终是NULL
。begin_iteration 将执行的第一个提升迭代的索引(base-1 索引)。
这通常是 '1',但当使用训练续接时,取决于更新参数,提升轮次将从之前模型结束的地方继续,在这种情况下,这将大于 1。
end_iteration 将执行的最后一个提升迭代的索引(base-1 索引,包括此结束)。
它应与传递给
xgb.train()
或xgb.cv()
的参数nrounds
匹配。请注意,提升可能会在达到最后一个迭代之前中断,例如通过使用早停回调
xgb.cb.early.stop()
。iteration 正在执行的迭代号的索引(第一次迭代将与参数
begin_iteration
相同,然后下一次将加 1,以此类推)。iter_feval 为提供的
evals
计算的评估指标,由目标函数或参数custom_metric
确定。对于
xgb.train()
,这将是一个命名向量,其中evals
中的每个元素对应一个条目,名称的确定方式是 'evals 名称' + '-' + '指标名称' - 例如,如果evals
包含一个名为 "tr" 的条目且指标是 "rmse",这将是一个名称为 "tr-rmse" 的单元素向量。对于
xgb.cv()
,这将是一个二维矩阵,维度为[length(evals), nfolds]
,行名称将遵循与在xgb.train()
中传递的一维向量相同的命名逻辑。请注意,在内部,诸如 xgb.cb.print.evaluation 之类的内置回调通过计算行均值和标准差来汇总此表。
final_feval 执行最后一次提升轮次后的评估结果(与
iter_feval
格式相同,并且与模型拟合期间执行的最后一轮传递给iter_feval
的输入完全相同)。prev_cb_res 在进行训练续接时,如果 booster R 属性中存在任何具有相同名称(由参数
cb_name
提供)的回调之前运行的结果。有时,可能需要将新结果附加到之前的结果中,诸如 xgb.cb.evaluation.log 之类的内置回调会自动执行此操作,将新行附加到之前的表中。
如果没有此类之前的回调结果(从头开始拟合模型而不是更新现有模型时永远不会有),则此值为
NULL
。对于不支持训练续接的
xgb.cv()
,此值始终为NULL
。
以下名称(cb_name
值)保留用于内部回调:
print_evaluation
evaluation_log
reset_parameters
early_stop
save_model
cv_predict
gblinear_history
以下名称保留用于其他非回调属性:
names
class
call
params
niter
nfeatures
folds
当使用内置的早停回调 (xgb.cb.early.stop) 时,该回调将始终在其他回调之前执行,因为它设置了一些其他回调可能也会使用的 booster C 级别属性。否则,执行顺序将与回调传递给模型拟合函数的顺序匹配。
示例
# Example constructing a custom callback that calculates
# squared error on the training data (no separate test set),
# and outputs the per-iteration results.
ssq_callback <- xgb.Callback(
cb_name = "ssq",
f_before_training = function(env, model, data, evals,
begin_iteration, end_iteration) {
# A vector to keep track of a number at each iteration
env$logs <- rep(NA_real_, end_iteration - begin_iteration + 1)
},
f_after_iter = function(env, model, data, evals, iteration, iter_feval) {
# This calculates the sum of squared errors on the training data.
# Note that this can be better done by passing an 'evals' entry,
# but this demonstrates a way in which callbacks can be structured.
pred <- predict(model, data)
err <- pred - getinfo(data, "label")
sq_err <- sum(err^2)
env$logs[iteration] <- sq_err
cat(
sprintf(
"Squared error at iteration %d: %.2f\n",
iteration, sq_err
)
)
# A return value of 'TRUE' here would signal to finalize the training
return(FALSE)
},
f_after_training = function(env, model, data, evals, iteration,
final_feval, prev_cb_res) {
return(env$logs)
}
)
data(mtcars)
y <- mtcars$mpg
x <- as.matrix(mtcars[, -1])
dm <- xgb.DMatrix(x, label = y, nthread = 1)
model <- xgb.train(
data = dm,
params = xgb.params(objective = "reg:squarederror", nthread = 1),
nrounds = 5,
callbacks = list(ssq_callback)
)
# Result from 'f_after_iter' will be available as an attribute
attributes(model)$ssq