xgboost
类型定义 | 函数
集合操作

暴露 XGBoost 内部通信器的实验性支持。 更多...

类型定义

typedef void * TrackerHandle
 Tracker 的句柄。 更多...
 

函数

int XGTrackerCreate (char const *config, TrackerHandle *handle)
 创建一个新的 tracker。 更多...
 
int XGTrackerWorkerArgs (TrackerHandle handle, char const **args)
 获取运行 worker 所需的参数。这应在调用 XGTrackerRun() 后调用。 更多...
 
int XGTrackerRun (TrackerHandle handle, char const *config)
 启动 tracker。tracker 在后台运行,一旦 tracker 启动,此函数立即返回。 更多...
 
int XGTrackerWaitFor (TrackerHandle handle, char const *config)
 等待 tracker 完成。应在调用 XGTrackerRun() 后调用此函数。此函数将阻塞,直到 tracker 任务完成或达到超时。 更多...
 
int XGTrackerFree (TrackerHandle handle)
 释放 tracker 实例。这应在调用 XGTrackerWaitFor() 后调用。如果未正确等待 tracker,此函数将关闭与 tracker 的所有连接,可能导致未定义的行为。 更多...
 
int XGCommunicatorInit (char const *config)
 初始化集合通信器。 更多...
 
int XGCommunicatorFinalize (void)
 终止集合通信器。 更多...
 
int XGCommunicatorGetRank (void)
 获取当前进程的 rank。 更多...
 
int XGCommunicatorGetWorldSize (void)
 获取进程总数(world size)。 更多...
 
int XGCommunicatorIsDistributed (void)
 获取通信器是否是分布式的。 更多...
 
int XGCommunicatorPrint (char const *message)
 向 tracker 打印消息。 更多...
 
int XGCommunicatorGetProcessorName (const char **name_str)
 获取处理器的名称。 更多...
 
int XGCommunicatorBroadcast (void *send_receive_buffer, size_t size, int root)
 从 root 进程向所有其他进程广播内存区域。此函数不是线程安全的。 更多...
 
int XGCommunicatorAllreduce (void *send_receive_buffer, size_t count, int data_type, int op)
 执行 in-place allreduce 操作。此函数不是线程安全的。 更多...
 

详细描述

暴露 XGBoost 内部通信器的实验性支持。

注意
这仍在开发中。

XGBoost 中的集合通信器从 dmlc 的 rabit 项目演变而来,但自从采用后已发生显著变化。它由一个 tracker 和一组 worker 组成。tracker 负责启动通信组并处理日志等集中式任务。worker 是执行 allreduce 等集合任务的实际通信器。

要使用集合实现,首先需要使用相应的参数创建一个 tracker,然后使用 XGTrackerWorkerArgs() 获取 worker 的参数。然后可以将获得的参数传递给 XGCommunicatorInit() 函数。调用 XGCommunicatorInit() 必须伴随对 XGCommunicatorFinalize() 的调用以进行清理。请注意,通信器使用 C++ 中的 std::thread,由于运行时关闭序列,在 C++ 析构函数中可能存在未定义的行为。最好在运行时关闭之前调用 XGCommunicatorFinalize()。此要求类似于 Python 线程或 socket,不应依赖于 __del__ 函数。

由于它是 XGBoost 的一部分,当调用 XGBoost 函数时会返回错误,例如,训练 booster 可能会返回连接错误。

类型定义文档

◆ TrackerHandle

typedef void* TrackerHandle

Tracker 的句柄。

目前 XGBoost 中有两种类型的 tracker,第一种是 rabit,另一种是 federatedrabit 用于常规集合通信,而 federated 用于联邦学习。

函数文档

◆ XGCommunicatorAllreduce()

int XGCommunicatorAllreduce ( void *  send_receive_buffer,
size_t  count,
int  data_type,
int  op 
)

执行 in-place allreduce 操作。此函数不是线程安全的。

用法示例:以下代码计算结果的和

enum class Op {
kMax = 0, kMin = 1, kSum = 2, kBitwiseAND = 3, kBitwiseOR = 4, kBitwiseXOR = 5
};
std::vector<int> data(10);
...
Allreduce(data.data(), data.size(), DataType:kInt32, Op::kSum);
...
DataType
xgboost 接口接受的数据类型
定义于: data.h:33
参数
send_receive_buffer用于发送和接收数据的缓冲区。
count待归约元素的数量。
data_type数据类型枚举,参见 communicator.h 中的 xgboost::collective::DataType。
op操作类型枚举,参见 communicator.h 中的 xgboost::collective::Operation。
返回值
成功返回 0,失败返回 -1。

◆ XGCommunicatorBroadcast()

int XGCommunicatorBroadcast ( void *  send_receive_buffer,
size_t  size,
int  root 
)

从 root 进程向所有其他进程广播内存区域。此函数不是线程安全的。

示例

int a = 1;
Broadcast(&a, sizeof(a), root);
参数
send_receive_buffer指向发送或接收缓冲区的指针。
size数据大小(字节)。
root进行广播的进程 rank。
返回值
成功返回 0,失败返回 -1。

◆ XGCommunicatorFinalize()

int XGCommunicatorFinalize ( void  )

终止集合通信器。

完成所有任务后调用此函数。

返回值
成功返回 0,失败返回 -1。

◆ XGCommunicatorGetProcessorName()

int XGCommunicatorGetProcessorName ( const char **  name_str)

