跳至内容

XGBoost 的交叉验证函数。

用法

xgb.cv(
  params = xgb.params(),
  data,
  nrounds,
  nfold,
  prediction = FALSE,
  showsd = TRUE,
  metrics = list(),
  objective = NULL,
  custom_metric = NULL,
  stratified = "auto",
  folds = NULL,
  train_folds = NULL,
  verbose = TRUE,
  print_every_n = 1L,
  early_stopping_rounds = NULL,
  maximize = NULL,
  callbacks = list(),
  ...
)

参数

params

控制模型构建过程的 XGBoost 参数列表。详见在线文档xgb.params() 的文档。

应作为带有命名条目的列表传递。此列表中未指定的参数将使用其默认值。

可以通过函数 xgb.params() 创建一个命名参数列表,该函数接受所有有效参数作为函数参数。

data

一个 xgb.DMatrix 对象,包含目标函数进行模型训练所需的相应字段,例如 label 或 bounds。

请注意,此处仅支持基本的 xgb.DMatrix 类 - 不支持 xgb.QuantileDMatrixxgb.ExtMemDMatrix 等变体。

nrounds

最大提升迭代次数。

nfold

原始数据集被随机分成 nfold 个大小相等的子样本。

prediction

一个逻辑值,指示是否返回每个交叉验证模型的测试折叠预测。此参数会启用 xgb.cb.cv.predict() 回调函数。

showsd

是否显示交叉验证标准差的逻辑值。

metrics

交叉验证中使用的评估指标列表,未指定时,根据目标函数选择评估指标。可能的选项有:

  • error:二分类错误率

  • rmse:均方根误差

  • logloss:负对数似然函数

  • mae:平均绝对误差

  • mape:平均绝对百分比误差

  • auc:曲线下面积

  • aucpr:PR 曲线下面积

  • merror:用于评估多分类的精确匹配误差

objective

自定义目标函数。应接受两个参数:第一个是当前预测结果(根据目标数量/类别可以是数值向量或矩阵),第二个是用于训练的 data DMatrix 对象。

它应该返回一个包含两个元素 gradhess(按此顺序)的列表,根据目标数量/类别,它们可以是数值向量或数值矩阵(与作为第一个参数传递的预测结果维度相同)。

custom_metric

自定义评估函数。与 objective 类似,应接受两个参数,第一个是预测结果,第二个是 data DMatrix。

应返回一个包含两个元素的列表:metric(将为此指标显示的名称,应为字符串/字符)和 value(函数计算的数值,应为数值标量)。

请注意,即使传递了 custom_metric,目标函数也有一个相关的默认指标,将额外进行评估。要禁用内置指标,可以传递参数 disable_default_eval_metric = TRUE

stratified

一个逻辑标志,指示折叠采样是否应按结果标签的值进行分层。对于回归目标中的实值标签,分层将预先通过将标签离散化为最多 5 个桶来完成。

如果传递 "auto",当 params 中的目标是分类目标(来自 XGBoost 的内置目标,不适用于自定义目标)时,将设置为 TRUE,否则设置为 FALSE

data 具有 group 字段时,此参数将被忽略 - 在这种情况下,拆分将基于整个组进行(请注意,这可能导致折叠大小不同)。

自定义目标函数支持此处的值 TRUE

folds

包含预定义交叉验证折叠的列表(每个元素必须是测试折叠索引的向量)。提供折叠时,将忽略 nfoldstratified 参数。

如果 data 具有 group 字段且目标函数需要此字段,则每个折叠(列表元素)还必须有两个属性(可通过 attributes 获取),分别名为 group_testgroup_train,它们应保存通过 setinfo.xgb.DMatrix() 分配给结果 DMatrices 的 group

train_folds

指定用于训练的索引列表。如果为 NULL(默认值),则 folds 中未指定的所有索引将用于训练。

data 具有 group 字段时,不支持此功能。

verbose

