自定义目标函数的高级用法

目录

概述

XGBoost 允许根据用户为所需目标函数提供的梯度和 Hessian 矩阵来优化自定义用户定义函数。

为了使自定义目标函数按预期工作

  • 要优化的函数必须是光滑且二阶可导的。

  • 函数必须对行/观测值具有可加性,例如具有独立同分布 (i.i.d.) 假设的似然函数。

  • 函数的得分范围必须是无界的(例如,它不应该只适用于正数)。

  • 函数必须是凸的。请注意,如果 Hessian 矩阵包含负值,它们将被截断,这可能会导致模型无法很好地拟合函数。

  • 对于多输出目标函数,不同目标之间不应存在依赖关系(即 Hessian 矩阵对于每一行都应是对角矩阵)。

然而,可以通过放弃使用函数的真实 Hessian 矩阵,转而使用具有更好性质的近似 Hessian 矩阵来解决其中一些限制——在使用非真实 Hessian 矩阵时收敛可能会变慢,但许多理论保证仍然有效,并能产生可用的模型。例如,XGBoost 对多项逻辑回归的内部实现使用了具有对角结构的 Hessian 矩阵上界,而不是对于数据中每一行都是一个完整的方阵的真实 Hessian 矩阵。

本教程提供了一些针对不完全符合上述标准的用例的建议,通过展示如何解决由浓度参数化的狄利克雷回归问题。

狄利克雷回归模型对 XGBoost 提出了一些挑战

  • 浓度参数必须是正数。一种简单的方法是对原始无界值应用 ‘exp’ 变换来实现,但在这种情况下目标函数会变成非凸的。此外,请注意,与 GLM 模型中使用的典型分布不同,此函数不属于指数族。

  • Hessian 矩阵在目标之间存在依赖关系——也就是说,对于一个具有 ‘k’ 个参数的狄利克雷分布,每一行数据都会有一个维度为 [k, k] 的完整 Hessian 矩阵。

  • 此类模型的最优截距应是一个向量,而不是对每个目标都使用相同的值。

为了将此类模型用作自定义目标函数

  • 可以使用期望 Hessian 矩阵(也称为 Fisher 信息矩阵或期望信息)代替真实 Hessian 矩阵。对于可加似然函数,即使真实 Hessian 矩阵不是正定或半正定的,期望 Hessian 矩阵始终是半正定的。

  • 可以使用具有对角结构的期望 Hessian 矩阵的上界,这样在此对角界下的二阶近似总是产生大于或等于非对角期望 Hessian 矩阵下的函数值。

  • 由于 XGBoost 用于截距的 base_score 参数仅限于标量,因此可以使用 base_margin 功能代替,但请注意使用它需要付出更多努力。

狄利克雷回归公式

狄利克雷分布是 Beta 分布在多维上的推广。它对值总和为 1 的比例数据进行建模,通常用作复合模型(例如狄利克雷-多项式)的一部分或贝叶斯模型中的先验,但它也可以单独用于比例数据。

对于给定观测值 y 和给定预测值 x,其似然函数如下所示

\[L(\mathbf{y} | \mathbf{x}) = \frac{1}{\beta(\mathbf{x})} \prod_{i=1}^k y_i^{x_i - 1}\]

其中

\[\beta(\mathbf{x}) = \frac{ \prod_{i=1}^k \Gamma(x_i) }{\Gamma( \sum_{i=1}^k x_i )}\]

在这种情况下,我们希望优化按行求和的负对数似然函数。所得的函数、梯度和 Hessian 矩阵可以按如下方式实现

Python
import numpy as np
from scipy.special import loggamma, psi as digamma, polygamma
trigamma = lambda x: polygamma(1, x)

def dirichlet_fun(pred: np.ndarray, Y: np.ndarray) -> float:
    epred = np.exp(pred)
    sum_epred = np.sum(epred, axis=1, keepdims=True)
    return (
        loggamma(epred).sum()
        - loggamma(sum_epred).sum()
        - np.sum(np.log(Y) * (epred - 1))
    )
def dirichlet_grad(pred: np.ndarray, Y: np.ndarray) -> np.ndarray:
    epred = np.exp(pred)
    return epred * (
        digamma(epred)
        - digamma(np.sum(epred, axis=1, keepdims=True))
        - np.log(Y)
    )
