12 #include <system_error>  
 15 #if defined(__linux__) 
 16 #include <sys/ioctl.h>  
 26 using in_port_t = std::uint16_t;
 
 29 #pragma comment(lib, "Ws2_32.lib")
 
 32 #if !defined(xgboost_IS_MINGW) 
 38 #include <arpa/inet.h>  
 40 #include <netinet/in.h>  
 41 #include <netinet/in.h>  
 42 #include <netinet/tcp.h>  
 43 #include <sys/socket.h>  
 46 #if defined(__sun) || defined(sun) 
 47 #include <sys/sockio.h> 
 54 #include "xgboost/logging.h"  
 57 #if !defined(HOST_NAME_MAX) 
 58 #define HOST_NAME_MAX 256  
 63 #if defined(xgboost_IS_MINGW) 
 65 inline void MingWError() { LOG(FATAL) << 
"mingw不支持分布式训练。"; }
 
 71  return WSAGetLastError();
 
 82 #if defined(__GLIBC__) 
 86  auto err = std::error_code{errsv, std::system_category()};
 
 88  << file << 
"(" << line << 
"): 调用 `" << fn_name << 
"` 失败: " << err.message()
 
 93  auto err = std::error_code{errsv, std::system_category()};
 
 94  LOG(FATAL) << 
