diff --git a/paddle/fluid/operators/collective/c_gen_hccl_id_op.cc b/paddle/fluid/operators/collective/c_gen_hccl_id_op.cc index 593eaf923a978402cc7607bb7d2bc4a6419dd2cb..af1e576a8c74f509822a1f227976c6a2ad803d82 100644 --- a/paddle/fluid/operators/collective/c_gen_hccl_id_op.cc +++ b/paddle/fluid/operators/collective/c_gen_hccl_id_op.cc @@ -23,15 +23,35 @@ limitations under the License. */ #include "paddle/fluid/platform/enforce.h" #include "paddle/fluid/platform/place.h" -#ifdef PADDLE_WITH_ASCEND_CL -#include "paddle/fluid/operators/collective/gen_hccl_id_op_helper.h" -#endif +#include "paddle/fluid/platform/dynload/hccl.h" +#include "paddle/fluid/platform/gen_comm_id_helper.h" namespace paddle { namespace operators { #ifdef PADDLE_WITH_ASCEND_CL +static void GenHCCLID(std::vector* hccl_ids) { + for (size_t i = 0; i < hccl_ids->size(); ++i) { + PADDLE_ENFORCE_NPU_SUCCESS( + platform::dynload::HcclGetRootInfo(&(*hccl_ids)[i])); + } +} + +static void CopyHCCLIDToVar(const std::vector& hccl_ids, + std::function func, + const framework::Scope& scope) { + for (size_t i = 0; i < hccl_ids.size(); ++i) { + std::string var_name = func(i); + auto var = scope.FindVar(var_name); + PADDLE_ENFORCE_NOT_NULL( + var, platform::errors::NotFound("Variable with name %s is not found", + var_name.c_str())); + auto hccl_id = var->GetMutable(); + memcpy(hccl_id, &hccl_ids[i], sizeof(HcclRootInfo)); + } +} + class CGenHCCLIdOp : public framework::OperatorBase { public: CGenHCCLIdOp(const std::string& type, @@ -49,14 +69,22 @@ class CGenHCCLIdOp : public framework::OperatorBase { return Output("Out"); }; + std::string endpoint = Attr("endpoint"); + int server_fd = platform::SocketServer::GetInstance(endpoint).socket(); + + std::vector hccl_ids; + hccl_ids.resize(1); + if (rank == 0) { + GenHCCLID(&hccl_ids); std::vector endpoint_list = Attr>("other_endpoints"); - SendBroadCastHCCLID(endpoint_list, 1, func, local_scope); + platform::SendBroadCastCommID(endpoint_list, &hccl_ids); } else { - std::string endpoint = Attr("endpoint"); - RecvBroadCastHCCLID(endpoint, 1, func, local_scope); + platform::RecvBroadCastCommID(server_fd, endpoint, &hccl_ids); } + + CopyHCCLIDToVar(hccl_ids, func, scope); scope.DeleteScope(&local_scope); } }; diff --git a/paddle/fluid/platform/gen_comm_id_helper.cc b/paddle/fluid/platform/gen_comm_id_helper.cc index f38603e80fb115f3131173c36f0ee2962d06c0de..5f6dd5679a1a8eacc270a17e0f725e4311897dda 100644 --- a/paddle/fluid/platform/gen_comm_id_helper.cc +++ b/paddle/fluid/platform/gen_comm_id_helper.cc @@ -13,7 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. */ #if defined(PADDLE_WITH_NCCL) || defined(PADDLE_WITH_RCCL) || \ - defined(PADDLE_WITH_XPU_BKCL) + defined(PADDLE_WITH_XPU_BKCL) || defined(PADDLE_WITH_ASCEND_CL) #include "paddle/fluid/platform/gen_comm_id_helper.h" #include @@ -33,6 +33,10 @@ limitations under the License. */ #include "xpu/bkcl.h" #endif +#if defined(PADDLE_WITH_ASCEND_CL) +#include "paddle/fluid/platform/collective_helper.h" +#endif + namespace paddle { namespace platform { @@ -262,10 +266,17 @@ static int ConnectAddr(const std::string& ep, const char* head) { return sock; } +// TODO(WANGXI): maybe need to unify this hard code +#ifdef PADDLE_WITH_ASCEND_CL +#define MAX_COMMUNIQUEID_LEN 4108 +#else +#define MAX_COMMUNIQUEID_LEN 1024 +#endif + template static void RecvCommID(int conn, CommUniqueId* nccl_id) { - char buffer[1024] = {0}; - static_assert(sizeof(CommUniqueId) <= 1024, + char buffer[MAX_COMMUNIQUEID_LEN] = {0}; + static_assert(sizeof(CommUniqueId) <= MAX_COMMUNIQUEID_LEN, "nccl id bytes must <= buffer size"); CHECK_SYS_CALL(SocketRecv(conn, buffer, sizeof(CommUniqueId)), @@ -275,7 +286,7 @@ static void RecvCommID(int conn, CommUniqueId* nccl_id) { template static void SendCommID(int conn, CommUniqueId* nccl_id) { - char buffer[1024] = {0}; + char buffer[MAX_COMMUNIQUEID_LEN] = {0}; memcpy(buffer, nccl_id, sizeof(CommUniqueId)); CHECK_SYS_CALL(SocketSend(conn, buffer, sizeof(CommUniqueId)), @@ -361,6 +372,9 @@ INSTANT_TEMPLATE(ncclUniqueId) #ifdef PADDLE_WITH_XPU_BKCL INSTANT_TEMPLATE(BKCLUniqueId) #endif +#ifdef PADDLE_WITH_ASCEND_CL +INSTANT_TEMPLATE(HcclRootInfo) +#endif } // namespace platform } // namespace paddle diff --git a/paddle/fluid/platform/gen_comm_id_helper.h b/paddle/fluid/platform/gen_comm_id_helper.h index c51c5ac6c8ac7bc8a8887c39c0b08d8cd0af4540..fb5d8d8fcd94059cbef66de809bca295d205a73c 100644 --- a/paddle/fluid/platform/gen_comm_id_helper.h +++ b/paddle/fluid/platform/gen_comm_id_helper.h @@ -15,7 +15,7 @@ limitations under the License. */ #pragma once #if defined(PADDLE_WITH_NCCL) || defined(PADDLE_WITH_RCCL) || \ - defined(PADDLE_WITH_XPU_BKCL) + defined(PADDLE_WITH_XPU_BKCL) || defined(PADDLE_WITH_ASCEND_CL) #include #include #include