xgboost
类型定义 | 函数
集体

XGBoost 中内部通信器的实验性支持。更多...

类型定义

typedef void * TrackerHandle
 追踪器句柄。更多...
 

函数

int XGTrackerCreate (char const *config, TrackerHandle *handle)
 创建一个新的追踪器。更多...
 
int XGTrackerWorkerArgs (TrackerHandle handle, char const **args)
 获取运行工作器所需的参数。这应该在 XGTrackerRun() 之后调用。更多...
 
int XGTrackerRun (TrackerHandle handle, char const *config)
 启动追踪器。追踪器在后台运行,此函数在追踪器启动后返回。更多...
 
int XGTrackerWaitFor (TrackerHandle handle, char const *config)
 等待追踪器完成,应在 XGTrackerRun() 之后调用。此函数将阻塞,直到追踪器任务完成或达到超时。更多...
 
int XGTrackerFree (TrackerHandle handle)
 释放追踪器实例。这应该在 XGTrackerWaitFor() 之后调用。如果追踪器没有正确等待,此函数将关闭与追踪器的所有连接,可能导致未定义行为。更多...
 
int XGCommunicatorInit (char const *config)
 初始化集体通信器。更多...
 
int XGCommunicatorFinalize (void)
 终结集体通信器。更多...
 
int XGCommunicatorGetRank (void)
 获取当前进程的排名。更多...
 
int XGCommunicatorGetWorldSize (void)
 获取进程总数。更多...
 
int XGCommunicatorIsDistributed (void)
 获取通信器是否是分布式的。更多...
 
int XGCommunicatorPrint (char const *message)
 将消息打印到追踪器。更多...
 
int XGCommunicatorGetProcessorName (const char **name_str)
 获取处理器名称。更多...
 
int XGCommunicatorBroadcast (void *send_receive_buffer, size_t size, int root)
 从根节点向所有其他节点广播内存区域。此函数不是线程安全的。更多...
 
int XGCommunicatorAllreduce (void *send_receive_buffer, size_t count, int data_type, int op)
 执行原地 Allreduce。此函数不是线程安全的。更多...
 

详细描述

XGBoost 中内部通信器的实验性支持。

注意
这仍在开发中。

XGBoost 中的集体通信器从 dmlc 的 `rabit` 项目演变而来,但自采用以来已发生显著变化。它由一个追踪器和一组工作器组成。追踪器负责引导通信组并处理集中式任务,如日志记录。工作器是执行 Allreduce 等集体任务的实际通信器。

要使用集体实现,首先需要使用相应的参数创建一个追踪器,然后使用 XGTrackerWorkerArgs() 获取工作器的参数。然后可以将获得的参数传递给 XGCommunicatorInit() 函数。调用 XGCommunicatorInit() 必须伴随 XGCommunicatorFinalize() 调用进行清理。请注意,通信器在 C++ 中使用 `std::thread`,由于运行时关闭序列,在 C++ 析构函数中存在未定义行为。最好在运行时关闭之前调用 XGCommunicatorFinalize()。此要求类似于 Python 线程或套接字,不应在 `__del__` 函数中依赖。

由于它是 XGBoost 的一部分,当调用 XGBoost 函数时会返回错误,例如,训练一个 Booster 可能会返回连接错误。

类型定义文档

◆ TrackerHandle

typedef void* TrackerHandle

追踪器的句柄。

XGBoost 中目前有两种类型的追踪器,第一种是 `rabit`,另一种是 `federated`。`rabit` 用于普通集体通信,而 `federated` 用于联邦学习。

函数文档

◆ XGCommunicatorAllreduce()

int XGCommunicatorAllreduce ( void *  send_receive_buffer,
size_t  count,
int  data_type,
int  op 
)

执行就地 allreduce。此函数不是线程安全的。

示例用法:以下代码给出结果的和

enum class Op {
kMax = 0, kMin = 1, kSum = 2, kBitwiseAND = 3, kBitwiseOR = 4, kBitwiseXOR = 5
};
std::vector<int> data(10);
...
Allreduce(data.data(), data.size(), DataType:kInt32, Op::kSum);
...
DataType
xgboost接口接受的数据类型
定义: data.h:33
参数
send_receive_buffer用于发送和接收数据的缓冲区。
count要减少的元素数量。
data_type数据类型枚举,请参见 communicator.h 中的 xgboost::collective::DataType。
op操作类型枚举,请参见 communicator.h 中的 xgboost::collective::Operation。
返回
成功为 0,失败为 -1。

◆ XGCommunicatorBroadcast()

int XGCommunicatorBroadcast ( void *  send_receive_buffer,
size_t  size,
int  root 
)

将内存区域从根节点广播到所有其他节点。此函数不是线程安全的。

示例

int a = 1;
Broadcast(&a, sizeof(a), root);
参数
send_receive_buffer指向发送或接收缓冲区的指针。
size数据大小(以字节为单位)。
root要从中广播的进程排名。
返回
成功为 0,失败为 -1。

◆ XGCommunicatorFinalize()

int XGCommunicatorFinalize ( void  )

结束集体通信器。

完成所有任务后调用此函数。

返回
成功为 0,失败为 -1。

◆ XGCommunicatorGetProcessorName()

int XGCommunicatorGetProcessorName ( const char **  name_str)

