Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
magicwindyyd
mindspore
提交
ce8aec5a
M
mindspore
项目概览
magicwindyyd
/
mindspore
与 Fork 源项目一致
Fork自
MindSpore / mindspore
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
M
mindspore
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
ce8aec5a
编写于
4月 18, 2020
作者:
M
mindspore-ci-bot
提交者:
Gitee
4月 18, 2020
浏览文件
操作
浏览文件
下载
差异文件
!428 support customize network checkpoint
Merge pull request !428 from changzherui/cus_net_ckpt
上级
76c700fb
93e90959
变更
1
隐藏空白更改
内联
并排
Showing
1 changed file
with
37 addition
and
19 deletion
+37
-19
mindspore/train/serialization.py
mindspore/train/serialization.py
+37
-19
未找到文件。
mindspore/train/serialization.py
浏览文件 @
ce8aec5a
...
...
@@ -224,42 +224,60 @@ def load_param_into_net(net, parameter_dict):
msg
=
(
"Argument parameter_dict should be a dict, but got {}."
.
format
(
type
(
parameter_dict
)))
raise
TypeError
(
msg
)
logger
.
info
(
"Execute parameter into net process."
)
param_name_net_not_have
=
[]
logger
.
info
(
"Execute load parameter into net process."
)
for
name
in
parameter_dict
:
b_par_dict_have_par_of_net
=
False
for
_
,
param
in
net
.
parameters_and_names
():
if
name
==
param
.
name
:
b_par_dict_have_par_of_net
=
True
if
name
==
param
.
name
and
param
.
layerwise_parallel
:
# layerwise parallel parameter data loaded from checkpoint file,
# was a complete(merged) data, need to be splited
if
param
.
layerwise_parallel
:
new_param
=
parameter_dict
[
param
.
name
]
_load_tensor_for_layerwise
(
new_param
,
param
)
new_param
=
parameter_dict
[
param
.
name
]
_load_tensor_for_layerwise
(
new_param
,
param
)
break
if
not
b_par_dict_have_par_of_net
:
param_name_net_not_have
.
append
(
name
)
param_n
ame_param_dict_not_have
=
[]
param_n
ot_load
=
[]
for
_
,
param
in
net
.
parameters_and_names
():
if
param
.
name
in
parameter_dict
:
new_param
=
parameter_dict
[
param
.
name
]
if
not
isinstance
(
new_param
,
Parameter
):
logger
.
error
(
"Failed to combine the net and the parameters."
)
msg
=
(
"Argument parameter_dict element should be a Parameter, but got {}."
.
format
(
type
(
new_param
)))
raise
TypeError
(
msg
)
_update_param
(
param
,
new_param
)
else
:
param_name_param_dict_not_have
.
append
(
param
.
name
)
param_not_load
.
append
(
param
.
name
)
if
param_not_load
:
_load_dismatch_prefix_params
(
net
,
parameter_dict
,
param_not_load
)
logger
.
debug
(
"Params not matched(in net but not in parameter_dict):"
)
for
paramname
in
param_name_param_dict_not_have
:
logger
.
debug
(
"%s"
,
paramname
)
logger
.
debug
(
"Params not matched(in parameter_dict but not in net):"
)
for
paramname
in
param_name_net_not_have
:
logger
.
debug
(
"%s"
,
paramname
)
logger
.
info
(
"Load parameter into net process finish."
)
for
param_name
in
param_not_load
:
logger
.
debug
(
"%s"
,
param_name
)
logger
.
info
(
"Load parameter into net finish, {} parameters has not been loaded."
.
format
(
len
(
param_not_load
)))
def
_load_dismatch_prefix_params
(
net
,
parameter_dict
,
param_not_load
):
"""When some net parameter did not load, try to continue load."""
prefix_name
=
""
longest_name
=
param_not_load
[
0
]
while
prefix_name
!=
longest_name
and
param_not_load
:
logger
.
debug
(
"Count: {} parameters has not been loaded, try to load continue."
.
format
(
len
(
param_not_load
)))
longest_name
=
sorted
(
param_not_load
,
key
=
len
,
reverse
=
True
)[
0
]
prefix_name
=
longest_name
for
net_param_name
in
param_not_load
:
for
dict_name
in
parameter_dict
:
if
dict_name
.
endswith
(
net_param_name
):
tmp_name
=
dict_name
[:
-
len
(
net_param_name
)]
prefix_name
=
prefix_name
if
len
(
prefix_name
)
<
len
(
tmp_name
)
else
tmp_name
if
prefix_name
!=
longest_name
:
logger
.
info
(
"Remove parameter prefix name: {}, continue to load."
.
format
(
prefix_name
))
for
_
,
param
in
net
.
parameters_and_names
():
new_param_name
=
prefix_name
+
param
.
name
if
param
.
name
in
param_not_load
and
new_param_name
in
parameter_dict
:
new_param
=
parameter_dict
[
new_param_name
]
_update_param
(
param
,
new_param
)
param_not_load
.
remove
(
param
.
name
)
def
_save_graph
(
network
,
file_name
):
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录