提交 731fb550 编写于 作者: M mindspore-ci-bot 提交者: Gitee

!4627 add save checkpoint check durning parallel

Merge pull request !4627 from changzherui/add_ckpt_check
......@@ -396,15 +396,17 @@ def _get_merged_param_data(net, param_name, param_data):
Returns:
Tensor, the combined tensor which with the whole data value.
"""
layout = []
layout = net.parameter_layout_dict[param_name]
if len(layout) < 2:
if len(layout) < 5:
logger.info("layout dict does not contain the key %s", param_name)
return param_data
dev_mat = layout[0]
tensor_map = layout[1]
field_size = layout[3]
uniform_split = layout[4]
if uniform_split[0] == 0:
raise RuntimeError("Save checkpoint only support uniform split tensor now.")
from mindspore.parallel._cell_wrapper import get_allgather_cell
from mindspore.parallel._tensor import _reshape_param_data, _reshape_param_data_with_weight
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册