def dirichlet_hess(pred: np.ndarray, Y: np.ndarray) -> np.ndarray:
    epred = np.exp(pred)
    grad = dirichlet_grad(pred, Y)
    k = Y.shape[1]
    H = np.empty((pred.shape[0], k, k))
    for row in range(pred.shape[0]):
        H[row, :, :] = (
            - trigamma(epred[row].sum()) * np.outer(epred[row], epred[row])
            + np.diag(grad[row] + trigamma(epred[row]) * epred[row] ** 2)
        )
    return H
softmax <- function(x) {
    max.x <- max(x)
    e <- exp(x - max.x)
    return(e / sum(e))
}

dirichlet.fun <- function(pred, y) {
    epred <- exp(pred)
    sum_epred <- rowSums(epred)
    return(
        sum(lgamma(epred))
        - sum(lgamma(sum_epred))
        - sum(log(y) * (epred - 1))
    )
}

dirichlet.grad <- function(pred, y) {
    epred <- exp(pred)
    return(
        epred * (
            digamma(epred)
            - digamma(rowSums(epred))
            - log(y)
        )
    )
}

dirichlet.hess <- function(pred, y) {
    epred <- exp(pred)
    grad <- dirichlet.grad(pred, y)
    k <- ncol(y)
    H <- array(dim = c(nrow(y), k, k))
    for (row in seq_len(nrow(y))) {
        H[row, , ] <- (
            - trigamma(sum(epred[row,])) * tcrossprod(epred[row,])
            + diag(grad[row,] + trigamma(epred[row,]) * epred[row,]^2)
        )
    }
    return(H)
}

验证实现的正确性

Python
from math import isclose
from scipy import stats
from scipy.optimize import check_grad
from scipy.special import softmax

def gen_random_dirichlet(rng: np.random.Generator, m: int, k: int):
    alpha = np.exp(rng.standard_normal(size=k))
    return rng.dirichlet(alpha, size=m)

def test_dirichlet_fun_grad_hess():
    k = 3
    m = 10
    rng = np.random.default_rng(seed=123)
    Y = gen_random_dirichlet(rng, m, k)
    x0 = rng.standard_normal(size=k)
    for row in range(Y.shape[0]):
        fun_row = dirichlet_fun(x0.reshape((1,-1)), Y[[row]])
        ref_logpdf = stats.dirichlet.logpdf(
            Y[row] / Y[row].sum(), # <- avoid roundoff error
            np.exp(x0),
        )
        assert isclose(fun_row, -ref_logpdf)

        gdiff = check_grad(
            lambda pred: dirichlet_fun(pred.reshape((1,-1)), Y[[row]]),
            lambda pred: dirichlet_grad(pred.reshape((1,-1)), Y[[row]]),
            x0
        )
        assert gdiff <= 1e-6

        H_numeric = np.empty((k,k))
        eps = 1e-7
        for ii in range(k):
            x0_plus_eps = x0.reshape((1,-1)).copy()
            x0_plus_eps[0,ii] += eps
            for jj in range(k):
                H_numeric[ii, jj] = (
                    dirichlet_grad(x0_plus_eps, Y[[row]])[0][jj]
                    - dirichlet_grad(x0.reshape((1,-1)), Y[[row]])[0][jj]
                ) / eps
        H = dirichlet_hess(x0.reshape((1,-1)), Y[[row]])[0]
        np.testing.assert_almost_equal(H, H_numeric, decimal=6)
test_dirichlet_fun_grad_hess()
library(DirichletReg)
library(testthat)