获取处理器的名称。

参数
name_str指向接收到的处理器名称的指针。
返回
成功为 0,失败为 -1。

◆ XGCommunicatorGetRank()

int XGCommunicatorGetRank ( void  )

获取当前进程的排名。

返回
工作器的排名。

◆ XGCommunicatorGetWorldSize()

int XGCommunicatorGetWorldSize ( void  )

获取进程总数。

返回
总世界大小。

◆ XGCommunicatorInit()

int XGCommunicatorInit ( char const *  config)

初始化集体通信器。

目前通信器 API 处于实验阶段,函数签名将来可能在不通知的情况下更改。

在使用任何功能之前,在工作进程中调用此函数一次。请确保在使用后调用 XGCommunicatorFinalize()。初始化后的通信器是一个全局线程局部变量。

参数
configJSON 编码的配置。接受的 JSON 键是
  • dmlc_communicator:通信器的类型,这应该与追踪器类型匹配。
    • rabit:使用 Rabit。如果未指定类型,则为默认值。
    • federated:使用 gRPC 接口进行联邦学习。

仅适用于 `rabit` 通信器

  • dmlc_tracker_uri:追踪器的主机名或 IP 地址。
  • dmlc_tracker_port: 追踪器的端口号。
  • dmlc_task_id:当前任务的 ID,可用于获取确定性排名分配。
  • dmlc_retry:连接失败的重试次数。
  • dmlc_timeout: 超时时间(秒)。
  • dmlc_nccl_path:nccl 共享库 `libnccl.so` 的路径。

仅适用于 `federated` 通信器(环境变量使用大写,运行时配置使用小写)

  • federated_server_address: 联邦服务器地址。
  • federated_world_size: 联邦工作器数量。
  • federated_rank: 当前工作器的排名。
  • federated_server_cert_path:服务器证书文件路径。仅在 SSL 模式下需要。
  • federated_client_key_path:客户端密钥文件路径。仅在 SSL 模式下需要。
  • federated_client_cert_path:客户端证书文件路径。仅在 SSL 模式下需要。
返回
成功为 0,失败为 -1。

◆ XGCommunicatorIsDistributed()

int XGCommunicatorIsDistributed ( void  )

获取通信器是否是分布式的。

返回
如果通信器是分布式的,则为 True。

◆ XGCommunicatorPrint()

int XGCommunicatorPrint ( char const *  message)

将消息打印到追踪器。

此函数可用于向监控追踪器的用户传达进度信息。

参数
message要打印的消息。
返回
成功为 0,失败为 -1。

◆ XGTrackerCreate()

int XGTrackerCreate ( char const *  config,
TrackerHandle handle 
)

创建新的追踪器。

参数
configJSON 编码的参数。
  • dmlc_communicator:字符串,要创建的追踪器类型。可用选项为 `rabit` 和 `federated`。有关详细信息,请参见 TrackerHandle
  • n_workers:整数,工作器数量。
  • port:(可选)整数,此追踪器应监听的端口。
  • timeout:(可选)整数,各种网络操作的超时时间(秒)。默认值为 300 秒。

某些配置是 `rabit` 特定的

  • host:(可选)字符串,由 `rabit` 追踪器用于指定主机地址。当通信器无法可靠地获取主机地址时,这可能很有用。
  • sortby:(可选)整数。
    • 0:按主机名对工作器进行排序。
    • 1:按任务 ID 对工作器进行排序。

某些 `federated` 特定配置

  • federated_secure:布尔值,这是否是安全服务器。测试时为 False。
  • server_key_path:服务器密钥的路径。仅当这是安全服务器时使用。
  • server_cert_path:服务器证书的路径。仅当这是安全服务器时使用。
  • client_cert_path:客户端证书的路径。仅当这是安全服务器时使用。
参数
handle已创建追踪器的句柄。
返回
成功为 0,失败为 -1。

◆ XGTrackerFree()

int XGTrackerFree ( TrackerHandle  handle)

释放追踪器实例。这应该在 XGTrackerWaitFor() 之后调用。如果追踪器没有正确等待,此函数将关闭与追踪器的所有连接,可能导致未定义行为。

参数
handle追踪器的句柄。
返回
成功为 0,失败为 -1。

◆ XGTrackerRun()

int XGTrackerRun ( TrackerHandle  handle,
char const *  config 
)

启动追踪器。追踪器在后台运行,此函数在追踪器启动后返回。

参数
handle追踪器的句柄。
config目前未使用,保留以备将来。
返回
成功为 0,失败为 -1。

◆ XGTrackerWaitFor()

int XGTrackerWaitFor ( TrackerHandle  handle,
char const *  config 
)

等待追踪器完成,应在 XGTrackerRun() 之后调用。此函数将阻塞,直到追踪器任务完成或达到超时。

参数
handle追踪器的句柄。
configJSON 编码的配置。目前不需要参数,保留以备将来。
返回
成功为 0,失败为 -1。

◆ XGTrackerWorkerArgs()

int XGTrackerWorkerArgs ( TrackerHandle  handle,
char const **  args 
)

获取运行工作器所需的参数。这应该在 XGTrackerRun() 之后调用。

参数
handle追踪器的句柄。
args作为 JSON 文档返回的参数。
返回
成功为 0,失败为 -1。