自定义目标和评估指标
目录
概述
XGBoost 被设计成一个可扩展的库。扩展它的一种方式是为训练提供我们自己的目标函数,并为性能监控提供相应的指标。本文档介绍了如何为 XGBoost 实现自定义的逐元素评估指标和目标。尽管本文使用 Python 进行演示,但这些概念应该很容易适用于其他语言绑定。
注意
排名任务不支持自定义函数。
XGBoost 1.6 中做了重大更改。
有关更复杂目标函数的限制和变通方法的更多信息,请参阅高级用法示例:自定义目标函数的高级用法
在接下来的两节中,我们将逐步讲解如何实现平方对数误差(SLE)
目标函数
及其默认指标均方根对数误差(RMSLE)
尽管 XGBoost 本身支持这些函数,但使用它进行演示为我们提供了比较我们自己的实现结果与 XGBoost 内部实现结果的机会,以便学习。完成本教程后,我们应该能够提供自己的函数进行快速实验。最后,我们将提供一些关于非恒等链接函数的注释,以及使用自定义指标和目标函数与 scikit-learn 接口的示例。
如果我们计算该目标函数的梯度
以及 Hessian(目标函数的二阶导数)
自定义目标函数
在模型训练期间,目标函数起着重要作用:根据模型预测和观察到的数据标签(或目标)提供梯度信息,包括一阶和二阶梯度。因此,一个有效的目标函数应该接受两个输入,即预测和标签。为了实现SLE
,我们定义
import numpy as np
import xgboost as xgb
from typing import Tuple
def gradient(predt: np.ndarray, dtrain: xgb.DMatrix) -> np.ndarray:
'''Compute the gradient squared log error.'''
y = dtrain.get_label()
return (np.log1p(predt) - np.log1p(y)) / (predt + 1)
def hessian(predt: np.ndarray, dtrain: xgb.DMatrix) -> np.ndarray:
'''Compute the hessian for squared log error.'''
y = dtrain.get_label()
return ((-np.log1p(predt) + np.log1p(y) + 1) /
np.power(predt + 1, 2))
def squared_log(predt: np.ndarray,
dtrain: xgb.DMatrix) -> Tuple[np.ndarray, np.ndarray]:
'''Squared Log Error objective. A simplified version for RMSLE used as
objective function.
'''
predt[predt < -1] = -1 + 1e-6
grad = gradient(predt, dtrain)
hess = hessian(predt, dtrain)
return grad, hess
在上述代码片段中,squared_log
是我们想要的目标函数。它接受一个 numpy 数组predt
作为模型预测,以及用于获取所需信息(包括标签和权重(此处未使用))的训练 DMatrix。然后,通过将其作为参数传递给xgb.train
,将此目标函数用作 XGBoost 训练期间的回调函数。
xgb.train({'tree_method': 'hist', 'seed': 1994}, # any other tree method is fine.
dtrain=dtrain,
num_boost_round=10,
obj=squared_log)
请注意,在我们的目标函数定义中,我们是从预测中减去标签还是反过来,这一点很重要。如果你发现训练误差上升而不是下降,这可能就是原因。
自定义指标函数
因此,在有了自定义目标函数之后,我们可能还需要一个相应的指标来监控模型的性能。如上所述,SLE
的默认指标是RMSLE
。类似地,我们定义另一个类似回调的函数作为新指标
def rmsle(predt: np.ndarray, dtrain: xgb.DMatrix) -> Tuple[str, float]:
''' Root mean squared log error metric.'''
y = dtrain.get_label()
predt[predt < -1] = -1 + 1e-6
elements = np.power(np.log1p(y) - np.log1p(predt), 2)
return 'PyRMSLE', float(np.sqrt(np.sum(elements) / len(y)))
由于我们正在用 Python 进行演示,因此指标或目标函数不必是函数,任何可调用对象都应该足够。与目标函数类似,我们的指标也接受predt
和dtrain
作为输入,但返回指标本身的名称和一个浮点值作为结果。在将其作为custom_metric
参数传递给 XGBoost 后
xgb.train({'tree_method': 'hist', 'seed': 1994,
'disable_default_eval_metric': 1},
dtrain=dtrain,
num_boost_round=10,
obj=squared_log,
custom_metric=rmsle,
evals=[(dtrain, 'dtrain'), (dtest, 'dtest')],
evals_result=results)
我们将能够看到 XGBoost 打印类似如下的内容
[0] dtrain-PyRMSLE:1.37153 dtest-PyRMSLE:1.31487
[1] dtrain-PyRMSLE:1.26619 dtest-PyRMSLE:1.20899
[2] dtrain-PyRMSLE:1.17508 dtest-PyRMSLE:1.11629
[3] dtrain-PyRMSLE:1.09836 dtest-PyRMSLE:1.03871
[4] dtrain-PyRMSLE:1.03557 dtest-PyRMSLE:0.977186
[5] dtrain-PyRMSLE:0.985783 dtest-PyRMSLE:0.93057
...
请注意,参数disable_default_eval_metric
用于抑制 XGBoost 中的默认指标。
有关完全可重现的源代码和比较图,请参阅定义自定义回归目标和指标的演示。
反向链接函数
使用内置目标函数时,原始预测会根据目标函数进行转换。当提供自定义目标函数时,XGBoost 不知道其链接函数,因此用户负责对目标函数和自定义评估指标进行转换。对于像平方误差
这样的具有恒等链接的目标函数,这很简单,但对于其他链接函数(如对数链接或逆链接),差异很大。
对于 Python 包,预测的行为可以通过predict
函数中的output_margin
参数控制。当使用custom_metric
参数而不使用自定义目标函数时,指标函数将接收转换后的预测,因为目标函数是由 XGBoost 定义的。然而,当自定义目标函数也与该指标一起提供时,目标函数和自定义指标都将接收原始预测。以下示例通过多类分类模型比较了两种不同的行为。首先,我们定义了 2 个不同的 Python 指标函数,实现了相同的底层指标进行比较,当也使用自定义目标函数时使用merror_with_transform,否则首选更简单的merror,因为 XGBoost 可以自行执行转换。
import xgboost as xgb
import numpy as np
def merror_with_transform(predt: np.ndarray, dtrain: xgb.DMatrix):
"""Used when custom objective is supplied."""
y = dtrain.get_label()
n_classes = predt.size // y.shape[0]
# Like custom objective, the predt is untransformed leaf weight when custom objective
# is provided.
# With the use of `custom_metric` parameter in train function, custom metric receives
# raw input only when custom objective is also being used. Otherwise custom metric
# will receive transformed prediction.
assert predt.shape == (d_train.num_row(), n_classes)
out = np.zeros(dtrain.num_row())
for r in range(predt.shape[0]):
i = np.argmax(predt[r])
out[r] = i
assert y.shape == out.shape
errors = np.zeros(dtrain.num_row())
errors[y != out] = 1.0
return 'PyMError', np.sum(errors) / dtrain.num_row()
上述函数仅在我们想要使用自定义目标函数且 XGBoost 不知道如何转换预测时才需要。多类误差函数的常规实现是
def merror(predt: np.ndarray, dtrain: xgb.DMatrix):
"""Used when there's no custom objective."""
# No need to do transform, XGBoost handles it internally.
errors = np.zeros(dtrain.num_row())
errors[y != out] = 1.0
return 'PyMError', np.sum(errors) / dtrain.num_row()
接下来我们需要自定义的 softprob 目标
def softprob_obj(predt: np.ndarray, data: xgb.DMatrix):
"""Loss function. Computing the gradient and approximated hessian (diagonal).
Reimplements the `multi:softprob` inside XGBoost.
"""
# Full implementation is available in the Python demo script linked below
...
return grad, hess
最后我们可以使用obj
和custom_metric
参数训练模型
Xy = xgb.DMatrix(X, y)
booster = xgb.train(
{"num_class": kClasses, "disable_default_eval_metric": True},
m,
num_boost_round=kRounds,
obj=softprob_obj,
custom_metric=merror_with_transform,
evals_result=custom_results,
evals=[(m, "train")],
)
或者,如果您不需要自定义目标函数,而只想提供一个 XGBoost 中不具备的指标
booster = xgb.train(
{
"num_class": kClasses,
"disable_default_eval_metric": True,
"objective": "multi:softmax",
},
m,
num_boost_round=kRounds,
# Use a simpler metric implementation.
custom_metric=merror,
evals_result=custom_results,
evals=[(m, "train")],
)
我们使用multi:softmax
来演示转换预测的差异。对于softprob
,输出预测数组的形状为(n_samples, n_classes)
,而对于softmax
,其形状为(n_samples, )
。多类目标函数的演示也位于创建自定义多类目标函数的演示。此外,请参阅截距以获取更多解释。
Scikit-Learn 接口
XGBoost 的 scikit-learn 接口提供了一些实用程序,以改进与标准 scikit-learn 函数的集成。例如,在 XGBoost 1.6.0 之后,用户可以直接使用 scikit-learn 的成本函数(而非评分函数)
from sklearn.datasets import load_diabetes
from sklearn.metrics import mean_absolute_error
X, y = load_diabetes(return_X_y=True)
reg = xgb.XGBRegressor(
tree_method="hist",
eval_metric=mean_absolute_error,
)
reg.fit(X, y, eval_set=[(X, y)])
此外,对于自定义目标函数,用户可以在不访问DMatrix
的情况下定义目标函数
def softprob_obj(labels: np.ndarray, predt: np.ndarray) -> Tuple[np.ndarray, np.ndarray]:
rows = labels.shape[0]
classes = predt.shape[1]
grad = np.zeros((rows, classes), dtype=float)
hess = np.zeros((rows, classes), dtype=float)
eps = 1e-6
for r in range(predt.shape[0]):
target = labels[r]
p = softmax(predt[r, :])
for c in range(predt.shape[1]):
g = p[c] - 1.0 if c == target else p[c]
h = max((2.0 * p[c] * (1.0 - p[c])).item(), eps)
grad[r, c] = g
hess[r, c] = h
grad = grad.reshape((rows * classes, 1))
hess = hess.reshape((rows * classes, 1))
return grad, hess
clf = xgb.XGBClassifier(tree_method="hist", objective=softprob_obj)