自定义目标函数和评估指标
目录
概述
XGBoost 设计为可扩展的库。扩展它的一种方式是提供我们自己的训练目标函数和相应的性能监控指标。本文档介绍了如何为 XGBoost 实现自定义的逐元素评估指标和目标函数。尽管本文以 Python 为例进行演示,但这些概念应易于应用于其他语言绑定。
注意
排序任务不支持自定义函数。
XGBoost 1.6 版本中引入了重大变更。
有关更复杂目标的限制和变通方法,请参见高级用法示例:自定义目标函数高级用法
在接下来的两个章节中,我们将逐步介绍如何实现 平方对数误差 (SLE)
目标函数
及其默认指标 均方根对数误差(RMSLE)
尽管 XGBoost 原生支持上述函数,但使用它进行演示可以让我们有机会比较我们自己的实现与 XGBoost 内部实现的结果,以便于学习。完成本教程后,我们应该能够提供自己的函数来进行快速实验。最后,我们将提供关于非恒等链接函数的注意事项,并提供在使用 scikit-learn 接口时使用自定义指标和目标函数的示例。
如果我们计算上述目标函数的梯度
以及海塞矩阵(目标函数的二阶导数)
自定义目标函数
在模型训练过程中,目标函数起着重要作用:根据模型预测和观测数据标签(或目标)提供梯度信息,包括一阶和二阶梯度。因此,有效的目标函数应该接受两个输入,即预测和标签。为了实现 SLE
,我们定义
import numpy as np
import xgboost as xgb
from typing import Tuple
def gradient(predt: np.ndarray, dtrain: xgb.DMatrix) -> np.ndarray:
'''Compute the gradient squared log error.'''
y = dtrain.get_label()
return (np.log1p(predt) - np.log1p(y)) / (predt + 1)
def hessian(predt: np.ndarray, dtrain: xgb.DMatrix) -> np.ndarray:
'''Compute the hessian for squared log error.'''
y = dtrain.get_label()
return ((-np.log1p(predt) + np.log1p(y) + 1) /
np.power(predt + 1, 2))
def squared_log(predt: np.ndarray,
dtrain: xgb.DMatrix) -> Tuple[np.ndarray, np.ndarray]:
'''Squared Log Error objective. A simplified version for RMSLE used as
objective function.
'''
predt[predt < -1] = -1 + 1e-6
grad = gradient(predt, dtrain)
hess = hessian(predt, dtrain)
return grad, hess
在上面的代码片段中,squared_log
是我们想要的目标函数。它接受一个 numpy 数组 predt
作为模型预测,并接受训练 DMatrix 以获取所需信息,包括标签和权重(此处未使用)。然后,通过将其作为参数传递给 xgb.train
,此目标函数在训练期间用作 XGBoost 的回调函数。
xgb.train({'tree_method': 'hist', 'seed': 1994}, # any other tree method is fine.
dtrain=dtrain,
num_boost_round=10,
obj=squared_log)
请注意,在我们定义目标函数时,我们是使用预测减去标签还是标签减去预测非常重要。如果您发现训练误差不降反升,这可能是原因所在。
自定义评估指标
因此,在有了自定义目标函数后,我们可能还需要一个相应的指标来监控模型的性能。如上所述,SLE
的默认指标是 RMSLE
。类似地,我们定义另一个类似回调的函数作为新的指标
def rmsle(predt: np.ndarray, dtrain: xgb.DMatrix) -> Tuple[str, float]:
''' Root mean squared log error metric.'''
y = dtrain.get_label()
predt[predt < -1] = -1 + 1e-6
elements = np.power(np.log1p(y) - np.log1p(predt), 2)
return 'PyRMSLE', float(np.sqrt(np.sum(elements) / len(y)))
由于我们在 Python 中进行演示,指标或目标函数不必是函数,任何可调用对象都应足够。与目标函数类似,我们的指标也接受 predt
和 dtrain
作为输入,但返回指标本身的名称和一个浮点值作为结果。将其作为 custom_metric
参数传递给 XGBoost 后
xgb.train({'tree_method': 'hist', 'seed': 1994,
'disable_default_eval_metric': 1},
dtrain=dtrain,
num_boost_round=10,
obj=squared_log,
custom_metric=rmsle,
evals=[(dtrain, 'dtrain'), (dtest, 'dtest')],
evals_result=results)
我们将能够看到 XGBoost 打印如下信息
[0] dtrain-PyRMSLE:1.37153 dtest-PyRMSLE:1.31487
[1] dtrain-PyRMSLE:1.26619 dtest-PyRMSLE:1.20899
[2] dtrain-PyRMSLE:1.17508 dtest-PyRMSLE:1.11629
[3] dtrain-PyRMSLE:1.09836 dtest-PyRMSLE:1.03871
[4] dtrain-PyRMSLE:1.03557 dtest-PyRMSLE:0.977186
[5] dtrain-PyRMSLE:0.985783 dtest-PyRMSLE:0.93057
...
请注意,参数 disable_default_eval_metric
用于禁止 XGBoost 中的默认指标。
有关完全可重现的源代码和比较图,请参见 定义自定义回归目标函数和指标的演示。
反向链接函数
使用内置目标函数时,原始预测会根据目标函数进行转换。提供自定义目标函数时,XGBoost 不知道其链接函数,因此用户负责为目标函数和自定义评估指标进行转换。对于像 squared error
这样的恒等链接函数,这很简单,但对于对数链接或逆链接等其他链接函数,差异是显著的。
对于 Python 包,预测的行为可以通过 predict
函数中的 output_margin
参数控制。在使用 custom_metric
参数但不使用自定义目标函数时,指标函数将接收转换后的预测,因为目标函数由 XGBoost 定义。然而,当同时提供自定义目标函数和该指标时,目标函数和自定义指标都将接收原始预测。以下示例通过一个多分类模型提供了两种不同行为的比较。首先我们定义 2 个不同的 Python 指标函数,它们实现相同的底层指标以进行比较,merror_with_transform 在同时使用自定义目标函数时使用,否则首选更简单的 merror,因为 XGBoost 可以自行执行转换。
import xgboost as xgb
import numpy as np
def merror_with_transform(predt: np.ndarray, dtrain: xgb.DMatrix):
"""Used when custom objective is supplied."""
y = dtrain.get_label()
n_classes = predt.size // y.shape[0]
# Like custom objective, the predt is untransformed leaf weight when custom objective
# is provided.
# With the use of `custom_metric` parameter in train function, custom metric receives
# raw input only when custom objective is also being used. Otherwise custom metric
# will receive transformed prediction.
assert predt.shape == (d_train.num_row(), n_classes)
out = np.zeros(dtrain.num_row())
for r in range(predt.shape[0]):
i = np.argmax(predt[r])
out[r] = i
assert y.shape == out.shape
errors = np.zeros(dtrain.num_row())
errors[y != out] = 1.0
return 'PyMError', np.sum(errors) / dtrain.num_row()
上述函数仅在我们需要使用自定义目标函数且 XGBoost 不知道如何转换预测时才需要。多分类误差函数的正常实现是
def merror(predt: np.ndarray, dtrain: xgb.DMatrix):
"""Used when there's no custom objective."""
# No need to do transform, XGBoost handles it internally.
errors = np.zeros(dtrain.num_row())
errors[y != out] = 1.0
return 'PyMError', np.sum(errors) / dtrain.num_row()
接下来我们需要自定义 softprob 目标函数
def softprob_obj(predt: np.ndarray, data: xgb.DMatrix):
"""Loss function. Computing the gradient and approximated hessian (diagonal).
Reimplements the `multi:softprob` inside XGBoost.
"""
# Full implementation is available in the Python demo script linked below
...
return grad, hess
最后,我们可以使用 obj
和 custom_metric
参数训练模型
Xy = xgb.DMatrix(X, y)
booster = xgb.train(
{"num_class": kClasses, "disable_default_eval_metric": True},
m,
num_boost_round=kRounds,
obj=softprob_obj,
custom_metric=merror_with_transform,
evals_result=custom_results,
evals=[(m, "train")],
)
或者如果您不需要自定义目标函数,只希望提供一个 XGBoost 中没有的指标
booster = xgb.train(
{
"num_class": kClasses,
"disable_default_eval_metric": True,
"objective": "multi:softmax",
},
m,
num_boost_round=kRounds,
# Use a simpler metric implementation.
custom_metric=merror,
evals_result=custom_results,
evals=[(m, "train")],
)
我们使用 multi:softmax
来说明转换后预测的差异。对于 softprob
,输出预测数组的形状为 (n_samples, n_classes)
,而对于 softmax
,其形状为 (n_samples, )
。多分类目标函数的演示也可在 创建自定义多分类目标函数的演示 中找到。此外,请参见 截距 以获得更多解释。
Scikit-Learn 接口
XGBoost 的 scikit-learn 接口提供了一些实用工具来改善与标准 scikit-learn 函数的集成。例如,在 XGBoost 1.6.0 之后,用户可以直接使用 scikit-learn 中的成本函数(非评分函数)
from sklearn.datasets import load_diabetes
from sklearn.metrics import mean_absolute_error
X, y = load_diabetes(return_X_y=True)
reg = xgb.XGBRegressor(
tree_method="hist",
eval_metric=mean_absolute_error,
)
reg.fit(X, y, eval_set=[(X, y)])
此外,对于自定义目标函数,用户可以定义目标函数而无需访问 DMatrix
def softprob_obj(labels: np.ndarray, predt: np.ndarray) -> Tuple[np.ndarray, np.ndarray]:
rows = labels.shape[0]
classes = predt.shape[1]
grad = np.zeros((rows, classes), dtype=float)
hess = np.zeros((rows, classes), dtype=float)
eps = 1e-6
for r in range(predt.shape[0]):
target = labels[r]
p = softmax(predt[r, :])
for c in range(predt.shape[1]):
g = p[c] - 1.0 if c == target else p[c]
h = max((2.0 * p[c] * (1.0 - p[c])).item(), eps)
grad[r, c] = g
hess[r, c] = h
grad = grad.reshape((rows * classes, 1))
hess = hess.reshape((rows * classes, 1))
return grad, hess
clf = xgb.XGBClassifier(tree_method="hist", objective=softprob_obj)