From 59c4fdac6a314ad5e6afea6d4946036e5c0e4305 Mon Sep 17 00:00:00 2001 From: zhaoyingli <86812880+zhaoyinglia@users.noreply.github.com> Date: Wed, 30 Mar 2022 19:50:05 +0800 Subject: [PATCH] [AutoParallel] fix converter when sliced_shape is 1 (#41103) * fix converter when sliced_shape is 1 * update unittest --- .../distributed/auto_parallel/converter.py | 4 ++-- .../paddle/distributed/auto_parallel/utils.py | 4 ++-- .../unittests/auto_parallel/converter.py | 22 +++++++++++++++++++ 3 files changed, 26 insertions(+), 4 deletions(-) diff --git a/python/paddle/distributed/auto_parallel/converter.py b/python/paddle/distributed/auto_parallel/converter.py index 1475c447042..2ea200c7d6f 100644 --- a/python/paddle/distributed/auto_parallel/converter.py +++ b/python/paddle/distributed/auto_parallel/converter.py @@ -447,8 +447,8 @@ class Converter(object): slice_shape = shape else: slice_shape = shape // process_shape[dims_mapping[i]] - if shape == 1: - index = 0 + if slice_shape == 1: + index = partition_index[i][0] else: index = (partition_index[i][0] + 1) // slice_shape sliced_index = sliced_index * (shape // slice_shape) + index diff --git a/python/paddle/distributed/auto_parallel/utils.py b/python/paddle/distributed/auto_parallel/utils.py index a7b5f3a2fd0..fc85cd04d40 100644 --- a/python/paddle/distributed/auto_parallel/utils.py +++ b/python/paddle/distributed/auto_parallel/utils.py @@ -943,8 +943,8 @@ def _get_sliced_param_index(rank, complete_shape, dims_mapping, process_shape, slice_shape = shape else: slice_shape = shape // process_shape[dims_mapping[i]] - if shape == 1: - index = 0 + if slice_shape == 1: + index = partition_index[i][0] else: index = (partition_index[i][0] + 1) // slice_shape sliced_param_index = sliced_param_index * (shape // slice_shape) + index diff --git a/python/paddle/fluid/tests/unittests/auto_parallel/converter.py b/python/paddle/fluid/tests/unittests/auto_parallel/converter.py index e34f267b423..d5d7caa7b77 100644 --- a/python/paddle/fluid/tests/unittests/auto_parallel/converter.py +++ b/python/paddle/fluid/tests/unittests/auto_parallel/converter.py @@ -78,6 +78,28 @@ def test_convert(): convert_tensor_dict = converter.convert(strict=False) assert np.equal(convert_tensor_dict[new_name], tensor_row[rank_id]).all() + # test sliced_shape is 1 + complete_tensor = np.arange(4).reshape([2, 2]) + tensor_row = np.split(complete_tensor, 2, axis=0) + complet_strategy = { + "tensor_2": { + "process_shape": [2], + "process_group": [0, 1], + "dims_mapping": [-1, -1] + } + } + row_strategy = { + "tensor_2": { + "process_shape": [2], + "process_group": [0, 1], + "dims_mapping": [0, -1] + } + } + tensor_dict = {"tensor_2": [complete_tensor]} + converter = Converter(tensor_dict, complet_strategy, row_strategy) + convert_tensor_dict = converter.convert() + assert np.equal(convert_tensor_dict["tensor_2"], tensor_row[rank_id]).all() + if __name__ == "__main__": test_convert() -- GitLab