Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
pingzhuyan
mindspore
提交
4683de34
M
mindspore
项目概览
pingzhuyan
/
mindspore
与 Fork 源项目一致
Fork自
MindSpore / mindspore
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
M
mindspore
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
体验新版 GitCode,发现更多精彩内容 >>
提交
4683de34
编写于
8月 29, 2020
作者:
L
liuyang_655
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
modify save_checkpoint
上级
b346f0b3
变更
5
隐藏空白更改
内联
并排
Showing
5 changed file
with
49 addition
and
52 deletion
+49
-52
mindspore/train/callback/_checkpoint.py
mindspore/train/callback/_checkpoint.py
+3
-3
mindspore/train/serialization.py
mindspore/train/serialization.py
+33
-37
model_zoo/official/gnn/gat/train.py
model_zoo/official/gnn/gat/train.py
+2
-2
model_zoo/official/nlp/tinybert/src/utils.py
model_zoo/official/nlp/tinybert/src/utils.py
+5
-5
tests/ut/python/utils/test_serialize.py
tests/ut/python/utils/test_serialize.py
+6
-5
未找到文件。
mindspore/train/callback/_checkpoint.py
浏览文件 @
4683de34
...
...
@@ -23,7 +23,7 @@ import mindspore.context as context
from
mindspore
import
log
as
logger
from
mindspore._checkparam
import
check_bool
,
check_int_non_negative
from
mindspore.train._utils
import
_make_directory
from
mindspore.train.serialization
import
_exec_
save_checkpoint
,
_save_graph
from
mindspore.train.serialization
import
save_checkpoint
,
_save_graph
from
._callback
import
Callback
,
set_cur_net
...
...
@@ -306,8 +306,8 @@ class ModelCheckpoint(Callback):
set_cur_net
(
cb_params
.
train_network
)
cb_params
.
train_network
.
exec_checkpoint_graph
()
_exec_
save_checkpoint
(
cb_params
.
train_network
,
cur_file
,
self
.
_config
.
integrated_save
,
self
.
_config
.
async_save
)
save_checkpoint
(
cb_params
.
train_network
,
cur_file
,
self
.
_config
.
integrated_save
,
self
.
_config
.
async_save
)
self
.
_latest_ckpt_file_name
=
cur_file
...
...
mindspore/train/serialization.py
浏览文件 @
4683de34
...
...
@@ -141,24 +141,52 @@ def _exec_save(ckpt_file_name, data_list):
raise
RuntimeError
(
e
.
__str__
())
def
save_checkpoint
(
parameter_list
,
ckpt_file_nam
e
,
async_save
=
False
):
def
save_checkpoint
(
save_obj
,
ckpt_file_name
,
integrated_save
=
Tru
e
,
async_save
=
False
):
"""
Saves checkpoint info to a specified file.
Args:
parameter_list (list): Parameters list, each element is a dictionary
like {"name":xx, "type":xx, "shape":xx, "data":xx}.
save_obj (nn.Cell or list): The train network for training or parameters list(each element is a dictionary,
like {"name":xx, "type":xx, "shape":xx, "data":xx}.)
ckpt_file_name (str): Checkpoint file name.
integrated_save (bool): Whether to integrated save in automatic model parallel scene.
async_save (bool): Whether asynchronous execution saves the checkpoint to a file. Default: False
Raises:
TypeError: If the parameter save_obj is not nn.Cell or list type.
RuntimeError: Failed to save the Checkpoint file.
"""
if
not
isinstance
(
save_obj
,
nn
.
Cell
)
and
not
isinstance
(
save_obj
,
list
):
raise
TypeError
(
"The parameter save_obj should be nn.Cell or list, but got {}"
.
format
(
type
(
save_obj
)))
logger
.
info
(
"Execute save checkpoint process."
)
if
isinstance
(
save_obj
,
nn
.
Cell
):
save_obj
.
init_parameters_data
()
param_dict
=
{}
for
_
,
param
in
save_obj
.
parameters_and_names
():
param_dict
[
param
.
name
]
=
param
param_list
=
[]
for
(
key
,
value
)
in
param_dict
.
items
():
each_param
=
{
"name"
:
key
}
if
isinstance
(
value
.
data
,
Tensor
):
param_data
=
value
.
data
else
:
param_data
=
Tensor
(
value
.
data
)
# in automatic model parallel scenario, some parameters were spliteds to all the devices,
# which should be combined before saving
if
integrated_save
and
key
in
save_obj
.
parameter_layout_dict
:
param_data
=
_get_merged_param_data
(
save_obj
,
key
,
param_data
)
each_param
[
"data"
]
=
param_data
param_list
.
append
(
each_param
)
save_obj
=
param_list
data_list
=
{}
with
_ckpt_mutex
:
for
param
in
parameter_list
:
for
param
in
save_obj
:
key
=
param
[
"name"
]
data_list
[
key
]
=
[]
if
isinstance
(
param
[
"data"
],
Parameter
):
...
...
@@ -180,6 +208,7 @@ def save_checkpoint(parameter_list, ckpt_file_name, async_save=False):
thr
.
start
()
else
:
_exec_save
(
ckpt_file_name
,
data_list
)
logger
.
info
(
"Save checkpoint process finish."
)
...
...
@@ -354,39 +383,6 @@ def _save_graph(network, file_name):
os
.
chmod
(
file_name
,
stat
.
S_IRUSR
)
def
_exec_save_checkpoint
(
train_network
,
ckpt_file_name
,
integrated_save
=
True
,
async_save
=
False
):
"""
Saves checkpoint for 'ms' backend.
Args:
train_network (Network): The train network for training.
ckpt_file_name (str): The name of checkpoint file.
integrated_save (bool): Whether to integrated save in automatic model parallel scene.
async_save (bool): Whether asynchronous execute save checkpoint into file. Default: False.
"""
train_network
.
init_parameters_data
()
param_dict
=
{}
for
_
,
param
in
train_network
.
parameters_and_names
():
param_dict
[
param
.
name
]
=
param
param_list
=
[]
for
(
key
,
value
)
in
param_dict
.
items
():
each_param
=
{
"name"
:
key
}
if
isinstance
(
value
.
data
,
Tensor
):
param_data
=
value
.
data
else
:
param_data
=
Tensor
(
value
.
data
)
# in automatic model parallel scenario, some parameters were spliteds to all the devices,
# which should be combined before saving
if
integrated_save
and
key
in
train_network
.
parameter_layout_dict
:
param_data
=
_get_merged_param_data
(
train_network
,
key
,
param_data
)
each_param
[
"data"
]
=
param_data
param_list
.
append
(
each_param
)
save_checkpoint
(
param_list
,
ckpt_file_name
,
async_save
)
def
_get_merged_param_data
(
net
,
param_name
,
param_data
):
"""
Gets the merged data(tensor) from tensor slice, by device arrangement and tensor map.
...
...
model_zoo/official/gnn/gat/train.py
浏览文件 @
4683de34
...
...
@@ -18,7 +18,7 @@ import os
import
numpy
as
np
import
mindspore.context
as
context
from
mindspore.train.serialization
import
_exec_
save_checkpoint
,
load_checkpoint
from
mindspore.train.serialization
import
save_checkpoint
,
load_checkpoint
from
src.config
import
GatConfig
from
src.dataset
import
load_and_process
...
...
@@ -98,7 +98,7 @@ def train():
val_loss_model
=
eval_loss
if
os
.
path
.
exists
(
"ckpts/gat.ckpt"
):
os
.
remove
(
"ckpts/gat.ckpt"
)
_exec_
save_checkpoint
(
train_net
.
network
,
"ckpts/gat.ckpt"
)
save_checkpoint
(
train_net
.
network
,
"ckpts/gat.ckpt"
)
val_acc_max
=
np
.
max
((
val_acc_max
,
eval_acc
))
val_loss_min
=
np
.
min
((
val_loss_min
,
eval_loss
))
curr_step
=
0
...
...
model_zoo/official/nlp/tinybert/src/utils.py
浏览文件 @
4683de34
...
...
@@ -20,7 +20,7 @@ import numpy as np
from
mindspore
import
Tensor
from
mindspore.common
import
dtype
as
mstype
from
mindspore.train.callback
import
Callback
from
mindspore.train.serialization
import
_exec_
save_checkpoint
from
mindspore.train.serialization
import
save_checkpoint
from
mindspore.ops
import
operations
as
P
from
mindspore.nn.learning_rate_schedule
import
LearningRateSchedule
,
PolynomialDecayLR
,
WarmUpLR
from
.assessment_method
import
Accuracy
...
...
@@ -53,9 +53,9 @@ class ModelSaveCkpt(Callback):
self
.
save_ckpt_step
))
if
os
.
path
.
exists
(
path
):
os
.
remove
(
path
)
_exec_
save_checkpoint
(
self
.
network
,
os
.
path
.
join
(
self
.
output_dir
,
"tiny_bert_{}_{}.ckpt"
.
format
(
int
(
saved_ckpt_num
),
self
.
save_ckpt_step
)))
save_checkpoint
(
self
.
network
,
os
.
path
.
join
(
self
.
output_dir
,
"tiny_bert_{}_{}.ckpt"
.
format
(
int
(
saved_ckpt_num
),
self
.
save_ckpt_step
)))
class
LossCallBack
(
Callback
):
"""
...
...
@@ -113,7 +113,7 @@ class EvalCallBack(Callback):
eval_model_ckpt_file
=
"eval_model.ckpt"
if
os
.
path
.
exists
(
eval_model_ckpt_file
):
os
.
remove
(
eval_model_ckpt_file
)
_exec_
save_checkpoint
(
self
.
network
,
eval_model_ckpt_file
)
save_checkpoint
(
self
.
network
,
eval_model_ckpt_file
)
class
BertLearningRate
(
LearningRateSchedule
):
"""
...
...
tests/ut/python/utils/test_serialize.py
浏览文件 @
4683de34
...
...
@@ -31,7 +31,7 @@ from mindspore.nn.optim.momentum import Momentum
from
mindspore.ops
import
operations
as
P
from
mindspore.train.callback
import
_CheckpointManager
from
mindspore.train.serialization
import
save_checkpoint
,
load_checkpoint
,
load_param_into_net
,
\
_exec_save_checkpoint
,
export
,
_save_graph
export
,
_save_graph
from
..ut_filter
import
non_graph_engine
context
.
set_context
(
mode
=
context
.
GRAPH_MODE
,
print_file_path
=
"print/print.pb"
)
...
...
@@ -95,8 +95,8 @@ def test_save_graph():
os
.
remove
(
output_file
)
def
test_save_checkpoint
():
""" test
_save_checkpoint
"""
def
test_save_checkpoint
_for_list
():
""" test
save_checkpoint for list
"""
parameter_list
=
[]
one_param
=
{}
param1
=
{}
...
...
@@ -280,14 +280,15 @@ def test_load_param_into_net():
assert
net
.
conv1
.
weight
.
default_input
.
asnumpy
()[
0
][
0
][
0
][
0
]
==
1
def
test_exec_save_checkpoint
():
def
test_save_checkpoint_for_network
():
""" test save_checkpoint for network"""
net
=
Net
()
loss
=
SoftmaxCrossEntropyWithLogits
(
is_grad
=
False
,
sparse
=
True
)
opt
=
Momentum
(
net
.
trainable_params
(),
0.0
,
0.9
,
0.0001
,
1024
)
loss_net
=
WithLossCell
(
net
,
loss
)
train_network
=
TrainOneStepCell
(
loss_net
,
opt
)
_exec_
save_checkpoint
(
train_network
,
ckpt_file_name
=
"./new_ckpt.ckpt"
)
save_checkpoint
(
train_network
,
ckpt_file_name
=
"./new_ckpt.ckpt"
)
load_checkpoint
(
"new_ckpt.ckpt"
)
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录