跳到内容

根据 XGBoost 模型预测数据值。

用法

# S3 method for class 'xgboost'
predict(
  object,
  newdata,
  type = "response",
  base_margin = NULL,
  iteration_range = NULL,
  validate_features = TRUE,
  ...
)

参数

object

一个类为 xgboost 的 XGBoost 模型对象,由函数 xgboost() 生成。

请注意,对于由 xgb.train() 生成的类为 xgb.Booster 的模型,还有一个低级的 predict.xgb.Booster() 方法,它也可以作为 xgboost 类模型的替代方法使用,执行的验证和后处理较少。

newdata

用于计算 object 中模型预测值的数据。支持的输入类别包括:

  • 数据框(来自 base R 的类 data.frame 及其子类,如 data.table)。

  • 矩阵(来自 base R 的类 matrix)。

  • 来自 Matrix 包的稀疏矩阵,可以是类 dgRMatrix (CSR) 或 dgCMatrix (CSC)。

  • 来自 Matrix 包的稀疏向量,将被解释为包含一个单独的观察值。

对于数据框,如果存在任何分类特征,它们应该是类 factor,并且必须与构建模型时使用的数据中的 factor 列具有相同的水平。任何类型不是 factor 的列将被解释为数值型。

如果存在命名的列,并且模型是使用命名的列的数据进行拟合的,默认情况下它们将按名称进行匹配(参见 validate_features)。

type

预测的类型。支持的选项包括:

  • "response":将输出模型在响应变量尺度上的预测值(例如,二分类情况下属于最后一个类别的概率)。结果将是与 newdata 行数相匹配的数值向量,或者是一个形状为 [nrows(newdata), nscores] 的数值矩阵(对于每个观察值产生多个得分的目标,如多类别分类或多分位数回归)。

  • "raw":将输出未处理的 boosting 得分(例如,对于目标 binary:logistic,为 log-odds)。输出形状和类型与 "response" 相同。

  • "class":将输出预测概率最高的类别,返回为与 newdata 行数相匹配的 factor 类型(仅适用于分类目标)。

  • "leaf":将输出每个观察值在每棵树上的终端节点索引,为一个形状为 [nrows(newdata), ntrees] 的整数矩阵,或者为一个具有额外一个或两个维度的整数数组,最大形状为 [nrows(newdata), ntrees, nscores, n_parallel_trees],用于每个树产生多个得分和/或具有多个并行树的模型(例如,随机森林)。

    仅适用于基于树的 booster(不包括 gblinear)。

  • "contrib":将基于 SHAP 值,生成每个特征对给定观察值模型得分的贡献估计。贡献值在未转换的 margin 尺度上(例如,对于二分类,值为相对于基线的 log-odds 偏差)。

    输出将是一个形状为 [nrows, nfeatures+1] 的数值矩阵,其中最后一个特征是截距,或者如果模型对每个观察值产生多个得分,则输出形状为 [nrows, nscores, nfeatures+1] 的数值数组。

  • "interaction":类似于 "contrib",但计算每一对特征交互贡献的 SHAP 值。请注意,此操作在计算和内存方面可能相当昂贵。

    由于它与特征数量呈平方关系,建议首先进行最重要的特征选择。

    输出将是一个形状为 [nrows, nfeatures+1, nfeatures+1] 的数值数组,或者(对于每个观察值产生多个得分的目标)形状为 [nrows, nscores, nfeatures+1, nfeatures+1]

base_margin

用于从现有模型进行 boosting 的基础 margin(原始得分,独立于模型中的树添加到所有观察值)。

如果提供,应该是一个长度等于 newdata 行数的向量(对于每个观察值产生单个得分的目标),或者是一个行数与 newdata 行数匹配且列数与模型估计的得分数匹配的矩阵(例如,多类别分类的类别数)。

iteration_range

用于预测的模型回合/迭代序列,通过传递一个包含序列开始和结束数字的二维向量来指定(格式与 R 的 seq 相同 - 即,基于 1 的索引,并包含两端)。

例如,传递 c(1,20) 将使用前二十次迭代进行预测,而传递 c(1,1) 将仅使用第一次迭代。

如果传递 NULL,则如果模型使用了早停,将停止在最佳迭代处;否则,将使用所有迭代(回合)。

如果传递 "all",无论模型是否使用了早停,都将使用所有回合。

不适用于 gblinear booster。

validate_features

验证数据中的特征名称是否与列中的特征名称匹配,如果不匹配则重新排序数据中的特征。

如果传递 FALSE,则假定特征名称和类型与训练数据中的相同且顺序一致。

请注意,这仅适用于列名,而不适用于分类列中的因子水平。

请注意,此检查可能会为预测增加相当大的延迟,因此建议在对性能敏感的应用中禁用它。

...

未使用。

返回值

一个数值向量(对于一维输出)、数值矩阵(对于二维输出)、数值数组(对于三维及更高维度),或 factor(对于类别预测)。有关输出类型和形状的详细信息,请参见参数 type 的文档。

示例

data("ToothGrowth")
y <- ToothGrowth$supp
x <- ToothGrowth[, -2L]
model <- xgboost(x, y, nthreads = 1L, nrounds = 3L, max_depth = 2L)
pred_prob <- predict(model, x[1:5, ], type = "response")
pred_raw <- predict(model, x[1:5, ], type = "raw")
pred_class <- predict(model, x[1:5, ], type = "class")

# Relationships between these
manual_probs <- 1 / (1 + exp(-pred_raw))
manual_class <- ifelse(manual_probs < 0.5, levels(y)[1], levels(y)[2])

# They should match up to numerical precision
round(pred_prob, 6) == round(manual_probs, 6)
pred_class == manual_class