未验证 提交 59c4fdac 编写于 作者: Z zhaoyingli 提交者: GitHub

[AutoParallel] fix converter when sliced_shape is 1 (#41103)

* fix converter when sliced_shape is 1

* update unittest
上级 a0e961c0
......@@ -447,8 +447,8 @@ class Converter(object):
slice_shape = shape
else:
slice_shape = shape // process_shape[dims_mapping[i]]
if shape == 1:
index = 0
if slice_shape == 1:
index = partition_index[i][0]
else:
index = (partition_index[i][0] + 1) // slice_shape
sliced_index = sliced_index * (shape // slice_shape) + index
......
......@@ -943,8 +943,8 @@ def _get_sliced_param_index(rank, complete_shape, dims_mapping, process_shape,
slice_shape = shape
else:
slice_shape = shape // process_shape[dims_mapping[i]]
if shape == 1:
index = 0
if slice_shape == 1:
index = partition_index[i][0]
else:
index = (partition_index[i][0] + 1) // slice_shape
sliced_param_index = sliced_param_index * (shape // slice_shape) + index
......
......@@ -78,6 +78,28 @@ def test_convert():
convert_tensor_dict = converter.convert(strict=False)
assert np.equal(convert_tensor_dict[new_name], tensor_row[rank_id]).all()
# test sliced_shape is 1
complete_tensor = np.arange(4).reshape([2, 2])
tensor_row = np.split(complete_tensor, 2, axis=0)
complet_strategy = {
"tensor_2": {
"process_shape": [2],
"process_group": [0, 1],
"dims_mapping": [-1, -1]
}
}
row_strategy = {
"tensor_2": {
"process_shape": [2],
"process_group": [0, 1],
"dims_mapping": [0, -1]
}
}
tensor_dict = {"tensor_2": [complete_tensor]}
converter = Converter(tensor_dict, complet_strategy, row_strategy)
convert_tensor_dict = converter.convert()
assert np.equal(convert_tensor_dict["tensor_2"], tensor_row[rank_id]).all()
if __name__ == "__main__":
test_convert()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册