xgboost
predictor.h
前往此文件的文档。
1 
7 #pragma once
8 #include <xgboost/base.h>
9 #include <xgboost/cache.h> // 用于 DMatrixCache
10 #include <xgboost/context.h> // 用于 Context
11 #include <xgboost/context.h>
12 #include <xgboost/data.h>
14 
15 #include <functional> // 用于 function
16 #include <memory> // 用于 shared_ptr
17 #include <string>
18 #include <vector>
19 
20 // 前向声明
21 namespace xgboost::gbm {
22 struct GBTreeModel;
23 } // namespace xgboost::gbm
24 
25 namespace xgboost {
30  // 用于缓存预测值的存储
32  // 当前缓存的版本,对应于树的层数
33  std::uint32_t version{0};
34 
35  PredictionCacheEntry() = default;
41  void Update(std::uint32_t v) { version += v; }
42  void Reset() { version = 0; }
43 };
44 
48 class PredictionContainer : public DMatrixCache<PredictionCacheEntry> {
49  // 我们为所有线程缓存最多 64 个 DMatrix
50  std::size_t static constexpr DefaultSize() { return 64; }
51 
52  public
54  std::shared_ptr<PredictionCacheEntry> Cache(std::shared_ptr<DMatrix> m, DeviceOrd device) {
55  auto p_cache = this->CacheItem(m);
56  if (!device.IsCPU()) {
57  p_cache->predictions.SetDevice(device);
58  }
59  return p_cache;
60  }
61 };
62 
71 class Predictor {
72  protected
73  Context const* ctx_;
74 
75  public
76  explicit Predictor(Context const* ctx) : ctx_{ctx} {}
77 
78  virtual ~Predictor() = default;
79 
85  virtual void Configure(Args const&);
86 
94  virtual void InitOutPredictions(const MetaInfo& info, HostDeviceVector<float>* out_predt,
95  const gbm::GBTreeModel& model) const;
96 
107  virtual void PredictBatch(DMatrix* dmat, PredictionCacheEntry* out_preds,
108  gbm::GBTreeModel const& model, bst_tree_t tree_begin,
109  bst_tree_t tree_end = 0) const = 0;
110 
124  virtual bool InplacePredict(std::shared_ptr<DMatrix> p_fmat, const gbm::GBTreeModel& model,
125  float missing, PredictionCacheEntry* out_preds,
126  bst_tree_t tree_begin = 0, bst_tree_t tree_end = 0) const = 0;
127 
138  virtual void PredictLeaf(DMatrix* dmat, HostDeviceVector<float>* out_preds,
139  gbm::GBTreeModel const& model, bst_tree_t tree_end = 0) const = 0;
140 
156  virtual void PredictContribution(DMatrix* dmat, HostDeviceVector<float>* out_contribs,
157  gbm::GBTreeModel const& model, bst_tree_t tree_end = 0,
158  std::vector<float> const* tree_weights = nullptr,
159  bool approximate = false, int condition = 0,
160  unsigned condition_feature = 0) const = 0;
161 
163  gbm::GBTreeModel const& model,
164  bst_tree_t tree_end = 0,
165  std::vector<float> const* tree_weights = nullptr,
166  bool approximate = false) const = 0;
167 
174  static Predictor* Create(std::string const& name, Context const* ctx);
175 };
176 
181  : public dmlc::FunctionRegEntryBase<PredictorReg, std::function<Predictor*(Context const*)>> {};
182 
183 #define XGBOOST_REGISTER_PREDICTOR(UniqueId, Name) \
184  static DMLC_ATTRIBUTE_UNUSED ::xgboost::PredictorReg& \
185  __make_##PredictorReg##_##UniqueId##__ = \
186  ::dmlc::Registry<::xgboost::PredictorReg>::Get()->__REGISTER__(Name)
187 } // namespace xgboost
定义了 xgboost 的配置宏和基本类型。
线程感知的 DMatrix 相关数据的 FIFO 缓存。
定义: cache.h:26
std::shared_ptr< PredictionCacheEntry > CacheItem(std::shared_ptr< DMatrix > m, Args const &... args)
如果缓存中尚不存在,则缓存新的 DMatrix。
定义: cache.h:145
XGBoost 用于保存所有外部数据的内部数据结构。
定义: data.h:549
数据集的元信息,始终位于内存中。
定义: data.h:48
用于管理预测缓存的容器。
定义: predictor.h:48
std::shared_ptr< PredictionCacheEntry > Cache(std::shared_ptr< DMatrix > m, DeviceOrd device)
定义: predictor.h:54
PredictionContainer()
定义: predictor.h:53
对 GBTree 的单个训练实例或实例批次执行预测...
定义: predictor.h:71
virtual void InitOutPredictions(const MetaInfo &info, HostDeviceVector< float > *out_predt, const gbm::GBTreeModel &model) const
初始化输出预测。
virtual void Configure(Args const &)
在预测缓存中配置和注册输入矩阵。
virtual bool InplacePredict(std::shared_ptr< DMatrix > p_fmat, const gbm::GBTreeModel &model, float missing, PredictionCacheEntry *out_preds, bst_tree_t tree_begin=0, bst_tree_t tree_end=0) const =0
原地预测。
Predictor(Context const *ctx)
定义: predictor.h:76
virtual void PredictContribution(DMatrix *dmat, HostDeviceVector< float > *out_contribs, gbm::GBTreeModel const &model, bst_tree_t tree_end=0, std::vector< float > const *tree_weights=nullptr, bool approximate=false, int condition=0, unsigned condition_feature=0) const =0
特征对个体预测的贡献;输出将是一个长度为 (nfeats + 1) *... 的向量
virtual void PredictBatch(DMatrix *dmat, PredictionCacheEntry *out_preds, gbm::GBTreeModel const &model, bst_tree_t tree_begin, bst_tree_t tree_end=0) const =0
为给定的特征矩阵生成批量预测。如果可用,可能会使用缓存的预测...
Context const * ctx_
定义: predictor.h:73
static Predictor * Create(std::string const &name, Context const *ctx)
创建一个新的 Predictor*。
virtual ~Predictor()=default
virtual void PredictLeaf(DMatrix *dmat, HostDeviceVector< float > *out_preds, gbm::GBTreeModel const &model, bst_tree_t tree_end=0) const =0
预测每棵树的叶子索引,输出将是 nsample * ntree 向量,这仅在...中有效
virtual void PredictInteractionContributions(DMatrix *dmat, HostDeviceVector< float > *out_contribs, gbm::GBTreeModel const &model, bst_tree_t tree_end=0, std::vector< float > const *tree_weights=nullptr, bool approximate=false) const =0
xgboost 的输入数据结构。
设备和主机向量抽象层。
定义: linear_updater.h:23
用于多目标树的核心数据结构。
定义: base.h:89
std::vector< std::pair< std::string, std::string > > Args
定义: base.h:316
std::int32_t bst_tree_t
用于索引树的类型。
定义: base.h:127
XGBoost 的运行时上下文。包含线程和设备等信息。
定义: context.h:133
设备序号类型。该类型被打包成 32 位,以便在查看 lin... 等类型时有效使用
定义: context.h:34
bool IsCPU() const
定义: context.h:45
包含指向输入矩阵和相关联的缓存预测的指针。
定义: predictor.h:29
std::uint32_t version
定义: predictor.h:33
HostDeviceVector< float > predictions
定义: predictor.h:31
void Reset()
定义: predictor.h:42
void Update(std::uint32_t v)
按版本号更新缓存条目。
定义: predictor.h:41
预测器的注册表条目。
定义: predictor.h:181