xgboost
objective.h
前往此文件文档。
1 
7 #ifndef XGBOOST_OBJECTIVE_H_
8 #define XGBOOST_OBJECTIVE_H_
9 
10 #include <dmlc/registry.h>
11 #include <xgboost/base.h>
12 #include <xgboost/data.h>
14 #include <xgboost/linalg.h> // for Vector
15 #include <xgboost/model.h>
16 #include <xgboost/task.h>
17 
18 #include <cstdint> // for int32_t
19 #include <functional>
20 #include <string> // for string
21 
22 namespace xgboost {
23 
24 class RegTree;
25 struct Context;
26 
28 class ObjFunction : public Configurable {
29  protected
30  Context const* ctx_{nullptr};
31 
32  public
33  static constexpr float DefaultBaseScore() { return 0.5f; }
34 
35  public
36  ~ObjFunction() override = default;
42  virtual void Configure(Args const& args) = 0;
51  virtual void GetGradient(HostDeviceVector<float> const& preds, MetaInfo const& info,
52  std::int32_t iter, linalg::Matrix<GradientPair>* out_gpair) = 0;
53 
55  [[nodiscard]] virtual const char* DefaultEvalMetric() const = 0;
59  [[nodiscard]] virtual Json DefaultMetricConfig() const { return Json{Null{}}; }
67  virtual void PredTransform(HostDeviceVector<float>*) const {}
75  virtual void EvalTransform(HostDeviceVector<float>* io_preds) { this->PredTransform(io_preds); }
86  virtual void ProbToMargin(linalg::Vector<float>* /*base_score*/) const {}
96  virtual void InitEstimation(MetaInfo const& info, linalg::Vector<float>* base_score) const;
100  [[nodiscard]] virtual struct ObjInfo Task() const = 0;
105  [[nodiscard]] virtual bst_target_t Targets(MetaInfo const& info) const {
106  if (info.labels.Shape(1) > 1) {
107  LOG(FATAL) << "当前目标函数不支持多输出";
108  }
109  return 1;
110  }
112  [[nodiscard]] Context const* Ctx() const { return this->ctx_; }
113 
129  virtual void UpdateTreeLeaf(HostDeviceVector<bst_node_t> const& /*position*/,
130  MetaInfo const& /*info*/, float /*learning_rate*/,
131  HostDeviceVector<float> const& /*prediction*/,
132  std::int32_t /*group_idx*/, RegTree* /*p_tree*/) const {}
139  static ObjFunction* Create(const std::string& name, Context const* ctx);
140 };
141 
146  : public dmlc::FunctionRegEntryBase<ObjFunctionReg,
147  std::function<ObjFunction* ()> > {
148 };
149 
162 #define XGBOOST_REGISTER_OBJECTIVE(UniqueId, Name) \
163  static DMLC_ATTRIBUTE_UNUSED ::xgboost::ObjFunctionReg & \
164  __make_ ## ObjFunctionReg ## _ ## UniqueId ## __ = \
165  ::dmlc::Registry< ::xgboost::ObjFunctionReg>::Get()->__REGISTER__(Name)
166 } // namespace xgboost
167 #endif // XGBOOST_OBJECTIVE_H_
为 xgboost 定义配置宏和基本类型。
定义: json.h:333
表示JSON格式的数据结构。
Definition: json.h:392
数据集的元信息,始终存储在内存中。
Definition: data.h:51
linalg::Tensor< float, 2 > 标签
每个实例的标签
定义: data.h:63
目标函数的接口。
定义: objective.h:28
虚拟 void EvalTransform(HostDeviceVector< float > *io_preds)
将逆链接(激活)函数应用于预测值。
定义: objective.h:75
虚拟 void InitEstimation(MetaInfo const &info, linalg::Vector< float > *base_score) const
获取预测的初始估计值(截距)。
静态 constexpr float DefaultBaseScore()
定义: objective.h:33
虚拟 void GetGradient(HostDeviceVector< float > const &preds, MetaInfo const &info, std::int32_t iter, linalg::Matrix< GradientPair > *out_gpair)=0
根据现有信息获取每个预测的梯度。
Context const * Ctx() const
上下文的获取器。
定义: objective.h:112
虚拟 void Configure(Args const &args)=0
使用指定参数配置目标函数。
Context const * ctx_
定义: objective.h:30
虚拟 void PredTransform(HostDeviceVector< float > *) const
将逆链接(激活)函数应用于预测值。
定义: objective.h:67
静态 ObjFunction * Create(const std::string &name, Context const *ctx)
根据名称创建目标函数。
虚拟 Json DefaultMetricConfig() const
返回默认指标的配置。
定义: objective.h:59
虚拟 struct ObjInfo Task() const =0
返回此目标的任务。
~ObjFunction() override=default
虚拟 void UpdateTreeLeaf(HostDeviceVector< bst_node_t > const &, MetaInfo const &, float, HostDeviceVector< float > const &, std::int32_t, RegTree *) const
构建树后更新叶子值。对于 Hessian 为 0 的目标函数是必需的。
定义: objective.h:129
虚拟 const char * DefaultEvalMetric() const =0
虚拟 void ProbToMargin(linalg::Vector< float > *) const
将链接函数应用于截距。
定义: objective.h:86
虚拟 bst_target_t Targets(MetaInfo const &info) const
返回输入矩阵的目标数量。目前 XGBoost 只支持多目标回归。
定义: objective.h:105
将回归树定义为最常见的树模型。
定义: tree_model.h:102
一个张量存储。要将其用于切片等其他功能,首先需要获取一个视图...
定义: linalg.h:745
auto Shape() const
定义: linalg.h:867
xgboost 的输入数据结构。
设备与主机向量抽象层。
线性代数相关工具。
定义 XGBoost 中不同组件的抽象接口。
集成目标、gbm和评估的学习器接口。这是用户面临的XGB...
Definition: base.h:97
std::vector< std::pair< std::string, std::string > > Args
定义: base.h:324
std::uint32_t bst_target_t
用于索引输出目标的类型。
定义: base.h:127
定义: model.h:28
XGBoost的运行时上下文。包含线程和设备等信息。
Definition: context.h:133
目标函数工厂函数的注册条目。
定义: objective.h:147
目标返回的结构体,用于确定当前任务。此结构体未被任何算法使用...
定义: task.h:24