未验证 提交 faa9b97b 编写于 作者: L lilong12 提交者: GitHub

fix cscatter, test=develop (#26554)

上级 ed102ea1
...@@ -64,12 +64,19 @@ class CScatterOpCUDAKernel : public framework::OpKernel<T> { ...@@ -64,12 +64,19 @@ class CScatterOpCUDAKernel : public framework::OpKernel<T> {
framework::DDim x_dims = x->dims(); framework::DDim x_dims = x->dims();
framework::DDim out_dims(x_dims); framework::DDim out_dims(x_dims);
framework::Tensor temp; framework::Tensor temp;
auto in_data_ptr = x->data<T>(); auto out_ptr = temp.mutable_data<T>(out_dims, place);
PADDLE_ENFORCE_CUDA_SUCCESS(platform::dynload::ncclBroadcast( if (root_id == comm->rank()) {
reinterpret_cast<const void*>(in_data_ptr), PADDLE_ENFORCE_CUDA_SUCCESS(platform::dynload::ncclBcast(
temp.mutable_data<T>(out_dims, place), numel, dtype, root_id, reinterpret_cast<void*>(const_cast<T*>(x->data<T>())), numel, dtype,
comm->comm(), stream)); root_id, comm->comm(), stream));
VLOG(3) << "rank " << comm->rank() << " invoke Scatter.";
framework::TensorCopy(*static_cast<const framework::Tensor*>(x), place,
*platform::DeviceContextPool::Instance().Get(place),
static_cast<framework::Tensor*>(&temp));
} else {
PADDLE_ENFORCE_CUDA_SUCCESS(platform::dynload::ncclBcast(
out_ptr, numel, dtype, root_id, comm->comm(), stream));
}
out_dims[0] = out_dims[0] / nranks; out_dims[0] = out_dims[0] / nranks;
auto start_index = out_dims[0] * comm->rank(); auto start_index = out_dims[0] * comm->rank();
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册