xgboost
json_io.h
查看此文件的文档。
1 
4 #ifndef XGBOOST_JSON_IO_H_
5 #define XGBOOST_JSON_IO_H_
6 #include <dmlc/endian.h>
7 #include <xgboost/base.h>
8 #include <xgboost/json.h>
9 
10 #include <cstdint> // 用于 int8_t
11 #include <limits>
12 #include <string>
13 #include <utility>
14 #include <vector>
15 
16 namespace xgboost {
20 class JsonReader {
21  public
22  using Char = std::int8_t;
23 
24  protected
25  size_t constexpr static kMaxNumLength = std::numeric_limits<double>::max_digits10 + 1;
26 
27  struct SourceLocation {
28  private
29  std::size_t pos_{0}; // current position in raw_str_
30 
31  public
32  SourceLocation() = default;
33  size_t Pos() const { return pos_; }
34 
35  void Forward() { pos_++; }
36  void Forward(uint32_t n) { pos_ += n; }
38 
40 
41  protected
42  void SkipSpaces();
43 
45  if (XGBOOST_EXPECT((cursor_.Pos() == raw_str_.size()), false)) {
46  return -1;
47  }
48  char ch = raw_str_[cursor_.Pos()];
49  cursor_.Forward();
50  return ch;
51  }
52 
54  if (cursor_.Pos() == raw_str_.size()) {
55  return -1;
56  }
57  Char ch = raw_str_[cursor_.Pos()];
58  return ch;
59  }
60 
61  /* \brief 跳过空格并消耗下一个字符。*/
63  SkipSpaces();
64  return GetNextChar();
65  }
66  /* \brief 在不先跳过空字符的情况下消耗下一个字符,当下一个*/
67  * 字符不是预期的字符时抛出异常。*/
68  */
69  Char GetConsecutiveChar(char expected_char) {
70  Char result = GetNextChar();
71  if (XGBOOST_EXPECT(result != expected_char, false)) { Expect(expected_char, result); }
72  return result;
73  }
74 
75  void Error(std::string msg) const;
76 
77  // 报告预期的字符
78  void Expect(Char c, Char got) {
79  std::string msg = "期望: \"";
80  msg += c;
81  msg += "\", 得到: \"";
82  if (got == EOF) {
83  msg += "EOF\"";
84  } else if (got == 0) {
85  msg += "\\0\"";
86  } else {
87  msg += std::to_string(got) + " \"";
88  }
89  Error(msg);
90  }
91 
92  virtual Json ParseString();
93  virtual Json ParseObject();
94  virtual Json ParseArray();
95  virtual Json ParseNumber();
96  virtual Json ParseBoolean();
97  virtual Json ParseNull();
98 
100 
101  public
102  explicit JsonReader(StringView str)
103  raw_str_{str} {}
104 
105  virtual ~JsonReader() = default;
106 
107  virtual Json Load();
108 };
109 
110 class JsonWriter {
111  template <typename T, std::enable_if_t<!std::is_same_v<Json, T>>* = nullptr>
112  void Save(T const& v) {
113  this->Save(Json{v});
114  }
115  template <typename Array, typename Fn>
116  void WriteArray(Array const* arr, Fn&& fn) {
117  stream_->emplace_back('[');
118  auto const& vec = arr->GetArray();
119  size_t size = vec.size();
120  for (size_t i = 0; i < size; ++i) {
121  auto const& value = vec[i];
122  this->Save(fn(value));
123  if (i != size - 1) {
124  stream_->emplace_back(',');
125  }
126  }
127  stream_->emplace_back(']');
128  }
129 
130  protected
131  std::vector<char>* stream_;
132 
133  public
134  explicit JsonWriter(std::vector<char>* stream) : stream_{stream} {}
135 
136  virtual ~JsonWriter() = default;
137 
138  virtual void Save(Json json);
139 
140  virtual void Visit(JsonArray const* arr);
141  virtual void Visit(F32Array const* arr);
142  virtual void Visit(F64Array const*) { LOG(FATAL) << "只有 UBJSON 格式才能处理 f64 数组。"; }
143  virtual void Visit(I8Array const* arr);
144  virtual void Visit(U8Array const* arr);
145  virtual void Visit(I16Array const* arr);
146  virtual void Visit(I32Array const* arr);
147  virtual void Visit(I64Array const* arr);
148  virtual void Visit(JsonObject const* obj);
149  virtual void Visit(JsonNumber const* num);
150  virtual void Visit(JsonInteger const* num);
151  virtual void Visit(JsonNull const* null);
152  virtual void Visit(JsonString const* str);
153  virtual void Visit(JsonBoolean const* boolean);
154 };
155 
156 #if defined(__GLIBC__)
157 template <typename T>
158 T BuiltinBSwap(T v);
159 
160 template <>
161 inline uint16_t BuiltinBSwap(uint16_t v) {
162  return __builtin_bswap16(v);
163 }
164 
165 template <>
166 inline uint32_t BuiltinBSwap(uint32_t v) {
167  return __builtin_bswap32(v);
168 }
169 
170 template <>
171 inline uint64_t BuiltinBSwap(uint64_t v) {
172  return __builtin_bswap64(v);
173 }
174 #else
175 template <typename T>
176 T BuiltinBSwap(T v) {
177  dmlc::ByteSwap(&v, sizeof(v), 1);
178  return v;
179 }
180 #endif // defined(__GLIBC__)
181 
182 template <typename T, std::enable_if_t<sizeof(T) == 1>* = nullptr>
183 inline T ToBigEndian(T v) {
184  return v;
185 }
186 
187 template <typename T, std::enable_if_t<sizeof(T) != 1>* = nullptr>
188 inline T ToBigEndian(T v) {
189  static_assert(std::is_pod<T>::value, "仅支持 POD 类型。");
190 #if DMLC_LITTLE_ENDIAN
191  auto constexpr kS = sizeof(T);
192  std::conditional_t<kS == 2, uint16_t, std::conditional_t<kS == 4, uint32_t, uint64_t>> u;
193  std::memcpy(&u, &v, sizeof(u));
194  u = BuiltinBSwap(u);
195  std::memcpy(&v, &u, sizeof(u));
196 #endif // DMLC_LITTLE_ENDIAN
197  return v;
198 }
199 
203 class UBJReader : public JsonReader {
204  Json Parse();
205 
206  template <typename T>
207  T ReadStream() {
208  auto ptr = this->raw_str_.c_str() + cursor_.Pos();
209  T v{0};
210  std::memcpy(&v, ptr, sizeof(v));
211  cursor_.Forward(sizeof(v));
212  return v;
213  }
214 
215  template <typename T>
216  T ReadPrimitive() {
217  auto v = ReadStream<T>();
218  v = ToBigEndian(v);
219  return v;
220  }
221 
222  template <typename TypedArray>
223  auto ParseTypedArray(std::int64_t n) {
224  TypedArray results{static_cast<size_t>(n)};
225  for (int64_t i = 0; i < n; ++i) {
226  auto v = this->ReadPrimitive<typename TypedArray::Type>();
227  results.Set(i, v);
228  }
229  return Json{std::move(results)};
230  }
231 
232  std::string DecodeStr();
233 
234  Json ParseArray() override;
235  Json ParseObject() override;
236 
237  public
239  Json Load() override;
240 };
241 
245 class UBJWriter : public JsonWriter {
246  void Visit(JsonArray const* arr) override;
247  void Visit(F32Array const* arr) override;
248  void Visit(F64Array const* arr) override;
249  void Visit(I8Array const* arr) override;
250  void Visit(U8Array const* arr) override;
251  void Visit(I16Array const* arr) override;
252  void Visit(I32Array const* arr) override;
253  void Visit(I64Array const* arr) override;
254  void Visit(JsonObject const* obj) override;
255  void Visit(JsonNumber const* num) override;
256  void Visit(JsonInteger const* num) override;
257  void Visit(JsonNull const* null) override;
258  void Visit(JsonString const* str) override;
259  void Visit(JsonBoolean const* boolean) override;
260 
261  public
263  void Save(Json json) override;
264 };
265 } // namespace xgboost
266 
267 #endif // XGBOOST_JSON_IO_H_
为 xgboost 定义配置宏和基本类型。
#define XGBOOST_EXPECT(cond, ret)
定义: base.h:55
定义: json.h:114
std::vector< Json > const & GetArray() &&
定义: json.h:132
描述真值和假值。
定义: json.h:336
定义: json.h:281
定义: json.h:319
定义: json.h:243
定义: json.h:205
一个 JSON 读取器,当前错误检查和 UTF-8 支持不完善。
定义: json_io.h:20
virtual Json ParseString()
virtual Json Load()
virtual Json ParseNumber()
virtual Json ParseArray()
StringView raw_str_
定义: json_io.h:39
Char PeekNextChar()
定义: json_io.h:53
struct xgboost::JsonReader::SourceLocation cursor_
virtual Json ParseBoolean()
Char GetNextChar()
定义: json_io.h:44
void Expect(Char c, Char got)
定义: json_io.h:78
virtual Json ParseObject()
Char GetNextNonSpaceChar()
定义: json_io.h:62
Char GetConsecutiveChar(char expected_char)
定义: json_io.h:69
std::int8_t Char
定义: json_io.h:22
void Error(std::string msg) const
virtual ~JsonReader()=default
constexpr static size_t kMaxNumLength
定义: json_io.h:25
JsonReader(StringView str)
定义: json_io.h:102
virtual Json ParseNull()
定义: json.h:87
用于 Universal Binary JSON 的类型化数组。
定义: json.h:151
定义: json_io.h:110
std::vector< char > * stream_
定义: json_io.h:131
virtual void Save(Json json)
virtual void Visit(F64Array const *)
定义: json_io.h:142
JsonWriter(std::vector< char > *stream)
定义: json_io.h:134
virtual void Visit(JsonInteger const *num)
virtual void Visit(JsonNull const *null)
virtual void Visit(U8Array const *arr)
virtual void Visit(JsonArray const *arr)
virtual void Visit(I8Array const *arr)
virtual void Visit(F32Array const *arr)
virtual void Visit(JsonNumber const *num)
virtual ~JsonWriter()=default
virtual void Visit(I64Array const *arr)
virtual void Visit(JsonBoolean const *boolean)
virtual void Visit(I32Array const *arr)
virtual void Visit(JsonObject const *obj)
virtual void Visit(I16Array const *arr)
virtual void Visit(JsonString const *str)
表示 JSON 格式的数据结构。
定义: json.h:378
UBJSON 读取器 https://ubjson.org/。
定义: json_io.h:203
Json Load() override
UBJSON 写入器 https://ubjson.org/。
定义: json_io.h:245
void Save(Json json) override
Core data structure for multi-target trees.
定义: base.h:89
T BuiltinBSwap(T v)
定义: json_io.h:176
T ToBigEndian(T v)
定义: json_io.h:183
size_t Pos() const
定义: json_io.h:33
void Forward()
定义: json_io.h:35
void Forward(uint32_t n)
定义: json_io.h:36
定义: string_view.h:16
value_type const * c_str() const
定义: string_view.h:50
constexpr std::size_t size() const
定义: string_view.h:43