From 1bd7a143559763ab58d251711b0c0b5041c5a71c Mon Sep 17 00:00:00 2001 From: maxhuiy <1508399706@qq.com> Date: Wed, 9 Feb 2022 14:11:28 +0800 Subject: [PATCH] [MLU] add mlu kernel for c_comm_init op (#39364) --- paddle/fluid/framework/var_type_traits.h | 7 +++++++ .../operators/collective/c_comm_init_op.cc | 20 +++++++++++++------ 2 files changed, 21 insertions(+), 6 deletions(-) diff --git a/paddle/fluid/framework/var_type_traits.h b/paddle/fluid/framework/var_type_traits.h index ac55abaad8d..9c27fd2b246 100644 --- a/paddle/fluid/framework/var_type_traits.h +++ b/paddle/fluid/framework/var_type_traits.h @@ -47,6 +47,10 @@ #include "xpu/bkcl.h" #endif +#if defined(PADDLE_WITH_CNCL) +#include +#endif + namespace pten { class DenseTensor; class SelectedRows; @@ -181,6 +185,9 @@ using VarTypeRegistry = detail::VarTypeRegistryImpl< #endif #if defined(PADDLE_WITH_XPU_BKCL) BKCLUniqueId, platform::BKCLCommunicator, +#endif +#if defined(PADDLE_WITH_CNCL) + cnclCliqueId, #endif int, float, Vocab>; template diff --git a/paddle/fluid/operators/collective/c_comm_init_op.cc b/paddle/fluid/operators/collective/c_comm_init_op.cc index 56b0017fefe..39acb50d4e8 100644 --- a/paddle/fluid/operators/collective/c_comm_init_op.cc +++ b/paddle/fluid/operators/collective/c_comm_init_op.cc @@ -20,12 +20,15 @@ limitations under the License. */ #if defined(PADDLE_WITH_XPU_BKCL) #include "xpu/bkcl.h" #endif +#if defined(PADDLE_WITH_CNCL) +#include +#endif #include #include "paddle/fluid/framework/op_registry.h" #if defined(PADDLE_WITH_NCCL) || defined(PADDLE_WITH_RCCL) || \ - defined(PADDLE_WITH_XPU_BKCL) + defined(PADDLE_WITH_XPU_BKCL) || defined(PADDLE_WITH_CNCL) #include "paddle/fluid/platform/collective_helper.h" #endif @@ -56,18 +59,23 @@ class CCommInitOp : public framework::OperatorBase { using UniqueId = BKCLUniqueId; using Place = platform::XPUPlace; using CommContext = platform::BKCLCommContext; +#elif defined(PADDLE_WITH_CNCL) + using UniqueId = cnclCliqueId; + using Place = platform::MLUPlace; + using CommContext = platform::CNCLCommContext; #else PADDLE_THROW(platform::errors::PreconditionNotMet( - "PaddlePaddle should be compiled with GPU or XPU.")); + "PaddlePaddle should be compiled with GPU or XPU or MLU.")); #endif PADDLE_ENFORCE_EQ( - platform::is_gpu_place(place) || platform::is_xpu_place(place), true, - platform::errors::PreconditionNotMet( - "CCommInitOp can run on gpu or xpu place only.")); + platform::is_gpu_place(place) || platform::is_xpu_place(place) || + platform::is_mlu_place(place), + true, platform::errors::PreconditionNotMet( + "CCommInitOp can run on gpu or xpu or mlu place only.")); #if defined(PADDLE_WITH_NCCL) || defined(PADDLE_WITH_RCCL) || \ - defined(PADDLE_WITH_XPU_BKCL) + defined(PADDLE_WITH_XPU_BKCL) || defined(PADDLE_WITH_CNCL) auto var = scope.FindVar(Input("X")); PADDLE_ENFORCE_NOT_NULL( var, platform::errors::InvalidArgument("Input con not be empty.")); -- GitLab