Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
magicwindyyd
mindspore
提交
087779b7
M
mindspore
项目概览
magicwindyyd
/
mindspore
与 Fork 源项目一致
Fork自
MindSpore / mindspore
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
M
mindspore
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
087779b7
编写于
6月 24, 2020
作者:
M
mindspore-ci-bot
提交者:
Gitee
6月 24, 2020
浏览文件
操作
浏览文件
下载
差异文件
!2517 checkpoint add model_type
Merge pull request !2517 from chenzhongming/quant
上级
0a368494
d3f9b800
变更
15
隐藏空白更改
内联
并排
Showing
15 changed file
with
132 addition
and
86 deletion
+132
-86
mindspore/_checkparam.py
mindspore/_checkparam.py
+11
-0
mindspore/ccsrc/utils/checkpoint.proto
mindspore/ccsrc/utils/checkpoint.proto
+1
-0
mindspore/train/callback/_checkpoint.py
mindspore/train/callback/_checkpoint.py
+19
-9
mindspore/train/callback/_loss_monitor.py
mindspore/train/callback/_loss_monitor.py
+2
-2
mindspore/train/serialization.py
mindspore/train/serialization.py
+31
-19
model_zoo/lenet/eval.py
model_zoo/lenet/eval.py
+3
-8
model_zoo/lenet_quant/README.md
model_zoo/lenet_quant/README.md
+7
-7
model_zoo/lenet_quant/eval.py
model_zoo/lenet_quant/eval.py
+7
-6
model_zoo/lenet_quant/eval_quant.py
model_zoo/lenet_quant/eval_quant.py
+7
-7
model_zoo/lenet_quant/src/lenet.py
model_zoo/lenet_quant/src/lenet.py
+2
-2
model_zoo/lenet_quant/src/lenet_fusion.py
model_zoo/lenet_quant/src/lenet_fusion.py
+3
-2
model_zoo/lenet_quant/train.py
model_zoo/lenet_quant/train.py
+12
-4
model_zoo/lenet_quant/train_quant.py
model_zoo/lenet_quant/train_quant.py
+14
-7
tests/ut/python/predict/test_predict_save_model.py
tests/ut/python/predict/test_predict_save_model.py
+1
-1
tests/ut/python/utils/test_serialize.py
tests/ut/python/utils/test_serialize.py
+12
-12
未找到文件。
mindspore/_checkparam.py
浏览文件 @
087779b7
...
...
@@ -593,6 +593,17 @@ def check_bool(input_param):
raise
TypeError
(
"Input type must be bool!"
)
def
check_string
(
input_param
,
valid_values
):
"""String type judgment."""
if
isinstance
(
input_param
,
str
)
and
input_param
in
valid_values
:
return
input_param
if
len
(
valid_values
)
==
1
:
raise
ValueError
(
f
'Input should be str and must be
{
valid_values
[
0
]
}
,'
f
' but got
{
input_param
}
.'
)
raise
ValueError
(
f
'Input should be str and must be one of
{
valid_values
}
,'
f
' but got
{
input_param
}
.'
)
def
check_input_format
(
input_param
):
"""Judge input format."""
if
input_param
==
"NCHW"
:
...
...
mindspore/ccsrc/utils/checkpoint.proto
浏览文件 @
087779b7
...
...
@@ -22,6 +22,7 @@ message Checkpoint {
required
TensorProto
tensor
=
2
;
}
repeated
Value
value
=
1
;
required
string
model_type
=
2
;
}
...
...
mindspore/train/callback/_checkpoint.py
浏览文件 @
087779b7
...
...
@@ -21,17 +21,16 @@ import time
import
mindspore.context
as
context
from
mindspore
import
log
as
logger
from
mindspore._checkparam
import
check_bool
,
check_int_non_negative
from
mindspore._checkparam
import
check_bool
,
check_
string
,
check_
int_non_negative
from
mindspore.train._utils
import
_make_directory
from
mindspore.train.serialization
import
_exec_save_checkpoint
,
_save_graph
from
._callback
import
Callback
,
set_cur_net
_cur_dir
=
os
.
getcwd
()
_save_dir
=
_cur_dir
def
_check_file_name_prefix
(
file_name_prefix
):
"""
Check file name valid or not.
...
...
@@ -87,6 +86,7 @@ class CheckpointConfig:
Can't be used with keep_checkpoint_max at the same time.
integrated_save (bool): Whether to intergrated save in automatic model parallel scene. Default: True.
Integrated save function is only supported in automatic parallel scene, not supported in manual parallel.
model_type (str): Model type in `normal`, `fusion` or `quant`. Default: "normal".
Raises:
ValueError: If the input_param is None or 0.
...
...
@@ -101,7 +101,8 @@ class CheckpointConfig:
save_checkpoint_seconds
=
0
,
keep_checkpoint_max
=
5
,
keep_checkpoint_per_n_minutes
=
0
,
integrated_save
=
True
):
integrated_save
=
True
,
model_type
=
"normal"
):
if
not
save_checkpoint_steps
and
not
save_checkpoint_seconds
and
\
not
keep_checkpoint_max
and
not
keep_checkpoint_per_n_minutes
:
...
...
@@ -115,6 +116,8 @@ class CheckpointConfig:
keep_checkpoint_max
=
check_int_non_negative
(
keep_checkpoint_max
)
if
keep_checkpoint_per_n_minutes
:
keep_checkpoint_per_n_minutes
=
check_int_non_negative
(
keep_checkpoint_per_n_minutes
)
if
model_type
:
model_type
=
check_string
(
model_type
,
[
"normal"
,
"fusion"
,
"quant"
])
self
.
_save_checkpoint_steps
=
save_checkpoint_steps
self
.
_save_checkpoint_seconds
=
save_checkpoint_seconds
...
...
@@ -129,6 +132,7 @@ class CheckpointConfig:
if
not
self
.
_keep_checkpoint_per_n_minutes
or
self
.
_keep_checkpoint_per_n_minutes
==
0
:
self
.
_keep_checkpoint_max
=
1
self
.
_model_type
=
model_type
self
.
_integrated_save
=
check_bool
(
integrated_save
)
@
property
...
...
@@ -156,12 +160,18 @@ class CheckpointConfig:
"""Get the value of _integrated_save."""
return
self
.
_integrated_save
@
property
def
model_type
(
self
):
"""Get the value of model_type."""
return
self
.
_model_type
def
get_checkpoint_policy
(
self
):
"""Get the policy of checkpoint."""
checkpoint_policy
=
{
'save_checkpoint_steps'
:
self
.
_save_checkpoint_steps
,
'save_checkpoint_seconds'
:
self
.
_save_checkpoint_seconds
,
'keep_checkpoint_max'
:
self
.
_keep_checkpoint_max
,
'keep_checkpoint_per_n_minutes'
:
self
.
_keep_checkpoint_per_n_minutes
}
'keep_checkpoint_per_n_minutes'
:
self
.
_keep_checkpoint_per_n_minutes
,
'model_type'
:
self
.
_model_type
}
return
checkpoint_policy
...
...
@@ -226,7 +236,7 @@ class ModelCheckpoint(Callback):
graph_file_name
=
os
.
path
.
join
(
self
.
_directory
,
self
.
_prefix
+
'-graph.meta'
)
_save_graph
(
cb_params
.
train_network
,
graph_file_name
)
self
.
_graph_saved
=
True
self
.
_save_ckpt
(
cb_params
)
self
.
_save_ckpt
(
cb_params
,
self
.
_config
.
model_type
)
def
end
(
self
,
run_context
):
"""
...
...
@@ -237,7 +247,7 @@ class ModelCheckpoint(Callback):
"""
cb_params
=
run_context
.
original_args
()
_to_save_last_ckpt
=
True
self
.
_save_ckpt
(
cb_params
,
_to_save_last_ckpt
)
self
.
_save_ckpt
(
cb_params
,
self
.
_config
.
model_type
,
_to_save_last_ckpt
)
from
mindspore.parallel._cell_wrapper
import
destroy_allgather_cell
destroy_allgather_cell
()
...
...
@@ -256,7 +266,7 @@ class ModelCheckpoint(Callback):
return
False
def
_save_ckpt
(
self
,
cb_params
,
force_to_save
=
False
):
def
_save_ckpt
(
self
,
cb_params
,
model_type
,
force_to_save
=
False
):
"""Save checkpoint files."""
if
cb_params
.
cur_step_num
==
self
.
_last_triggered_step
:
return
...
...
@@ -292,7 +302,7 @@ class ModelCheckpoint(Callback):
set_cur_net
(
cb_params
.
train_network
)
cb_params
.
train_network
.
exec_checkpoint_graph
()
_exec_save_checkpoint
(
cb_params
.
train_network
,
gen_file
,
self
.
_config
.
integrated_save
)
_exec_save_checkpoint
(
cb_params
.
train_network
,
gen_file
,
model_type
,
self
.
_config
.
integrated_save
)
if
os
.
path
.
exists
(
gen_file
):
shutil
.
move
(
gen_file
,
cur_file
)
...
...
mindspore/train/callback/_loss_monitor.py
浏览文件 @
087779b7
...
...
@@ -76,7 +76,7 @@ class LossMonitor(Callback):
step_loss
=
np
.
mean
(
step_loss
.
asnumpy
())
self
.
losses
.
append
(
step_loss
)
cur_step_in_epoch
=
int
((
cb_params
.
cur_step_num
-
1
)
%
cb_params
.
batch_num
)
cur_step_in_epoch
=
int
((
cb_params
.
cur_step_num
-
1
)
%
cb_params
.
batch_num
)
+
1
if
isinstance
(
step_loss
,
float
)
and
(
np
.
isnan
(
step_loss
)
or
np
.
isinf
(
step_loss
)):
raise
ValueError
(
"Epoch: [{:3d}/{:3d}], step: [{:5d}/{:5d}]. "
...
...
@@ -87,7 +87,7 @@ class LossMonitor(Callback):
if
self
.
_per_print_times
!=
0
and
cb_params
.
cur_step_num
%
self
.
_per_print_times
==
0
:
print
(
"Epoch: [{:3d}/{:3d}], step: [{:5d}/{:5d}], "
"loss: [{:5.4f}/{:5.4f}], time: [{:5.4f}]"
.
format
(
cb_params
.
cur_epoch_num
-
1
,
cb_params
.
epoch_num
,
cb_params
.
cur_epoch_num
,
cb_params
.
epoch_num
,
cur_step_in_epoch
,
int
(
cb_params
.
batch_num
),
step_loss
,
np
.
mean
(
self
.
losses
),
step_mseconds
),
flush
=
True
)
mindspore/train/serialization.py
浏览文件 @
087779b7
...
...
@@ -29,6 +29,7 @@ from mindspore.common.api import _executor
from
mindspore.common
import
dtype
as
mstype
from
mindspore._checkparam
import
check_input_data
__all__
=
[
"save_checkpoint"
,
"load_checkpoint"
,
"load_param_into_net"
,
"export"
]
tensor_to_ms_type
=
{
"Int8"
:
mstype
.
int8
,
"Uint8"
:
mstype
.
uint8
,
"Int16"
:
mstype
.
int16
,
"Uint16"
:
mstype
.
uint16
,
...
...
@@ -40,6 +41,8 @@ tensor_to_np_type = {"Int8": np.int8, "Uint8": np.uint8, "Int16": np.int16, "Uin
"Int32"
:
np
.
int32
,
"Uint32"
:
np
.
uint32
,
"Int64"
:
np
.
int64
,
"Uint64"
:
np
.
uint64
,
"Float16"
:
np
.
float16
,
"Float32"
:
np
.
float32
,
"Float64"
:
np
.
float64
,
"Bool"
:
np
.
bool_
}
ModelType
=
[
"normal"
,
"fusion"
,
"quant"
]
def
_special_process_par
(
par
,
new_par
):
"""
...
...
@@ -101,20 +104,22 @@ def _update_param(param, new_param):
param
.
set_parameter_data
(
type
(
param
.
data
)(
new_param
.
data
))
def
save_checkpoint
(
parameter_list
,
ckp
oint_file_name
):
def
save_checkpoint
(
parameter_list
,
ckp
t_file_name
,
model_type
=
"normal"
):
"""
Saves checkpoint info to a specified file.
Args:
parameter_list (list): Parameters list, each element is a dict
like {"name":xx, "type":xx, "shape":xx, "data":xx}.
ckpoint_file_name (str): Checkpoint file name.
ckpt_file_name (str): Checkpoint file name.
model_type (str): The name of model type. Default: "normal".
Raises:
RuntimeError: Failed to save the Checkpoint file.
"""
logger
.
info
(
"Execute save checkpoint process."
)
checkpoint_list
=
Checkpoint
()
checkpoint_list
.
model_type
=
model_type
try
:
for
param
in
parameter_list
:
...
...
@@ -133,22 +138,23 @@ def save_checkpoint(parameter_list, ckpoint_file_name):
for
dim
in
param
[
'data'
].
shape
:
param_tensor
.
dims
.
append
(
dim
)
with
open
(
ckp
oin
t_file_name
,
"wb"
)
as
f
:
with
open
(
ckpt_file_name
,
"wb"
)
as
f
:
f
.
write
(
checkpoint_list
.
SerializeToString
())
os
.
chmod
(
ckp
oin
t_file_name
,
stat
.
S_IRUSR
)
os
.
chmod
(
ckpt_file_name
,
stat
.
S_IRUSR
)
except
BaseException
as
e
:
logger
.
error
(
"Failed to save the checkpoint file %s."
,
ckp
oin
t_file_name
)
logger
.
error
(
"Failed to save the checkpoint file %s."
,
ckpt_file_name
)
raise
RuntimeError
(
e
.
__str__
())
logger
.
info
(
"Save checkpoint process finish."
)
def
load_checkpoint
(
ckp
oint_file_name
,
net
=
None
):
def
load_checkpoint
(
ckp
t_file_name
,
model_type
=
"normal"
,
net
=
None
):
"""
Loads checkpoint info from a specified file.
Args:
ckpoint_file_name (str): Checkpoint file name.
ckpt_file_name (str): Checkpoint file name.
model_type (str): The name of model type in `normal`, `fusion` or `quant`. Default: "normal".
net (Cell): Cell network. Default: None
Returns:
...
...
@@ -157,28 +163,33 @@ def load_checkpoint(ckpoint_file_name, net=None):
Raises:
ValueError: Checkpoint file is incorrect.
"""
if
not
isinstance
(
ckp
oin
t_file_name
,
str
):
raise
ValueError
(
"The ckp
oint_file_name must be S
tring."
)
if
not
isinstance
(
ckpt_file_name
,
str
):
raise
ValueError
(
"The ckp
t_file_name must be s
tring."
)
if
not
os
.
path
.
exists
(
ckpoint_file_name
)
or
ckpoint_file_name
[
-
5
:]
!=
".ckpt"
:
if
model_type
not
in
ModelType
:
raise
ValueError
(
f
"The model_type is not in
{
ModelType
}
."
)
if
not
os
.
path
.
exists
(
ckpt_file_name
)
or
ckpt_file_name
[
-
5
:]
!=
".ckpt"
:
raise
ValueError
(
"Please input the correct checkpoint file name."
)
if
os
.
path
.
getsize
(
ckp
oin
t_file_name
)
==
0
:
if
os
.
path
.
getsize
(
ckpt_file_name
)
==
0
:
raise
ValueError
(
"The checkpoint file may be empty, please make sure enter the correct file name."
)
logger
.
info
(
"Execute load checkpoint process."
)
checkpoint_list
=
Checkpoint
()
try
:
with
open
(
ckp
oin
t_file_name
,
"rb"
)
as
f
:
with
open
(
ckpt_file_name
,
"rb"
)
as
f
:
pb_content
=
f
.
read
()
checkpoint_list
.
ParseFromString
(
pb_content
)
except
BaseException
as
e
:
logger
.
error
(
"Failed to read the checkpoint file
%s, please check the correct of the file."
,
ckpoin
t_file_name
)
logger
.
error
(
"Failed to read the checkpoint file
`%s`, please check the correct of the file."
,
ckp
t_file_name
)
raise
ValueError
(
e
.
__str__
())
parameter_dict
=
{}
if
model_type
!=
checkpoint_list
.
model_type
:
raise
KeyError
(
"Checkpoint file model type({}) is not equal to input model type({})."
.
format
(
checkpoint_list
.
model_type
,
model_type
))
try
:
for
element
in
checkpoint_list
.
value
:
data
=
element
.
tensor
.
tensor_content
...
...
@@ -206,7 +217,7 @@ def load_checkpoint(ckpoint_file_name, net=None):
logger
.
info
(
"Load checkpoint process finish."
)
except
BaseException
as
e
:
logger
.
error
(
"Failed to load the checkpoint file
%s."
,
ckpoin
t_file_name
)
logger
.
error
(
"Failed to load the checkpoint file
`%s`."
,
ckp
t_file_name
)
raise
RuntimeError
(
e
.
__str__
())
if
net
:
...
...
@@ -303,14 +314,15 @@ def _save_graph(network, file_name):
os
.
chmod
(
file_name
,
stat
.
S_IWUSR
|
stat
.
S_IRUSR
)
def
_exec_save_checkpoint
(
train_network
,
ckp
oint_file_name
,
integrated_save
=
True
):
def
_exec_save_checkpoint
(
train_network
,
ckp
t_file_name
,
model_type
=
"normal"
,
integrated_save
=
True
):
"""
Saves checkpoint for 'ms' backend.
Args:
train_network (Network): The train network for training.
ckpoint_file_name (str): The name of checkpoint file.
integrated_save (bool): Whether to intergrated save in automatic model parallel scene.
ckpt_file_name (str): The name of checkpoint file.
model_type (str): The name of model type in `normal`, `fusion` or `quant`. Default: "normal".
integrated_save (bool): Whether to integrated save in automatic model parallel scene.
"""
param_dict
=
{}
...
...
@@ -334,7 +346,7 @@ def _exec_save_checkpoint(train_network, ckpoint_file_name, integrated_save=True
each_param
[
"data"
]
=
param_data
param_list
.
append
(
each_param
)
save_checkpoint
(
param_list
,
ckp
oint_file_nam
e
)
save_checkpoint
(
param_list
,
ckp
t_file_name
,
model_typ
e
)
def
_get_merged_param_data
(
net
,
param_name
,
param_data
):
...
...
model_zoo/lenet/eval.py
浏览文件 @
087779b7
...
...
@@ -20,16 +20,14 @@ python eval.py --data_path /YourDataPath --ckpt_path Your.ckpt
import
os
import
argparse
from
src.dataset
import
create_dataset
from
src.config
import
mnist_cfg
as
cfg
from
src.lenet
import
LeNet5
import
mindspore.nn
as
nn
from
mindspore
import
context
from
mindspore.train.serialization
import
load_checkpoint
,
load_param_into_net
from
mindspore.train.callback
import
ModelCheckpoint
,
CheckpointConfig
from
mindspore.train
import
Model
from
mindspore.nn.metrics
import
Accuracy
from
src.dataset
import
create_dataset
from
src.config
import
mnist_cfg
as
cfg
from
src.lenet
import
LeNet5
if
__name__
==
"__main__"
:
parser
=
argparse
.
ArgumentParser
(
description
=
'MindSpore Lenet Example'
)
...
...
@@ -49,9 +47,6 @@ if __name__ == "__main__":
net_loss
=
nn
.
SoftmaxCrossEntropyWithLogits
(
is_grad
=
False
,
sparse
=
True
,
reduction
=
"mean"
)
repeat_size
=
cfg
.
epoch_size
net_opt
=
nn
.
Momentum
(
network
.
trainable_params
(),
cfg
.
lr
,
cfg
.
momentum
)
config_ck
=
CheckpointConfig
(
save_checkpoint_steps
=
cfg
.
save_checkpoint_steps
,
keep_checkpoint_max
=
cfg
.
keep_checkpoint_max
)
ckpoint_cb
=
ModelCheckpoint
(
prefix
=
"checkpoint_lenet"
,
config
=
config_ck
)
model
=
Model
(
network
,
net_loss
,
net_opt
,
metrics
=
{
"Accuracy"
:
Accuracy
()})
print
(
"============== Starting Testing =============="
)
...
...
model_zoo/lenet_quant/README.md
浏览文件 @
087779b7
...
...
@@ -128,9 +128,9 @@ After all the following we will get the loss value of each step as following:
```
bash
>>>
Epoch:
[
1/ 10] step:
[
1/ 900], loss:
[
2.3040/2.5234],
time
:
[
1.300234]
>>>
...
>>>
Epoch:
[
10
/ 10] step:
[
887/ 900], loss:
[
0.0113/0.0223],
time
:
[
1.300234]
>>>
Epoch:
[
10
/ 10] step:
[
888/ 900], loss:
[
0.0334/0.0223],
time
:
[
1.300234]
>>>
Epoch:
[
10
/ 10] step:
[
889/ 900], loss:
[
0.0233/0.0223],
time
:
[
1.300234]
>>>
Epoch:
[
9
/ 10] step:
[
887/ 900], loss:
[
0.0113/0.0223],
time
:
[
1.300234]
>>>
Epoch:
[
9
/ 10] step:
[
888/ 900], loss:
[
0.0334/0.0223],
time
:
[
1.300234]
>>>
Epoch:
[
9
/ 10] step:
[
889/ 900], loss:
[
0.0233/0.0223],
time
:
[
1.300234]
```
Also, you can just run this command instead.
...
...
@@ -197,9 +197,9 @@ After all the following we will get the loss value of each step as following:
```
bash
>>>
Epoch:
[
1/ 10] step:
[
1/ 900], loss:
[
2.3040/2.5234],
time
:
[
1.300234]
>>>
...
>>>
Epoch:
[
10
/ 10] step:
[
887/ 900], loss:
[
0.0113/0.0223],
time
:
[
1.300234]
>>>
Epoch:
[
10
/ 10] step:
[
888/ 900], loss:
[
0.0334/0.0223],
time
:
[
1.300234]
>>>
Epoch:
[
10
/ 10] step:
[
889/ 900], loss:
[
0.0233/0.0223],
time
:
[
1.300234]
>>>
Epoch:
[
9
/ 10] step:
[
887/ 900], loss:
[
0.0113/0.0223],
time
:
[
1.300234]
>>>
Epoch:
[
9
/ 10] step:
[
888/ 900], loss:
[
0.0334/0.0223],
time
:
[
1.300234]
>>>
Epoch:
[
9
/ 10] step:
[
889/ 900], loss:
[
0.0233/0.0223],
time
:
[
1.300234]
```
### Evaluate quantization aware model
...
...
@@ -215,7 +215,7 @@ param_dict = load_checkpoint(args.ckpt_path)
load_param_into_net
(
network
,
param_dict
)
# convert funsion netwrok to quantization aware network
network
=
quant
.
convert_quant_network
(
network
network
=
quant
.
convert_quant_network
(
network
)
```
Also, you can just run this command insread.
...
...
model_zoo/lenet_quant/eval.py
浏览文件 @
087779b7
...
...
@@ -23,7 +23,6 @@ import argparse
import
mindspore.nn
as
nn
from
mindspore
import
context
from
mindspore.train.serialization
import
load_checkpoint
,
load_param_into_net
from
mindspore.train.callback
import
ModelCheckpoint
,
CheckpointConfig
from
mindspore.train
import
Model
from
mindspore.nn.metrics
import
Accuracy
from
src.dataset
import
create_dataset
...
...
@@ -47,16 +46,18 @@ if __name__ == "__main__":
ds_eval
=
create_dataset
(
os
.
path
.
join
(
args
.
data_path
,
"test"
),
cfg
.
batch_size
,
1
)
step_size
=
ds_eval
.
get_dataset_size
()
# define fusion network
network
=
LeNet5Fusion
(
cfg
.
num_classes
)
# define loss
net_loss
=
nn
.
SoftmaxCrossEntropyWithLogits
(
is_grad
=
False
,
sparse
=
True
,
reduction
=
"mean"
)
repeat_size
=
cfg
.
epoch_size
# define network optimization
net_opt
=
nn
.
Momentum
(
network
.
trainable_params
(),
cfg
.
lr
,
cfg
.
momentum
)
config_ck
=
CheckpointConfig
(
save_checkpoint_steps
=
cfg
.
epoch_size
*
step_size
,
keep_checkpoint_max
=
cfg
.
keep_checkpoint_max
)
ckpoint_cb
=
ModelCheckpoint
(
prefix
=
"checkpoint_lenet"
,
config
=
config_ck
)
# call back and monitor
model
=
Model
(
network
,
net_loss
,
net_opt
,
metrics
=
{
"Accuracy"
:
Accuracy
()})
param_dict
=
load_checkpoint
(
args
.
ckpt_path
)
# load check point into network
param_dict
=
load_checkpoint
(
args
.
ckpt_path
,
network
.
type
)
load_param_into_net
(
network
,
param_dict
)
print
(
"============== Starting Testing =============="
)
...
...
model_zoo/lenet_quant/eval_quant.py
浏览文件 @
087779b7
...
...
@@ -23,7 +23,6 @@ import argparse
import
mindspore.nn
as
nn
from
mindspore
import
context
from
mindspore.train.serialization
import
load_checkpoint
,
load_param_into_net
from
mindspore.train.callback
import
ModelCheckpoint
,
CheckpointConfig
from
mindspore.train
import
Model
from
mindspore.nn.metrics
import
Accuracy
from
mindspore.train.quant
import
quant
...
...
@@ -48,20 +47,21 @@ if __name__ == "__main__":
ds_eval
=
create_dataset
(
os
.
path
.
join
(
args
.
data_path
,
"test"
),
cfg
.
batch_size
,
1
)
step_size
=
ds_eval
.
get_dataset_size
()
# define fu
n
sion network
# define fusion network
network
=
LeNet5Fusion
(
cfg
.
num_classes
)
# convert fu
n
sion netwrok to quantization aware network
# convert fusion netwrok to quantization aware network
network
=
quant
.
convert_quant_network
(
network
,
quant_delay
=
0
,
bn_fold
=
False
,
freeze_bn
=
10000
)
# define loss
net_loss
=
nn
.
SoftmaxCrossEntropyWithLogits
(
is_grad
=
False
,
sparse
=
True
,
reduction
=
"mean"
)
# define network optimization
net_opt
=
nn
.
Momentum
(
network
.
trainable_params
(),
cfg
.
lr
,
cfg
.
momentum
)
config_ck
=
CheckpointConfig
(
save_checkpoint_steps
=
cfg
.
epoch_size
*
step_size
,
keep_checkpoint_max
=
cfg
.
keep_checkpoint_max
)
ckpoint_cb
=
ModelCheckpoint
(
prefix
=
"checkpoint_lenet"
,
config
=
config_ck
)
# call back and monitor
model
=
Model
(
network
,
net_loss
,
net_opt
,
metrics
=
{
"Accuracy"
:
Accuracy
()})
# load quantization aware network checkpoint
param_dict
=
load_checkpoint
(
args
.
ckpt_path
)
param_dict
=
load_checkpoint
(
args
.
ckpt_path
,
model_type
=
"quant"
)
load_param_into_net
(
network
,
param_dict
)
print
(
"============== Starting Testing =============="
)
...
...
model_zoo/lenet_quant/src/lenet.py
浏览文件 @
087779b7
...
...
@@ -34,8 +34,8 @@ class LeNet5(nn.Cell):
super
(
LeNet5
,
self
).
__init__
()
self
.
num_class
=
num_class
self
.
conv1
=
nn
.
Conv2d
(
channel
,
6
,
5
)
self
.
conv2
=
nn
.
Conv2d
(
6
,
16
,
5
)
self
.
conv1
=
nn
.
Conv2d
(
channel
,
6
,
5
,
pad_mode
=
'valid'
)
self
.
conv2
=
nn
.
Conv2d
(
6
,
16
,
5
,
pad_mode
=
'valid'
)
self
.
fc1
=
nn
.
Dense
(
16
*
5
*
5
,
120
)
self
.
fc2
=
nn
.
Dense
(
120
,
84
)
self
.
fc3
=
nn
.
Dense
(
84
,
self
.
num_class
)
...
...
model_zoo/lenet_quant/src/lenet_fusion.py
浏览文件 @
087779b7
...
...
@@ -32,11 +32,12 @@ class LeNet5(nn.Cell):
def
__init__
(
self
,
num_class
=
10
,
channel
=
1
):
super
(
LeNet5
,
self
).
__init__
()
self
.
type
=
"fusion"
self
.
num_class
=
num_class
# change `nn.Conv2d` to `nn.Conv2dBnAct`
self
.
conv1
=
nn
.
Conv2dBnAct
(
channel
,
6
,
5
,
activation
=
'relu'
)
self
.
conv2
=
nn
.
Conv2dBnAct
(
6
,
16
,
5
,
activation
=
'relu'
)
self
.
conv1
=
nn
.
Conv2dBnAct
(
channel
,
6
,
5
,
pad_mode
=
'valid'
,
activation
=
'relu'
)
self
.
conv2
=
nn
.
Conv2dBnAct
(
6
,
16
,
5
,
pad_mode
=
'valid'
,
activation
=
'relu'
)
# change `nn.Dense` to `nn.DenseBnAct`
self
.
fc1
=
nn
.
DenseBnAct
(
16
*
5
*
5
,
120
,
activation
=
'relu'
)
self
.
fc2
=
nn
.
DenseBnAct
(
120
,
84
,
activation
=
'relu'
)
...
...
model_zoo/lenet_quant/train.py
浏览文件 @
087779b7
...
...
@@ -46,16 +46,24 @@ if __name__ == "__main__":
ds_train
=
create_dataset
(
os
.
path
.
join
(
args
.
data_path
,
"train"
),
cfg
.
batch_size
,
cfg
.
epoch_size
)
step_size
=
ds_train
.
get_dataset_size
()
# define fusion network
network
=
LeNet5Fusion
(
cfg
.
num_classes
)
# define network loss
net_loss
=
nn
.
SoftmaxCrossEntropyWithLogits
(
is_grad
=
False
,
sparse
=
True
,
reduction
=
"mean"
)
# define network optimization
net_opt
=
nn
.
Momentum
(
network
.
trainable_params
(),
cfg
.
lr
,
cfg
.
momentum
)
# call back and monitor
time_cb
=
TimeMonitor
(
data_size
=
ds_train
.
get_dataset_size
())
config_ck
=
CheckpointConfig
(
save_checkpoint_steps
=
cfg
.
epoch_size
*
step_size
,
keep_checkpoint_max
=
cfg
.
keep_checkpoint_max
)
ckpoint_cb
=
ModelCheckpoint
(
prefix
=
"checkpoint_lenet"
,
config
=
config_ck
)
config_ckpt
=
CheckpointConfig
(
save_checkpoint_steps
=
cfg
.
epoch_size
*
step_size
,
keep_checkpoint_max
=
cfg
.
keep_checkpoint_max
,
model_type
=
network
.
type
)
ckpt_callback
=
ModelCheckpoint
(
prefix
=
"checkpoint_lenet"
,
config
=
config_ckpt
)
# define model
model
=
Model
(
network
,
net_loss
,
net_opt
,
metrics
=
{
"Accuracy"
:
Accuracy
()})
print
(
"============== Starting Training =============="
)
model
.
train
(
cfg
[
'epoch_size'
],
ds_train
,
callbacks
=
[
time_cb
,
ckp
oint_cb
,
LossMonitor
()],
model
.
train
(
cfg
[
'epoch_size'
],
ds_train
,
callbacks
=
[
time_cb
,
ckp
t_callback
,
LossMonitor
()],
dataset_sink_mode
=
args
.
dataset_sink_mode
)
print
(
"============== End Training =============="
)
model_zoo/lenet_quant/train_quant.py
浏览文件 @
087779b7
...
...
@@ -48,23 +48,30 @@ if __name__ == "__main__":
ds_train
=
create_dataset
(
os
.
path
.
join
(
args
.
data_path
,
"train"
),
cfg
.
batch_size
,
cfg
.
epoch_size
)
step_size
=
ds_train
.
get_dataset_size
()
# define fu
n
sion network
# define fusion network
network
=
LeNet5Fusion
(
cfg
.
num_classes
)
# load quantization aware network checkpoint
param_dict
=
load_checkpoint
(
args
.
ckpt_path
)
param_dict
=
load_checkpoint
(
args
.
ckpt_path
,
network
.
type
)
load_param_into_net
(
network
,
param_dict
)
# convert fu
nsion netwro
k to quantization aware network
# convert fu
sion networ
k to quantization aware network
network
=
quant
.
convert_quant_network
(
network
,
quant_delay
=
0
,
bn_fold
=
False
,
freeze_bn
=
10000
)
# define network loss
net_loss
=
nn
.
SoftmaxCrossEntropyWithLogits
(
is_grad
=
False
,
sparse
=
True
,
reduction
=
"mean"
)
# define network optimization
net_opt
=
nn
.
Momentum
(
network
.
trainable_params
(),
cfg
.
lr
,
cfg
.
momentum
)
# call back and monitor
time_cb
=
TimeMonitor
(
data_size
=
ds_train
.
get_dataset_size
())
config_ck
=
CheckpointConfig
(
save_checkpoint_steps
=
cfg
.
epoch_size
*
step_size
,
keep_checkpoint_max
=
cfg
.
keep_checkpoint_max
)
ckpoint_cb
=
ModelCheckpoint
(
prefix
=
"checkpoint_lenet"
,
config
=
config_ck
)
config_ckpt
=
CheckpointConfig
(
save_checkpoint_steps
=
cfg
.
epoch_size
*
step_size
,
keep_checkpoint_max
=
cfg
.
keep_checkpoint_max
,
model_type
=
"quant"
)
ckpt_callback
=
ModelCheckpoint
(
prefix
=
"checkpoint_lenet"
,
config
=
config_ckpt
)
# define model
model
=
Model
(
network
,
net_loss
,
net_opt
,
metrics
=
{
"Accuracy"
:
Accuracy
()})
print
(
"============== Starting Training =============="
)
model
.
train
(
cfg
[
'epoch_size'
],
ds_train
,
callbacks
=
[
time_cb
,
ckp
oint_cb
,
LossMonitor
()],
model
.
train
(
cfg
[
'epoch_size'
],
ds_train
,
callbacks
=
[
time_cb
,
ckp
t_callback
,
LossMonitor
()],
dataset_sink_mode
=
args
.
dataset_sink_mode
)
print
(
"============== End Training =============="
)
tests/ut/python/predict/test_predict_save_model.py
浏览文件 @
087779b7
...
...
@@ -85,7 +85,7 @@ if __name__ == '__main__':
is_ckpt_exist
=
os
.
path
.
exists
(
ckpt_file_path
)
if
is_ckpt_exist
:
param_dict
=
load_checkpoint
(
ckp
oin
t_file_name
=
ckpt_file_path
)
param_dict
=
load_checkpoint
(
ckpt_file_name
=
ckpt_file_path
)
load_param_into_net
(
net
,
param_dict
)
export
(
net
,
input_data
,
file_name
=
model_path_name
,
file_format
=
'LITE'
)
print
(
"test lenet predict success."
)
...
...
tests/ut/python/utils/test_serialize.py
浏览文件 @
087779b7
...
...
@@ -111,19 +111,19 @@ def test_save_checkpoint():
os
.
chmod
(
'./parameters.ckpt'
,
stat
.
S_IWRITE
)
os
.
remove
(
'./parameters.ckpt'
)
ckp
oin
t_file_name
=
os
.
path
.
join
(
_cur_dir
,
'./parameters.ckpt'
)
save_checkpoint
(
parameter_list
,
ckp
oin
t_file_name
)
ckpt_file_name
=
os
.
path
.
join
(
_cur_dir
,
'./parameters.ckpt'
)
save_checkpoint
(
parameter_list
,
ckpt_file_name
)
def
test_load_checkpoint_error_filename
():
ckp
oin
t_file_name
=
1
ckpt_file_name
=
1
with
pytest
.
raises
(
ValueError
):
load_checkpoint
(
ckp
oin
t_file_name
)
load_checkpoint
(
ckpt_file_name
)
def
test_load_checkpoint
():
ckp
oin
t_file_name
=
os
.
path
.
join
(
_cur_dir
,
'./parameters.ckpt'
)
par_dict
=
load_checkpoint
(
ckp
oin
t_file_name
)
ckpt_file_name
=
os
.
path
.
join
(
_cur_dir
,
'./parameters.ckpt'
)
par_dict
=
load_checkpoint
(
ckpt_file_name
)
assert
len
(
par_dict
)
==
3
assert
par_dict
[
'param_test'
].
name
==
'param_test'
...
...
@@ -136,17 +136,17 @@ def test_checkpoint_manager():
""" test_checkpoint_manager """
ckp_mgr
=
_CheckpointManager
()
ckp
oin
t_file_name
=
os
.
path
.
join
(
_cur_dir
,
'./test1.ckpt'
)
with
open
(
ckp
oin
t_file_name
,
'w'
):
os
.
chmod
(
ckp
oin
t_file_name
,
stat
.
S_IWUSR
|
stat
.
S_IRUSR
)
ckpt_file_name
=
os
.
path
.
join
(
_cur_dir
,
'./test1.ckpt'
)
with
open
(
ckpt_file_name
,
'w'
):
os
.
chmod
(
ckpt_file_name
,
stat
.
S_IWUSR
|
stat
.
S_IRUSR
)
ckp_mgr
.
update_ckpoint_filelist
(
_cur_dir
,
"test"
)
assert
ckp_mgr
.
ckpoint_num
==
1
ckp_mgr
.
remove_ckpoint_file
(
ckp
oin
t_file_name
)
ckp_mgr
.
remove_ckpoint_file
(
ckpt_file_name
)
ckp_mgr
.
update_ckpoint_filelist
(
_cur_dir
,
"test"
)
assert
ckp_mgr
.
ckpoint_num
==
0
assert
not
os
.
path
.
exists
(
ckp
oin
t_file_name
)
assert
not
os
.
path
.
exists
(
ckpt_file_name
)
another_file_name
=
os
.
path
.
join
(
_cur_dir
,
'./test2.ckpt'
)
another_file_name
=
os
.
path
.
realpath
(
another_file_name
)
...
...
@@ -283,7 +283,7 @@ def test_exec_save_checkpoint():
loss_net
=
WithLossCell
(
net
,
loss
)
train_network
=
TrainOneStepCell
(
loss_net
,
opt
)
_exec_save_checkpoint
(
train_network
,
ckp
oin
t_file_name
=
"./new_ckpt.ckpt"
)
_exec_save_checkpoint
(
train_network
,
ckpt_file_name
=
"./new_ckpt.ckpt"
)
load_checkpoint
(
"new_ckpt.ckpt"
)
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录