diff --git a/python/paddle/distributed/auto_parallel/converter.py b/python/paddle/distributed/auto_parallel/converter.py index 1475c447042aded805137e1e6ea17b3a1a374fb0..2ea200c7d6f81d5380217e4a0fc03d8f5978fed1 100644 --- a/python/paddle/distributed/auto_parallel/converter.py +++ b/python/paddle/distributed/auto_parallel/converter.py @@ -447,8 +447,8 @@ class Converter(object): slice_shape = shape else: slice_shape = shape // process_shape[dims_mapping[i]] - if shape == 1: - index = 0 + if slice_shape == 1: + index = partition_index[i][0] else: index = (partition_index[i][0] + 1) // slice_shape sliced_index = sliced_index * (shape // slice_shape) + index diff --git a/python/paddle/distributed/auto_parallel/utils.py b/python/paddle/distributed/auto_parallel/utils.py index a7b5f3a2fd0d0f0cb07a912a6898da69693b7b8d..fc85cd04d4010ed826ea198f0c4b44a7c461ea86 100644 --- a/python/paddle/distributed/auto_parallel/utils.py +++ b/python/paddle/distributed/auto_parallel/utils.py @@ -943,8 +943,8 @@ def _get_sliced_param_index(rank, complete_shape, dims_mapping, process_shape, slice_shape = shape else: slice_shape = shape // process_shape[dims_mapping[i]] - if shape == 1: - index = 0 + if slice_shape == 1: + index = partition_index[i][0] else: index = (partition_index[i][0] + 1) // slice_shape sliced_param_index = sliced_param_index * (shape // slice_shape) + index diff --git a/python/paddle/fluid/tests/unittests/auto_parallel/converter.py b/python/paddle/fluid/tests/unittests/auto_parallel/converter.py index e34f267b4237bf5ebe19adda1c90f1c147294333..d5d7caa7b77083f50066de02f8b33d8012b161c2 100644 --- a/python/paddle/fluid/tests/unittests/auto_parallel/converter.py +++ b/python/paddle/fluid/tests/unittests/auto_parallel/converter.py @@ -78,6 +78,28 @@ def test_convert(): convert_tensor_dict = converter.convert(strict=False) assert np.equal(convert_tensor_dict[new_name], tensor_row[rank_id]).all() + # test sliced_shape is 1 + complete_tensor = np.arange(4).reshape([2, 2]) + tensor_row = np.split(complete_tensor, 2, axis=0) + complet_strategy = { + "tensor_2": { + "process_shape": [2], + "process_group": [0, 1], + "dims_mapping": [-1, -1] + } + } + row_strategy = { + "tensor_2": { + "process_shape": [2], + "process_group": [0, 1], + "dims_mapping": [0, -1] + } + } + tensor_dict = {"tensor_2": [complete_tensor]} + converter = Converter(tensor_dict, complet_strategy, row_strategy) + convert_tensor_dict = converter.convert() + assert np.equal(convert_tensor_dict["tensor_2"], tensor_row[rank_id]).all() + if __name__ == "__main__": test_convert()