未验证 提交 41bfec8d 编写于 作者: W WangXi 提交者: GitHub

[HybridParallel] fix port reuse when create multi group (#31876)

上级 8fec3c6d
......@@ -35,7 +35,7 @@ namespace imperative {
void NCCLParallelContext::BcastNCCLId(
std::vector<ncclUniqueId> &nccl_ids, // NOLINT
int root) {
int root, int server_fd) {
if (strategy_.local_rank_ == root) {
std::vector<std::string> other_trainers;
for (auto &ep : strategy_.trainer_endpoints_) {
......@@ -45,11 +45,14 @@ void NCCLParallelContext::BcastNCCLId(
}
platform::SendBroadCastCommID(other_trainers, &nccl_ids);
} else {
platform::RecvBroadCastCommID(strategy_.current_endpoint_, &nccl_ids);
platform::RecvBroadCastCommID(server_fd, strategy_.current_endpoint_,
&nccl_ids);
}
}
void NCCLParallelContext::Init() {
int server_fd = -1;
std::vector<ncclUniqueId> nccl_ids;
nccl_ids.resize(strategy_.nrings_);
......@@ -58,8 +61,13 @@ void NCCLParallelContext::Init() {
for (size_t i = 0; i < nccl_ids.size(); ++i) {
platform::dynload::ncclGetUniqueId(&nccl_ids[i]);
}
} else {
// FIXME(wangxi): gloo will use rank0 endpoint, so not create socket server
// on rank0.
server_fd = platform::SocketServer::GetInstance(strategy_.current_endpoint_)
.socket();
}
BcastNCCLId(nccl_ids, 0);
BcastNCCLId(nccl_ids, 0, server_fd);
int gpu_id = BOOST_GET_CONST(platform::CUDAPlace, place_).device;
for (int ring_id = 0; ring_id < strategy_.nrings_; ring_id++) {
......@@ -80,14 +88,20 @@ void NCCLParallelContext::Init() {
}
void NCCLParallelContext::InitWithRingID(int ring_id) {
int server_fd = -1;
std::vector<ncclUniqueId> nccl_ids;
nccl_ids.resize(1);
if (strategy_.local_rank_ == 0) {
// generate the unique ncclid on the root worker
platform::dynload::ncclGetUniqueId(&nccl_ids[0]);
} else {
// FIXME(wangxi): gloo will use rank0 endpoint, so not create socket server
// on rank0.
server_fd = platform::SocketServer::GetInstance(strategy_.current_endpoint_)
.socket();
}
BcastNCCLId(nccl_ids, 0);
BcastNCCLId(nccl_ids, 0, server_fd);
int gpu_id = BOOST_GET_CONST(platform::CUDAPlace, place_).device;
VLOG(0) << "init nccl context nranks: " << strategy_.nranks_
......
......@@ -49,7 +49,8 @@ class NCCLParallelContext : public ParallelContext {
~NCCLParallelContext() override = default;
void BcastNCCLId(std::vector<ncclUniqueId>& nccl_ids, int root); // NOLINT
void BcastNCCLId(std::vector<ncclUniqueId>& nccl_ids, int root, // NOLINT
int server_fd);
void Init() override;
......
......@@ -15,6 +15,7 @@
#include <thread> // NOLINT
#include "paddle/fluid/imperative/nccl_context.h"
#include "paddle/fluid/platform/gen_comm_id_helper.h"
#include "gtest/gtest.h"
......@@ -36,9 +37,13 @@ imperative::ParallelStrategy GetStrategy(int local_rank) {
#if defined(PADDLE_WITH_NCCL) || defined(PADDLE_WITH_RCCL)
void BcastNCCLId(int local_rank, std::vector<ncclUniqueId>* nccl_ids) {
auto strategy = GetStrategy(local_rank);
int server_fd = platform::CreateListenSocket(strategy.current_endpoint_);
platform::CUDAPlace gpu(local_rank);
imperative::NCCLParallelContext ctx(strategy, gpu);
ctx.BcastNCCLId(*nccl_ids, 0);
ctx.BcastNCCLId(*nccl_ids, 0, server_fd);
platform::CloseSocket(server_fd);
}
TEST(BcastNCCLId, Run) {
......
......@@ -66,6 +66,9 @@ class CGenNCCLIdOp : public framework::OperatorBase {
return Output("Out");
};
std::string endpoint = Attr<std::string>("endpoint");
int server_fd = platform::SocketServer::GetInstance(endpoint).socket();
std::vector<ncclUniqueId> nccl_ids;
nccl_ids.resize(1);
......@@ -75,8 +78,6 @@ class CGenNCCLIdOp : public framework::OperatorBase {
Attr<std::vector<std::string>>("other_endpoints");
platform::SendBroadCastCommID(endpoint_list, &nccl_ids);
} else {
std::string endpoint = Attr<std::string>("endpoint");
int server_fd = platform::SocketServer::GetInstance(endpoint).socket();
platform::RecvBroadCastCommID(server_fd, endpoint, &nccl_ids);
}
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册