未验证 提交 693c3c14 编写于 作者: Z zhaoyingli 提交者: GitHub

fix autoconvert (#37347)

* fix autoconvert

* fix merge parameter
上级 d2200e97
...@@ -642,7 +642,7 @@ def _load_distributed_state_dict(checkpoint_path): ...@@ -642,7 +642,7 @@ def _load_distributed_state_dict(checkpoint_path):
""" Load parameters' state_dict from checkpoint_path """ """ Load parameters' state_dict from checkpoint_path """
all_state_dict = {} all_state_dict = {}
for idx, ckpt_file in enumerate(checkpoint_path): for idx, ckpt_file in enumerate(checkpoint_path):
state_dict_info = paddle.load(ckpt_file) state_dict_info = paddle.load(ckpt_file, return_numpy=True)
pre_world_size = state_dict_info["world_size"] pre_world_size = state_dict_info["world_size"]
assert pre_world_size == len(checkpoint_path), \ assert pre_world_size == len(checkpoint_path), \
"The number of 'checkpoint_path' must be equal to the last training world size." "The number of 'checkpoint_path' must be equal to the last training world size."
...@@ -778,12 +778,16 @@ def _merge_parameter_with_dist_attr(param_list, dist_attr): ...@@ -778,12 +778,16 @@ def _merge_parameter_with_dist_attr(param_list, dist_attr):
dims_mapping) dims_mapping)
# merge the parameter with dist_attr # merge the parameter with dist_attr
partition_param_list = [] partition_param_list = []
merged_partiton = []
for process in process_group: for process in process_group:
partition_index = _compute_partition_index( partition_index = _compute_partition_index(
process, complete_shape, dims_mapping, process_shape, process_group) process, complete_shape, dims_mapping, process_shape, process_group)
index = process_group.index(process) index = process_group.index(process)
if partition_index not in merged_partiton:
merged_partiton.append(partition_index)
_merge_parameter(partition_param_list, param_list[index], _merge_parameter(partition_param_list, param_list[index],
partition_index) partition_index, complete_shape)
assert len(partition_param_list) == 1 or not partition_param_list, \ assert len(partition_param_list) == 1 or not partition_param_list, \
"Fail to merge parameter" "Fail to merge parameter"
complete_param = _to_LodTensor(partition_param_list[0][0]) complete_param = _to_LodTensor(partition_param_list[0][0])
...@@ -810,7 +814,8 @@ def _slice_parameter_with_dist_attr(param, dist_attr): ...@@ -810,7 +814,8 @@ def _slice_parameter_with_dist_attr(param, dist_attr):
return sliced_param return sliced_param
def _merge_parameter(partition_param_list, param, partition_index): def _merge_parameter(partition_param_list, param, partition_index,
complete_shape):
""" """
Merge partitial parameters to a complete one. Merge partitial parameters to a complete one.
...@@ -830,16 +835,23 @@ def _merge_parameter(partition_param_list, param, partition_index): ...@@ -830,16 +835,23 @@ def _merge_parameter(partition_param_list, param, partition_index):
""" """
from .reshard import _compute_concat_info from .reshard import _compute_concat_info
if len(partition_param_list) == 1:
is_complete_data = True
for idx, item in enumerate(partition_param_list[0][1]):
if item[0] != 0 or item[1] != complete_shape[idx]:
is_complete_data = False
break
if is_complete_data:
return
if not partition_param_list: if not partition_param_list:
partition_param_list.append((param, partition_index)) partition_param_list.append((param, partition_index))
else: else:
i = 0 i = 0
has_concat = False
while i < len(partition_param_list): while i < len(partition_param_list):
concat_axis, first_order, new_partition = _compute_concat_info( concat_axis, first_order, new_partition = _compute_concat_info(
partition_param_list[i][1], partition_index) partition_param_list[i][1], partition_index)
if concat_axis != -1: if concat_axis != -1:
has_concat = True
if first_order == 0: if first_order == 0:
new_param = np.concatenate( new_param = np.concatenate(
(partition_param_list[i][0], param), axis=concat_axis) (partition_param_list[i][0], param), axis=concat_axis)
...@@ -848,19 +860,11 @@ def _merge_parameter(partition_param_list, param, partition_index): ...@@ -848,19 +860,11 @@ def _merge_parameter(partition_param_list, param, partition_index):
(param, partition_param_list[i][0]), axis=concat_axis) (param, partition_param_list[i][0]), axis=concat_axis)
partition_param_list.pop(i) partition_param_list.pop(i)
_merge_parameter(partition_param_list, new_param, new_partition) _merge_parameter(partition_param_list, new_param, new_partition,
complete_shape)
break break
i += 1 i += 1
if not has_concat:
need_append = True
for i in range(len(partition_param_list)):
if partition_index == partition_param_list[i][1]:
need_append = False
break
if need_append:
partition_param_list.append((param, partition_index))
def _slice_parameter(complete_param, partition_index_list, length): def _slice_parameter(complete_param, partition_index_list, length):
""" """
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册