test_that("dirichlet formulae", {
    k <- 3L
    m <- 10L
    set.seed(123)
    alpha <- exp(rnorm(k))
    y <- rdirichlet(m, alpha)
    x0 <- rnorm(k)

    for (row in seq_len(m)) {
        logpdf <- dirichlet.fun(matrix(x0, nrow=1), y[row,,drop=F])
        ref_logpdf <- ddirichlet(y[row,,drop=F], exp(x0), log = T)
        expect_equal(logpdf, -ref_logpdf)

        eps <- 1e-7
        grad_num <- numeric(k)
        for (col in seq_len(k)) {
            xplus <- x0
            xplus[col] <- x0[col] + eps
            grad_num[col] <- (
                dirichlet.fun(matrix(xplus, nrow=1), y[row,,drop=F])
                - dirichlet.fun(matrix(x0, nrow=1), y[row,,drop=F])
            ) / eps
        }

        grad <- dirichlet.grad(matrix(x0, nrow=1), y[row,,drop=F])
        expect_equal(grad |> as.vector(), grad_num, tolerance=1e-6)

        H_numeric <- array(dim=c(k, k))
        for (ii in seq_len(k)) {
            xplus <- x0
            xplus[ii] <- x0[ii] + eps
            for (jj in seq_len(k)) {
                H_numeric[ii, jj] <- (
                    dirichlet.grad(matrix(xplus, nrow=1), y[row,,drop=F])[1, jj]
                    - grad[1L, jj]
                ) / eps
            }
        }

        H <- dirichlet.hess(matrix(xplus, nrow=1), y[row,,drop=F])
        expect_equal(H[1,,], H_numeric, tolerance=1e-6)
    }
})

作为目标函数的狄利克雷回归

如前所述,此函数的 Hessian 矩阵对 XGBoost 来说存在问题:它可能具有负行列式,甚至对角线上可能存在负值,这对于优化方法来说是个问题——在 XGBoost 中,这些值将被截断,并且由此产生的模型可能无法生成合理的预测结果。

一个潜在的解决方法是使用期望 Hessian 矩阵代替——即,如果响应变量按照预测分布,梯度的期望外积。有关更多信息,请参见维基百科文章

https://en.wikipedia.org/wiki/Fisher_information

通常,对于指数族中的目标函数,可以很容易地从链接函数的梯度和概率分布的方差中获得期望 Hessian 矩阵,但对于其他函数,通常可能涉及其他类型的计算(例如,狄利克雷分布的协方差和对数协方差)。

然而,它仍然产生了一种与 Hessian 矩阵非常相似的形式。从这里的差异也可以看出,在最优解处(梯度为零),狄利克雷分布的期望 Hessian 矩阵和真实 Hessian 矩阵将匹配,这对于优化来说是一个很好的特性(即 Hessian 矩阵在驻点处将是正定的,这意味着它是一个最小值而不是最大值或鞍点)。

Python
def dirichlet_expected_hess(pred: np.ndarray) -> np.ndarray:
    epred = np.exp(pred)
    k = pred.shape[1]
    Ehess = np.empty((pred.shape[0], k, k))
    for row in range(pred.shape[0]):
        Ehess[row, :, :] = (
            - trigamma(epred[row].sum()) * np.outer(epred[row], epred[row])
            + np.diag(trigamma(epred[row]) * epred[row] ** 2)
        )
    return Ehess
def test_dirichlet_expected_hess():
    k = 3
    rng = np.random.default_rng(seed=123)
    x0 = rng.standard_normal(size=k)
    y_sample = rng.dirichlet(np.exp(x0), size=5_000_000)
    x_broadcast = np.broadcast_to(x0, (y_sample.shape[0], k))
    g_sample = dirichlet_grad(x_broadcast, y_sample)
    ref = (g_sample.T @ g_sample) / y_sample.shape[0]
    Ehess = dirichlet_expected_hess(x0.reshape((1,-1)))[0]
    np.testing.assert_almost_equal(Ehess, ref, decimal=2)
test_dirichlet_expected_hess()
dirichlet.expected.hess <- function(pred) {
    epred <- exp(pred)
    k <- ncol(pred)
    H <- array(dim = c(nrow(pred), k, k))
    for (row in seq_len(nrow(pred))) {
        H[row, , ] <- (
            - trigamma(sum(epred[row,])) * tcrossprod(epred[row,])
            + diag(trigamma(epred[row,]) * epred[row,]^2)
        )
    }
    return(H)
}

test_that("expected hess", {
    k <- 3L
    set.seed(123)
    x0 <- rnorm(k)
    alpha <- exp(x0)
    n.samples <- 5e6
    y.samples <- rdirichlet(n.samples, alpha)

    x.broadcast <- rep(x0, n.samples) |> matrix(ncol=k, byrow=T)
    grad.samples <- dirichlet.grad(x.broadcast, y.samples)
    ref <- crossprod(grad.samples) / n.samples
    Ehess <- dirichlet.expected.hess(matrix(x0, nrow=1))
    expect_equal(Ehess[1,,], ref, tolerance=1e-2)
})

