From e26e51ba87b343fd63c1bc2a0f8c158f1efd6162 Mon Sep 17 00:00:00 2001 From: xiayanming Date: Thu, 14 Apr 2022 11:21:15 +0800 Subject: [PATCH] [fix bug] communication op suppport rccl (#41763) --- paddle/fluid/operators/collective/alltoall_op.cu.cc | 6 +++--- .../operators/collective/c_comm_init_multitrainer_op.cc | 7 +++++-- paddle/fluid/operators/collective/global_gather_op.cu.cc | 6 +++--- paddle/fluid/operators/collective/global_scatter_op.cu.cc | 6 +++--- 4 files changed, 14 insertions(+), 11 deletions(-) diff --git a/paddle/fluid/operators/collective/alltoall_op.cu.cc b/paddle/fluid/operators/collective/alltoall_op.cu.cc index 26fdee200cd..0e0ea722084 100644 --- a/paddle/fluid/operators/collective/alltoall_op.cu.cc +++ b/paddle/fluid/operators/collective/alltoall_op.cu.cc @@ -14,7 +14,7 @@ limitations under the License. */ #include "paddle/fluid/operators/collective/alltoall_op.h" -#if defined(PADDLE_WITH_NCCL) +#if defined(PADDLE_WITH_NCCL) || defined(PADDLE_WITH_RCCL) #include "paddle/fluid/platform/collective_helper.h" #include "paddle/fluid/platform/device/gpu/nccl_helper.h" #endif @@ -26,7 +26,7 @@ template class AllToAllOpCUDAKernel : public framework::OpKernel { public: void Compute(const framework::ExecutionContext& ctx) const override { -#if defined(PADDLE_WITH_NCCL) +#if defined(PADDLE_WITH_NCCL) || defined(PADDLE_WITH_RCCL) #if NCCL_VERSION_CODE >= 2703 auto x = ctx.Input("X"); auto out = ctx.Output("Out"); @@ -43,7 +43,7 @@ class AllToAllOpCUDAKernel : public framework::OpKernel { auto comm = platform::NCCLCommContext::Instance().Get(ring_id, place); int nranks = comm->nranks(); - cudaStream_t stream = nullptr; + gpuStream_t stream = nullptr; if (ctx.Attr("use_calc_stream")) { auto dev_ctx = platform::DeviceContextPool::Instance().Get(place); stream = static_cast(dev_ctx)->stream(); diff --git a/paddle/fluid/operators/collective/c_comm_init_multitrainer_op.cc b/paddle/fluid/operators/collective/c_comm_init_multitrainer_op.cc index f69fe8f1e3f..86c966378cc 100644 --- a/paddle/fluid/operators/collective/c_comm_init_multitrainer_op.cc +++ b/paddle/fluid/operators/collective/c_comm_init_multitrainer_op.cc @@ -14,6 +14,9 @@ limitations under the License. */ #if defined(PADDLE_WITH_NCCL) #include #endif +#if defined(PADDLE_WITH_RCCL) +#include +#endif #include #include #include @@ -24,7 +27,7 @@ limitations under the License. */ #include "paddle/fluid/framework/threadpool.h" // #include "paddle/fluid/operators/distributed/distributed.h" // #include "paddle/fluid/operators/distributed/request_handler_impl.h" -#if defined(PADDLE_WITH_NCCL) +#if defined(PADDLE_WITH_NCCL) || defined(PADDLE_WITH_RCCL) #include "paddle/fluid/platform/collective_helper.h" #include "paddle/fluid/platform/device/gpu/nccl_helper.h" #endif @@ -51,7 +54,7 @@ class CCommInitMultiTrainerOp : public framework::OperatorBase { auto var = scope.FindVar(Input("X")); PADDLE_ENFORCE_NOT_NULL( var, platform::errors::InvalidArgument("Input X must be provided.")); -#if defined(PADDLE_WITH_NCCL) +#if defined(PADDLE_WITH_NCCL) || defined(PADDLE_WITH_RCCL) ncclUniqueId* nccl_id = var->GetMutable(); int ntrainers = Attr("ntrainers"); diff --git a/paddle/fluid/operators/collective/global_gather_op.cu.cc b/paddle/fluid/operators/collective/global_gather_op.cu.cc index 4f9725a2706..6684470e881 100644 --- a/paddle/fluid/operators/collective/global_gather_op.cu.cc +++ b/paddle/fluid/operators/collective/global_gather_op.cu.cc @@ -14,7 +14,7 @@ limitations under the License. */ #include "paddle/fluid/operators/collective/global_gather_op.h" -#if defined(PADDLE_WITH_NCCL) +#if defined(PADDLE_WITH_NCCL) || defined(PADDLE_WITH_RCCL) #include "paddle/fluid/platform/collective_helper.h" #include "paddle/fluid/platform/device/gpu/nccl_helper.h" #endif @@ -26,7 +26,7 @@ template class GlobalGatherOpCUDAKernel : public framework::OpKernel { public: void Compute(const framework::ExecutionContext& ctx) const override { -#if defined(PADDLE_WITH_NCCL) +#if defined(PADDLE_WITH_NCCL) || defined(PADDLE_WITH_RCCL) #if NCCL_VERSION_CODE >= 2703 auto x = ctx.Input("X"); auto local_count = ctx.Input("local_count"); @@ -79,7 +79,7 @@ class GlobalGatherOpCUDAKernel : public framework::OpKernel { ring_id)); auto place = ctx.GetPlace(); auto comm = platform::NCCLCommContext::Instance().Get(ring_id, place); - cudaStream_t stream = nullptr; + gpuStream_t stream = nullptr; if (ctx.Attr("use_calc_stream")) { auto dev_ctx = platform::DeviceContextPool::Instance().Get(place); stream = static_cast(dev_ctx)->stream(); diff --git a/paddle/fluid/operators/collective/global_scatter_op.cu.cc b/paddle/fluid/operators/collective/global_scatter_op.cu.cc index 3a7e6a0079a..cd3c3a3229c 100644 --- a/paddle/fluid/operators/collective/global_scatter_op.cu.cc +++ b/paddle/fluid/operators/collective/global_scatter_op.cu.cc @@ -14,7 +14,7 @@ limitations under the License. */ #include "paddle/fluid/operators/collective/global_scatter_op.h" -#if defined(PADDLE_WITH_NCCL) +#if defined(PADDLE_WITH_NCCL) || defined(PADDLE_WITH_RCCL) #include "paddle/fluid/platform/collective_helper.h" #include "paddle/fluid/platform/device/gpu/nccl_helper.h" #endif @@ -26,7 +26,7 @@ template class GlobalScatterOpCUDAKernel : public framework::OpKernel { public: void Compute(const framework::ExecutionContext& ctx) const override { -#if defined(PADDLE_WITH_NCCL) +#if defined(PADDLE_WITH_NCCL) || defined(PADDLE_WITH_RCCL) #if NCCL_VERSION_CODE >= 2703 auto x = ctx.Input("X"); auto local_count = ctx.Input("local_count"); @@ -78,7 +78,7 @@ class GlobalScatterOpCUDAKernel : public framework::OpKernel { auto place = ctx.GetPlace(); auto comm = platform::NCCLCommContext::Instance().Get(ring_id, place); - cudaStream_t stream = nullptr; + gpuStream_t stream = nullptr; if (ctx.Attr("use_calc_stream")) { auto dev_ctx = platform::DeviceContextPool::Instance().Get(place); stream = static_cast(dev_ctx)->stream(); -- GitLab