7 #ifndef XGBOOST_TREE_MODEL_H_
8 #define XGBOOST_TREE_MODEL_H_
14 #include <xgboost/logging.h>
74 std::unique_ptr<T> ptr_{
nullptr};
81 ptr_ = std::make_unique<T>(*that);
84 T*
get() const noexcept {
return ptr_.get(); }
92 explicit operator bool()
const {
return static_cast<bool>(ptr_); }
94 void reset(T* ptr) { ptr_.reset(ptr); }
114 static_assert(
sizeof(
Node) == 4 *
sizeof(
int) +
sizeof(Info),
"Node: 64 位对齐");
116 Node(int32_t cleft, int32_t cright, int32_t parent, uint32_t split_ind,
float split_cond,
118 : parent_{parent}, cleft_{cleft}, cright_{cright} {
120 this->
SetSplit(split_ind, split_cond, default_left);
133 static_assert(!std::is_signed_v<bst_feature_t>);
134 return sindex_ & ((1U << 31) - 1U);
173 bool default_left =
false) {
174 if (default_left) split_index |= (1U << 31);
175 this->sindex_ = split_index;
176 (this->info_).split_cond = split_cond;
185 (this->info_).leaf_value = value;
187 this->cright_ = right;
199 if (is_left_child) pidx |= (1U << 31);
200 this->parent_ = pidx;
203 return parent_ == b.parent_ && cleft_ == b.cleft_ &&
204 cright_ == b.cright_ && sindex_ == b.sindex_ &&
205 info_.leaf_value == b.info_.leaf_value;
236 this->DeleteNode(nodes_[rid].
LeftChild());
238 nodes_[rid].SetLeaf(value);
246 if (nodes_[rid].
IsLeaf())
return;
260 split_categories_segments_.resize(param_.
num_nodes);
261 for (
int i = 0; i < param_.
num_nodes; i++) {
262 nodes_[i].SetLeaf(0.0f);
287 [[nodiscard]]
const std::vector<Node>&
GetNodes()
const {
return nodes_; }
290 [[nodiscard]]
const std::vector<RTreeNodeStat>&
GetStats()
const {
return stats_; }
305 return nodes_ == b.nodes_ && stats_ == b.stats_ &&
306 deleted_nodes_ == b.deleted_nodes_ && param_ == b.param_;
313 template <
typename Func>
void WalkTree(Func func)
const {
314 std::stack<bst_node_t> nodes;
317 while (!nodes.empty()) {
318 auto nidx = nodes.top();
323 auto left =
self.LeftChild(nidx);
324 auto right =
self.RightChild(nidx);
359 bool default_left,
bst_float base_weight,
361 bst_float loss_change,
float sum_hess,
float left_sum,
391 float left_sum,
float right_sum);
399 [[nodiscard]]
bool IsMultiTarget()
const {
return static_cast<bool>(p_mt_tree_); }
409 return p_mt_tree_.get();
441 return this->p_mt_tree_->Depth(nid);
444 while (!nodes_[nid].
IsRoot()) {
446 nid = nodes_[nid].Parent();
455 return this->p_mt_tree_->SetLeaf(nidx, weight);
463 if (nodes_[nid].
IsLeaf())
return 0;
481 void Init(
size_t size);
497 [[nodiscard]]
size_t Size()
const;
509 [[nodiscard]]
bool IsMissing(
size_t i)
const;
511 void HasMissing(
bool has_missing) { this->has_missing_ = has_missing; }
521 std::vector<float> data_;
533 std::string format)
const;
547 return split_categories_;
555 auto segment = node_ptr[nidx];
556 auto node_cats = categories.
subspan(segment.beg, segment.size);
588 return this->p_mt_tree_->SplitIndex(nidx);
590 return (*
this)[nidx].SplitIndex();
594 return this->p_mt_tree_->SplitCond(nidx);
596 return (*
this)[nidx].SplitCond();
600 return this->p_mt_tree_->DefaultLeft(nidx);
602 return (*
this)[nidx].DefaultLeft();
609 return nidx ==
kRoot;
611 return (*
this)[nidx].IsRoot();
615 return this->p_mt_tree_->IsLeaf(nidx);
617 return (*
this)[nidx].IsLeaf();
621 return this->p_mt_tree_->Parent(nidx);
623 return (*
this)[nidx].Parent();
627 return this->p_mt_tree_->LeftChild(nidx);
629 return (*
this)[nidx].LeftChild();
633 return this->p_mt_tree_->RightChild(nidx);
635 return (*
this)[nidx].RightChild();
639 CHECK_NE(nidx,
kRoot);
640 auto p = this->p_mt_tree_->Parent(nidx);
641 return nidx == this->p_mt_tree_->LeftChild(p);
643 return (*
this)[nidx].IsLeftChild();
647 return this->p_mt_tree_->Size();
649 return this->nodes_.size();
653 template <
bool typed>
654 void LoadCategoricalSplit(
Json const& in);
655 void SaveCategoricalSplit(
Json* p_out)
const;
659 std::vector<Node> nodes_;
661 std::vector<int> deleted_nodes_;
663 std::vector<RTreeNodeStat> stats_;
664 std::vector<FeatureType> split_types_;
667 std::vector<uint32_t> split_categories_;
669 std::vector<CategoricalSplitMatrix::Segment> split_categories_segments_;
676 int nid = deleted_nodes_.back();
677 deleted_nodes_.pop_back();
683 CHECK_LT(param_.
num_nodes, std::numeric_limits<int>::max())
684 <<
"number of nodes in the tree exceed 2^31";
688 split_categories_segments_.resize(param_.
num_nodes);
692 void DeleteNode(
int nid) {
694 auto pid = (*this)[nid].Parent();
701 deleted_nodes_.push_back(nid);
702 nodes_[nid].MarkDelete();
709 std::fill(data_.begin(), data_.end(), std::numeric_limits<float>::quiet_NaN());
714 auto p_data = inst.
data();
715 auto p_out = data_.data();
717 for (std::size_t i = 0, n = inst.
size(); i < n; ++i) {
718 auto const& entry = p_data[i];
719 p_out[entry.index] = entry.fvalue;
721 has_missing_ = data_.size() != inst.
size();
740 return " support for multi-target tree is not yet implemented.";
#define XGBOOST_DEVICE
Tag function as usable by device.
Definition: base.h:64
用于定义包含唯一指针的可复制数据结构的辅助类。
定义: tree_model.h:73
T const * operator->() const noexcept
定义: tree_model.h:90
T * get() const noexcept
定义: tree_model.h:84
bool operator!() const
定义: tree_model.h:93
CopyUniquePtr(CopyUniquePtr const &that)
定义: tree_model.h:78
T * operator->() noexcept
定义: tree_model.h:87
T & operator*()
定义: tree_model.h:86
T const & operator*() const
定义: tree_model.h:89
void reset(T *ptr)
定义: tree_model.h:94
特征映射数据结构,用于辅助文本模型转储。TODO(tqchen) 考虑使其更轻量级...
定义: feature_map.h:22
表示JSON格式的数据结构。
Definition: json.h:392
多目标模型的树结构。
定义: multi_target_tree_model.h:69
static constexpr bst_node_t InvalidNodeId()
定义: multi_target_tree_model.h:71
XGBOOST_DEVICE int Parent() const
获取节点的父节点
定义: tree_model.h:145
XGBOOST_DEVICE void MarkDelete()
标记此节点已删除
定义: tree_model.h:190
XGBOOST_DEVICE bool IsRoot() const
当前节点是否为根节点
定义: tree_model.h:151
XGBOOST_DEVICE int RightChild() const
右子节点索引
定义: tree_model.h:126
XGBOOST_DEVICE float LeafValue() const
定义: tree_model.h:141
XGBOOST_DEVICE Node()
定义: tree_model.h:112
XGBOOST_DEVICE void SetParent(int pidx, bool is_left_child=true)
定义: tree_model.h:198
XGBOOST_DEVICE void SetLeaf(bst_float value, int right=kInvalidNodeId)
设置节点的叶子值
定义: tree_model.h:184
XGBOOST_DEVICE bool IsLeftChild() const
当前节点是否为左子节点
定义: tree_model.h:147
XGBOOST_DEVICE void SetSplit(unsigned split_index, SplitCondT split_cond, bool default_left=false)
设置当前节点的分割条件
定义: tree_model.h:172
XGBOOST_DEVICE void SetLeftChild(int nid)
设置左子节点
定义: tree_model.h:156
XGBOOST_DEVICE bst_feature_t SplitIndex() const
分割条件的特征索引
定义: tree_model.h:132
XGBOOST_DEVICE bool IsDeleted() const
此节点是否已删除
定义: tree_model.h:149
XGBOOST_DEVICE bool IsLeaf() const
当前节点是否为叶子节点
定义: tree_model.h:139
bool operator==(const Node &b) const
定义: tree_model.h:202
Node(int32_t cleft, int32_t cright, int32_t parent, uint32_t split_ind, float split_cond, bool default_left)
定义: tree_model.h:116
XGBOOST_DEVICE void Reuse()
重用此已删除节点。
定义: tree_model.h:194
XGBOOST_DEVICE void SetRightChild(int nid)
设置右子节点
定义: tree_model.h:163
XGBOOST_DEVICE bool DefaultLeft() const
当特征未知时,是否走向左子节点
定义: tree_model.h:137
XGBOOST_DEVICE int LeftChild() const
左子节点索引
定义: tree_model.h:124
XGBOOST_DEVICE int DefaultChild() const
当特征缺失时,默认子节点的索引
定义: tree_model.h:128
XGBOOST_DEVICE SplitCondT SplitCond() const
定义: tree_model.h:143
将回归树定义为最常见的树模型。
定义: tree_model.h:102
int MaxDepth(int nid) const
获取最大深度
定义: tree_model.h:462
void SaveModel(Json *out) const override
将模型配置保存到JSON对象
bst_target_t NumTargets() const
叶子权重的尺寸。
定义: tree_model.h:403
void WalkTree(Func func) const
定义: tree_model.h:313
bool IsLeaf(bst_node_t nidx) const
定义: tree_model.h:613
bool operator==(const RegTree &b) const
定义: tree_model.h:304
const RTreeNodeStat & Stat(int nid) const
获取给定nid的节点统计信息
定义: tree_model.h:297
void ExpandNode(bst_node_t nidx, bst_feature_t split_index, float split_cond, bool default_left, linalg::VectorView< float const > base_weight, linalg::VectorView< float const > left_weight, linalg::VectorView< float const > right_weight)
将叶子节点扩展为多目标树的两个额外的叶子节点。
bst_node_t Parent(bst_node_t nidx) const
定义: tree_model.h:619
bst_node_t NumNodes() const noexcept
获取此树中包括已删除节点在内的总节点数。
定义: tree_model.h:418
const Node & operator[](int nid) const
获取给定nid的节点
定义: tree_model.h:282
bst_node_t DefaultChild(bst_node_t nidx) const
定义: tree_model.h:604
void ExpandNode(bst_node_t nid, unsigned split_index, bst_float split_value, bool default_left, bst_float base_weight, bst_float left_leaf_weight, bst_float right_leaf_weight, bst_float loss_change, float sum_hess, float left_sum, float right_sum, bst_node_t leaf_right_child=kInvalidNodeId)
将叶子节点扩展为两个额外的叶子节点。
Node & operator[](int nid)
获取给定nid的节点
定义: tree_model.h:278
RegTree()
定义: tree_model.h:256
static constexpr bst_node_t kInvalidNodeId
定义: tree_model.h:105
bst_feature_t SplitIndex(bst_node_t nidx) const
定义: tree_model.h:586
bool IsRoot(bst_node_t nidx) const
定义: tree_model.h:607
static constexpr uint32_t kDeletedNodeMarker
定义: tree_model.h:106
bool IsMultiTarget() const
这是否是多目标树。
定义: tree_model.h:399
bst_node_t NumExtraNodes() const noexcept
除根节点外的额外节点数
定义: tree_model.h:428
bool DefaultLeft(bst_node_t nidx) const
定义: tree_model.h:598
auto GetMultiTargetTree() const
获取多目标树的底层实现。
定义: tree_model.h:407
bst_node_t LeftChild(bst_node_t nidx) const
定义: tree_model.h:625
bst_node_t GetNumLeaves() const
RegTree(bst_target_t n_targets, bst_feature_t n_features)
构造函数,用形状初始化树模型。
定义: tree_model.h:269
bst_node_t RightChild(bst_node_t nidx) const
定义: tree_model.h:631
common::Span< uint32_t const > NodeCats(bst_node_t nidx) const
获取类别的位存储。
定义: tree_model.h:552
bool IsLeftChild(bst_node_t nidx) const
定义: tree_model.h:637
CategoricalSplitMatrix GetCategoriesMatrix() const
定义: tree_model.h:578
RTreeNodeStat & Stat(int nid)
获取给定nid的节点统计信息
定义: tree_model.h:293
bst_float SplitCondT
定义: tree_model.h:104
void ExpandCategorical(bst_node_t nid, bst_feature_t split_index, common::Span< const uint32_t > split_cat, bool default_left, bst_float base_weight, bst_float left_leaf_weight, bst_float right_leaf_weight, bst_float loss_change, float sum_hess, float left_sum, float right_sum)
用类别扩展叶子节点。
bool Equal(const RegTree &b) const
从用户角度比较两棵树是否相等。相等性仅比较未删除的...
std::vector< FeatureType > const & GetSplitTypes() const
获取所有节点的分割类型。
定义: tree_model.h:543
void CollapseToLeaf(int rid, bst_float value)
将非叶节点折叠为叶节点,删除其子节点
定义: tree_model.h:245
bst_node_t NumValidNodes() const noexcept
获取此树中有效节点的总数。
定义: tree_model.h:422
void ChangeToLeaf(int rid, bst_float value)
将非叶节点更改为叶节点,删除其子节点
定义: tree_model.h:233
const std::vector< RTreeNodeStat > & GetStats() const
获取stats的常量引用
定义: tree_model.h:290
void SetLeaf(bst_node_t nidx, linalg::VectorView< float const > weight)
设置多目标树的叶子权重。
定义: tree_model.h:453
const std::vector< Node > & GetNodes() const
获取节点的常量引用
定义: tree_model.h:287
void LoadModel(Json const &in) override
从JSON对象加载模型
std::string DumpModel(const FeatureMap &fmap, bool with_stats, std::string format) const
以请求的格式将模型转储为文本字符串
FeatureType NodeSplitType(bst_node_t nidx) const
获取节点的分割类型。
定义: tree_model.h:539
bst_feature_t NumFeatures() const noexcept
获取特征数量。
定义: tree_model.h:414
common::Span< uint32_t const > GetSplitCategories() const
定义: tree_model.h:546
bool HasCategoricalSplit() const
这棵树是否有分类分割。
定义: tree_model.h:395
std::int32_t GetDepth(bst_node_t nid) const
获取当前深度
定义: tree_model.h:439
static constexpr bst_node_t kRoot
定义: tree_model.h:107
bst_node_t GetNumSplitNodes() const
auto const & GetSplitCategoriesPtr() const
定义: tree_model.h:559
float SplitCond(bst_node_t nidx) const
定义: tree_model.h:592
int MaxDepth()
获取最大深度
定义: tree_model.h:470
bst_node_t Size() const
定义: tree_model.h:645
span类实现,基于ISO++20 span<T>。接口应相同。
Definition: span.h:431
constexpr XGBOOST_DEVICE pointer data() const __span_noexcept
Definition: span.h:550
XGBOOST_DEVICE auto subspan() const -> Span< element_type, detail::ExtentValue< Extent, Offset, Count >::value >
定义: span.h:597
constexpr XGBOOST_DEVICE index_type size() const __span_noexcept
Definition: span.h:555
具有静态类型和维度的张量视图。它实现了索引和切片。
定义: linalg.h:277
集成目标、gbm和评估的学习器接口。这是用户面临的XGB...
Definition: base.h:97
std::int32_t bst_node_t
树节点索引的类型。
定义: base.h:119
std::uint32_t bst_target_t
用于索引输出目标的类型。
定义: base.h:127
std::uint32_t bst_feature_t
数据列(特征)索引的类型。
Definition: base.h:107
float bst_float
浮点类型,用于存储统计信息
Definition: base.h:103
StringView MTNotImplemented()
定义: tree_model.h:739
回归树中使用的节点统计信息
定义: tree_model.h:50
RTreeNodeStat(float loss_chg, float sum_hess, float weight)
定义: tree_model.h:61
bst_float loss_chg
当前分割导致的损失变化
定义: tree_model.h:52
int leaf_child_cnt
目前已知的叶子节点子节点数量
定义: tree_model.h:58
bst_float sum_hess
海森值之和,用于衡量数据覆盖率
定义: tree_model.h:54
bool operator==(const RTreeNodeStat &b) const
定义: tree_model.h:63
bst_float base_weight
当前节点的权重
定义: tree_model.h:56
std::size_t size
定义: tree_model.h:571
std::size_t beg
定义: tree_model.h:570
用于分类分割的CSR类矩阵。
定义: tree_model.h:568
common::Span< uint32_t const > categories
定义: tree_model.h:574
common::Span< Segment const > node_ptr
定义: tree_model.h:575
common::Span< FeatureType const > split_type
定义: tree_model.h:573
可由RegTree接收并可由稀疏特征向量构建的密集特征向量。
定义: tree_model.h:476
void HasMissing(bool has_missing)
定义: tree_model.h:511
void Drop()
填充后丢弃跟踪,必须在填充后调用。
定义: tree_model.h:724
bool HasMissing() const
定义: tree_model.h:736
bool IsMissing(size_t i) const
检查第i个条目是否缺失
定义: tree_model.h:734
size_t Size() const
返回特征向量的大小
定义: tree_model.h:726
void Init(size_t size)
用尺寸向量初始化向量
定义: tree_model.h:707
common::Span< float > Data()
定义: tree_model.h:513
void Fill(SparsePage::Inst const &inst)
用稀疏向量填充向量
定义: tree_model.h:713
bst_float GetFvalue(size_t i) const
获取第i个值
定义: tree_model.h:730
Definition: string_view.h:16
bst_node_t num_deleted
已删除节点的数量。
定义: tree_model.h:34
bst_feature_t num_feature
用于树构建的特征数量。
定义: tree_model.h:36
bool operator==(const TreeParam &b) const
定义: tree_model.h:40
bst_node_t num_nodes
节点数量。
定义: tree_model.h:32
void ToJson(Json *p_out) const
void FromJson(Json const &in)
bst_target_t size_leaf_vector
叶向量大小。由向量叶使用。
定义: tree_model.h:38