xgboost
poll_utils.h
前往此文件的文档。
1 
6 #pragma once
9 
10 #if defined(_WIN32)
11 #include <xgboost/windefs.h>
12 // Socket API
13 #include <winsock2.h>
14 #include <ws2tcpip.h>
15 #else
16 
17 #include <arpa/inet.h>
18 #include <fcntl.h>
19 #include <netdb.h>
20 #include <netinet/in.h>
21 #include <sys/ioctl.h>
22 #include <sys/socket.h>
23 #include <unistd.h>
24 
25 #include <cerrno>
26 
27 #endif // defined(_WIN32)
28 
29 #include <chrono>
30 #include <cstring>
31 #include <string>
32 #include <system_error> // make_error_code, errc
33 #include <unordered_map>
34 #include <vector>
35 
36 #if !defined(_WIN32)
37 
38 #include <poll.h>
39 
40 using SOCKET = int;
41 using sock_size_t = size_t; // NOLINT
42 #endif // !defined(_WIN32)
43 
44 #define IS_MINGW() defined(__MINGW32__)
45 
46 #if IS_MINGW() && !defined(POLLRDNORM) && !defined(POLLRDBAND)
47 /*
48  * On later mingw versions poll should be supported (with bugs). See:
49  * https://stackoverflow.com/a/60623080
50  *
51  * But right now the mingw distributed with R 3.6 doesn't support it.
52  * So we just give a warning and provide dummy implementation to get
53  * compilation passed. Otherwise we will have to provide a stub for
54  * RABIT.
55  *
56  * Even on mingw version that has these structures and flags defined,
57  * functions like `send` and `listen` might have unresolved linkage to
58  * their implementation. So supporting mingw is quite difficult at
59  * the time of writing.
60  */
61 #pragma message("mingw上不支持分布式训练。")
62 typedef struct pollfd {
63  SOCKET fd;
64  short events; // NOLINT
65  short revents; // NOLINT
66 } WSAPOLLFD, *PWSAPOLLFD, *LPWSAPOLLFD;
67 
68 // POLLRDNORM | POLLRDBAND
69 #define POLLIN (0x0100 | 0x0200)
70 #define POLLPRI 0x0400
71 // POLLWRNORM
72 #define POLLOUT 0x0010
73 
74 #endif // IS_MINGW() && !defined(POLLRDNORM) && !defined(POLLRDBAND)
75 
76 namespace rabit {
77 namespace utils {
78 
79 template <typename PollFD>
80 int PollImpl(PollFD* pfd, int nfds, std::chrono::seconds timeout) noexcept(true) {
81  // For Windows and Linux, negative timeout means infinite timeout. For freebsd,
82  // INFTIM(-1) should be used instead.
83 #if defined(_WIN32)
84 
85 #if IS_MINGW()
86  xgboost::MingWError();
87  return -1;
88 #else
89  return WSAPoll(pfd, nfds, std::chrono::milliseconds(timeout).count());
90 #endif // IS_MINGW()
91 
92 #else
93  return poll(pfd, nfds, timeout.count() < 0 ? -1 : std::chrono::milliseconds(timeout).count());
94 #endif // IS_MINGW()
95 }
96 
97 template <typename E>
98 std::enable_if_t<std::is_integral_v<E>, xgboost::collective::Result> PollError(E const& revents) {
99  if ((revents & POLLERR) != 0) {
100  auto err = errno;
101  auto str = strerror(err);
102  return xgboost::system::FailWithCode(std::string{"轮询错误条件:"} + // NOLINT
103  std::string{str} + // NOLINT
104  " 代码:" + std::to_string(err));
105  }
106  if ((revents & POLLNVAL) != 0) {
107  return xgboost::system::FailWithCode("无效的轮询请求。");
108  }
109  if ((revents & POLLHUP) != 0) {
110  // 摘自 Linux 手册:
111  //
112  // 请注意,当从通道(例如管道或流套接字)读取时,此事件仅表示对端关闭了其通道端点。
113  // 只有在通道中所有未处理的数据都被消耗后,后续从通道的读取才会返回 0(文件末尾)。
114  //
115  // 我们通常没有工作节点退出的障碍,一端退出而另一端仍在读取数据是正常的。
116  //
118  }
119 #if defined(POLLRDHUP)
120  // 仅限 Linux 的标志
121  if ((revents & POLLRDHUP) != 0) {
122  return xgboost::system::FailWithCode("轮询到对端挂断。");
123  }
124 #endif // defined(POLLRDHUP)
126 }
127 
128 
131 struct PollHelper {
132  public
137  inline void WatchRead(SOCKET fd) {
138  // 添加文件描述符以监视读取
139  auto& pfd = fds[fd];
140  pfd.fd = fd;
141  pfd.events |= POLLIN;
142  }
143  void WatchRead(xgboost::collective::TCPSocket const &socket) { this->WatchRead(socket.Handle()); }
144 
148  inline void WatchWrite(SOCKET fd) {
149  // 添加文件描述符以监视写入
150  auto& pfd = fds[fd];
151  pfd.fd = fd;
152  pfd.events |= POLLOUT;
153  }
155  this->WatchWrite(socket.Handle());
156  }
161  inline void WatchException(SOCKET fd) {
162  // 添加文件描述符以监视异常
163  auto& pfd = fds[fd];
164  pfd.fd = fd;
165  pfd.events |= POLLPRI;
166  }
168  this->WatchException(socket.Handle());
169  }
173  [[nodiscard]] bool CheckRead(SOCKET fd) const {
174  // 检查描述符是否已准备好读取。
175  const auto& pfd = fds.find(fd);
176  return pfd != fds.end() && ((pfd->second.events & POLLIN) != 0);
177  }
178  [[nodiscard]] bool CheckRead(xgboost::collective::TCPSocket const& socket) const {
179  return this->CheckRead(socket.Handle());
180  }
181 
185  [[nodiscard]] bool CheckWrite(SOCKET fd) const {
186  // 检查描述符是否已准备好写入。
187  const auto& pfd = fds.find(fd);
188  return pfd != fds.end() && ((pfd->second.events & POLLOUT) != 0);
189  }
190  [[nodiscard]] bool CheckWrite(xgboost::collective::TCPSocket const& socket) const {
191  return this->CheckWrite(socket.Handle());
192  }
197  [[nodiscard]] xgboost::collective::Result Poll(std::chrono::seconds timeout,
198  bool check_error = true) {
199  // 在定义的集合上执行轮询,检查读取、写入、异常
200  std::vector<pollfd> fdset;
201  fdset.reserve(fds.size());
202  for (auto kv : fds) {
203  fdset.push_back(kv.second);
204  }
205  std::int32_t ret = PollImpl(fdset.data(), fdset.size(), timeout);
206  if (ret == 0) {
208  "轮询超时:" + std::to_string(timeout.count()) + " 秒。",
209  std::make_error_code(std::errc::timed_out));
210  } else if (ret < 0) {
211  return xgboost::system::FailWithCode("轮询失败, nfds:" + std::to_string(fdset.size()));
212  }
213 
214  for (auto& pfd : fdset) {
215  auto result = PollError(pfd.revents);
216  if (check_error && !result.OK()) {
217  return result;
218  }
219 
220  auto revents = pfd.revents & pfd.events;
221  fds[pfd.fd].events = revents;
222  }
224  }
225 
226  std::unordered_map<SOCKET, pollfd> fds;
227 };
228 } // namespace utils
229 } // namespace rabit
230 
231 #if IS_MINGW() && !defined(POLLRDNORM) && !defined(POLLRDBAND)
232 #undef POLLIN
233 #undef POLLPRI
234 #undef POLLOUT
235 #endif // IS_MINGW()
xgboost::collective::TCPSocket
用于简单通信的 TCP 套接字。
定义: socket.h:267
HandleT const & Handle() const
返回原生套接字文件描述符。
rabit::utils::PollImpl
int PollImpl(PollFD *pfd, int nfds, std::chrono::seconds timeout) noexcept(true)
rabit::utils::PollError
auto Fail(std::string msg, char const *file=__builtin_FILE(), std::int32_t line=__builtin_LINE())
返回失败。
定义: result.h:124
auto Success() noexcept(true)
返回成功。
xgboost::system::FailWithCode
collective::Result FailWithCode(std::string msg)
SOCKET
int SOCKET
定义: poll_utils.h:41
result.h
rabit::utils::PollHelper
用于执行轮询的辅助数据结构
定义: poll_utils.h:131
void WatchException(SOCKET fd)
添加文件描述符以监视异常
rabit::utils::PollHelper::CheckWrite
bool CheckWrite(xgboost::collective::TCPSocket const &socket) const
rabit::utils::PollHelper::WatchRead
void WatchRead(xgboost::collective::TCPSocket const &socket)
定义: poll_utils.h:142
xgboost::collective::Result Poll(std::chrono::seconds timeout, bool check_error=true)
在定义的集合上执行轮询,检查读取、写入、异常
rabit::utils::PollHelper::WatchWrite
void WatchWrite(xgboost::collective::TCPSocket const &socket)
定义: poll_utils.h:153
rabit::utils::PollHelper::CheckRead
bool CheckRead(SOCKET fd) const
检查描述符是否已准备好读取。
定义: poll_utils.h:173
void WatchException(xgboost::collective::TCPSocket const &socket)
定义: poll_utils.h:166
bool CheckWrite(SOCKET fd) const
检查描述符是否已准备好写入。
定义: poll_utils.h:185
void WatchWrite(SOCKET fd)
定义: poll_utils.h:148
bool CheckRead(xgboost::collective::TCPSocket const &socket) const
定义: poll_utils.h:177
rabit::utils::PollHelper::fds
std::unordered_map< SOCKET, pollfd > fds
void WatchRead(SOCKET fd)
添加文件描述符以监视读取