xgboost
multi_target_tree_model.h
转到此文件的文档。
1 
6 #ifndef XGBOOST_MULTI_TARGET_TREE_MODEL_H_
7 #define XGBOOST_MULTI_TARGET_TREE_MODEL_H_
8 #include <xgboost/base.h> // 用于 bst_node_t, bst_target_t, bst_feature_t
9 #include <xgboost/context.h> // 用于 Context
10 #include <xgboost/linalg.h> // 用于 VectorView
11 #include <xgboost/model.h> // 用于 Model
12 #include <xgboost/span.h> // 用于 Span
13 
14 #include <cinttypes> // 用于 uint8_t
15 #include <cstddef> // 用于 size_t
16 #include <vector> // 用于 vector
17 
18 namespace xgboost {
19 struct TreeParam;
23 class MultiTargetTree : public Model {
24  public
25  static bst_node_t constexpr InvalidNodeId() { return -1; }
26 
27  private
28  TreeParam const* param_;
29  std::vector<bst_node_t> left_;
30  std::vector<bst_node_t> right_;
31  std::vector<bst_node_t> parent_;
32  std::vector<bst_feature_t> split_index_;
33  std::vector<std::uint8_t> default_left_;
34  std::vector<float> split_conds_;
35  std::vector<float> weights_;
36 
37  [[nodiscard]] linalg::VectorView<float const> NodeWeight(bst_node_t nidx) const {
38  auto beg = nidx * this->NumTarget();
39  auto v = common::Span<float const>{weights_}.subspan(beg, this->NumTarget());
40  return linalg::MakeTensorView(DeviceOrd::CPU(), v, v.size());
41  }
42  [[nodiscard]] linalg::VectorView<float> NodeWeight(bst_node_t nidx) {
43  auto beg = nidx * this->NumTarget();
44  auto v = common::Span<float>{weights_}.subspan(beg, this->NumTarget());
45  return linalg::MakeTensorView(DeviceOrd::CPU(), v, v.size());
46  }
47 
48  public
49  explicit MultiTargetTree(TreeParam const* param);
57  void Expand(bst_node_t nidx, bst_feature_t split_idx, float split_cond, bool default_left,
60  linalg::VectorView<float const> right_weight);
61 
62  [[nodiscard]] bool IsLeaf(bst_node_t nidx) const { return left_[nidx] == InvalidNodeId(); }
63  [[nodiscard]] bst_node_t Parent(bst_node_t nidx) const { return parent_.at(nidx); }
64  [[nodiscard]] bst_node_t LeftChild(bst_node_t nidx) const { return left_.at(nidx); }
65  [[nodiscard]] bst_node_t RightChild(bst_node_t nidx) const { return right_.at(nidx); }
66 
67  [[nodiscard]] bst_feature_t SplitIndex(bst_node_t nidx) const { return split_index_[nidx]; }
68  [[nodiscard]] float SplitCond(bst_node_t nidx) const { return split_conds_[nidx]; }
69  [[nodiscard]] bool DefaultLeft(bst_node_t nidx) const { return default_left_[nidx]; }
70  [[nodiscard]] bst_node_t DefaultChild(bst_node_t nidx) const {
71  return this->DefaultLeft(nidx) ? this->LeftChild(nidx) : this->RightChild(nidx);
72  }
73 
74  [[nodiscard]] bst_target_t NumTarget() const;
75 
76  [[nodiscard]] std::size_t Size() const;
77 
78  [[nodiscard]] bst_node_t Depth(bst_node_t nidx) const {
79  bst_node_t depth{0};
80  while (Parent(nidx) != InvalidNodeId()) {
81  ++depth;
82  nidx = Parent(nidx);
83  }
84  return depth;
85  }
86 
88  CHECK(IsLeaf(nidx));
89  return this->NodeWeight(nidx);
90  }
91 
92  void LoadModel(Json const& in) override;
93  void SaveModel(Json* out) const override;
94 };
95 } // 命名空间 xgboost
96 #endif // XGBOOST_MULTI_TARGET_TREE_MODEL_H_
定义 xgboost 的配置宏和基本类型。
表示 JSON 格式的数据结构。
定义: json.h:378
用于多目标模型的树结构。
定义: multi_target_tree_model.h:23
bool IsLeaf(bst_node_t nidx) const
定义: multi_target_tree_model.h:62
bst_feature_t SplitIndex(bst_node_t nidx) const
定义: multi_target_tree_model.h:67
bst_node_t Parent(bst_node_t nidx) const
定义: multi_target_tree_model.h:63
std::size_t Size() const
bst_node_t RightChild(bst_node_t nidx) const
定义: multi_target_tree_model.h:65
bst_target_t NumTarget() const
void SaveModel(Json *out) const override
将模型配置保存到 JSON 对象
static constexpr bst_node_t InvalidNodeId()
定义: multi_target_tree_model.h:25
MultiTargetTree(TreeParam const *param)
bool DefaultLeft(bst_node_t nidx) const
定义: multi_target_tree_model.h:69
bst_node_t LeftChild(bst_node_t nidx) const
定义: multi_target_tree_model.h:64
bst_node_t DefaultChild(bst_node_t nidx) const
定义: multi_target_tree_model.h:70
void SetLeaf(bst_node_t nidx, linalg::VectorView< float const > weight)
设置叶节点的权重。
float SplitCond(bst_node_t nidx) const
定义: multi_target_tree_model.h:68
linalg::VectorView< float const > LeafValue(bst_node_t nidx) const
定义: multi_target_tree_model.h:87
bst_node_t Depth(bst_node_t nidx) const
定义: multi_target_tree_model.h:78
void Expand(bst_node_t nidx, bst_feature_t split_idx, float split_cond, bool default_left, linalg::VectorView< float const > base_weight, linalg::VectorView< float const > left_weight, linalg::VectorView< float const > right_weight)
将叶节点扩展为分裂节点。
void LoadModel(Json const &in) override
从 JSON 对象加载模型
span 类实现,基于 ISO++20 span<T>。接口应相同。
定义: span.h:431
XGBOOST_DEVICE auto subspan() const -> Span< element_type, detail::ExtentValue< Extent, Offset, Count >::value >
定义: span.h:597
具有静态类型和维度的张量视图。它实现了索引和切片操作。
定义: linalg.h:294
线性代数相关工具。
定义 XGBoost 中不同组件的抽象接口。
auto MakeTensorView(Context const *ctx, Container &data, S &&...shape)
自动类型推断的构造函数。
定义: linalg.h:581
多目标树的核心数据结构。
定义: base.h:89
std::int32_t bst_node_t
树节点索引的类型。
定义: base.h:111
std::uint32_t bst_target_t
用于索引输出目标的类型。
定义: base.h:119
std::uint32_t bst_feature_t
数据列(特征)索引的类型。
定义: base.h:99
constexpr static auto CPU()
CPU 的构造函数。
定义: context.h:64
定义: model.h:17
树的元参数
定义: tree_model.h:34