跳至内容

可视化 SHAP 值与特征值的关系,以了解特征效应。

用法

xgb.plot.shap(
  data,
  shap_contrib = NULL,
  features = NULL,
  top_n = 1,
  model = NULL,
  trees = NULL,
  target_class = NULL,
  approxcontrib = FALSE,
  subsample = NULL,
  n_col = 1,
  col = rgb(0, 0, 1, 0.2),
  pch = ".",
  discrete_n_uniq = 5,
  discrete_jitter = 0.01,
  ylab = "SHAP",
  plot_NA = TRUE,
  col_NA = rgb(0.7, 0, 1, 0.6),
  pch_NA = ".",
  pos_NA = 1.07,
  plot_loess = TRUE,
  col_loess = 2,
  span_loess = 0.5,
  which = c("1d", "2d"),
  plot = TRUE,
  ...
)

参数

data

待解释的数据,可以是 matrix, dgCMatrixdata.frame

shap_contrib

data 的 SHAP 贡献矩阵。默认值 (NULL) 会从 modeldata 计算得出。

features

要绘制的列索引或特征名称向量。当为 NULL (默认) 时,将通过 xgb.importance() 选择 top_n 个最重要的特征。

top_n

应选择多少个最重要的特征 (<= 100)? 默认 SHAP 依赖图为 1 个,SHAP 摘要图为 10 个。仅在 features = NULL 时使用。

model

一个 xgb.Booster 模型。仅在 shap_contrib = NULLfeatures = NULL 时需要。

trees

features = NULL 时,此参数传递给 xgb.importance()

target_class

仅与多分类模型相关。默认值 (NULL) 会将 SHAP 值在所有类别上平均。传递一个 (基于 0 的) 类别索引以仅显示该类别的 SHAP 值。

approxcontrib

shap_contrib = NULL 时,此参数传递给 predict.xgb.Booster()

subsample

随机选取用于绘图的数据点比例。默认值 (NULL) 将最多使用 10 万个数据点。

n_col

绘图网格中的列数。

col

散点图标记的颜色。

pch

散点图标记样式。

discrete_n_uniq

将特征视为离散型特征的最大唯一特征值数量。

discrete_jitter

添加到离散特征值中的抖动量。

ylab

一维图中的 y 轴标签。

plot_NA

是否绘制包含缺失值情况的贡献?默认为 TRUE

col_NA

缺失值贡献标记的颜色。

pch_NA

NA 值的标记类型。

pos_NA

NA 值显示位置的 x 坐标相对位置:min(x) + (max(x) - min(x)) * pos_NA

plot_loess

是否绘制 LOESS 平滑曲线?(默认为 TRUE)。仅对具有超过 5 个不同值的特征进行平滑处理。

col_loess

LOESS 曲线的颜色。

span_loess

stats::loess() 函数的 span 参数。

which

是进行单变量还是双变量绘图。目前仅实现了“1d”。

plot

是否绘制图表?(默认为 TRUE)。如果为 FALSE,则仅返回一个矩阵列表。

...

传递给 graphics::plot() 的其他参数。

返回值

除了生成图表 (当 plot = TRUE 时) 外,它还会默默返回一个包含两个矩阵的列表:

  • data:特征值矩阵。

  • shap_contrib:相应的 SHAP 值矩阵。

详情

这些散点图显示了 SHAP 特征贡献如何依赖于特征值。它与部分依赖图(partial dependence plots)的相似之处在于,它们都能提示特征值如何影响预测。然而,在部分依赖图中,我们看到的是模型预测对特征值的边际依赖性,而 SHAP 依赖图则显示了特征对每个个体情况预测的估计贡献。

plot_loess = TRUE 时,特征值将四舍五入到三个有效数字,并计算和绘制加权 LOESS,其中权重是每个四舍五入值对应的数据点数量。

注意:SHAP 贡献位于模型边际的尺度上。例如,对于逻辑二项式目标,边际在 log-odds 尺度上。此外,由于 SHAP 代表“SHapley Additive exPlanation”(Shapley 加性解释,即模型预测 = 所有特征的 SHAP 贡献之和 + 偏差),根据使用的目标函数,将一个特征的 SHAP 贡献从边际空间转换到预测空间不一定有意义。

参考文献

  1. Scott M. Lundberg, Su-In Lee, "A Unified Approach to Interpreting Model Predictions", NIPS Proceedings 2017, https://arxiv.org/abs/1705.07874

  2. Scott M. Lundberg, Su-In Lee, "Consistent feature attribution for tree ensembles", https://arxiv.org/abs/1706.06060

示例


data(agaricus.train, package = "xgboost")
data(agaricus.test, package = "xgboost")

## Keep the number of threads to 1 for examples
nthread <- 1
data.table::setDTthreads(nthread)
nrounds <- 20

model_binary <- xgboost(
  agaricus.train$data, factor(agaricus.train$label),
  nrounds = nrounds,
  verbosity = 0L,
  learning_rate = 0.1,
  max_depth = 3L,
  subsample = 0.5,
  nthreads = nthread
)

xgb.plot.shap(agaricus.test$data, model = model_binary, features = "odor=none")

contr <- predict(model_binary, agaricus.test$data, type = "contrib")
xgb.plot.shap(agaricus.test$data, contr, model = model_binary, top_n = 12, n_col = 3)

# Summary plot
xgb.ggplot.shap.summary(agaricus.test$data, contr, model = model_binary, top_n = 12)

# Multiclass example - plots for each class separately:
x <- as.matrix(iris[, -5])
set.seed(123)
is.na(x[sample(nrow(x) * 4, 30)]) <- TRUE # introduce some missing values

model_multiclass <- xgboost(
  x, iris$Species,
  nrounds = nrounds,
  verbosity = 0,
  max_depth = 2,
  subsample = 0.5,
  nthreads = nthread
)
nclass <- 3
trees0 <- seq(from = 1, by = nclass, length.out = nrounds)
col <- rgb(0, 0, 1, 0.5)

xgb.plot.shap(
  x,
  model = model_multiclass,
  trees = trees0,
  target_class = 0,
  top_n = 4,
  n_col = 2,
  col = col,
  pch = 16,
  pch_NA = 17
)

xgb.plot.shap(
  x,
  model = model_multiclass,
  trees = trees0 + 1,
  target_class = 1,
  top_n = 4,
  n_col = 2,
  col = col,
  pch = 16,
  pch_NA = 17
)

xgb.plot.shap(
  x,
  model = model_multiclass,
  trees = trees0 + 2,
  target_class = 2,
  top_n = 4,
  n_col = 2,
  col = col,
  pch = 16,
  pch_NA = 17
)

# Summary plot
xgb.ggplot.shap.summary(x, model = model_multiclass, target_class = 0, top_n = 4)