使用加速失效时间进行生存分析
什么是生存分析?
生存分析(回归)对事件发生的时间进行建模。生存分析是一种特殊的回归,与传统的回归任务不同,其特点如下:
标签始终为正,因为事件发生前的等待时间不能为负。
标签可能无法完全得知,或称为截尾,因为“测量时间需要时间”。
第二点至关重要,我们应该对此进行更深入的探讨。正如您可能从名称中猜到的那样,生存分析最早的应用之一是模拟给定人群的死亡率。我们以NCCTG 肺癌数据集为例。前 8 列代表特征,最后一列“死亡时间”代表标签。
机构 |
年龄 |
性别 |
ph.ecog |
ph.karno |
pat.karno |
餐食热量 |
体重减轻 |
死亡时间(天) |
---|---|---|---|---|---|---|---|---|
3 |
74 |
1 |
1 |
90 |
100 |
1175 |
不适用 |
306 |
3 |
68 |
1 |
0 |
90 |
90 |
1225 |
15 |
455 |
3 |
56 |
1 |
0 |
90 |
90 |
不适用 |
15 |
\([1010, +\infty)\) |
5 |
57 |
1 |
1 |
90 |
60 |
1150 |
11 |
210 |
1 |
60 |
1 |
0 |
100 |
90 |
不适用 |
0 |
883 |
12 |
74 |
1 |
1 |
50 |
80 |
513 |
0 |
\([1022, +\infty)\) |
7 |
68 |
2 |
2 |
70 |
60 |
384 |
10 |
310 |
仔细查看第三位患者的标签。他的标签是一个范围,而不是一个单一的数字。第三位患者的标签被称为截尾,因为由于某种原因,实验人员无法获得该标签的完整测量值。一种可能的情况是:患者在第 1010 天存活下来,并在第 1011 天离开诊所,因此他的死亡未被直接观察到。另一种可能性是:实验在他死亡被观察到之前就提前结束了(因为你不能永远进行实验)。无论如何,他的标签是 \([1010, +\infty)\),这意味着他的死亡时间可以是任何大于 1010 的数字,例如 2000、3000 或 10000。
有四种截尾类型
无截尾:标签未截尾,以单个数字给出。
右截尾:标签形式为 \([a, +\infty)\),其中 \(a\) 是下限。
左截尾:标签形式为 \([0, b]\),其中 \(b\) 是上限。
区间截尾:标签形式为 \([a, b]\),其中 \(a\) 和 \(b\) 分别是下限和上限。
右截尾是最常用的。
加速失效时间模型
加速失效时间(AFT)模型是生存分析中最常用的模型之一。该模型形式如下:
其中
\(\mathbf{x}\) 是 \(\mathbb{R}^d\) 中的一个向量,表示特征。
\(\mathbf{w}\) 是一个包含 \(d\) 个系数的向量,每个系数对应一个特征。
\(\langle \cdot, \cdot \rangle\) 是 \(\mathbb{R}^d\) 中的常用点积。
\(\ln{(\cdot)}\) 是自然对数。
\(Y\) 和 \(Z\) 是随机变量。
\(Y\) 是输出标签。
\(Z\) 是已知概率分布的随机变量。常见的选择是正态分布、逻辑分布和极值分布。直观地说,\(Z\) 代表将预测 \(\langle \mathbf{w}, \mathbf{x} \rangle\) 从真实对数标签 \(\ln{Y}\) 拉开的“噪声”。
\(\sigma\) 是一个参数,用于缩放 \(Z\) 的大小。
请注意,此模型是线性回归模型 \(Y = \langle \mathbf{w}, \mathbf{x} \rangle\) 的广义形式。为了使 AFT 与梯度提升协同工作,我们对模型进行如下修改:
其中 \(\mathcal{T}(\mathbf{x})\) 表示给定输入 \(\mathbf{x}\) 时,决策树集成的输出。由于 \(Z\) 是一个随机变量,我们为表达式 \(\ln{Y} = \mathcal{T}(\mathbf{x}) + \sigma Z\) 定义了一个似然。因此,XGBoost 的目标是通过拟合一个好的树集成 \(\mathcal{T}(\mathbf{x})\) 来最大化(对数)似然。
如何使用
第一步是将标签表示为范围形式,以便每个数据点都与其相关联两个数字,即标签的下限和上限。对于未截尾的标签,使用 \([a, a]\) 形式的退化区间。
截尾类型 |
区间形式 |
下限有限? |
上限有限? |
---|---|---|---|
无截尾 |
\([a, a]\) |
✔ |
✔ |
右截尾 |
\([a, +\infty)\) |
✔ |
✘ |
左截尾 |
\([0, b]\) |
✔ |
✔ |
区间截尾 |
\([a, b]\) |
✔ |
✔ |
将下限数字收集到一个数组中(我们称之为 y_lower_bound
),将上限数字收集到另一个数组中(称之为 y_upper_bound
)。通过调用 xgboost.DMatrix.set_float_info()
将范围标签与数据矩阵对象关联起来。
import numpy as np
import xgboost as xgb
# 4-by-2 Data matrix
X = np.array([[1, -1], [-1, 1], [0, 1], [1, 0]])
dtrain = xgb.DMatrix(X)
# Associate ranged labels with the data matrix.
# This example shows each kind of censored labels.
# uncensored right left interval
y_lower_bound = np.array([ 2.0, 3.0, 0.0, 4.0])
y_upper_bound = np.array([ 2.0, +np.inf, 4.0, 5.0])
dtrain.set_float_info('label_lower_bound', y_lower_bound)
dtrain.set_float_info('label_upper_bound', y_upper_bound)
library(xgboost)
# 4-by-2 Data matrix
X <- matrix(c(1., -1., -1., 1., 0., 1., 1., 0.),
nrow=4, ncol=2, byrow=TRUE)
dtrain <- xgb.DMatrix(X)
# Associate ranged labels with the data matrix.
# This example shows each kind of censored labels.
# uncensored right left interval
y_lower_bound <- c( 2., 3., 0., 4.)
y_upper_bound <- c( 2., +Inf, 4., 5.)
setinfo(dtrain, 'label_lower_bound', y_lower_bound)
setinfo(dtrain, 'label_upper_bound', y_upper_bound)
现在我们准备调用训练 API
params = {'objective': 'survival:aft',
'eval_metric': 'aft-nloglik',
'aft_loss_distribution': 'normal',
'aft_loss_distribution_scale': 1.20,
'tree_method': 'hist', 'learning_rate': 0.05, 'max_depth': 2}
bst = xgb.train(params, dtrain, num_boost_round=5,
evals=[(dtrain, 'train')])
params <- list(objective='survival:aft',
eval_metric='aft-nloglik',
aft_loss_distribution='normal',
aft_loss_distribution_scale=1.20,
tree_method='hist',
learning_rate=0.05,
max_depth=2)
watchlist <- list(train = dtrain)
bst <- xgb.train(params, dtrain, nrounds=5, watchlist)
我们将 objective
参数设置为 survival:aft
,将 eval_metric
设置为 aft-nloglik
,以便最大化 AFT 模型的对数似然。(XGBoost 实际上会最小化负对数似然,因此得名 aft-nloglik
。)
参数 aft_loss_distribution
对应于 AFT 模型中 \(Z\) 项的分布,aft_loss_distribution_scale
对应于比例因子 \(\sigma\)。
目前,您可以为 aft_loss_distribution
选择三种概率分布:
|
概率密度函数(PDF) |
---|---|
|
\(\dfrac{\exp{(-z^2/2)}}{\sqrt{2\pi}}\) |
|
\(\dfrac{e^z}{(1+e^z)^2}\) |
|
\(e^z e^{-\exp{z}}\) |
请注意,目前尚无法使用 scikit-learn 接口(例如 xgboost.XGBRegressor
)设置范围标签。目前,您应该使用 xgboost.train
和 xgboost.DMatrix
。有关 Python 示例的集合,请参阅生存分析演练