跳到内容

基于XGBoost模型预测数据值。

用法

# S3 method for class 'xgboost'
predict(
  object,
  newdata,
  type = "response",
  base_margin = NULL,
  iteration_range = NULL,
  validate_features = TRUE,
  ...
)

参数

对象

由函数xgboost()生成的xgboost类的XGBoost模型对象。

请注意,对于由xgb.train()生成的xgb.Booster类模型,还有一个较低级别的predict.xgb.Booster()方法,它也可以作为xgboost类模型的替代方法,执行更少的验证和后处理。

newdata

用于计算object中传递的模型预测值的数据。支持的输入类有:

  • 数据框(来自R基础包的data.frame类及其子类,如data.table)。

  • 矩阵(来自R基础包的matrix类)。

  • 来自Matrix包的稀疏矩阵,可以是dgRMatrix类(CSR)或dgCMatrix类(CSC)。

  • 来自Matrix包的稀疏向量,将被解释为包含单个观测值。

对于数据框,如果存在任何分类特征,它们应该是factor类,并且应该与构建模型时使用的数据的factor列具有相同的水平。任何类型不是factor的列都将被解释为数值型。

如果存在命名列,并且模型是使用命名列的数据拟合的,那么默认情况下它们将按名称匹配(参见validate_features)。

type

要进行的预测类型。支持的选项有:

  • "response":将输出响应变量尺度的模型预测(例如,二元分类中属于最后一类的概率)。结果将是一个长度与newdata行数匹配的数值向量,或者一个形状为[nrows(newdata), nscores]的数值矩阵(对于每个观测值产生多个分数的模型,如多类分类或多分位数回归)。

  • "raw":将输出未处理的提升分数(例如,对于目标binary:logistic,为对数几率)。输出形状和类型与"response"相同。

  • "class":将输出具有最高预测概率的类别,返回为factor(仅适用于分类目标),其长度与newdata的行数匹配。

  • "leaf":将输出每个观测值在每棵树上的叶节点索引,作为形状为[nrows(newdata), ntrees]的整数矩阵,或者作为具有额外一个或两个维度的整数数组,对于每个树产生多个分数和/或具有多个并行树的模型(例如随机森林),最多为[nrows(newdata), ntrees, nscores, n_parallel_trees]

    仅适用于基于树的增强器(不适用于gblinear)。

  • "contrib":将基于SHAP值,为给定观测值生成对模型分数的影响估计(按特征)。贡献值处于未转换边距的尺度上(例如,对于二元分类,这些值是与基线相比的对数几率偏差)。

    输出将是一个形状为[nrows, nfeatures+1]的数值矩阵,其中截距是最后一个特征;如果模型为每个观测值生成多个分数,则输出将是一个形状为[nrows, nscores, nfeatures+1]的数值数组。

  • "interaction":类似于"contrib",但计算每个特征对交互的SHAP值。请注意,此操作在计算和内存方面可能相当昂贵。

    由于它与特征数量呈平方关系,建议首先选择最重要的特征。

    输出将是一个形状为[nrows, nfeatures+1, nfeatures+1]的数值数组,或者(对于每个观测值产生多个分数的模型)形状为[nrows, nscores, nfeatures+1, nfeatures+1]的数值数组。

基本边距

从现有模型进行提升的基线边距(原始分数,独立于模型中的树添加到所有观测值中)。

如果提供,应该是一个向量,其长度等于newdata中的行数(对于每个观测值生成单个分数的模型),或者是一个矩阵,其行数与newdata中的行数匹配,列数与模型估计的分数数量匹配(例如,多类分类的类别数量)。

iteration_range

用于预测的模型轮次/迭代序列,通过传递一个包含序列开始和结束数字的二维向量来指定(与R的seq格式相同——即,基于1的索引,并包含两端)。

例如,传递c(1,20)将使用前二十次迭代进行预测,而传递c(1,1)将仅使用第一次迭代进行预测。

如果传递NULL,则如果模型使用了提前停止,则将在最佳迭代处停止,否则将使用所有迭代(轮次)。

如果传递"all",无论模型是否提前停止,都将使用所有轮次。

不适用于gblinear增强器。

validate_features

验证数据中的特征名称是否与列中的特征名称匹配,否则重新排序数据中的特征。

如果传递FALSE,则假定特征名称和类型相同,并且与训练数据中的顺序相同。

请注意,这仅适用于列名,而不适用于分类列中的因子水平。

请注意,此检查可能会为预测增加一些可观的延迟,因此建议在对性能敏感的应用程序中禁用它。

...

未使用。

可以是数值向量(对于一维输出)、数值矩阵(对于二维输出)、数值数组(对于三维及更高维度),或factor(对于类别预测)。有关输出类型和形状的详细信息,请参见参数type的文档。

示例

data("ToothGrowth")
y <- ToothGrowth$supp
x <- ToothGrowth[, -2L]
model <- xgboost(x, y, nthreads = 1L, nrounds = 3L, max_depth = 2L)
pred_prob <- predict(model, x[1:5, ], type = "response")
pred_raw <- predict(model, x[1:5, ], type = "raw")
pred_class <- predict(model, x[1:5, ], type = "class")

# Relationships between these
manual_probs <- 1 / (1 + exp(-pred_raw))
manual_class <- ifelse(manual_probs < 0.5, levels(y)[1], levels(y)[2])

# They should match up to numerical precision
round(pred_prob, 6) == round(manual_probs, 6)
pred_class == manual_class