获取处理器的名称。

参数
name_str指向接收返回的处理器名称的指针。
返回值
成功返回 0,失败返回 -1。

◆ XGCommunicatorGetRank()

int XGCommunicatorGetRank ( void  )

获取当前进程的 rank。

返回值
worker 的 rank。

◆ XGCommunicatorGetWorldSize()

int XGCommunicatorGetWorldSize ( void  )

获取进程总数。

返回值
总 world size。

◆ XGCommunicatorInit()

int XGCommunicatorInit ( char const *  config)

初始化集合通信器。

当前通信器 API 处于实验阶段,函数签名将来可能会更改,恕不另行通知。

在使用任何功能之前,在 worker 进程中调用此函数一次。请确保在使用后调用 XGCommunicatorFinalize()。初始化的通信器是全局线程局部变量。

参数
configJSON 编码的配置。可接受的 JSON 键包括:
  • dmlc_communicator:通信器类型,应与 tracker 类型匹配。
    • rabit:使用 Rabit。如果未指定类型,这是默认值。
    • federated:使用 gRPC 接口进行联邦学习。

仅适用于 rabit 通信器

  • dmlc_tracker_uri:tracker 的主机名或 IP 地址。
  • dmlc_tracker_port:tracker 的端口号。
  • dmlc_task_id:当前任务的 ID,可用于获得确定性的 rank 分配。
  • dmlc_retry:连接失败的重试次数。
  • dmlc_timeout:超时时间(秒)。
  • dmlc_nccl_path:nccl 共享库 libnccl.so 的路径。

仅适用于 federated 通信器(环境变量使用大写,运行时配置使用小写)

  • federated_server_address:联邦服务器的地址。
  • federated_world_size:联邦 worker 的数量。
  • federated_rank:当前 worker 的 rank。
  • federated_server_cert_path:服务器证书文件路径。仅在 SSL 模式下需要。
  • federated_client_key_path:客户端密钥文件路径。仅在 SSL 模式下需要。
  • federated_client_cert_path:客户端证书文件路径。仅在 SSL 模式下需要。
返回值
成功返回 0,失败返回 -1。

◆ XGCommunicatorIsDistributed()`

int XGCommunicatorIsDistributed ( void  )

获取通信器是否是分布式的。

返回值
如果通信器是分布式的,则返回 True。

◆ XGCommunicatorPrint()`

int XGCommunicatorPrint ( char const *  message)

向 tracker 打印消息。

此函数可用于向监视 tracker 的用户传达进度信息。

参数
message待打印的消息。
返回值
成功返回 0,失败返回 -1。

◆ XGTrackerCreate()`

int XGTrackerCreate ( char const *  config,
TrackerHandle handle 
)

创建一个新的 tracker。

参数
configJSON 编码的参数。
  • dmlc_communicator:字符串,要创建的 tracker 类型。可用选项有 rabitfederated。更多信息请参见 TrackerHandle
  • n_workers:整数,worker 的数量。
  • port:(可选)整数,此 tracker 应监听的端口。
  • timeout:(可选)整数,各种网络操作的超时时间(秒)。默认值为 300 秒。

一些配置项是 rabit 特有的

  • host:(可选)字符串,由 rabit tracker 使用,用于指定主机地址。当通信器无法可靠地获取主机地址时,这会很有用。
  • sortby:(可选)整数。
    • 0:按主机名对 worker 进行排序。
    • 1:按任务 ID 对 worker 进行排序。

一些 federated 特有的配置项

  • federated_secure:布尔值,表示是否为安全服务器。测试时设置为 False。
  • server_key_path:服务器密钥路径。仅在为安全服务器时使用。
  • server_cert_path:服务器证书路径。仅在为安全服务器时使用。
  • client_cert_path:客户端证书路径。仅在为安全服务器时使用。
参数
handle创建的 tracker 的句柄。
返回值
成功返回 0,失败返回 -1。

◆ XGTrackerFree()`

int XGTrackerFree ( TrackerHandle  handle)

释放 tracker 实例。这应在调用 XGTrackerWaitFor() 后调用。如果未正确等待 tracker,此函数将关闭与 tracker 的所有连接,可能导致未定义的行为。

参数
handletracker 的句柄。
返回值
成功返回 0,失败返回 -1。

◆ XGTrackerRun()`

int XGTrackerRun ( TrackerHandle  handle,
char const *  config 
)

启动 tracker。tracker 在后台运行,一旦 tracker 启动,此函数立即返回。

参数
handletracker 的句柄。
config目前未使用,保留供将来使用。
返回值
成功返回 0,失败返回 -1。

◆ XGTrackerWaitFor()`

int XGTrackerWaitFor ( TrackerHandle  handle,
char const *  config 
)

等待 tracker 完成。应在调用 XGTrackerRun() 后调用此函数。此函数将阻塞,直到 tracker 任务完成或达到超时。

参数
handletracker 的句柄。
configJSON 编码的配置。目前不需要参数,保留供将来使用。
返回值
成功返回 0,失败返回 -1。

◆ XGTrackerWorkerArgs()`

int XGTrackerWorkerArgs ( TrackerHandle  handle,
char const **  args 
)

获取运行 worker 所需的参数。这应在调用 XGTrackerRun() 后调用。

参数
handletracker 的句柄。
args作为 JSON 文档返回的参数。
返回值
成功返回 0,失败返回 -1。