xgboost
objective.h
前往此文件的文档。
1 
7 #ifndef XGBOOST_OBJECTIVE_H_
8 #define XGBOOST_OBJECTIVE_H_
9 
10 #include <dmlc/registry.h>
11 #include <xgboost/base.h>
12 #include <xgboost/data.h>
14 #include <xgboost/model.h>
15 #include <xgboost/task.h>
16 
17 #include <cstdint> // std::int32_t
18 #include <functional>
19 #include <string>
20 
21 namespace xgboost {
22 
23 class RegTree;
24 struct Context;
25 
27 class ObjFunction : public Configurable {
28  protected
29  Context const* ctx_;
30 
31  public
32  static constexpr float DefaultBaseScore() { return 0.5f; }
33 
34  public
36  ~ObjFunction() override = default;
41  virtual void Configure(Args const& args) = 0;
50  virtual void GetGradient(HostDeviceVector<float> const& preds, MetaInfo const& info,
51  std::int32_t iter, linalg::Matrix<GradientPair>* out_gpair) = 0;
52 
54  virtual const char* DefaultEvalMetric() const = 0;
58  virtual Json DefaultMetricConfig() const { return Json{Null{}}; }
59 
60  // the following functions are optional, most of time default implementation is good enough
68  virtual void PredTransform(HostDeviceVector<float>*) const {}
76  virtual void EvalTransform(HostDeviceVector<float>* io_preds) { this->PredTransform(io_preds); }
85  [[nodiscard]] virtual float ProbToMargin(float base_score) const { return base_score; }
94  virtual void InitEstimation(MetaInfo const& info, linalg::Tensor<float, 1>* base_score) const;
98  [[nodiscard]] virtual struct ObjInfo Task() const = 0;
103  [[nodiscard]] virtual bst_target_t Targets(MetaInfo const& info) const {
104  if (info.labels.Shape(1) > 1) {
105  LOG(FATAL) << "multioutput is not supported by the current objective function";
106  }
107  return 1;
108  }
109 
125  virtual void UpdateTreeLeaf(HostDeviceVector<bst_node_t> const& /*position*/,
126  MetaInfo const& /*info*/, float /*learning_rate*/,
127  HostDeviceVector<float> const& /*prediction*/,
128  std::int32_t /*group_idx*/, RegTree* /*p_tree*/) const {}
129 
135  static ObjFunction* Create(const std::string& name, Context const* ctx);
136 };
137 
142  : public dmlc::FunctionRegEntryBase<ObjFunctionReg,
143  std::function<ObjFunction* ()> > {
144 };
145 
158 #define XGBOOST_REGISTER_OBJECTIVE(UniqueId, Name) \
159  static DMLC_ATTRIBUTE_UNUSED ::xgboost::ObjFunctionReg & \
160  __make_ ## ObjFunctionReg ## _ ## UniqueId ## __ = \
161  ::dmlc::Registry< ::xgboost::ObjFunctionReg>::Get()->__REGISTER__(Name)
162 } // namespace xgboost
163 #endif // XGBOOST_OBJECTIVE_H_
定义了 xgboost 的配置宏和基本类型。
定义于: json.h:319
表示 JSON 格式的数据结构。
定义于: json.h:378
数据集的元信息,始终位于内存中。
定义于: data.h:48
linalg::Tensor< float, 2 > labels
每个实例的标签
定义于: data.h:60
目标函数的接口
定义于: objective.h:27
virtual void EvalTransform(HostDeviceVector< float > *io_preds)
将逆链接(激活)函数应用于预测值。
定义于: objective.h:76
static constexpr float DefaultBaseScore()
定义于: objective.h:32
virtual void GetGradient(HostDeviceVector< float > const &preds, MetaInfo const &info, std::int32_t iter, linalg::Matrix< GradientPair > *out_gpair)=0
根据现有信息,计算每个预测的梯度。
virtual void Configure(Args const &args)=0
使用指定参数配置目标函数。
Context const * ctx_
定义于: objective.h:29
virtual void PredTransform(HostDeviceVector< float > *) const
将逆链接(激活)函数应用于预测值。
定义于: objective.h:68
static ObjFunction * Create(const std::string &name, Context const *ctx)
根据名称创建一个目标函数。
virtual void InitEstimation(MetaInfo const &info, linalg::Tensor< float, 1 > *base_score) const
获取预测的初始估计。
virtual Json DefaultMetricConfig() const
返回默认评估指标的配置。
定义于: objective.h:58
virtual struct ObjInfo Task() const =0
返回此目标函数的任务。
~ObjFunction() override=default
虚析构函数
virtual void UpdateTreeLeaf(HostDeviceVector< bst_node_t > const &, MetaInfo const &, float, HostDeviceVector< float > const &, std::int32_t, RegTree *) const
在树构建后更新叶子值。无hessian的目标函数需要此函数。
定义于: objective.h:125
virtual const char * DefaultEvalMetric() const =0
virtual bst_target_t Targets(MetaInfo const &info) const
返回输入矩阵的目标数量。目前 XGBoost 仅支持多目标回归。
定义于: objective.h:103
virtual float ProbToMargin(float base_score) const
将链接函数应用于截距。
定义于: objective.h:85
定义回归树为最常见的树模型。
定义于: tree_model.h:157
一个张量存储。要将其用于切片等其他功能,首先需要获取一个视图...
定义于: linalg.h:762
auto Shape() const
定义于: linalg.h:884
xgboost 的输入数据结构。
设备和主机向量的抽象层。
定义了 XGBoost 中不同组件的抽象接口。
多目标树的核心数据结构。
定义于: base.h:89
std::vector< std::pair< std::string, std::string > > Args
定义于: base.h:316
std::uint32_t bst_target_t
用于索引输出目标的类型。
定义于: base.h:119
定义于: model.h:31
XGBoost 的运行时上下文。包含线程和设备等信息。
定义于: context.h:133
目标函数工厂函数的注册项。
定义于: objective.h:143
由目标函数返回的结构体,用于确定当前的任务。该结构体不被任何算法使用...
定义于: task.h:24