From 693c3c143c596154774ec4dba9651255ac854771 Mon Sep 17 00:00:00 2001 From: zhaoyingli <86812880+zhaoyinglia@users.noreply.github.com> Date: Mon, 22 Nov 2021 19:50:41 +0800 Subject: [PATCH] fix autoconvert (#37347) * fix autoconvert * fix merge parameter --- .../paddle/distributed/auto_parallel/utils.py | 36 ++++++++++--------- 1 file changed, 20 insertions(+), 16 deletions(-) diff --git a/python/paddle/distributed/auto_parallel/utils.py b/python/paddle/distributed/auto_parallel/utils.py index e8557a2931..9c1f9a8c7c 100755 --- a/python/paddle/distributed/auto_parallel/utils.py +++ b/python/paddle/distributed/auto_parallel/utils.py @@ -642,7 +642,7 @@ def _load_distributed_state_dict(checkpoint_path): """ Load parameters' state_dict from checkpoint_path """ all_state_dict = {} for idx, ckpt_file in enumerate(checkpoint_path): - state_dict_info = paddle.load(ckpt_file) + state_dict_info = paddle.load(ckpt_file, return_numpy=True) pre_world_size = state_dict_info["world_size"] assert pre_world_size == len(checkpoint_path), \ "The number of 'checkpoint_path' must be equal to the last training world size." @@ -778,12 +778,16 @@ def _merge_parameter_with_dist_attr(param_list, dist_attr): dims_mapping) # merge the parameter with dist_attr partition_param_list = [] + merged_partiton = [] for process in process_group: partition_index = _compute_partition_index( process, complete_shape, dims_mapping, process_shape, process_group) index = process_group.index(process) - _merge_parameter(partition_param_list, param_list[index], - partition_index) + if partition_index not in merged_partiton: + merged_partiton.append(partition_index) + _merge_parameter(partition_param_list, param_list[index], + partition_index, complete_shape) + assert len(partition_param_list) == 1 or not partition_param_list, \ "Fail to merge parameter" complete_param = _to_LodTensor(partition_param_list[0][0]) @@ -810,7 +814,8 @@ def _slice_parameter_with_dist_attr(param, dist_attr): return sliced_param -def _merge_parameter(partition_param_list, param, partition_index): +def _merge_parameter(partition_param_list, param, partition_index, + complete_shape): """ Merge partitial parameters to a complete one. @@ -830,16 +835,23 @@ def _merge_parameter(partition_param_list, param, partition_index): """ from .reshard import _compute_concat_info + if len(partition_param_list) == 1: + is_complete_data = True + for idx, item in enumerate(partition_param_list[0][1]): + if item[0] != 0 or item[1] != complete_shape[idx]: + is_complete_data = False + break + if is_complete_data: + return + if not partition_param_list: partition_param_list.append((param, partition_index)) else: i = 0 - has_concat = False while i < len(partition_param_list): concat_axis, first_order, new_partition = _compute_concat_info( partition_param_list[i][1], partition_index) if concat_axis != -1: - has_concat = True if first_order == 0: new_param = np.concatenate( (partition_param_list[i][0], param), axis=concat_axis) @@ -848,19 +860,11 @@ def _merge_parameter(partition_param_list, param, partition_index): (param, partition_param_list[i][0]), axis=concat_axis) partition_param_list.pop(i) - _merge_parameter(partition_param_list, new_param, new_partition) + _merge_parameter(partition_param_list, new_param, new_partition, + complete_shape) break i += 1 - if not has_concat: - need_append = True - for i in range(len(partition_param_list)): - if partition_index == partition_param_list[i][1]: - need_append = False - break - if need_append: - partition_param_list.append((param, partition_index)) - def _slice_parameter(complete_param, partition_index_list, length): """ -- GitLab