From f4456ffea10cdb8a51ed67677e1e6eed5258bde7 Mon Sep 17 00:00:00 2001 From: sandyhouse Date: Thu, 3 Sep 2020 11:02:43 +0000 Subject: [PATCH] update code, test=develop --- paddle/fluid/operators/collective/c_recv_op.cc | 2 +- paddle/fluid/operators/collective/c_recv_op.cu.cc | 9 +++++---- 2 files changed, 6 insertions(+), 5 deletions(-) diff --git a/paddle/fluid/operators/collective/c_recv_op.cc b/paddle/fluid/operators/collective/c_recv_op.cc index af3599a7bda..ed786974432 100644 --- a/paddle/fluid/operators/collective/c_recv_op.cc +++ b/paddle/fluid/operators/collective/c_recv_op.cc @@ -31,7 +31,7 @@ class CRecvOp : public framework::OperatorWithKernel { } }; -class CSendOpMaker : public framework::OpProtoAndCheckerMaker { +class CRecvOpMaker : public framework::OpProtoAndCheckerMaker { public: void Make() { AddInput("Out", "(Tensor) tensor to receive."); diff --git a/paddle/fluid/operators/collective/c_recv_op.cu.cc b/paddle/fluid/operators/collective/c_recv_op.cu.cc index ea6a612b053..69f5d5beb9d 100644 --- a/paddle/fluid/operators/collective/c_recv_op.cu.cc +++ b/paddle/fluid/operators/collective/c_recv_op.cu.cc @@ -23,11 +23,11 @@ namespace paddle { namespace operators { template -class CSendOpCUDAKernel : public framework::OpKernel { +class CRecvOpCUDAKernel : public framework::OpKernel { public: void Compute(const framework::ExecutionContext& ctx) const override { #if defined(PADDLE_WITH_NCCL) - auto out = ctx.Output("Out"); + auto out = ctx.Input("Out"); int numel = out->numel(); ncclDataType_t dtype = platform::ToNCCLDataType(out->type()); @@ -44,8 +44,9 @@ class CSendOpCUDAKernel : public framework::OpKernel { } int peer = ctx.Attr("peer"); - PADDLE_ENFORCE_CUDA_SUCCESS(platform::dynload::ncclSend( - out->mutable_data(place), numel, dtype, peer, comm->comm(), stream)); + PADDLE_ENFORCE_CUDA_SUCCESS( + platform::dynload::ncclRecv(const_cast(out->data()), numel, + dtype, peer, comm->comm(), stream)); VLOG(3) << "rank " << comm->rank() << " recv " << framework::product(out->dims()) << " from " << peer; #else -- GitLab