From 1edf4374c4e933438b957d91c189c2e854a606d5 Mon Sep 17 00:00:00 2001 From: LiuWei Date: Tue, 13 Jul 2021 19:41:50 +0800 Subject: [PATCH] change hccl_helper as commid helper (#34118) --- .../operators/collective/c_gen_hccl_id_op.cc | 40 ++++++++++++++++--- paddle/fluid/platform/gen_comm_id_helper.cc | 22 ++++++++-- paddle/fluid/platform/gen_comm_id_helper.h | 2 +- 3 files changed, 53 insertions(+), 11 deletions(-) diff --git a/paddle/fluid/operators/collective/c_gen_hccl_id_op.cc b/paddle/fluid/operators/collective/c_gen_hccl_id_op.cc index 593eaf923a9..af1e576a8c7 100644 --- a/paddle/fluid/operators/collective/c_gen_hccl_id_op.cc +++ b/paddle/fluid/operators/collective/c_gen_hccl_id_op.cc @@ -23,15 +23,35 @@ limitations under the License. */ #include "paddle/fluid/platform/enforce.h" #include "paddle/fluid/platform/place.h" -#ifdef PADDLE_WITH_ASCEND_CL -#include "paddle/fluid/operators/collective/gen_hccl_id_op_helper.h" -#endif +#include "paddle/fluid/platform/dynload/hccl.h" +#include "paddle/fluid/platform/gen_comm_id_helper.h" namespace paddle { namespace operators { #ifdef PADDLE_WITH_ASCEND_CL +static void GenHCCLID(std::vector* hccl_ids) { + for (size_t i = 0; i < hccl_ids->size(); ++i) { + PADDLE_ENFORCE_NPU_SUCCESS( + platform::dynload::HcclGetRootInfo(&(*hccl_ids)[i])); + } +} + +static void CopyHCCLIDToVar(const std::vector& hccl_ids, + std::function func, + const framework::Scope& scope) { + for (size_t i = 0; i < hccl_ids.size(); ++i) { + std::string var_name = func(i); + auto var = scope.FindVar(var_name); + PADDLE_ENFORCE_NOT_NULL( + var, platform::errors::NotFound("Variable with name %s is not found", + var_name.c_str())); + auto hccl_id = var->GetMutable(); + memcpy(hccl_id, &hccl_ids[i], sizeof(HcclRootInfo)); + } +} + class CGenHCCLIdOp : public framework::OperatorBase { public: CGenHCCLIdOp(const std::string& type, @@ -49,14 +69,22 @@ class CGenHCCLIdOp : public framework::OperatorBase { return Output("Out"); }; + std::string endpoint = Attr("endpoint"); + int server_fd = platform::SocketServer::GetInstance(endpoint).socket(); + + std::vector hccl_ids; + hccl_ids.resize(1); + if (rank == 0) { + GenHCCLID(&hccl_ids); std::vector endpoint_list = Attr>("other_endpoints"); - SendBroadCastHCCLID(endpoint_list, 1, func, local_scope); + platform::SendBroadCastCommID(endpoint_list, &hccl_ids); } else { - std::string endpoint = Attr("endpoint"); - RecvBroadCastHCCLID(endpoint, 1, func, local_scope); + platform::RecvBroadCastCommID(server_fd, endpoint, &hccl_ids); } + + CopyHCCLIDToVar(hccl_ids, func, scope); scope.DeleteScope(&local_scope); } }; diff --git a/paddle/fluid/platform/gen_comm_id_helper.cc b/paddle/fluid/platform/gen_comm_id_helper.cc index f38603e80fb..5f6dd5679a1 100644 --- a/paddle/fluid/platform/gen_comm_id_helper.cc +++ b/paddle/fluid/platform/gen_comm_id_helper.cc @@ -13,7 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. */ #if defined(PADDLE_WITH_NCCL) || defined(PADDLE_WITH_RCCL) || \ - defined(PADDLE_WITH_XPU_BKCL) + defined(PADDLE_WITH_XPU_BKCL) || defined(PADDLE_WITH_ASCEND_CL) #include "paddle/fluid/platform/gen_comm_id_helper.h" #include @@ -33,6 +33,10 @@ limitations under the License. */ #include "xpu/bkcl.h" #endif +#if defined(PADDLE_WITH_ASCEND_CL) +#include "paddle/fluid/platform/collective_helper.h" +#endif + namespace paddle { namespace platform { @@ -262,10 +266,17 @@ static int ConnectAddr(const std::string& ep, const char* head) { return sock; } +// TODO(WANGXI): maybe need to unify this hard code +#ifdef PADDLE_WITH_ASCEND_CL +#define MAX_COMMUNIQUEID_LEN 4108 +#else +#define MAX_COMMUNIQUEID_LEN 1024 +#endif + template static void RecvCommID(int conn, CommUniqueId* nccl_id) { - char buffer[1024] = {0}; - static_assert(sizeof(CommUniqueId) <= 1024, + char buffer[MAX_COMMUNIQUEID_LEN] = {0}; + static_assert(sizeof(CommUniqueId) <= MAX_COMMUNIQUEID_LEN, "nccl id bytes must <= buffer size"); CHECK_SYS_CALL(SocketRecv(conn, buffer, sizeof(CommUniqueId)), @@ -275,7 +286,7 @@ static void RecvCommID(int conn, CommUniqueId* nccl_id) { template static void SendCommID(int conn, CommUniqueId* nccl_id) { - char buffer[1024] = {0}; + char buffer[MAX_COMMUNIQUEID_LEN] = {0}; memcpy(buffer, nccl_id, sizeof(CommUniqueId)); CHECK_SYS_CALL(SocketSend(conn, buffer, sizeof(CommUniqueId)), @@ -361,6 +372,9 @@ INSTANT_TEMPLATE(ncclUniqueId) #ifdef PADDLE_WITH_XPU_BKCL INSTANT_TEMPLATE(BKCLUniqueId) #endif +#ifdef PADDLE_WITH_ASCEND_CL +INSTANT_TEMPLATE(HcclRootInfo) +#endif } // namespace platform } // namespace paddle diff --git a/paddle/fluid/platform/gen_comm_id_helper.h b/paddle/fluid/platform/gen_comm_id_helper.h index c51c5ac6c8a..fb5d8d8fcd9 100644 --- a/paddle/fluid/platform/gen_comm_id_helper.h +++ b/paddle/fluid/platform/gen_comm_id_helper.h @@ -15,7 +15,7 @@ limitations under the License. */ #pragma once #if defined(PADDLE_WITH_NCCL) || defined(PADDLE_WITH_RCCL) || \ - defined(PADDLE_WITH_XPU_BKCL) + defined(PADDLE_WITH_XPU_BKCL) || defined(PADDLE_WITH_ASCEND_CL) #include #include #include -- GitLab