"调用 `" << fn_name << 
"` 失败: " << err.message() << std::endl;
 
 102 #define INVALID_SOCKET -1 
 105 #if !defined(xgboost_CHECK_SYS_CALL) 
 106 #define xgboost_CHECK_SYS_CALL(exp, expected) \ 
 108  if (XGBOOST_EXPECT((exp) != (expected), false)) { \ 
 109  ::xgboost::system::ThrowAtError(#exp); \ 
 116  return closesocket(fd);
 
 124  auto rc = shutdown(fd, SD_BOTH);
 
 125  if (rc != 0 && 
LastError() == WSANOTINITIALISED) {
 
 129  auto rc = shutdown(fd, SHUT_RDWR);
 
 130  if (rc != 0 && 
LastError() == ENOTCONN) {
 
 139  return errsv == WSAEWOULDBLOCK;
 
 141  return errsv == EAGAIN || errsv == EWOULDBLOCK || errsv == EINPROGRESS;
 
 153  if (WSAStartup(MAKEWORD(2, 2), &wsa_data) == -1) {
 
 156  if (LOBYTE(wsa_data.wVersion) != 2 || HIBYTE(wsa_data.wVersion) != 2) {
 
 158  LOG(FATAL) << 
"找不到可用的 Winsock.dll 版本";
 
 169 #if defined(_WIN32) && defined(xgboost_IS_MINGW) 
 171 inline const char *inet_ntop(
int, 
const void *, 
char *, socklen_t) { 
 
 181 namespace collective {
 
 202  in_port_t 
Port()
 const { 
return ntohs(addr_.sin6_port); }
 
 205  char buf[INET6_ADDRSTRLEN];
 
 206  auto const *s = system::inet_ntop(
static_cast<std::int32_t
>(
SockDomain::kV6), &addr_.sin6_addr,
 
 207  buf, INET6_ADDRSTRLEN);
 
 213  sockaddr_in6 
const &
Handle()
 const { 
return addr_; }
 
 227  [[nodiscard]] in_port_t 
Port()
 const { 
return ntohs(addr_.sin_port); }
 
 229  [[nodiscard]] std::string 
Addr()
 const {
 
 230  char buf[INET_ADDRSTRLEN];
 
 231  auto const *s = system::inet_ntop(
static_cast<std::int32_t
>(
SockDomain::kV4), &addr_.sin_addr,
 
 232  buf, INET_ADDRSTRLEN);
 
 238  [[nodiscard]] sockaddr_in 
const &
Handle()
 const { 
return addr_; }
 
 255  [[nodiscard]] 
auto Domain()
 const { 
return domain_; }
 
 258  [[nodiscard]] 
bool IsV6()
 const { 
return !
IsV4(); }
 
 260  [[nodiscard]] 
auto const &
V4()
 const { 
return v4_; }
 
 261  [[nodiscard]] 
auto const &
V6()
 const { 
return v6_; }
 
 272  HandleT handle_{InvalidSocket()};
 
 273  bool non_blocking_{
false};
 
 276 #if defined(__APPLE__) 
 290  auto ret_iafamily = [](std::int32_t domain) {
 
 297  LOG(FATAL) << 
"未知 IA 家族。";
 
 304  WSAPROTOCOL_INFOW info;
 
 305  socklen_t len = 
sizeof(info);
 
 307  getsockopt(handle_, SOL_SOCKET, SO_PROTOCOL_INFO, 
reinterpret_cast<char *
>(&info), &len),
 
 309  return ret_iafamily(info.iAddressFamily);
 
 310 #elif defined(__APPLE__) 
 312 #elif defined(__unix__) 
 315  socklen_t len = 
sizeof(domain);
 
 317  getsockopt(this->
Handle(), SOL_SOCKET, SO_DOMAIN, 
reinterpret_cast<char *
>(&domain), &len),
 
 319  return ret_iafamily(domain);
 
 322  socklen_t sizeofsa = 
sizeof(sa);
 
 324  if (sizeofsa < 
sizeof(uchar_t) * 2) {
 
 325  return ret_iafamily(AF_INET);
 
 327  return ret_iafamily(sa.sa_family);
 
 330  LOG(FATAL) << 
"未知平台。";
 
 331  return ret_iafamily(AF_INET);
 
 335  [[nodiscard]] 
bool IsClosed()
 const { 
return handle_ == InvalidSocket(); }
 
 339  std::int32_t optval = 0;
 
 340  socklen_t len = 
sizeof(optval);
 
 341  auto ret = getsockopt(handle_, SOL_SOCKET, SO_ERROR, 
reinterpret_cast<char *
>(&optval), &len);
 
 344  return Fail(
"获取套接字错误失败。", std::move(errc));
 
 347  auto errc = std::error_code{optval, std::system_category()};
 
 348  return Fail(
"套接字错误。", std::move(errc));
 
 359  if (err.Code() == std::error_code{EBADF, std::system_category()} || 
 
 360  err.Code() == std::error_code{EINTR, std::system_category()}) { 
 
 368  u_long mode = non_block ? 1 : 0;
 
 369  if (ioctlsocket(handle_, FIONBIO, &mode) != NO_ERROR) {
 
 373  std::int32_t flag = fcntl(handle_, F_GETFL, 0);
 
 383  rc = fcntl(handle_, F_SETFL, flag);
 
 388  non_blocking_ = non_block;
 
 395  DWORD tv = timeout.count() * 1000;
 
 397  setsockopt(
Handle(), SOL_SOCKET, SO_RCVTIMEO, 
reinterpret_cast<char *
>(&tv), 
sizeof(tv));
 
 400  tv.tv_sec = timeout.count();
 
 402  auto rc = setsockopt(
Handle(), SOL_SOCKET, SO_RCVTIMEO, 
reinterpret_cast<char const *
>(&tv),
 
 412  auto rc = setsockopt(this->
Handle(), SOL_SOCKET, SO_SNDBUF, 
reinterpret_cast<char *
>(&n_bytes),
 
 417  rc = setsockopt(this->
Handle(), SOL_SOCKET, SO_RCVBUF, 
reinterpret_cast<char *
>(&n_bytes),
 
 427  auto rc = getsockopt(this->
Handle(), SOL_SOCKET, SO_SNDBUF, 
reinterpret_cast<char *
>(n_bytes),
 
 429  if (rc != 0 || optlen != 
sizeof(std::int32_t)) {
 
 436  auto rc = getsockopt(this->
Handle(), SOL_SOCKET, SO_RCVBUF, 
reinterpret_cast<char *
>(n_bytes),
 
 438  if (rc != 0 || optlen != 
sizeof(std::int32_t)) {
 
 443 #if defined(__linux__) 
 444  [[nodiscard]] 
Result PendingSendSize(std::int32_t *n_bytes)
 const {
 
 445  return ioctl(this->
Handle(), TIOCOUTQ, n_bytes) == 0 ? 
Success()
 
 448  [[nodiscard]] Result PendingRecvSize(std::int32_t *n_bytes)
 const {
 
 449  return ioctl(this->
Handle(), FIONREAD, n_bytes) == 0 ? 
Success()
 
 455  std::int32_t keepalive = 1;
 
 456  auto rc = setsockopt(handle_, SOL_SOCKET, SO_KEEPALIVE, 
reinterpret_cast<char *
>(&keepalive),
 
 465  auto rc = setsockopt(handle_, IPPROTO_TCP, TCP_NODELAY, 
reinterpret_cast<char *
>(&no_delay),
 
 479  auto rc = this->
Accept(&newsock, &addr);
 
 486  auto interrupt = WSAEINTR;
 
 488  auto interrupt = EINTR;
 
 491  struct sockaddr_in caddr;
 
 492  socklen_t caddr_len = 
sizeof(caddr);
 
 493  HandleT newfd = accept(
Handle(), 
reinterpret_cast<sockaddr *
>(&caddr), &caddr_len);
 
 500  struct sockaddr_in6 caddr;
 
 501  socklen_t caddr_len = 
sizeof(caddr);
 
 502  HandleT newfd = accept(
Handle(), 
reinterpret_cast<sockaddr *
>(&caddr), &caddr_len);
 
 522  auto rc = this->
Close();
 
 524  LOG(WARNING) << rc.Report();
 
 554  自动 句柄 = 
reinterpret_cast<sockaddr 
const *
>(&地址.句柄());
 
 555  如果 (bind(handle_, handle, 
sizeof(std::remove_reference_t<decltype(地址.句柄())>)) != 0) {
 
 559  sockaddr_in6 res_addr;
 
 560  socklen_t addrlen = 
sizeof(res_addr);
 
 561  如果 (getsockname(handle_, 
reinterpret_cast<sockaddr *
>(&res_addr), &addrlen) != 0) {
 
 564  *p_out = ntohs(res_addr.sin6_port);
 
 567  自动 句柄 = 
reinterpret_cast<sockaddr 
const *
>(&地址.句柄());
 
 568  如果 (bind(handle_, handle, 
sizeof(std::remove_reference_t<decltype(地址.句柄())>)) != 0) {
 
 572  sockaddr_in res_addr;
 
 573  socklen_t addrlen = 
sizeof(res_addr);
 
 574  如果 (getsockname(handle_, 
reinterpret_cast<sockaddr *
>(&res_addr), &addrlen) != 0) {
 
 577  *p_out = ntohs(res_addr.sin_port);
 
 583  [[nodiscard]] 
auto 端口()
 const {
 
 585  sockaddr_in res_addr;
 
 586  socklen_t addrlen = 
sizeof(res_addr);
 
 587  自动 代码 = getsockname(handle_, 
reinterpret_cast<sockaddr *
>(&res_addr), &addrlen);
 
 591  返回 std::make_pair(
成功(), std::int32_t{ntohs(res_addr.sin_port)});
 
 593  sockaddr_in6 res_addr;
 
 594  socklen_t addrlen = 
sizeof(res_addr);
 
 595  自动 代码 = getsockname(handle_, 
reinterpret_cast<sockaddr *
>(&res_addr), &addrlen);
 
 599  返回 std::make_pair(
成功(), std::int32_t{ntohs(res_addr.sin6_port)});
 
 611  std::int32_t errc{0};
 
 613  自动 句柄 = 
reinterpret_cast<sockaddr 
const *
>(&地址.V4().句柄());
 
 614  errc = bind(handle_, handle, 
sizeof(std::remove_reference_t<decltype(地址.V4().句柄())>));
 
 616  自动 句柄 = 
reinterpret_cast<sockaddr 
const *
>(&地址.V6().句柄());
 
 617  errc = bind(handle_, handle, 
sizeof(std::remove_reference_t<decltype(地址.V6().句柄())>));
 
 622  自动 [rc, new_port] = this->
端口();
 
 630  如果 (*端口 != new_port) {
 
 631  返回 失败(
"从绑定获取的端口无效。");
 
 639  [[nodiscard]] 
结果 SendAll(
void const *buf, std::size_t len, std::size_t *n_sent) {
 
 640  char const *_buf = 
reinterpret_cast<const char *
>(buf);
 
 641  std::size_t &ndone = *n_sent;
 
 644  ssize_t ret = send(handle_, _buf, len - ndone, 0);
 
 659  [[nodiscard]] 
结果 RecvAll(
void *buf, std::size_t len, std::size_t *n_recv) {
 
 660  char *_buf = 
reinterpret_cast<char *
>(buf);
 
 661  std::size_t &ndone = *n_recv;
 
 664  ssize_t ret = recv(handle_, _buf, len - ndone, MSG_WAITALL);
 
 686  自动 发送(
const void *buf_, std::size_t len, std::int32_t flags = 0) {
 
 687  const char *buf = 
reinterpret_cast<const char *
>(buf_);
 
 688  返回 send(handle_, buf, len, flags);
 
 697  自动 接收(
void *buf, std::size_t len, std::int32_t flags = 0) {
 
 698  char *_buf = 
static_cast<char *
>(buf);
 
 701  返回 recv(handle_, _buf, len, flags);
 
 711  [[nodiscard]] 
结果 接收(std::string *p_str);
 
 716  如果 (InvalidSocket() != handle_) {
 
 728  handle_ = InvalidSocket();
 
 756 #if defined(xgboost_IS_MINGW) 
 760  自动 fd = socket(
static_cast<std::int32_t
>(domain), SOCK_STREAM, 0);
 
 761  如果 (fd == InvalidSocket()) {
 
 766 #if defined(__APPLE__) 
 767  socket.domain_ = domain;
 
 774 #if defined(xgboost_IS_MINGW) 
 778  自动 fd = socket(
static_cast<std::int32_t
>(domain), SOCK_STREAM, 0);
 
 779  如果 (fd == InvalidSocket()) {
 
 784 #if defined(__APPLE__) 
 785  socket->domain_ = domain;
 
 805  std::chrono::seconds timeout,
 
 818  std::string &ip = *p_out;
 
 819  根据 (host->h_addrtype) {
 
 821  自动 地址 = 
reinterpret_cast<struct in_addr *
>(host->h_addr_list[0]);
 
 822  char str[INET_ADDRSTRLEN];
 
 823  inet_ntop(AF_INET, addr, str, INET_ADDRSTRLEN);
 
 828  自动 地址 = 
reinterpret_cast<struct in6_addr *
>(host->h_addr_list[0]);
 
 829  char str[INET6_ADDRSTRLEN];
 
 830  inet_ntop(AF_INET6, addr, str, INET6_ADDRSTRLEN);
 
 843 #undef xgboost_CHECK_SYS_CALL 
SockAddrV4(sockaddr_in addr)
定义: socket.h:221
static SockAddrV4 InaddrAny()
in_port_t Port() const
定义: socket.h:227
sockaddr_in const & Handle() const
定义: socket.h:238
std::string Addr() const
定义: socket.h:229
static SockAddrV4 Loopback()
SockAddrV4()
定义: socket.h:222
static SockAddrV6 InaddrAny()
SockAddrV6()
定义: socket.h:197
sockaddr_in6 const & Handle() const
定义: socket.h:213
in_port_t Port() const
定义: socket.h:202
SockAddrV6(sockaddr_in6 addr)
定义: socket.h:196
std::string Addr() const
定义: socket.h:204
static SockAddrV6 Loopback()
TCP套接字的地址,可以是IPv4或IPv6。
定义: socket.h:244
bool IsV6() const
定义: socket.h:258
auto const & V6() const
定义: socket.h:261
bool IsV4() const
定义: socket.h:257
auto Domain() const
定义: socket.h:255
SockAddress(SockAddrV4 const &addr)
定义: socket.h:253
auto const & V4() const
定义: socket.h:260
SockAddress(SockAddrV6 const &addr)
定义: socket.h:252
用于简单通信的TCP套接字。
定义: socket.h:267
Result GetSockError() const
获取上次错误代码(如果有)
定义: socket.h:338
Result Recv(std::string *p_str)
接收字符串,格式与RABIT中的Python套接字包装器匹配。
Result RecvTimeout(std::chrono::seconds timeout)
定义: socket.h:392
HandleT const & Handle() const
返回原生套接字文件描述符。
定义: socket.h:539
Result SetNoDelay(std::int32_t no_delay=1)
定义: socket.h:464
TCPSocket & operator=(TCPSocket const &that)=delete
Result BindHost(std::int32_t *p_out)
将套接字绑定到INADDR_ANY,返回操作系统选择的端口。
定义: socket.h:549
Result Listen(std::int32_t backlog=256)
监听传入请求。应在绑定后调用。
static TCPSocket * CreatePtr(SockDomain domain)
定义: socket.h:773
Result Shutdown()
关闭套接字。
定义: socket.h:735
system::SocketT HandleT
定义: socket.h:269
Result Bind(StringView ip, std::int32_t *port)
将套接字绑定到地址。
定义: socket.h:608
TCPSocket & operator=(TCPSocket &&that) noexcept(true)
定义: socket.h:532
auto Port() const
定义: socket.h:583
Result SetKeepAlive()
定义: socket.h:454
Result Accept(TCPSocket *out, SockAddress *addr)
定义: socket.h:484
TCPSocket(TCPSocket const &that)=delete
Result RecvBufSize(std::int32_t *n_bytes)
定义: socket.h:434
Result SendBufSize(std::int32_t *n_bytes)
定义: socket.h:425
std::size_t Send(StringView str)
发送字符串,格式与RABIT中的Python套接字包装器匹配。
auto Domain() const -> SockDomain
返回套接字域。
定义: socket.h:289
static TCPSocket Create(SockDomain domain)
在指定域上创建TCP套接字。
定义: socket.h:755
Result RecvAll(void *buf, std::size_t len, std::size_t *n_recv)
接收数据,无错误则应接收所有数据。
定义: socket.h:659
auto Recv(void *buf, std::size_t len, std::int32_t flags=0)
使用套接字接收数据
定义: socket.h:697
bool NonBlocking() const
定义: socket.h:391
bool IsClosed() const
定义: socket.h:335
Result NonBlocking(bool non_block)
定义: socket.h:366
TCPSocket Accept()
接受新连接,为新连接返回一个新的TCP套接字。
定义: socket.h:476
~TCPSocket()
定义: socket.h:520
auto Send(const void *buf_, std::size_t len, std::int32_t flags=0)
使用套接字发送数据。
定义: socket.h:686
Result SendAll(void const *buf, std::size_t len, std::size_t *n_sent)
发送数据,无错误则应发送所有数据。
定义: socket.h:639
bool BadSocket() const
检查是否发生异常
定义: socket.h:354
TCPSocket(TCPSocket &&that) noexcept(true)
定义: socket.h:530
Result SetBufSize(std::int32_t n_bytes)
定义: socket.h:411
Result Close()
关闭套接字,如果套接字未关闭,则在析构函数中自动调用。
定义: socket.h:715
void swap(xgboost::IntrusivePtr< T > &x, xgboost::IntrusivePtr< T > &y) noexcept
定义: intrusive_ptr.h:209
Result Connect(xgboost::StringView host, std::int32_t port, std::int32_t retry, std::chrono::seconds timeout, xgboost::collective::TCPSocket *out_conn)
连接到远程地址,如果失败则返回错误代码。
Result GetHostName(std::string *p_out)
获取本地主机名。
SockAddress MakeSockAddress(StringView host, in_port_t port)
解析主机地址并返回SockAddress实例。支持IPv4和IPv6主机。
SockDomain
定义: socket.h:184
void SafeColl(Result const &rc, char const *file=__builtin_FILE(), std::int32_t line=__builtin_LINE())
auto Fail(std::string msg, char const *file=__builtin_FILE(), std::int32_t line=__builtin_LINE())
返回失败。
定义: result.h:124
Result INetNToP(H const &host, std::string *p_out)
inet_ntop
定义: socket.h:817
auto Success() noexcept(true)
返回成功。
定义: result.h:120
bool ErrorWouldBlock(std::int32_t errsv) noexcept(true)
定义: socket.h:137
auto ThrowAtError(StringView fn_name, std::int32_t errsv=LastError())
定义: socket.h:92
void SocketStartup()
定义: socket.h:150
std::int32_t CloseSocket(SocketT fd)
定义: socket.h:114
bool LastErrorWouldBlock()
定义: socket.h:145
std::int32_t LastError()
定义: socket.h:69
void SocketFinalize()
定义: socket.h:163
std::int32_t ShutdownSocket(SocketT fd)
定义: socket.h:122
collective::Result FailWithCode(std::string msg)
定义: socket.h:78
int SocketT
定义: socket.h:101
集成目标、gbm和评估的学习器接口。这是用户面临的XGB...
Definition: base.h:97
int SOCKET
定义: poll_utils.h:40
#define __builtin_LINE()
定义: result.h:57
#define __builtin_FILE()
定义: result.h:56
#define INVALID_SOCKET
定义: socket.h:102
#define xgboost_CHECK_SYS_CALL(exp, expected)
定义: socket.h:106
Definition: string_view.h:16
一种比抛出dmlc异常更容易处理的错误类型。我们可以记录并传播s...
定义: result.h:67