跳到内容

将 SHAP 值与特征值进行可视化,以了解特征效应。

用法

xgb.plot.shap(
  data,
  shap_contrib = NULL,
  features = NULL,
  top_n = 1,
  model = NULL,
  trees = NULL,
  target_class = NULL,
  approxcontrib = FALSE,
  subsample = NULL,
  n_col = 1,
  col = rgb(0, 0, 1, 0.2),
  pch = ".",
  discrete_n_uniq = 5,
  discrete_jitter = 0.01,
  ylab = "SHAP",
  plot_NA = TRUE,
  col_NA = rgb(0.7, 0, 1, 0.6),
  pch_NA = ".",
  pos_NA = 1.07,
  plot_loess = TRUE,
  col_loess = 2,
  span_loess = 0.5,
  which = c("1d", "2d"),
  plot = TRUE,
  ...
)

参数

数据

用于解释的数据,可以是 matrixdgCMatrixdata.frame

shap_contrib

data 的 SHAP 贡献矩阵。默认值 (NULL) 会根据 modeldata 计算。

features

要绘制的列索引或特征名称的向量。当为 NULL(默认值)时,将通过 xgb.importance() 选择 top_n 个最重要的特征。

top_n

应该选择多少个最重要的特征(<= 100)?对于 SHAP 依赖图,默认为 1;对于 SHAP 摘要图,默认为 10。仅当 features = NULL 时使用。

model

一个 xgb.Booster 模型。仅当 shap_contrib = NULLfeatures = NULL 时需要。

trees

features = NULL 时,传递给 xgb.importance()

target_class

仅与多类别模型相关。默认值 (NULL) 会对所有类别的 SHAP 值取平均。传入一个(基于 0 的)类别索引以仅显示该类别的 SHAP 值。

approxcontrib

shap_contrib = NULL 时,传递给 predict.xgb.Booster()

subsample

随机选择用于绘图的数据点的比例。默认值 (NULL) 将使用多达 10 万个数据点。

n_col

绘图网格中的列数。

col

散点图标记的颜色。

pch

散点图标记。

discrete_n_uniq

将特征视为离散特征的最大唯一特征值数量。

discrete_jitter

添加到离散特征值中的抖动量。

ylab

一维绘图中的 y 轴标签。

plot_NA

是否应绘制缺失值案例的贡献?默认值为 TRUE

col_NA

缺失值贡献标记的颜色。

pch_NA

NA 值的标记类型。

pos_NA

显示 NA 值的 x 位置的相对位置:min(x) + (max(x) - min(x)) * pos_NA

plot_loess

是否应绘制局部加权回归 (loess) 平滑曲线?(默认为 TRUE)。平滑仅针对具有超过 5 个不同值的特征进行。

col_loess

loess 曲线的颜色。

span_loess

stats::loess()span 参数。

which

进行单变量还是双变量绘图。目前,仅实现了“1d”。

plot

是否应绘制图表?(默认为 TRUE)。如果为 FALSE,则仅返回矩阵列表。

...

传递给 graphics::plot() 的其他参数。

除了生成绘图(当 plot = TRUE 时),它还会静默返回包含两个矩阵的列表

  • data:特征值矩阵。

  • shap_contrib:对应的 SHAP 值矩阵。

详细信息

这些散点图表示 SHAP 特征贡献如何依赖于特征值。它们与部分依赖图的相似之处在于,它们也提供了特征值如何影响预测的线索。然而,在部分依赖图中,我们看到模型预测对特征值的边际依赖性,而 SHAP 依赖图显示了每个个体案例中特征对预测的估计贡献。

plot_loess = TRUE 时,特征值四舍五入到三个有效数字,并计算并绘制加权 LOESS,其中权重是每个四舍五入值处的数据点数量。

注意:SHAP 贡献在模型边际尺度上。例如,对于逻辑二项式目标,边际在对数几率尺度上。此外,由于 SHAP 代表“SHapley Additive exPlanation”(模型预测 = 所有特征的 SHAP 贡献之和 + 偏差),因此,根据所使用的目标,将特征的 SHAP 贡献从边际空间转换为预测空间不一定是有意义的。

参考文献

  1. Scott M. Lundberg, Su-In Lee, "A Unified Approach to Interpreting Model Predictions", NIPS Proceedings 2017, https://arxiv.org/abs/1705.07874

  2. Scott M. Lundberg, Su-In Lee, "Consistent feature attribution for tree ensembles", https://arxiv.org/abs/1706.06060

示例


data(agaricus.train, package = "xgboost")
data(agaricus.test, package = "xgboost")

## Keep the number of threads to 1 for examples
nthread <- 1
data.table::setDTthreads(nthread)
nrounds <- 20

model_binary <- xgboost(
  agaricus.train$data, factor(agaricus.train$label),
  nrounds = nrounds,
  verbosity = 0L,
  learning_rate = 0.1,
  max_depth = 3L,
  subsample = 0.5,
  nthreads = nthread
)

xgb.plot.shap(agaricus.test$data, model = model_binary, features = "odor=none")

contr <- predict(model_binary, agaricus.test$data, type = "contrib")
xgb.plot.shap(agaricus.test$data, contr, model = model_binary, top_n = 12, n_col = 3)

# Summary plot
xgb.ggplot.shap.summary(agaricus.test$data, contr, model = model_binary, top_n = 12)

# Multiclass example - plots for each class separately:
x <- as.matrix(iris[, -5])
set.seed(123)
is.na(x[sample(nrow(x) * 4, 30)]) <- TRUE # introduce some missing values

model_multiclass <- xgboost(
  x, iris$Species,
  nrounds = nrounds,
  verbosity = 0,
  max_depth = 2,
  subsample = 0.5,
  nthreads = nthread
)
nclass <- 3
trees0 <- seq(from = 1, by = nclass, length.out = nrounds)
col <- rgb(0, 0, 1, 0.5)

xgb.plot.shap(
  x,
  model = model_multiclass,
  trees = trees0,
  target_class = 0,
  top_n = 4,
  n_col = 2,
  col = col,
  pch = 16,
  pch_NA = 17
)

xgb.plot.shap(
  x,
  model = model_multiclass,
  trees = trees0 + 1,
  target_class = 1,
  top_n = 4,
  n_col = 2,
  col = col,
  pch = 16,
  pch_NA = 17
)

xgb.plot.shap(
  x,
  model = model_multiclass,
  trees = trees0 + 2,
  target_class = 2,
  top_n = 4,
  n_col = 2,
  col = col,
  pch = 16,
  pch_NA = 17
)

# Summary plot
xgb.ggplot.shap.summary(x, model = model_multiclass, target_class = 0, top_n = 4)