xgboost
poll_utils.h
前往此文件文档。
1 
6 #pragma once
9 
10 #if defined(_WIN32)
11 #include <xgboost/windefs.h>
12 // 套接字 API
13 #include <winsock2.h>
14 #include <ws2tcpip.h>
15 #else
16 
17 #include <arpa/inet.h>
18 #include <fcntl.h>
19 #include <netdb.h>
20 #include <netinet/in.h>
21 #include <sys/ioctl.h>
22 #include <sys/socket.h>
23 #include <unistd.h>
24 
25 #include <cerrno>
26 
27 #endif // defined(_WIN32)
28 
29 #include <chrono>
30 #include <cstring>
31 #include <string>
32 #include <system_error> // make_error_code, errc
33 #include <unordered_map>
34 #include <vector>
35 
36 #if !defined(_WIN32)
37 
38 #include <poll.h>
39 
40 using SOCKET = int;
41 using sock_size_t = size_t; // NOLINT
42 #endif // !defined(_WIN32)
43 
44 #define IS_MINGW() defined(__MINGW32__)
45 
46 #if IS_MINGW() && !defined(POLLRDNORM) && !defined(POLLRDBAND)
47 /*
48  * 在较新的 mingw 版本上,poll 应该受支持(有 bug)。请参阅:
49  * https://stackoverflow.com/a/60623080
50  *
51  * 但目前 R 3.6 分发的 mingw 不支持它。
52  * 所以我们只给出一个警告并提供虚拟实现以通过编译。
53  * 否则我们将不得不为 RABIT 提供一个存根。
54  * 即使在定义了这些结构和标志的 mingw 版本上,
55  *
56  * 诸如 `send` 和 `listen` 等函数也可能存在未解析的链接到
57  * 它们的实现。因此,在撰写本文时支持 mingw 相当困难。
58  *
59  *
60  */
61 #pragma message("mingw 不支持分布式训练。")
62 typedef struct pollfd {
63  SOCKET fd;
64  short events; // NOLINT
65  short revents; // NOLINT
66 } WSAPOLLFD, *PWSAPOLLFD, *LPWSAPOLLFD;
67 
68 // POLLRDNORM | POLLRDBAND
69 #define POLLIN (0x0100 | 0x0200)
70 #define POLLPRI 0x0400
71 // POLLWRNORM
72 #define POLLOUT 0x0010
73 
74 #endif // IS_MINGW() && !defined(POLLRDNORM) && !defined(POLLRDBAND)
75 
76 namespace rabit {
77 namespace utils {
78 
79 template <typename PollFD>
80 int PollImpl(PollFD* pfd, int nfds, std::chrono::seconds timeout) noexcept(true) {
81  // 对于 Windows 和 Linux,负超时表示无限超时。对于 freebsd,
82  // 应该使用 INFTIM(-1)。
83 #if defined(_WIN32)
84 
85 #if IS_MINGW()
86  xgboost::MingWError();
87  return -1;
88 #else
89  return WSAPoll(pfd, nfds, std::chrono::milliseconds(timeout).count());
90 #endif // IS_MINGW()
91 
92 #else
93  return poll(pfd, nfds, timeout.count() < 0 ? -1 : std::chrono::milliseconds(timeout).count());
94 #endif // IS_MINGW()
95 }
96 
97 template <typename E>
98 std::enable_if_t<std::is_integral_v<E>, xgboost::collective::Result> PollError(E const& revents) {
99  if ((revents & POLLERR) != 0) {
100  auto err = errno;
101  auto str = strerror(err);
102  return xgboost::system::FailWithCode(std::string{"轮询错误条件:"} + // NOLINT
103  std::string{str} + // NOLINT
104  " 代码:" + std::to_string(err));
105  }
106  if ((revents & POLLNVAL) != 0) {
107  return xgboost::system::FailWithCode("无效的轮询请求。");
108  }
109  if ((revents & POLLHUP) != 0) {
110  // 摘自 Linux 手册:
111  //
112  // 请注意,从管道或流套接字等通道读取时,此事件
113  // 仅表示对等端关闭了其通道端。后续从
114  // 通道读取将仅在通道中所有未决数据
115  // 被消耗后返回 0(文件末尾)。
116  //
117  // 我们通常没有退出工作器的屏障,一端
118  // 退出而另一端仍在读取数据是很正常的。
120  }
121 #if defined(POLLRDHUP)
122  // 仅限 Linux 标志
123  if ((revents & POLLRDHUP) != 0) {
124  return xgboost::system::FailWithCode("轮询挂起在另一端。");
125  }
126 #endif // defined(POLLRDHUP)
128 }
129 
131 struct PollHelper {
132  public
137  inline void WatchRead(SOCKET fd) {
138  auto& pfd = fds[fd];
139  pfd.fd = fd;
140  pfd.events |= POLLIN;
141  }
142  void WatchRead(xgboost::collective::TCPSocket const &socket) { this->WatchRead(socket.Handle()); }
143 
148  inline void WatchWrite(SOCKET fd) {
149  auto& pfd = fds[fd];
150  pfd.fd = fd;
151  pfd.events |= POLLOUT;
152  }
154  this->WatchWrite(socket.Handle());
155  }
156 
161  inline void WatchException(SOCKET fd) {
162  auto& pfd = fds[fd];
163  pfd.fd = fd;
164  pfd.events |= POLLPRI;
165  }
167  this->WatchException(socket.Handle());
168  }
173  [[nodiscard]] bool CheckRead(SOCKET fd) const {
174  const auto& pfd = fds.find(fd);
175  return pfd != fds.end() && ((pfd->second.events & POLLIN) != 0);
176  }
177  [[nodiscard]] bool CheckRead(xgboost::collective::TCPSocket const& socket) const {
178  return this->CheckRead(socket.Handle());
179  }
180 
185  [[nodiscard]] bool CheckWrite(SOCKET fd) const {
186  const auto& pfd = fds.find(fd);
187  return pfd != fds.end() && ((pfd->second.events & POLLOUT) != 0);
188  }
189  [[nodiscard]] bool CheckWrite(xgboost::collective::TCPSocket const& socket) const {
190  return this->CheckWrite(socket.Handle());
191  }
197  [[nodiscard]] xgboost::collective::Result Poll(std::chrono::seconds timeout,
198  bool check_error = true) {
199  std::vector<pollfd> fdset;
200  fdset.reserve(fds.size());
201  for (auto kv : fds) {
202  fdset.push_back(kv.second);
203  }
204  std::int32_t ret = PollImpl(fdset.data(), fdset.size(), timeout);
205  if (ret == 0) {
207  "轮询超时:" + std::to_string(timeout.count()) + " 秒。",
208  std::make_error_code(std::errc::timed_out));
209  } else if (ret < 0) {
210  return xgboost::system::FailWithCode("轮询失败,nfds:" + std::to_string(fdset.size()));
211  }
212 
213  for (auto& pfd : fdset) {
214  auto result = PollError(pfd.revents);
215  if (check_error && !result.OK()) {
216  return result;
217  }
218 
219  auto revents = pfd.revents & pfd.events;
220  fds[pfd.fd].events = revents;
221  }
223  }
224 
225  std::unordered_map<SOCKET, pollfd> fds;
226 };
227 } // namespace utils
228 } // namespace rabit
229 
230 #if IS_MINGW() && !defined(POLLRDNORM) && !defined(POLLRDBAND)
231 #undef POLLIN
232 #undef POLLPRI
233 #undef POLLOUT
234 #endif // IS_MINGW()
用于简单通信的 TCP 套接字。
定义: socket.h:267
HandleT const & Handle() const
返回原生套接字文件描述符。
定义: socket.h:539
int PollImpl(PollFD *pfd, int nfds, std::chrono::seconds timeout) noexcept(true)
定义: poll_utils.h:80
std::enable_if_t< std::is_integral_v< E >, xgboost::collective::Result > PollError(E const &revents)
定义: poll_utils.h:98
定义: poll_utils.h:76
auto Fail(std::string msg, char const *file=__builtin_FILE(), std::int32_t line=__builtin_LINE())
返回失败。
定义: result.h:124
auto Success() noexcept(true)
返回成功。
定义: result.h:120
collective::Result FailWithCode(std::string msg)
定义: socket.h:78
int SOCKET
定义: poll_utils.h:40
size_t sock_size_t
定义: poll_utils.h:41
用于执行轮询的辅助数据结构
定义: poll_utils.h:131
void WatchException(SOCKET fd)
添加文件描述符以监视异常
定义: poll_utils.h:161
bool CheckWrite(xgboost::collective::TCPSocket const &socket) const
定义: poll_utils.h:189
void WatchRead(xgboost::collective::TCPSocket const &socket)
定义: poll_utils.h:142
xgboost::collective::Result Poll(std::chrono::seconds timeout, bool check_error=true)
对定义的集合执行轮询,包括读、写、异常
定义: poll_utils.h:197
void WatchWrite(xgboost::collective::TCPSocket const &socket)
定义: poll_utils.h:153
bool CheckRead(SOCKET fd) const
检查描述符是否准备好读取。
定义: poll_utils.h:173
void WatchException(xgboost::collective::TCPSocket const &socket)
定义: poll_utils.h:166
bool CheckWrite(SOCKET fd) const
检查描述符是否准备好写入。
定义: poll_utils.h:185
void WatchWrite(SOCKET fd)
添加文件描述符以监视写入
定义: poll_utils.h:148
bool CheckRead(xgboost::collective::TCPSocket const &socket) const
定义: poll_utils.h:177
std::unordered_map< SOCKET, pollfd > fds
定义: poll_utils.h:225
void WatchRead(SOCKET fd)
添加文件描述符以监视读取
定义: poll_utils.h:137
一种比抛出 dmlc 异常更容易处理的错误类型。我们可以记录并传播 s...
定义: result.h:67