未验证 提交 faa9b97b 编写于 作者: L lilong12 提交者: GitHub

fix cscatter, test=develop (#26554)

上级 ed102ea1
......@@ -64,12 +64,19 @@ class CScatterOpCUDAKernel : public framework::OpKernel<T> {
framework::DDim x_dims = x->dims();
framework::DDim out_dims(x_dims);
framework::Tensor temp;
auto in_data_ptr = x->data<T>();
PADDLE_ENFORCE_CUDA_SUCCESS(platform::dynload::ncclBroadcast(
reinterpret_cast<const void*>(in_data_ptr),
temp.mutable_data<T>(out_dims, place), numel, dtype, root_id,
comm->comm(), stream));
VLOG(3) << "rank " << comm->rank() << " invoke Scatter.";
auto out_ptr = temp.mutable_data<T>(out_dims, place);
if (root_id == comm->rank()) {
PADDLE_ENFORCE_CUDA_SUCCESS(platform::dynload::ncclBcast(
reinterpret_cast<void*>(const_cast<T*>(x->data<T>())), numel, dtype,
root_id, comm->comm(), stream));
framework::TensorCopy(*static_cast<const framework::Tensor*>(x), place,
*platform::DeviceContextPool::Instance().Get(place),
static_cast<framework::Tensor*>(&temp));
} else {
PADDLE_ENFORCE_CUDA_SUCCESS(platform::dynload::ncclBcast(
out_ptr, numel, dtype, root_id, comm->comm(), stream));
}
out_dims[0] = out_dims[0] / nranks;
auto start_index = out_dims[0] * comm->rank();
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册