未验证 提交 41bfec8d 编写于 作者: W WangXi 提交者: GitHub

[HybridParallel] fix port reuse when create multi group (#31876)

上级 8fec3c6d
...@@ -35,7 +35,7 @@ namespace imperative { ...@@ -35,7 +35,7 @@ namespace imperative {
void NCCLParallelContext::BcastNCCLId( void NCCLParallelContext::BcastNCCLId(
std::vector<ncclUniqueId> &nccl_ids, // NOLINT std::vector<ncclUniqueId> &nccl_ids, // NOLINT
int root) { int root, int server_fd) {
if (strategy_.local_rank_ == root) { if (strategy_.local_rank_ == root) {
std::vector<std::string> other_trainers; std::vector<std::string> other_trainers;
for (auto &ep : strategy_.trainer_endpoints_) { for (auto &ep : strategy_.trainer_endpoints_) {
...@@ -45,11 +45,14 @@ void NCCLParallelContext::BcastNCCLId( ...@@ -45,11 +45,14 @@ void NCCLParallelContext::BcastNCCLId(
} }
platform::SendBroadCastCommID(other_trainers, &nccl_ids); platform::SendBroadCastCommID(other_trainers, &nccl_ids);
} else { } else {
platform::RecvBroadCastCommID(strategy_.current_endpoint_, &nccl_ids); platform::RecvBroadCastCommID(server_fd, strategy_.current_endpoint_,
&nccl_ids);
} }
} }
void NCCLParallelContext::Init() { void NCCLParallelContext::Init() {
int server_fd = -1;
std::vector<ncclUniqueId> nccl_ids; std::vector<ncclUniqueId> nccl_ids;
nccl_ids.resize(strategy_.nrings_); nccl_ids.resize(strategy_.nrings_);
...@@ -58,8 +61,13 @@ void NCCLParallelContext::Init() { ...@@ -58,8 +61,13 @@ void NCCLParallelContext::Init() {
for (size_t i = 0; i < nccl_ids.size(); ++i) { for (size_t i = 0; i < nccl_ids.size(); ++i) {
platform::dynload::ncclGetUniqueId(&nccl_ids[i]); platform::dynload::ncclGetUniqueId(&nccl_ids[i]);
} }
} else {
// FIXME(wangxi): gloo will use rank0 endpoint, so not create socket server
// on rank0.
server_fd = platform::SocketServer::GetInstance(strategy_.current_endpoint_)
.socket();
} }
BcastNCCLId(nccl_ids, 0); BcastNCCLId(nccl_ids, 0, server_fd);
int gpu_id = BOOST_GET_CONST(platform::CUDAPlace, place_).device; int gpu_id = BOOST_GET_CONST(platform::CUDAPlace, place_).device;
for (int ring_id = 0; ring_id < strategy_.nrings_; ring_id++) { for (int ring_id = 0; ring_id < strategy_.nrings_; ring_id++) {
...@@ -80,14 +88,20 @@ void NCCLParallelContext::Init() { ...@@ -80,14 +88,20 @@ void NCCLParallelContext::Init() {
} }
void NCCLParallelContext::InitWithRingID(int ring_id) { void NCCLParallelContext::InitWithRingID(int ring_id) {
int server_fd = -1;
std::vector<ncclUniqueId> nccl_ids; std::vector<ncclUniqueId> nccl_ids;
nccl_ids.resize(1); nccl_ids.resize(1);
if (strategy_.local_rank_ == 0) { if (strategy_.local_rank_ == 0) {
// generate the unique ncclid on the root worker // generate the unique ncclid on the root worker
platform::dynload::ncclGetUniqueId(&nccl_ids[0]); platform::dynload::ncclGetUniqueId(&nccl_ids[0]);
} else {
// FIXME(wangxi): gloo will use rank0 endpoint, so not create socket server
// on rank0.
server_fd = platform::SocketServer::GetInstance(strategy_.current_endpoint_)
.socket();
} }
BcastNCCLId(nccl_ids, 0); BcastNCCLId(nccl_ids, 0, server_fd);
int gpu_id = BOOST_GET_CONST(platform::CUDAPlace, place_).device; int gpu_id = BOOST_GET_CONST(platform::CUDAPlace, place_).device;
VLOG(0) << "init nccl context nranks: " << strategy_.nranks_ VLOG(0) << "init nccl context nranks: " << strategy_.nranks_
......
...@@ -49,7 +49,8 @@ class NCCLParallelContext : public ParallelContext { ...@@ -49,7 +49,8 @@ class NCCLParallelContext : public ParallelContext {
~NCCLParallelContext() override = default; ~NCCLParallelContext() override = default;
void BcastNCCLId(std::vector<ncclUniqueId>& nccl_ids, int root); // NOLINT void BcastNCCLId(std::vector<ncclUniqueId>& nccl_ids, int root, // NOLINT
int server_fd);
void Init() override; void Init() override;
......
...@@ -15,6 +15,7 @@ ...@@ -15,6 +15,7 @@
#include <thread> // NOLINT #include <thread> // NOLINT
#include "paddle/fluid/imperative/nccl_context.h" #include "paddle/fluid/imperative/nccl_context.h"
#include "paddle/fluid/platform/gen_comm_id_helper.h"
#include "gtest/gtest.h" #include "gtest/gtest.h"
...@@ -36,9 +37,13 @@ imperative::ParallelStrategy GetStrategy(int local_rank) { ...@@ -36,9 +37,13 @@ imperative::ParallelStrategy GetStrategy(int local_rank) {
#if defined(PADDLE_WITH_NCCL) || defined(PADDLE_WITH_RCCL) #if defined(PADDLE_WITH_NCCL) || defined(PADDLE_WITH_RCCL)
void BcastNCCLId(int local_rank, std::vector<ncclUniqueId>* nccl_ids) { void BcastNCCLId(int local_rank, std::vector<ncclUniqueId>* nccl_ids) {
auto strategy = GetStrategy(local_rank); auto strategy = GetStrategy(local_rank);
int server_fd = platform::CreateListenSocket(strategy.current_endpoint_);
platform::CUDAPlace gpu(local_rank); platform::CUDAPlace gpu(local_rank);
imperative::NCCLParallelContext ctx(strategy, gpu); imperative::NCCLParallelContext ctx(strategy, gpu);
ctx.BcastNCCLId(*nccl_ids, 0); ctx.BcastNCCLId(*nccl_ids, 0, server_fd);
platform::CloseSocket(server_fd);
} }
TEST(BcastNCCLId, Run) { TEST(BcastNCCLId, Run) {
......
...@@ -66,6 +66,9 @@ class CGenNCCLIdOp : public framework::OperatorBase { ...@@ -66,6 +66,9 @@ class CGenNCCLIdOp : public framework::OperatorBase {
return Output("Out"); return Output("Out");
}; };
std::string endpoint = Attr<std::string>("endpoint");
int server_fd = platform::SocketServer::GetInstance(endpoint).socket();
std::vector<ncclUniqueId> nccl_ids; std::vector<ncclUniqueId> nccl_ids;
nccl_ids.resize(1); nccl_ids.resize(1);
...@@ -75,8 +78,6 @@ class CGenNCCLIdOp : public framework::OperatorBase { ...@@ -75,8 +78,6 @@ class CGenNCCLIdOp : public framework::OperatorBase {
Attr<std::vector<std::string>>("other_endpoints"); Attr<std::vector<std::string>>("other_endpoints");
platform::SendBroadCastCommID(endpoint_list, &nccl_ids); platform::SendBroadCastCommID(endpoint_list, &nccl_ids);
} else { } else {
std::string endpoint = Attr<std::string>("endpoint");
int server_fd = platform::SocketServer::GetInstance(endpoint).socket();
platform::RecvBroadCastCommID(server_fd, endpoint, &nccl_ids); platform::RecvBroadCastCommID(server_fd, endpoint, &nccl_ids);
} }
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册