回调函数

本文档基本介绍了 XGBoost Python 包中使用的回调 API。在 XGBoost 1.3 中,为 Python 包设计了一个新的回调接口,它提供了为训练设计各种扩展的灵活性。此外,XGBoost 还内置了许多预定义的回调函数,用于支持提前停止、检查点等。

使用内置回调函数

默认情况下,XGBoost 中的训练方法具有以下参数,例如 early_stopping_roundsverbose/verbose_eval,指定这些参数后,训练过程将在内部定义相应回调。例如,当指定 early_stopping_rounds 时,EarlyStopping 回调将在迭代循环内部调用。您也可以将此回调函数直接传递给 XGBoost

D_train = xgb.DMatrix(X_train, y_train)
D_valid = xgb.DMatrix(X_valid, y_valid)

# Define a custom evaluation metric used for early stopping.
def eval_error_metric(predt, dtrain: xgb.DMatrix):
    label = dtrain.get_label()
    r = np.zeros(predt.shape)
    gt = predt > 0.5
    r[gt] = 1 - label[gt]
    le = predt <= 0.5
    r[le] = label[le]
    return 'CustomErr', np.sum(r)

# Specify which dataset and which metric should be used for early stopping.
early_stop = xgb.callback.EarlyStopping(rounds=early_stopping_rounds,
                                        metric_name='CustomErr',
                                        data_name='Valid')

booster = xgb.train(
    {'objective': 'binary:logistic',
     'eval_metric': ['error', 'rmse'],
     'tree_method': 'hist'}, D_train,
    evals=[(D_train, 'Train'), (D_valid, 'Valid')],
    feval=eval_error_metric,
    num_boost_round=1000,
    callbacks=[early_stop],
    verbose_eval=False)

dump = booster.get_dump(dump_format='json')
assert len(early_stop.stopping_history['Valid']['CustomErr']) == len(dump)

定义自己的回调函数

XGBoost 提供了一个回调接口类:TrainingCallback,用户定义的回调应该继承此类并覆盖相应的方法。在使用和定义回调函数的示例中有一个可用的例子。