未验证 提交 0b6dd535 编写于 作者: zhouweiwei2014's avatar zhouweiwei2014 提交者: GitHub

[Zero-Dim] distributed scatter/all_to_all support input 0D tensor (#53186)

上级 3650c4a8
......@@ -31,8 +31,13 @@ class CAllGatherOp : public framework::OperatorWithKernel {
platform::errors::InvalidArgument(
"The value of nranks should be >=2."));
framework::DDim dim = ctx->GetInputDim("X");
dim[0] = dim[0] * nranks;
if (dim[0] < 0) dim[0] = -1;
// 0D use stack/unstack while others use concat/split
if (dim.size() == 0) {
dim = phi::make_ddim({nranks});
} else {
dim[0] = dim[0] * nranks;
if (dim[0] < 0) dim[0] = -1;
}
ctx->SetOutputDim("Out", dim);
}
};
......
......@@ -57,13 +57,9 @@ class CAllGatherOpCUDAKernel : public framework::OpKernel<T> {
platform::errors::InvalidArgument(
"nranks: %s should equal to %s", nranks, comm->nranks()));
framework::DDim out_dims = in->dims();
out_dims[0] *= nranks;
out->mutable_data<T>(out_dims, place);
int64_t send_numel = in->numel();
const T* send_buff = in->data<T>();
T* recv_buff = out->data<T>();
T* recv_buff = out->mutable_data<T>(place);
gpuStream_t stream = nullptr;
if (ctx.Attr<bool>("use_calc_stream")) {
......
......@@ -39,15 +39,12 @@ class CAllGatherOpCPUKernel : public framework::OpKernel<T> {
#if defined(PADDLE_WITH_GLOO)
auto in = ctx.Input<phi::DenseTensor>("X");
auto out = ctx.Output<phi::DenseTensor>("Out");
framework::DDim out_dims = in->dims();
auto place = ctx.GetPlace();
auto gloo = paddle::framework::GlooWrapper::GetInstance();
auto nranks = gloo->Size();
out_dims[0] *= nranks;
int64_t send_numel = in->numel();
const T* send_buff = in->data<T>();
T* recv_buff = out->mutable_data<T>(out_dims, place);
T* recv_buff = out->mutable_data<T>(place);
PADDLE_ENFORCE_EQ(
gloo->IsInitialized(),
......
......@@ -43,12 +43,9 @@ class CAllGatherOpXPUKernel : public framework::OpKernel<T> {
platform::errors::InvalidArgument(
"nranks: %s should equal to %s", nranks, comm->nranks()));
framework::DDim out_dims = in->dims();
out_dims[0] *= nranks;
size_t numel = in->numel();
const void* sendbuff = in->data<T>();
void* recvbuff = out->mutable_data<T>(out_dims, place);
void* recvbuff = out->mutable_data<T>(place);
XPUStream stream = nullptr;
if (ctx.Attr<bool>("use_calc_stream")) {
......
......@@ -154,7 +154,7 @@ DDim flatten_to_1d(const DDim& src) { return DDim({product(src)}); }
DDim stride(const DDim& ddim) {
DDim strides;
strides.rank_ = ddim.size();
strides[ddim.size() - 1] = 1;
if (ddim.size() > 0) strides[ddim.size() - 1] = 1;
for (int i = ddim.size() - 2; i >= 0; --i) {
strides[i] = strides[i + 1] * ddim[i + 1];
}
......@@ -164,7 +164,7 @@ DDim stride(const DDim& ddim) {
DDim stride_numel(const DDim& ddim) {
DDim strides;
strides.rank_ = ddim.size();
strides[ddim.size() - 1] = ddim[ddim.size() - 1];
if (ddim.size() > 0) strides[ddim.size() - 1] = ddim[ddim.size() - 1];
for (int i = ddim.size() - 2; i >= 0; --i) {
strides[i] = strides[i + 1] * ddim[i];
}
......
......@@ -122,8 +122,7 @@ def scatter_object_list(
in_obj_sizes.append(obj_size)
max_obj_size_tensor = max(in_obj_sizes)
else:
# NOTE: shape can be [] after 0D tensor support
max_obj_size_tensor = paddle.empty([1], dtype="int64")
max_obj_size_tensor = paddle.empty([], dtype="int64")
stream.broadcast(max_obj_size_tensor, src)
max_obj_size = int(max_obj_size_tensor.item())
......@@ -137,8 +136,7 @@ def scatter_object_list(
out_tensor = paddle.empty([max_obj_size], dtype="uint8")
scatter(out_tensor, in_tensor_list if rank == src else None, src, group)
# NOTE: shape can be [] after 0D tensor support
out_tensor_size = paddle.empty([1], dtype="int64")
out_tensor_size = paddle.empty([], dtype="int64")
scatter(out_tensor_size, in_obj_sizes if rank == src else None, src, group)
out_object_list.clear()
......
......@@ -108,7 +108,11 @@ def _all_gather_in_static_mode(tensor_list, tensor, group, sync_op):
},
)
tensor_list.clear()
tensor_list.extend(paddle.split(out, nranks, 0))
# 0D use stack/unstack while others use concat/split
if len(tensor.shape) == 0:
tensor_list.extend(paddle.unstack(out, 0))
else:
tensor_list.extend(paddle.split(out, nranks, 0))
def all_gather(
......
......@@ -78,7 +78,12 @@ def _all_to_all_in_static_mode(
if isinstance(in_tensor_or_tensor_list, list):
if len(in_tensor_or_tensor_list) == 0:
raise RuntimeError("The input tensor_list should not be empty.")
in_tensor = paddle.concat(in_tensor_or_tensor_list, axis=0)
# 0D use stack/unstack while others use concat/split
if len(in_tensor_or_tensor_list[0].shape) == 0:
in_tensor = paddle.stack(in_tensor_or_tensor_list, axis=0)
else:
in_tensor = paddle.concat(in_tensor_or_tensor_list, axis=0)
out_tensor = out_tensor_or_tensor_list
if isinstance(out_tensor_or_tensor_list, list):
if len(out_tensor_or_tensor_list) != 0:
......@@ -110,7 +115,13 @@ def _all_to_all_in_static_mode(
if isinstance(out_tensor_or_tensor_list, list):
if not sync_op:
dist.wait(out_tensor, use_calc_stream=False)
out_tensor_or_tensor_list.extend(paddle.split(out_tensor, nranks, 0))
# 0D use stack/unstack while others use concat/split
if len(in_tensor_or_tensor_list[0].shape) == 0:
out_tensor_or_tensor_list.extend(paddle.unstack(out_tensor, 0))
else:
out_tensor_or_tensor_list.extend(
paddle.split(out_tensor, nranks, 0)
)
return None
......
......@@ -91,7 +91,11 @@ def _scatter_in_static_mode(
)
else:
tensor_list = [tensor for _ in range(nranks)]
input_tensor = paddle.concat(tensor_list, axis=0)
# 0D use stack/unstack while others use concat/split
if len(tensor_list[0].shape) == 0:
input_tensor = paddle.stack(tensor_list, axis=0)
else:
input_tensor = paddle.concat(tensor_list, axis=0)
ring_id = 0 if group is None else group.id
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册