自定义目标函数的高级用法
目录
概述
XGBoost 允许根据用户为所需目标函数提供的梯度和 Hessian 来优化自定义的用户定义函数。
为了使自定义目标按预期工作
要优化的函数必须是平滑的且二次可微。
该函数必须对行/观测值具有可加性,例如具有 i.i.d. 假设的似然函数。
函数的得分范围必须是无界的(即,它不应仅适用于正数,例如)。
该函数必须是凸的。请注意,如果 Hessian 具有负值,它们将被截断,这很可能导致模型无法很好地拟合该函数。
对于多输出目标,不同目标之间不应存在依赖关系(即,对于每一行,Hessian 应为对角线)。
尽管如此,其中一些限制可以通过放弃函数的真实 Hessian,转而使用其他具有更好性质的近似值来解决——当不使用函数的真实 Hessian 时,收敛可能会更慢,但许多理论保证仍然成立并产生可用的模型。例如,XGBoost 的多项式逻辑回归内部实现使用具有对角线结构的 Hessian 的上限,而不是数据中每一行的完整方阵形式的真实 Hessian。
本教程通过展示如何解决由浓度参数化的狄利克雷回归,为不完全符合上述标准的用例提供了一些建议。
狄利克雷回归模型给 XGBoost 带来了一些挑战
浓度参数必须为正。实现这一目标的一种简单方法是对原始无界值应用“exp”变换,但在这种情况下,目标变为非凸。此外,请注意,与 GLM 模型中使用的典型分布不同,此函数不属于指数族。
Hessian 在目标之间存在依赖关系——也就是说,对于具有“k”个参数的狄利克雷分布,每一行将具有维度为
[k, k]
的完整 Hessian 矩阵。这种模型的最佳截距将涉及一个值向量,而不是每个目标都相同的值。
为了将这种类型的模型用作自定义目标
可以使用期望 Hessian(又称 Fisher 信息矩阵或期望信息)代替真实 Hessian。即使真实 Hessian 不为正半定,期望 Hessian 对于加性似然函数也始终为正半定。
可以使用具有对角线结构的期望 Hessian 的上限,这样在此对角线上限下的二阶近似将始终产生大于或等于非对角线期望 Hessian 下的函数值。
由于 XGBoost 用于截距的
base_score
参数仅限于标量,因此可以使用base_margin
功能,但请注意,使用它需要更多的精力。
狄利克雷回归公式
狄利克雷分布是 Beta 分布到多个维度的一般化。它模拟值总和为 1 的比例数据,通常用作复合模型(例如狄利克雷-多项式)的一部分或贝叶斯模型中的先验,但它也可以单独用于比例数据。
对于给定观测值 y
和给定预测 x
,其似然性如下所示
其中
在这种情况下,我们希望优化按行求和的负对数似然。由此产生的函数、梯度和 Hessian 可以按如下方式实现
import numpy as np
from scipy.special import loggamma, psi as digamma, polygamma
trigamma = lambda x: polygamma(1, x)
def dirichlet_fun(pred: np.ndarray, Y: np.ndarray) -> float:
epred = np.exp(pred)
sum_epred = np.sum(epred, axis=1, keepdims=True)
return (
loggamma(epred).sum()
- loggamma(sum_epred).sum()
- np.sum(np.log(Y) * (epred - 1))
)
def dirichlet_grad(pred: np.ndarray, Y: np.ndarray) -> np.ndarray:
epred = np.exp(pred)
return epred * (
digamma(epred)
- digamma(np.sum(epred, axis=1, keepdims=True))
- np.log(Y)
)
def dirichlet_hess(pred: np.ndarray, Y: np.ndarray) -> np.ndarray:
epred = np.exp(pred)
grad = dirichlet_grad(pred, Y)
k = Y.shape[1]
H = np.empty((pred.shape[0], k, k))
for row in range(pred.shape[0]):
H[row, :, :] = (
- trigamma(epred[row].sum()) * np.outer(epred[row], epred[row])
+ np.diag(grad[row] + trigamma(epred[row]) * epred[row] ** 2)
)
return H
softmax <- function(x) {
max.x <- max(x)
e <- exp(x - max.x)
return(e / sum(e))
}
dirichlet.fun <- function(pred, y) {
epred <- exp(pred)
sum_epred <- rowSums(epred)
return(
sum(lgamma(epred))
- sum(lgamma(sum_epred))
- sum(log(y) * (epred - 1))
)
}
dirichlet.grad <- function(pred, y) {
epred <- exp(pred)
return(
epred * (
digamma(epred)
- digamma(rowSums(epred))
- log(y)
)
)
}
dirichlet.hess <- function(pred, y) {
epred <- exp(pred)
grad <- dirichlet.grad(pred, y)
k <- ncol(y)
H <- array(dim = c(nrow(y), k, k))
for (row in seq_len(nrow(y))) {
H[row, , ] <- (
- trigamma(sum(epred[row,])) * tcrossprod(epred[row,])
+ diag(grad[row,] + trigamma(epred[row,]) * epred[row,]^2)
)
}
return(H)
}
请自行验证实现是否正确
from math import isclose
from scipy import stats
from scipy.optimize import check_grad
from scipy.special import softmax
def gen_random_dirichlet(rng: np.random.Generator, m: int, k: int):
alpha = np.exp(rng.standard_normal(size=k))
return rng.dirichlet(alpha, size=m)
def test_dirichlet_fun_grad_hess():
k = 3
m = 10
rng = np.random.default_rng(seed=123)
Y = gen_random_dirichlet(rng, m, k)
x0 = rng.standard_normal(size=k)
for row in range(Y.shape[0]):
fun_row = dirichlet_fun(x0.reshape((1,-1)), Y[[row]])
ref_logpdf = stats.dirichlet.logpdf(
Y[row] / Y[row].sum(), # <- avoid roundoff error
np.exp(x0),
)
assert isclose(fun_row, -ref_logpdf)
gdiff = check_grad(
lambda pred: dirichlet_fun(pred.reshape((1,-1)), Y[[row]]),
lambda pred: dirichlet_grad(pred.reshape((1,-1)), Y[[row]]),
x0
)
assert gdiff <= 1e-6
H_numeric = np.empty((k,k))
eps = 1e-7
for ii in range(k):
x0_plus_eps = x0.reshape((1,-1)).copy()
x0_plus_eps[0,ii] += eps
for jj in range(k):
H_numeric[ii, jj] = (
dirichlet_grad(x0_plus_eps, Y[[row]])[0][jj]
- dirichlet_grad(x0.reshape((1,-1)), Y[[row]])[0][jj]
) / eps
H = dirichlet_hess(x0.reshape((1,-1)), Y[[row]])[0]
np.testing.assert_almost_equal(H, H_numeric, decimal=6)
test_dirichlet_fun_grad_hess()
library(DirichletReg)
library(testthat)
test_that("dirichlet formulae", {
k <- 3L
m <- 10L
set.seed(123)
alpha <- exp(rnorm(k))
y <- rdirichlet(m, alpha)
x0 <- rnorm(k)
for (row in seq_len(m)) {
logpdf <- dirichlet.fun(matrix(x0, nrow=1), y[row,,drop=F])
ref_logpdf <- ddirichlet(y[row,,drop=F], exp(x0), log = T)
expect_equal(logpdf, -ref_logpdf)
eps <- 1e-7
grad_num <- numeric(k)
for (col in seq_len(k)) {
xplus <- x0
xplus[col] <- x0[col] + eps
grad_num[col] <- (
dirichlet.fun(matrix(xplus, nrow=1), y[row,,drop=F])
- dirichlet.fun(matrix(x0, nrow=1), y[row,,drop=F])
) / eps
}
grad <- dirichlet.grad(matrix(x0, nrow=1), y[row,,drop=F])
expect_equal(grad |> as.vector(), grad_num, tolerance=1e-6)
H_numeric <- array(dim=c(k, k))
for (ii in seq_len(k)) {
xplus <- x0
xplus[ii] <- x0[ii] + eps
for (jj in seq_len(k)) {
H_numeric[ii, jj] <- (
dirichlet.grad(matrix(xplus, nrow=1), y[row,,drop=F])[1, jj]
- grad[1L, jj]
) / eps
}
}
H <- dirichlet.hess(matrix(xplus, nrow=1), y[row,,drop=F])
expect_equal(H[1,,], H_numeric, tolerance=1e-6)
}
})
狄利克雷回归作为目标函数
如前所述,此函数的 Hessian 对于 XGBoost 来说是有问题的:它可能具有负行列式,甚至对角线上可能具有负值,这对于优化方法来说是有问题的——在 XGBoost 中,这些值将被裁剪,并且由此产生的模型最终可能无法产生合理的预测。
一个潜在的解决方法是使用期望 Hessian——即,如果响应变量根据预测分布,则梯度的期望外积。有关更多信息,请参阅维基百科文章
https://en.wikipedia.org/wiki/Fisher_information
通常,对于指数族中的目标函数,这很容易从链接函数的梯度和概率分布的方差中获得,但对于其他一般函数,它可能涉及其他类型的计算(例如,狄利克雷的协方差和对数协方差)。
然而,它仍然产生与 Hessian 非常相似的形式。从这里的差异中也可以看出,在最佳点(梯度为零)处,狄利克雷的期望 Hessian 和真实 Hessian 将匹配,这对于优化来说是一个很好的特性(即,在驻点处,Hessian 将为正,这意味着它将是最小值而不是最大值或鞍点)。
def dirichlet_expected_hess(pred: np.ndarray) -> np.ndarray:
epred = np.exp(pred)
k = pred.shape[1]
Ehess = np.empty((pred.shape[0], k, k))
for row in range(pred.shape[0]):
Ehess[row, :, :] = (
- trigamma(epred[row].sum()) * np.outer(epred[row], epred[row])
+ np.diag(trigamma(epred[row]) * epred[row] ** 2)
)
return Ehess
def test_dirichlet_expected_hess():
k = 3
rng = np.random.default_rng(seed=123)
x0 = rng.standard_normal(size=k)
y_sample = rng.dirichlet(np.exp(x0), size=5_000_000)
x_broadcast = np.broadcast_to(x0, (y_sample.shape[0], k))
g_sample = dirichlet_grad(x_broadcast, y_sample)
ref = (g_sample.T @ g_sample) / y_sample.shape[0]
Ehess = dirichlet_expected_hess(x0.reshape((1,-1)))[0]
np.testing.assert_almost_equal(Ehess, ref, decimal=2)
test_dirichlet_expected_hess()
dirichlet.expected.hess <- function(pred) {
epred <- exp(pred)
k <- ncol(pred)
H <- array(dim = c(nrow(pred), k, k))
for (row in seq_len(nrow(pred))) {
H[row, , ] <- (
- trigamma(sum(epred[row,])) * tcrossprod(epred[row,])
+ diag(trigamma(epred[row,]) * epred[row,]^2)
)
}
return(H)
}
test_that("expected hess", {
k <- 3L
set.seed(123)
x0 <- rnorm(k)
alpha <- exp(x0)
n.samples <- 5e6
y.samples <- rdirichlet(n.samples, alpha)
x.broadcast <- rep(x0, n.samples) |> matrix(ncol=k, byrow=T)
grad.samples <- dirichlet.grad(x.broadcast, y.samples)
ref <- crossprod(grad.samples) / n.samples
Ehess <- dirichlet.expected.hess(matrix(x0, nrow=1))
expect_equal(Ehess[1,,], ref, tolerance=1e-2)
})
但请注意,这仍然不适用于 XGBoost,因为期望 Hessian,就像真实 Hessian 一样,形状为 [nrows, k, k]
,而 XGBoost 需要形状为 [nrows, k]
的东西。
可以使用每行的期望 Hessian 的对角线,但可以做得更好:可以改用具有对角线结构的上限,因为它应该带来更好的收敛特性,就像其他基于 Hessian 的优化方法一样。
在没有明显方法获得上限的情况下,这里的一种可能性是直接基于对角占优矩阵的定义数值构建这样的上限
https://en.wikipedia.org/wiki/Diagonally_dominant_matrix
也就是说:取数据每行的期望 Hessian 的绝对值,并按该数据行中 [k, k]
形状的 Hessian 的行求和
def dirichlet_diag_upper_bound_expected_hess(
pred: np.ndarray, Y: np.ndarray
) -> np.ndarray:
Ehess = dirichlet_expected_hess(pred)
diag_bound_Ehess = np.empty((pred.shape[0], Y.shape[1]))
for row in range(pred.shape[0]):
diag_bound_Ehess[row, :] = np.abs(Ehess[row, :, :]).sum(axis=1)
return diag_bound_Ehess
dirichlet.diag.upper.bound.expected.hess <- function(pred, y) {
Ehess <- dirichlet.expected.hess(pred)
diag.bound.Ehess <- array(dim=dim(pred))
for (row in seq_len(nrow(pred))) {
diag.bound.Ehess[row,] <- abs(Ehess[row,,]) |> rowSums()
}
return(diag.bound.Ehess)
}
(注意:可以通过不计算完整的矩阵来更有效地进行计算,在 R 中,通过使行成为最后一个维度并在之后转置来更有效地进行计算)
有了所有这些部分,现在可以将此模型构建为 XGBoost 自定义目标所需的格式
import xgboost as xgb
from typing import Tuple
def dirichlet_xgb_objective(
pred: np.ndarray, dtrain: xgb.DMatrix
) -> Tuple[np.ndarray, np.ndarray]:
Y = dtrain.get_label().reshape(pred.shape)
return (
dirichlet_grad(pred, Y),
dirichlet_diag_upper_bound_expected_hess(pred, Y),
)
library(xgboost)
dirichlet.xgb.objective <- function(pred, dtrain) {
y <- getinfo(dtrain, "label")
return(
list(
grad = dirichlet.grad(pred, y),
hess = dirichlet.diag.upper.bound.expected.hess(pred, y)
)
)
}
以及基于狄利克雷对数似然的评估指标监控
def dirichlet_eval_metric(
pred: np.ndarray, dtrain: xgb.DMatrix
) -> Tuple[str, float]:
Y = dtrain.get_label().reshape(pred.shape)
return "dirichlet_ll", dirichlet_fun(pred, Y)
dirichlet.eval.metric <- function(pred, dtrain) {
y <- getinfo(dtrain, "label")
ll <- dirichlet.fun(pred, y)
return(
list(
metric = "dirichlet_ll",
value = ll
)
)
}
实际例子
R 包 DirichletReg
是比例数据测试数据集的一个很好的来源
https://cran.r-project.cn/package=DirichletReg
对于本例,我们现在将使用 Arctic Lake 数据集 (Aitchison, J. (2003). The Statistical Analysis of Compositional Data. The Blackburn Press, Caldwell, NJ.),取自 DirichletReg
R 包,它包含 39 行数据,其中一个预测变量“深度”和表示这个北极湖测量沉积物成分(沙子、淤泥、粘土)的三值响应变量。
数据
# depth
X = np.array([
10.4,11.7,12.8,13,15.7,16.3,18,18.7,20.7,22.1,
22.4,24.4,25.8,32.5,33.6,36.8,37.8,36.9,42.2,47,
47.1,48.4,49.4,49.5,59.2,60.1,61.7,62.4,69.3,73.6,
74.4,78.5,82.9,87.7,88.1,90.4,90.6,97.7,103.7,
]).reshape((-1,1))
# sand, silt, clay
Y = np.array([
[0.775,0.195,0.03], [0.719,0.249,0.032], [0.507,0.361,0.132],
[0.522,0.409,0.066], [0.7,0.265,0.035], [0.665,0.322,0.013],
[0.431,0.553,0.016], [0.534,0.368,0.098], [0.155,0.544,0.301],
[0.317,0.415,0.268], [0.657,0.278,0.065], [0.704,0.29,0.006],
[0.174,0.536,0.29], [0.106,0.698,0.196], [0.382,0.431,0.187],
[0.108,0.527,0.365], [0.184,0.507,0.309], [0.046,0.474,0.48],
[0.156,0.504,0.34], [0.319,0.451,0.23], [0.095,0.535,0.37],
[0.171,0.48,0.349], [0.105,0.554,0.341], [0.048,0.547,0.41],
[0.026,0.452,0.522], [0.114,0.527,0.359], [0.067,0.469,0.464],
[0.069,0.497,0.434], [0.04,0.449,0.511], [0.074,0.516,0.409],
[0.048,0.495,0.457], [0.045,0.485,0.47], [0.066,0.521,0.413],
[0.067,0.473,0.459], [0.074,0.456,0.469], [0.06,0.489,0.451],
[0.063,0.538,0.399], [0.025,0.48,0.495], [0.02,0.478,0.502],
])
data("ArcticLake", package="DirichletReg")
x <- ArcticLake[, c("depth"), drop=F]
y <- ArcticLake[, c("sand", "silt", "clay")] |> as.matrix()
拟合 XGBoost 模型并进行预测
from typing import Dict, List
dtrain = xgb.DMatrix(X, label=Y)
results: Dict[str, Dict[str, List[float]]] = {}
booster = xgb.train(
params={
"tree_method": "hist",
"num_target": Y.shape[1],
"base_score": 0,
"disable_default_eval_metric": True,
"max_depth": 3,
"seed": 123,
},
dtrain=dtrain,
num_boost_round=10,
obj=dirichlet_xgb_objective,
evals=[(dtrain, "Train")],
evals_result=results,
custom_metric=dirichlet_eval_metric,
)
yhat = softmax(booster.inplace_predict(X), axis=1)
dtrain <- xgb.DMatrix(x, y)
booster <- xgb.train(
params = list(
tree_method="hist",
num_target=ncol(y),
base_score=0,
disable_default_eval_metric=TRUE,
max_depth=3,
seed=123
),
data = dtrain,
nrounds = 10,
obj = dirichlet.xgb.objective,
evals = list(Train=dtrain),
eval_metric = dirichlet.eval.metric
)
raw.pred <- predict(booster, x, reshape=TRUE)
yhat <- apply(raw.pred, 1, softmax) |> t()
应产生如下评估日志(注意:函数按预期递减——但与其他目标不同,这里的最小值可以达到零以下)
[0] Train-dirichlet_ll:-40.25009
[1] Train-dirichlet_ll:-47.69122
[2] Train-dirichlet_ll:-52.64620
[3] Train-dirichlet_ll:-56.36977
[4] Train-dirichlet_ll:-59.33048
[5] Train-dirichlet_ll:-61.93359
[6] Train-dirichlet_ll:-64.17280
[7] Train-dirichlet_ll:-66.29709
[8] Train-dirichlet_ll:-68.21001
[9] Train-dirichlet_ll:-70.03442
可以通过简单地查看 yhat
和 Y
来确认所获得的 yhat
在很大程度上与实际浓度相似,超出随机预测的预期。
为了获得更好的结果,可能需要添加截距。XGBoost 只允许使用标量作为截距,但对于向量值模型,最佳截距也应具有向量形式。
这可以通过提供 base_margin
来完成——与截距不同,必须在此处专门为每一行提供值,并且在进行预测时必须再次提供该 base_margin
(即不会像 base_score
那样自动添加)。
对于狄利克雷模型,可以使用通用求解器(例如 SciPy 的牛顿求解器)结合专门用于截距部分的似然、梯度和 Hessian 函数来高效地获得最佳向量值截距。此外,请注意,如果将其框架为有界优化而不对浓度应用“exp”变换,它将变为凸问题,其中真实 Hessian 可以毫无问题地用于其他类的求解器中。
为简单起见,本例仍将重用前面定义的相同似然和梯度函数,并结合 SciPy 的/ R 的 L-BFGS 求解器来获得最佳向量值截距
from scipy.optimize import minimize
def get_optimal_intercepts(Y: np.ndarray) -> np.ndarray:
k = Y.shape[1]
res = minimize(
fun=lambda pred: dirichlet_fun(
np.broadcast_to(pred, (Y.shape[0], k)),
Y
),
x0=np.zeros(k),
jac=lambda pred: dirichlet_grad(
np.broadcast_to(pred, (Y.shape[0], k)),
Y
).sum(axis=0)
)
return res["x"]
intercepts = get_optimal_intercepts(Y)
get.optimal.intercepts <- function(y) {
k <- ncol(y)
broadcast.vec <- function(x) rep(x, nrow(y)) |> matrix(ncol=k, byrow=T)
res <- optim(
par = numeric(k),
fn = function(x) dirichlet.fun(broadcast.vec(x), y),
gr = function(x) dirichlet.grad(broadcast.vec(x), y) |> colSums(),
method = "L-BFGS-B"
)
return(res$par)
}
intercepts <- get.optimal.intercepts(y)
现在再次拟合模型,这次带有截距
base_margin = np.broadcast_to(intercepts, Y.shape)
dtrain_w_intercept = xgb.DMatrix(X, label=Y, base_margin=base_margin)
results: Dict[str, Dict[str, List[float]]] = {}
booster = xgb.train(
params={
"tree_method": "hist",
"num_target": Y.shape[1],
"base_score": 0,
"disable_default_eval_metric": True,
"max_depth": 3,
"seed": 123,
},
dtrain=dtrain_w_intercept,
num_boost_round=10,
obj=dirichlet_xgb_objective,
evals=[(dtrain, "Train")],
evals_result=results,
custom_metric=dirichlet_eval_metric,
)
yhat = softmax(
booster.predict(
xgb.DMatrix(X, base_margin=base_margin)
),
axis=1
)
base.margin <- rep(intercepts, nrow(y)) |> matrix(nrow=nrow(y), byrow=T)
dtrain <- xgb.DMatrix(x, y, base_margin=base.margin)
booster <- xgb.train(
params = list(
tree_method="hist",
num_target=ncol(y),
base_score=0,
disable_default_eval_metric=TRUE,
max_depth=3,
seed=123
),
data = dtrain,
nrounds = 10,
obj = dirichlet.xgb.objective,
evals = list(Train=dtrain),
eval_metric = dirichlet.eval.metric
)
raw.pred <- predict(
booster,
x,
base_margin=base.margin,
reshape=TRUE
)
yhat <- apply(raw.pred, 1, softmax) |> t()
[0] Train-dirichlet_ll:-37.01861
[1] Train-dirichlet_ll:-42.86120
[2] Train-dirichlet_ll:-46.55133
[3] Train-dirichlet_ll:-49.15111
[4] Train-dirichlet_ll:-51.02638
[5] Train-dirichlet_ll:-52.53880
[6] Train-dirichlet_ll:-53.77409
[7] Train-dirichlet_ll:-54.88851
[8] Train-dirichlet_ll:-55.95961
[9] Train-dirichlet_ll:-56.95497
对于这个小的示例问题,两种模型之间的预测应该非常相似,并且不带截距的版本在训练数据中实现了较低的目标函数(至少对于 Python 版本),但对于真实数据的更严肃用法,添加截距时很可能会观察到更好的结果。