xgboost
tree_model.h
前往此文件文档。
1 
7 #ifndef XGBOOST_TREE_MODEL_H_
8 #define XGBOOST_TREE_MODEL_H_
9 
10 #include <xgboost/base.h>
11 #include <xgboost/data.h>
12 #include <xgboost/feature_map.h>
13 #include <xgboost/linalg.h> // 用于 VectorView
14 #include <xgboost/logging.h>
15 #include <xgboost/model.h>
16 #include <xgboost/multi_target_tree_model.h> // 用于 MultiTargetTree
17 
18 #include <algorithm>
19 #include <cstring>
20 #include <limits>
21 #include <memory> // 用于 make_unique
22 #include <stack>
23 #include <string>
24 #include <vector>
25 
26 namespace xgboost {
27 class Json;
28 
30 struct TreeParam {
39 
40  bool operator==(const TreeParam& b) const {
41  return num_nodes == b.num_nodes && num_deleted == b.num_deleted &&
43  }
44 
45  void FromJson(Json const& in);
46  void ToJson(Json* p_out) const;
47 };
48 
50 struct RTreeNodeStat {
58  int leaf_child_cnt {0};
59 
60  RTreeNodeStat() = default;
61  RTreeNodeStat(float loss_chg, float sum_hess, float weight)
63  bool operator==(const RTreeNodeStat& b) const {
64  return loss_chg == b.loss_chg && sum_hess == b.sum_hess &&
66  }
67 };
68 
72 template <typename T>
74  std::unique_ptr<T> ptr_{nullptr};
75 
76  public
77  CopyUniquePtr() = default;
79  ptr_.reset(nullptr);
80  if (that.ptr_) {
81  ptr_ = std::make_unique<T>(*that);
82  }
83  }
84  T* get() const noexcept { return ptr_.get(); } // NOLINT
85 
86  T& operator*() { return *ptr_; }
87  T* operator->() noexcept { return this->get(); }
88 
89  T const& operator*() const { return *ptr_; }
90  T const* operator->() const noexcept { return this->get(); }
91 
92  explicit operator bool() const { return static_cast<bool>(ptr_); }
93  bool operator!() const { return !ptr_; }
94  void reset(T* ptr) { ptr_.reset(ptr); } // NOLINT
95 };
96 
102 class RegTree : public Model {
103  public
106  static constexpr uint32_t kDeletedNodeMarker = std::numeric_limits<uint32_t>::max();
107  static constexpr bst_node_t kRoot{0};
108 
110  class Node {
111  public
113  // 断言紧凑对齐
114  static_assert(sizeof(Node) == 4 * sizeof(int) + sizeof(Info), "Node: 64 位对齐");
115  }
116  Node(int32_t cleft, int32_t cright, int32_t parent, uint32_t split_ind, float split_cond,
117  bool default_left)
118  : parent_{parent}, cleft_{cleft}, cright_{cright} {
119  this->SetParent(parent_);
120  this->SetSplit(split_ind, split_cond, default_left);
121  }
122 
124  [[nodiscard]] XGBOOST_DEVICE int LeftChild() const { return this->cleft_; }
126  [[nodiscard]] XGBOOST_DEVICE int RightChild() const { return this->cright_; }
128  [[nodiscard]] XGBOOST_DEVICE int DefaultChild() const {
129  return this->DefaultLeft() ? this->LeftChild() : this->RightChild();
130  }
132  [[nodiscard]] XGBOOST_DEVICE bst_feature_t SplitIndex() const {
133  static_assert(!std::is_signed_v<bst_feature_t>);
134  return sindex_ & ((1U << 31) - 1U);
135  }
137  [[nodiscard]] XGBOOST_DEVICE bool DefaultLeft() const { return (sindex_ >> 31) != 0; }
139  [[nodiscard]] XGBOOST_DEVICE bool IsLeaf() const { return cleft_ == kInvalidNodeId; }
141  [[nodiscard]] XGBOOST_DEVICE float LeafValue() const { return (this->info_).leaf_value; }
143  [[nodiscard]] XGBOOST_DEVICE SplitCondT SplitCond() const { return (this->info_).split_cond; }
145  [[nodiscard]] XGBOOST_DEVICE int Parent() const { return parent_ & ((1U << 31) - 1); }
147  [[nodiscard]] XGBOOST_DEVICE bool IsLeftChild() const { return (parent_ & (1U << 31)) != 0; }
149  [[nodiscard]] XGBOOST_DEVICE bool IsDeleted() const { return sindex_ == kDeletedNodeMarker; }
151  [[nodiscard]] XGBOOST_DEVICE bool IsRoot() const { return parent_ == kInvalidNodeId; }
156  XGBOOST_DEVICE void SetLeftChild(int nid) {
157  this->cleft_ = nid;
158  }
164  this->cright_ = nid;
165  }
172  XGBOOST_DEVICE void SetSplit(unsigned split_index, SplitCondT split_cond,
173  bool default_left = false) {
174  if (default_left) split_index |= (1U << 31);
175  this->sindex_ = split_index;
176  (this->info_).split_cond = split_cond;
177  }
184  XGBOOST_DEVICE void SetLeaf(bst_float value, int right = kInvalidNodeId) {
185  (this->info_).leaf_value = value;
186  this->cleft_ = kInvalidNodeId;
187  this->cright_ = right;
188  }
191  this->sindex_ = kDeletedNodeMarker;
192  }
195  this->sindex_ = 0;
196  }
197  // 设置父节点
198  XGBOOST_DEVICE void SetParent(int pidx, bool is_left_child = true) {
199  if (is_left_child) pidx |= (1U << 31);
200  this->parent_ = pidx;
201  }
202  bool operator==(const Node& b) const {
203  return parent_ == b.parent_ && cleft_ == b.cleft_ &&
204  cright_ == b.cright_ && sindex_ == b.sindex_ &&
205  info_.leaf_value == b.info_.leaf_value;
206  }
207 
208  private
213  union Info{
214  bst_float leaf_value;
215  SplitCondT split_cond;
216  };
217  // 指向父节点的指针,最高位用于
218  // 指示它是否为左子节点
219  int32_t parent_{kInvalidNodeId};
220  // 指向左子节点、右子节点的指针
221  int32_t cleft_{kInvalidNodeId}, cright_{kInvalidNodeId};
222  // 分裂特征索引,左分裂或右分裂取决于最高位
223  uint32_t sindex_{0};
224  // 额外信息
225  Info info_;
226  };
227 
233  void ChangeToLeaf(int rid, bst_float value) {
234  CHECK(nodes_[nodes_[rid].LeftChild() ].IsLeaf());
235  CHECK(nodes_[nodes_[rid].RightChild()].IsLeaf());
236  this->DeleteNode(nodes_[rid].LeftChild());
237  this->DeleteNode(nodes_[rid].RightChild());
238  nodes_[rid].SetLeaf(value);
239  }
245  void CollapseToLeaf(int rid, bst_float value) {
246  if (nodes_[rid].IsLeaf()) return;
247  if (!nodes_[nodes_[rid].LeftChild() ].IsLeaf()) {
248  CollapseToLeaf(nodes_[rid].LeftChild(), 0.0f);
249  }
250  if (!nodes_[nodes_[rid].RightChild() ].IsLeaf()) {
251  CollapseToLeaf(nodes_[rid].RightChild(), 0.0f);
252  }
253  this->ChangeToLeaf(rid, value);
254  }
255 
257  nodes_.resize(param_.num_nodes);
258  stats_.resize(param_.num_nodes);
259  split_types_.resize(param_.num_nodes, FeatureType::kNumerical);
260  split_categories_segments_.resize(param_.num_nodes);
261  for (int i = 0; i < param_.num_nodes; i++) {
262  nodes_[i].SetLeaf(0.0f);
263  nodes_[i].SetParent(kInvalidNodeId);
264  }
265  }
269  explicit RegTree(bst_target_t n_targets, bst_feature_t n_features) : RegTree{} {
270  param_.num_feature = n_features;
271  param_.size_leaf_vector = n_targets;
272  if (n_targets > 1) {
273  this->p_mt_tree_.reset(new MultiTargetTree{&param_});
274  }
275  }
276 
278  Node& operator[](int nid) {
279  return nodes_[nid];
280  }
282  const Node& operator[](int nid) const {
283  return nodes_[nid];
284  }
285 
287  [[nodiscard]] const std::vector<Node>& GetNodes() const { return nodes_; }
288 
290  [[nodiscard]] const std::vector<RTreeNodeStat>& GetStats() const { return stats_; }
291 
293  RTreeNodeStat& Stat(int nid) {
294  return stats_[nid];
295  }
297  [[nodiscard]] const RTreeNodeStat& Stat(int nid) const {
298  return stats_[nid];
299  }
300 
301  void LoadModel(Json const& in) override;
302  void SaveModel(Json* out) const override;
303 
304  bool operator==(const RegTree& b) const {
305  return nodes_ == b.nodes_ && stats_ == b.stats_ &&
306  deleted_nodes_ == b.deleted_nodes_ && param_ == b.param_;
307  }
308  /* \brief 遍历此树中的所有节点。
309  *
310  * \param Function 接受节点索引,并在迭代应停止时返回 false,否则返回 true。
311  * 停止,否则返回 true。
312  */
313  template <typename Func> void WalkTree(Func func) const {
314  std::stack<bst_node_t> nodes;
315  nodes.push(kRoot);
316  auto &self = *this;
317  while (!nodes.empty()) {
318  auto nidx = nodes.top();
319  nodes.pop();
320  if (!func(nidx)) {
321  return;
322  }
323  auto left = self.LeftChild(nidx);
324  auto right = self.RightChild(nidx);
325  if (left != RegTree::kInvalidNodeId) {
326  nodes.push(left);
327  }
328  if (right != RegTree::kInvalidNodeId) {
329  nodes.push(right);
330  }
331  }
332  }
339  [[nodiscard]] bool Equal(const RegTree& b) const;
340 
358  void ExpandNode(bst_node_t nid, unsigned split_index, bst_float split_value,
359  bool default_left, bst_float base_weight,
360  bst_float left_leaf_weight, bst_float right_leaf_weight,
361  bst_float loss_change, float sum_hess, float left_sum,
362  float right_sum,
363  bst_node_t leaf_right_child = kInvalidNodeId);
367  void ExpandNode(bst_node_t nidx, bst_feature_t split_index, float split_cond, bool default_left,
370  linalg::VectorView<float const> right_weight);
371 
388  common::Span<const uint32_t> split_cat, bool default_left,
389  bst_float base_weight, bst_float left_leaf_weight,
390  bst_float right_leaf_weight, bst_float loss_change, float sum_hess,
391  float left_sum, float right_sum);
395  [[nodiscard]] bool HasCategoricalSplit() const { return !split_categories_.empty(); }
399  [[nodiscard]] bool IsMultiTarget() const { return static_cast<bool>(p_mt_tree_); }
403  [[nodiscard]] bst_target_t NumTargets() const { return param_.size_leaf_vector; }
407  [[nodiscard]] auto GetMultiTargetTree() const {
408  CHECK(IsMultiTarget());
409  return p_mt_tree_.get();
410  }
414  [[nodiscard]] bst_feature_t NumFeatures() const noexcept { return param_.num_feature; }
418  [[nodiscard]] bst_node_t NumNodes() const noexcept { return param_.num_nodes; }
422  [[nodiscard]] bst_node_t NumValidNodes() const noexcept {
423  return param_.num_nodes - param_.num_deleted;
424  }
428  [[nodiscard]] bst_node_t NumExtraNodes() const noexcept {
429  return param_.num_nodes - 1 - param_.num_deleted;
430  }
431  /* \brief 计算树中的叶子数量。 */
432  [[nodiscard]] bst_node_t GetNumLeaves() const;
433  [[nodiscard]] bst_node_t GetNumSplitNodes() const;
434 
439  [[nodiscard]] std::int32_t GetDepth(bst_node_t nid) const {
440  if (IsMultiTarget()) {
441  return this->p_mt_tree_->Depth(nid);
442  }
443  int depth = 0;
444  while (!nodes_[nid].IsRoot()) {
445  ++depth;
446  nid = nodes_[nid].Parent();
447  }
448  return depth;
449  }
454  CHECK(IsMultiTarget());
455  return this->p_mt_tree_->SetLeaf(nidx, weight);
456  }
457 
462  [[nodiscard]] int MaxDepth(int nid) const {
463  if (nodes_[nid].IsLeaf()) return 0;
464  return std::max(MaxDepth(nodes_[nid].LeftChild()) + 1, MaxDepth(nodes_[nid].RightChild()) + 1);
465  }
466 
470  int MaxDepth() { return MaxDepth(0); }
471 
476  struct FVec {
481  void Init(size_t size);
486  void Fill(SparsePage::Inst const& inst);
487 
492  void Drop();
497  [[nodiscard]] size_t Size() const;
503  [[nodiscard]] bst_float GetFvalue(size_t i) const;
509  [[nodiscard]] bool IsMissing(size_t i) const;
510  [[nodiscard]] bool HasMissing() const;
511  void HasMissing(bool has_missing) { this->has_missing_ = has_missing; }
512 
513  [[nodiscard]] common::Span<float> Data() { return data_; }
514 
515  private
521  std::vector<float> data_;
522  bool has_missing_;
523  };
524 
532  [[nodiscard]] std::string DumpModel(const FeatureMap& fmap, bool with_stats,
533  std::string format) const;
539  [[nodiscard]] FeatureType NodeSplitType(bst_node_t nidx) const { return split_types_.at(nidx); }
543  [[nodiscard]] std::vector<FeatureType> const& GetSplitTypes() const {
544  return split_types_;
545  }
547  return split_categories_;
548  }
553  auto node_ptr = GetCategoriesMatrix().node_ptr;
554  auto categories = GetCategoriesMatrix().categories;
555  auto segment = node_ptr[nidx];
556  auto node_cats = categories.subspan(segment.beg, segment.size);
557  return node_cats;
558  }
559  [[nodiscard]] auto const& GetSplitCategoriesPtr() const { return split_categories_segments_; }
560 
569  struct Segment {
570  std::size_t beg{0};
571  std::size_t size{0};
572  };
576  };
577 
581  view.categories = this->GetSplitCategories();
582  view.node_ptr = common::Span<CategoricalSplitMatrix::Segment const>(split_categories_segments_);
583  return view;
584  }
585 
586  [[nodiscard]] bst_feature_t SplitIndex(bst_node_t nidx) const {
587  if (IsMultiTarget()) {
588  return this->p_mt_tree_->SplitIndex(nidx);
589  }
590  return (*this)[nidx].SplitIndex();
591  }
592  [[nodiscard]] float SplitCond(bst_node_t nidx) const {
593  if (IsMultiTarget()) {
594  return this->p_mt_tree_->SplitCond(nidx);
595  }
596  return (*this)[nidx].SplitCond();
597  }
598  [[nodiscard]] bool DefaultLeft(bst_node_t nidx) const {
599  if (IsMultiTarget()) {
600  return this->p_mt_tree_->DefaultLeft(nidx);
601  }
602  return (*this)[nidx].DefaultLeft();
603  }
604  [[nodiscard]] bst_node_t DefaultChild(bst_node_t nidx) const {
605  return this->DefaultLeft(nidx) ? this->LeftChild(nidx) : this->RightChild(nidx);
606  }
607  [[nodiscard]] bool IsRoot(bst_node_t nidx) const {
608  if (IsMultiTarget()) {
609  return nidx == kRoot;
610  }
611  return (*this)[nidx].IsRoot();
612  }
613  [[nodiscard]] bool IsLeaf(bst_node_t nidx) const {
614  if (IsMultiTarget()) {
615  return this->p_mt_tree_->IsLeaf(nidx);
616  }
617  return (*this)[nidx].IsLeaf();
618  }
619  [[nodiscard]] bst_node_t Parent(bst_node_t nidx) const {
620  if (IsMultiTarget()) {
621  return this->p_mt_tree_->Parent(nidx);
622  }
623  return (*this)[nidx].Parent();
624  }
625  [[nodiscard]] bst_node_t LeftChild(bst_node_t nidx) const {
626  if (IsMultiTarget()) {
627  return this->p_mt_tree_->LeftChild(nidx);
628  }
629  return (*this)[nidx].LeftChild();
630  }
631  [[nodiscard]] bst_node_t RightChild(bst_node_t nidx) const {
632  if (IsMultiTarget()) {
633  return this->p_mt_tree_->RightChild(nidx);
634  }
635  return (*this)[nidx].RightChild();
636  }
637  [[nodiscard]] bool IsLeftChild(bst_node_t nidx) const {
638  if (IsMultiTarget()) {
639  CHECK_NE(nidx, kRoot);
640  auto p = this->p_mt_tree_->Parent(nidx);
641  return nidx == this->p_mt_tree_->LeftChild(p);
642  }
643  return (*this)[nidx].IsLeftChild();
644  }
645  [[nodiscard]] bst_node_t Size() const {
646  if (IsMultiTarget()) {
647  return this->p_mt_tree_->Size();
648  }
649  return this->nodes_.size();
650  }
651 
652  private
653  template <bool typed>
654  void LoadCategoricalSplit(Json const& in);
655  void SaveCategoricalSplit(Json* p_out) const;
657  TreeParam param_;
658  // vector of nodes
659  std::vector<Node> nodes_;
660  // free node space, used during training process
661  std::vector<int> deleted_nodes_;
662  // stats of nodes
663  std::vector<RTreeNodeStat> stats_;
664  std::vector<FeatureType> split_types_;
665 
666  // Categories for each internal node.
667  std::vector<uint32_t> split_categories_;
668  // Ptr to split categories of each node.
669  std::vector<CategoricalSplitMatrix::Segment> split_categories_segments_;
670  // ptr to multi-target tree with vector leaf.
672  // allocate a new node,
673  // !!!!!! NOTE: may cause BUG here, nodes.resize
674  bst_node_t AllocNode() {
675  if (param_.num_deleted != 0) {
676  int nid = deleted_nodes_.back();
677  deleted_nodes_.pop_back();
678  nodes_[nid].Reuse();
679  --param_.num_deleted;
680  return nid;
681  }
682  int nd = param_.num_nodes++;
683  CHECK_LT(param_.num_nodes, std::numeric_limits<int>::max())
684  << "number of nodes in the tree exceed 2^31";
685  nodes_.resize(param_.num_nodes);
686  stats_.resize(param_.num_nodes);
687  split_types_.resize(param_.num_nodes, FeatureType::kNumerical);
688  split_categories_segments_.resize(param_.num_nodes);
689  return nd;
690  }
691  // delete a tree node, keep the parent field to allow trace back
692  void DeleteNode(int nid) {
693  CHECK_GE(nid, 1);
694  auto pid = (*this)[nid].Parent();
695  if (nid == (*this)[pid].LeftChild()) {
696  (*this)[pid].SetLeftChild(kInvalidNodeId);
697  } else {
698  (*this)[pid].SetRightChild(kInvalidNodeId);
699  }
700 
701  deleted_nodes_.push_back(nid);
702  nodes_[nid].MarkDelete();
703  ++param_.num_deleted;
704  }
705 };
706 
707 inline void RegTree::FVec::Init(size_t size) {
708  data_.resize(size);
709  std::fill(data_.begin(), data_.end(), std::numeric_limits<float>::quiet_NaN());
710  has_missing_ = true;
711 }
712 
713 inline void RegTree::FVec::Fill(SparsePage::Inst const& inst) {
714  auto p_data = inst.data();
715  auto p_out = data_.data();
716 
717  for (std::size_t i = 0, n = inst.size(); i < n; ++i) {
718  auto const& entry = p_data[i];
719  p_out[entry.index] = entry.fvalue;
720  }
721  has_missing_ = data_.size() != inst.size();
722 }
723 
724 inline void RegTree::FVec::Drop() { this->Init(this->Size()); }
725 
726 inline size_t RegTree::FVec::Size() const {
727  return data_.size();
728 }
729 
730 inline float RegTree::FVec::GetFvalue(size_t i) const {
731  return data_[i];
732 }
733 
734 inline bool RegTree::FVec::IsMissing(size_t i) const { return std::isnan(data_[i]); }
735 
736 inline bool RegTree::FVec::HasMissing() const { return has_missing_; }
737 
738 // Multi-target tree not yet implemented error
740  return " support for multi-target tree is not yet implemented.";
741 }
742 } // namespace xgboost
743 #endif // XGBOOST_TREE_MODEL_H_
为 xgboost 定义配置宏和基本类型。
#define XGBOOST_DEVICE
Tag function as usable by device.
Definition: base.h:64
用于定义包含唯一指针的可复制数据结构的辅助类。
定义: tree_model.h:73
T const * operator->() const noexcept
定义: tree_model.h:90
T * get() const noexcept
定义: tree_model.h:84
bool operator!() const
定义: tree_model.h:93
CopyUniquePtr(CopyUniquePtr const &that)
定义: tree_model.h:78
T * operator->() noexcept
定义: tree_model.h:87
T & operator*()
定义: tree_model.h:86
T const & operator*() const
定义: tree_model.h:89
void reset(T *ptr)
定义: tree_model.h:94
特征映射数据结构,用于辅助文本模型转储。TODO(tqchen) 考虑使其更轻量级...
定义: feature_map.h:22
表示JSON格式的数据结构。
Definition: json.h:392
多目标模型的树结构。
定义: multi_target_tree_model.h:69
static constexpr bst_node_t InvalidNodeId()
定义: multi_target_tree_model.h:71
树节点
定义: tree_model.h:110
XGBOOST_DEVICE int Parent() const
获取节点的父节点
定义: tree_model.h:145
XGBOOST_DEVICE void MarkDelete()
标记此节点已删除
定义: tree_model.h:190
XGBOOST_DEVICE bool IsRoot() const
当前节点是否为根节点
定义: tree_model.h:151
XGBOOST_DEVICE int RightChild() const
右子节点索引
定义: tree_model.h:126
XGBOOST_DEVICE float LeafValue() const
定义: tree_model.h:141
XGBOOST_DEVICE Node()
定义: tree_model.h:112
XGBOOST_DEVICE void SetParent(int pidx, bool is_left_child=true)
定义: tree_model.h:198
XGBOOST_DEVICE void SetLeaf(bst_float value, int right=kInvalidNodeId)
设置节点的叶子值
定义: tree_model.h:184
XGBOOST_DEVICE bool IsLeftChild() const
当前节点是否为左子节点
定义: tree_model.h:147
XGBOOST_DEVICE void SetSplit(unsigned split_index, SplitCondT split_cond, bool default_left=false)
设置当前节点的分割条件
定义: tree_model.h:172
XGBOOST_DEVICE void SetLeftChild(int nid)
设置左子节点
定义: tree_model.h:156
XGBOOST_DEVICE bst_feature_t SplitIndex() const
分割条件的特征索引
定义: tree_model.h:132
XGBOOST_DEVICE bool IsDeleted() const
此节点是否已删除
定义: tree_model.h:149
XGBOOST_DEVICE bool IsLeaf() const
当前节点是否为叶子节点
定义: tree_model.h:139
bool operator==(const Node &b) const
定义: tree_model.h:202
Node(int32_t cleft, int32_t cright, int32_t parent, uint32_t split_ind, float split_cond, bool default_left)
定义: tree_model.h:116
XGBOOST_DEVICE void Reuse()
重用此已删除节点。
定义: tree_model.h:194
XGBOOST_DEVICE void SetRightChild(int nid)
设置右子节点
定义: tree_model.h:163
XGBOOST_DEVICE bool DefaultLeft() const
当特征未知时,是否走向左子节点
定义: tree_model.h:137
XGBOOST_DEVICE int LeftChild() const
左子节点索引
定义: tree_model.h:124
XGBOOST_DEVICE int DefaultChild() const
当特征缺失时,默认子节点的索引
定义: tree_model.h:128
XGBOOST_DEVICE SplitCondT SplitCond() const
定义: tree_model.h:143
将回归树定义为最常见的树模型。
定义: tree_model.h:102
int MaxDepth(int nid) const
获取最大深度
定义: tree_model.h:462
void SaveModel(Json *out) const override
将模型配置保存到JSON对象
bst_target_t NumTargets() const
叶子权重的尺寸。
定义: tree_model.h:403
void WalkTree(Func func) const
定义: tree_model.h:313
bool IsLeaf(bst_node_t nidx) const
定义: tree_model.h:613
bool operator==(const RegTree &b) const
定义: tree_model.h:304
const RTreeNodeStat & Stat(int nid) const
获取给定nid的节点统计信息
定义: tree_model.h:297
void ExpandNode(bst_node_t nidx, bst_feature_t split_index, float split_cond, bool default_left, linalg::VectorView< float const > base_weight, linalg::VectorView< float const > left_weight, linalg::VectorView< float const > right_weight)
将叶子节点扩展为多目标树的两个额外的叶子节点。
bst_node_t Parent(bst_node_t nidx) const
定义: tree_model.h:619
bst_node_t NumNodes() const noexcept
获取此树中包括已删除节点在内的总节点数。
定义: tree_model.h:418
const Node & operator[](int nid) const
获取给定nid的节点
定义: tree_model.h:282
bst_node_t DefaultChild(bst_node_t nidx) const
定义: tree_model.h:604
void ExpandNode(bst_node_t nid, unsigned split_index, bst_float split_value, bool default_left, bst_float base_weight, bst_float left_leaf_weight, bst_float right_leaf_weight, bst_float loss_change, float sum_hess, float left_sum, float right_sum, bst_node_t leaf_right_child=kInvalidNodeId)
将叶子节点扩展为两个额外的叶子节点。
Node & operator[](int nid)
获取给定nid的节点
定义: tree_model.h:278
RegTree()
定义: tree_model.h:256
static constexpr bst_node_t kInvalidNodeId
定义: tree_model.h:105
bst_feature_t SplitIndex(bst_node_t nidx) const
定义: tree_model.h:586
bool IsRoot(bst_node_t nidx) const
定义: tree_model.h:607
static constexpr uint32_t kDeletedNodeMarker
定义: tree_model.h:106
bool IsMultiTarget() const
这是否是多目标树。
定义: tree_model.h:399
bst_node_t NumExtraNodes() const noexcept
除根节点外的额外节点数
定义: tree_model.h:428
bool DefaultLeft(bst_node_t nidx) const
定义: tree_model.h:598
auto GetMultiTargetTree() const
获取多目标树的底层实现。
定义: tree_model.h:407
bst_node_t LeftChild(bst_node_t nidx) const
定义: tree_model.h:625
bst_node_t GetNumLeaves() const
RegTree(bst_target_t n_targets, bst_feature_t n_features)
构造函数,用形状初始化树模型。
定义: tree_model.h:269
bst_node_t RightChild(bst_node_t nidx) const
定义: tree_model.h:631
common::Span< uint32_t const > NodeCats(bst_node_t nidx) const
获取类别的位存储。
定义: tree_model.h:552
bool IsLeftChild(bst_node_t nidx) const
定义: tree_model.h:637
CategoricalSplitMatrix GetCategoriesMatrix() const
定义: tree_model.h:578
RTreeNodeStat & Stat(int nid)
获取给定nid的节点统计信息
定义: tree_model.h:293
bst_float SplitCondT
定义: tree_model.h:104
void ExpandCategorical(bst_node_t nid, bst_feature_t split_index, common::Span< const uint32_t > split_cat, bool default_left, bst_float base_weight, bst_float left_leaf_weight, bst_float right_leaf_weight, bst_float loss_change, float sum_hess, float left_sum, float right_sum)
用类别扩展叶子节点。
bool Equal(const RegTree &b) const
从用户角度比较两棵树是否相等。相等性仅比较未删除的...
std::vector< FeatureType > const & GetSplitTypes() const
获取所有节点的分割类型。
定义: tree_model.h:543
void CollapseToLeaf(int rid, bst_float value)
将非叶节点折叠为叶节点,删除其子节点
定义: tree_model.h:245
bst_node_t NumValidNodes() const noexcept
获取此树中有效节点的总数。
定义: tree_model.h:422
void ChangeToLeaf(int rid, bst_float value)
将非叶节点更改为叶节点,删除其子节点
定义: tree_model.h:233
const std::vector< RTreeNodeStat > & GetStats() const
获取stats的常量引用
定义: tree_model.h:290
void SetLeaf(bst_node_t nidx, linalg::VectorView< float const > weight)
设置多目标树的叶子权重。
定义: tree_model.h:453
const std::vector< Node > & GetNodes() const
获取节点的常量引用
定义: tree_model.h:287
void LoadModel(Json const &in) override
从JSON对象加载模型
std::string DumpModel(const FeatureMap &fmap, bool with_stats, std::string format) const
以请求的格式将模型转储为文本字符串
FeatureType NodeSplitType(bst_node_t nidx) const
获取节点的分割类型。
定义: tree_model.h:539
bst_feature_t NumFeatures() const noexcept
获取特征数量。
定义: tree_model.h:414
common::Span< uint32_t const > GetSplitCategories() const
定义: tree_model.h:546
bool HasCategoricalSplit() const
这棵树是否有分类分割。
定义: tree_model.h:395
std::int32_t GetDepth(bst_node_t nid) const
获取当前深度
定义: tree_model.h:439
static constexpr bst_node_t kRoot
定义: tree_model.h:107
bst_node_t GetNumSplitNodes() const
auto const & GetSplitCategoriesPtr() const
定义: tree_model.h:559
float SplitCond(bst_node_t nidx) const
定义: tree_model.h:592
int MaxDepth()
获取最大深度
定义: tree_model.h:470
bst_node_t Size() const
定义: tree_model.h:645
span类实现,基于ISO++20 span<T>。接口应相同。
Definition: span.h:431
constexpr XGBOOST_DEVICE pointer data() const __span_noexcept
Definition: span.h:550
XGBOOST_DEVICE auto subspan() const -> Span< element_type, detail::ExtentValue< Extent, Offset, Count >::value >
定义: span.h:597
constexpr XGBOOST_DEVICE index_type size() const __span_noexcept
Definition: span.h:555
具有静态类型和维度的张量视图。它实现了索引和切片。
定义: linalg.h:277
xgboost 的输入数据结构。
特征映射数据结构,用于帮助可视化和模型导出。
线性代数相关工具。
定义 XGBoost 中不同组件的抽象接口。
集成目标、gbm和评估的学习器接口。这是用户面临的XGB...
Definition: base.h:97
std::int32_t bst_node_t
树节点索引的类型。
定义: base.h:119
FeatureType
定义: data.h:41
std::uint32_t bst_target_t
用于索引输出目标的类型。
定义: base.h:127
std::uint32_t bst_feature_t
数据列(特征)索引的类型。
Definition: base.h:107
float bst_float
浮点类型,用于存储统计信息
Definition: base.h:103
StringView MTNotImplemented()
定义: tree_model.h:739
定义: model.h:14
回归树中使用的节点统计信息
定义: tree_model.h:50
RTreeNodeStat(float loss_chg, float sum_hess, float weight)
定义: tree_model.h:61
bst_float loss_chg
当前分割导致的损失变化
定义: tree_model.h:52
int leaf_child_cnt
目前已知的叶子节点子节点数量
定义: tree_model.h:58
bst_float sum_hess
海森值之和,用于衡量数据覆盖率
定义: tree_model.h:54
bool operator==(const RTreeNodeStat &b) const
定义: tree_model.h:63
bst_float base_weight
当前节点的权重
定义: tree_model.h:56
std::size_t size
定义: tree_model.h:571
std::size_t beg
定义: tree_model.h:570
用于分类分割的CSR类矩阵。
定义: tree_model.h:568
common::Span< uint32_t const > categories
定义: tree_model.h:574
common::Span< Segment const > node_ptr
定义: tree_model.h:575
common::Span< FeatureType const > split_type
定义: tree_model.h:573
可由RegTree接收并可由稀疏特征向量构建的密集特征向量。
定义: tree_model.h:476
void HasMissing(bool has_missing)
定义: tree_model.h:511
void Drop()
填充后丢弃跟踪,必须在填充后调用。
定义: tree_model.h:724
bool HasMissing() const
定义: tree_model.h:736
bool IsMissing(size_t i) const
检查第i个条目是否缺失
定义: tree_model.h:734
size_t Size() const
返回特征向量的大小
定义: tree_model.h:726
void Init(size_t size)
用尺寸向量初始化向量
定义: tree_model.h:707
common::Span< float > Data()
定义: tree_model.h:513
void Fill(SparsePage::Inst const &inst)
用稀疏向量填充向量
定义: tree_model.h:713
bst_float GetFvalue(size_t i) const
获取第i个值
定义: tree_model.h:730
Definition: string_view.h:16
树的元参数
定义: tree_model.h:30
bst_node_t num_deleted
已删除节点的数量。
定义: tree_model.h:34
bst_feature_t num_feature
用于树构建的特征数量。
定义: tree_model.h:36
bool operator==(const TreeParam &b) const
定义: tree_model.h:40
bst_node_t num_nodes
节点数量。
定义: tree_model.h:32
void ToJson(Json *p_out) const
void FromJson(Json const &in)
bst_target_t size_leaf_vector
叶向量大小。由向量叶使用。
定义: tree_model.h:38