From 03e99e26cc6b6590f4102da9c8c0258094b87f51 Mon Sep 17 00:00:00 2001 From: WangXi Date: Sun, 28 Feb 2021 12:23:41 +0000 Subject: [PATCH] update --- .../operators/collective/c_gen_nccl_id_op.cc | 6 ++-- .../collective/gen_nccl_id_op_helper.cc | 35 +++++++++++++++++-- .../collective/gen_nccl_id_op_helper.h | 30 +++++++++++++--- 3 files changed, 61 insertions(+), 10 deletions(-) diff --git a/paddle/fluid/operators/collective/c_gen_nccl_id_op.cc b/paddle/fluid/operators/collective/c_gen_nccl_id_op.cc index 18f2a40f3dd..9446c38dcba 100644 --- a/paddle/fluid/operators/collective/c_gen_nccl_id_op.cc +++ b/paddle/fluid/operators/collective/c_gen_nccl_id_op.cc @@ -55,7 +55,8 @@ class CGenNCCLIdOp : public framework::OperatorBase { SendBroadCastNCCLID(endpoint_list, 1, func, local_scope); } else { std::string endpoint = Attr("endpoint"); - RecvBroadCastNCCLID(endpoint, 1, func, local_scope); + int server_fd = platform::SocketServer::GetInstance(endpoint).socket(); + platform::RecvBroadCastCommID(server_fd, endpoint, &nccl_ids); } scope.DeleteScope(&local_scope); } @@ -71,8 +72,7 @@ class CGenNCCLIdOp : public framework::OperatorBase { : OperatorBase(type, inputs, outputs, attrs) {} void RunImpl(const framework::Scope& scope, - const platform::Place& dev_place) const override { - } + const platform::Place& dev_place) const override {} }; #endif diff --git a/paddle/fluid/operators/collective/gen_nccl_id_op_helper.cc b/paddle/fluid/operators/collective/gen_nccl_id_op_helper.cc index a0df244000b..94f471e4456 100644 --- a/paddle/fluid/operators/collective/gen_nccl_id_op_helper.cc +++ b/paddle/fluid/operators/collective/gen_nccl_id_op_helper.cc @@ -31,7 +31,9 @@ limitations under the License. */ #include "paddle/fluid/string/split.h" namespace paddle { -namespace operators { +namespace platform { + +std::once_flag SocketServer::init_flag_; constexpr char COMM_HEAD[] = "_pd_gen_comm_id_"; @@ -340,5 +342,34 @@ void RecvBroadCastNCCLID(int server_fd, std::string endpoint, int nccl_comm_num, CloseSocket(client); } -} // namespace operators +SocketServer& SocketServer::GetInstance(const std::string& end_point) { + static SocketServer instance; + std::call_once(init_flag_, [&]() { + instance.server_fd_ = CreateListenSocket(end_point); + instance.end_point_ = end_point; + }); + PADDLE_ENFORCE_NE(instance.server_fd_, -1, + platform::errors::Unavailable( + "listen socket failed with end_point=%s", end_point)); + PADDLE_ENFORCE_EQ(instance.end_point_, end_point, + platform::errors::InvalidArgument( + "old end_point=%s must equal with new end_point=%s", + instance.end_point_, end_point)); + return instance; +} + +/// template instantiation +#define INSTANT_TEMPLATE(Type) \ + template void SendBroadCastCommID(std::vector servers, \ + std::vector * nccl_ids); \ + template void RecvBroadCastCommID(std::string endpoint, \ + std::vector * nccl_ids); + +#if defined(PADDLE_WITH_NCCL) || defined(PADDLE_WITH_RCCL) +INSTANT_TEMPLATE(ncclUniqueId) +#endif +#ifdef PADDLE_WITH_XPU_BKCL +INSTANT_TEMPLATE(BKCLUniqueId) +#endif +} // namespace platform } // namespace paddle diff --git a/paddle/fluid/operators/collective/gen_nccl_id_op_helper.h b/paddle/fluid/operators/collective/gen_nccl_id_op_helper.h index 38751805191..8db9bcee4da 100644 --- a/paddle/fluid/operators/collective/gen_nccl_id_op_helper.h +++ b/paddle/fluid/operators/collective/gen_nccl_id_op_helper.h @@ -15,6 +15,8 @@ limitations under the License. */ #pragma once #include +#include +#include // NOLINT #include #include @@ -25,7 +27,7 @@ class Scope; } // namespace paddle namespace paddle { -namespace operators { +namespace platform { int CreateListenSocket(const std::string& ep); @@ -41,8 +43,26 @@ void RecvBroadCastNCCLID(std::string endpoint, int nccl_comm_num, const framework::Scope& scope); // recv nccl id from socket -void RecvBroadCastNCCLID(int server_fd, std::string endpoint, int nccl_comm_num, - std::function func, - const framework::Scope& scope); -} // namespace operators +template +void RecvBroadCastCommID(int server_fd, std::string endpoint, + std::vector* nccl_ids); + +class SocketServer { + public: + SocketServer() = default; + + ~SocketServer() { CloseSocket(server_fd_); } + + int socket() const { return server_fd_; } + + static SocketServer& GetInstance(const std::string& end_point); + + private: + int server_fd_{-1}; + std::string end_point_; + + static std::once_flag init_flag_; +}; + +} // namespace platform } // namespace paddle -- GitLab