未验证 提交 0b6dd535 编写于 作者: zhouweiwei2014's avatar zhouweiwei2014 提交者: GitHub

[Zero-Dim] distributed scatter/all_to_all support input 0D tensor (#53186)

上级 3650c4a8
...@@ -31,8 +31,13 @@ class CAllGatherOp : public framework::OperatorWithKernel { ...@@ -31,8 +31,13 @@ class CAllGatherOp : public framework::OperatorWithKernel {
platform::errors::InvalidArgument( platform::errors::InvalidArgument(
"The value of nranks should be >=2.")); "The value of nranks should be >=2."));
framework::DDim dim = ctx->GetInputDim("X"); framework::DDim dim = ctx->GetInputDim("X");
dim[0] = dim[0] * nranks; // 0D use stack/unstack while others use concat/split
if (dim[0] < 0) dim[0] = -1; if (dim.size() == 0) {
dim = phi::make_ddim({nranks});
} else {
dim[0] = dim[0] * nranks;
if (dim[0] < 0) dim[0] = -1;
}
ctx->SetOutputDim("Out", dim); ctx->SetOutputDim("Out", dim);
} }
}; };
......
...@@ -57,13 +57,9 @@ class CAllGatherOpCUDAKernel : public framework::OpKernel<T> { ...@@ -57,13 +57,9 @@ class CAllGatherOpCUDAKernel : public framework::OpKernel<T> {
platform::errors::InvalidArgument( platform::errors::InvalidArgument(
"nranks: %s should equal to %s", nranks, comm->nranks())); "nranks: %s should equal to %s", nranks, comm->nranks()));
framework::DDim out_dims = in->dims();
out_dims[0] *= nranks;
out->mutable_data<T>(out_dims, place);
int64_t send_numel = in->numel(); int64_t send_numel = in->numel();
const T* send_buff = in->data<T>(); const T* send_buff = in->data<T>();
T* recv_buff = out->data<T>(); T* recv_buff = out->mutable_data<T>(place);
gpuStream_t stream = nullptr; gpuStream_t stream = nullptr;
if (ctx.Attr<bool>("use_calc_stream")) { if (ctx.Attr<bool>("use_calc_stream")) {
......
...@@ -39,15 +39,12 @@ class CAllGatherOpCPUKernel : public framework::OpKernel<T> { ...@@ -39,15 +39,12 @@ class CAllGatherOpCPUKernel : public framework::OpKernel<T> {
#if defined(PADDLE_WITH_GLOO) #if defined(PADDLE_WITH_GLOO)
auto in = ctx.Input<phi::DenseTensor>("X"); auto in = ctx.Input<phi::DenseTensor>("X");
auto out = ctx.Output<phi::DenseTensor>("Out"); auto out = ctx.Output<phi::DenseTensor>("Out");
framework::DDim out_dims = in->dims();
auto place = ctx.GetPlace(); auto place = ctx.GetPlace();
auto gloo = paddle::framework::GlooWrapper::GetInstance(); auto gloo = paddle::framework::GlooWrapper::GetInstance();
auto nranks = gloo->Size(); auto nranks = gloo->Size();
out_dims[0] *= nranks;
int64_t send_numel = in->numel(); int64_t send_numel = in->numel();
const T* send_buff = in->data<T>(); const T* send_buff = in->data<T>();
T* recv_buff = out->mutable_data<T>(out_dims, place); T* recv_buff = out->mutable_data<T>(place);
PADDLE_ENFORCE_EQ( PADDLE_ENFORCE_EQ(
gloo->IsInitialized(), gloo->IsInitialized(),
......
...@@ -43,12 +43,9 @@ class CAllGatherOpXPUKernel : public framework::OpKernel<T> { ...@@ -43,12 +43,9 @@ class CAllGatherOpXPUKernel : public framework::OpKernel<T> {
platform::errors::InvalidArgument( platform::errors::InvalidArgument(
"nranks: %s should equal to %s", nranks, comm->nranks())); "nranks: %s should equal to %s", nranks, comm->nranks()));
framework::DDim out_dims = in->dims();
out_dims[0] *= nranks;
size_t numel = in->numel(); size_t numel = in->numel();
const void* sendbuff = in->data<T>(); const void* sendbuff = in->data<T>();
void* recvbuff = out->mutable_data<T>(out_dims, place); void* recvbuff = out->mutable_data<T>(place);
XPUStream stream = nullptr; XPUStream stream = nullptr;
if (ctx.Attr<bool>("use_calc_stream")) { if (ctx.Attr<bool>("use_calc_stream")) {
......
...@@ -154,7 +154,7 @@ DDim flatten_to_1d(const DDim& src) { return DDim({product(src)}); } ...@@ -154,7 +154,7 @@ DDim flatten_to_1d(const DDim& src) { return DDim({product(src)}); }
DDim stride(const DDim& ddim) { DDim stride(const DDim& ddim) {
DDim strides; DDim strides;
strides.rank_ = ddim.size(); strides.rank_ = ddim.size();
strides[ddim.size() - 1] = 1; if (ddim.size() > 0) strides[ddim.size() - 1] = 1;
for (int i = ddim.size() - 2; i >= 0; --i) { for (int i = ddim.size() - 2; i >= 0; --i) {
strides[i] = strides[i + 1] * ddim[i + 1]; strides[i] = strides[i + 1] * ddim[i + 1];
} }
...@@ -164,7 +164,7 @@ DDim stride(const DDim& ddim) { ...@@ -164,7 +164,7 @@ DDim stride(const DDim& ddim) {
DDim stride_numel(const DDim& ddim) { DDim stride_numel(const DDim& ddim) {
DDim strides; DDim strides;
strides.rank_ = ddim.size(); strides.rank_ = ddim.size();
strides[ddim.size() - 1] = ddim[ddim.size() - 1]; if (ddim.size() > 0) strides[ddim.size() - 1] = ddim[ddim.size() - 1];
for (int i = ddim.size() - 2; i >= 0; --i) { for (int i = ddim.size() - 2; i >= 0; --i) {
strides[i] = strides[i + 1] * ddim[i]; strides[i] = strides[i + 1] * ddim[i];
} }
......
...@@ -122,8 +122,7 @@ def scatter_object_list( ...@@ -122,8 +122,7 @@ def scatter_object_list(
in_obj_sizes.append(obj_size) in_obj_sizes.append(obj_size)
max_obj_size_tensor = max(in_obj_sizes) max_obj_size_tensor = max(in_obj_sizes)
else: else:
# NOTE: shape can be [] after 0D tensor support max_obj_size_tensor = paddle.empty([], dtype="int64")
max_obj_size_tensor = paddle.empty([1], dtype="int64")
stream.broadcast(max_obj_size_tensor, src) stream.broadcast(max_obj_size_tensor, src)
max_obj_size = int(max_obj_size_tensor.item()) max_obj_size = int(max_obj_size_tensor.item())
...@@ -137,8 +136,7 @@ def scatter_object_list( ...@@ -137,8 +136,7 @@ def scatter_object_list(
out_tensor = paddle.empty([max_obj_size], dtype="uint8") out_tensor = paddle.empty([max_obj_size], dtype="uint8")
scatter(out_tensor, in_tensor_list if rank == src else None, src, group) scatter(out_tensor, in_tensor_list if rank == src else None, src, group)
# NOTE: shape can be [] after 0D tensor support out_tensor_size = paddle.empty([], dtype="int64")
out_tensor_size = paddle.empty([1], dtype="int64")
scatter(out_tensor_size, in_obj_sizes if rank == src else None, src, group) scatter(out_tensor_size, in_obj_sizes if rank == src else None, src, group)
out_object_list.clear() out_object_list.clear()
......
...@@ -108,7 +108,11 @@ def _all_gather_in_static_mode(tensor_list, tensor, group, sync_op): ...@@ -108,7 +108,11 @@ def _all_gather_in_static_mode(tensor_list, tensor, group, sync_op):
}, },
) )
tensor_list.clear() tensor_list.clear()
tensor_list.extend(paddle.split(out, nranks, 0)) # 0D use stack/unstack while others use concat/split
if len(tensor.shape) == 0:
tensor_list.extend(paddle.unstack(out, 0))
else:
tensor_list.extend(paddle.split(out, nranks, 0))
def all_gather( def all_gather(
......
...@@ -78,7 +78,12 @@ def _all_to_all_in_static_mode( ...@@ -78,7 +78,12 @@ def _all_to_all_in_static_mode(
if isinstance(in_tensor_or_tensor_list, list): if isinstance(in_tensor_or_tensor_list, list):
if len(in_tensor_or_tensor_list) == 0: if len(in_tensor_or_tensor_list) == 0:
raise RuntimeError("The input tensor_list should not be empty.") raise RuntimeError("The input tensor_list should not be empty.")
in_tensor = paddle.concat(in_tensor_or_tensor_list, axis=0) # 0D use stack/unstack while others use concat/split
if len(in_tensor_or_tensor_list[0].shape) == 0:
in_tensor = paddle.stack(in_tensor_or_tensor_list, axis=0)
else:
in_tensor = paddle.concat(in_tensor_or_tensor_list, axis=0)
out_tensor = out_tensor_or_tensor_list out_tensor = out_tensor_or_tensor_list
if isinstance(out_tensor_or_tensor_list, list): if isinstance(out_tensor_or_tensor_list, list):
if len(out_tensor_or_tensor_list) != 0: if len(out_tensor_or_tensor_list) != 0:
...@@ -110,7 +115,13 @@ def _all_to_all_in_static_mode( ...@@ -110,7 +115,13 @@ def _all_to_all_in_static_mode(
if isinstance(out_tensor_or_tensor_list, list): if isinstance(out_tensor_or_tensor_list, list):
if not sync_op: if not sync_op:
dist.wait(out_tensor, use_calc_stream=False) dist.wait(out_tensor, use_calc_stream=False)
out_tensor_or_tensor_list.extend(paddle.split(out_tensor, nranks, 0)) # 0D use stack/unstack while others use concat/split
if len(in_tensor_or_tensor_list[0].shape) == 0:
out_tensor_or_tensor_list.extend(paddle.unstack(out_tensor, 0))
else:
out_tensor_or_tensor_list.extend(
paddle.split(out_tensor, nranks, 0)
)
return None return None
......
...@@ -91,7 +91,11 @@ def _scatter_in_static_mode( ...@@ -91,7 +91,11 @@ def _scatter_in_static_mode(
) )
else: else:
tensor_list = [tensor for _ in range(nranks)] tensor_list = [tensor for _ in range(nranks)]
input_tensor = paddle.concat(tensor_list, axis=0) # 0D use stack/unstack while others use concat/split
if len(tensor_list[0].shape) == 0:
input_tensor = paddle.stack(tensor_list, axis=0)
else:
input_tensor = paddle.concat(tensor_list, axis=0)
ring_id = 0 if group is None else group.id ring_id = 0 if group is None else group.id
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册