未验证 提交 22ec915c 编写于 作者: R Roc 提交者: GitHub

[0D Tensor]support 0d tensor for dist.scatter and dist.broadcast (#48638)

上级 35ebf2b4
......@@ -911,13 +911,14 @@ void ConcatInferMeta(const std::vector<const MetaTensor*>& x,
// 1. calculate axis
int rank = x.at(0)->dims().size();
PADDLE_ENFORCE_EQ(
axis >= -rank && axis < rank,
!rank || (axis >= -rank && axis < rank),
true,
phi::errors::InvalidArgument(
"The axis is expected to be in range of [%d, %d), but got %d",
-rank,
rank,
axis));
axis = rank ? axis : 0;
if (axis < 0) {
axis = axis + rank;
}
......
......@@ -21,13 +21,14 @@ namespace funcs {
static inline int64_t ComputeAxis(int64_t axis, int64_t rank) {
PADDLE_ENFORCE_EQ(
axis >= -rank && axis < rank,
!rank || (axis >= -rank && axis < rank),
true,
phi::errors::InvalidArgument(
"The axis is expected to be in range of [%d, %d), but got %d",
-rank,
rank,
axis));
axis = rank ? axis : 0;
if (axis < 0) {
axis = axis + rank;
}
......
......@@ -34,6 +34,35 @@ void ConcatKernel(const Context& dev_ctx,
DenseTensor* out) {
int64_t axis = axis_scalar.to<int64_t>();
if (UNLIKELY(x[0]->dims().size() == 0)) {
// for dims is 0 specially
phi::DDim tmp_1dim, out_dims;
out_dims[0] = x.size();
tmp_1dim[0] = 1;
out->Resize(out_dims);
dev_ctx.template Alloc<T>(out);
size_t output_offset = 0;
for (auto* in : x) {
if (in->numel() == 0UL) {
continue;
}
auto in_stride = phi::stride_numel(tmp_1dim);
auto out_stride = phi::stride_numel(out->dims());
paddle::operators::StridedNumelCopyWithAxis<T>(
dev_ctx,
axis,
out->data<T>() + output_offset,
out_stride,
in->data<T>(),
in_stride,
in_stride[axis]);
output_offset += in_stride[axis];
}
return;
}
axis = phi::funcs::ComputeAxis(axis, x[0]->dims().size());
std::vector<phi::DDim> x_dims;
......
......@@ -167,6 +167,29 @@ class TestProcessGroupFp32(unittest.TestCase):
print("test broadcast api ok")
# test broadcast with shape=[]
# rank 0
x = np.random.random([]).astype(self.dtype)
tensor_x = paddle.to_tensor(x)
# rank 1
y = np.random.random([]).astype(self.dtype)
tensor_y = paddle.to_tensor(y)
broadcast_result = paddle.assign(tensor_x)
if pg.rank() == 0:
task = dist.broadcast(tensor_x, 0, sync_op=False)
task.synchronize()
paddle.device.cuda.synchronize()
assert task.is_completed()
assert np.array_equal(broadcast_result, tensor_x)
else:
task = dist.broadcast(tensor_y, 0)
paddle.device.cuda.synchronize()
assert np.array_equal(broadcast_result, tensor_y)
assert tensor_y.shape == []
print("test broadcast api with shape=[] ok")
# test barrier
# rank 0
if pg.rank() == 0:
......@@ -417,6 +440,30 @@ class TestProcessGroupFp32(unittest.TestCase):
assert np.array_equal(tensor_y, out2)
print("test scatter api ok\n")
# test Scatter with shape=[]
# rank 0
x = np.random.random([]).astype(self.dtype)
y = np.random.random([]).astype(self.dtype)
tensor_x = paddle.to_tensor(x)
tensor_y = paddle.to_tensor(y)
if pg.rank() == 0:
in_1, in_2 = tensor_x, tensor_x + 1
task = dist.scatter(tensor_y, [in_1, in_2], 0, sync_op=True)
paddle.device.cuda.synchronize()
# rank 1
else:
task = dist.scatter(tensor_y, [], 0, sync_op=True)
task.wait()
paddle.device.cuda.synchronize()
out1 = paddle.assign(tensor_x)
out2 = paddle.assign(tensor_x + 1)
if pg.rank() == 0:
assert np.array_equal(tensor_y, out1)
else:
assert np.array_equal(tensor_y, out2), f"{tensor_y}, {out2}"
assert tensor_y.shape == []
print("test scatter api with shape=[] ok\n")
# test send min
# rank 0
x = np.random.random(self.shape).astype(self.dtype)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册