From c6a084ef2bfcf5bd4f1cdeedbef13a933413ddab Mon Sep 17 00:00:00 2001 From: lilong12 Date: Wed, 20 Apr 2022 17:53:58 +0800 Subject: [PATCH] be compatible with the old version of alltoall (#42007) --- python/paddle/distributed/collective.py | 16 ++++-- .../tests/unittests/process_group_nccl.py | 50 +++++++++++++++++++ 2 files changed, 62 insertions(+), 4 deletions(-) diff --git a/python/paddle/distributed/collective.py b/python/paddle/distributed/collective.py index 35ab1193c2b..b92b2a3c15d 100644 --- a/python/paddle/distributed/collective.py +++ b/python/paddle/distributed/collective.py @@ -860,9 +860,12 @@ def all_gather(tensor_list, tensor, group=None, use_calc_stream=True): if in_dygraph_mode(): group = _get_default_group() if group is None else group - tensor_shape = list(tensor.shape) - tensor_shape[0] *= group.nranks - out = paddle.empty(tensor_shape, tensor.dtype) + if len(tensor_list) == 0: + tensor_shape = list(tensor.shape) + tensor_shape[0] *= group.nranks + out = paddle.empty(tensor_shape, tensor.dtype) + else: + out = paddle.concat(tensor_list, axis=0) task = group.process_group.all_gather(tensor, out) task.wait() tensor_list.clear() @@ -1783,7 +1786,12 @@ def alltoall(in_tensor_list, out_tensor_list, group=None, use_calc_stream=True): temp = paddle.concat(in_tensor_list, axis=0) nranks = len(in_tensor_list) if in_dygraph_mode(): - out = paddle.concat(out_tensor_list, axis=0) + if len(out_tensor_list) == 0: + tensor_shape = list(in_tensor_list[0].shape) + tensor_shape[0] *= nranks + out = paddle.empty(tensor_shape, in_tensor_list[0].dtype) + else: + out = paddle.concat(out_tensor_list, axis=0) task = group.process_group.alltoall(temp, out) task.wait() out_tensor_list.clear() diff --git a/python/paddle/fluid/tests/unittests/process_group_nccl.py b/python/paddle/fluid/tests/unittests/process_group_nccl.py index 7ae38b3bbc4..7aa83ad9079 100644 --- a/python/paddle/fluid/tests/unittests/process_group_nccl.py +++ b/python/paddle/fluid/tests/unittests/process_group_nccl.py @@ -185,6 +185,24 @@ class TestProcessGroupFp32(unittest.TestCase): assert np.array_equal(tensor_y, out_2) print("test allgather api ok\n") + if pg.rank() == 0: + task = pg.all_gather(tensor_x, tensor_out) + task.wait() + paddle.device.cuda.synchronize() + # rank 1 + else: + tensor_out_list = [] + task = dist.all_gather( + tensor_out_list, tensor_y, use_calc_stream=False) + paddle.device.cuda.synchronize() + tensor_out = paddle.concat(tensor_out_list) + out_1 = paddle.slice(tensor_out, [0], [0], [out_shape[0] // 2]) + out_2 = paddle.slice(tensor_out, [0], [out_shape[0] // 2], + [out_shape[0]]) + assert np.array_equal(tensor_x, out_1) + assert np.array_equal(tensor_y, out_2) + print("test allgather api2 ok\n") + # test alltoall # rank 0 x = np.random.random(self.shape).astype(self.dtype) @@ -219,6 +237,38 @@ class TestProcessGroupFp32(unittest.TestCase): assert np.array_equal(out2_1, raw_tensor_x_2) print("test alltoall api ok\n") + x = np.random.random(self.shape).astype(self.dtype) + y = np.random.random(self.shape).astype(self.dtype) + out1 = np.random.random(self.shape).astype(self.dtype) + out2 = np.random.random(self.shape).astype(self.dtype) + tensor_x = paddle.to_tensor(x) + tensor_y = paddle.to_tensor(y) + tensor_out1 = paddle.to_tensor(out1) + tensor_out2 = paddle.to_tensor(out2) + raw_tensor_x_2 = paddle.slice(tensor_x, [0], [self.shape[0] // 2], + [self.shape[0]]) + raw_tensor_y_1 = paddle.slice(tensor_y, [0], [0], + [self.shape[0] // 2]) + if pg.rank() == 0: + task = pg.alltoall(tensor_x, tensor_out1) + task.wait() + # rank 1 + else: + in_1, in_2 = paddle.split(tensor_y, 2) + out_1, out_2 = paddle.split(tensor_out2, 2) + out_tensor_list = [] + task = dist.alltoall([in_1, in_2], out_tensor_list) + paddle.device.cuda.synchronize() + tensor_out2 = paddle.concat(out_tensor_list) + out1_2 = paddle.slice(tensor_out1, [0], [self.shape[0] // 2], + [self.shape[0]]) + out2_1 = paddle.slice(tensor_out2, [0], [0], [self.shape[0] // 2]) + if pg.rank() == 0: + assert np.array_equal(out1_2.numpy(), raw_tensor_y_1.numpy()) + else: + assert np.array_equal(out2_1, raw_tensor_x_2) + print("test alltoall api2 ok\n") + # test Reduce # rank 0 x = np.random.random(self.shape).astype(self.dtype) -- GitLab