xgboost
data.h
前往此文件文档。
1 
7 #ifndef XGBOOST_DATA_H_
8 #define XGBOOST_DATA_H_
9 
10 #include <dmlc/base.h>
11 #include <dmlc/io.h> // for Stream
12 #include <dmlc/serializer.h> // for Handler
13 #include <xgboost/base.h>
15 #include <xgboost/linalg.h>
16 #include <xgboost/span.h>
17 #include <xgboost/string_view.h>
18 
19 #include <algorithm>
20 #include <cstdint> // for int32_t, uint8_t
21 #include <limits>
22 #include <memory>
23 #include <string>
24 #include <utility>
25 #include <vector>
26 
27 namespace xgboost {
28 // 提前声明 dmatrix。
29 class DMatrix;
30 struct Context;
31 
33 enum class DataType : uint8_t {
34  kFloat32 = 1,
35  kDouble = 2,
36  kUInt32 = 3,
37  kUInt64 = 4,
38  kStr = 5
39 };
40 
41 enum class FeatureType : uint8_t { kNumerical = 0, kCategorical = 1 };
42 
43 enum class DataSplitMode : int { kRow = 0, kCol = 1 };
44 
45 // 元信息使用的容器的提前声明。
46 class CatContainer;
47 
51 class MetaInfo {
52  public
54  static constexpr uint64_t kNumField = 13;
55 
57  bst_idx_t num_row_{0}; // NOLINT
59  uint64_t num_col_{0}; // NOLINT
61  uint64_t num_nonzero_{0}; // NOLINT
70  std::vector<bst_group_t> group_ptr_; // NOLINT
87 
91  std::vector<std::string> feature_type_names;
95  std::vector<std::string> feature_names;
96  /*
97  * \brief 每个特征的类型。在指定 feature_type_names 时自动设置。
98  */
100  /*
101  * \brief 每个特征的权重,用于在使用列抽样时定义每个特征被选中的概率。
102  * selected when using column sampling.
103  */
105 
107  MetaInfo(MetaInfo&& that) = default;
108  MetaInfo(MetaInfo const& that) = delete;
109  MetaInfo& operator=(MetaInfo&& that) = default;
110  MetaInfo& operator=(MetaInfo const& that) = delete;
111 
115  void Validate(DeviceOrd device) const;
125 
126  MetaInfo Copy() const;
130  bool IsDense() const { return num_col_ * num_row_ == num_nonzero_; }
136  inline bst_float GetWeight(size_t i) const {
137  return weights_.Size() != 0 ? weights_.HostVector()[i] : 1.0f;
138  }
140  const std::vector<size_t>& LabelAbsSort(Context const* ctx) const;
142  void Clear();
147  void LoadBinary(dmlc::Stream* fi);
152  void SaveBinary(dmlc::Stream* fo) const;
158  void SetInfo(Context const& ctx, StringView key, StringView interface_str);
159 
160  void GetInfo(char const* key, bst_ulong* out_len, DataType dtype,
161  const void** out_dptr) const;
162 
163  void SetFeatureInfo(const char *key, const char **info, const bst_ulong size);
164  void GetFeatureInfo(const char *field, std::vector<std::string>* out_str_vecs) const;
165 
177  void Extend(MetaInfo const& that, bool accumulate_rows, bool check_column);
185  void SynchronizeNumberOfColumns(Context const* ctx, DataSplitMode split_mode);
186 
188  [[nodiscard]] bool IsRowSplit() const { return data_split_mode == DataSplitMode::kRow; }
190  [[nodiscard]] bool IsColumnSplit() const { return data_split_mode == DataSplitMode::kCol; }
192  [[nodiscard]] bool IsRanking() const { return !group_ptr_.empty(); }
193 
198  [[nodiscard]] bool IsVerticalFederated() const;
199 
206  bool ShouldHaveLabels() const;
210  bool HasCategorical() const { return has_categorical_; }
214  [[nodiscard]] CatContainer const* Cats() const;
215  [[nodiscard]] CatContainer* Cats();
216  [[nodiscard]] std::shared_ptr<CatContainer const> CatsShared() const;
220  void Cats(std::shared_ptr<CatContainer> cats);
221 
222  private
223  void SetInfoFromHost(Context const* ctx, StringView key, Json arr);
224  void SetInfoFromCUDA(Context const* ctx, StringView key, Json arr);
225 
227  mutable std::vector<size_t> label_order_cache_;
228  bool has_categorical_{false};
229 
230  std::shared_ptr<CatContainer> cats_;
231 };
232 
234 struct Entry {
240  Entry() = default;
248  inline static bool CmpValue(const Entry& a, const Entry& b) {
249  return a.fvalue < b.fvalue;
250  }
251  static bool CmpIndex(Entry const& a, Entry const& b) {
252  return a.index < b.index;
253  }
254  inline bool operator==(const Entry& other) const {
255  return (this->index == other.index && this->fvalue == other.fvalue);
256  }
257 };
258 
262 struct BatchParam {
275  bool regen{false};
279  bool forbid_regen{false};
283  double sparse_thresh{std::numeric_limits<double>::quiet_NaN()};
289  bool prefetch_copy{true};
293  std::int32_t n_prefetch_batches{3};
297  BatchParam() = default;
310  : max_bin{max_bin}, hess{hessian}, regen{regenerate} {}
311 
312  [[nodiscard]] bool ParamNotEqual(BatchParam const& other) const {
313  // 检查非浮点参数。
314  bool cond = max_bin != other.max_bin;
315  // 检查稀疏阈值。
316  bool l_nan = std::isnan(sparse_thresh);
317  bool r_nan = std::isnan(other.sparse_thresh);
318  bool st_chg = (l_nan != r_nan) || (!l_nan && !r_nan && (sparse_thresh != other.sparse_thresh));
319  cond |= st_chg;
320 
321  return cond;
322  }
323  [[nodiscard]] bool Initialized() const { return max_bin != 0; }
327  [[nodiscard]] BatchParam MakeCache() const {
328  auto p = *this;
329  // 这些参数与梯度索引最初是如何生成的无关。
330  // first place.
331  p.regen = false;
332  p.forbid_regen = false;
333  return p;
334  }
335 };
336 
339 
342 
343  [[nodiscard]] Inst operator[](std::size_t i) const {
344  auto size = *(offset.data() + i + 1) - *(offset.data() + i);
345  return {data.data() + *(offset.data() + i), static_cast<Inst::index_type>(size)};
346  }
347 
348  [[nodiscard]] size_t Size() const { return offset.size() == 0 ? 0 : offset.size() - 1; }
349 };
350 
354 class SparsePage {
355  public
356  // 每行的偏移量。
360 
361  size_t base_rowid {0};
362 
365 
366  [[nodiscard]] HostSparsePageView GetView() const {
367  return {offset.ConstHostSpan(), data.ConstHostSpan()};
368  }
369 
372  this->Clear();
373  }
374 
375  SparsePage(SparsePage const& that) = delete;
376  SparsePage(SparsePage&& that) = default;
377  SparsePage& operator=(SparsePage const& that) = delete;
378  SparsePage& operator=(SparsePage&& that) = default;
379  virtual ~SparsePage() = default;
380 
382  [[nodiscard]] size_t Size() const {
383  return offset.Size() == 0 ? 0 : offset.Size() - 1;
384  }
385 
387  [[nodiscard]] size_t MemCostBytes() const {
388  return offset.Size() * sizeof(size_t) + data.Size() * sizeof(Entry);
389  }
390 
392  inline void Clear() {
393  base_rowid = 0;
394  auto& offset_vec = offset.HostVector();
395  offset_vec.clear();
396  offset_vec.push_back(0);
397  data.HostVector().clear();
398  }
399 
401  inline void SetBaseRowId(size_t row_id) {
402  base_rowid = row_id;
403  }
404 
405  [[nodiscard]] SparsePage GetTranspose(int num_columns, int32_t n_threads) const;
406 
410  void SortIndices(int32_t n_threads);
414  [[nodiscard]] bool IsIndicesSorted(int32_t n_threads) const;
418  void Reindex(uint64_t feature_offset, int32_t n_threads);
419 
420  void SortRows(int32_t n_threads);
421 
432  template <typename AdapterBatchT>
433  bst_idx_t Push(AdapterBatchT const& batch, float missing, std::int32_t nthread);
434 
439  void Push(const SparsePage &batch);
444  void PushCSC(const SparsePage& batch);
445 };
446 
447 class CSCPage: public SparsePage {
448  public
450  explicit CSCPage(SparsePage page) : SparsePage(std::move(page)) {}
451 };
452 
458  public
459  std::shared_ptr<SparsePage const> page;
460  explicit ExtSparsePage(std::shared_ptr<SparsePage const> p) : page{std::move(p)} {}
461 };
462 
463 class SortedCSCPage : public SparsePage {
464  public
466  explicit SortedCSCPage(SparsePage page) : SparsePage(std::move(page)) {}
467 };
468 
469 class EllpackPage;
470 class GHistIndexMatrix;
471 
472 template<typename T>
474  public
475  using iterator_category = std::forward_iterator_tag; // NOLINT
476  virtual ~BatchIteratorImpl() = default;
477  virtual const T& operator*() const = 0;
479  [[nodiscard]] virtual bool AtEnd() const = 0;
480  virtual std::shared_ptr<T const> Page() const = 0;
481 };
482 
483 template<typename T>
485  public
486  using iterator_category = std::forward_iterator_tag; // NOLINT
487  explicit BatchIterator(BatchIteratorImpl<T>* impl) { impl_.reset(impl); }
488  explicit BatchIterator(std::shared_ptr<BatchIteratorImpl<T>> impl) { impl_ = impl; }
489 
491  CHECK(impl_ != nullptr);
492  ++(*impl_);
493  return *this;
494  }
495 
496  const T& operator*() const {
497  CHECK(impl_ != nullptr);
498  return *(*impl_);
499  }
500 
501  [[nodiscard]] bool operator!=(const BatchIterator&) const { return !this->AtEnd(); }
502 
503  [[nodiscard]] bool AtEnd() const {
504  CHECK(impl_ != nullptr);
505  return impl_->AtEnd();
506  }
507 
508  [[nodiscard]] std::shared_ptr<T const> Page() const {
509  return impl_->Page();
510  }
511 
512  private
513  std::shared_ptr<BatchIteratorImpl<T>> impl_;
514 };
515 
516 template<typename T>
517 class BatchSet {
518  public
519  explicit BatchSet(BatchIterator<T> begin_iter) : begin_iter_(std::move(begin_iter)) {}
520  BatchIterator<T> begin() { return begin_iter_; } // NOLINT
521  BatchIterator<T> end() { return BatchIterator<T>(nullptr); } // NOLINT
522 
523  private
524  BatchIterator<T> begin_iter_;
525 };
526 
527 struct XGBAPIThreadLocalEntry;
528 
529 // 外部内存 DMatrix 的配置。
530 struct ExtMemConfig {
531  // 缓存前缀,如果缓存位于主机内存中,则不使用。(on_host 为 true)
532  std::string cache;
533  // Ellpack 页是否存储在主机内存中。
534  bool on_host;
535  // GPU 实现的主机缓存/总缓存。
537  // 缓存中每个 ellpack 页的最小字节数。仅用于主机内 ExtMemQdm。
538  // ExtMemQdm.
539  std::int64_t min_cache_page_bytes;
540  // 缺失值。
541  float missing;
542  // CPU 线程数。
543  std::int32_t n_threads{0};
544  // 缓存可压缩的比例。用于测试。
545  float hw_decomp_ratio{std::numeric_limits<float>::quiet_NaN()};
546  // 回退到使用 nvcomp。用于测试。
548 
549  ExtMemConfig() = delete;
550  ExtMemConfig(std::string cache, bool on_host, float h_ratio, std::int64_t min_cache,
551  float missing, std::int32_t n_threads)
552  : cache{std::move(cache)},
553  on_host{on_host},
554  cache_host_ratio{h_ratio},
555  min_cache_page_bytes{min_cache},
556  missing{missing},
557  n_threads{n_threads} {}
558 
559  ExtMemConfig& SetParamsForTest(float _hw_decomp_ratio, bool _allow_decomp_fallback) {
560  this->hw_decomp_ratio = _hw_decomp_ratio;
561  this->allow_decomp_fallback = _allow_decomp_fallback;
562  return *this;
563  }
564 };
565 
573 class DMatrix {
574  public
576  DMatrix() = default;
578  [[nodiscard]] virtual MetaInfo& Info() = 0;
579  virtual void SetInfo(const char* key, std::string const& interface_str) {
580  auto const& ctx = *this->Ctx();
581  this->Info().SetInfo(ctx, key, StringView{interface_str});
582  }
584  [[nodiscard]] virtual const MetaInfo& Info() const = 0;
585 
587  [[nodiscard]] XGBAPIThreadLocalEntry& GetThreadLocal() const;
592  [[nodiscard]] virtual Context const* Ctx() const = 0;
593 
597  template <typename T>
599  template <typename T>
601  template <typename T>
602  BatchSet<T> GetBatches(Context const* ctx, const BatchParam& param);
603  template <typename T>
604  [[nodiscard]] bool PageExists() const;
605 
611  [[nodiscard]] bool SingleColBlock() const { return this->NumBatches() == 1; }
612  [[nodiscard]] virtual std::int32_t NumBatches() const { return 1; }
613 
614  virtual ~DMatrix();
615 
619  [[nodiscard]] bool IsDense() const { return this->Info().IsDense(); }
620 
629  static DMatrix* Load(const std::string& uri, bool silent = true,
630  DataSplitMode data_split_mode = DataSplitMode::kRow);
631 
644  template <typename AdapterT>
645  static DMatrix* Create(AdapterT* adapter, float missing, int nthread,
646  const std::string& cache_prefix = "",
647  DataSplitMode data_split_mode = DataSplitMode::kRow);
648 
668  template <typename DataIterHandle, typename DMatrixHandle, typename DataIterResetCallback,
669  typename XGDMatrixCallbackNext>
670  static DMatrix* Create(DataIterHandle iter, DMatrixHandle proxy, std::shared_ptr<DMatrix> ref,
671  DataIterResetCallback* reset, XGDMatrixCallbackNext* next, float missing,
672  std::int32_t nthread, bst_bin_t max_bin, std::int64_t max_quantile_blocks);
673 
690  template <typename DataIterHandle, typename DMatrixHandle, typename DataIterResetCallback,
691  typename XGDMatrixCallbackNext>
693  XGDMatrixCallbackNext* next, ExtMemConfig const& config);
694 
702  template <typename DataIterHandle, typename DMatrixHandle, typename DataIterResetCallback,
703  typename XGDMatrixCallbackNext>
704  static DMatrix* Create(DataIterHandle iter, DMatrixHandle proxy, std::shared_ptr<DMatrix> ref,
706  bst_bin_t max_bin, std::int64_t max_quantile_blocks,
707  ExtMemConfig const& config);
708 
710 
718  virtual DMatrix* SliceCol(int num_slices, int slice_id) = 0;
722  [[nodiscard]] CatContainer const* Cats() const { return this->CatsShared().get(); }
723  [[nodiscard]] std::shared_ptr<CatContainer const> CatsShared() const {
724  return this->Info().CatsShared();
725  }
726 
727  protected
729  virtual BatchSet<CSCPage> GetColumnBatches(Context const* ctx) = 0;
731  virtual BatchSet<EllpackPage> GetEllpackBatches(Context const* ctx, BatchParam const& param) = 0;
733  BatchParam const& param) = 0;
734  virtual BatchSet<ExtSparsePage> GetExtBatches(Context const* ctx, BatchParam const& param) = 0;
735 
736  [[nodiscard]] virtual bool EllpackExists() const = 0;
737  [[nodiscard]] virtual bool GHistIndexExists() const = 0;
738  [[nodiscard]] virtual bool SparsePageExists() const = 0;
739 };
740 
741 template <>
743  return GetRowBatches();
744 }
745 
746 template <>
747 inline bool DMatrix::PageExists<EllpackPage>() const {
748  return this->EllpackExists();
749 }
750 
751 template <>
752 inline bool DMatrix::PageExists<GHistIndexMatrix>() const {
753  return this->GHistIndexExists();
754 }
755 
756 template <>
757 inline bool DMatrix::PageExists<SparsePage>() const {
758  return this->SparsePageExists();
759 }
760 
761 template <>
763  return GetRowBatches();
764 }
765 
766 template <>
767 inline BatchSet<CSCPage> DMatrix::GetBatches(Context const* ctx) {
768  return GetColumnBatches(ctx);
769 }
770 
771 template <>
772 inline BatchSet<SortedCSCPage> DMatrix::GetBatches(Context const* ctx) {
773  return GetSortedColumnBatches(ctx);
774 }
775 
776 template <>
778  return GetEllpackBatches(ctx, param);
779 }
780 
781 template <>
782 inline BatchSet<GHistIndexMatrix> DMatrix::GetBatches(Context const* ctx, BatchParam const& param) {
783  return GetGradientIndex(ctx, param);
784 }
785 
786 template <>
787 inline BatchSet<ExtSparsePage> DMatrix::GetBatches(Context const* ctx, BatchParam const& param) {
788  return GetExtBatches(ctx, param);
789 }
790 } // namespace xgboost
791 
793 
794 namespace dmlc {
796 
797 namespace serializer {
798 
799 template <>
800 struct Handler<xgboost::Entry> {
801  inline static void Write(Stream* strm, const xgboost::Entry& data) {
802  strm->Write(data.index);
803  strm->Write(data.fvalue);
804  }
805 
806  inline static bool Read(Stream* strm, xgboost::Entry* data) {
807  return strm->Read(&data->index) && strm->Read(&data->fvalue);
808  }
809 };
810 
811 } // namespace serializer
812 } // namespace dmlc
813 #endif // XGBOOST_DATA_H_
为 xgboost 定义配置宏和基本类型。
#define XGBOOST_DEVICE
Tag function as usable by device.
Definition: base.h:64
Definition: data.h:473
virtual BatchIteratorImpl & operator++()=0
std::forward_iterator_tag iterator_category
Definition: data.h:475
virtual std::shared_ptr< T const > Page() const =0
virtual bool AtEnd() const =0
virtual const T & operator*() const =0
virtual ~BatchIteratorImpl()=default
Definition: data.h:484
BatchIterator(std::shared_ptr< BatchIteratorImpl< T >> impl)
Definition: data.h:488
std::forward_iterator_tag iterator_category
Definition: data.h:486
BatchIterator(BatchIteratorImpl< T > *impl)
Definition: data.h:487
const T & operator*() const
Definition: data.h:496
std::shared_ptr< T const > Page() const
Definition: data.h:508
BatchIterator & operator++()
Definition: data.h:490
bool operator!=(const BatchIterator &) const
Definition: data.h:501
bool AtEnd() const
Definition: data.h:503
Definition: data.h:517
BatchSet(BatchIterator< T > begin_iter)
Definition: data.h:519
BatchIterator< T > begin()
Definition: data.h:520
BatchIterator< T > end()
Definition: data.h:521
Definition: data.h:447
CSCPage()
Definition: data.h:449
CSCPage(SparsePage page)
Definition: data.h:450
内部数据结构,由XGBoost用于保存所有外部数据。
Definition: data.h:573
CatContainer const * Cats() const
用于类别字符串表示的访问器。
Definition: data.h:722
virtual BatchSet< EllpackPage > GetEllpackBatches(Context const *ctx, BatchParam const &param)=0
static DMatrix * Load(const std::string &uri, bool silent=true, DataSplitMode data_split_mode=DataSplitMode::kRow)
从URI加载DMatrix。
virtual BatchSet< SparsePage > GetRowBatches()=0
virtual BatchSet< GHistIndexMatrix > GetGradientIndex(Context const *ctx, BatchParam const &param)=0
static DMatrix * Create(DataIterHandle iter, DMatrixHandle proxy, DataIterResetCallback *reset, XGDMatrixCallbackNext *next, ExtMemConfig const &config)
使用回调创建外部内存DMatrix。
virtual void SetInfo(const char *key, std::string const &interface_str)
Definition: data.h:579
virtual BatchSet< ExtSparsePage > GetExtBatches(Context const *ctx, BatchParam const &param)=0
bool PageExists() const
BatchSet< T > GetBatches(Context const *ctx)
virtual ~DMatrix()
virtual MetaInfo & Info()=0
数据集的元信息
static DMatrix * Create(DataIterHandle iter, DMatrixHandle proxy, std::shared_ptr< DMatrix > ref, DataIterResetCallback *reset, XGDMatrixCallbackNext *next, float missing, std::int32_t nthread, bst_bin_t max_bin, std::int64_t max_quantile_blocks)
创建用于基于直方图的算法的新的基于分位数的DMatrix。
static DMatrix * Create(AdapterT *adapter, float missing, int nthread, const std::string &cache_prefix="", DataSplitMode data_split_mode=DataSplitMode::kRow)
从外部数据适配器创建新的DMatrix。
virtual DMatrix * SliceCol(int num_slices, int slice_id)=0
按列切分DMatrix。
virtual bool GHistIndexExists() const =0
XGBAPIThreadLocalEntry & GetThreadLocal() const
获取DMatrix返回数据的线程局部内存。
virtual bool SparsePageExists() const =0
virtual DMatrix * Slice(common::Span< int32_t const > ridxs)=0
virtual Context const * Ctx() const =0
获取此DMatrix的上下文对象。上下文是在DMatrix使用...构造时创建的
BatchSet< T > GetBatches()
获取批次。使用基于范围的for循环遍历BatchSet以访问单个批次。
virtual const MetaInfo & Info() const =0
数据集的元信息
virtual bool EllpackExists() const =0
bool SingleColBlock() const
Definition: data.h:611
virtual BatchSet< CSCPage > GetColumnBatches(Context const *ctx)=0
virtual BatchSet< SortedCSCPage > GetSortedColumnBatches(Context const *ctx)=0
BatchSet< T > GetBatches(Context const *ctx, const BatchParam &param)
std::shared_ptr< CatContainer const > CatsShared() const
Definition: data.h:723
static DMatrix * Create(DataIterHandle iter, DMatrixHandle proxy, std::shared_ptr< DMatrix > ref, DataIterResetCallback *reset, XGDMatrixCallbackNext *next, bst_bin_t max_bin, std::int64_t max_quantile_blocks, ExtMemConfig const &config)
使用回调创建外部内存分位数DMatrix。
virtual std::int32_t NumBatches() const
Definition: data.h:612
bool IsDense() const
矩阵是否是稠密的。
Definition: data.h:619
DMatrix()=default
默认构造函数
用于导出DMatrix的稀疏页面。与SparsePage相同,只是类型不同,以防止在...中使用
Definition: data.h:457
ExtSparsePage(std::shared_ptr< SparsePage const > p)
Definition: data.h:460
std::shared_ptr< SparsePage const > page
Definition: data.h:459
std::size_t Size() const
common::Span< T const > ConstHostSpan() const
Definition: host_device_vector.h:116
std::vector< T > & HostVector()
表示JSON格式的数据结构。
Definition: json.h:392
数据集的元信息,始终存储在内存中。
Definition: data.h:51
std::vector< std::string > feature_names
每个特征的名称。
Definition: data.h:95
std::shared_ptr< CatContainer const > CatsShared() const
MetaInfo(MetaInfo &&that)=default
HostDeviceVector< bst_float > labels_upper_bound_
标签的上限,用于生存分析(截尾回归)
Definition: data.h:86
MetaInfo Slice(Context const *ctx, common::Span< bst_idx_t const > ridxs, bst_idx_t nnz) const
切分元信息。
void SynchronizeNumberOfColumns(Context const *ctx, DataSplitMode split_mode)
在所有工作器之间同步列数。
uint64_t num_col_
数据中的列数
Definition: data.h:59
std::vector< std::string > feature_type_names
用户提供的每个特征类型的名称。例如:“int”/“float”/“i”/“q”。
Definition: data.h:91
HostDeviceVector< bst_float > weights_
每个实例的权重,可选
Definition: data.h:72
CatContainer * Cats()
bool IsVerticalFederated() const
检查我们是否正在进行垂直联邦学习的便捷方法,这需要一些特殊的...
MetaInfo & operator=(MetaInfo const &that)=delete
void GetInfo(char const *key, bst_ulong *out_len, DataType dtype, const void **out_dptr) const
bool IsColumnSplit() const
数据是否按列拆分。
Definition: data.h:190
bst_float GetWeight(size_t i) const
获取每个实例的权重。
Definition: data.h:136
HostDeviceVector< FeatureType > feature_types
Definition: data.h:99
DataSplitMode data_split_mode
数据拆分模式
Definition: data.h:65
void LoadBinary(dmlc::Stream *fi)
从二进制流加载元信息。
std::vector< bst_group_t > group_ptr_
当学习任务是排名时,组的开始和结束索引。
Definition: data.h:70
void Validate(DeviceOrd device) const
验证所有元信息。
HostDeviceVector< float > feature_weights
Definition: data.h:104
void GetFeatureInfo(const char *field, std::vector< std::string > *out_str_vecs) const
bst_idx_t num_row_
数据中的行数
Definition: data.h:57
bool HasCategorical() const
DMatrix是否具有分类特征的标志。
Definition: data.h:210
void Cats(std::shared_ptr< CatContainer > cats)
类别的设置器。
CatContainer const * Cats() const
类别的获取器。
MetaInfo & operator=(MetaInfo &&that)=default
bool IsDense() const
矩阵是否是稠密的。
Definition: data.h:130
void Clear()
清除所有信息
void Extend(MetaInfo const &that, bool accumulate_rows, bool check_column)
扩展其他元信息。
bool IsRanking() const
这是否是排名学习数据。
Definition: data.h:192
uint64_t num_nonzero_
数据中非零条目数
Definition: data.h:61
MetaInfo Copy() const
linalg::Tensor< float, 2 > labels
每个实例的标签
Definition: data.h:63
void SaveBinary(dmlc::Stream *fo) const
将元信息保存到二进制流。
MetaInfo(MetaInfo const &that)=delete
bool ShouldHaveLabels() const
检查MetaInfo是否应包含标签的便捷方法。
void SetInfo(Context const &ctx, StringView key, StringView interface_str)
使用数组接口设置元信息中的信息。
static constexpr uint64_t kNumField
MetaInfo中的数据字段数量
Definition: data.h:54
bool IsRowSplit() const
数据是否按行拆分。
Definition: data.h:188
linalg::Matrix< float > base_margin_
初始化的边距,如果指定,xgboost将从该初始边距开始,可用于指定在...
Definition: data.h:78
HostDeviceVector< bst_float > labels_lower_bound_
标签的下限,用于生存分析(截尾回归)
Definition: data.h:82
const std::vector< size_t > & LabelAbsSort(Context const *ctx) const
按绝对值获取标签的排序索引(argsort)(由cox损失使用)
void SetFeatureInfo(const char *key, const char **info, const bst_ulong size)
Definition: data.h:463
SortedCSCPage(SparsePage page)
Definition: data.h:466
SortedCSCPage()
Definition: data.h:465
稀疏批次的内存存储单元,以CSR格式存储。
Definition: data.h:354
void Push(const SparsePage &batch)
推送稀疏页面。
SparsePage()
构造函数
Definition: data.h:371
SparsePage GetTranspose(int num_columns, int32_t n_threads) const
void SetBaseRowId(size_t row_id)
设置此页面的基本行ID。
Definition: data.h:401
void Reindex(uint64_t feature_offset, int32_t n_threads)
使用偏移量重新索引列索引。
void PushCSC(const SparsePage &batch)
推送以CSC格式存储的SparsePage。
bool IsIndicesSorted(int32_t n_threads) const
检查列索引是否已排序。
virtual ~SparsePage()=default
void SortIndices(int32_t n_threads)
排序列索引。
HostDeviceVector< Entry > data
段的数据
Definition: data.h:359
HostSparsePageView GetView() const
Definition: data.h:366
SparsePage & operator=(SparsePage const &that)=delete
size_t MemCostBytes() const
Definition: data.h:387
void Clear()
清除页面
Definition: data.h:392
SparsePage(SparsePage const &that)=delete
size_t Size() const
Definition: data.h:382
void SortRows(int32_t n_threads)
bst_idx_t Push(AdapterBatchT const &batch, float missing, std::int32_t nthread)
将外部数据批次推送到此页面。
HostDeviceVector< bst_idx_t > offset
Definition: data.h:357
SparsePage & operator=(SparsePage &&that)=default
SparsePage(SparsePage &&that)=default
size_t base_rowid
Definition: data.h:361
span类实现,基于ISO++20 span<T>。接口应相同。
Definition: span.h:431
constexpr XGBOOST_DEVICE pointer data() const __span_noexcept
Definition: span.h:550
std::size_t index_type
Definition: span.h:435
constexpr XGBOOST_DEVICE index_type size() const __span_noexcept
Definition: span.h:555
DECLARE_FIELD_ENUM_CLASS(xgboost::DataSplitMode)
void * DMatrixHandle
DMatrix 句柄
定义: c_api.h:50
int XGDMatrixCallbackNext(DataIterHandle iter)
获取下一批数据的回调函数原型。
Definition: c_api.h:470
void * DataIterHandle
外部数据迭代器的句柄
定义: c_api.h:368
void DataIterResetCallback(DataIterHandle handle)
重置外部迭代器的回调函数原型。
Definition: c_api.h:475
设备与主机向量抽象层。
线性代数相关工具。
Definition: data.h:794
DMLC_DECLARE_TRAITS(is_pod, xgboost::Entry, true)
Definition: intrusive_ptr.h:207
集成目标、gbm和评估的学习器接口。这是用户面临的XGB...
Definition: base.h:97
FeatureType
Definition: data.h:41
std::int32_t bst_bin_t
直方图 bin 索引的类型。我们有时使用 -1 表示无效 bin。
Definition: base.h:111
std::uint64_t bst_idx_t
数据行索引(样本)的类型。
Definition: base.h:115
DataSplitMode
Definition: data.h:43
std::uint32_t bst_feature_t
数据列(特征)索引的类型。
Definition: base.h:107
DataType
xgboost接口接受的数据类型
Definition: data.h:33
std::uint64_t bst_ulong
无符号长整数
Definition: base.h:101
float bst_float
浮点类型,用于存储统计信息
Definition: base.h:103
static void Write(Stream *strm, const xgboost::Entry &data)
Definition: data.h:801
static bool Read(Stream *strm, xgboost::Entry *data)
Definition: data.h:806
用于构造直方图索引批次的参数。
Definition: data.h:262
bool forbid_regen
禁止重新生成梯度索引。用于内部验证。
Definition: data.h:279
bst_bin_t max_bin
每个特征的最大 bin 数量,用于直方图。
Definition: data.h:266
common::Span< float const > hess
Hessian,用于未来近似实现中的草图。
Definition: data.h:270
bool regen
我们是否应该强制DMatrix重新生成批次。仅用于GHistIndex。
Definition: data.h:275
bool ParamNotEqual(BatchParam const &other) const
Definition: data.h:312
BatchParam()=default
精确或其他不需要直方图的。
bool prefetch_copy
用于GPU外部内存。是否将数据复制到设备。
Definition: data.h:289
double sparse_thresh
用于生成直方图列矩阵的参数。
Definition: data.h:283
bool Initialized() const
Definition: data.h:323
BatchParam(bst_bin_t max_bin, common::Span< float const > hessian, bool regenerate)
由近似树方法使用。
Definition: data.h:309
BatchParam MakeCache() const
为DMatrix制作一个自身的副本,以描述其现有索引是如何生成的。
Definition: data.h:327
BatchParam(bst_bin_t max_bin, double sparse_thresh)
由直方图树方法使用。
Definition: data.h:301
std::int32_t n_prefetch_batches
外部内存预取的批次数。
Definition: data.h:293
XGBoost的运行时上下文。包含线程和设备等信息。
Definition: context.h:133
设备序号的类型。该类型被打包成32位,以便在查看类型(如lin...)时高效使用
Definition: context.h:34
稀疏向量中的元素。
Definition: data.h:234
XGBOOST_DEVICE Entry(bst_feature_t index, bst_float fvalue)
带索引和值的构造函数
Definition: data.h:246
Entry()=default
默认构造函数
bst_feature_t index
特征索引
Definition: data.h:236
static bool CmpIndex(Entry const &a, Entry const &b)
Definition: data.h:251
bst_float fvalue
特征值
Definition: data.h:238
bool operator==(const Entry &other) const
Definition: data.h:254
static bool CmpValue(const Entry &a, const Entry &b)
反向比较特征值
Definition: data.h:248
Definition: data.h:530
ExtMemConfig(std::string cache, bool on_host, float h_ratio, std::int64_t min_cache, float missing, std::int32_t n_threads)
Definition: data.h:550
ExtMemConfig & SetParamsForTest(float _hw_decomp_ratio, bool _allow_decomp_fallback)
Definition: data.h:559
std::int32_t n_threads
Definition: data.h:543
std::int64_t min_cache_page_bytes
Definition: data.h:539
bool on_host
Definition: data.h:534
bool allow_decomp_fallback
Definition: data.h:547
float cache_host_ratio
Definition: data.h:536
float missing
Definition: data.h:541
std::string cache
Definition: data.h:532
float hw_decomp_ratio
Definition: data.h:545
Definition: data.h:337
size_t Size() const
Definition: data.h:348
Inst operator[](std::size_t i) const
Definition: data.h:343
common::Span< bst_idx_t const > offset
Definition: data.h:340
common::Span< Entry const > data
Definition: data.h:341
Definition: string_view.h:16