未验证 提交 c6a084ef 编写于 作者: L lilong12 提交者: GitHub

be compatible with the old version of alltoall (#42007)

上级 a3c50c42
...@@ -860,9 +860,12 @@ def all_gather(tensor_list, tensor, group=None, use_calc_stream=True): ...@@ -860,9 +860,12 @@ def all_gather(tensor_list, tensor, group=None, use_calc_stream=True):
if in_dygraph_mode(): if in_dygraph_mode():
group = _get_default_group() if group is None else group group = _get_default_group() if group is None else group
tensor_shape = list(tensor.shape) if len(tensor_list) == 0:
tensor_shape[0] *= group.nranks tensor_shape = list(tensor.shape)
out = paddle.empty(tensor_shape, tensor.dtype) tensor_shape[0] *= group.nranks
out = paddle.empty(tensor_shape, tensor.dtype)
else:
out = paddle.concat(tensor_list, axis=0)
task = group.process_group.all_gather(tensor, out) task = group.process_group.all_gather(tensor, out)
task.wait() task.wait()
tensor_list.clear() tensor_list.clear()
...@@ -1783,7 +1786,12 @@ def alltoall(in_tensor_list, out_tensor_list, group=None, use_calc_stream=True): ...@@ -1783,7 +1786,12 @@ def alltoall(in_tensor_list, out_tensor_list, group=None, use_calc_stream=True):
temp = paddle.concat(in_tensor_list, axis=0) temp = paddle.concat(in_tensor_list, axis=0)
nranks = len(in_tensor_list) nranks = len(in_tensor_list)
if in_dygraph_mode(): if in_dygraph_mode():
out = paddle.concat(out_tensor_list, axis=0) if len(out_tensor_list) == 0:
tensor_shape = list(in_tensor_list[0].shape)
tensor_shape[0] *= nranks
out = paddle.empty(tensor_shape, in_tensor_list[0].dtype)
else:
out = paddle.concat(out_tensor_list, axis=0)
task = group.process_group.alltoall(temp, out) task = group.process_group.alltoall(temp, out)
task.wait() task.wait()
out_tensor_list.clear() out_tensor_list.clear()
......
...@@ -185,6 +185,24 @@ class TestProcessGroupFp32(unittest.TestCase): ...@@ -185,6 +185,24 @@ class TestProcessGroupFp32(unittest.TestCase):
assert np.array_equal(tensor_y, out_2) assert np.array_equal(tensor_y, out_2)
print("test allgather api ok\n") print("test allgather api ok\n")
if pg.rank() == 0:
task = pg.all_gather(tensor_x, tensor_out)
task.wait()
paddle.device.cuda.synchronize()
# rank 1
else:
tensor_out_list = []
task = dist.all_gather(
tensor_out_list, tensor_y, use_calc_stream=False)
paddle.device.cuda.synchronize()
tensor_out = paddle.concat(tensor_out_list)
out_1 = paddle.slice(tensor_out, [0], [0], [out_shape[0] // 2])
out_2 = paddle.slice(tensor_out, [0], [out_shape[0] // 2],
[out_shape[0]])
assert np.array_equal(tensor_x, out_1)
assert np.array_equal(tensor_y, out_2)
print("test allgather api2 ok\n")
# test alltoall # test alltoall
# rank 0 # rank 0
x = np.random.random(self.shape).astype(self.dtype) x = np.random.random(self.shape).astype(self.dtype)
...@@ -219,6 +237,38 @@ class TestProcessGroupFp32(unittest.TestCase): ...@@ -219,6 +237,38 @@ class TestProcessGroupFp32(unittest.TestCase):
assert np.array_equal(out2_1, raw_tensor_x_2) assert np.array_equal(out2_1, raw_tensor_x_2)
print("test alltoall api ok\n") print("test alltoall api ok\n")
x = np.random.random(self.shape).astype(self.dtype)
y = np.random.random(self.shape).astype(self.dtype)
out1 = np.random.random(self.shape).astype(self.dtype)
out2 = np.random.random(self.shape).astype(self.dtype)
tensor_x = paddle.to_tensor(x)
tensor_y = paddle.to_tensor(y)
tensor_out1 = paddle.to_tensor(out1)
tensor_out2 = paddle.to_tensor(out2)
raw_tensor_x_2 = paddle.slice(tensor_x, [0], [self.shape[0] // 2],
[self.shape[0]])
raw_tensor_y_1 = paddle.slice(tensor_y, [0], [0],
[self.shape[0] // 2])
if pg.rank() == 0:
task = pg.alltoall(tensor_x, tensor_out1)
task.wait()
# rank 1
else:
in_1, in_2 = paddle.split(tensor_y, 2)
out_1, out_2 = paddle.split(tensor_out2, 2)
out_tensor_list = []
task = dist.alltoall([in_1, in_2], out_tensor_list)
paddle.device.cuda.synchronize()
tensor_out2 = paddle.concat(out_tensor_list)
out1_2 = paddle.slice(tensor_out1, [0], [self.shape[0] // 2],
[self.shape[0]])
out2_1 = paddle.slice(tensor_out2, [0], [0], [self.shape[0] // 2])
if pg.rank() == 0:
assert np.array_equal(out1_2.numpy(), raw_tensor_y_1.numpy())
else:
assert np.array_equal(out2_1, raw_tensor_x_2)
print("test alltoall api2 ok\n")
# test Reduce # test Reduce
# rank 0 # rank 0
x = np.random.random(self.shape).astype(self.dtype) x = np.random.random(self.shape).astype(self.dtype)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册