From 41bfec8dbb03ccf5a040e2b64fd5564c00532632 Mon Sep 17 00:00:00 2001 From: WangXi Date: Mon, 26 Apr 2021 11:50:39 +0800 Subject: [PATCH] [HybridParallel] fix port reuse when create multi group (#31876) --- paddle/fluid/imperative/nccl_context.cc | 22 +++++++++++++++---- paddle/fluid/imperative/nccl_context.h | 3 ++- .../imperative/tests/nccl_context_test.cc | 7 +++++- .../operators/collective/c_gen_nccl_id_op.cc | 5 +++-- 4 files changed, 29 insertions(+), 8 deletions(-) diff --git a/paddle/fluid/imperative/nccl_context.cc b/paddle/fluid/imperative/nccl_context.cc index b91fc460781..9f036742f0f 100644 --- a/paddle/fluid/imperative/nccl_context.cc +++ b/paddle/fluid/imperative/nccl_context.cc @@ -35,7 +35,7 @@ namespace imperative { void NCCLParallelContext::BcastNCCLId( std::vector &nccl_ids, // NOLINT - int root) { + int root, int server_fd) { if (strategy_.local_rank_ == root) { std::vector other_trainers; for (auto &ep : strategy_.trainer_endpoints_) { @@ -45,11 +45,14 @@ void NCCLParallelContext::BcastNCCLId( } platform::SendBroadCastCommID(other_trainers, &nccl_ids); } else { - platform::RecvBroadCastCommID(strategy_.current_endpoint_, &nccl_ids); + platform::RecvBroadCastCommID(server_fd, strategy_.current_endpoint_, + &nccl_ids); } } void NCCLParallelContext::Init() { + int server_fd = -1; + std::vector nccl_ids; nccl_ids.resize(strategy_.nrings_); @@ -58,8 +61,13 @@ void NCCLParallelContext::Init() { for (size_t i = 0; i < nccl_ids.size(); ++i) { platform::dynload::ncclGetUniqueId(&nccl_ids[i]); } + } else { + // FIXME(wangxi): gloo will use rank0 endpoint, so not create socket server + // on rank0. + server_fd = platform::SocketServer::GetInstance(strategy_.current_endpoint_) + .socket(); } - BcastNCCLId(nccl_ids, 0); + BcastNCCLId(nccl_ids, 0, server_fd); int gpu_id = BOOST_GET_CONST(platform::CUDAPlace, place_).device; for (int ring_id = 0; ring_id < strategy_.nrings_; ring_id++) { @@ -80,14 +88,20 @@ void NCCLParallelContext::Init() { } void NCCLParallelContext::InitWithRingID(int ring_id) { + int server_fd = -1; std::vector nccl_ids; nccl_ids.resize(1); if (strategy_.local_rank_ == 0) { // generate the unique ncclid on the root worker platform::dynload::ncclGetUniqueId(&nccl_ids[0]); + } else { + // FIXME(wangxi): gloo will use rank0 endpoint, so not create socket server + // on rank0. + server_fd = platform::SocketServer::GetInstance(strategy_.current_endpoint_) + .socket(); } - BcastNCCLId(nccl_ids, 0); + BcastNCCLId(nccl_ids, 0, server_fd); int gpu_id = BOOST_GET_CONST(platform::CUDAPlace, place_).device; VLOG(0) << "init nccl context nranks: " << strategy_.nranks_ diff --git a/paddle/fluid/imperative/nccl_context.h b/paddle/fluid/imperative/nccl_context.h index bcaeb811b10..1eee393aa71 100644 --- a/paddle/fluid/imperative/nccl_context.h +++ b/paddle/fluid/imperative/nccl_context.h @@ -49,7 +49,8 @@ class NCCLParallelContext : public ParallelContext { ~NCCLParallelContext() override = default; - void BcastNCCLId(std::vector& nccl_ids, int root); // NOLINT + void BcastNCCLId(std::vector& nccl_ids, int root, // NOLINT + int server_fd); void Init() override; diff --git a/paddle/fluid/imperative/tests/nccl_context_test.cc b/paddle/fluid/imperative/tests/nccl_context_test.cc index 4967df5341d..2d8a08217b0 100644 --- a/paddle/fluid/imperative/tests/nccl_context_test.cc +++ b/paddle/fluid/imperative/tests/nccl_context_test.cc @@ -15,6 +15,7 @@ #include // NOLINT #include "paddle/fluid/imperative/nccl_context.h" +#include "paddle/fluid/platform/gen_comm_id_helper.h" #include "gtest/gtest.h" @@ -36,9 +37,13 @@ imperative::ParallelStrategy GetStrategy(int local_rank) { #if defined(PADDLE_WITH_NCCL) || defined(PADDLE_WITH_RCCL) void BcastNCCLId(int local_rank, std::vector* nccl_ids) { auto strategy = GetStrategy(local_rank); + int server_fd = platform::CreateListenSocket(strategy.current_endpoint_); + platform::CUDAPlace gpu(local_rank); imperative::NCCLParallelContext ctx(strategy, gpu); - ctx.BcastNCCLId(*nccl_ids, 0); + ctx.BcastNCCLId(*nccl_ids, 0, server_fd); + + platform::CloseSocket(server_fd); } TEST(BcastNCCLId, Run) { diff --git a/paddle/fluid/operators/collective/c_gen_nccl_id_op.cc b/paddle/fluid/operators/collective/c_gen_nccl_id_op.cc index 7da30f64d1c..470537582e9 100644 --- a/paddle/fluid/operators/collective/c_gen_nccl_id_op.cc +++ b/paddle/fluid/operators/collective/c_gen_nccl_id_op.cc @@ -66,6 +66,9 @@ class CGenNCCLIdOp : public framework::OperatorBase { return Output("Out"); }; + std::string endpoint = Attr("endpoint"); + int server_fd = platform::SocketServer::GetInstance(endpoint).socket(); + std::vector nccl_ids; nccl_ids.resize(1); @@ -75,8 +78,6 @@ class CGenNCCLIdOp : public framework::OperatorBase { Attr>("other_endpoints"); platform::SendBroadCastCommID(endpoint_list, &nccl_ids); } else { - std::string endpoint = Attr("endpoint"); - int server_fd = platform::SocketServer::GetInstance(endpoint).socket(); platform::RecvBroadCastCommID(server_fd, endpoint, &nccl_ids); } -- GitLab