12 #include <system_error>
15 #if defined(__linux__)
16 #include <sys/ioctl.h>
26 using in_port_t = std::uint16_t;
29 #pragma comment(lib, "Ws2_32.lib")
32 #if !defined(xgboost_IS_MINGW)
38 #include <arpa/inet.h>
40 #include <netinet/in.h>
41 #include <netinet/in.h>
42 #include <netinet/tcp.h>
43 #include <sys/socket.h>
46 #if defined(__sun) || defined(sun)
47 #include <sys/sockio.h>
54 #include "xgboost/logging.h"
57 #if !defined(HOST_NAME_MAX)
58 #define HOST_NAME_MAX 256
63 #if defined(xgboost_IS_MINGW)
65 inline void MingWError() { LOG(FATAL) <<
"mingw不支持分布式训练。"; }
71 return WSAGetLastError();
82 #if defined(__GLIBC__)
86 auto err = std::error_code{errsv, std::system_category()};
88 << file <<
"(" << line <<
"): 调用 `" << fn_name <<
"` 失败: " << err.message()
93 auto err = std::error_code{errsv, std::system_category()};
94 LOG(FATAL) <<
"调用 `" << fn_name <<
"` 失败: " << err.message() << std::endl;
102 #define INVALID_SOCKET -1
105 #if !defined(xgboost_CHECK_SYS_CALL)
106 #define xgboost_CHECK_SYS_CALL(exp, expected) \
108 if (XGBOOST_EXPECT((exp) != (expected), false)) { \
109 ::xgboost::system::ThrowAtError(#exp); \
116 return closesocket(fd);
124 auto rc = shutdown(fd, SD_BOTH);
125 if (rc != 0 &&
LastError() == WSANOTINITIALISED) {
129 auto rc = shutdown(fd, SHUT_RDWR);
130 if (rc != 0 &&
LastError() == ENOTCONN) {
139 return errsv == WSAEWOULDBLOCK;
141 return errsv == EAGAIN || errsv == EWOULDBLOCK || errsv == EINPROGRESS;
153 if (WSAStartup(MAKEWORD(2, 2), &wsa_data) == -1) {
156 if (LOBYTE(wsa_data.wVersion) != 2 || HIBYTE(wsa_data.wVersion) != 2) {
158 LOG(FATAL) <<
"找不到可用的 Winsock.dll 版本";
169 #if defined(_WIN32) && defined(xgboost_IS_MINGW)
171 inline const char *inet_ntop(
int,
const void *,
char *, socklen_t) {
181 namespace collective {
202 in_port_t
Port()
const {
return ntohs(addr_.sin6_port); }
205 char buf[INET6_ADDRSTRLEN];
206 auto const *s = system::inet_ntop(
static_cast<std::int32_t
>(
SockDomain::kV6), &addr_.sin6_addr,
207 buf, INET6_ADDRSTRLEN);
213 sockaddr_in6
const &
Handle()
const {
return addr_; }
227 [[nodiscard]] in_port_t
Port()
const {
return ntohs(addr_.sin_port); }
229 [[nodiscard]] std::string
Addr()
const {
230 char buf[INET_ADDRSTRLEN];
231 auto const *s = system::inet_ntop(
static_cast<std::int32_t
>(
SockDomain::kV4), &addr_.sin_addr,
232 buf, INET_ADDRSTRLEN);
238 [[nodiscard]] sockaddr_in
const &
Handle()
const {
return addr_; }
255 [[nodiscard]]
auto Domain()
const {
return domain_; }
258 [[nodiscard]]
bool IsV6()
const {
return !
IsV4(); }
260 [[nodiscard]]
auto const &
V4()
const {
return v4_; }
261 [[nodiscard]]
auto const &
V6()
const {
return v6_; }
272 HandleT handle_{InvalidSocket()};
273 bool non_blocking_{
false};
276 #if defined(__APPLE__)
290 auto ret_iafamily = [](std::int32_t domain) {
297 LOG(FATAL) <<
"未知 IA 家族。";
304 WSAPROTOCOL_INFOW info;
305 socklen_t len =
sizeof(info);
307 getsockopt(handle_, SOL_SOCKET, SO_PROTOCOL_INFO,
reinterpret_cast<char *
>(&info), &len),
309 return ret_iafamily(info.iAddressFamily);
310 #elif defined(__APPLE__)
312 #elif defined(__unix__)
315 socklen_t len =
sizeof(domain);
317 getsockopt(this->
Handle(), SOL_SOCKET, SO_DOMAIN,
reinterpret_cast<char *
>(&domain), &len),
319 return ret_iafamily(domain);
322 socklen_t sizeofsa =
sizeof(sa);
324 if (sizeofsa <
sizeof(uchar_t) * 2) {
325 return ret_iafamily(AF_INET);
327 return ret_iafamily(sa.sa_family);
330 LOG(FATAL) <<
"未知平台。";
331 return ret_iafamily(AF_INET);
335 [[nodiscard]]
bool IsClosed()
const {
return handle_ == InvalidSocket(); }
339 std::int32_t optval = 0;
340 socklen_t len =
sizeof(optval);
341 auto ret = getsockopt(handle_, SOL_SOCKET, SO_ERROR,
reinterpret_cast<char *
>(&optval), &len);
344 return Fail(
"获取套接字错误失败。", std::move(errc));
347 auto errc = std::error_code{optval, std::system_category()};
348 return Fail(
"套接字错误。", std::move(errc));
359 if (err.Code() == std::error_code{EBADF, std::system_category()} ||
360 err.Code() == std::error_code{EINTR, std::system_category()}) {
368 u_long mode = non_block ? 1 : 0;
369 if (ioctlsocket(handle_, FIONBIO, &mode) != NO_ERROR) {
373 std::int32_t flag = fcntl(handle_, F_GETFL, 0);
383 rc = fcntl(handle_, F_SETFL, flag);
388 non_blocking_ = non_block;
395 DWORD tv = timeout.count() * 1000;
397 setsockopt(
Handle(), SOL_SOCKET, SO_RCVTIMEO,
reinterpret_cast<char *
>(&tv),
sizeof(tv));
400 tv.tv_sec = timeout.count();
402 auto rc = setsockopt(
Handle(), SOL_SOCKET, SO_RCVTIMEO,
reinterpret_cast<char const *
>(&tv),
412 auto rc = setsockopt(this->
Handle(), SOL_SOCKET, SO_SNDBUF,
reinterpret_cast<char *
>(&n_bytes),
417 rc = setsockopt(this->
Handle(), SOL_SOCKET, SO_RCVBUF,
reinterpret_cast<char *
>(&n_bytes),
427 auto rc = getsockopt(this->
Handle(), SOL_SOCKET, SO_SNDBUF,
reinterpret_cast<char *
>(n_bytes),
429 if (rc != 0 || optlen !=
sizeof(std::int32_t)) {
436 auto rc = getsockopt(this->
Handle(), SOL_SOCKET, SO_RCVBUF,
reinterpret_cast<char *
>(n_bytes),
438 if (rc != 0 || optlen !=
sizeof(std::int32_t)) {
443 #if defined(__linux__)
444 [[nodiscard]]
Result PendingSendSize(std::int32_t *n_bytes)
const {
445 return ioctl(this->
Handle(), TIOCOUTQ, n_bytes) == 0 ?
Success()
448 [[nodiscard]] Result PendingRecvSize(std::int32_t *n_bytes)
const {
449 return ioctl(this->
Handle(), FIONREAD, n_bytes) == 0 ?
Success()
455 std::int32_t keepalive = 1;
456 auto rc = setsockopt(handle_, SOL_SOCKET, SO_KEEPALIVE,
reinterpret_cast<char *
>(&keepalive),
465 auto rc = setsockopt(handle_, IPPROTO_TCP, TCP_NODELAY,
reinterpret_cast<char *
>(&no_delay),
479 auto rc = this->
Accept(&newsock, &addr);
486 auto interrupt = WSAEINTR;
488 auto interrupt = EINTR;
491 struct sockaddr_in caddr;
492 socklen_t caddr_len =
sizeof(caddr);
493 HandleT newfd = accept(
Handle(),
reinterpret_cast<sockaddr *
>(&caddr), &caddr_len);
500 struct sockaddr_in6 caddr;
501 socklen_t caddr_len =
sizeof(caddr);
502 HandleT newfd = accept(
Handle(),
reinterpret_cast<sockaddr *
>(&caddr), &caddr_len);
522 auto rc = this->
Close();
524 LOG(WARNING) << rc.Report();
554 自动 句柄 =
reinterpret_cast<sockaddr
const *
>(&地址.句柄());
555 如果 (bind(handle_, handle,
sizeof(std::remove_reference_t<decltype(地址.句柄())>)) != 0) {
559 sockaddr_in6 res_addr;
560 socklen_t addrlen =
sizeof(res_addr);
561 如果 (getsockname(handle_,
reinterpret_cast<sockaddr *
>(&res_addr), &addrlen) != 0) {
564 *p_out = ntohs(res_addr.sin6_port);
567 自动 句柄 =
reinterpret_cast<sockaddr
const *
>(&地址.句柄());
568 如果 (bind(handle_, handle,
sizeof(std::remove_reference_t<decltype(地址.句柄())>)) != 0) {
572 sockaddr_in res_addr;
573 socklen_t addrlen =
sizeof(res_addr);
574 如果 (getsockname(handle_,
reinterpret_cast<sockaddr *
>(&res_addr), &addrlen) != 0) {
577 *p_out = ntohs(res_addr.sin_port);
583 [[nodiscard]]
auto 端口()
const {
585 sockaddr_in res_addr;
586 socklen_t addrlen =
sizeof(res_addr);
587 自动 代码 = getsockname(handle_,
reinterpret_cast<sockaddr *
>(&res_addr), &addrlen);
591 返回 std::make_pair(
成功(), std::int32_t{ntohs(res_addr.sin_port)});
593 sockaddr_in6 res_addr;
594 socklen_t addrlen =
sizeof(res_addr);
595 自动 代码 = getsockname(handle_,
reinterpret_cast<sockaddr *
>(&res_addr), &addrlen);
599 返回 std::make_pair(
成功(), std::int32_t{ntohs(res_addr.sin6_port)});
611 std::int32_t errc{0};
613 自动 句柄 =
reinterpret_cast<sockaddr
const *
>(&地址.V4().句柄());
614 errc = bind(handle_, handle,
sizeof(std::remove_reference_t<decltype(地址.V4().句柄())>));
616 自动 句柄 =
reinterpret_cast<sockaddr
const *
>(&地址.V6().句柄());
617 errc = bind(handle_, handle,
sizeof(std::remove_reference_t<decltype(地址.V6().句柄())>));
622 自动 [rc, new_port] = this->
端口();
630 如果 (*端口 != new_port) {
631 返回 失败(
"从绑定获取的端口无效。");
639 [[nodiscard]]
结果 SendAll(
void const *buf, std::size_t len, std::size_t *n_sent) {
640 char const *_buf =
reinterpret_cast<const char *
>(buf);
641 std::size_t &ndone = *n_sent;
644 ssize_t ret = send(handle_, _buf, len - ndone, 0);
659 [[nodiscard]]
结果 RecvAll(
void *buf, std::size_t len, std::size_t *n_recv) {
660 char *_buf =
reinterpret_cast<char *
>(buf);
661 std::size_t &ndone = *n_recv;
664 ssize_t ret = recv(handle_, _buf, len - ndone, MSG_WAITALL);
686 自动 发送(
const void *buf_, std::size_t len, std::int32_t flags = 0) {
687 const char *buf =
reinterpret_cast<const char *
>(buf_);
688 返回 send(handle_, buf, len, flags);
697 自动 接收(
void *buf, std::size_t len, std::int32_t flags = 0) {
698 char *_buf =
static_cast<char *
>(buf);
701 返回 recv(handle_, _buf, len, flags);
711 [[nodiscard]]
结果 接收(std::string *p_str);
716 如果 (InvalidSocket() != handle_) {
728 handle_ = InvalidSocket();
756 #if defined(xgboost_IS_MINGW)
760 自动 fd = socket(
static_cast<std::int32_t
>(domain), SOCK_STREAM, 0);
761 如果 (fd == InvalidSocket()) {
766 #if defined(__APPLE__)
767 socket.domain_ = domain;
774 #if defined(xgboost_IS_MINGW)
778 自动 fd = socket(
static_cast<std::int32_t
>(domain), SOCK_STREAM, 0);
779 如果 (fd == InvalidSocket()) {
784 #if defined(__APPLE__)
785 socket->domain_ = domain;
805 std::chrono::seconds timeout,
818 std::string &ip = *p_out;
819 根据 (host->h_addrtype) {
821 自动 地址 =
reinterpret_cast<struct in_addr *
>(host->h_addr_list[0]);
822 char str[INET_ADDRSTRLEN];
823 inet_ntop(AF_INET, addr, str, INET_ADDRSTRLEN);
828 自动 地址 =
reinterpret_cast<struct in6_addr *
>(host->h_addr_list[0]);
829 char str[INET6_ADDRSTRLEN];
830 inet_ntop(AF_INET6, addr, str, INET6_ADDRSTRLEN);
843 #undef xgboost_CHECK_SYS_CALL
SockAddrV4(sockaddr_in addr)
定义: socket.h:221
static SockAddrV4 InaddrAny()
in_port_t Port() const
定义: socket.h:227
sockaddr_in const & Handle() const
定义: socket.h:238
std::string Addr() const
定义: socket.h:229
static SockAddrV4 Loopback()
SockAddrV4()
定义: socket.h:222
static SockAddrV6 InaddrAny()
SockAddrV6()
定义: socket.h:197
sockaddr_in6 const & Handle() const
定义: socket.h:213
in_port_t Port() const
定义: socket.h:202
SockAddrV6(sockaddr_in6 addr)
定义: socket.h:196
std::string Addr() const
定义: socket.h:204
static SockAddrV6 Loopback()
TCP套接字的地址,可以是IPv4或IPv6。
定义: socket.h:244
bool IsV6() const
定义: socket.h:258
auto const & V6() const
定义: socket.h:261
bool IsV4() const
定义: socket.h:257
auto Domain() const
定义: socket.h:255
SockAddress(SockAddrV4 const &addr)
定义: socket.h:253
auto const & V4() const
定义: socket.h:260
SockAddress(SockAddrV6 const &addr)
定义: socket.h:252
用于简单通信的TCP套接字。
定义: socket.h:267
Result GetSockError() const
获取上次错误代码(如果有)
定义: socket.h:338
Result Recv(std::string *p_str)
接收字符串,格式与RABIT中的Python套接字包装器匹配。
Result RecvTimeout(std::chrono::seconds timeout)
定义: socket.h:392
HandleT const & Handle() const
返回原生套接字文件描述符。
定义: socket.h:539
Result SetNoDelay(std::int32_t no_delay=1)
定义: socket.h:464
TCPSocket & operator=(TCPSocket const &that)=delete
Result BindHost(std::int32_t *p_out)
将套接字绑定到INADDR_ANY,返回操作系统选择的端口。
定义: socket.h:549
Result Listen(std::int32_t backlog=256)
监听传入请求。应在绑定后调用。
static TCPSocket * CreatePtr(SockDomain domain)
定义: socket.h:773
Result Shutdown()
关闭套接字。
定义: socket.h:735
system::SocketT HandleT
定义: socket.h:269
Result Bind(StringView ip, std::int32_t *port)
将套接字绑定到地址。
定义: socket.h:608
TCPSocket & operator=(TCPSocket &&that) noexcept(true)
定义: socket.h:532
auto Port() const
定义: socket.h:583
Result SetKeepAlive()
定义: socket.h:454
Result Accept(TCPSocket *out, SockAddress *addr)
定义: socket.h:484
TCPSocket(TCPSocket const &that)=delete
Result RecvBufSize(std::int32_t *n_bytes)
定义: socket.h:434
Result SendBufSize(std::int32_t *n_bytes)
定义: socket.h:425
std::size_t Send(StringView str)
发送字符串,格式与RABIT中的Python套接字包装器匹配。
auto Domain() const -> SockDomain
返回套接字域。
定义: socket.h:289
static TCPSocket Create(SockDomain domain)
在指定域上创建TCP套接字。
定义: socket.h:755
Result RecvAll(void *buf, std::size_t len, std::size_t *n_recv)
接收数据,无错误则应接收所有数据。
定义: socket.h:659
auto Recv(void *buf, std::size_t len, std::int32_t flags=0)
使用套接字接收数据
定义: socket.h:697
bool NonBlocking() const
定义: socket.h:391
bool IsClosed() const
定义: socket.h:335
Result NonBlocking(bool non_block)
定义: socket.h:366
TCPSocket Accept()
接受新连接,为新连接返回一个新的TCP套接字。
定义: socket.h:476
~TCPSocket()
定义: socket.h:520
auto Send(const void *buf_, std::size_t len, std::int32_t flags=0)
使用套接字发送数据。
定义: socket.h:686
Result SendAll(void const *buf, std::size_t len, std::size_t *n_sent)
发送数据,无错误则应发送所有数据。
定义: socket.h:639
bool BadSocket() const
检查是否发生异常
定义: socket.h:354
TCPSocket(TCPSocket &&that) noexcept(true)
定义: socket.h:530
Result SetBufSize(std::int32_t n_bytes)
定义: socket.h:411
Result Close()
关闭套接字,如果套接字未关闭,则在析构函数中自动调用。
定义: socket.h:715
void swap(xgboost::IntrusivePtr< T > &x, xgboost::IntrusivePtr< T > &y) noexcept
定义: intrusive_ptr.h:209
Result Connect(xgboost::StringView host, std::int32_t port, std::int32_t retry, std::chrono::seconds timeout, xgboost::collective::TCPSocket *out_conn)
连接到远程地址,如果失败则返回错误代码。
Result GetHostName(std::string *p_out)
获取本地主机名。
SockAddress MakeSockAddress(StringView host, in_port_t port)
解析主机地址并返回SockAddress实例。支持IPv4和IPv6主机。
SockDomain
定义: socket.h:184
void SafeColl(Result const &rc, char const *file=__builtin_FILE(), std::int32_t line=__builtin_LINE())
auto Fail(std::string msg, char const *file=__builtin_FILE(), std::int32_t line=__builtin_LINE())
返回失败。
定义: result.h:124
Result INetNToP(H const &host, std::string *p_out)
inet_ntop
定义: socket.h:817
auto Success() noexcept(true)
返回成功。
定义: result.h:120
bool ErrorWouldBlock(std::int32_t errsv) noexcept(true)
定义: socket.h:137
auto ThrowAtError(StringView fn_name, std::int32_t errsv=LastError())
定义: socket.h:92
void SocketStartup()
定义: socket.h:150
std::int32_t CloseSocket(SocketT fd)
定义: socket.h:114
bool LastErrorWouldBlock()
定义: socket.h:145
std::int32_t LastError()
定义: socket.h:69
void SocketFinalize()
定义: socket.h:163
std::int32_t ShutdownSocket(SocketT fd)
定义: socket.h:122
collective::Result FailWithCode(std::string msg)
定义: socket.h:78
int SocketT
定义: socket.h:101
集成目标、gbm和评估的学习器接口。这是用户面临的XGB...
Definition: base.h:97
int SOCKET
定义: poll_utils.h:40
#define __builtin_LINE()
定义: result.h:57
#define __builtin_FILE()
定义: result.h:56
#define INVALID_SOCKET
定义: socket.h:102
#define xgboost_CHECK_SYS_CALL(exp, expected)
定义: socket.h:106
Definition: string_view.h:16
一种比抛出dmlc异常更容易处理的错误类型。我们可以记录并传播s...
定义: result.h:67