但请注意,这对于 XGBoost 仍然不可用,因为期望 Hessian 矩阵与真实 Hessian 矩阵一样,形状为 [nrows, k, k],而 XGBoost 需要形状为 [nrows, k] 的输入。

可以使用期望 Hessian 矩阵的对角线,但可以做得更好:可以改用具有对角结构的期望 Hessian 矩阵的上界,因为它应该会带来更好的收敛特性,就像其他基于 Hessian 的优化方法一样。

在没有明显方法获得上界的情况下,这里的一种可能性是直接基于对角占优矩阵的定义数值构建这样的上界

https://en.wikipedia.org/wiki/Diagonally_dominant_matrix

也就是说:取数据中每一行的期望 Hessian 矩阵的绝对值,并按该行数据的 [k, k] 形状 Hessian 矩阵的行求和

Python
def dirichlet_diag_upper_bound_expected_hess(
    pred: np.ndarray, Y: np.ndarray
) -> np.ndarray:
    Ehess = dirichlet_expected_hess(pred)
    diag_bound_Ehess = np.empty((pred.shape[0], Y.shape[1]))
    for row in range(pred.shape[0]):
        diag_bound_Ehess[row, :] = np.abs(Ehess[row, :, :]).sum(axis=1)
    return diag_bound_Ehess
dirichlet.diag.upper.bound.expected.hess <- function(pred, y) {
    Ehess <- dirichlet.expected.hess(pred)
    diag.bound.Ehess <- array(dim=dim(pred))
    for (row in seq_len(nrow(pred))) {
        diag.bound.Ehess[row,] <- abs(Ehess[row,,]) |> rowSums()
    }
    return(diag.bound.Ehess)
}

(注意:可以通过不计算完整的矩阵来比此处所示更高效地进行计算,在 R 中,可以通过将行设为最后一个维度并在之后进行转置来实现)

具备了所有这些部分后,现在可以将此模型转换为 XGBoost 自定义目标函数所需的格式

Python
import xgboost as xgb
from typing import Tuple

def dirichlet_xgb_objective(
    pred: np.ndarray, dtrain: xgb.DMatrix
) -> Tuple[np.ndarray, np.ndarray]:
    Y = dtrain.get_label().reshape(pred.shape)
    return (
        dirichlet_grad(pred, Y),
        dirichlet_diag_upper_bound_expected_hess(pred, Y),
    )
library(xgboost)

dirichlet.xgb.objective <- function(pred, dtrain) {
    y <- getinfo(dtrain, "label")
    return(
        list(
            grad = dirichlet.grad(pred, y),
            hess = dirichlet.diag.upper.bound.expected.hess(pred, y)
        )
    )
}

以及基于狄利克雷对数似然函数的评估指标监测

Python
def dirichlet_eval_metric(
    pred: np.ndarray, dtrain: xgb.DMatrix
) -> Tuple[str, float]:
    Y = dtrain.get_label().reshape(pred.shape)
    return "dirichlet_ll", dirichlet_fun(pred, Y)
dirichlet.eval.metric <- function(pred, dtrain) {
    y <- getinfo(dtrain, "label")
    ll <- dirichlet.fun(pred, y)
    return(
        list(
            metric = "dirichlet_ll",
            value = ll
        )
    )
}

实际示例

比例数据的良好测试数据集来源是 R 包 DirichletReg

https://cran.r-project.org.cn/package=DirichletReg

在本示例中,我们将使用 Arctic Lake 数据集(Aitchison, J. (2003). The Statistical Analysis of Compositional Data. The Blackburn Press, Caldwell, NJ.),该数据集取自 DirichletReg R 包,包含 39 行数据,一个预测变量‘depth’以及一个表示该北极湖泊测量沉积物组成(沙、淤泥、粘土)的三值响应变量。

数据

