17 #include <arpa/inet.h>
20 #include <netinet/in.h>
21 #include <sys/ioctl.h>
22 #include <sys/socket.h>
32 #include <system_error>
33 #include <unordered_map>
44 #define IS_MINGW() defined(__MINGW32__)
46 #if IS_MINGW() && !defined(POLLRDNORM) && !defined(POLLRDBAND)
61 #pragma message("mingw上不支持分布式训练。")
62 typedef struct pollfd {
66 } WSAPOLLFD, *PWSAPOLLFD, *LPWSAPOLLFD;
69 #define POLLIN (0x0100 | 0x0200)
70 #define POLLPRI 0x0400
72 #define POLLOUT 0x0010
79 template <
typename PollFD>
80 int PollImpl(PollFD* pfd,
int nfds, std::chrono::seconds timeout) noexcept(
true) {
86 xgboost::MingWError();
89 return WSAPoll(pfd, nfds, std::chrono::milliseconds(timeout).count());
93 return poll(pfd, nfds, timeout.count() < 0 ? -1 : std::chrono::milliseconds(timeout).count());
99 if ((revents & POLLERR) != 0) {
101 auto str = strerror(err);
104 " 代码:" + std::to_string(err));
106 if ((revents & POLLNVAL) != 0) {
109 if ((revents & POLLHUP) != 0) {
119 #if defined(POLLRDHUP)
121 if ((revents & POLLRDHUP) != 0) {
141 pfd.events |= POLLIN;
152 pfd.events |= POLLOUT;
165 pfd.events |= POLLPRI;
175 const auto& pfd =
fds.find(fd);
176 return pfd !=
fds.end() && ((pfd->second.events & POLLIN) != 0);
187 const auto& pfd =
fds.find(fd);
188 return pfd !=
fds.end() && ((pfd->second.events & POLLOUT) != 0);
198 bool check_error =
true) {
200 std::vector<pollfd> fdset;
201 fdset.reserve(
fds.size());
202 for (
auto kv :
fds) {
203 fdset.push_back(kv.second);
205 std::int32_t ret =
PollImpl(fdset.data(), fdset.size(), timeout);
208 "轮询超时:" + std::to_string(timeout.count()) +
" 秒。",
209 std::make_error_code(std::errc::timed_out));
210 }
else if (ret < 0) {
214 for (
auto& pfd : fdset) {
216 if (check_error && !result.OK()) {
220 auto revents = pfd.revents & pfd.events;
221 fds[pfd.fd].events = revents;
226 std::unordered_map<SOCKET, pollfd>
fds;
231 #if IS_MINGW() && !defined(POLLRDNORM) && !defined(POLLRDBAND)
xgboost::collective::TCPSocket
用于简单通信的 TCP 套接字。
定义: socket.h:267
HandleT const & Handle() const
返回原生套接字文件描述符。
rabit::utils::PollImpl
int PollImpl(PollFD *pfd, int nfds, std::chrono::seconds timeout) noexcept(true)
定义: poll_utils.h:98
rabit
定义: poll_utils.h:76
auto Fail(std::string msg, char const *file=__builtin_FILE(), std::int32_t line=__builtin_LINE())
返回失败。
定义: result.h:124
auto Success() noexcept(true)
返回成功。
xgboost::system::FailWithCode
collective::Result FailWithCode(std::string msg)
定义: poll_utils.h:41
result.h
rabit::utils::PollHelper
用于执行轮询的辅助数据结构
定义: poll_utils.h:131
void WatchException(SOCKET fd)
添加文件描述符以监视异常
rabit::utils::PollHelper::CheckWrite
bool CheckWrite(xgboost::collective::TCPSocket const &socket) const
rabit::utils::PollHelper::WatchRead
void WatchRead(xgboost::collective::TCPSocket const &socket)
定义: poll_utils.h:142
xgboost::collective::Result Poll(std::chrono::seconds timeout, bool check_error=true)
在定义的集合上执行轮询,检查读取、写入、异常
rabit::utils::PollHelper::WatchWrite
void WatchWrite(xgboost::collective::TCPSocket const &socket)
定义: poll_utils.h:153
rabit::utils::PollHelper::CheckRead
bool CheckRead(SOCKET fd) const
检查描述符是否已准备好读取。
定义: poll_utils.h:173
void WatchException(xgboost::collective::TCPSocket const &socket)
定义: poll_utils.h:166
bool CheckWrite(SOCKET fd) const
检查描述符是否已准备好写入。
定义: poll_utils.h:185
void WatchWrite(SOCKET fd)
定义: poll_utils.h:148
bool CheckRead(xgboost::collective::TCPSocket const &socket) const
定义: poll_utils.h:177
rabit::utils::PollHelper::fds
std::unordered_map< SOCKET, pollfd > fds
void WatchRead(SOCKET fd)
添加文件描述符以监视读取