Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
magicwindyyd
mindspore
提交
4936fe48
M
mindspore
项目概览
magicwindyyd
/
mindspore
与 Fork 源项目一致
Fork自
MindSpore / mindspore
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
M
mindspore
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
4936fe48
编写于
7月 16, 2020
作者:
M
mindspore-ci-bot
提交者:
Gitee
7月 16, 2020
浏览文件
操作
浏览文件
下载
差异文件
!2878 Asynchronous saving checkpoint
Merge pull request !2878 from mindspore_ding/checkpoint_mindspore_new
上级
b64fca6e
d45abc5f
变更
2
显示空白变更内容
内联
并排
Showing
2 changed file
with
58 addition
and
29 deletion
+58
-29
mindspore/train/callback/_checkpoint.py
mindspore/train/callback/_checkpoint.py
+11
-7
mindspore/train/serialization.py
mindspore/train/serialization.py
+47
-22
未找到文件。
mindspore/train/callback/_checkpoint.py
浏览文件 @
4936fe48
...
...
@@ -15,7 +15,6 @@
"""Checkpoint related classes and functions."""
import
os
import
shutil
import
stat
import
time
...
...
@@ -86,6 +85,7 @@ class CheckpointConfig:
Can't be used with keep_checkpoint_max at the same time.
integrated_save (bool): Whether to intergrated save in automatic model parallel scene. Default: True.
Integrated save function is only supported in automatic parallel scene, not supported in manual parallel.
async_save (bool): Whether asynchronous execute save checkpoint into file. Default: False
Raises:
ValueError: If the input_param is None or 0.
...
...
@@ -100,7 +100,8 @@ class CheckpointConfig:
save_checkpoint_seconds
=
0
,
keep_checkpoint_max
=
5
,
keep_checkpoint_per_n_minutes
=
0
,
integrated_save
=
True
):
integrated_save
=
True
,
async_save
=
False
):
if
not
save_checkpoint_steps
and
not
save_checkpoint_seconds
and
\
not
keep_checkpoint_max
and
not
keep_checkpoint_per_n_minutes
:
...
...
@@ -129,6 +130,7 @@ class CheckpointConfig:
self
.
_keep_checkpoint_max
=
1
self
.
_integrated_save
=
check_bool
(
integrated_save
)
self
.
_async_save
=
check_bool
(
async_save
)
@
property
def
save_checkpoint_steps
(
self
):
...
...
@@ -155,6 +157,11 @@ class CheckpointConfig:
"""Get the value of _integrated_save."""
return
self
.
_integrated_save
@
property
def
async_save
(
self
):
"""Get the value of _async_save."""
return
self
.
_async_save
def
get_checkpoint_policy
(
self
):
"""Get the policy of checkpoint."""
checkpoint_policy
=
{
'save_checkpoint_steps'
:
self
.
_save_checkpoint_steps
,
...
...
@@ -282,8 +289,6 @@ class ModelCheckpoint(Callback):
global
_save_dir
_save_dir
=
self
.
_directory
cur_file
=
os
.
path
.
join
(
self
.
_directory
,
cur_ckpoint_file
)
tmp_ckpt_file_name_for_cur_process
=
str
(
os
.
getpid
())
+
"-"
+
'parameters.ckpt'
gen_file
=
os
.
path
.
join
(
_save_dir
,
tmp_ckpt_file_name_for_cur_process
)
self
.
_last_time_for_keep
=
time
.
time
()
self
.
_last_triggered_step
=
cb_params
.
cur_step_num
...
...
@@ -291,10 +296,9 @@ class ModelCheckpoint(Callback):
set_cur_net
(
cb_params
.
train_network
)
cb_params
.
train_network
.
exec_checkpoint_graph
()
_exec_save_checkpoint
(
cb_params
.
train_network
,
gen_file
,
self
.
_config
.
integrated_save
)
_exec_save_checkpoint
(
cb_params
.
train_network
,
cur_file
,
self
.
_config
.
integrated_save
,
self
.
_config
.
async_save
)
if
os
.
path
.
exists
(
gen_file
):
shutil
.
move
(
gen_file
,
cur_file
)
self
.
_latest_ckpt_file_name
=
cur_file
@
property
...
...
mindspore/train/serialization.py
浏览文件 @
4936fe48
...
...
@@ -15,6 +15,7 @@
"""Model and parameters serialization."""
import
os
import
stat
from
threading
import
Thread
,
Lock
import
numpy
as
np
import
mindspore.nn
as
nn
...
...
@@ -40,6 +41,7 @@ tensor_to_np_type = {"Int8": np.int8, "Uint8": np.uint8, "Int16": np.int16, "Uin
"Int32"
:
np
.
int32
,
"Uint32"
:
np
.
uint32
,
"Int64"
:
np
.
int64
,
"Uint64"
:
np
.
uint64
,
"Float16"
:
np
.
float16
,
"Float32"
:
np
.
float32
,
"Float64"
:
np
.
float64
,
"Bool"
:
np
.
bool_
}
_ckpt_mutex
=
Lock
()
def
_special_process_par
(
par
,
new_par
):
"""
...
...
@@ -101,7 +103,29 @@ def _update_param(param, new_param):
param
.
set_parameter_data
(
type
(
param
.
data
)(
new_param
.
data
))
def
save_checkpoint
(
parameter_list
,
ckpt_file_name
):
def
_exec_save
(
ckpt_file_name
,
data_list
):
"""Execute save checkpoint into file process."""
checkpoint_list
=
Checkpoint
()
try
:
with
_ckpt_mutex
:
for
name
,
value
in
data_list
.
items
():
param_value
=
checkpoint_list
.
value
.
add
()
param_value
.
tag
=
name
param_tensor
=
param_value
.
tensor
param_tensor
.
dims
.
extend
(
value
[
0
])
param_tensor
.
tensor_type
=
value
[
1
]
param_tensor
.
tensor_content
=
value
[
2
].
tostring
()
with
open
(
ckpt_file_name
,
"wb"
)
as
f
:
f
.
write
(
checkpoint_list
.
SerializeToString
())
os
.
chmod
(
ckpt_file_name
,
stat
.
S_IRUSR
)
except
BaseException
as
e
:
logger
.
error
(
"Failed to save the checkpoint file %s."
,
ckpt_file_name
)
raise
RuntimeError
(
e
.
__str__
())
def
save_checkpoint
(
parameter_list
,
ckpt_file_name
,
async_save
=
False
):
"""
Saves checkpoint info to a specified file.
...
...
@@ -109,37 +133,37 @@ def save_checkpoint(parameter_list, ckpt_file_name):
parameter_list (list): Parameters list, each element is a dict
like {"name":xx, "type":xx, "shape":xx, "data":xx}.
ckpt_file_name (str): Checkpoint file name.
async_save (bool): Whether asynchronous execute save checkpoint into file. Default: False
Raises:
RuntimeError: Failed to save the Checkpoint file.
"""
logger
.
info
(
"Execute save checkpoint process."
)
checkpoint_list
=
Checkpoint
()
try
:
data_list
=
{}
with
_ckpt_mutex
:
for
param
in
parameter_list
:
param_value
=
checkpoint_list
.
value
.
add
()
param_value
.
tag
=
param
[
"name"
]
param_tensor
=
param_value
.
tensor
key
=
param
[
"name"
]
data_list
[
key
]
=
[]
if
isinstance
(
param
[
"data"
],
Parameter
):
param
[
"data"
].
init_data
()
param_data
=
param
[
"data"
].
asnumpy
().
reshape
(
-
1
)
param_tensor
.
tensor_content
=
param_data
.
tostring
()
param_tensor
.
tensor_type
=
str
(
param
[
"data"
].
dtype
)
dims
=
[]
if
param
[
'data'
].
shape
==
():
param_tensor
.
dims
.
append
(
0
)
dims
.
append
(
0
)
else
:
for
dim
in
param
[
'data'
].
shape
:
param_tensor
.
dims
.
append
(
dim
)
with
open
(
ckpt_file_name
,
"wb"
)
as
f
:
f
.
write
(
checkpoint_list
.
SerializeToString
())
os
.
chmod
(
ckpt_file_name
,
stat
.
S_IRUSR
)
except
BaseException
as
e
:
logger
.
error
(
"Failed to save the checkpoint file %s."
,
ckpt_file_name
)
raise
RuntimeError
(
e
.
__str__
())
dims
.
append
(
dim
)
data_list
[
key
].
append
(
dims
)
tensor_type
=
str
(
param
[
"data"
].
dtype
)
data_list
[
key
].
append
(
tensor_type
)
data
=
param
[
"data"
].
asnumpy
().
reshape
(
-
1
)
data_list
[
key
].
append
(
data
)
if
async_save
:
thr
=
Thread
(
target
=
_exec_save
,
args
=
(
ckpt_file_name
,
data_list
))
thr
.
start
()
else
:
_exec_save
(
ckpt_file_name
,
data_list
)
logger
.
info
(
"Save checkpoint process finish."
)
...
...
@@ -305,7 +329,7 @@ def _save_graph(network, file_name):
os
.
chmod
(
file_name
,
stat
.
S_IRUSR
)
def
_exec_save_checkpoint
(
train_network
,
ckpt_file_name
,
integrated_save
=
True
):
def
_exec_save_checkpoint
(
train_network
,
ckpt_file_name
,
integrated_save
=
True
,
async_save
=
False
):
"""
Saves checkpoint for 'ms' backend.
...
...
@@ -313,6 +337,7 @@ def _exec_save_checkpoint(train_network, ckpt_file_name, integrated_save=True):
train_network (Network): The train network for training.
ckpt_file_name (str): The name of checkpoint file.
integrated_save (bool): Whether to integrated save in automatic model parallel scene.
async_save (bool): Whether asynchronous execute save checkpoint into file. Default: False.
"""
param_dict
=
{}
...
...
@@ -336,7 +361,7 @@ def _exec_save_checkpoint(train_network, ckpt_file_name, integrated_save=True):
each_param
[
"data"
]
=
param_data
param_list
.
append
(
each_param
)
save_checkpoint
(
param_list
,
ckpt_file_name
)
save_checkpoint
(
param_list
,
ckpt_file_name
,
async_save
)
def
_get_merged_param_data
(
net
,
param_name
,
param_data
):
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录