From faa9b97b7815b083b8afaaafd6efb1fb90e68936 Mon Sep 17 00:00:00 2001 From: lilong12 Date: Sat, 22 Aug 2020 14:38:17 +0800 Subject: [PATCH] fix cscatter, test=develop (#26554) --- .../operators/collective/c_scatter_op.cu.cc | 19 +++++++++++++------ 1 file changed, 13 insertions(+), 6 deletions(-) diff --git a/paddle/fluid/operators/collective/c_scatter_op.cu.cc b/paddle/fluid/operators/collective/c_scatter_op.cu.cc index c5cd32ef07..8d9e6b4b7d 100644 --- a/paddle/fluid/operators/collective/c_scatter_op.cu.cc +++ b/paddle/fluid/operators/collective/c_scatter_op.cu.cc @@ -64,12 +64,19 @@ class CScatterOpCUDAKernel : public framework::OpKernel { framework::DDim x_dims = x->dims(); framework::DDim out_dims(x_dims); framework::Tensor temp; - auto in_data_ptr = x->data(); - PADDLE_ENFORCE_CUDA_SUCCESS(platform::dynload::ncclBroadcast( - reinterpret_cast(in_data_ptr), - temp.mutable_data(out_dims, place), numel, dtype, root_id, - comm->comm(), stream)); - VLOG(3) << "rank " << comm->rank() << " invoke Scatter."; + auto out_ptr = temp.mutable_data(out_dims, place); + if (root_id == comm->rank()) { + PADDLE_ENFORCE_CUDA_SUCCESS(platform::dynload::ncclBcast( + reinterpret_cast(const_cast(x->data())), numel, dtype, + root_id, comm->comm(), stream)); + + framework::TensorCopy(*static_cast(x), place, + *platform::DeviceContextPool::Instance().Get(place), + static_cast(&temp)); + } else { + PADDLE_ENFORCE_CUDA_SUCCESS(platform::dynload::ncclBcast( + out_ptr, numel, dtype, root_id, comm->comm(), stream)); + } out_dims[0] = out_dims[0] / nranks; auto start_index = out_dims[0] * comm->rank(); -- GitLab