Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
正统之独孤求败
mindspore
提交
22c6baee
M
mindspore
项目概览
正统之独孤求败
/
mindspore
与 Fork 源项目一致
Fork自
MindSpore / mindspore
通知
1
Star
0
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
M
mindspore
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
体验新版 GitCode,发现更多精彩内容 >>
提交
22c6baee
编写于
4月 02, 2020
作者:
W
WeibiaoYu
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
Support to config whether to save integeated checkpoint, in auto model parallel scene
上级
60f7a95b
变更
4
隐藏空白更改
内联
并排
Showing
4 changed file
with
19 addition
and
43 deletion
+19
-43
mindspore/common/api.py
mindspore/common/api.py
+0
-35
mindspore/train/callback.py
mindspore/train/callback.py
+13
-3
mindspore/train/serialization.py
mindspore/train/serialization.py
+4
-3
tests/ut/python/utils/test_callback.py
tests/ut/python/utils/test_callback.py
+2
-2
未找到文件。
mindspore/common/api.py
浏览文件 @
22c6baee
...
...
@@ -374,9 +374,6 @@ class _Executor:
obj
.
parameter_layout_dict
=
self
.
_executor
.
get_parameter_layout
(
phase
)
obj
.
load_parameter_slice
(
params
)
if
_get_parallel_mode
()
in
[
"hybrid_parallel"
]:
obj
.
parameter_layout_dict
=
self
.
_build_parameter_layout
(
obj
)
# the following GE init process is not needed when use vm or ms backend
if
enable_ge
:
# decide whether to sink based on whether the inputs is virtual or not
...
...
@@ -449,38 +446,6 @@ class _Executor:
return
self
.
_exec_pip
(
obj
,
*
args
,
phase
=
phase_real
)
raise
KeyError
(
'{} graph is not exist.'
.
format
(
phase_real
))
def
_build_parameter_layout
(
self
,
obj
):
"""
Build parameter layout, for layerwise_parallel parameter.
Args:
obj (Function or Cell): The function or cell instance need to be compiled.
Returns:
Dictionary, parameter layout info.
"""
parameter_layout_dict
=
{}
layerwise_parallel_parameters
=
[]
for
key
in
obj
.
parameters_dict
():
if
obj
.
parameters_dict
()[
key
].
layerwise_parallel
is
True
:
layerwise_parallel_parameters
.
append
(
key
)
if
not
layerwise_parallel_parameters
:
return
parameter_layout_dict
from
..communication.management
import
get_group_size
group_size
=
[
get_group_size
()]
for
key
in
layerwise_parallel_parameters
:
tensor_map
=
[
0
]
shape
=
obj
.
parameters_dict
()[
key
].
data
.
shape
()
for
x
in
range
(
len
(
shape
)):
# dim 0 set 0, others set -1
if
x
:
tensor_map
.
append
(
-
1
)
layout
=
[
group_size
,
tensor_map
]
parameter_layout_dict
[
key
]
=
layout
return
parameter_layout_dict
def
del_net_res
(
self
,
net_id
):
self
.
_executor
.
del_net_res
(
net_id
)
...
...
mindspore/train/callback.py
浏览文件 @
22c6baee
...
...
@@ -24,7 +24,7 @@ import mindspore.context as context
from
mindspore.train.serialization
import
_exec_save_checkpoint
,
_fill_param_into_net
,
_save_graph
from
mindspore.train._utils
import
_make_directory
from
mindspore
import
log
as
logger
from
mindspore._checkparam
import
check_int_non_negative
from
mindspore._checkparam
import
check_int_non_negative
,
check_bool
from
mindspore.common.tensor
import
Tensor
from
.summary.summary_record
import
_cache_summary_tensor_data
...
...
@@ -150,6 +150,8 @@ class CheckpointConfig:
keep_checkpoint_max (int): Maximum step to save checkpoint. Default: 5.
keep_checkpoint_per_n_minutes (int): Keep one checkpoint every n minutes. Default: 0.
Can't be used with keep_checkpoint_max at the same time.
integrated_save (bool): Whether to intergrated save in automatic model parall scene. Default: True.
Integrated save function is only supported in automatic parall scene, not supported in manual parallel.
Raises:
ValueError: If the input_param is None or 0.
...
...
@@ -163,7 +165,8 @@ class CheckpointConfig:
save_checkpoint_steps
=
1
,
save_checkpoint_seconds
=
0
,
keep_checkpoint_max
=
5
,
keep_checkpoint_per_n_minutes
=
0
):
keep_checkpoint_per_n_minutes
=
0
,
integrated_save
=
True
):
if
not
save_checkpoint_steps
and
not
save_checkpoint_seconds
and
\
not
keep_checkpoint_max
and
not
keep_checkpoint_per_n_minutes
:
...
...
@@ -191,6 +194,8 @@ class CheckpointConfig:
if
not
self
.
_keep_checkpoint_per_n_minutes
or
self
.
_keep_checkpoint_per_n_minutes
==
0
:
self
.
_keep_checkpoint_max
=
1
self
.
_integrated_save
=
check_bool
(
integrated_save
)
@
property
def
save_checkpoint_steps
(
self
):
"""Get the value of _save_checkpoint_steps."""
...
...
@@ -211,6 +216,11 @@ class CheckpointConfig:
"""Get the value of _keep_checkpoint_per_n_minutes."""
return
self
.
_keep_checkpoint_per_n_minutes
@
property
def
integrated_save
(
self
):
"""Get the value of _integrated_save."""
return
self
.
_integrated_save
def
get_checkpoint_policy
(
self
):
"""Get the policy of checkpoint."""
checkpoint_policy
=
{
'save_checkpoint_steps'
:
self
.
_save_checkpoint_steps
,
...
...
@@ -619,7 +629,7 @@ class ModelCheckpoint(Callback):
_set_cur_net
(
cb_params
.
train_network
)
cb_params
.
train_network
.
exec_checkpoint_graph
()
_exec_save_checkpoint
(
cb_params
.
train_network
,
gen_file
)
_exec_save_checkpoint
(
cb_params
.
train_network
,
gen_file
,
self
.
_config
.
integrated_save
)
if
os
.
path
.
exists
(
gen_file
):
shutil
.
move
(
gen_file
,
cur_file
)
...
...
mindspore/train/serialization.py
浏览文件 @
22c6baee
...
...
@@ -279,13 +279,14 @@ def _save_graph(network, file_name):
os
.
chmod
(
file_name
,
stat
.
S_IWUSR
|
stat
.
S_IRUSR
)
def
_exec_save_checkpoint
(
train_network
,
ckpoint_file_name
):
def
_exec_save_checkpoint
(
train_network
,
ckpoint_file_name
,
integrated_save
=
True
):
"""
Saves checkpoint for 'ms' backend.
Args:
train_network (Network): The train network for training.
ckpoint_file_name (str): The name of checkpoint file.
integrated_save (bool): Whether to intergrated save in automatic model parallel scene.
"""
param_dict
=
{}
...
...
@@ -300,9 +301,9 @@ def _exec_save_checkpoint(train_network, ckpoint_file_name):
else
:
param_data
=
Tensor
(
value
.
data
)
# in model parallel scenario, some parameters were spliteds to all the devices,
# in
automatic
model parallel scenario, some parameters were spliteds to all the devices,
# which should be combined before saving
if
key
in
train_network
.
parameter_layout_dict
:
if
integrated_save
and
key
in
train_network
.
parameter_layout_dict
:
param_data
=
_get_merged_param_data
(
train_network
,
key
,
param_data
)
each_param
[
"data"
]
=
param_data
...
...
tests/ut/python/utils/test_callback.py
浏览文件 @
22c6baee
...
...
@@ -308,10 +308,10 @@ def test_RunContext():
def
test_Checkpoint_Config
():
"""Test CheckpointConfig all None or 0."""
with
pytest
.
raises
(
ValueError
):
CheckpointConfig
(
0
,
0
,
0
,
0
)
CheckpointConfig
(
0
,
0
,
0
,
0
,
True
)
with
pytest
.
raises
(
ValueError
):
CheckpointConfig
(
0
,
None
,
0
,
0
)
CheckpointConfig
(
0
,
None
,
0
,
0
,
True
)
def
test_step_end_save_graph
():
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录