diff --git a/paddle/fluid/operators/collective/alltoall_op.cu.cc b/paddle/fluid/operators/collective/alltoall_op.cu.cc index 26fdee200cd84f0a455fdd8500f194c71c421928..0e0ea72208488c520b24680611a272bd378c9338 100644 --- a/paddle/fluid/operators/collective/alltoall_op.cu.cc +++ b/paddle/fluid/operators/collective/alltoall_op.cu.cc @@ -14,7 +14,7 @@ limitations under the License. */ #include "paddle/fluid/operators/collective/alltoall_op.h" -#if defined(PADDLE_WITH_NCCL) +#if defined(PADDLE_WITH_NCCL) || defined(PADDLE_WITH_RCCL) #include "paddle/fluid/platform/collective_helper.h" #include "paddle/fluid/platform/device/gpu/nccl_helper.h" #endif @@ -26,7 +26,7 @@ template class AllToAllOpCUDAKernel : public framework::OpKernel { public: void Compute(const framework::ExecutionContext& ctx) const override { -#if defined(PADDLE_WITH_NCCL) +#if defined(PADDLE_WITH_NCCL) || defined(PADDLE_WITH_RCCL) #if NCCL_VERSION_CODE >= 2703 auto x = ctx.Input("X"); auto out = ctx.Output("Out"); @@ -43,7 +43,7 @@ class AllToAllOpCUDAKernel : public framework::OpKernel { auto comm = platform::NCCLCommContext::Instance().Get(ring_id, place); int nranks = comm->nranks(); - cudaStream_t stream = nullptr; + gpuStream_t stream = nullptr; if (ctx.Attr("use_calc_stream")) { auto dev_ctx = platform::DeviceContextPool::Instance().Get(place); stream = static_cast(dev_ctx)->stream(); diff --git a/paddle/fluid/operators/collective/c_comm_init_multitrainer_op.cc b/paddle/fluid/operators/collective/c_comm_init_multitrainer_op.cc index f69fe8f1e3f1fa4d337760165e2bebad146b0925..86c966378ccb689aae156aa3e725360f77db275f 100644 --- a/paddle/fluid/operators/collective/c_comm_init_multitrainer_op.cc +++ b/paddle/fluid/operators/collective/c_comm_init_multitrainer_op.cc @@ -14,6 +14,9 @@ limitations under the License. */ #if defined(PADDLE_WITH_NCCL) #include #endif +#if defined(PADDLE_WITH_RCCL) +#include +#endif #include #include #include @@ -24,7 +27,7 @@ limitations under the License. */ #include "paddle/fluid/framework/threadpool.h" // #include "paddle/fluid/operators/distributed/distributed.h" // #include "paddle/fluid/operators/distributed/request_handler_impl.h" -#if defined(PADDLE_WITH_NCCL) +#if defined(PADDLE_WITH_NCCL) || defined(PADDLE_WITH_RCCL) #include "paddle/fluid/platform/collective_helper.h" #include "paddle/fluid/platform/device/gpu/nccl_helper.h" #endif @@ -51,7 +54,7 @@ class CCommInitMultiTrainerOp : public framework::OperatorBase { auto var = scope.FindVar(Input("X")); PADDLE_ENFORCE_NOT_NULL( var, platform::errors::InvalidArgument("Input X must be provided.")); -#if defined(PADDLE_WITH_NCCL) +#if defined(PADDLE_WITH_NCCL) || defined(PADDLE_WITH_RCCL) ncclUniqueId* nccl_id = var->GetMutable(); int ntrainers = Attr("ntrainers"); diff --git a/paddle/fluid/operators/collective/global_gather_op.cu.cc b/paddle/fluid/operators/collective/global_gather_op.cu.cc index 4f9725a27062b2b57d7325c4342b01e1204f0a4a..6684470e881cb0daf56d740f1f820e881c632df1 100644 --- a/paddle/fluid/operators/collective/global_gather_op.cu.cc +++ b/paddle/fluid/operators/collective/global_gather_op.cu.cc @@ -14,7 +14,7 @@ limitations under the License. */ #include "paddle/fluid/operators/collective/global_gather_op.h" -#if defined(PADDLE_WITH_NCCL) +#if defined(PADDLE_WITH_NCCL) || defined(PADDLE_WITH_RCCL) #include "paddle/fluid/platform/collective_helper.h" #include "paddle/fluid/platform/device/gpu/nccl_helper.h" #endif @@ -26,7 +26,7 @@ template class GlobalGatherOpCUDAKernel : public framework::OpKernel { public: void Compute(const framework::ExecutionContext& ctx) const override { -#if defined(PADDLE_WITH_NCCL) +#if defined(PADDLE_WITH_NCCL) || defined(PADDLE_WITH_RCCL) #if NCCL_VERSION_CODE >= 2703 auto x = ctx.Input("X"); auto local_count = ctx.Input("local_count"); @@ -79,7 +79,7 @@ class GlobalGatherOpCUDAKernel : public framework::OpKernel { ring_id)); auto place = ctx.GetPlace(); auto comm = platform::NCCLCommContext::Instance().Get(ring_id, place); - cudaStream_t stream = nullptr; + gpuStream_t stream = nullptr; if (ctx.Attr("use_calc_stream")) { auto dev_ctx = platform::DeviceContextPool::Instance().Get(place); stream = static_cast(dev_ctx)->stream(); diff --git a/paddle/fluid/operators/collective/global_scatter_op.cu.cc b/paddle/fluid/operators/collective/global_scatter_op.cu.cc index 3a7e6a0079ac5af011fc5aa84810b48143e84544..cd3c3a3229ca068b634e0111e85195e5c7935e34 100644 --- a/paddle/fluid/operators/collective/global_scatter_op.cu.cc +++ b/paddle/fluid/operators/collective/global_scatter_op.cu.cc @@ -14,7 +14,7 @@ limitations under the License. */ #include "paddle/fluid/operators/collective/global_scatter_op.h" -#if defined(PADDLE_WITH_NCCL) +#if defined(PADDLE_WITH_NCCL) || defined(PADDLE_WITH_RCCL) #include "paddle/fluid/platform/collective_helper.h" #include "paddle/fluid/platform/device/gpu/nccl_helper.h" #endif @@ -26,7 +26,7 @@ template class GlobalScatterOpCUDAKernel : public framework::OpKernel { public: void Compute(const framework::ExecutionContext& ctx) const override { -#if defined(PADDLE_WITH_NCCL) +#if defined(PADDLE_WITH_NCCL) || defined(PADDLE_WITH_RCCL) #if NCCL_VERSION_CODE >= 2703 auto x = ctx.Input("X"); auto local_count = ctx.Input("local_count"); @@ -78,7 +78,7 @@ class GlobalScatterOpCUDAKernel : public framework::OpKernel { auto place = ctx.GetPlace(); auto comm = platform::NCCLCommContext::Instance().Get(ring_id, place); - cudaStream_t stream = nullptr; + gpuStream_t stream = nullptr; if (ctx.Attr("use_calc_stream")) { auto dev_ctx = platform::DeviceContextPool::Instance().Get(place); stream = static_cast(dev_ctx)->stream();