7 #ifndef XGBOOST_TREE_MODEL_H_ 
 8 #define XGBOOST_TREE_MODEL_H_ 
 14 #include <xgboost/logging.h> 
 74  std::unique_ptr<T> ptr_{
nullptr};
 
 81  ptr_ = std::make_unique<T>(*that);
 
 84  T* 
get() const noexcept { 
return ptr_.get(); } 
 
 92  explicit operator bool()
 const { 
return static_cast<bool>(ptr_); }
 
 94  void reset(T* ptr) { ptr_.reset(ptr); } 
 
 114  static_assert(
sizeof(
Node) == 4 * 
sizeof(
int) + 
sizeof(Info), 
"Node: 64 位对齐");
 
 116  Node(int32_t cleft, int32_t cright, int32_t parent, uint32_t split_ind, 
float split_cond,
 
 118  : parent_{parent}, cleft_{cleft}, cright_{cright} {
 
 120  this->
SetSplit(split_ind, split_cond, default_left);
 
 133  static_assert(!std::is_signed_v<bst_feature_t>);
 
 134  return sindex_ & ((1U << 31) - 1U);
 
 173  bool default_left = 
false) {
 
 174  if (default_left) split_index |= (1U << 31);
 
 175  this->sindex_ = split_index;
 
 176  (this->info_).split_cond = split_cond;
 
 185  (this->info_).leaf_value = value;
 
 187  this->cright_ = right;
 
 199  if (is_left_child) pidx |= (1U << 31);
 
 200  this->parent_ = pidx;
 
 203  return parent_ == b.parent_ && cleft_ == b.cleft_ &&
 
 204  cright_ == b.cright_ && sindex_ == b.sindex_ &&
 
 205  info_.leaf_value == b.info_.leaf_value;
 
 236  this->DeleteNode(nodes_[rid].
LeftChild());
 
 238  nodes_[rid].SetLeaf(value);
 
 246  if (nodes_[rid].
IsLeaf()) 
return;
 
 260  split_categories_segments_.resize(param_.
num_nodes);
 
 261  for (
int i = 0; i < param_.
num_nodes; i++) {
 
 262  nodes_[i].SetLeaf(0.0f);
 
 287  [[nodiscard]] 
const std::vector<Node>& 
GetNodes()
 const { 
return nodes_; }
 
 290  [[nodiscard]] 
const std::vector<RTreeNodeStat>& 
GetStats()
 const { 
return stats_; }
 
 305  return nodes_ == b.nodes_ && stats_ == b.stats_ &&
 
 306  deleted_nodes_ == b.deleted_nodes_ && param_ == b.param_;
 
 313  template <
typename Func> 
void WalkTree(Func func)
 const {
 
 314  std::stack<bst_node_t> nodes;
 
 317  while (!nodes.empty()) {
 
 318  auto nidx = nodes.top();
 
 323  auto left = 
self.LeftChild(nidx);
 
 324  auto right = 
self.RightChild(nidx);
 
 359  bool default_left, 
bst_float base_weight,
 
 361  bst_float loss_change, 
float sum_hess, 
float left_sum,
 
 391  float left_sum, 
float right_sum);
 
 399  [[nodiscard]] 
bool IsMultiTarget()
 const { 
return static_cast<bool>(p_mt_tree_); }
 
 409  return p_mt_tree_.get();
 
 441  return this->p_mt_tree_->Depth(nid);
 
 444  while (!nodes_[nid].
IsRoot()) {
 
 446  nid = nodes_[nid].Parent();
 
 455  return this->p_mt_tree_->SetLeaf(nidx, weight);
 
 463  if (nodes_[nid].
IsLeaf()) 
return 0;
 
 481  void Init(
size_t size);
 
 497  [[nodiscard]] 
size_t Size() 
const;
 
 509  [[nodiscard]] 
bool IsMissing(
size_t i) 
const;
 
 511  void HasMissing(
bool has_missing) { this->has_missing_ = has_missing; }
 
 521  std::vector<float> data_;
 
 533  std::string format) 
const;
 
 547  return split_categories_;
 
 555  auto segment = node_ptr[nidx];
 
 556  auto node_cats = categories.
subspan(segment.beg, segment.size);
 
 588  return this->p_mt_tree_->SplitIndex(nidx);
 
 590  return (*
this)[nidx].SplitIndex();
 
 594  return this->p_mt_tree_->SplitCond(nidx);
 
 596  return (*
this)[nidx].SplitCond();
 
 600  return this->p_mt_tree_->DefaultLeft(nidx);
 
 602  return (*
this)[nidx].DefaultLeft();
 
 609  return nidx == 
kRoot;
 
 611  return (*
this)[nidx].IsRoot();
 
 615  return this->p_mt_tree_->IsLeaf(nidx);
 
 617  return (*
this)[nidx].IsLeaf();
 
 621  return this->p_mt_tree_->Parent(nidx);
 
 623  return (*
this)[nidx].Parent();
 
 627  return this->p_mt_tree_->LeftChild(nidx);
 
 629  return (*
this)[nidx].LeftChild();
 
 633  return this->p_mt_tree_->RightChild(nidx);
 
 635  return (*
this)[nidx].RightChild();
 
 639  CHECK_NE(nidx, 
kRoot);
 
 640  auto p = this->p_mt_tree_->Parent(nidx);
 
 641  return nidx == this->p_mt_tree_->LeftChild(p);
 
 643  return (*
this)[nidx].IsLeftChild();
 
 647  return this->p_mt_tree_->Size();
 
 649  return this->nodes_.size();
 
 653  template <
bool typed>
 
 654  void LoadCategoricalSplit(
Json const& in);
 
 655  void SaveCategoricalSplit(
Json* p_out) 
const;
 
 659  std::vector<Node> nodes_;
 
 661  std::vector<int> deleted_nodes_;
 
 663  std::vector<RTreeNodeStat> stats_;
 
 664  std::vector<FeatureType> split_types_;
 
 667  std::vector<uint32_t> split_categories_;
 
 669  std::vector<CategoricalSplitMatrix::Segment> split_categories_segments_;
 
 676  int nid = deleted_nodes_.back();
 
 677  deleted_nodes_.pop_back();
 
 683  CHECK_LT(param_.
num_nodes, std::numeric_limits<int>::max())
 
 684  << 
"number of nodes in the tree exceed 2^31";
 
 688  split_categories_segments_.resize(param_.
num_nodes);
 
 692  void DeleteNode(
int nid) {
 
 694  auto pid = (*this)[nid].Parent();
 
 701  deleted_nodes_.push_back(nid);
 
 702  nodes_[nid].MarkDelete();
 
 709  std::fill(data_.begin(), data_.end(), std::numeric_limits<float>::quiet_NaN());
 
 714  auto p_data = inst.
data();
 
 715  auto p_out = data_.data();
 
 717  for (std::size_t i = 0, n = inst.
size(); i < n; ++i) {
 
 718  auto const& entry = p_data[i];
 
 719  p_out[entry.index] = entry.fvalue;
 
 721  has_missing_ = data_.size() != inst.
size();
 
 740  return " support for multi-target tree is not yet implemented.";
 
#define XGBOOST_DEVICE
Tag function as usable by device.
Definition: base.h:64
用于定义包含唯一指针的可复制数据结构的辅助类。
定义: tree_model.h:73
T const * operator->() const noexcept
定义: tree_model.h:90
T * get() const noexcept
定义: tree_model.h:84
bool operator!() const
定义: tree_model.h:93
CopyUniquePtr(CopyUniquePtr const &that)
定义: tree_model.h:78
T * operator->() noexcept
定义: tree_model.h:87
T & operator*()
定义: tree_model.h:86
T const & operator*() const
定义: tree_model.h:89
void reset(T *ptr)
定义: tree_model.h:94
特征映射数据结构,用于辅助文本模型转储。TODO(tqchen) 考虑使其更轻量级...
定义: feature_map.h:22
表示JSON格式的数据结构。
Definition: json.h:392
多目标模型的树结构。
定义: multi_target_tree_model.h:69
static constexpr bst_node_t InvalidNodeId()
定义: multi_target_tree_model.h:71
XGBOOST_DEVICE int Parent() const
获取节点的父节点
定义: tree_model.h:145
XGBOOST_DEVICE void MarkDelete()
标记此节点已删除
定义: tree_model.h:190
XGBOOST_DEVICE bool IsRoot() const
当前节点是否为根节点
定义: tree_model.h:151
XGBOOST_DEVICE int RightChild() const
右子节点索引
定义: tree_model.h:126
XGBOOST_DEVICE float LeafValue() const
定义: tree_model.h:141
XGBOOST_DEVICE Node()
定义: tree_model.h:112
XGBOOST_DEVICE void SetParent(int pidx, bool is_left_child=true)
定义: tree_model.h:198
XGBOOST_DEVICE void SetLeaf(bst_float value, int right=kInvalidNodeId)
设置节点的叶子值
定义: tree_model.h:184
XGBOOST_DEVICE bool IsLeftChild() const
当前节点是否为左子节点
定义: tree_model.h:147
XGBOOST_DEVICE void SetSplit(unsigned split_index, SplitCondT split_cond, bool default_left=false)
设置当前节点的分割条件
定义: tree_model.h:172
XGBOOST_DEVICE void SetLeftChild(int nid)
设置左子节点
定义: tree_model.h:156
XGBOOST_DEVICE bst_feature_t SplitIndex() const
分割条件的特征索引
定义: tree_model.h:132
XGBOOST_DEVICE bool IsDeleted() const
此节点是否已删除
定义: tree_model.h:149
XGBOOST_DEVICE bool IsLeaf() const
当前节点是否为叶子节点
定义: tree_model.h:139
bool operator==(const Node &b) const
定义: tree_model.h:202
Node(int32_t cleft, int32_t cright, int32_t parent, uint32_t split_ind, float split_cond, bool default_left)
定义: tree_model.h:116
XGBOOST_DEVICE void Reuse()
重用此已删除节点。
定义: tree_model.h:194
XGBOOST_DEVICE void SetRightChild(int nid)
设置右子节点
定义: tree_model.h:163
XGBOOST_DEVICE bool DefaultLeft() const
当特征未知时,是否走向左子节点
定义: tree_model.h:137
XGBOOST_DEVICE int LeftChild() const
左子节点索引
定义: tree_model.h:124
XGBOOST_DEVICE int DefaultChild() const
当特征缺失时,默认子节点的索引
定义: tree_model.h:128
XGBOOST_DEVICE SplitCondT SplitCond() const
定义: tree_model.h:143
将回归树定义为最常见的树模型。
定义: tree_model.h:102
int MaxDepth(int nid) const
获取最大深度
定义: tree_model.h:462
void SaveModel(Json *out) const override
将模型配置保存到JSON对象
bst_target_t NumTargets() const
叶子权重的尺寸。
定义: tree_model.h:403
void WalkTree(Func func) const
定义: tree_model.h:313
bool IsLeaf(bst_node_t nidx) const
定义: tree_model.h:613
bool operator==(const RegTree &b) const
定义: tree_model.h:304
const RTreeNodeStat & Stat(int nid) const
获取给定nid的节点统计信息
定义: tree_model.h:297
void ExpandNode(bst_node_t nidx, bst_feature_t split_index, float split_cond, bool default_left, linalg::VectorView< float const > base_weight, linalg::VectorView< float const > left_weight, linalg::VectorView< float const > right_weight)
将叶子节点扩展为多目标树的两个额外的叶子节点。
bst_node_t Parent(bst_node_t nidx) const
定义: tree_model.h:619
bst_node_t NumNodes() const noexcept
获取此树中包括已删除节点在内的总节点数。
定义: tree_model.h:418
const Node & operator[](int nid) const
获取给定nid的节点
定义: tree_model.h:282
bst_node_t DefaultChild(bst_node_t nidx) const
定义: tree_model.h:604
void ExpandNode(bst_node_t nid, unsigned split_index, bst_float split_value, bool default_left, bst_float base_weight, bst_float left_leaf_weight, bst_float right_leaf_weight, bst_float loss_change, float sum_hess, float left_sum, float right_sum, bst_node_t leaf_right_child=kInvalidNodeId)
将叶子节点扩展为两个额外的叶子节点。
Node & operator[](int nid)
获取给定nid的节点
定义: tree_model.h:278
RegTree()
定义: tree_model.h:256
static constexpr bst_node_t kInvalidNodeId
定义: tree_model.h:105
bst_feature_t SplitIndex(bst_node_t nidx) const
定义: tree_model.h:586
bool IsRoot(bst_node_t nidx) const
定义: tree_model.h:607
static constexpr uint32_t kDeletedNodeMarker
定义: tree_model.h:106
bool IsMultiTarget() const
这是否是多目标树。
定义: tree_model.h:399
bst_node_t NumExtraNodes() const noexcept
除根节点外的额外节点数
定义: tree_model.h:428
bool DefaultLeft(bst_node_t nidx) const
定义: tree_model.h:598
auto GetMultiTargetTree() const
获取多目标树的底层实现。
定义: tree_model.h:407
bst_node_t LeftChild(bst_node_t nidx) const
定义: tree_model.h:625
bst_node_t GetNumLeaves() const
RegTree(bst_target_t n_targets, bst_feature_t n_features)
构造函数,用形状初始化树模型。
定义: tree_model.h:269
bst_node_t RightChild(bst_node_t nidx) const
定义: tree_model.h:631
common::Span< uint32_t const > NodeCats(bst_node_t nidx) const
获取类别的位存储。
定义: tree_model.h:552
bool IsLeftChild(bst_node_t nidx) const
定义: tree_model.h:637
CategoricalSplitMatrix GetCategoriesMatrix() const
定义: tree_model.h:578
RTreeNodeStat & Stat(int nid)
获取给定nid的节点统计信息
定义: tree_model.h:293
bst_float SplitCondT
定义: tree_model.h:104
void ExpandCategorical(bst_node_t nid, bst_feature_t split_index, common::Span< const uint32_t > split_cat, bool default_left, bst_float base_weight, bst_float left_leaf_weight, bst_float right_leaf_weight, bst_float loss_change, float sum_hess, float left_sum, float right_sum)
用类别扩展叶子节点。
bool Equal(const RegTree &b) const
从用户角度比较两棵树是否相等。相等性仅比较未删除的...
std::vector< FeatureType > const & GetSplitTypes() const
获取所有节点的分割类型。
定义: tree_model.h:543
void CollapseToLeaf(int rid, bst_float value)
将非叶节点折叠为叶节点,删除其子节点
定义: tree_model.h:245
bst_node_t NumValidNodes() const noexcept
获取此树中有效节点的总数。
定义: tree_model.h:422
void ChangeToLeaf(int rid, bst_float value)
将非叶节点更改为叶节点,删除其子节点
定义: tree_model.h:233
const std::vector< RTreeNodeStat > & GetStats() const
获取stats的常量引用
定义: tree_model.h:290
void SetLeaf(bst_node_t nidx, linalg::VectorView< float const > weight)
设置多目标树的叶子权重。
定义: tree_model.h:453
const std::vector< Node > & GetNodes() const
获取节点的常量引用
定义: tree_model.h:287
void LoadModel(Json const &in) override
从JSON对象加载模型
std::string DumpModel(const FeatureMap &fmap, bool with_stats, std::string format) const
以请求的格式将模型转储为文本字符串
FeatureType NodeSplitType(bst_node_t nidx) const
获取节点的分割类型。
定义: tree_model.h:539
bst_feature_t NumFeatures() const noexcept
获取特征数量。
定义: tree_model.h:414
common::Span< uint32_t const > GetSplitCategories() const
定义: tree_model.h:546
bool HasCategoricalSplit() const
这棵树是否有分类分割。
定义: tree_model.h:395
std::int32_t GetDepth(bst_node_t nid) const
获取当前深度
定义: tree_model.h:439
static constexpr bst_node_t kRoot
定义: tree_model.h:107
bst_node_t GetNumSplitNodes() const
auto const & GetSplitCategoriesPtr() const
定义: tree_model.h:559
float SplitCond(bst_node_t nidx) const
定义: tree_model.h:592
int MaxDepth()
获取最大深度
定义: tree_model.h:470
bst_node_t Size() const
定义: tree_model.h:645
span类实现,基于ISO++20 span<T>。接口应相同。
Definition: span.h:431
constexpr XGBOOST_DEVICE pointer data() const __span_noexcept
Definition: span.h:550
XGBOOST_DEVICE auto subspan() const -> Span< element_type, detail::ExtentValue< Extent, Offset, Count >::value >
定义: span.h:597
constexpr XGBOOST_DEVICE index_type size() const __span_noexcept
Definition: span.h:555
具有静态类型和维度的张量视图。它实现了索引和切片。
定义: linalg.h:277
集成目标、gbm和评估的学习器接口。这是用户面临的XGB...
Definition: base.h:97
std::int32_t bst_node_t
树节点索引的类型。
定义: base.h:119
std::uint32_t bst_target_t
用于索引输出目标的类型。
定义: base.h:127
std::uint32_t bst_feature_t
数据列(特征)索引的类型。
Definition: base.h:107
float bst_float
浮点类型,用于存储统计信息
Definition: base.h:103
StringView MTNotImplemented()
定义: tree_model.h:739
回归树中使用的节点统计信息
定义: tree_model.h:50
RTreeNodeStat(float loss_chg, float sum_hess, float weight)
定义: tree_model.h:61
bst_float loss_chg
当前分割导致的损失变化
定义: tree_model.h:52
int leaf_child_cnt
目前已知的叶子节点子节点数量
定义: tree_model.h:58
bst_float sum_hess
海森值之和,用于衡量数据覆盖率
定义: tree_model.h:54
bool operator==(const RTreeNodeStat &b) const
定义: tree_model.h:63
bst_float base_weight
当前节点的权重
定义: tree_model.h:56
std::size_t size
定义: tree_model.h:571
std::size_t beg
定义: tree_model.h:570
用于分类分割的CSR类矩阵。
定义: tree_model.h:568
common::Span< uint32_t const > categories
定义: tree_model.h:574
common::Span< Segment const > node_ptr
定义: tree_model.h:575
common::Span< FeatureType const > split_type
定义: tree_model.h:573
可由RegTree接收并可由稀疏特征向量构建的密集特征向量。
定义: tree_model.h:476
void HasMissing(bool has_missing)
定义: tree_model.h:511
void Drop()
填充后丢弃跟踪,必须在填充后调用。
定义: tree_model.h:724
bool HasMissing() const
定义: tree_model.h:736
bool IsMissing(size_t i) const
检查第i个条目是否缺失
定义: tree_model.h:734
size_t Size() const
返回特征向量的大小
定义: tree_model.h:726
void Init(size_t size)
用尺寸向量初始化向量
定义: tree_model.h:707
common::Span< float > Data()
定义: tree_model.h:513
void Fill(SparsePage::Inst const &inst)
用稀疏向量填充向量
定义: tree_model.h:713
bst_float GetFvalue(size_t i) const
获取第i个值
定义: tree_model.h:730
Definition: string_view.h:16
bst_node_t num_deleted
已删除节点的数量。
定义: tree_model.h:34
bst_feature_t num_feature
用于树构建的特征数量。
定义: tree_model.h:36
bool operator==(const TreeParam &b) const
定义: tree_model.h:40
bst_node_t num_nodes
节点数量。
定义: tree_model.h:32
void ToJson(Json *p_out) const
void FromJson(Json const &in)
bst_target_t size_leaf_vector
叶向量大小。由向量叶使用。
定义: tree_model.h:38