6 #ifndef XGBOOST_MULTI_TARGET_TREE_MODEL_H_
7 #define XGBOOST_MULTI_TARGET_TREE_MODEL_H_
29 std::vector<bst_node_t> left_;
30 std::vector<bst_node_t> right_;
31 std::vector<bst_node_t> parent_;
32 std::vector<bst_feature_t> split_index_;
33 std::vector<std::uint8_t> default_left_;
34 std::vector<float> split_conds_;
35 std::vector<float> weights_;
42 [[nodiscard]] linalg::VectorView<float> NodeWeight(
bst_node_t nidx) {
44 auto v = common::Span<float>{weights_}.subspan(beg, this->
NumTarget());
76 [[nodiscard]] std::size_t
Size()
const;
89 return this->NodeWeight(nidx);
表示 JSON 格式的数据结构。
定义: json.h:378
用于多目标模型的树结构。
定义: multi_target_tree_model.h:23
bool IsLeaf(bst_node_t nidx) const
定义: multi_target_tree_model.h:62
bst_feature_t SplitIndex(bst_node_t nidx) const
定义: multi_target_tree_model.h:67
bst_node_t Parent(bst_node_t nidx) const
定义: multi_target_tree_model.h:63
bst_node_t RightChild(bst_node_t nidx) const
定义: multi_target_tree_model.h:65
bst_target_t NumTarget() const
void SaveModel(Json *out) const override
将模型配置保存到 JSON 对象
static constexpr bst_node_t InvalidNodeId()
定义: multi_target_tree_model.h:25
MultiTargetTree(TreeParam const *param)
bool DefaultLeft(bst_node_t nidx) const
定义: multi_target_tree_model.h:69
bst_node_t LeftChild(bst_node_t nidx) const
定义: multi_target_tree_model.h:64
bst_node_t DefaultChild(bst_node_t nidx) const
定义: multi_target_tree_model.h:70
void SetLeaf(bst_node_t nidx, linalg::VectorView< float const > weight)
设置叶节点的权重。
float SplitCond(bst_node_t nidx) const
定义: multi_target_tree_model.h:68
linalg::VectorView< float const > LeafValue(bst_node_t nidx) const
定义: multi_target_tree_model.h:87
bst_node_t Depth(bst_node_t nidx) const
定义: multi_target_tree_model.h:78
void Expand(bst_node_t nidx, bst_feature_t split_idx, float split_cond, bool default_left, linalg::VectorView< float const > base_weight, linalg::VectorView< float const > left_weight, linalg::VectorView< float const > right_weight)
将叶节点扩展为分裂节点。
void LoadModel(Json const &in) override
从 JSON 对象加载模型
span 类实现,基于 ISO++20 span<T>。接口应相同。
定义: span.h:431
XGBOOST_DEVICE auto subspan() const -> Span< element_type, detail::ExtentValue< Extent, Offset, Count >::value >
定义: span.h:597
具有静态类型和维度的张量视图。它实现了索引和切片操作。
定义: linalg.h:294
auto MakeTensorView(Context const *ctx, Container &data, S &&...shape)
自动类型推断的构造函数。
定义: linalg.h:581
多目标树的核心数据结构。
定义: base.h:89
std::int32_t bst_node_t
树节点索引的类型。
定义: base.h:111
std::uint32_t bst_target_t
用于索引输出目标的类型。
定义: base.h:119
std::uint32_t bst_feature_t
数据列(特征)索引的类型。
定义: base.h:99
constexpr static auto CPU()
CPU 的构造函数。
定义: context.h:64