xgboost
tree_model.h
转到此文件的文档。
1 
7 #ifndef XGBOOST_TREE_MODEL_H_
8 #define XGBOOST_TREE_MODEL_H_
9 
10 #include <dmlc/io.h>
11 #include <dmlc/parameter.h>
12 #include <xgboost/base.h>
13 #include <xgboost/data.h>
14 #include <xgboost/feature_map.h>
15 #include <xgboost/linalg.h> // for VectorView
16 #include <xgboost/logging.h>
17 #include <xgboost/model.h>
18 #include <xgboost/multi_target_tree_model.h> // for MultiTargetTree
19 
20 #include <algorithm>
21 #include <cstring>
22 #include <limits>
23 #include <memory> // for make_unique
24 #include <stack>
25 #include <string>
26 #include <vector>
27 
28 namespace xgboost {
29 class Json;
30 
31 // FIXME(trivialfis): Once binary IO is gone, make this parameter internal as it should
32 // not be configured by users.
34 struct TreeParam : public dmlc::Parameter<TreeParam> {
38  /* \brief number of nodes in the tree */int num_nodes{1};
40  /* \brief number of deleted nodes */int num_deleted{0};
44  /* \brief number of features in the model */bst_feature_t num_feature{0};
49  /* \brief Size of leaf vector */bst_target_t size_leaf_vector{1};
51  int reserved[31];
54  // assert compact alignment
55  static_assert(sizeof(TreeParam) == (31 + 6) * sizeof(int), "TreeParam: 64位对齐");
56  std::memset(reserved, 0, sizeof(reserved));
57  }
58 
59  // Swap byte order for all fields. Useful for transporting models between machines with different
60  // endianness (big endian vs little endian)
61  [[nodiscard]] TreeParam ByteSwap() const {
62  TreeParam x = *this;
63  dmlc::ByteSwap(&x.deprecated_num_roots, sizeof(x.deprecated_num_roots), 1);
64  dmlc::ByteSwap(&x.num_nodes, sizeof(x.num_nodes), 1);
65  dmlc::ByteSwap(&x.num_deleted, sizeof(x.num_deleted), 1);
66  dmlc::ByteSwap(&x.deprecated_max_depth, sizeof(x.deprecated_max_depth), 1);
67  dmlc::ByteSwap(&x.num_feature, sizeof(x.num_feature), 1);
68  dmlc::ByteSwap(&x.size_leaf_vector, sizeof(x.size_leaf_vector), 1);
69  dmlc::ByteSwap(x.reserved, sizeof(x.reserved[0]), sizeof(x.reserved) / sizeof(x.reserved[0]));
70  return x;
71  }
72 
73  // declare the parameters
75  // only declare the parameters that can be set by the user.
76  // other arguments are set by the algorithm.
77  DMLC_DECLARE_FIELD(num_nodes).set_lower_bound(1).set_default(1);
78  DMLC_DECLARE_FIELD(num_feature)
79  .set_default(0)
80  .describe("树构建中使用的特征数量。");
81  DMLC_DECLARE_FIELD(num_deleted).set_default(0);
82  DMLC_DECLARE_FIELD(size_leaf_vector)
83  .set_lower_bound(0)
84  .set_default(1)
85  .describe("叶子向量的大小,为向量树保留");
86  }
87 
88  bool operator==(const TreeParam& b) const {
89  return num_nodes == b.num_nodes && num_deleted == b.num_deleted &&
91  }
92 };
93 
95 struct RTreeNodeStat {
97  /* \brief loss reduction of this node */bst_float loss_chg;
99  /* \brief sum of hessian */bst_float sum_hess;
101  /* \brief base weight */bst_float base_weight;
103  /* \brief number of child node that are leaves */int leaf_child_cnt {0};
104 
105  RTreeNodeStat() = default;
106  RTreeNodeStat(float loss_chg, float sum_hess, float weight)
108  bool operator==(const RTreeNodeStat& b) const {
109  return loss_chg == b.loss_chg && sum_hess == b.sum_hess &&
111  }
112  // Swap byte order for all fields. Useful for transporting models between machines with different
113  // endianness (big endian vs little endian)
114  [[nodiscard]] RTreeNodeStat ByteSwap() const {
115  RTreeNodeStat x = *this;
116  dmlc::ByteSwap(&x.loss_chg, sizeof(x.loss_chg), 1);
117  dmlc::ByteSwap(&x.sum_hess, sizeof(x.sum_hess), 1);
118  dmlc::ByteSwap(&x.base_weight, sizeof(x.base_weight), 1);
119  dmlc::ByteSwap(&x.leaf_child_cnt, sizeof(x.leaf_child_cnt), 1);
120  return x;
121  }
122 };
123 
127 /* \brief A unique pointer that can be copied. */template <typename T>
129  std::unique_ptr<T> ptr_{nullptr};
130 
131  public
132  CopyUniquePtr() = default;
134  ptr_.reset(nullptr);
135  if (that.ptr_) {
136  ptr_ = std::make_unique<T>(*that);
137  }
138  }
139  /* \brief get pointer */[[nodiscard]] XGBOOST_DEVICE T* get() const noexcept { return ptr_.get(); } // NOLINT
140 
141  /* \brief dereference */T& operator*() { return *ptr_; }
142  /* \brief member access */T* operator->() noexcept { return this->get(); }
143 
144  /* \brief const dereference */T const& operator*() const { return *ptr_; }
145  /* \brief const member access */T const* operator->() const noexcept { return this->get(); }
146 
147  /* \brief check if pointer is valid */explicit operator bool() const { return static_cast<bool>(ptr_); }
148  /* \brief check if pointer is not valid */bool operator!() const { return !ptr_; }
149  /* \brief reset pointer */void reset(T* ptr) { ptr_.reset(ptr); } // NOLINT
150 };
151 
157 /* \brief data structure for regression tree */class RegTree : public Model {
158  public
160  // invalid node idstatic constexpr bst_node_t kInvalidNodeId{MultiTargetTree::InvalidNodeId()};
161  // marker for deleted nodestatic constexpr uint32_t kDeletedNodeMarker = std::numeric_limits<uint32_t>::max();
162  // root node idstatic constexpr bst_node_t kRoot{0};
163 
165  /* \brief regression tree node */class Node {
166  public
168  // assert compact alignment
169  static_assert(sizeof(Node) == 4 * sizeof(int) + sizeof(Info),
170  "Node: 64 bit align");
171  }
172  Node(int32_t cleft, int32_t cright, int32_t parent,
173  uint32_t split_ind, float split_cond, bool default_left)
174  parent_{parent}, cleft_{cleft}, cright_{cright} {
175  this->SetParent(parent_);
176  this->SetSplit(split_ind, split_cond, default_left);
177  }
178 
180  /** \brief get left child * \return left child index */[[nodiscard]] XGBOOST_DEVICE int LeftChild() const { return this->cleft_; }
182  /** \brief get right child * \return right child index */[[nodiscard]] XGBOOST_DEVICE int RightChild() const { return this->cright_; }
184  /** \brief get default child when feature is missing * \return default child index */[[nodiscard]] XGBOOST_DEVICE int DefaultChild() const {
185  return this->DefaultLeft() ? this->LeftChild() : this->RightChild();
186  }
188  /** \brief get split feature index * \return split feature index */[[nodiscard]] XGBOOST_DEVICE bst_feature_t SplitIndex() const {
189  static_assert(!std::is_signed_v<bst_feature_t>);
190  return sindex_ & ((1U << 31) - 1U);
191  }
193  /** \brief check if it's default left branch * \return whether it's default left branch */[[nodiscard]] XGBOOST_DEVICE bool DefaultLeft() const { return (sindex_ >> 31) != 0; }
195  /** \brief whether the node is leaf * \return whether is leaf */[[nodiscard]] XGBOOST_DEVICE bool IsLeaf() const { return cleft_ == kInvalidNodeId; }
197  /** \brief get leaf value of leaf node * \return leaf value */[[nodiscard]] XGBOOST_DEVICE float LeafValue() const { return (this->info_).leaf_value; }
199  /** \brief get split condition * \return split condition */[[nodiscard]] XGBOOST_DEVICE SplitCondT SplitCond() const { return (this->info_).split_cond; }
201  /** \brief get parent of the node * \return parent index */[[nodiscard]] XGBOOST_DEVICE int Parent() const { return parent_ & ((1U << 31) - 1); }
203  /** \brief whether this node is left child * \return whether it is left child */[[nodiscard]] XGBOOST_DEVICE bool IsLeftChild() const { return (parent_ & (1U << 31)) != 0; }
205  /** \brief whether node is deleted * \return whether node is deleted */[[nodiscard]] XGBOOST_DEVICE bool IsDeleted() const { return sindex_ == kDeletedNodeMarker; }
207  /** \brief check if node is root * \return whether is root */[[nodiscard]] XGBOOST_DEVICE bool IsRoot() const { return parent_ == kInvalidNodeId; }
212  /** * \param nid node id * \brief set left child */XGBOOST_DEVICE void SetLeftChild(int nid) {
213  this->cleft_ = nid;
214  }
219  /** * \param nid node id * \brief set right child */XGBOOST_DEVICE void SetRightChild(int nid) {
220  this->cright_ = nid;
221  }
228  /** * \brief set split condition * \param split_index feature index to split * \param split_cond split condition * \param default_left the default direction when feature is missing */XGBOOST_DEVICE void SetSplit(unsigned split_index, SplitCondT split_cond,
229  bool default_left = false) {
230  if (default_left) split_index |= (1U << 31);
231  this->sindex_ = split_index;
232  (this->info_).split_cond = split_cond;
233  }
240  /** * \brief set the node to be leaf * \param value the value of the leaf */XGBOOST_DEVICE void SetLeaf(bst_float value, int right = kInvalidNodeId) {
241  (this->info_).leaf_value = value;
242  this->cleft_ = kInvalidNodeId;
243  this->cright_ = right;
244  }
246  /** \brief mark that this node is deleted */XGBOOST_DEVICE void MarkDelete() {
247  this->sindex_ = kDeletedNodeMarker;
248  }
250  /** \brief reuse this node */XGBOOST_DEVICE void Reuse() {
251  this->sindex_ = 0;
252  }
253  // set parent
254  /* \brief set parent of the node */XGBOOST_DEVICE void SetParent(int pidx, bool is_left_child = true) {
255  if (is_left_child) pidx |= (1U << 31);
256  this->parent_ = pidx;
257  }
258  bool operator==(const Node& b) const {
259  return parent_ == b.parent_ && cleft_ == b.cleft_ &&
260  cright_ == b.cright_ && sindex_ == b.sindex_ &&
261  info_.leaf_value == b.info_.leaf_value;
262  }
263 
264  [[nodiscard]] Node ByteSwap() const {
265  Node x = *this;
266  dmlc::ByteSwap(&x.parent_, sizeof(x.parent_), 1);
267  dmlc::ByteSwap(&x.cleft_, sizeof(x.cleft_), 1);
268  dmlc::ByteSwap(&x.cright_, sizeof(x.cright_), 1);
269  dmlc::ByteSwap(&x.sindex_, sizeof(x.sindex_), 1);
270  dmlc::ByteSwap(&x.info_, sizeof(x.info_), 1);
271  return x;
272  }
273 
274  private
279  /* \brief node information, union of leaf value and split condition. */union Info{
280  bst_float leaf_value;
281  SplitCondT split_cond;
282  };
283  // pointer to parent, highest bit is used to
284  // indicate whether it's a left child or not
285  int32_t parent_{kInvalidNodeId};
286  // pointer to left, right
287  int32_t cleft_{kInvalidNodeId}, cright_{kInvalidNodeId};
288  // split feature index, left split or right split depends on the highest bit
289  uint32_t sindex_{0};
290  // extra info
291  Info info_;
292  };
293 
299  /** \brief Change node to a leaf node. * \param rid node index * \param value new leaf value */void ChangeToLeaf(int rid, bst_float value) {
300  CHECK(nodes_[nodes_[rid].LeftChild() ].IsLeaf());
301  CHECK(nodes_[nodes_[rid].RightChild()].IsLeaf());
302  this->DeleteNode(nodes_[rid].LeftChild());
303  this->DeleteNode(nodes_[rid].RightChild());
304  nodes_[rid].SetLeaf(value);
305  }
311  /** \brief collapse a node to a leaf node, delete all children * \param rid node index * \param value new leaf value */void CollapseToLeaf(int rid, bst_float value) {
312  if (nodes_[rid].IsLeaf()) return;
313  if (!nodes_[nodes_[rid].LeftChild() ].IsLeaf()) {
314  CollapseToLeaf(nodes_[rid].LeftChild(), 0.0f);
315  }
316  if (!nodes_[nodes_[rid].RightChild() ].IsLeaf()) {
317  CollapseToLeaf(nodes_[rid].RightChild(), 0.0f);
318  }
319  this->ChangeToLeaf(rid, value);
320  }
321 
322  /* \brief constructor */RegTree() {
323  param_.Init(Args{});
324  nodes_.resize(param_.num_nodes);
325  stats_.resize(param_.num_nodes);
326  split_types_.resize(param_.num_nodes, FeatureType::kNumerical);
327  split_categories_segments_.resize(param_.num_nodes);
328  for (int i = 0; i < param_.num_nodes; i++) {
329  nodes_[i].SetLeaf(0.0f);
330  nodes_[i].SetParent(kInvalidNodeId);
331  }
332  }
336  /* \brief constructor with size and feature support */explicit RegTree(bst_target_t n_targets, bst_feature_t n_features) : RegTree{} {
337  param_.num_feature = n_features;
338  param_.size_leaf_vector = n_targets;
339  if (n_targets > 1) {
340  this->p_mt_tree_.reset(new MultiTargetTree{&param_});
341  }
342  }
343 
345  /* \brief reference to node */Node& operator[](int nid) {
346  return nodes_[nid];
347  }
349  /* \brief const reference to node */const Node& operator[](int nid) const {
350  return nodes_[nid];
351  }
352 
354  /* \brief get all nodes */[[nodiscard]] const std::vector<Node>& GetNodes() const { return nodes_; }
355 
357  /* \brief get all stats */[[nodiscard]] const std::vector<RTreeNodeStat>& GetStats() const { return stats_; }
358 
360  /* \brief reference to node statistics */RTreeNodeStat& Stat(int nid) {
361  return stats_[nid];
362  }
364  /* \brief const reference to node statistics */[[nodiscard]] const RTreeNodeStat& Stat(int nid) const {
365  return stats_[nid];
366  }
367 
372  /** \brief 从二进制流加载模型 * \param fi 输入流 */void Load(dmlc::Stream* fi);
377  /** \brief 将模型保存到二进制流 * \param fo 输出流 */void Save(dmlc::Stream* fo) const;
378 
379  /** \brief 从 JSON 对象加载模型。 * \param in JSON 对象。 */void LoadModel(Json const& in) override;
380  /** \brief 将模型保存到 JSON 对象。 * \param out JSON 对象。 */void SaveModel(Json* out) const override;
381 
382  /** \brief 测试两棵树是否相等。 */bool operator==(const RegTree& b) const {
383  return nodes_ == b.nodes_ && stats_ == b.stats_ &&
384  deleted_nodes_ == b.deleted_nodes_ && param_ == b.param_;
385  }
386  /* \brief 遍历此树中的所有节点。
387  *
388  * \param func 一个接受节点索引的函数,当迭代应停止时返回 false,否则返回 true。
389  * stop, otherwise returns true.
390  */
391  template <typename Func> void WalkTree(Func func) const {
392  std::stack<bst_node_t> nodes;
393  nodes.push(kRoot);
394  auto &self = *this;
395  while (!nodes.empty()) {
396  auto nidx = nodes.top();
397  nodes.pop();
398  if (!func(nidx)) {
399  return;
400  }
401  auto left = self.LeftChild(nidx);
402  auto right = self.RightChild(nidx);
403  if (left != RegTree::kInvalidNodeId) {
404  nodes.push(left);
405  }
406  if (right != RegTree::kInvalidNodeId) {
407  nodes.push(right);
408  }
409  }
410  }
417  /** \brief 测试两棵树是否相等。 */[[nodiscard]] bool Equal(const RegTree& b) const;
418 
436  /** \brief 将叶子节点展开为分支节点 * \param nid 要展开的节点 id * \param split_index 分裂的特征索引 * \param split_value 分裂条件 * \param default_left 默认方向 * \param base_weight 节点的基权重 * \param left_leaf_weight 左子节点的叶子权重 * \param right_leaf_weight 右子节点的叶子权重 * \param loss_change 此分裂引起的训练损失变化 * \param sum_hess 当前节点的 hessian 之和 * \param left_sum 左子节点的 hessian 之和 * \param right_sum 右子节点的 hessian 之和 */void ExpandNode(bst_node_t nid, unsigned split_index, bst_float split_value,
437  bool default_left, bst_float base_weight,
438  bst_float left_leaf_weight, bst_float right_leaf_weight,
439  bst_float loss_change, float sum_hess, float left_sum,
440  float right_sum,
441  bst_node_t leaf_right_child = kInvalidNodeId);
445  /** \brief 将多目标叶子节点展开为分支节点 * \param nidx 要展开的节点 id * \param split_index 分裂的特征索引 * \param split_cond 分裂条件 * \param default_left 默认方向 * \param base_weight 节点的基权重 * \param left_weight 左子节点的叶子权重 * \param right_weight 右子节点的叶子权重 */void ExpandNode(bst_node_t nidx, bst_feature_t split_index, float split_cond, bool default_left,
448  linalg::VectorView<float const> right_weight);
449 
465  /** \brief 通过类别分裂展开叶子节点 * \param nid 要展开的节点 id * \param split_index 分裂的特征索引 * \param split_cat 左分支中的类别列表 * \param default_left 默认方向 * \param base_weight 节点的基权重 * \param left_leaf_weight 左子节点的叶子权重 * \param right_leaf_weight 右子节点的叶子权重 * \param loss_change 此分裂引起的训练损失变化 * \param sum_hess 当前节点的 hessian 之和 * \param left_sum 左子节点的 hessian 之和 * \param right_sum 右子节点的 hessian 之和 */void ExpandCategorical(bst_node_t nid, bst_feature_t split_index,
466  common::Span<const uint32_t> split_cat, bool default_left,
467  bst_float base_weight, bst_float left_leaf_weight,
468  bst_float right_leaf_weight, bst_float loss_change, float sum_hess,
469  float left_sum, float right_sum);
473  /** \brief 检查树是否包含类别分裂 */[[nodiscard]] bool HasCategoricalSplit() const { return !split_categories_.empty(); }
477  /** \brief 检查是否是多目标树 */[[nodiscard]] bool IsMultiTarget() const { return static_cast<bool>(p_mt_tree_); }
481  /** \brief 树输出的目标数量。 */[[nodiscard]] bst_target_t NumTargets() const { return param_.size_leaf_vector; }
485  /** \brief 如果是多目标树,获取 MultiTargetTree 对象 */[[nodiscard]] auto GetMultiTargetTree() const {
486  CHECK(IsMultiTarget());
487  return p_mt_tree_.get();
488  }
492  /** \brief 此树中使用的特征数量 */[[nodiscard]] bst_feature_t NumFeatures() const noexcept { return param_.num_feature; }
496  /** \brief 树中的节点数量 */[[nodiscard]] bst_node_t NumNodes() const noexcept { return param_.num_nodes; }
500  /** \brief 树中的有效节点数量 */[[nodiscard]] bst_node_t NumValidNodes() const noexcept {
501  return param_.num_nodes - param_.num_deleted;
502  }
506  /** \brief 树中的额外节点数量 */[[nodiscard]] bst_node_t NumExtraNodes() const noexcept {
507  return param_.num_nodes - 1 - param_.num_deleted;
508  }
509  /* \brief 计算树中的叶子节点数量。 */
510  [[nodiscard]] bst_node_t GetNumLeaves() const;
511  /** \brief 计算树中的分裂节点数量。 */[[nodiscard]] bst_node_t GetNumSplitNodes() const;
512 
517  /** \brief 获取节点的深度 * \param nid 节点 id * \return 节点的深度 */[[nodiscard]] std::int32_t GetDepth(bst_node_t nid) const {
518  if (IsMultiTarget()) {
519  return this->p_mt_tree_->Depth(nid);
520  }
521  int depth = 0;
522  while (!nodes_[nid].IsRoot()) {
523  ++depth;
524  nid = nodes_[nid].Parent();
525  }
526  return depth;
527  }
531  /** \brief 设置节点的叶子向量。 * \param nidx 节点索引 * \param weight 叶子值 */void SetLeaf(bst_node_t nidx, linalg::VectorView<float const> weight) {
532  CHECK(IsMultiTarget());
533  return this->p_mt_tree_->SetLeaf(nidx, weight);
534  }
535 
540  /** \brief 获取节点的最大深度 * \param nid 节点 id * \return 节点的最大深度 */[[nodiscard]] int MaxDepth(int nid) const {
541  if (nodes_[nid].IsLeaf()) return 0;
542  return std::max(MaxDepth(nodes_[nid].LeftChild()) + 1, MaxDepth(nodes_[nid].RightChild()) + 1);
543  }
544 
548  /** \brief 获取树的最大深度 */int MaxDepth() { return MaxDepth(0); }
549 
554  /** \brief 一个帮助使用特征向量的结构 * \note 这是一个临时的辅助结构,请谨慎使用 */struct FVec {
559  /** * \brief 初始化特征向量 * \param size 特征向量的大小 */void Init(size_t size);
564  /** * \brief 使用数据实例填充特征向量 * \param inst 数据实例 */void Fill(SparsePage::Inst const& inst);
565 
570  /** * \brief 清空特征向量 */void Drop();
575  /** * \brief 获取特征向量的大小 */[[nodiscard]] size_t Size() const;
581  [[nodiscard]] bst_float GetFvalue(size_t i) const;
587  [[nodiscard]] bool IsMissing(size_t i) const;
588  [[nodiscard]] bool HasMissing() const;
589  void HasMissing(bool has_missing) { this->has_missing_ = has_missing; }
590 
591  [[nodiscard]] common::Span<float> Data() { return data_; }
592 
593  private
599  std::vector<float> data_;
600  bool has_missing_;
601  };
602 
609  std::vector<float>* mean_values,
610  bst_float* out_contribs) const;
618  [[nodiscard]] std::string DumpModel(const FeatureMap& fmap, bool with_stats,
619  std::string format) const;
625  [[nodiscard]] FeatureType NodeSplitType(bst_node_t nidx) const { return split_types_.at(nidx); }
629  [[nodiscard]] std::vector<FeatureType> const& GetSplitTypes() const {
630  return split_types_;
631  }
633  return split_categories_;
634  }
639  auto node_ptr = GetCategoriesMatrix().node_ptr;
640  auto categories = GetCategoriesMatrix().categories;
641  auto segment = node_ptr[nidx];
642  auto node_cats = categories.subspan(segment.beg, segment.size);
643  return node_cats;
644  }
645  [[nodiscard]] auto const& GetSplitCategoriesPtr() const { return split_categories_segments_; }
646 
655  struct Segment {
656  std::size_t beg{0};
657  std::size_t size{0};
658  };
662  };
663 
667  view.categories = this->GetSplitCategories();
668  view.node_ptr = common::Span<CategoricalSplitMatrix::Segment const>(split_categories_segments_);
669  return view;
670  }
671 
672  [[nodiscard]] bst_feature_t SplitIndex(bst_node_t nidx) const {
673  if (IsMultiTarget()) {
674  return this->p_mt_tree_->SplitIndex(nidx);
675  }
676  return (*this)[nidx].SplitIndex();
677  }
678  [[nodiscard]] float SplitCond(bst_node_t nidx) const {
679  if (IsMultiTarget()) {
680  return this->p_mt_tree_->SplitCond(nidx);
681  }
682  return (*this)[nidx].SplitCond();
683  }
684  [[nodiscard]] bool DefaultLeft(bst_node_t nidx) const {
685  if (IsMultiTarget()) {
686  return this->p_mt_tree_->DefaultLeft(nidx);
687  }
688  return (*this)[nidx].DefaultLeft();
689  }
690  [[nodiscard]] bst_node_t DefaultChild(bst_node_t nidx) const {
691  return this->DefaultLeft(nidx) ? this->LeftChild(nidx) : this->RightChild(nidx);
692  }
693  [[nodiscard]] bool IsRoot(bst_node_t nidx) const {
694  if (IsMultiTarget()) {
695  return nidx == kRoot;
696  }
697  return (*this)[nidx].IsRoot();
698  }
699  [[nodiscard]] bool IsLeaf(bst_node_t nidx) const {
700  if (IsMultiTarget()) {
701  return this->p_mt_tree_->IsLeaf(nidx);
702  }
703  return (*this)[nidx].IsLeaf();
704  }
705  [[nodiscard]] bst_node_t Parent(bst_node_t nidx) const {
706  if (IsMultiTarget()) {
707  return this->p_mt_tree_->Parent(nidx);
708  }
709  return (*this)[nidx].Parent();
710  }
711  [[nodiscard]] bst_node_t LeftChild(bst_node_t nidx) const {
712  if (IsMultiTarget()) {
713  return this->p_mt_tree_->LeftChild(nidx);
714  }
715  return (*this)[nidx].LeftChild();
716  }
717  [[nodiscard]] bst_node_t RightChild(bst_node_t nidx) const {
718  if (IsMultiTarget()) {
719  return this->p_mt_tree_->RightChild(nidx);
720  }
721  return (*this)[nidx].RightChild();
722  }
723  [[nodiscard]] bool IsLeftChild(bst_node_t nidx) const {
724  if (IsMultiTarget()) {
725  CHECK_NE(nidx, kRoot);
726  auto p = this->p_mt_tree_->Parent(nidx);
727  return nidx == this->p_mt_tree_->LeftChild(p);
728  }
729  return (*this)[nidx].IsLeftChild();
730  }
731  [[nodiscard]] bst_node_t Size() const {
732  if (IsMultiTarget()) {
733  return this->p_mt_tree_->Size();
734  }
735  return this->nodes_.size();
736  }
737 
738  private
739  template <bool typed>
740  void LoadCategoricalSplit(Json const& in);
741  void SaveCategoricalSplit(Json* p_out) const;
743  TreeParam param_;
744  // vector of nodes
745  std::vector<Node> nodes_;
746  // free node space, used during training process
747  std::vector<int> deleted_nodes_;
748  // stats of nodes
749  std::vector<RTreeNodeStat> stats_;
750  std::vector<FeatureType> split_types_;
751 
752  // Categories for each internal node.
753  std::vector<uint32_t> split_categories_;
754  // Ptr to split categories of each node.
755  std::vector<CategoricalSplitMatrix::Segment> split_categories_segments_;
756  // ptr to multi-target tree with vector leaf.
758  // allocate a new node,
759  // !!!!!! NOTE: may cause BUG here, nodes.resize
760  bst_node_t AllocNode() {
761  if (param_.num_deleted != 0) {
762  int nid = deleted_nodes_.back();
763  deleted_nodes_.pop_back();
764  nodes_[nid].Reuse();
765  --param_.num_deleted;
766  return nid;
767  }
768  int nd = param_.num_nodes++;
769  CHECK_LT(param_.num_nodes, std::numeric_limits<int>::max())
770  << "number of nodes in the tree exceed 2^31";
771  nodes_.resize(param_.num_nodes);
772  stats_.resize(param_.num_nodes);
773  split_types_.resize(param_.num_nodes, FeatureType::kNumerical);
774  split_categories_segments_.resize(param_.num_nodes);
775  return nd;
776  }
777  // delete a tree node, keep the parent field to allow trace back
778  void DeleteNode(int nid) {
779  CHECK_GE(nid, 1);
780  auto pid = (*this)[nid].Parent();
781  if (nid == (*this)[pid].LeftChild()) {
782  (*this)[pid].SetLeftChild(kInvalidNodeId);
783  } else {
784  (*this)[pid].SetRightChild(kInvalidNodeId);
785  }
786 
787  deleted_nodes_.push_back(nid);
788  nodes_[nid].MarkDelete();
789  ++param_.num_deleted;
790  }
791 };
792 
793 inline void RegTree::FVec::Init(size_t size) {
794  data_.resize(size);
795  std::fill(data_.begin(), data_.end(), std::numeric_limits<float>::quiet_NaN());
796  has_missing_ = true;
797 }
798 
799 inline void RegTree::FVec::Fill(SparsePage::Inst const& inst) {
800  auto p_data = inst.data();
801  auto p_out = data_.data();
802 
803  for (std::size_t i = 0, n = inst.size(); i < n; ++i) {
804  auto const& entry = p_data[i];
805  p_out[entry.index] = entry.fvalue;
806  }
807  has_missing_ = data_.size() != inst.size();
808 }
809 
810 inline void RegTree::FVec::Drop() { this->Init(this->Size()); }
811 
812 inline size_t RegTree::FVec::Size() const {
813  return data_.size();
814 }
815 
816 inline float RegTree::FVec::GetFvalue(size_t i) const {
817  return data_[i];
818 }
819 
820 inline bool RegTree::FVec::IsMissing(size_t i) const { return std::isnan(data_[i]); }
821 
822 inline bool RegTree::FVec::HasMissing() const { return has_missing_; }
823 
824 // Multi-target tree not yet implemented error
826  return " support for multi-target tree is not yet implemented.";
827 }
828 } // namespace xgboost
829 #endif // XGBOOST_TREE_MODEL_H_
Defines configuration macros and basic types for xgboost.
#define XGBOOST_DEVICE
Tag function as usable by device.
Definition: base.h:64
Helper for defining copyable data structure that contains unique pointers.
Definition: tree_model.h:128
T const * operator->() const noexcept
Definition: tree_model.h:145
T * get() const noexcept
Definition: tree_model.h:139
bool operator!() const
Definition: tree_model.h:148
CopyUniquePtr(CopyUniquePtr const &that)
Definition: tree_model.h:133
T * operator->() noexcept
Definition: tree_model.h:142
T & operator*()
Definition: tree_model.h:141
T const & operator*() const
Definition: tree_model.h:144
void reset(T *ptr)
Definition: tree_model.h:149
Feature map data structure to help text model dump. TODO(tqchen) consider make it even more lightweig...
Definition: feature_map.h:22
Data structure representing JSON format.
Definition: json.h:378
Tree structure for multi-target model.
Definition: multi_target_tree_model.h:23
static constexpr bst_node_t InvalidNodeId()
Definition: multi_target_tree_model.h:25
tree node
Definition: tree_model.h:165
XGBOOST_DEVICE int Parent() const
get parent of the node
Definition: tree_model.h:201
XGBOOST_DEVICE void MarkDelete()
mark that this node is deleted
Definition: tree_model.h:246
XGBOOST_DEVICE bool IsRoot() const
whether current node is root
Definition: tree_model.h:207
XGBOOST_DEVICE int RightChild() const
index of right child
Definition: tree_model.h:182
XGBOOST_DEVICE float LeafValue() const
Definition: tree_model.h:197
XGBOOST_DEVICE Node()
Definition: tree_model.h:167
XGBOOST_DEVICE void SetParent(int pidx, bool is_left_child=true)
Definition: tree_model.h:254
Node ByteSwap() const
Definition: tree_model.h:264
XGBOOST_DEVICE void SetLeaf(bst_float value, int right=kInvalidNodeId)
set the leaf value of the node
Definition: tree_model.h:240
XGBOOST_DEVICE bool IsLeftChild() const
whether current node is left child
Definition: tree_model.h:203
XGBOOST_DEVICE void SetSplit(unsigned split_index, SplitCondT split_cond, bool default_left=false)
set split condition of current node
Definition: tree_model.h:228
XGBOOST_DEVICE void SetLeftChild(int nid)
set the left child
Definition: tree_model.h:212
XGBOOST_DEVICE bst_feature_t SplitIndex() const
feature index of split condition
Definition: tree_model.h:188
XGBOOST_DEVICE bool IsDeleted() const
whether this node is deleted
Definition: tree_model.h:205
XGBOOST_DEVICE bool IsLeaf() const
whether current node is leaf node
Definition: tree_model.h:195
bool operator==(const Node &b) const
Definition: tree_model.h:258
Node(int32_t cleft, int32_t cright, int32_t parent, uint32_t split_ind, float split_cond, bool default_left)
Definition: tree_model.h:172
XGBOOST_DEVICE void Reuse()
Reuse this deleted node.
Definition: tree_model.h:250
XGBOOST_DEVICE void SetRightChild(int nid)
set the right child
Definition: tree_model.h:219
XGBOOST_DEVICE bool DefaultLeft() const
when feature is unknown, whether goes to left child
Definition: tree_model.h:193
XGBOOST_DEVICE int LeftChild() const
index of left child
Definition: tree_model.h:180
XGBOOST_DEVICE int DefaultChild() const
index of default child when feature is missing
Definition: tree_model.h:184
XGBOOST_DEVICE SplitCondT SplitCond() const
Definition: tree_model.h:199
define regression tree to be the most common tree model.
Definition: tree_model.h:157
int MaxDepth(int nid) const
get maximum depth
Definition: tree_model.h:540
void SaveModel(Json *out) const override
saves the model config to a JSON object
bst_target_t NumTargets() const
The size of leaf weight.
Definition: tree_model.h:481
void WalkTree(Func func) const
Definition: tree_model.h:391
void Save(dmlc::Stream *fo) const
save model to stream
bool IsLeaf(bst_node_t nidx) const
Definition: tree_model.h:699
bool operator==(const RegTree &b) const
Definition: tree_model.h:382
const RTreeNodeStat & Stat(int nid) const
get node statistics given nid
Definition: tree_model.h:364
void ExpandNode(bst_node_t nidx, bst_feature_t split_index, float split_cond, bool default_left, linalg::VectorView< float const > base_weight, linalg::VectorView< float const > left_weight, linalg::VectorView< float const > right_weight)
Expands a leaf node into two additional leaf nodes for a multi-target tree.
bst_node_t Parent(bst_node_t nidx) const
Definition: tree_model.h:705
bst_node_t NumNodes() const noexcept
Get the total number of nodes including deleted ones in this tree.
Definition: tree_model.h:496
const Node & operator[](int nid) const
get node given nid
Definition: tree_model.h:349
bst_node_t DefaultChild(bst_node_t nidx) const
Definition: tree_model.h:690
void ExpandNode(bst_node_t nid, unsigned split_index, bst_float split_value, bool default_left, bst_float base_weight, bst_float left_leaf_weight, bst_float right_leaf_weight, bst_float loss_change, float sum_hess, float left_sum, float right_sum, bst_node_t leaf_right_child=kInvalidNodeId)
Expands a leaf node into two additional leaf nodes.
Node & operator[](int nid)
get node given nid
Definition: tree_model.h:345
RegTree()
Definition: tree_model.h:322
static constexpr bst_node_t kInvalidNodeId
Definition: tree_model.h:160
bst_feature_t SplitIndex(bst_node_t nidx) const
Definition: tree_model.h:672
bool IsRoot(bst_node_t nidx) const
Definition: tree_model.h:693
static constexpr uint32_t kDeletedNodeMarker
Definition: tree_model.h:161
bool IsMultiTarget() const
Whether this is a multi-target tree.
Definition: tree_model.h:477
bst_node_t NumExtraNodes() const noexcept
number of extra nodes besides the root
Definition: tree_model.h:506
bool DefaultLeft(bst_node_t nidx) const
Definition: tree_model.h:684
auto GetMultiTargetTree() const
Get the underlying implementaiton of multi-target tree.
Definition: tree_model.h:485
void Load(dmlc::Stream *fi)
load model from stream
bst_node_t LeftChild(bst_node_t nidx) const
Definition: tree_model.h:711
bst_node_t GetNumLeaves() const
RegTree(bst_target_t n_targets, bst_feature_t n_features)
Constructor that initializes the tree model with shape.
Definition: tree_model.h:336
bst_node_t RightChild(bst_node_t nidx) const
Definition: tree_model.h:717
common::Span< uint32_t const > NodeCats(bst_node_t nidx) const
Get the bit storage for categories.
Definition: tree_model.h:638
bool IsLeftChild(bst_node_t nidx) const
Definition: tree_model.h:723
CategoricalSplitMatrix GetCategoriesMatrix() const
Definition: tree_model.h:664
RTreeNodeStat & Stat(int nid)
get node statistics given nid
Definition: tree_model.h:360
bst_float SplitCondT
Definition: tree_model.h:159
void ExpandCategorical(bst_node_t nid, bst_feature_t split_index, common::Span< const uint32_t > split_cat, bool default_left, bst_float base_weight, bst_float left_leaf_weight, bst_float right_leaf_weight, bst_float loss_change, float sum_hess, float left_sum, float right_sum)
Expands a leaf node with categories.
bool Equal(const RegTree &b) const
Compares whether 2 trees are equal from a user's perspective. The equality compares only non-deleted ...
std::vector< FeatureType > const & GetSplitTypes() const
Get split types for all nodes.
Definition: tree_model.h:629
void CollapseToLeaf(int rid, bst_float value)
collapse a non leaf node to a leaf node, delete its children
Definition: tree_model.h:311
bst_node_t NumValidNodes() const noexcept
Get the total number of valid nodes in this tree.
Definition: tree_model.h:500
void ChangeToLeaf(int rid, bst_float value)
change a non leaf node to a leaf node, delete its children
Definition: tree_model.h:299
const std::vector< RTreeNodeStat > & GetStats() const
get const reference to stats
Definition: tree_model.h:357
void SetLeaf(bst_node_t nidx, linalg::VectorView< float const > weight)
Set the leaf weight for a multi-target tree.
Definition: tree_model.h:531
const std::vector< Node > & GetNodes() const
get const reference to nodes
Definition: tree_model.h:354
void CalculateContributionsApprox(const RegTree::FVec &feat, std::vector< float > *mean_values, bst_float *out_contribs) const
calculate the approximate feature contributions for the given root
void LoadModel(Json const &in) override
load the model from a JSON object
std::string DumpModel(const FeatureMap &fmap, bool with_stats, std::string format) const
dump the model in the requested format as a text string
FeatureType NodeSplitType(bst_node_t nidx) const
Get split type for a node.
Definition: tree_model.h:625
bst_feature_t NumFeatures() const noexcept
Get the number of features.
Definition: tree_model.h:492
common::Span< uint32_t const > GetSplitCategories() const
Definition: tree_model.h:632
bool HasCategoricalSplit() const
Whether this tree has categorical split.
Definition: tree_model.h:473
std::int32_t GetDepth(bst_node_t nid) const
get current depth
Definition: tree_model.h:517
static constexpr bst_node_t kRoot
Definition: tree_model.h:162
bst_node_t GetNumSplitNodes() const
auto const & GetSplitCategoriesPtr() const
Definition: tree_model.h:645
float SplitCond(bst_node_t nidx) const
Definition: tree_model.h:678
int MaxDepth()
get maximum depth
Definition: tree_model.h:548
bst_node_t Size() const
Definition: tree_model.h:731
span class implementation, based on ISO++20 span<T>. The interface should be the same.
Definition: span.h:431
constexpr XGBOOST_DEVICE pointer data() const __span_noexcept
Definition: span.h:550
XGBOOST_DEVICE auto subspan() const -> Span< element_type, detail::ExtentValue< Extent, Offset, Count >::value >
Definition: span.h:597
constexpr XGBOOST_DEVICE index_type size() const __span_noexcept
Definition: span.h:555
A tensor view with static type and dimension. It implements indexing and slicing.
Definition: linalg.h:294
The input data structure of xgboost.
Feature map data structure to help visualization and model dump.
Linear algebra related utilities.
Defines the abstract interface for different components in XGBoost.
Core data structure for multi-target trees.
Definition: base.h:89
std::vector< std::pair< std::string, std::string > > Args
Definition: base.h:316
std::int32_t bst_node_t
Type for tree node index.
Definition: base.h:111
FeatureType
Definition: data.h:41
std::uint32_t bst_target_t
Type for indexing into output targets.
Definition: base.h:119
std::uint32_t bst_feature_t
Type for data column (feature) index.
Definition: base.h:99
float bst_float
float type, used for storing statistics
Definition: base.h:95
StringView MTNotImplemented()
Definition: tree_model.h:825
Definition: model.h:17
node statistics used in regression tree
Definition: tree_model.h:95
RTreeNodeStat(float loss_chg, float sum_hess, float weight)
Definition: tree_model.h:106
bst_float loss_chg
loss change caused by current split
Definition: tree_model.h:97
int leaf_child_cnt
number of child that is leaf node known up to now
Definition: tree_model.h:103
bst_float sum_hess
sum of hessian values, used to measure coverage of data
Definition: tree_model.h:99
bool operator==(const RTreeNodeStat &b) const
Definition: tree_model.h:108
RTreeNodeStat ByteSwap() const
Definition: tree_model.h:114
bst_float base_weight
weight of current node
Definition: tree_model.h:101
std::size_t size
Definition: tree_model.h:657
std::size_t beg
Definition: tree_model.h:656
CSR-like matrix for categorical splits.
Definition: tree_model.h:654
common::Span< uint32_t const > categories
Definition: tree_model.h:660
common::Span< Segment const > node_ptr
Definition: tree_model.h:661
common::Span< FeatureType const > split_type
Definition: tree_model.h:659
dense feature vector that can be taken by RegTree and can be construct from sparse feature vector.
Definition: tree_model.h:554
void HasMissing(bool has_missing)
Definition: tree_model.h:589
void Drop()
drop the trace after fill, must be called after fill.
Definition: tree_model.h:810
bool HasMissing() const
Definition: tree_model.h:822
bool IsMissing(size_t i) const
check whether i-th entry is missing
Definition: tree_model.h:820
size_t Size() const
returns the size of the feature vector
Definition: tree_model.h:812
void Init(size_t size)
initialize the vector with size vector
Definition: tree_model.h:793
common::Span< float > Data()
Definition: tree_model.h:591
void Fill(SparsePage::Inst const &inst)
fill the vector with sparse vector
Definition: tree_model.h:799
bst_float GetFvalue(size_t i) const
get ith value
Definition: tree_model.h:816
Definition: string_view.h:16
meta parameters of the tree
Definition: tree_model.h:34
bst_feature_t num_feature
number of features used for tree construction
Definition: tree_model.h:44
int num_nodes
total number of nodes
Definition: tree_model.h:38
int num_deleted
number of deleted nodes
Definition: tree_model.h:40
bool operator==(const TreeParam &b) const
Definition: tree_model.h:88
int reserved[31]
reserved part, make sure alignment works for 64bit
Definition: tree_model.h:51
TreeParam ByteSwap() const
Definition: tree_model.h:61
TreeParam()
constructor
Definition: tree_model.h:53
DMLC_DECLARE_PARAMETER(TreeParam)
Definition: tree_model.h:74
bst_target_t size_leaf_vector
leaf vector size, used for vector tree used to store more than one dimensional information in tree
Definition: tree_model.h:49
int deprecated_num_roots
(Deprecated) number of start root
Definition: tree_model.h:36
int deprecated_max_depth
maximum depth, this is a statistics of the tree
Definition: tree_model.h:42