未验证 提交 693c3c14 编写于 作者: Z zhaoyingli 提交者: GitHub

fix autoconvert (#37347)

* fix autoconvert

* fix merge parameter
上级 d2200e97
......@@ -642,7 +642,7 @@ def _load_distributed_state_dict(checkpoint_path):
""" Load parameters' state_dict from checkpoint_path """
all_state_dict = {}
for idx, ckpt_file in enumerate(checkpoint_path):
state_dict_info = paddle.load(ckpt_file)
state_dict_info = paddle.load(ckpt_file, return_numpy=True)
pre_world_size = state_dict_info["world_size"]
assert pre_world_size == len(checkpoint_path), \
"The number of 'checkpoint_path' must be equal to the last training world size."
......@@ -778,12 +778,16 @@ def _merge_parameter_with_dist_attr(param_list, dist_attr):
dims_mapping)
# merge the parameter with dist_attr
partition_param_list = []
merged_partiton = []
for process in process_group:
partition_index = _compute_partition_index(
process, complete_shape, dims_mapping, process_shape, process_group)
index = process_group.index(process)
_merge_parameter(partition_param_list, param_list[index],
partition_index)
if partition_index not in merged_partiton:
merged_partiton.append(partition_index)
_merge_parameter(partition_param_list, param_list[index],
partition_index, complete_shape)
assert len(partition_param_list) == 1 or not partition_param_list, \
"Fail to merge parameter"
complete_param = _to_LodTensor(partition_param_list[0][0])
......@@ -810,7 +814,8 @@ def _slice_parameter_with_dist_attr(param, dist_attr):
return sliced_param
def _merge_parameter(partition_param_list, param, partition_index):
def _merge_parameter(partition_param_list, param, partition_index,
complete_shape):
"""
Merge partitial parameters to a complete one.
......@@ -830,16 +835,23 @@ def _merge_parameter(partition_param_list, param, partition_index):
"""
from .reshard import _compute_concat_info
if len(partition_param_list) == 1:
is_complete_data = True
for idx, item in enumerate(partition_param_list[0][1]):
if item[0] != 0 or item[1] != complete_shape[idx]:
is_complete_data = False
break
if is_complete_data:
return
if not partition_param_list:
partition_param_list.append((param, partition_index))
else:
i = 0
has_concat = False
while i < len(partition_param_list):
concat_axis, first_order, new_partition = _compute_concat_info(
partition_param_list[i][1], partition_index)
if concat_axis != -1:
has_concat = True
if first_order == 0:
new_param = np.concatenate(
(partition_param_list[i][0], param), axis=concat_axis)
......@@ -848,19 +860,11 @@ def _merge_parameter(partition_param_list, param, partition_index):
(param, partition_param_list[i][0]), axis=concat_axis)
partition_param_list.pop(i)
_merge_parameter(partition_param_list, new_param, new_partition)
_merge_parameter(partition_param_list, new_param, new_partition,
complete_shape)
break
i += 1
if not has_concat:
need_append = True
for i in range(len(partition_param_list)):
if partition_index == partition_param_list[i][1]:
need_append = False
break
if need_append:
partition_param_list.append((param, partition_index))
def _slice_parameter(complete_param, partition_index_list, length):
"""
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册