xgboost
socket.h
前往此文件文档。
1 
4 #pragma once
5 
6 #include <cerrno> // errno, EINTR, EBADF
7 #include <climits> // HOST_NAME_MAX
8 #include <cstddef> // std::size_t
9 #include <cstdint> // std::int32_t, std::uint16_t
10 #include <cstring> // memset
11 #include <string> // std::string
12 #include <system_error> // std::error_code, std::system_category
13 #include <utility> // std::swap
14 
15 #if defined(__linux__)
16 #include <sys/ioctl.h> // for TIOCOUTQ, FIONREAD
17 #endif // defined(__linux__)
18 
19 #if defined(_WIN32)
20 // 守护 include。
21 #include <xgboost/windefs.h>
22 // Socket API
23 #include <winsock2.h>
24 #include <ws2tcpip.h>
25 
26 using in_port_t = std::uint16_t;
27 
28 #ifdef _MSC_VER
29 #pragma comment(lib, "Ws2_32.lib")
30 #endif // _MSC_VER
31 
32 #if !defined(xgboost_IS_MINGW)
33 using ssize_t = int;
34 #endif // !xgboost_IS_MINGW()
35 
36 #else // UNIX
37 
38 #include <arpa/inet.h> // inet_ntop
39 #include <fcntl.h> // fcntl, F_GETFL, O_NONBLOCK
40 #include <netinet/in.h> // sockaddr_in6, sockaddr_in, in_port_t, INET6_ADDRSTRLEN, INET_ADDRSTRLEN
41 #include <netinet/in.h> // IPPROTO_TCP
42 #include <netinet/tcp.h> // TCP_NODELAY
43 #include <sys/socket.h> // socket, SOL_SOCKET, SO_ERROR, MSG_WAITALL, recv, send, AF_INET6, AF_INET
44 #include <unistd.h> // close
45 
46 #if defined(__sun) || defined(sun)
47 #include <sys/sockio.h>
48 #endif // defined(__sun) || defined(sun)
49 
50 #endif // defined(_WIN32)
51 
52 #include "xgboost/base.h" // XGBOOST_EXPECT
53 #include "xgboost/collective/result.h" // for Result
54 #include "xgboost/logging.h" // LOG
55 #include "xgboost/string_view.h" // StringView
56 
57 #if !defined(HOST_NAME_MAX)
58 #define HOST_NAME_MAX 256 // macos
59 #endif
60 
61 namespace xgboost {
62 
63 #if defined(xgboost_IS_MINGW)
64 // see the dummy implementation of `poll` in rabit for more info.
65 inline void MingWError() { LOG(FATAL) << "mingw不支持分布式训练。"; }
66 #endif // defined(xgboost_IS_MINGW)
67 
68 namespace system {
69 inline std::int32_t LastError() {
70 #if defined(_WIN32)
71  return WSAGetLastError();
72 #else
73  int errsv = errno;
74  return errsv;
75 #endif
76 }
77 
78 [[nodiscard]] inline collective::Result FailWithCode(std::string msg) {
79  return collective::Fail(std::move(msg), std::error_code{LastError(), std::system_category()});
80 }
81 
82 #if defined(__GLIBC__)
83 inline auto ThrowAtError(StringView fn_name, std::int32_t errsv = LastError(),
84  std::int32_t line = __builtin_LINE(),
85  char const *file = __builtin_FILE()) {
86  auto err = std::error_code{errsv, std::system_category()};
87  LOG(FATAL) << "\n"
88  << file << "(" << line << "): 调用 `" << fn_name << "` 失败: " << err.message()
89  << std::endl;
90 }
91 #else
92 inline auto ThrowAtError(StringView fn_name, std::int32_t errsv = LastError()) {
93  auto err = std::error_code{errsv, std::system_category()};
94  LOG(FATAL) << "调用 `" << fn_name << "` 失败: " << err.message() << std::endl;
95 }
96 #endif // defined(__GLIBC__)
97 
98 #if defined(_WIN32)
99 using SocketT = SOCKET;
100 #else
101 using SocketT = int;
102 #define INVALID_SOCKET -1
103 #endif // defined(_WIN32)
104 
105 #if !defined(xgboost_CHECK_SYS_CALL)
106 #define xgboost_CHECK_SYS_CALL(exp, expected) \
107  do { \
108  if (XGBOOST_EXPECT((exp) != (expected), false)) { \
109  ::xgboost::system::ThrowAtError(#exp); \
110  } \
111  } while (false)
112 #endif // !defined(xgboost_CHECK_SYS_CALL)
113 
114 inline std::int32_t CloseSocket(SocketT fd) {
115 #if defined(_WIN32)
116  return closesocket(fd);
117 #else
118  return close(fd);
119 #endif
120 }
121 
122 inline std::int32_t ShutdownSocket(SocketT fd) {
123 #if defined(_WIN32)
124  auto rc = shutdown(fd, SD_BOTH);
125  if (rc != 0 && LastError() == WSANOTINITIALISED) {
126  return 0;
127  }
128 #else
129  auto rc = shutdown(fd, SHUT_RDWR);
130  if (rc != 0 && LastError() == ENOTCONN) {
131  return 0;
132  }
133 #endif
134  return rc;
135 }
136 
137 inline bool ErrorWouldBlock(std::int32_t errsv) noexcept(true) {
138 #ifdef _WIN32
139  return errsv == WSAEWOULDBLOCK;
140 #else
141  return errsv == EAGAIN || errsv == EWOULDBLOCK || errsv == EINPROGRESS;
142 #endif // _WIN32
143 }
144 
145 inline bool LastErrorWouldBlock() {
146  int errsv = LastError();
147  return ErrorWouldBlock(errsv);
148 }
149 
150 inline void SocketStartup() {
151 #if defined(_WIN32)
152  WSADATA wsa_data;
153  if (WSAStartup(MAKEWORD(2, 2), &wsa_data) == -1) {
154  ThrowAtError("WSAStartup");
155  }
156  if (LOBYTE(wsa_data.wVersion) != 2 || HIBYTE(wsa_data.wVersion) != 2) {
157  WSACleanup();
158  LOG(FATAL) << "找不到可用的 Winsock.dll 版本";
159  }
160 #endif // defined(_WIN32)
161 }
162 
163 inline void SocketFinalize() {
164 #if defined(_WIN32)
165  WSACleanup();
166 #endif // defined(_WIN32)
167 }
168 
169 #if defined(_WIN32) && defined(xgboost_IS_MINGW)
170 // 为旧 mysys32 的虚拟定义。
171 inline const char *inet_ntop(int, const void *, char *, socklen_t) { // NOLINT
172  MingWError();
173  return nullptr;
174 }
175 #else
176 using ::inet_ntop;
177 #endif // defined(_WIN32) && defined(xgboost_IS_MINGW)
178 
179 } // namespace system
180 
181 namespace collective {
182 class SockAddress;
183 
184 enum class SockDomain : std::int32_t { kV4 = AF_INET, kV6 = AF_INET6 };
185 
190 SockAddress MakeSockAddress(StringView host, in_port_t port);
191 
192 class SockAddrV6 {
193  sockaddr_in6 addr_;
194 
195  public
196  explicit SockAddrV6(sockaddr_in6 addr) : addr_{addr} {}
197  SockAddrV6() { std::memset(&addr_, '\0', sizeof(addr_)); }
198 
201 
202  in_port_t Port() const { return ntohs(addr_.sin6_port); }
203 
204  std::string Addr() const {
205  char buf[INET6_ADDRSTRLEN];
206  auto const *s = system::inet_ntop(static_cast<std::int32_t>(SockDomain::kV6), &addr_.sin6_addr,
207  buf, INET6_ADDRSTRLEN);
208  if (s == nullptr) {
209  system::ThrowAtError("inet_ntop");
210  }
211  return {buf};
212  }
213  sockaddr_in6 const &Handle() const { return addr_; }
214 };
215 
216 class SockAddrV4 {
217  private
218  sockaddr_in addr_;
219 
220  public
221  explicit SockAddrV4(sockaddr_in addr) : addr_{addr} {}
222  SockAddrV4() { std::memset(&addr_, '\0', sizeof(addr_)); }
223 
226 
227  [[nodiscard]] in_port_t Port() const { return ntohs(addr_.sin_port); }
228 
229  [[nodiscard]] std::string Addr() const {
230  char buf[INET_ADDRSTRLEN];
231  auto const *s = system::inet_ntop(static_cast<std::int32_t>(SockDomain::kV4), &addr_.sin_addr,
232  buf, INET_ADDRSTRLEN);
233  if (s == nullptr) {
234  system::ThrowAtError("inet_ntop");
235  }
236  return {buf};
237  }
238  [[nodiscard]] sockaddr_in const &Handle() const { return addr_; }
239 };
240 
244 class SockAddress {
245  private
246  SockAddrV6 v6_;
247  SockAddrV4 v4_;
248  SockDomain domain_{SockDomain::kV4};
249 
250  public
251  SockAddress() = default;
252  explicit SockAddress(SockAddrV6 const &addr) : v6_{addr}, domain_{SockDomain::kV6} {}
253  explicit SockAddress(SockAddrV4 const &addr) : v4_{addr} {}
254 
255  [[nodiscard]] auto Domain() const { return domain_; }
256 
257  [[nodiscard]] bool IsV4() const { return Domain() == SockDomain::kV4; }
258  [[nodiscard]] bool IsV6() const { return !IsV4(); }
259 
260  [[nodiscard]] auto const &V4() const { return v4_; }
261  [[nodiscard]] auto const &V6() const { return v6_; }
262 };
263 
267 class TCPSocket {
268  public
270 
271  private
272  HandleT handle_{InvalidSocket()};
273  bool non_blocking_{false};
274  // 在 macOS 上,不先绑定套接字就无法可靠地从套接字中提取域。
275  // socket on macos.
276 #if defined(__APPLE__)
277  SockDomain domain_{SockDomain::kV4};
278 #endif
279 
280  constexpr static HandleT InvalidSocket() { return INVALID_SOCKET; }
281 
282  explicit TCPSocket(HandleT newfd) : handle_{newfd} {}
283 
284  public
285  TCPSocket() = default;
289  [[nodiscard]] auto Domain() const -> SockDomain {
290  auto ret_iafamily = [](std::int32_t domain) {
291  switch (domain) {
292  case AF_INET
293  return SockDomain::kV4;
294  case AF_INET6
295  return SockDomain::kV6;
296  default: {
297  LOG(FATAL) << "未知 IA 家族。";
298  }
299  }
300  return SockDomain::kV4;
301  };
302 
303 #if defined(_WIN32)
304  WSAPROTOCOL_INFOW info;
305  socklen_t len = sizeof(info);
307  getsockopt(handle_, SOL_SOCKET, SO_PROTOCOL_INFO, reinterpret_cast<char *>(&info), &len),
308  0);
309  return ret_iafamily(info.iAddressFamily);
310 #elif defined(__APPLE__)
311  return domain_;
312 #elif defined(__unix__)
313 #ifndef __PASE__
314  std::int32_t domain;
315  socklen_t len = sizeof(domain);
317  getsockopt(this->Handle(), SOL_SOCKET, SO_DOMAIN, reinterpret_cast<char *>(&domain), &len),
318  0);
319  return ret_iafamily(domain);
320 #else
321  struct sockaddr sa;
322  socklen_t sizeofsa = sizeof(sa);
323  xgboost_CHECK_SYS_CALL(getsockname(handle_, &sa, &sizeofsa), 0);
324  if (sizeofsa < sizeof(uchar_t) * 2) {
325  return ret_iafamily(AF_INET);
326  }
327  return ret_iafamily(sa.sa_family);
328 #endif // __PASE__
329 #else
330  LOG(FATAL) << "未知平台。";
331  return ret_iafamily(AF_INET);
332 #endif // platforms
333  }
334 
335  [[nodiscard]] bool IsClosed() const { return handle_ == InvalidSocket(); }
336 
338  [[nodiscard]] Result GetSockError() const {
339  std::int32_t optval = 0;
340  socklen_t len = sizeof(optval);
341  auto ret = getsockopt(handle_, SOL_SOCKET, SO_ERROR, reinterpret_cast<char *>(&optval), &len);
342  if (ret != 0) {
343  auto errc = std::error_code{system::LastError(), std::system_category()};
344  return Fail("获取套接字错误失败。", std::move(errc));
345  }
346  if (optval != 0) {
347  auto errc = std::error_code{optval, std::system_category()};
348  return Fail("套接字错误。", std::move(errc));
349  }
350  return Success();
351  }
352 
354  [[nodiscard]] bool BadSocket() const {
355  if (IsClosed()) {
356  return true;
357  }
358  auto err = GetSockError();
359  if (err.Code() == std::error_code{EBADF, std::system_category()} || // NOLINT
360  err.Code() == std::error_code{EINTR, std::system_category()}) { // NOLINT
361  return true;
362  }
363  return false;
364  }
365 
366  [[nodiscard]] Result NonBlocking(bool non_block) {
367 #if defined(_WIN32)
368  u_long mode = non_block ? 1 : 0;
369  if (ioctlsocket(handle_, FIONBIO, &mode) != NO_ERROR) {
370  return system::FailWithCode("设置套接字为非阻塞失败。");
371  }
372 #else
373  std::int32_t flag = fcntl(handle_, F_GETFL, 0);
374  auto rc = flag;
375  if (rc == -1) {
376  return system::FailWithCode("获取套接字标志失败。");
377  }
378  if (non_block) {
379  flag |= O_NONBLOCK;
380  } else {
381  flag &= ~O_NONBLOCK;
382  }
383  rc = fcntl(handle_, F_SETFL, flag);
384  if (rc == -1) {
385  return system::FailWithCode("设置套接字为非阻塞失败。");
386  }
387 #endif // _WIN32
388  non_blocking_ = non_block;
389  return Success();
390  }
391  [[nodiscard]] bool NonBlocking() const { return non_blocking_; }
392  [[nodiscard]] Result RecvTimeout(std::chrono::seconds timeout) {
393  // https://stackoverflow.com/questions/2876024/linux-is-there-a-read-or-recv-from-socket-with-timeout
394 #if defined(_WIN32)
395  DWORD tv = timeout.count() * 1000;
396  auto rc =
397  setsockopt(Handle(), SOL_SOCKET, SO_RCVTIMEO, reinterpret_cast<char *>(&tv), sizeof(tv));
398 #else
399  struct timeval tv;
400  tv.tv_sec = timeout.count();
401  tv.tv_usec = 0;
402  auto rc = setsockopt(Handle(), SOL_SOCKET, SO_RCVTIMEO, reinterpret_cast<char const *>(&tv),
403  sizeof(tv));
404 #endif
405  if (rc != 0) {
406  return system::FailWithCode("设置接收超时失败。");
407  }
408  return Success();
409  }
410 
411  [[nodiscard]] Result SetBufSize(std::int32_t n_bytes) {
412  auto rc = setsockopt(this->Handle(), SOL_SOCKET, SO_SNDBUF, reinterpret_cast<char *>(&n_bytes),
413  sizeof(n_bytes));
414  if (rc != 0) {
415  return system::FailWithCode("设置发送缓冲区大小失败。");
416  }
417  rc = setsockopt(this->Handle(), SOL_SOCKET, SO_RCVBUF, reinterpret_cast<char *>(&n_bytes),
418  sizeof(n_bytes));
419  if (rc != 0) {
420  return system::FailWithCode("设置接收缓冲区大小失败。");
421  }
422  return Success();
423  }
424 
425  [[nodiscard]] Result SendBufSize(std::int32_t *n_bytes) {
426  socklen_t optlen;
427  auto rc = getsockopt(this->Handle(), SOL_SOCKET, SO_SNDBUF, reinterpret_cast<char *>(n_bytes),
428  &optlen);
429  if (rc != 0 || optlen != sizeof(std::int32_t)) {
430  return system::FailWithCode("getsockopt");
431  }
432  return Success();
433  }
434  [[nodiscard]] Result RecvBufSize(std::int32_t *n_bytes) {
435  socklen_t optlen;
436  auto rc = getsockopt(this->Handle(), SOL_SOCKET, SO_RCVBUF, reinterpret_cast<char *>(n_bytes),
437  &optlen);
438  if (rc != 0 || optlen != sizeof(std::int32_t)) {
439  return system::FailWithCode("getsockopt");
440  }
441  return Success();
442  }
443 #if defined(__linux__)
444  [[nodiscard]] Result PendingSendSize(std::int32_t *n_bytes) const {
445  return ioctl(this->Handle(), TIOCOUTQ, n_bytes) == 0 ? Success()
446  : system::FailWithCode("ioctl");
447  }
448  [[nodiscard]] Result PendingRecvSize(std::int32_t *n_bytes) const {
449  return ioctl(this->Handle(), FIONREAD, n_bytes) == 0 ? Success()
450  : system::FailWithCode("ioctl");
451  }
452 #endif // defined(__linux__)
453 
454  [[nodiscard]] Result SetKeepAlive() {
455  std::int32_t keepalive = 1;
456  auto rc = setsockopt(handle_, SOL_SOCKET, SO_KEEPALIVE, reinterpret_cast<char *>(&keepalive),
457  sizeof(keepalive));
458  if (rc != 0) {
459  return system::FailWithCode("设置 TCP keepalive 失败。");
460  }
461  return Success();
462  }
463 
464  [[nodiscard]] Result SetNoDelay(std::int32_t no_delay = 1) {
465  auto rc = setsockopt(handle_, IPPROTO_TCP, TCP_NODELAY, reinterpret_cast<char *>(&no_delay),
466  sizeof(no_delay));
467  if (rc != 0) {
468  return system::FailWithCode("设置 TCP 无延迟失败。");
469  }
470  return Success();
471  }
472 
477  SockAddress addr;
478  TCPSocket newsock;
479  auto rc = this->Accept(&newsock, &addr);
480  SafeColl(rc);
481  return newsock;
482  }
483 
484  [[nodiscard]] Result Accept(TCPSocket *out, SockAddress *addr) {
485 #if defined(_WIN32)
486  auto interrupt = WSAEINTR;
487 #else
488  auto interrupt = EINTR;
489 #endif
490  if (this->Domain() == SockDomain::kV4) {
491  struct sockaddr_in caddr;
492  socklen_t caddr_len = sizeof(caddr);
493  HandleT newfd = accept(Handle(), reinterpret_cast<sockaddr *>(&caddr), &caddr_len);
494  if (newfd == InvalidSocket() && system::LastError() != interrupt) {
495  return system::FailWithCode("接受失败。");
496  }
497  *addr = SockAddress{SockAddrV4{caddr}};
498  *out = TCPSocket{newfd};
499  } else {
500  struct sockaddr_in6 caddr;
501  socklen_t caddr_len = sizeof(caddr);
502  HandleT newfd = accept(Handle(), reinterpret_cast<sockaddr *>(&caddr), &caddr_len);
503  if (newfd == InvalidSocket() && system::LastError() != interrupt) {
504  return system::FailWithCode("接受失败。");
505  }
506  *addr = SockAddress{SockAddrV6{caddr}};
507  *out = TCPSocket{newfd};
508  }
509  // 在 MacOS 上,如果父套接字是异步的,则此操作会自动设置为异步套接字
510  // We make sure all socket are blocking by default.
511  //
512  // 在 Windows 上,关闭套接字会在关闭期间返回。我们在设置非阻塞时会对此进行防护。
513  // setting non-blocking.
514  if (!out->IsClosed()) {
515  return out->NonBlocking(false);
516  }
517  return Success();
518  }
519 
521  if (!IsClosed()) {
522  auto rc = this->Close();
523  if (!rc.OK()) {
524  LOG(WARNING) << rc.Report();
525  }
526  }
527  }
528 
529  TCPSocket(TCPSocket const &that) = delete;
530  TCPSocket(TCPSocket &&that) noexcept(true) { std::swap(this->handle_, that.handle_); }
531  TCPSocket &operator=(TCPSocket const &that) = delete;
532  TCPSocket &operator=(TCPSocket &&that) noexcept(true) {
533  std::swap(this->handle_, that.handle_);
534  return *this;
535  }
539  [[nodiscard]] HandleT const &Handle() const { return handle_; }
545  [[nodiscard]] Result Listen(std::int32_t backlog = 256);
549  [[nodiscard]] Result BindHost(std::int32_t* p_out) {
550  // 为了一致性,使用 int32 而不是 in_port_t。我们从参数中获取端口作为参数
551  // 使用其他语言的用户,端口通常作为整数存储和传递。
552  如果 (() == 套接字域::kV6) {
553  自动 地址 = SockAddrV6::InaddrAny();
554  自动 句柄 = reinterpret_cast<sockaddr const *>(&地址.句柄());
555  如果 (bind(handle_, handle, sizeof(std::remove_reference_t<decltype(地址.句柄())>)) != 0) {
556  返回 系统::FailWithCode("绑定失败。");
557  }
558 
559  sockaddr_in6 res_addr;
560  socklen_t addrlen = sizeof(res_addr);
561  如果 (getsockname(handle_, reinterpret_cast<sockaddr *>(&res_addr), &addrlen) != 0) {
562  返回 系统::FailWithCode("获取套接字名称失败。");
563  }
564  *p_out = ntohs(res_addr.sin6_port);
565  } 否则 {
566  自动 地址 = SockAddrV4::InaddrAny();
567  自动 句柄 = reinterpret_cast<sockaddr const *>(&地址.句柄());
568  如果 (bind(handle_, handle, sizeof(std::remove_reference_t<decltype(地址.句柄())>)) != 0) {
569  返回 系统::FailWithCode("绑定失败。");
570  }
571 
572  sockaddr_in res_addr;
573  socklen_t addrlen = sizeof(res_addr);
574  如果 (getsockname(handle_, reinterpret_cast<sockaddr *>(&res_addr), &addrlen) != 0) {
575  返回 系统::FailWithCode("获取套接字名称失败。");
576  }
577  *p_out = ntohs(res_addr.sin_port);
578  }
579 
580  返回 成功();
581  }
582 
583  [[nodiscard]] auto 端口() const {
584  如果 (this->() == 套接字域::kV4) {
585  sockaddr_in res_addr;
586  socklen_t addrlen = sizeof(res_addr);
587  自动 代码 = getsockname(handle_, reinterpret_cast<sockaddr *>(&res_addr), &addrlen);
588  如果 (代码 != 0) {
589  返回 std::make_pair(系统::FailWithCode("获取套接字名称"), std::int32_t{0});
590  }
591  返回 std::make_pair(成功(), std::int32_t{ntohs(res_addr.sin_port)});
592  } 否则 {
593  sockaddr_in6 res_addr;
594  socklen_t addrlen = sizeof(res_addr);
595  自动 代码 = getsockname(handle_, reinterpret_cast<sockaddr *>(&res_addr), &addrlen);
596  如果 (代码 != 0) {
597  返回 std::make_pair(系统::FailWithCode("获取套接字名称"), std::int32_t{0});
598  }
599  返回 std::make_pair(成功(), std::int32_t{ntohs(res_addr.sin6_port)});
600  }
601  }
608  [[nodiscard]] 结果 绑定(StringView ip, std::int32_t *端口) {
609  // 将套接字 handle_ 绑定到 ip
610  自动 地址 = MakeSockAddress(ip, *端口);
611  std::int32_t errc{0};
612  如果 (地址.IsV4()) {
613  自动 句柄 = reinterpret_cast<sockaddr const *>(&地址.V4().句柄());
614  errc = bind(handle_, handle, sizeof(std::remove_reference_t<decltype(地址.V4().句柄())>));
615  } 否则 {
616  自动 句柄 = reinterpret_cast<sockaddr const *>(&地址.V6().句柄());
617  errc = bind(handle_, handle, sizeof(std::remove_reference_t<decltype(地址.V6().句柄())>));
618  }
619  如果 (errc != 0) {
620  返回 系统::FailWithCode("绑定套接字失败。");
621  }
622  自动 [rc, new_port] = this->端口();
623  如果 (!rc.OK()) {
624  返回 std::move(rc);
625  }
626  如果 (*端口 == 0) {
627  *端口 = new_port;
628  返回 成功();
629  }
630  如果 (*端口 != new_port) {
631  返回 失败("从绑定获取的端口无效。");
632  }
633  返回 成功();
634  }
635 
639  [[nodiscard]] 结果 SendAll(void const *buf, std::size_t len, std::size_t *n_sent) {
640  char const *_buf = reinterpret_cast<const char *>(buf);
641  std::size_t &ndone = *n_sent;
642  ndone = 0;
643  (ndone < len) {
644  ssize_t ret = send(handle_, _buf, len - ndone, 0);
645  如果 (ret == -1) {
646  如果 (系统::LastErrorWouldBlock()) {
647  返回 成功();
648  }
649  返回 系统::FailWithCode("发送");
650  }
651  _buf += ret;
652  ndone += ret;
653  }
654  返回 成功();
655  }
659  [[nodiscard]] 结果 RecvAll(void *buf, std::size_t len, std::size_t *n_recv) {
660  char *_buf = reinterpret_cast<char *>(buf);
661  std::size_t &ndone = *n_recv;
662  ndone = 0;
663  (ndone < len) {
664  ssize_t ret = recv(handle_, _buf, len - ndone, MSG_WAITALL);
665  如果 (ret == -1) {
666  如果 (系统::LastErrorWouldBlock()) {
667  返回 成功();
668  }
669  返回 系统::FailWithCode("接收");
670  }
671  如果 (ret == 0) {
672  返回 成功();
673  }
674  _buf += ret;
675  ndone += ret;
676  }
677  返回 成功();
678  }
686  自动 发送(const void *buf_, std::size_t len, std::int32_t flags = 0) {
687  const char *buf = reinterpret_cast<const char *>(buf_);
688  返回 send(handle_, buf, len, flags);
689  }
697  自动 接收(void *buf, std::size_t len, std::int32_t flags = 0) {
698  char *_buf = static_cast<char *>(buf);
699  // 请参阅 https://github.com/llvm/llvm-project/issues/104241 以了解跳过的整洁分析
700  // NOLINTBEGIN(clang-analyzer-unix.BlockInCriticalSection)
701  返回 recv(handle_, _buf, len, flags);
702  // NOLINTEND(clang-analyzer-unix.BlockInCriticalSection)
703  }
707  std::size_t 发送(StringView str);
711  [[nodiscard]] 结果 接收(std::string *p_str);
715  [[nodiscard]] 结果 关闭() {
716  如果 (InvalidSocket() != handle_) {
717  自动 rc = 系统::CloseSocket(handle_);
718 #if defined(_WIN32)
719  // 由于分离线程,我们可能会在完成 WSA 后关闭 TCP 套接字。
720  如果 (rc != 0 && 系统::LastError() != WSANOTINITIALISED) {
721  返回 系统::FailWithCode("关闭套接字失败。");
722  }
723 #else
724  如果 (rc != 0) {
725  返回 系统::FailWithCode("关闭套接字失败。");
726  }
727 #endif
728  handle_ = InvalidSocket();
729  }
730  返回 成功();
731  }
735  [[nodiscard]] 结果 Shutdown() {
736  如果 (this->已关闭()) {
737  返回 成功();
738  }
739  自动 rc = 系统::ShutdownSocket(this->句柄());
740 #if defined(_WIN32)
741  // 如果套接字未连接,Windows 无法关闭套接字。
742  如果 (rc == -1 && 系统::LastError() == WSAENOTCONN) {
743  返回 成功();
744  }
745 #endif
746  如果 (rc != 0) {
747  返回 系统::FailWithCode("关闭套接字失败。");
748  }
749  返回 成功();
750  }
751 
755  static TCPSocket 创建(SockDomain domain) {
756 #if defined(xgboost_IS_MINGW)
757  MingWError();
758  返回 {};
759 #else
760  自动 fd = socket(static_cast<std::int32_t>(domain), SOCK_STREAM, 0);
761  如果 (fd == InvalidSocket()) {
762  系统::ThrowAtError("套接字");
763  }
764 
765  TCPSocket 套接字{fd};
766 #if defined(__APPLE__)
767  socket.domain_ = domain;
768 #endif // defined(__APPLE__)
769  返回 套接字;
770 #endif // defined(xgboost_IS_MINGW)
771  }
772 
774 #if defined(xgboost_IS_MINGW)
775  MingWError();
776  返回 nullptr;
777 #else
778  自动 fd = socket(static_cast<std::int32_t>(domain), SOCK_STREAM, 0);
779  如果 (fd == InvalidSocket()) {
780  系统::ThrowAtError("套接字");
781  }
782  自动 套接字 = new TCPSocket{fd};
783 
784 #if defined(__APPLE__)
785  socket->domain_ = domain;
786 #endif // defined(__APPLE__)
787  返回 套接字;
788 #endif // defined(xgboost_IS_MINGW)
789  }
790 };
791 
804 [[nodiscard]] 结果 连接(xgboost::StringView host, std::int32_t port, std::int32_t retry,
805  std::chrono::seconds timeout,
807 
811 [[nodiscard]] 结果 GetHostName(std::string *p_out);
812 
816 模板 <typename H>
817 结果 INetNToP(H const &host, std::string *p_out) {
818  std::string &ip = *p_out;
819  根据 (host->h_addrtype) {
820  案例 AF_INET: {
821  自动 地址 = reinterpret_cast<struct in_addr *>(host->h_addr_list[0]);
822  char str[INET_ADDRSTRLEN];
823  inet_ntop(AF_INET, addr, str, INET_ADDRSTRLEN);
824  ip = str;
825  中断;
826  }
827  案例 AF_INET6: {
828  自动 地址 = reinterpret_cast<struct in6_addr *>(host->h_addr_list[0]);
829  char str[INET6_ADDRSTRLEN];
830  inet_ntop(AF_INET6, addr, str, INET6_ADDRSTRLEN);
831  ip = str;
832  中断;
833  }
834  默认: {
835  返回 失败("无效地址类型。");
836  }
837  }
838  返回 成功();
839 }
840 } // 命名空间 collective
841 } // 命名空间 xgboost
842 
843 #undef xgboost_CHECK_SYS_CALL
为 xgboost 定义配置宏和基本类型。
SockAddrV4(sockaddr_in addr)
定义: socket.h:221
static SockAddrV4 InaddrAny()
in_port_t Port() const
定义: socket.h:227
sockaddr_in const & Handle() const
定义: socket.h:238
std::string Addr() const
定义: socket.h:229
static SockAddrV4 Loopback()
SockAddrV4()
定义: socket.h:222
static SockAddrV6 InaddrAny()
SockAddrV6()
定义: socket.h:197
sockaddr_in6 const & Handle() const
定义: socket.h:213
in_port_t Port() const
定义: socket.h:202
SockAddrV6(sockaddr_in6 addr)
定义: socket.h:196
std::string Addr() const
定义: socket.h:204
static SockAddrV6 Loopback()
TCP套接字的地址,可以是IPv4或IPv6。
定义: socket.h:244
bool IsV6() const
定义: socket.h:258
auto const & V6() const
定义: socket.h:261
bool IsV4() const
定义: socket.h:257
auto Domain() const
定义: socket.h:255
SockAddress(SockAddrV4 const &addr)
定义: socket.h:253
auto const & V4() const
定义: socket.h:260
SockAddress(SockAddrV6 const &addr)
定义: socket.h:252
用于简单通信的TCP套接字。
定义: socket.h:267
Result GetSockError() const
获取上次错误代码(如果有)
定义: socket.h:338
Result Recv(std::string *p_str)
接收字符串,格式与RABIT中的Python套接字包装器匹配。
Result RecvTimeout(std::chrono::seconds timeout)
定义: socket.h:392
HandleT const & Handle() const
返回原生套接字文件描述符。
定义: socket.h:539
Result SetNoDelay(std::int32_t no_delay=1)
定义: socket.h:464
TCPSocket & operator=(TCPSocket const &that)=delete
Result BindHost(std::int32_t *p_out)
将套接字绑定到INADDR_ANY,返回操作系统选择的端口。
定义: socket.h:549
Result Listen(std::int32_t backlog=256)
监听传入请求。应在绑定后调用。
static TCPSocket * CreatePtr(SockDomain domain)
定义: socket.h:773
Result Shutdown()
关闭套接字。
定义: socket.h:735
system::SocketT HandleT
定义: socket.h:269
Result Bind(StringView ip, std::int32_t *port)
将套接字绑定到地址。
定义: socket.h:608
TCPSocket & operator=(TCPSocket &&that) noexcept(true)
定义: socket.h:532
auto Port() const
定义: socket.h:583
Result SetKeepAlive()
定义: socket.h:454
Result Accept(TCPSocket *out, SockAddress *addr)
定义: socket.h:484
TCPSocket(TCPSocket const &that)=delete
Result RecvBufSize(std::int32_t *n_bytes)
定义: socket.h:434
Result SendBufSize(std::int32_t *n_bytes)
定义: socket.h:425
std::size_t Send(StringView str)
发送字符串,格式与RABIT中的Python套接字包装器匹配。
auto Domain() const -> SockDomain
返回套接字域。
定义: socket.h:289
static TCPSocket Create(SockDomain domain)
在指定域上创建TCP套接字。
定义: socket.h:755
Result RecvAll(void *buf, std::size_t len, std::size_t *n_recv)
接收数据,无错误则应接收所有数据。
定义: socket.h:659
auto Recv(void *buf, std::size_t len, std::int32_t flags=0)
使用套接字接收数据
定义: socket.h:697
bool NonBlocking() const
定义: socket.h:391
bool IsClosed() const
定义: socket.h:335
Result NonBlocking(bool non_block)
定义: socket.h:366
TCPSocket Accept()
接受新连接,为新连接返回一个新的TCP套接字。
定义: socket.h:476
~TCPSocket()
定义: socket.h:520
auto Send(const void *buf_, std::size_t len, std::int32_t flags=0)
使用套接字发送数据。
定义: socket.h:686
Result SendAll(void const *buf, std::size_t len, std::size_t *n_sent)
发送数据,无错误则应发送所有数据。
定义: socket.h:639
bool BadSocket() const
检查是否发生异常
定义: socket.h:354
TCPSocket(TCPSocket &&that) noexcept(true)
定义: socket.h:530
Result SetBufSize(std::int32_t n_bytes)
定义: socket.h:411
Result Close()
关闭套接字,如果套接字未关闭,则在析构函数中自动调用。
定义: socket.h:715
void swap(xgboost::IntrusivePtr< T > &x, xgboost::IntrusivePtr< T > &y) noexcept
定义: intrusive_ptr.h:209
Result Connect(xgboost::StringView host, std::int32_t port, std::int32_t retry, std::chrono::seconds timeout, xgboost::collective::TCPSocket *out_conn)
连接到远程地址,如果失败则返回错误代码。
Result GetHostName(std::string *p_out)
获取本地主机名。
SockAddress MakeSockAddress(StringView host, in_port_t port)
解析主机地址并返回SockAddress实例。支持IPv4和IPv6主机。
SockDomain
定义: socket.h:184
void SafeColl(Result const &rc, char const *file=__builtin_FILE(), std::int32_t line=__builtin_LINE())
auto Fail(std::string msg, char const *file=__builtin_FILE(), std::int32_t line=__builtin_LINE())
返回失败。
定义: result.h:124
Result INetNToP(H const &host, std::string *p_out)
inet_ntop
定义: socket.h:817
auto Success() noexcept(true)
返回成功。
定义: result.h:120
bool ErrorWouldBlock(std::int32_t errsv) noexcept(true)
定义: socket.h:137
auto ThrowAtError(StringView fn_name, std::int32_t errsv=LastError())
定义: socket.h:92
void SocketStartup()
定义: socket.h:150
std::int32_t CloseSocket(SocketT fd)
定义: socket.h:114
bool LastErrorWouldBlock()
定义: socket.h:145
std::int32_t LastError()
定义: socket.h:69
void SocketFinalize()
定义: socket.h:163
std::int32_t ShutdownSocket(SocketT fd)
定义: socket.h:122
collective::Result FailWithCode(std::string msg)
定义: socket.h:78
int SocketT
定义: socket.h:101
集成目标、gbm和评估的学习器接口。这是用户面临的XGB...
Definition: base.h:97
int SOCKET
定义: poll_utils.h:40
#define __builtin_LINE()
定义: result.h:57
#define __builtin_FILE()
定义: result.h:56
#define INVALID_SOCKET
定义: socket.h:102
#define xgboost_CHECK_SYS_CALL(exp, expected)
定义: socket.h:106
Definition: string_view.h:16
一种比抛出dmlc异常更容易处理的错误类型。我们可以记录并传播s...
定义: result.h:67