xgboost
|
暴露 XGBoost 内部通信器的实验性支持。 更多...
类型定义 | |
typedef void * | TrackerHandle |
Tracker 的句柄。 更多... | |
函数 | |
int | XGTrackerCreate (char const *config, TrackerHandle *handle) |
创建一个新的 tracker。 更多... | |
int | XGTrackerWorkerArgs (TrackerHandle handle, char const **args) |
获取运行 worker 所需的参数。这应在调用 XGTrackerRun() 后调用。 更多... | |
int | XGTrackerRun (TrackerHandle handle, char const *config) |
启动 tracker。tracker 在后台运行,一旦 tracker 启动,此函数立即返回。 更多... | |
int | XGTrackerWaitFor (TrackerHandle handle, char const *config) |
等待 tracker 完成。应在调用 XGTrackerRun() 后调用此函数。此函数将阻塞,直到 tracker 任务完成或达到超时。 更多... | |
int | XGTrackerFree (TrackerHandle handle) |
释放 tracker 实例。这应在调用 XGTrackerWaitFor() 后调用。如果未正确等待 tracker,此函数将关闭与 tracker 的所有连接,可能导致未定义的行为。 更多... | |
int | XGCommunicatorInit (char const *config) |
初始化集合通信器。 更多... | |
int | XGCommunicatorFinalize (void) |
终止集合通信器。 更多... | |
int | XGCommunicatorGetRank (void) |
获取当前进程的 rank。 更多... | |
int | XGCommunicatorGetWorldSize (void) |
获取进程总数(world size)。 更多... | |
int | XGCommunicatorIsDistributed (void) |
获取通信器是否是分布式的。 更多... | |
int | XGCommunicatorPrint (char const *message) |
向 tracker 打印消息。 更多... | |
int | XGCommunicatorGetProcessorName (const char **name_str) |
获取处理器的名称。 更多... | |
int | XGCommunicatorBroadcast (void *send_receive_buffer, size_t size, int root) |
从 root 进程向所有其他进程广播内存区域。此函数不是线程安全的。 更多... | |
int | XGCommunicatorAllreduce (void *send_receive_buffer, size_t count, int data_type, int op) |
执行 in-place allreduce 操作。此函数不是线程安全的。 更多... | |
暴露 XGBoost 内部通信器的实验性支持。
XGBoost 中的集合通信器从 dmlc 的 rabit
项目演变而来,但自从采用后已发生显著变化。它由一个 tracker 和一组 worker 组成。tracker 负责启动通信组并处理日志等集中式任务。worker 是执行 allreduce 等集合任务的实际通信器。
要使用集合实现,首先需要使用相应的参数创建一个 tracker,然后使用 XGTrackerWorkerArgs() 获取 worker 的参数。然后可以将获得的参数传递给 XGCommunicatorInit() 函数。调用 XGCommunicatorInit() 必须伴随对 XGCommunicatorFinalize() 的调用以进行清理。请注意,通信器使用 C++ 中的 std::thread
,由于运行时关闭序列,在 C++ 析构函数中可能存在未定义的行为。最好在运行时关闭之前调用 XGCommunicatorFinalize()。此要求类似于 Python 线程或 socket,不应依赖于 __del__
函数。
由于它是 XGBoost 的一部分,当调用 XGBoost 函数时会返回错误,例如,训练 booster 可能会返回连接错误。
typedef void* TrackerHandle |
Tracker 的句柄。
目前 XGBoost 中有两种类型的 tracker,第一种是 rabit
,另一种是 federated
。rabit
用于常规集合通信,而 federated
用于联邦学习。
int XGCommunicatorAllreduce | ( | void * | send_receive_buffer, |
size_t | count, | ||
int | data_type, | ||
int | op | ||
) |
执行 in-place allreduce 操作。此函数不是线程安全的。
用法示例:以下代码计算结果的和
send_receive_buffer | 用于发送和接收数据的缓冲区。 |
count | 待归约元素的数量。 |
data_type | 数据类型枚举,参见 communicator.h 中的 xgboost::collective::DataType。 |
op | 操作类型枚举,参见 communicator.h 中的 xgboost::collective::Operation。 |
int XGCommunicatorBroadcast | ( | void * | send_receive_buffer, |
size_t | size, | ||
int | root | ||
) |
从 root 进程向所有其他进程广播内存区域。此函数不是线程安全的。
示例
send_receive_buffer | 指向发送或接收缓冲区的指针。 |
size | 数据大小(字节)。 |
root | 进行广播的进程 rank。 |
int XGCommunicatorFinalize | ( | void | ) |
终止集合通信器。
完成所有任务后调用此函数。
int XGCommunicatorGetProcessorName | ( | const char ** | name_str | ) |
获取处理器的名称。
name_str | 指向接收返回的处理器名称的指针。 |
int XGCommunicatorGetRank | ( | void | ) |
获取当前进程的 rank。
int XGCommunicatorGetWorldSize | ( | void | ) |
获取进程总数。
int XGCommunicatorInit | ( | char const * | config | ) |
初始化集合通信器。
当前通信器 API 处于实验阶段,函数签名将来可能会更改,恕不另行通知。
在使用任何功能之前,在 worker 进程中调用此函数一次。请确保在使用后调用 XGCommunicatorFinalize()。初始化的通信器是全局线程局部变量。
config | JSON 编码的配置。可接受的 JSON 键包括:
|
仅适用于 rabit
通信器
libnccl.so
的路径。仅适用于 federated
通信器(环境变量使用大写,运行时配置使用小写)
int XGCommunicatorIsDistributed | ( | void | ) |
获取通信器是否是分布式的。
int XGCommunicatorPrint | ( | char const * | message | ) |
向 tracker 打印消息。
此函数可用于向监视 tracker 的用户传达进度信息。
message | 待打印的消息。 |
int XGTrackerCreate | ( | char const * | config, |
TrackerHandle * | handle | ||
) |
创建一个新的 tracker。
config | JSON 编码的参数。 |
rabit
和 federated
。更多信息请参见 TrackerHandle。一些配置项是 rabit
特有的
rabit
tracker 使用,用于指定主机地址。当通信器无法可靠地获取主机地址时,这会很有用。一些 federated
特有的配置项
handle | 创建的 tracker 的句柄。 |
int XGTrackerFree | ( | TrackerHandle | handle | ) |
释放 tracker 实例。这应在调用 XGTrackerWaitFor() 后调用。如果未正确等待 tracker,此函数将关闭与 tracker 的所有连接,可能导致未定义的行为。
handle | tracker 的句柄。 |
int XGTrackerRun | ( | TrackerHandle | handle, |
char const * | config | ||
) |
启动 tracker。tracker 在后台运行,一旦 tracker 启动,此函数立即返回。
handle | tracker 的句柄。 |
config | 目前未使用,保留供将来使用。 |
int XGTrackerWaitFor | ( | TrackerHandle | handle, |
char const * | config | ||
) |
等待 tracker 完成。应在调用 XGTrackerRun() 后调用此函数。此函数将阻塞,直到 tracker 任务完成或达到超时。
handle | tracker 的句柄。 |
config | JSON 编码的配置。目前不需要参数,保留供将来使用。 |
int XGTrackerWorkerArgs | ( | TrackerHandle | handle, |
char const ** | args | ||
) |
获取运行 worker 所需的参数。这应在调用 XGTrackerRun() 后调用。
handle | tracker 的句柄。 |
args | 作为 JSON 文档返回的参数。 |