Python
# depth
X = np.array([
    10.4,11.7,12.8,13,15.7,16.3,18,18.7,20.7,22.1,
    22.4,24.4,25.8,32.5,33.6,36.8,37.8,36.9,42.2,47,
    47.1,48.4,49.4,49.5,59.2,60.1,61.7,62.4,69.3,73.6,
    74.4,78.5,82.9,87.7,88.1,90.4,90.6,97.7,103.7,
]).reshape((-1,1))
# sand, silt, clay
Y = np.array([
    [0.775,0.195,0.03], [0.719,0.249,0.032], [0.507,0.361,0.132],
    [0.522,0.409,0.066], [0.7,0.265,0.035], [0.665,0.322,0.013],
    [0.431,0.553,0.016], [0.534,0.368,0.098], [0.155,0.544,0.301],
    [0.317,0.415,0.268], [0.657,0.278,0.065], [0.704,0.29,0.006],
    [0.174,0.536,0.29], [0.106,0.698,0.196], [0.382,0.431,0.187],
    [0.108,0.527,0.365], [0.184,0.507,0.309], [0.046,0.474,0.48],
    [0.156,0.504,0.34], [0.319,0.451,0.23], [0.095,0.535,0.37],
    [0.171,0.48,0.349], [0.105,0.554,0.341], [0.048,0.547,0.41],
    [0.026,0.452,0.522], [0.114,0.527,0.359], [0.067,0.469,0.464],
    [0.069,0.497,0.434], [0.04,0.449,0.511], [0.074,0.516,0.409],
    [0.048,0.495,0.457], [0.045,0.485,0.47], [0.066,0.521,0.413],
    [0.067,0.473,0.459], [0.074,0.456,0.469], [0.06,0.489,0.451],
    [0.063,0.538,0.399], [0.025,0.48,0.495], [0.02,0.478,0.502],
])
data("ArcticLake", package="DirichletReg")
x <- ArcticLake[, c("depth"), drop=F]
y <- ArcticLake[, c("sand", "silt", "clay")] |> as.matrix()

拟合 XGBoost 模型并进行预测

Python
from typing import Dict, List

dtrain = xgb.DMatrix(X, label=Y)
results: Dict[str, Dict[str, List[float]]] = {}
booster = xgb.train(
    params={
        "tree_method": "hist",
        "num_target": Y.shape[1],
        "base_score": 0,
        "disable_default_eval_metric": True,
        "max_depth": 3,
        "seed": 123,
    },
    dtrain=dtrain,
    num_boost_round=10,
    obj=dirichlet_xgb_objective,
    evals=[(dtrain, "Train")],
    evals_result=results,
    custom_metric=dirichlet_eval_metric,
)
yhat = softmax(booster.inplace_predict(X), axis=1)
dtrain <- xgb.DMatrix(x, y)
booster <- xgb.train(
    params = list(
        tree_method="hist",
        num_target=ncol(y),
        base_score=0,
        disable_default_eval_metric=TRUE,
        max_depth=3,
        seed=123
    ),
    data = dtrain,
    nrounds = 10,
    obj = dirichlet.xgb.objective,
    evals = list(Train=dtrain),
    eval_metric = dirichlet.eval.metric
)
raw.pred <- predict(booster, x, reshape=TRUE)
yhat <- apply(raw.pred, 1, softmax) |> t()

应该会生成如下评估日志(注意:函数如预期般递减——但与其他目标函数不同,这里的最小值可以达到零以下)

[0] Train-dirichlet_ll:-40.25009
[1] Train-dirichlet_ll:-47.69122
[2] Train-dirichlet_ll:-52.64620
[3] Train-dirichlet_ll:-56.36977
[4] Train-dirichlet_ll:-59.33048
[5] Train-dirichlet_ll:-61.93359
[6] Train-dirichlet_ll:-64.17280
[7] Train-dirichlet_ll:-66.29709
[8] Train-dirichlet_ll:-68.21001
[9] Train-dirichlet_ll:-70.03442

通过简单地查看 yhatY,可以确认获得的 yhat 在很大程度上与实际浓度相似,超出了随机预测的预期。

为了获得更好的结果,可能需要添加一个截距。XGBoost 只允许使用标量作为截距,但对于向量值模型,最优截距也应采用向量形式。

