From 9066b74f58ad7163dfc0ad8ef912cc50264997d1 Mon Sep 17 00:00:00 2001 From: WangXi Date: Mon, 15 Mar 2021 10:54:51 +0800 Subject: [PATCH] c_gen_nccl_id add SocketServer to persit server (#31589) --- .../operators/collective/c_gen_nccl_id_op.cc | 3 ++- paddle/fluid/platform/gen_comm_id_helper.cc | 18 +++++++++++++++++ paddle/fluid/platform/gen_comm_id_helper.h | 20 +++++++++++++++++++ 3 files changed, 40 insertions(+), 1 deletion(-) diff --git a/paddle/fluid/operators/collective/c_gen_nccl_id_op.cc b/paddle/fluid/operators/collective/c_gen_nccl_id_op.cc index 485a6d7ec4e..1592d809f91 100644 --- a/paddle/fluid/operators/collective/c_gen_nccl_id_op.cc +++ b/paddle/fluid/operators/collective/c_gen_nccl_id_op.cc @@ -75,7 +75,8 @@ class CGenNCCLIdOp : public framework::OperatorBase { platform::SendBroadCastCommID(endpoint_list, &nccl_ids); } else { std::string endpoint = Attr("endpoint"); - platform::RecvBroadCastCommID(endpoint, &nccl_ids); + int server_fd = platform::SocketServer::GetInstance(endpoint).socket(); + platform::RecvBroadCastCommID(server_fd, endpoint, &nccl_ids); } CopyNCCLIDToVar(nccl_ids, func, scope); diff --git a/paddle/fluid/platform/gen_comm_id_helper.cc b/paddle/fluid/platform/gen_comm_id_helper.cc index ffe82371b18..f38603e80fb 100644 --- a/paddle/fluid/platform/gen_comm_id_helper.cc +++ b/paddle/fluid/platform/gen_comm_id_helper.cc @@ -36,6 +36,8 @@ limitations under the License. */ namespace paddle { namespace platform { +std::once_flag SocketServer::init_flag_; + constexpr char COMM_HEAD[] = "_pd_gen_comm_id_"; // Check system calls, such as socket, bind. @@ -330,6 +332,22 @@ void RecvBroadCastCommID(int server_fd, std::string endpoint, CloseSocket(client); } +SocketServer& SocketServer::GetInstance(const std::string& end_point) { + static SocketServer instance; + std::call_once(init_flag_, [&]() { + instance.server_fd_ = CreateListenSocket(end_point); + instance.end_point_ = end_point; + }); + PADDLE_ENFORCE_NE(instance.server_fd_, -1, + platform::errors::Unavailable( + "listen socket failed with end_point=%s", end_point)); + PADDLE_ENFORCE_EQ(instance.end_point_, end_point, + platform::errors::InvalidArgument( + "old end_point=%s must equal with new end_point=%s", + instance.end_point_, end_point)); + return instance; +} + /// template instantiation #define INSTANT_TEMPLATE(Type) \ template void SendBroadCastCommID(std::vector servers, \ diff --git a/paddle/fluid/platform/gen_comm_id_helper.h b/paddle/fluid/platform/gen_comm_id_helper.h index 6014a2b4ff9..c51c5ac6c8a 100644 --- a/paddle/fluid/platform/gen_comm_id_helper.h +++ b/paddle/fluid/platform/gen_comm_id_helper.h @@ -17,6 +17,8 @@ limitations under the License. */ #if defined(PADDLE_WITH_NCCL) || defined(PADDLE_WITH_RCCL) || \ defined(PADDLE_WITH_XPU_BKCL) #include +#include +#include #include #include @@ -39,6 +41,24 @@ void RecvBroadCastCommID(std::string endpoint, template void RecvBroadCastCommID(int server_fd, std::string endpoint, std::vector* nccl_ids); + +class SocketServer { + public: + SocketServer() = default; + + ~SocketServer() { CloseSocket(server_fd_); } + + int socket() const { return server_fd_; } + + static SocketServer& GetInstance(const std::string& end_point); + + private: + int server_fd_{-1}; + std::string end_point_; + + static std::once_flag init_flag_; +}; + } // namespace platform } // namespace paddle -- GitLab