如果为 0,xgboost 将保持静默。如果为 1,它将打印性能信息。如果为 2,将打印一些额外信息。请注意,设置 verbose > 0 会自动启用 xgb.cb.print.evaluation(period=1) 回调函数。

print_every_n

当传递 verbose>0 时,评估日志(在 evals 下传递的数据上计算的指标)将按照此处传递的值每 nth 迭代打印一次。无论此 'n' 值如何,第一次和最后一次迭代始终包含在内。

仅在 evals 下传递数据且传递 verbose>0 时有效。该参数传递给 xgb.cb.print.evaluation() 回调函数。

early_stopping_rounds

evals 下传递的评估数据上,性能(由提供的或目标函数默认选择的评估指标衡量)在连续指定的提升轮数内没有改善时,训练将停止。

必须传递 evals 才能使用此功能。设置此参数会添加 xgb.cb.early.stop() 回调函数。

如果为 NULL,则不使用早停。

maximize

如果设置了 fevalearly_stopping_rounds,则也必须设置此参数。当其为 TRUE 时,表示评估分数越大越好。此参数传递给 xgb.cb.early.stop() 回调函数。

callbacks

在提升过程中执行各种任务的回调函数列表。参见 xgb.Callback()。一些回调函数会根据参数值自动创建。用户可以提供现有或自己的回调方法来定制训练过程。

...

未使用。

在之前的 XGBoost 版本中属于此函数的一些参数目前已被弃用或已重命名。如果传递了已弃用或重命名的参数,将抛出警告(默认)并使用其当前的等效参数。如果在启用“严格模式”选项时,此警告将变为错误。

如果传递了既不是当前函数参数也不是已弃用或重命名参数的额外参数,则会根据“严格模式”选项抛出警告或错误。

重要提示: ... 将在未来版本中移除,并且所有当前的弃用警告将变为错误。请仅使用构成函数签名的参数。

一个 'xgb.cv.synchronous' 类的对象,包含以下元素:

  • call:函数调用。

  • params:传递给 xgboost 库的参数。请注意,它不捕获由 xgb.cb.reset.parameters() 回调函数更改的参数。

  • evaluation_log:评估历史记录,存储为 data.table,第一列对应于迭代次数,其余列对应于训练和测试交叉验证集的基于交叉验证的评估均值和标准差。它由 xgb.cb.evaluation.log() 回调函数创建。

  • niter:提升迭代次数。

  • nfeatures:训练数据中的特征数量。

  • folds:交叉验证折叠索引列表 - 可以是通过 folds 参数传递的,也可以是随机生成的。

  • best_iteration:具有最佳评估指标值的迭代次数(仅在使用早停时可用)。

此外还有其他潜在元素,它们是回调函数的结果,例如当传递 prediction = TRUE 时,会有一个包含子元素 pred 的列表 cv_predict,它由 xgb.cb.cv.predict() 回调函数添加(注意,也可以在 callbacks 下手动传递它并设置不同的参数,例如也保存交叉验证期间创建的模型);或者一个列表 early_stop,在使用早停回调函数 (xgb.cb.early.stop()) 时,它将包含诸如 best_iteration 等元素。

详细信息

原始样本被随机划分为 nfold 个大小相等的子样本。

在这 nfold 个子样本中,保留一个子样本作为验证数据用于测试模型,其余的 nfold - 1 个子样本用作训练数据。

然后将交叉验证过程重复 nrounds 次,每次都使用 nfold 个子样本中的一个作为验证数据。

所有观测值都用于训练和验证。

改编自 https://en.wikipedia.org/wiki/Cross-validation_%28statistics%29

示例

data(agaricus.train, package = "xgboost")

dtrain <- with(agaricus.train, xgb.DMatrix(data, label = label, nthread = 2))

cv <- xgb.cv(
  data = dtrain,
  nrounds = 3,
  params = xgb.params(
    nthread = 2,
    max_depth = 3,
    objective = "binary:logistic"
  ),
  nfold = 5,
  metrics = list("rmse","auc")
)
print(cv)
print(cv, verbose = TRUE)