这可以通过提供 base_margin 代替来实现——与截距不同,这里必须为每一行专门提供值,并且在进行预测时必须再次提供所述 base_margin(即,它不会像 base_score 那样自动添加)。

对于狄利克雷模型,可以使用通用求解器(例如 SciPy 的 Newton 求解器)高效地获得最优截距,该求解器只需针对截距部分使用专门的似然函数、梯度和 Hessian 函数。此外,请注意,如果将其框定为有界优化而不对浓度应用 'exp' 变换,它将变为一个凸问题,对于此类问题,真实 Hessian 矩阵可以在其他类别的求解器中毫无问题地使用。

为简单起见,本示例仍将重用先前定义的似然函数和梯度函数,并结合 SciPy / R 的 L-BFGS 求解器来获得最优向量值截距

Python
from scipy.optimize import minimize

def get_optimal_intercepts(Y: np.ndarray) -> np.ndarray:
    k = Y.shape[1]
    res = minimize(
        fun=lambda pred: dirichlet_fun(
            np.broadcast_to(pred, (Y.shape[0], k)),
            Y
        ),
        x0=np.zeros(k),
        jac=lambda pred: dirichlet_grad(
            np.broadcast_to(pred, (Y.shape[0], k)),
            Y
        ).sum(axis=0)
    )
    return res["x"]
intercepts = get_optimal_intercepts(Y)
get.optimal.intercepts <- function(y) {
    k <- ncol(y)
    broadcast.vec <- function(x) rep(x, nrow(y)) |> matrix(ncol=k, byrow=T)
    res <- optim(
        par = numeric(k),
        fn = function(x) dirichlet.fun(broadcast.vec(x), y),
        gr = function(x) dirichlet.grad(broadcast.vec(x), y) |> colSums(),
        method = "L-BFGS-B"
    )
    return(res$par)
}
intercepts <- get.optimal.intercepts(y)

现在再次拟合模型,这次包含截距

Python
base_margin = np.broadcast_to(intercepts, Y.shape)
dtrain_w_intercept = xgb.DMatrix(X, label=Y, base_margin=base_margin)
results: Dict[str, Dict[str, List[float]]] = {}
booster = xgb.train(
    params={
        "tree_method": "hist",
        "num_target": Y.shape[1],
        "base_score": 0,
        "disable_default_eval_metric": True,
        "max_depth": 3,
        "seed": 123,
    },
    dtrain=dtrain_w_intercept,
    num_boost_round=10,
    obj=dirichlet_xgb_objective,
    evals=[(dtrain, "Train")],
    evals_result=results,
    custom_metric=dirichlet_eval_metric,
)
yhat = softmax(
    booster.predict(
        xgb.DMatrix(X, base_margin=base_margin)
    ),
    axis=1
)
base.margin <- rep(intercepts, nrow(y)) |> matrix(nrow=nrow(y), byrow=T)
dtrain <- xgb.DMatrix(x, y, base_margin=base.margin)
booster <- xgb.train(
    params = list(
        tree_method="hist",
        num_target=ncol(y),
        base_score=0,
        disable_default_eval_metric=TRUE,
        max_depth=3,
        seed=123
    ),
    data = dtrain,
    nrounds = 10,
    obj = dirichlet.xgb.objective,
    evals = list(Train=dtrain),
    eval_metric = dirichlet.eval.metric
)
raw.pred <- predict(
    booster,
    x,
    base_margin=base.margin,
    reshape=TRUE
)
yhat <- apply(raw.pred, 1, softmax) |> t()
[0] Train-dirichlet_ll:-37.01861
[1] Train-dirichlet_ll:-42.86120
[2] Train-dirichlet_ll:-46.55133
[3] Train-dirichlet_ll:-49.15111
[4] Train-dirichlet_ll:-51.02638
[5] Train-dirichlet_ll:-52.53880
[6] Train-dirichlet_ll:-53.77409
[7] Train-dirichlet_ll:-54.88851
[8] Train-dirichlet_ll:-55.95961
[9] Train-dirichlet_ll:-56.95497

对于这个小型示例问题,两者的预测结果应该非常相似,并且没有截距的版本在训练数据中实现了较低的目标函数(至少在 Python 版本中),但对于涉及真实世界数据的更严肃用法,添加截距可能会获得更好的结果。