xgboost
tree_updater.h
前往此文件的文档。
1 
8 #ifndef XGBOOST_TREE_UPDATER_H_
9 #define XGBOOST_TREE_UPDATER_H_
10 
11 #include <dmlc/registry.h>
12 #include <xgboost/base.h> // 用于 Args, GradientPair
13 #include <xgboost/data.h> // DMatrix
14 #include <xgboost/host_device_vector.h> // 用于 HostDeviceVector
15 #include <xgboost/linalg.h> // 用于 VectorView
16 #include <xgboost/model.h> // 用于 Configurable
17 #include <xgboost/span.h> // 用于 Span
18 #include <xgboost/tree_model.h> // 用于 RegTree
19 
20 #include <functional> // 用于 function
21 #include <string> // 用于 string
22 #include <vector> // 用于 vector
23 
24 namespace xgboost {
25 namespace tree {
26 struct TrainParam;
27 }
28 
29 class Json;
30 struct Context;
31 struct ObjInfo;
32 
36 /// interface of tree update module, that performs update of a tree.
37 class TreeUpdater : public Configurable {
38  protected
39  /// \brief Context for device and RNG.
40  Context const* ctx_ = nullptr;
41  public
43  /// \brief virtual destructor
48  /// \brief Initialize the updater with given arguments.
55  /// \brief Whether this updater can be used for updating existing trees.
60  /// \brief Wether the out_position in Update is valid. This determines whether adaptive tree can be used.
74  /// \brief perform update to the tree models
75  /// \param param Training parameters.
76  /// \param gpair Pointer to the gradient and hessian.
77  /// \param data DMatrix object storing the training data.
88  /// \brief determines whether updater has enough knowledge about a given dataset to quickly update prediction cache.
89  /// \param data DMatrix object.
90  /// \param out_preds Output cache.
91  return false;
92  }
93  [[nodiscard]] virtual char const* Name() const = 0;
101  /// \brief Create a tree updater given name.
102  static TreeUpdater* Create(const std::string& name, Context const* ctx, ObjInfo const* task);
103 };
104 
107 /// \brief Registry entry for tree updater.
108 struct TreeUpdaterReg
109  : public dmlc::FunctionRegEntryBase<
110  TreeUpdaterReg, std::function<TreeUpdater*(Context const* ctx, ObjInfo const* task)>> {};
123 /// \brief Register a tree updater.
124 #define XGBOOST_REGISTER_TREE_UPDATER(UniqueId, Name) \
125  static DMLC_ATTRIBUTE_UNUSED ::xgboost::TreeUpdaterReg& \
126  __make_ ## TreeUpdaterReg ## _ ## UniqueId ## __ = \
127  ::dmlc::Registry< ::xgboost::TreeUpdaterReg>::Get()->__REGISTER__(Name)
128 } // namespace xgboost
129 #endif // XGBOOST_TREE_UPDATER_H_
定义了 xgboost 的配置宏和基本类型。
XGBoost 用于存储所有外部数据的内部数据结构。
定义: data.h:549
定义: host_device_vector.h:87
表示 JSON 格式的数据结构。
定义: json.h:378
树更新模块的接口,用于执行树的更新。
定义: tree_updater.h:36
Context const * ctx_
定义: tree_updater.h:38
~TreeUpdater() override=default
虚析构函数
virtual bool UpdatePredictionCache(const DMatrix *, linalg::MatrixView< float >)
确定更新器是否对给定数据集有足够的了解,可以快速更新预测缓存。
定义: tree_updater.h:88
static TreeUpdater * Create(const std::string &name, Context const *ctx, ObjInfo const *task)
根据名称创建树更新器。
virtual bool CanModifyTree() const
此更新器是否可用于更新现有树。
定义: tree_updater.h:55
virtual bool HasNodePosition() const
Update 中的 out_position 是否有效。这决定了是否可以使用自适应树。
定义: tree_updater.h:60
TreeUpdater(const Context *ctx)
定义: tree_updater.h:41
virtual void Update(tree::TrainParam const *param, linalg::Matrix< GradientPair > *gpair, DMatrix *data, common::Span< HostDeviceVector< bst_node_t >> out_position, const std::vector< RegTree * > &out_trees)=0
执行树模型的更新
virtual void Configure(const Args &args)=0
使用给定参数初始化更新器。
virtual char const * Name() const =0
span 类实现,基于 ISO++20 span<T>。接口应该相同。
定义: span.h:431
具有静态类型和维度的张量视图。它实现了索引和切片。
定义: linalg.h:294
张量存储。要将其用于切片等其他功能,需要先获取一个视图。
定义: linalg.h:762
xgboost 的输入数据结构。
设备和主机向量的抽象层。
线性代数相关的实用工具。
定义了 XGBoost 中不同组件的抽象接口。
多目标树的核心数据结构。
定义: base.h:89
std::vector< std::pair< std::string, std::string > > Args
定义: base.h:316
定义: model.h:31
XGBoost 的运行时上下文。包含线程和设备等信息。
定义: context.h:133
目标函数返回的结构体,用于确定当前任务。该结构体不被任何算法使用。
定义: task.h:24
树更新器的注册表条目。
定义: tree_updater.h:109
树的模型结构