基于XGBoost模型预测数据值。
用法
# S3 method for class 'xgboost'
predict(
object,
newdata,
type = "response",
base_margin = NULL,
iteration_range = NULL,
validate_features = TRUE,
...
)参数
- 对象
由函数
xgboost()生成的xgboost类的XGBoost模型对象。请注意,对于由
xgb.train()生成的xgb.Booster类模型,还有一个较低级别的predict.xgb.Booster()方法,它也可以作为xgboost类模型的替代方法,执行更少的验证和后处理。- newdata
用于计算
object中传递的模型预测值的数据。支持的输入类有:数据框(来自R基础包的
data.frame类及其子类,如data.table)。矩阵(来自R基础包的
matrix类)。来自
Matrix包的稀疏矩阵,可以是dgRMatrix类(CSR)或dgCMatrix类(CSC)。来自
Matrix包的稀疏向量,将被解释为包含单个观测值。
对于数据框,如果存在任何分类特征,它们应该是
factor类,并且应该与构建模型时使用的数据的factor列具有相同的水平。任何类型不是factor的列都将被解释为数值型。如果存在命名列,并且模型是使用命名列的数据拟合的,那么默认情况下它们将按名称匹配(参见
validate_features)。- type
要进行的预测类型。支持的选项有:
"response":将输出响应变量尺度的模型预测(例如,二元分类中属于最后一类的概率)。结果将是一个长度与newdata行数匹配的数值向量,或者一个形状为[nrows(newdata), nscores]的数值矩阵(对于每个观测值产生多个分数的模型,如多类分类或多分位数回归)。"raw":将输出未处理的提升分数(例如,对于目标binary:logistic,为对数几率)。输出形状和类型与"response"相同。"class":将输出具有最高预测概率的类别,返回为factor(仅适用于分类目标),其长度与newdata的行数匹配。"leaf":将输出每个观测值在每棵树上的叶节点索引,作为形状为[nrows(newdata), ntrees]的整数矩阵,或者作为具有额外一个或两个维度的整数数组,对于每个树产生多个分数和/或具有多个并行树的模型(例如随机森林),最多为[nrows(newdata), ntrees, nscores, n_parallel_trees]。仅适用于基于树的增强器(不适用于
gblinear)。"contrib":将基于SHAP值,为给定观测值生成对模型分数的影响估计(按特征)。贡献值处于未转换边距的尺度上(例如,对于二元分类,这些值是与基线相比的对数几率偏差)。输出将是一个形状为
[nrows, nfeatures+1]的数值矩阵,其中截距是最后一个特征;如果模型为每个观测值生成多个分数,则输出将是一个形状为[nrows, nscores, nfeatures+1]的数值数组。"interaction":类似于"contrib",但计算每个特征对交互的SHAP值。请注意,此操作在计算和内存方面可能相当昂贵。由于它与特征数量呈平方关系,建议首先选择最重要的特征。
输出将是一个形状为
[nrows, nfeatures+1, nfeatures+1]的数值数组,或者(对于每个观测值产生多个分数的模型)形状为[nrows, nscores, nfeatures+1, nfeatures+1]的数值数组。
- 基本边距
从现有模型进行提升的基线边距(原始分数,独立于模型中的树添加到所有观测值中)。
如果提供,应该是一个向量,其长度等于
newdata中的行数(对于每个观测值生成单个分数的模型),或者是一个矩阵,其行数与newdata中的行数匹配,列数与模型估计的分数数量匹配(例如,多类分类的类别数量)。- iteration_range
用于预测的模型轮次/迭代序列,通过传递一个包含序列开始和结束数字的二维向量来指定(与R的
seq格式相同——即,基于1的索引,并包含两端)。例如,传递
c(1,20)将使用前二十次迭代进行预测,而传递c(1,1)将仅使用第一次迭代进行预测。如果传递
NULL,则如果模型使用了提前停止,则将在最佳迭代处停止,否则将使用所有迭代(轮次)。如果传递"all",无论模型是否提前停止,都将使用所有轮次。
不适用于
gblinear增强器。- validate_features
验证数据中的特征名称是否与列中的特征名称匹配,否则重新排序数据中的特征。
如果传递
FALSE,则假定特征名称和类型相同,并且与训练数据中的顺序相同。请注意,这仅适用于列名,而不适用于分类列中的因子水平。
请注意,此检查可能会为预测增加一些可观的延迟,因此建议在对性能敏感的应用程序中禁用它。
- ...
未使用。
示例
data("ToothGrowth")
y <- ToothGrowth$supp
x <- ToothGrowth[, -2L]
model <- xgboost(x, y, nthreads = 1L, nrounds = 3L, max_depth = 2L)
pred_prob <- predict(model, x[1:5, ], type = "response")
pred_raw <- predict(model, x[1:5, ], type = "raw")
pred_class <- predict(model, x[1:5, ], type = "class")
# Relationships between these
manual_probs <- 1 / (1 + exp(-pred_raw))
manual_class <- ifelse(manual_probs < 0.5, levels(y)[1], levels(y)[2])
# They should match up to numerical precision
round(pred_prob, 6) == round(manual_probs, 6)
pred_class == manual_class