diff --git a/paddle/fluid/operators/collective/c_scatter_op.cu.cc b/paddle/fluid/operators/collective/c_scatter_op.cu.cc index c5cd32ef07aa9f9a7126efa68bf0709dcb26feea..8d9e6b4b7d99044f584e9e21062a786252d60f76 100644 --- a/paddle/fluid/operators/collective/c_scatter_op.cu.cc +++ b/paddle/fluid/operators/collective/c_scatter_op.cu.cc @@ -64,12 +64,19 @@ class CScatterOpCUDAKernel : public framework::OpKernel { framework::DDim x_dims = x->dims(); framework::DDim out_dims(x_dims); framework::Tensor temp; - auto in_data_ptr = x->data(); - PADDLE_ENFORCE_CUDA_SUCCESS(platform::dynload::ncclBroadcast( - reinterpret_cast(in_data_ptr), - temp.mutable_data(out_dims, place), numel, dtype, root_id, - comm->comm(), stream)); - VLOG(3) << "rank " << comm->rank() << " invoke Scatter."; + auto out_ptr = temp.mutable_data(out_dims, place); + if (root_id == comm->rank()) { + PADDLE_ENFORCE_CUDA_SUCCESS(platform::dynload::ncclBcast( + reinterpret_cast(const_cast(x->data())), numel, dtype, + root_id, comm->comm(), stream)); + + framework::TensorCopy(*static_cast(x), place, + *platform::DeviceContextPool::Instance().Get(place), + static_cast(&temp)); + } else { + PADDLE_ENFORCE_CUDA_SUCCESS(platform::dynload::ncclBcast( + out_ptr, numel, dtype, root_id, comm->comm(), stream)); + } out_dims[0] = out_dims[0] / nranks; auto start_index = out_dims[0] * comm->rank();