diff --git a/paddle/fluid/operators/collective/c_allgather_op.cc b/paddle/fluid/operators/collective/c_allgather_op.cc index 04019756ffe3ede07fceb52d0671f060f857032f..b78d9504dafb76b6963433c55b7e9f58bd185753 100644 --- a/paddle/fluid/operators/collective/c_allgather_op.cc +++ b/paddle/fluid/operators/collective/c_allgather_op.cc @@ -31,8 +31,13 @@ class CAllGatherOp : public framework::OperatorWithKernel { platform::errors::InvalidArgument( "The value of nranks should be >=2.")); framework::DDim dim = ctx->GetInputDim("X"); - dim[0] = dim[0] * nranks; - if (dim[0] < 0) dim[0] = -1; + // 0D use stack/unstack while others use concat/split + if (dim.size() == 0) { + dim = phi::make_ddim({nranks}); + } else { + dim[0] = dim[0] * nranks; + if (dim[0] < 0) dim[0] = -1; + } ctx->SetOutputDim("Out", dim); } }; diff --git a/paddle/fluid/operators/collective/c_allgather_op.cu.cc b/paddle/fluid/operators/collective/c_allgather_op.cu.cc index 93be43a1a324a6f857975002f0367b6ac8f65385..70b7d70dc93b31b032bf80e9e41121eeb57c4848 100644 --- a/paddle/fluid/operators/collective/c_allgather_op.cu.cc +++ b/paddle/fluid/operators/collective/c_allgather_op.cu.cc @@ -57,13 +57,9 @@ class CAllGatherOpCUDAKernel : public framework::OpKernel { platform::errors::InvalidArgument( "nranks: %s should equal to %s", nranks, comm->nranks())); - framework::DDim out_dims = in->dims(); - out_dims[0] *= nranks; - out->mutable_data(out_dims, place); - int64_t send_numel = in->numel(); const T* send_buff = in->data(); - T* recv_buff = out->data(); + T* recv_buff = out->mutable_data(place); gpuStream_t stream = nullptr; if (ctx.Attr("use_calc_stream")) { diff --git a/paddle/fluid/operators/collective/c_allgather_op.h b/paddle/fluid/operators/collective/c_allgather_op.h index e896f96ead5329f06d19ad53ad1321737eacdb6c..c5373bf13043809cf85d259799ff5a39297b5336 100644 --- a/paddle/fluid/operators/collective/c_allgather_op.h +++ b/paddle/fluid/operators/collective/c_allgather_op.h @@ -39,15 +39,12 @@ class CAllGatherOpCPUKernel : public framework::OpKernel { #if defined(PADDLE_WITH_GLOO) auto in = ctx.Input("X"); auto out = ctx.Output("Out"); - framework::DDim out_dims = in->dims(); auto place = ctx.GetPlace(); - auto gloo = paddle::framework::GlooWrapper::GetInstance(); auto nranks = gloo->Size(); - out_dims[0] *= nranks; int64_t send_numel = in->numel(); const T* send_buff = in->data(); - T* recv_buff = out->mutable_data(out_dims, place); + T* recv_buff = out->mutable_data(place); PADDLE_ENFORCE_EQ( gloo->IsInitialized(), diff --git a/paddle/fluid/operators/collective/c_allgather_op_xpu.cc b/paddle/fluid/operators/collective/c_allgather_op_xpu.cc index 1e7d3f3a9fec17c9944a2f86b896d5035ff4a8d9..c4fdb0fdf290e31e75bb87d3f894a79beae4aa23 100644 --- a/paddle/fluid/operators/collective/c_allgather_op_xpu.cc +++ b/paddle/fluid/operators/collective/c_allgather_op_xpu.cc @@ -43,12 +43,9 @@ class CAllGatherOpXPUKernel : public framework::OpKernel { platform::errors::InvalidArgument( "nranks: %s should equal to %s", nranks, comm->nranks())); - framework::DDim out_dims = in->dims(); - out_dims[0] *= nranks; - size_t numel = in->numel(); const void* sendbuff = in->data(); - void* recvbuff = out->mutable_data(out_dims, place); + void* recvbuff = out->mutable_data(place); XPUStream stream = nullptr; if (ctx.Attr("use_calc_stream")) { diff --git a/paddle/phi/core/ddim.cc b/paddle/phi/core/ddim.cc index 3256458e02be9e9659cf3bbba52503bc0277c459..05ca29843b42f6dd71ad6f950509848f8927cda6 100644 --- a/paddle/phi/core/ddim.cc +++ b/paddle/phi/core/ddim.cc @@ -154,7 +154,7 @@ DDim flatten_to_1d(const DDim& src) { return DDim({product(src)}); } DDim stride(const DDim& ddim) { DDim strides; strides.rank_ = ddim.size(); - strides[ddim.size() - 1] = 1; + if (ddim.size() > 0) strides[ddim.size() - 1] = 1; for (int i = ddim.size() - 2; i >= 0; --i) { strides[i] = strides[i + 1] * ddim[i + 1]; } @@ -164,7 +164,7 @@ DDim stride(const DDim& ddim) { DDim stride_numel(const DDim& ddim) { DDim strides; strides.rank_ = ddim.size(); - strides[ddim.size() - 1] = ddim[ddim.size() - 1]; + if (ddim.size() > 0) strides[ddim.size() - 1] = ddim[ddim.size() - 1]; for (int i = ddim.size() - 2; i >= 0; --i) { strides[i] = strides[i + 1] * ddim[i]; } diff --git a/python/paddle/distributed/communication/scatter.py b/python/paddle/distributed/communication/scatter.py index a8ba94e49c8146ddcaa8166343785b7ac61fc2c6..2826779d55738d714a229984d0a09613256b9882 100644 --- a/python/paddle/distributed/communication/scatter.py +++ b/python/paddle/distributed/communication/scatter.py @@ -122,8 +122,7 @@ def scatter_object_list( in_obj_sizes.append(obj_size) max_obj_size_tensor = max(in_obj_sizes) else: - # NOTE: shape can be [] after 0D tensor support - max_obj_size_tensor = paddle.empty([1], dtype="int64") + max_obj_size_tensor = paddle.empty([], dtype="int64") stream.broadcast(max_obj_size_tensor, src) max_obj_size = int(max_obj_size_tensor.item()) @@ -137,8 +136,7 @@ def scatter_object_list( out_tensor = paddle.empty([max_obj_size], dtype="uint8") scatter(out_tensor, in_tensor_list if rank == src else None, src, group) - # NOTE: shape can be [] after 0D tensor support - out_tensor_size = paddle.empty([1], dtype="int64") + out_tensor_size = paddle.empty([], dtype="int64") scatter(out_tensor_size, in_obj_sizes if rank == src else None, src, group) out_object_list.clear() diff --git a/python/paddle/distributed/communication/stream/all_gather.py b/python/paddle/distributed/communication/stream/all_gather.py index 4d02753a1a634b98b94770c0bff875e67b08d71d..69d9c5d52e080205776d5ecd1135dd0dc22b65f3 100644 --- a/python/paddle/distributed/communication/stream/all_gather.py +++ b/python/paddle/distributed/communication/stream/all_gather.py @@ -108,7 +108,11 @@ def _all_gather_in_static_mode(tensor_list, tensor, group, sync_op): }, ) tensor_list.clear() - tensor_list.extend(paddle.split(out, nranks, 0)) + # 0D use stack/unstack while others use concat/split + if len(tensor.shape) == 0: + tensor_list.extend(paddle.unstack(out, 0)) + else: + tensor_list.extend(paddle.split(out, nranks, 0)) def all_gather( diff --git a/python/paddle/distributed/communication/stream/all_to_all.py b/python/paddle/distributed/communication/stream/all_to_all.py index d8793601f729acf9bc81ea926df3a08ea5e5a442..38b1d2fcb3e8228f186919ca676c0ed8841b54e7 100644 --- a/python/paddle/distributed/communication/stream/all_to_all.py +++ b/python/paddle/distributed/communication/stream/all_to_all.py @@ -78,7 +78,12 @@ def _all_to_all_in_static_mode( if isinstance(in_tensor_or_tensor_list, list): if len(in_tensor_or_tensor_list) == 0: raise RuntimeError("The input tensor_list should not be empty.") - in_tensor = paddle.concat(in_tensor_or_tensor_list, axis=0) + # 0D use stack/unstack while others use concat/split + if len(in_tensor_or_tensor_list[0].shape) == 0: + in_tensor = paddle.stack(in_tensor_or_tensor_list, axis=0) + else: + in_tensor = paddle.concat(in_tensor_or_tensor_list, axis=0) + out_tensor = out_tensor_or_tensor_list if isinstance(out_tensor_or_tensor_list, list): if len(out_tensor_or_tensor_list) != 0: @@ -110,7 +115,13 @@ def _all_to_all_in_static_mode( if isinstance(out_tensor_or_tensor_list, list): if not sync_op: dist.wait(out_tensor, use_calc_stream=False) - out_tensor_or_tensor_list.extend(paddle.split(out_tensor, nranks, 0)) + # 0D use stack/unstack while others use concat/split + if len(in_tensor_or_tensor_list[0].shape) == 0: + out_tensor_or_tensor_list.extend(paddle.unstack(out_tensor, 0)) + else: + out_tensor_or_tensor_list.extend( + paddle.split(out_tensor, nranks, 0) + ) return None diff --git a/python/paddle/distributed/communication/stream/scatter.py b/python/paddle/distributed/communication/stream/scatter.py index c4a6a66afbcd4dec5d52e73d618effe341803773..c112516a1fc106859e3d7e176395cbe1c1258403 100644 --- a/python/paddle/distributed/communication/stream/scatter.py +++ b/python/paddle/distributed/communication/stream/scatter.py @@ -91,7 +91,11 @@ def _scatter_in_static_mode( ) else: tensor_list = [tensor for _ in range(nranks)] - input_tensor = paddle.concat(tensor_list, axis=0) + # 0D use stack/unstack while others use concat/split + if len(tensor_list[0].shape) == 0: + input_tensor = paddle.stack(tensor_list, axis=0) + else: + input_tensor = paddle.concat(tensor_list, axis=0) ring_id = 0 if group is None else group.id