diff --git a/paddle/fluid/operators/collective/c_recv_op.cc b/paddle/fluid/operators/collective/c_recv_op.cc index af3599a7bda409e9d34871fcf30eb2c406b4287e..ed7869744326e4a6e8bff8317b07e31b0a7620f3 100644 --- a/paddle/fluid/operators/collective/c_recv_op.cc +++ b/paddle/fluid/operators/collective/c_recv_op.cc @@ -31,7 +31,7 @@ class CRecvOp : public framework::OperatorWithKernel { } }; -class CSendOpMaker : public framework::OpProtoAndCheckerMaker { +class CRecvOpMaker : public framework::OpProtoAndCheckerMaker { public: void Make() { AddInput("Out", "(Tensor) tensor to receive."); diff --git a/paddle/fluid/operators/collective/c_recv_op.cu.cc b/paddle/fluid/operators/collective/c_recv_op.cu.cc index ea6a612b053d0f5651a668d29ae12fcbc77be5ef..69f5d5beb9d3b033d33fc3c6d93b630d68acdda7 100644 --- a/paddle/fluid/operators/collective/c_recv_op.cu.cc +++ b/paddle/fluid/operators/collective/c_recv_op.cu.cc @@ -23,11 +23,11 @@ namespace paddle { namespace operators { template -class CSendOpCUDAKernel : public framework::OpKernel { +class CRecvOpCUDAKernel : public framework::OpKernel { public: void Compute(const framework::ExecutionContext& ctx) const override { #if defined(PADDLE_WITH_NCCL) - auto out = ctx.Output("Out"); + auto out = ctx.Input("Out"); int numel = out->numel(); ncclDataType_t dtype = platform::ToNCCLDataType(out->type()); @@ -44,8 +44,9 @@ class CSendOpCUDAKernel : public framework::OpKernel { } int peer = ctx.Attr("peer"); - PADDLE_ENFORCE_CUDA_SUCCESS(platform::dynload::ncclSend( - out->mutable_data(place), numel, dtype, peer, comm->comm(), stream)); + PADDLE_ENFORCE_CUDA_SUCCESS( + platform::dynload::ncclRecv(const_cast(out->data()), numel, + dtype, peer, comm->comm(), stream)); VLOG(3) << "rank " << comm->rank() << " recv " << framework::product(out->dims()) << " from " << peer; #else