xgboost
multi_target_tree_model.h
前往此文件文档。
1 
6 #ifndef XGBOOST_MULTI_TARGET_TREE_MODEL_H_
7 #define XGBOOST_MULTI_TARGET_TREE_MODEL_H_
8 
9 #include <xgboost/base.h> // for bst_node_t, bst_target_t, bst_feature_t
10 #include <xgboost/context.h> // for Context
11 #include <xgboost/host_device_vector.h> // for HostDeviceVector
12 #include <xgboost/linalg.h> // for VectorView, MatrixView
13 #include <xgboost/model.h> // for Model
14 #include <xgboost/span.h> // for Span
15 
16 #include <cstddef> // for size_t
17 #include <cstdint> // for uint8_t
18 #include <mutex> // for mutex
19 #include <vector> // for vector
20 
21 namespace xgboost {
22 struct TreeParam;
27  static bst_node_t constexpr InvalidNodeId() { return -1; }
28 
29  bst_node_t const* left;
30  bst_node_t const* right;
32 
34  std::uint8_t const* default_left;
35  float const* split_conds;
36 
37  // The number of nodes
38  std::size_t n{0};
39 
41 
42  [[nodiscard]] XGBOOST_DEVICE bool IsLeaf(bst_node_t nidx) const {
43  return left[nidx] == InvalidNodeId();
44  }
45 
46  [[nodiscard]] XGBOOST_DEVICE bst_node_t LeftChild(bst_node_t nidx) const { return left[nidx]; }
47  [[nodiscard]] XGBOOST_DEVICE bst_node_t RightChild(bst_node_t nidx) const { return right[nidx]; }
48  [[nodiscard]] XGBOOST_DEVICE bst_feature_t SplitIndex(bst_node_t nidx) const {
49  return split_index[nidx];
50  }
51  [[nodiscard]] XGBOOST_DEVICE float SplitCond(bst_node_t nidx) const { return split_conds[nidx]; }
52  [[nodiscard]] XGBOOST_DEVICE bool DefaultLeft(bst_node_t nidx) const {
53  return default_left[nidx];
54  }
55  [[nodiscard]] XGBOOST_DEVICE bst_node_t DefaultChild(bst_node_t nidx) const {
56  return this->DefaultLeft(nidx) ? this->LeftChild(nidx) : this->RightChild(nidx);
57  }
59  return this->weights.Slice(nidx, linalg::All());
60  }
61 
62  [[nodiscard]] bst_target_t NumTargets() const { return this->weights.Shape(1); }
63  [[nodiscard]] bst_node_t Size() const { return this->n; }
64 };
65 
69 class MultiTargetTree : public Model {
70  public
72 
73  private
74  TreeParam const* param_;
79  HostDeviceVector<std::uint8_t> default_left_;
80  HostDeviceVector<float> split_conds_;
81  HostDeviceVector<float> weights_;
82 
83  mutable std::mutex tree_view_lock_;
84 
85  [[nodiscard]] linalg::VectorView<float const> NodeWeight(bst_node_t nidx) const {
86  auto beg = nidx * this->NumTargets();
87  auto v = this->weights_.ConstHostSpan().subspan(beg, this->NumTargets());
88  return linalg::MakeTensorView(DeviceOrd::CPU(), v, v.size());
89  }
90  [[nodiscard]] linalg::VectorView<float> NodeWeight(bst_node_t nidx) {
91  auto beg = nidx * this->NumTargets();
92  auto v = this->weights_.HostSpan().subspan(beg, this->NumTargets());
93  return linalg::MakeTensorView(DeviceOrd::CPU(), v, v.size());
94  }
95 
96  public
97  explicit MultiTargetTree(TreeParam const* param);
99  MultiTargetTree& operator=(MultiTargetTree const& that) = delete;
102 
110  void Expand(bst_node_t nidx, bst_feature_t split_idx, float split_cond, bool default_left,
113  linalg::VectorView<float const> right_weight);
114 
115  [[nodiscard]] bool IsLeaf(bst_node_t nidx) const {
116  return left_.ConstHostVector()[nidx] == InvalidNodeId();
117  }
118  [[nodiscard]] bst_node_t Parent(bst_node_t nidx) const {
119  return parent_.ConstHostVector().at(nidx);
120  }
121  [[nodiscard]] bst_node_t LeftChild(bst_node_t nidx) const {
122  return left_.ConstHostVector().at(nidx);
123  }
124  [[nodiscard]] bst_node_t RightChild(bst_node_t nidx) const {
125  return right_.ConstHostVector().at(nidx);
126  }
127 
128  [[nodiscard]] bst_feature_t SplitIndex(bst_node_t nidx) const {
129  return split_index_.ConstHostVector()[nidx];
130  }
131  [[nodiscard]] float SplitCond(bst_node_t nidx) const {
132  return split_conds_.ConstHostVector()[nidx];
133  }
134  [[nodiscard]] bool DefaultLeft(bst_node_t nidx) const {
135  return default_left_.ConstHostVector()[nidx];
136  }
137  [[nodiscard]] bst_node_t DefaultChild(bst_node_t nidx) const {
138  return this->DefaultLeft(nidx) ? this->LeftChild(nidx) : this->RightChild(nidx);
139  }
140 
141  [[nodiscard]] bst_target_t NumTargets() const;
142 
143  [[nodiscard]] std::size_t Size() const;
144 
145  [[nodiscard]] bst_node_t Depth(bst_node_t nidx) const {
146  bst_node_t depth{0};
147  while (Parent(nidx) != InvalidNodeId()) {
148  ++depth;
149  nidx = Parent(nidx);
150  }
151  return depth;
152  }
153 
155  CHECK(IsLeaf(nidx));
156  return this->NodeWeight(nidx);
157  }
163  [[nodiscard]] MultiTargetTreeView View(Context const* ctx) const;
164 
165  void LoadModel(Json const& in) override;
166  void SaveModel(Json* out) const override;
167 
168  [[nodiscard]] std::size_t MemCostBytes() const;
169 };
170 } // namespace xgboost
171 #endif // XGBOOST_MULTI_TARGET_TREE_MODEL_H_
为 xgboost 定义配置宏和基本类型。
#define XGBOOST_DEVICE
Tag function as usable by device.
Definition: base.h:64
common::Span< T const > ConstHostSpan() const
Definition: host_device_vector.h:116
const std::vector< T > & ConstHostVector() const
common::Span< T > HostSpan()
定义: host_device_vector.h:114
表示JSON格式的数据结构。
Definition: json.h:392
多目标模型的树结构。
定义: multi_target_tree_model.h:69
bool IsLeaf(bst_node_t nidx) const
定义: multi_target_tree_model.h:115
bst_feature_t SplitIndex(bst_node_t nidx) const
定义: multi_target_tree_model.h:128
bst_node_t Parent(bst_node_t nidx) const
定义: multi_target_tree_model.h:118
std::size_t Size() const
bst_node_t RightChild(bst_node_t nidx) const
定义: multi_target_tree_model.h:124
MultiTargetTree & operator=(MultiTargetTree &&that)=delete
MultiTargetTree(MultiTargetTree const &that)
bst_target_t NumTargets() const
MultiTargetTree & operator=(MultiTargetTree const &that)=delete
void SaveModel(Json *out) const override
将模型配置保存到JSON对象
static constexpr bst_node_t InvalidNodeId()
定义: multi_target_tree_model.h:71
MultiTargetTree(TreeParam const *param)
bool DefaultLeft(bst_node_t nidx) const
定义: multi_target_tree_model.h:134
bst_node_t LeftChild(bst_node_t nidx) const
定义: multi_target_tree_model.h:121
bst_node_t DefaultChild(bst_node_t nidx) const
定义: multi_target_tree_model.h:137
MultiTargetTreeView View(Context const *ctx) const
获取树的视图。
MultiTargetTree(MultiTargetTree &&that)=delete
void SetLeaf(bst_node_t nidx, linalg::VectorView< float const > weight)
设置叶子的权重。
float SplitCond(bst_node_t nidx) const
定义: multi_target_tree_model.h:131
linalg::VectorView< float const > LeafValue(bst_node_t nidx) const
定义: multi_target_tree_model.h:154
bst_node_t Depth(bst_node_t nidx) const
定义: multi_target_tree_model.h:145
void Expand(bst_node_t nidx, bst_feature_t split_idx, float split_cond, bool default_left, linalg::VectorView< float const > base_weight, linalg::VectorView< float const > left_weight, linalg::VectorView< float const > right_weight)
将叶子节点展开为分裂节点。
void LoadModel(Json const &in) override
从JSON对象加载模型
std::size_t MemCostBytes() const
XGBOOST_DEVICE auto subspan() const -> Span< element_type, detail::ExtentValue< Extent, Offset, Count >::value >
定义: span.h:597
具有静态类型和维度的张量视图。它实现了索引和切片。
定义: linalg.h:277
LINALG_HD auto Shape() const
定义: linalg.h:506
LINALG_HD auto Slice(S &&...slices) const
切片张量。返回的张量具有推断的维度和形状。不支持标量结果。
定义: linalg.h:493
设备与主机向量抽象层。
线性代数相关工具。
定义 XGBoost 中不同组件的抽象接口。
auto MakeTensorView(Context const *ctx, Container &data, S &&...shape)
用于自动类型推断的构造函数。
定义: linalg.h:564
constexpr detail::AllTag All()
指定切片轴中的所有元素。
定义: linalg.h:249
集成目标、gbm和评估的学习器接口。这是用户面临的XGB...
Definition: base.h:97
std::int32_t bst_node_t
树节点索引的类型。
定义: base.h:119
std::uint32_t bst_target_t
用于索引输出目标的类型。
定义: base.h:127
std::uint32_t bst_feature_t
数据列(特征)索引的类型。
Definition: base.h:107
XGBoost的运行时上下文。包含线程和设备等信息。
Definition: context.h:133
constexpr static auto CPU()
CPU 的构造函数。
定义: context.h:64
定义: model.h:14
适合主机和设备的多目标树视图。
定义: multi_target_tree_model.h:26
bst_node_t Size() const
定义: multi_target_tree_model.h:63
XGBOOST_DEVICE float SplitCond(bst_node_t nidx) const
定义: multi_target_tree_model.h:51
std::size_t n
定义: multi_target_tree_model.h:38
linalg::MatrixView< float const > weights
定义: multi_target_tree_model.h:40
bst_target_t NumTargets() const
定义: multi_target_tree_model.h:62
bst_node_t const * parent
定义: multi_target_tree_model.h:31
XGBOOST_DEVICE bst_node_t RightChild(bst_node_t nidx) const
定义: multi_target_tree_model.h:47
bst_node_t const * left
定义: multi_target_tree_model.h:29
static constexpr bst_node_t InvalidNodeId()
定义: multi_target_tree_model.h:27
XGBOOST_DEVICE bst_node_t DefaultChild(bst_node_t nidx) const
定义: multi_target_tree_model.h:55
XGBOOST_DEVICE bool IsLeaf(bst_node_t nidx) const
定义: multi_target_tree_model.h:42
XGBOOST_DEVICE bst_node_t LeftChild(bst_node_t nidx) const
定义: multi_target_tree_model.h:46
XGBOOST_DEVICE linalg::VectorView< float const > LeafValue(bst_node_t nidx) const
定义: multi_target_tree_model.h:58
XGBOOST_DEVICE bst_feature_t SplitIndex(bst_node_t nidx) const
定义: multi_target_tree_model.h:48
std::uint8_t const * default_left
定义: multi_target_tree_model.h:34
float const * split_conds
定义: multi_target_tree_model.h:35
bst_feature_t const * split_index
定义: multi_target_tree_model.h:33
XGBOOST_DEVICE bool DefaultLeft(bst_node_t nidx) const
定义: multi_target_tree_model.h:52
bst_node_t const * right
定义: multi_target_tree_model.h:30
树的元参数
定义: tree_model.h:30