From b2f7ab6636d5e7fcc3bfd655c071416f190ed619 Mon Sep 17 00:00:00 2001 From: lilong12 Date: Mon, 16 Nov 2020 20:58:29 +0800 Subject: [PATCH] bug fix, test=develop (#28648) --- paddle/fluid/operators/collective/recv_v2_op.cu.cc | 2 +- paddle/fluid/operators/collective/send_v2_op.cu.cc | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/paddle/fluid/operators/collective/recv_v2_op.cu.cc b/paddle/fluid/operators/collective/recv_v2_op.cu.cc index f0dd8aee235..892056f2135 100644 --- a/paddle/fluid/operators/collective/recv_v2_op.cu.cc +++ b/paddle/fluid/operators/collective/recv_v2_op.cu.cc @@ -26,6 +26,7 @@ template class RecvOpV2CUDAKernel : public framework::OpKernel { public: void Compute(const framework::ExecutionContext &ctx) const override { +#if defined(PADDLE_WITH_NCCL) && NCCL_VERSION_CODE >= 2703 int rid = ctx.Attr("ring_id"); PADDLE_ENFORCE_GE( rid, 0, @@ -44,7 +45,6 @@ class RecvOpV2CUDAKernel : public framework::OpKernel { framework::proto::VarType::Type type = framework::proto::VarType::Type(data_type); -#if defined(PADDLE_WITH_NCCL) && NCCL_VERSION_CODE >= 2703 cudaStream_t stream = nullptr; auto place = ctx.GetPlace(); auto comm = platform::NCCLCommContext::Instance().Get(rid, place); diff --git a/paddle/fluid/operators/collective/send_v2_op.cu.cc b/paddle/fluid/operators/collective/send_v2_op.cu.cc index 9f925b2eede..4de3f47ccc6 100644 --- a/paddle/fluid/operators/collective/send_v2_op.cu.cc +++ b/paddle/fluid/operators/collective/send_v2_op.cu.cc @@ -26,6 +26,7 @@ template class SendOpV2CUDAKernel : public framework::OpKernel { public: void Compute(const framework::ExecutionContext& ctx) const override { +#if defined(PADDLE_WITH_NCCL) && NCCL_VERSION_CODE >= 2703 auto x = ctx.Input("X"); int numel = x->numel(); @@ -42,7 +43,6 @@ class SendOpV2CUDAKernel : public framework::OpKernel { "The peer (%d) for send_v2 op must be non-negative.", peer)); cudaStream_t stream = nullptr; auto place = ctx.GetPlace(); -#if defined(PADDLE_WITH_NCCL) && NCCL_VERSION_CODE >= 2703 auto comm = platform::NCCLCommContext::Instance().Get(rid, place); if (ctx.Attr("use_calc_stream")) { auto dev_ctx = platform::DeviceContextPool::Instance().Get(place); -- GitLab