xgboost
socket.h
转到此文件的文档。
1 
4 #pragma once
5 
6 #include <cerrno> // errno, EINTR, EBADF
7 #include <climits> // HOST_NAME_MAX
8 #include <cstddef> // std::size_t
9 #include <cstdint> // std::int32_t, std::uint16_t
10 #include <cstring> // memset
11 #include <string> // std::string
12 #include <system_error> // std::error_code, std::system_category
13 #include <utility> // std::swap
14 
15 #if defined(__linux__)
16 #include <sys/ioctl.h> // for TIOCOUTQ, FIONREAD
17 #endif // defined(__linux__)
18 
19 #if defined(_WIN32)
20 // 保护头文件包含。
21 #include <xgboost/windefs.h>
22 // 套接字 API
23 #include <winsock2.h>
24 #include <ws2tcpip.h>
25 
26 using in_port_t = std::uint16_t;
27 
28 #ifdef _MSC_VER
29 #pragma comment(lib, "Ws2_32.lib")
30 #endif // _MSC_VER
31 
32 #if !defined(xgboost_IS_MINGW)
33 using ssize_t = int;
34 #endif // !xgboost_IS_MINGW()
35 
36 #else // UNIX
37 
38 #include <arpa/inet.h> // inet_ntop
39 #include <fcntl.h> // fcntl, F_GETFL, O_NONBLOCK
40 #include <netinet/in.h> // sockaddr_in6, sockaddr_in, in_port_t, INET6_ADDRSTRLEN, INET_ADDRSTRLEN
41 #include <netinet/in.h> // IPPROTO_TCP
42 #include <netinet/tcp.h> // TCP_NODELAY
43 #include <sys/socket.h> // socket, SOL_SOCKET, SO_ERROR, MSG_WAITALL, recv, send, AF_INET6, AF_INET
44 #include <unistd.h> // close
45 
46 #if defined(__sun) || defined(sun)
47 #include <sys/sockio.h>
48 #endif // defined(__sun) || defined(sun)
49 
50 #endif // defined(_WIN32)
51 
52 #include "xgboost/base.h" // XGBOOST_EXPECT
53 #include "xgboost/collective/result.h" // for Result
54 #include "xgboost/logging.h" // LOG
55 #include "xgboost/string_view.h" // StringView
56 
57 #if !defined(HOST_NAME_MAX)
58 #define HOST_NAME_MAX 256 // macos
59 #endif
60 
61 namespace xgboost {
62 
63 #if defined(xgboost_IS_MINGW)
64 // 有关更多信息,请参阅 rabit 中 `poll` 的模拟实现。
65 inline void MingWError() { LOG(FATAL) << "mingw 不支持分布式训练。"; }
66 #endif // defined(xgboost_IS_MINGW)
67 
68 namespace system {
69 inline std::int32_t LastError() {
70 #if defined(_WIN32)
71  return WSAGetLastError();
72 #else
73  int errsv = errno;
74  return errsv;
75 #endif
76 }
77 
78 [[nodiscard]] inline collective::Result FailWithCode(std::string msg) {
79  return collective::Fail(std::move(msg), std::error_code{LastError(), std::system_category()});
80 }
81 
82 #if defined(__GLIBC__)
83 inline auto ThrowAtError(StringView fn_name, std::int32_t errsv = LastError(),
84  std::int32_t line = __builtin_LINE(),
85  char const *file = __builtin_FILE()) {
86  auto err = std::error_code{errsv, std::system_category()};
87  LOG(FATAL) << "\n"
88  << file << "(" << line << "): Failed to call `" << fn_name << "`: " << err.message()
89  << std::endl;
90 }
91 #else
92 inline auto ThrowAtError(StringView fn_name, std::int32_t errsv = LastError()) {
93  auto err = std::error_code{errsv, std::system_category()};
94  LOG(FATAL) << "Failed to call `" << fn_name << "`: " << err.message() << std::endl;
95 }
96 #endif // defined(__GLIBC__)
97 
98 #if defined(_WIN32)
99 using SocketT = SOCKET;
100 #else
101 using SocketT = int;
102 #define INVALID_SOCKET -1
103 #endif // defined(_WIN32)
104 
105 #if !defined(xgboost_CHECK_SYS_CALL)
106 #define xgboost_CHECK_SYS_CALL(exp, expected) \
107  do { \
108  if (XGBOOST_EXPECT((exp) != (expected), false)) { \
109  ::xgboost::system::ThrowAtError(#exp); \
110  } \
111  } while (false)
112 #endif // !defined(xgboost_CHECK_SYS_CALL)
113 
114 inline std::int32_t CloseSocket(SocketT fd) {
115 #if defined(_WIN32)
116  return closesocket(fd);
117 #else
118  return close(fd);
119 #endif
120 }
121 
122 inline std::int32_t ShutdownSocket(SocketT fd) {
123 #if defined(_WIN32)
124  auto rc = shutdown(fd, SD_BOTH);
125  if (rc != 0 && LastError() == WSANOTINITIALISED) {
126  return 0;
127  }
128 #else
129  auto rc = shutdown(fd, SHUT_RDWR);
130  if (rc != 0 && LastError() == ENOTCONN) {
131  return 0;
132  }
133 #endif
134  return rc;
135 }
136 
137 inline bool ErrorWouldBlock(std::int32_t errsv) noexcept(true) {
138 #ifdef _WIN32
139  return errsv == WSAEWOULDBLOCK;
140 #else
141  return errsv == EAGAIN || errsv == EWOULDBLOCK || errsv == EINPROGRESS;
142 #endif // _WIN32
143 }
144 
145 inline bool LastErrorWouldBlock() {
146  int errsv = LastError();
147  return ErrorWouldBlock(errsv);
148 }
149 
150 inline void SocketStartup() {
151 #if defined(_WIN32)
152  WSADATA wsa_data;
153  if (WSAStartup(MAKEWORD(2, 2), &wsa_data) == -1) {
154  ThrowAtError("WSAStartup");
155  }
156  if (LOBYTE(wsa_data.wVersion) != 2 || HIBYTE(wsa_data.wVersion) != 2) {
157  WSACleanup();
158  LOG(FATAL) << "Could not find a usable version of Winsock.dll";
159  }
160 #endif // defined(_WIN32)
161 }
162 
163 inline void SocketFinalize() {
164 #if defined(_WIN32)
165  WSACleanup();
166 #endif // defined(_WIN32)
167 }
168 
169 #if defined(_WIN32) && defined(xgboost_IS_MINGW)
170 // 旧版 mysys32 的模拟定义。
171 inline const char *inet_ntop(int, const void *, char *, socklen_t) { // NOLINT
172  MingWError();
173  return nullptr;
174 }
175 #else
176 using ::inet_ntop;
177 #endif // defined(_WIN32) && defined(xgboost_IS_MINGW)
178 
179 } // namespace system
180 
181 namespace collective {
182 class SockAddress;
183 
184 enum class SockDomain : std::int32_t { kV4 = AF_INET, kV6 = AF_INET6 };
185 
190 SockAddress MakeSockAddress(StringView host, in_port_t port);
191 
192 class SockAddrV6 {
193  sockaddr_in6 addr_;
194 
195  public
196  explicit SockAddrV6(sockaddr_in6 addr) : addr_{addr} {}
197  SockAddrV6() { std::memset(&addr_, '\0', sizeof(addr_)); }
198 
201 
202  in_port_t Port() const { return ntohs(addr_.sin6_port); }
203 
204  std::string Addr() const {
205  char buf[INET6_ADDRSTRLEN];
206  auto const *s = system::inet_ntop(static_cast<std::int32_t>(SockDomain::kV6), &addr_.sin6_addr,
207  buf, INET6_ADDRSTRLEN);
208  if (s == nullptr) {
209  system::ThrowAtError("inet_ntop");
210  }
211  return {buf};
212  }
213  sockaddr_in6 const &Handle() const { return addr_; }
214 };
215 
216 class SockAddrV4 {
217  private
218  sockaddr_in addr_;
219 
220  public
221  explicit SockAddrV4(sockaddr_in addr) : addr_{addr} {}
222  SockAddrV4() { std::memset(&addr_, '\0', sizeof(addr_)); }
223 
226 
227  [[nodiscard]] in_port_t Port() const { return ntohs(addr_.sin_port); }
228 
229  [[nodiscard]] std::string Addr() const {
230  char buf[INET_ADDRSTRLEN];
231  auto const *s = system::inet_ntop(static_cast<std::int32_t>(SockDomain::kV4), &addr_.sin_addr,
232  buf, INET_ADDRSTRLEN);
233  if (s == nullptr) {
234  system::ThrowAtError("inet_ntop");
235  }
236  return {buf};
237  }
238  [[nodiscard]] sockaddr_in const &Handle() const { return addr_; }
239 };
240 
244 /// 类 SockAddress;
245  private
246  SockAddrV6 v6_;
247  SockAddrV4 v4_;
248  SockDomain domain_{SockDomain::kV4};
249 
250  public
251  SockAddress() = default;
252  explicit SockAddress(SockAddrV6 const &addr) : v6_{addr}, domain_{SockDomain::kV6} {}
253  explicit SockAddress(SockAddrV4 const &addr) : v4_{addr} {}
254 
255  [[nodiscard]] auto Domain() const { return domain_; }
256 
257  [[nodiscard]] bool IsV4() const { return Domain() == SockDomain::kV4; }
258  [[nodiscard]] bool IsV6() const { return !IsV4(); }
259 
260  [[nodiscard]] auto const &V4() const { return v4_; }
261  [[nodiscard]] auto const &V6() const { return v6_; }
262 };
263 
267 /// TCP 套接字封装器。
268  public
270 
271  private
272  HandleT handle_{InvalidSocket()};
273  bool non_blocking_{false};
274  // 在 macOS 上,无法可靠地在不先绑定套接字的情况下从套接字中提取域。
275  //
276 #if defined(__APPLE__)
277  SockDomain domain_{SockDomain::kV4};
278 #endif
279 
280  constexpr static HandleT InvalidSocket() { return INVALID_SOCKET; }
281 
282  explicit TCPSocket(HandleT newfd) : handle_{newfd} {}
283 
284  public
285  TCPSocket() = default;
289  [[nodiscard]] auto Domain() const -> SockDomain {
290  auto ret_iafamily = [](std::int32_t domain) {
291  switch (domain) {
292  case AF_INET
293  return SockDomain::kV4;
294  case AF_INET6
295  return SockDomain::kV6;
296  default: {
297  LOG(FATAL) << "Unknown IA family.";
298  }
299  }
300  return SockDomain::kV4;
301  };
302 
303 #if defined(_WIN32)
304  WSAPROTOCOL_INFOA info;
305  socklen_t len = sizeof(info);
307  getsockopt(handle_, SOL_SOCKET, SO_PROTOCOL_INFO, reinterpret_cast<char *>(&info), &len),
308  0);
309  return ret_iafamily(info.iAddressFamily);
310 #elif defined(__APPLE__)
311  return domain_;
312 #elif defined(__unix__)
313 #ifndef __PASE__
314  std::int32_t domain;
315  socklen_t len = sizeof(domain);
317  getsockopt(this->Handle(), SOL_SOCKET, SO_DOMAIN, reinterpret_cast<char *>(&domain), &len),
318  0);
319  return ret_iafamily(domain);
320 #else
321  struct sockaddr sa;
322  socklen_t sizeofsa = sizeof(sa);
323  xgboost_CHECK_SYS_CALL(getsockname(handle_, &sa, &sizeofsa), 0);
324  if (sizeofsa < sizeof(uchar_t) * 2) {
325  return ret_iafamily(AF_INET);
326  }
327  return ret_iafamily(sa.sa_family);
328 #endif // __PASE__
329 #else
330  LOG(FATAL) << "Unknown platform.";
331  return ret_iafamily(AF_INET);
332 #endif // platforms
333  }
334 
335 /// 检查套接字是否已关闭。
336  [[nodiscard]] bool IsClosed() const { return handle_ == InvalidSocket(); }
338 /// 获取套接字错误码。
339  [[nodiscard]] Result GetSockError() const {
340  std::int32_t optval = 0;
341  socklen_t len = sizeof(optval);
342  auto ret = getsockopt(handle_, SOL_SOCKET, SO_ERROR, reinterpret_cast<char *>(&optval), &len);
343  if (ret != 0) {
344  auto errc = std::error_code{system::LastError(), std::system_category()};
345  return Fail("Failed to retrieve socket error.", std::move(errc));
346  }
347  if (optval != 0) {
348  auto errc = std::error_code{optval, std::system_category()};
349  return Fail("Socket error.", std::move(errc));
350  }
351  return Success();
352  }
354 /// 检查套接字是否坏或已关闭。
355  [[nodiscard]] bool BadSocket() const {
356  if (IsClosed()) {
357  return true;
358  }
359  auto err = GetSockError();
360  if (err.Code() == std::error_code{EBADF, std::system_category()} || // NOLINT
361  err.Code() == std::error_code{EINTR, std::system_category()}) { // NOLINT
362  return true;
363  }
364  return false;
365  }
366 /// 设置非阻塞模式。
367  [[nodiscard]] Result NonBlocking(bool non_block) {
368 #if defined(_WIN32)
369  u_long mode = non_block ? 1 : 0;
370  if (ioctlsocket(handle_, FIONBIO, &mode) != NO_ERROR) {
371  return system::FailWithCode("Failed to set socket to non-blocking.");
372  }
373 #else
374  std::int32_t flag = fcntl(handle_, F_GETFL, 0);
375  auto rc = flag;
376  if (rc == -1) {
377  return system::FailWithCode("Failed to get socket flag.");
378  }
379  if (non_block) {
380  flag |= O_NONBLOCK;
381  } else {
382  flag &= ~O_NONBLOCK;
383  }
384  rc = fcntl(handle_, F_SETFL, flag);
385  if (rc == -1) {
386  return system::FailWithCode("Failed to set socket to non-blocking.");
387  }
388 #endif // _WIN32
389  non_blocking_ = non_block;
390  return Success();
391  [[nodiscard]] bool NonBlocking() const { return non_blocking_; }
392  [[nodiscard]] Result RecvTimeout(std::chrono::seconds timeout) {
393  // https://stackoverflow.com/questions/2876024/linux-is-there-a-read-or-recv-from-socket-with-timeout
394 #if defined(_WIN32)
395  DWORD tv = timeout.count() * 1000;
396  auto rc =
397  setsockopt(Handle(), SOL_SOCKET, SO_RCVTIMEO, reinterpret_cast<char *>(&tv), sizeof(tv));
398 #else
399  struct timeval tv;
400  tv.tv_sec = timeout.count();
401  tv.tv_usec = 0;
402  auto rc = setsockopt(Handle(), SOL_SOCKET, SO_RCVTIMEO, reinterpret_cast<char const *>(&tv),
403  sizeof(tv));
404 #endif
405  if (rc != 0) {
406  return system::FailWithCode("Failed to set timeout on recv.");
407  }
408  return Success();
409  }
410 
411  [[nodiscard]] Result SetBufSize(std::int32_t n_bytes) {
412  auto rc = setsockopt(this->Handle(), SOL_SOCKET, SO_SNDBUF, reinterpret_cast<char *>(&n_bytes),
413  sizeof(n_bytes));
414  if (rc != 0) {
415  return system::FailWithCode("Failed to set send buffer size.");
416  }
417  rc = setsockopt(this->Handle(), SOL_SOCKET, SO_RCVBUF, reinterpret_cast<char *>(&n_bytes),
418  sizeof(n_bytes));
419  if (rc != 0) {
420  return system::FailWithCode("Failed to set recv buffer size.");
421  }
422  return Success();
423  }
424 
425  [[nodiscard]] Result SendBufSize(std::int32_t *n_bytes) {
426  socklen_t optlen;
427  auto rc = getsockopt(this->Handle(), SOL_SOCKET, SO_SNDBUF, reinterpret_cast<char *>(n_bytes),
428  &optlen);
429  if (rc != 0 || optlen != sizeof(std::int32_t)) {
430  return system::FailWithCode("getsockopt");
431  }
432  return Success();
433  }
434  [[nodiscard]] Result RecvBufSize(std::int32_t *n_bytes) {
435  socklen_t optlen;
436  auto rc = getsockopt(this->Handle(), SOL_SOCKET, SO_RCVBUF, reinterpret_cast<char *>(n_bytes),
437  &optlen);
438  if (rc != 0 || optlen != sizeof(std::int32_t)) {
439  return system::FailWithCode("getsockopt");
440  }
441  return Success();
442  }
443 #if defined(__linux__)
444  [[nodiscard]] Result PendingSendSize(std::int32_t *n_bytes) const {
445  return ioctl(this->Handle(), TIOCOUTQ, n_bytes) == 0 ? Success()
446  : system::FailWithCode("ioctl");
447  }
448  [[nodiscard]] Result PendingRecvSize(std::int32_t *n_bytes) const {
449  return ioctl(this->Handle(), FIONREAD, n_bytes) == 0 ? Success()
450  : system::FailWithCode("ioctl");
451  }
452 #endif // defined(__linux__)
453 
454  [[nodiscard]] Result SetKeepAlive() {
455  std::int32_t keepalive = 1;
456  auto rc = setsockopt(handle_, SOL_SOCKET, SO_KEEPALIVE, reinterpret_cast<char *>(&keepalive),
457  sizeof(keepalive));
458  if (rc != 0) {
459  return system::FailWithCode("Failed to set TCP keeaplive.");
460  }
461  return Success();
462  }
463 
464  [[nodiscard]] Result SetNoDelay(std::int32_t no_delay = 1) {
465  auto rc = setsockopt(handle_, IPPROTO_TCP, TCP_NODELAY, reinterpret_cast<char *>(&no_delay),
466  sizeof(no_delay));
467  if (rc != 0) {
468  return system::FailWithCode("Failed to set TCP no delay.");
469  }
470  return Success();
471  }
472 
476 /// 接受连接并返回新创建的套接字。
477  TCPSocket Accept() {
478  SockAddress addr;
479  TCPSocket newsock;
480  auto rc = this->Accept(&newsock, &addr);
481  SafeColl(rc);
482  return newsock;
483  }
484  [[nodiscard]] Result Accept(TCPSocket *out, SockAddress *addr) {
485 #if defined(_WIN32)
486  auto interrupt = WSAEINTR;
487 #else
488  auto interrupt = EINTR;
489 #endif
490  if (this->Domain() == SockDomain::kV4) {
491  struct sockaddr_in caddr;
492  socklen_t caddr_len = sizeof(caddr);
493  HandleT newfd = accept(Handle(), reinterpret_cast<sockaddr *>(&caddr), &caddr_len);
494  if (newfd == InvalidSocket() && system::LastError() != interrupt) {
495  return system::FailWithCode("Failed to accept.");
496  }
497  *addr = SockAddress{SockAddrV4{caddr}};
498  *out = TCPSocket{newfd};
499  } else {
500  struct sockaddr_in6 caddr;
501  socklen_t caddr_len = sizeof(caddr);
502  HandleT newfd = accept(Handle(), reinterpret_cast<sockaddr *>(&caddr), &caddr_len);
503  if (newfd == InvalidSocket() && system::LastError() != interrupt) {
504  return system::FailWithCode("Failed to accept.");
505  }
506  *addr = SockAddress{SockAddrV6{caddr}};
507  *out = TCPSocket{newfd};
508  }
509  // 在 MacOS 上,如果父套接字是异步的,则此项会自动设置为异步套接字
510  // 我们确保所有套接字默认是阻塞的。
511  //
512  // 在 Windows 上,关闭时会返回一个关闭的套接字。我们在设置非阻塞时会对此进行防护。
513  //
514  if (!out->IsClosed()) {
515  return out->NonBlocking(false);
516  }
517  return Success();
518  }
519 
521  if (!IsClosed()) {
522  auto rc = this->Close();
523  if (!rc.OK()) {
524  LOG(WARNING) << rc.Report();
525  }
526  }
527  }
528 
529  TCPSocket(TCPSocket const &that) = delete;
530  TCPSocket(TCPSocket &&that) noexcept(true) { std::swap(this->handle_, that.handle_); }
531  TCPSocket &operator=(TCPSocket const &that) = delete;
532  TCPSocket &operator=(TCPSocket &&that) noexcept(true) {
533  std::swap(this->handle_, that.handle_);
534  return *this;
535  }
539 /// 获取套接字句柄。
545 /// 在套接字上监听。
549 /// 绑定到主机并返回绑定的端口号。
550  // 为了保持一致性,使用 int32 而不是 in_port_t。我们从使用其他语言的用户那里获取端口作为参数,
551  // 端口通常存储并作为 int 传递。
552  if (Domain() == SockDomain::kV6) {
553  auto addr = SockAddrV6::InaddrAny();
554  auto handle = reinterpret_cast<sockaddr const *>(&addr.Handle());
555  if (bind(handle_, handle, sizeof(std::remove_reference_t<decltype(addr.Handle())>)) != 0) {
556  return system::FailWithCode("bind failed.");
557  }
558 
559  sockaddr_in6 res_addr;
560  socklen_t addrlen = sizeof(res_addr);
561  if (getsockname(handle_, reinterpret_cast<sockaddr *>(&res_addr), &addrlen) != 0) {
562  return system::FailWithCode("getsockname failed.");
563  }
564  *p_out = ntohs(res_addr.sin6_port);
565  } else {
566  auto addr = SockAddrV4::InaddrAny();
567  auto handle = reinterpret_cast<sockaddr const *>(&addr.Handle());
568  if (bind(handle_, handle, sizeof(std::remove_reference_t<decltype(addr.Handle())>)) != 0) {
569  return system::FailWithCode("bind failed.");
570  }
571 
572  sockaddr_in res_addr;
573  socklen_t addrlen = sizeof(res_addr);
574  if (getsockname(handle_, reinterpret_cast<sockaddr *>(&res_addr), &addrlen) != 0) {
575  return system::FailWithCode("getsockname failed.");
576  }
577  *p_out = ntohs(res_addr.sin_port);
578  }
579 
580  return Success();
581  }
582 
583  [[nodiscard]] auto Port() const {
584  if (this->Domain() == SockDomain::kV4) {
585  sockaddr_in res_addr;
586  socklen_t addrlen = sizeof(res_addr);
587  auto code = getsockname(handle_, reinterpret_cast<sockaddr *>(&res_addr), &addrlen);
588  if (code != 0) {
589  return std::make_pair(system::FailWithCode("getsockname"), std::int32_t{0});
590  }
591  return std::make_pair(Success(), std::int32_t{ntohs(res_addr.sin_port)});
592  } else {
593  sockaddr_in6 res_addr;
594  socklen_t addrlen = sizeof(res_addr);
595  auto code = getsockname(handle_, reinterpret_cast<sockaddr *>(&res_addr), &addrlen);
596  if (code != 0) {
597  return std::make_pair(system::FailWithCode("getsockname"), std::int32_t{0});
598  }
599  return std::make_pair(Success(), std::int32_t{ntohs(res_addr.sin6_port)});
600  }
601  }
608  [[nodiscard]] Result Bind(StringView ip, std::int32_t *port) {
609  // 将 socket handle_ 绑定到 ip
610  auto addr = MakeSockAddress(ip, *port);
611  std::int32_t errc{0};
612  if (addr.IsV4()) {
613  auto handle = reinterpret_cast<sockaddr const *>(&addr.V4().Handle());
614  errc = bind(handle_, handle, sizeof(std::remove_reference_t<decltype(addr.V4().Handle())>));
615  } else {
616  auto handle = reinterpret_cast<sockaddr const *>(&addr.V6().Handle());
617  errc = bind(handle_, handle, sizeof(std::remove_reference_t<decltype(addr.V6().Handle())>));
618  }
619  if (errc != 0) {
620  return system::FailWithCode("Failed to bind socket.");
621  }
622  auto [rc, new_port] = this->Port();
623  if (!rc.OK()) {
624  return std::move(rc);
625  }
626  if (*port == 0) {
627  *port = new_port;
628  return Success();
629  }
630  if (*port != new_port) {
631  return Fail("Got an invalid port from bind.");
632  }
633  return Success();
634  }
635 
639  [[nodiscard]] Result SendAll(void const *buf, std::size_t len, std::size_t *n_sent) {
640  char const *_buf = reinterpret_cast<const char *>(buf);
641  std::size_t &ndone = *n_sent;
642  ndone = 0;
643  while (ndone < len) {
644  ssize_t ret = send(handle_, _buf, len - ndone, 0);
645  if (ret == -1) {
647  return Success();
648  }
649  return system::FailWithCode("send");
650  }
651  _buf += ret;
652  ndone += ret;
653  }
654  return Success();
655  }
659  [[nodiscard]] Result RecvAll(void *buf, std::size_t len, std::size_t *n_recv) {
660  char *_buf = reinterpret_cast<char *>(buf);
661  std::size_t &ndone = *n_recv;
662  ndone = 0;
663  while (ndone < len) {
664  ssize_t ret = recv(handle_, _buf, len - ndone, MSG_WAITALL);
665  if (ret == -1) {
667  return Success();
668  }
669  return system::FailWithCode("recv");
670  }
671  if (ret == 0) {
672  return Success();
673  }
674  _buf += ret;
675  ndone += ret;
676  }
677  return Success();
678  }
686  auto Send(const void *buf_, std::size_t len, std::int32_t flags = 0) {
687  const char *buf = reinterpret_cast<const char *>(buf_);
688  return send(handle_, buf, len, flags);
689  }
697  auto Recv(void *buf, std::size_t len, std::int32_t flags = 0) {
698  char *_buf = static_cast<char *>(buf);
699  // 有关跳过的 tidy 分析,请参阅 https://github.com/llvm/llvm-project/issues/104241
700  // NOLINTBEGIN(clang-analyzer-unix.BlockInCriticalSection)
701  return recv(handle_, _buf, len, flags);
702  // NOLINTEND(clang-analyzer-unix.BlockInCriticalSection)
703  }
707  std::size_t Send(StringView str);
711  [[nodiscard]] Result Recv(std::string *p_str);
715  [[nodiscard]] Result Close() {
716  if (InvalidSocket() != handle_) {
717  auto rc = system::CloseSocket(handle_);
718 #if defined(_WIN32)
719  // 可能是由于分离线程,我们在完成 WSA 后关闭了 TCP socket。
720  if (rc != 0 && system::LastError() != WSANOTINITIALISED) {
721  return system::FailWithCode("Failed to close the socket.");
722  }
723 #else
724  if (rc != 0) {
725  return system::FailWithCode("Failed to close the socket.");
726  }
727 #endif
728  handle_ = InvalidSocket();
729  }
730  return Success();
731  }
735  [[nodiscard]] Result Shutdown() {
736  if (this->IsClosed()) {
737  return Success();
738  }
739  auto rc = system::ShutdownSocket(this->Handle());
740 #if defined(_WIN32)
741  // 如果 socket 未连接,Windows 无法关闭它。
742  if (rc == -1 && system::LastError() == WSAENOTCONN) {
743  return Success();
744  }
745 #endif
746  if (rc != 0) {
747  return system::FailWithCode("Failed to shutdown socket.");
748  }
749  return Success();
750  }
751 
755  static TCPSocket Create(SockDomain domain) {
756 #if defined(xgboost_IS_MINGW)
757  MingWError();
758  return {};
759 #else
760  auto fd = socket(static_cast<std::int32_t>(domain), SOCK_STREAM, 0);
761  if (fd == InvalidSocket()) {
762  system::ThrowAtError("socket");
763  }
764 
765  TCPSocket socket{fd};
766 #if defined(__APPLE__)
767  socket.domain_ = domain;
768 #endif // defined(__APPLE__)
769  return socket;
770 #endif // defined(xgboost_IS_MINGW)
771  }
772 
773  static TCPSocket *CreatePtr(SockDomain domain) {
774 #if defined(xgboost_IS_MINGW)
775  MingWError();
776  return nullptr;
777 #else
778  auto fd = socket(static_cast<std::int32_t>(domain), SOCK_STREAM, 0);
779  if (fd == InvalidSocket()) {
780  system::ThrowAtError("socket");
781  }
782  auto socket = new TCPSocket{fd};
783 
784 #if defined(__APPLE__)
785  socket->domain_ = domain;
786 #endif // defined(__APPLE__)
787  return socket;
788 #endif // defined(xgboost_IS_MINGW)
789  }
790 };
791 
804 [[nodiscard]] Result Connect(xgboost::StringView host, std::int32_t port, std::int32_t retry,
805  std::chrono::seconds timeout,
807 
811 [[nodiscard]] Result GetHostName(std::string *p_out);
812 
816 template <typename H>
817 Result INetNToP(H const &host, std::string *p_out) {
818  std::string &ip = *p_out;
819  switch (host->h_addrtype) {
820  case AF_INET: {
821  auto addr = reinterpret_cast<struct in_addr *>(host->h_addr_list[0]);
822  char str[INET_ADDRSTRLEN];
823  inet_ntop(AF_INET, addr, str, INET_ADDRSTRLEN);
824  ip = str;
825  break;
826  }
827  case AF_INET6: {
828  auto addr = reinterpret_cast<struct in6_addr *>(host->h_addr_list[0]);
829  char str[INET6_ADDRSTRLEN];
830  inet_ntop(AF_INET6, addr, str, INET6_ADDRSTRLEN);
831  ip = str;
832  break;
833  }
834  default: {
835  return Fail("Invalid address type.");
836  }
837  }
838  return Success();
839 }
840 } // namespace collective
841 } // namespace xgboost
842 
843 #undef xgboost_CHECK_SYS_CALL
定义 xgboost 的配置宏和基本类型。
定义于: socket.h:216
SockAddrV4(sockaddr_in addr)
定义于: socket.h:221
static SockAddrV4 InaddrAny()
in_port_t Port() const
定义于: socket.h:227
sockaddr_in const & Handle() const
定义于: socket.h:238
std::string Addr() const
定义于: socket.h:229
static SockAddrV4 Loopback()
SockAddrV4()
定义于: socket.h:222
定义于: socket.h:192
static SockAddrV6 InaddrAny()
SockAddrV6()
定义于: socket.h:197
sockaddr_in6 const & Handle() const
定义于: socket.h:213
in_port_t Port() const
定义于: socket.h:202
SockAddrV6(sockaddr_in6 addr)
定义于: socket.h:196
std::string Addr() const
定义于: socket.h:204
static SockAddrV6 Loopback()
TCP socket 地址,可以是 IPv4 或 IPv6。
定义于: socket.h:244
bool IsV6() const
定义于: socket.h:258
auto const & V6() const
定义于: socket.h:261
bool IsV4() const
定义于: socket.h:257
auto Domain() const
定义于: socket.h:255
SockAddress(SockAddrV4 const &addr)
定义于: socket.h:253
auto const & V4() const
定义于: socket.h:260
SockAddress(SockAddrV6 const &addr)
定义于: socket.h:252
用于简单通信的 TCP socket。
定义于: socket.h:267
Result GetSockError() const
获取最后的错误码(如果有)
定义于: socket.h:338
Result Recv(std::string *p_str)
接收字符串,格式与 RABIT 中的 Python socket 封装器匹配。
Result RecvTimeout(std::chrono::seconds timeout)
定义于: socket.h:392
HandleT const & Handle() const
返回原生 socket 文件描述符。
定义于: socket.h:539
Result SetNoDelay(std::int32_t no_delay=1)
定义于: socket.h:464
TCPSocket & operator=(TCPSocket const &that)=delete
Result BindHost(std::int32_t *p_out)
将 socket 绑定到 INADDR_ANY,返回操作系统选择的端口。
定义于: socket.h:549
Result Listen(std::int32_t backlog=256)
监听传入请求。应在绑定后调用。
static TCPSocket * CreatePtr(SockDomain domain)
定义于: socket.h:773
Result Shutdown()
在 socket 上调用 shutdown。
定义于: socket.h:735
system::SocketT HandleT
定义于: socket.h:269
Result Bind(StringView ip, std::int32_t *port)
将 socket 绑定到指定地址。
定义于: socket.h:608
TCPSocket & operator=(TCPSocket &&that) noexcept(true)
定义于: socket.h:532
auto Port() const
定义于: socket.h:583
Result SetKeepAlive()
定义于: socket.h:454
Result Accept(TCPSocket *out, SockAddress *addr)
定义于: socket.h:484
TCPSocket(TCPSocket const &that)=delete
Result RecvBufSize(std::int32_t *n_bytes)
定义于: socket.h:434
Result SendBufSize(std::int32_t *n_bytes)
定义于: socket.h:425
std::size_t Send(StringView str)
发送字符串,格式与 RABIT 中的 Python socket 封装器匹配。
auto Domain() const -> SockDomain
返回 socket 域。
定义于: socket.h:289
static TCPSocket Create(SockDomain domain)
在指定的域上创建一个 TCP socket。
定义于: socket.h:755
Result RecvAll(void *buf, std::size_t len, std::size_t *n_recv)
接收数据,如果没有错误,则应接收所有数据。
定义于: socket.h:659
auto Recv(void *buf, std::size_t len, std::int32_t flags=0)
使用 socket 接收数据
定义于: socket.h:697
bool NonBlocking() const
定义于: socket.h:391
bool IsClosed() const
定义于: socket.h:335
Result NonBlocking(bool non_block)
定义于: socket.h:366
TCPSocket Accept()
接受新连接,为新连接返回一个新的 TCP socket。
定义于: socket.h:476
~TCPSocket()
定义于: socket.h:520
auto Send(const void *buf_, std::size_t len, std::int32_t flags=0)
使用 socket 发送数据。
定义于: socket.h:686
Result SendAll(void const *buf, std::size_t len, std::size_t *n_sent)
发送数据,如果没有错误,则应发送所有数据。
定义于: socket.h:639
bool BadSocket() const
检查是否发生异常
定义于: socket.h:354
TCPSocket(TCPSocket &&that) noexcept(true)
定义于: socket.h:530
Result SetBufSize(std::int32_t n_bytes)
定义于: socket.h:411
Result Close()
关闭 socket,如果在析构函数中 socket 未关闭,则会自动调用。
定义于: socket.h:715
void swap(xgboost::IntrusivePtr< T > &x, xgboost::IntrusivePtr< T > &y) noexcept
定义于: intrusive_ptr.h:209
Result Connect(xgboost::StringView host, std::int32_t port, std::int32_t retry, std::chrono::seconds timeout, xgboost::collective::TCPSocket *out_conn)
连接到远程地址,如果失败则返回错误码。
Result GetHostName(std::string *p_out)
获取本地主机名。
SockAddress MakeSockAddress(StringView host, in_port_t port)
解析主机地址并返回 SockAddress 实例。支持 IPv4 和 IPv6 主机。
SockDomain
定义于: socket.h:184
auto Fail(std::string msg, char const *file=__builtin_FILE(), std::int32_t line=__builtin_LINE())
返回失败。
定义于: result.h:124
void SafeColl(Result const &rc)
Result INetNToP(H const &host, std::string *p_out)
inet_ntop
定义于: socket.h:817
auto Success() noexcept(true)
返回成功。
定义于: result.h:120
bool ErrorWouldBlock(std::int32_t errsv) noexcept(true)
定义于: socket.h:137
auto ThrowAtError(StringView fn_name, std::int32_t errsv=LastError())
定义于: socket.h:92
void SocketStartup()
定义于: socket.h:150
std::int32_t CloseSocket(SocketT fd)
定义于: socket.h:114
bool LastErrorWouldBlock()
定义于: socket.h:145
std::int32_t LastError()
定义于: socket.h:69
void SocketFinalize()
定义于: socket.h:163
std::int32_t ShutdownSocket(SocketT fd)
定义于: socket.h:122
collective::Result FailWithCode(std::string msg)
定义于: socket.h:78
int SocketT
定义于: socket.h:101
多目标树的核心数据结构。
定义于: base.h:89
int SOCKET
定义于: poll_utils.h:40
#define __builtin_LINE()
定义于: result.h:57
#define __builtin_FILE()
定义于: result.h:56
#define INVALID_SOCKET
定义于: socket.h:102
#define xgboost_CHECK_SYS_CALL(exp, expected)
定义于: socket.h:106
定义于: string_view.h:16
一种比抛出 dmlc 异常更容易处理的错误类型。我们可以记录和传播 ...
定义于: result.h:67