Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
magicwindyyd
mindspore
提交
087779b7
M
mindspore
项目概览
magicwindyyd
/
mindspore
与 Fork 源项目一致
Fork自
MindSpore / mindspore
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
M
mindspore
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
087779b7
编写于
6月 24, 2020
作者:
M
mindspore-ci-bot
提交者:
Gitee
6月 24, 2020
浏览文件
操作
浏览文件
下载
差异文件
!2517 checkpoint add model_type
Merge pull request !2517 from chenzhongming/quant
上级
0a368494
d3f9b800
变更
15
显示空白变更内容
内联
并排
Showing
15 changed file
with
132 addition
and
86 deletion
+132
-86
mindspore/_checkparam.py
mindspore/_checkparam.py
+11
-0
mindspore/ccsrc/utils/checkpoint.proto
mindspore/ccsrc/utils/checkpoint.proto
+1
-0
mindspore/train/callback/_checkpoint.py
mindspore/train/callback/_checkpoint.py
+19
-9
mindspore/train/callback/_loss_monitor.py
mindspore/train/callback/_loss_monitor.py
+2
-2
mindspore/train/serialization.py
mindspore/train/serialization.py
+31
-19
model_zoo/lenet/eval.py
model_zoo/lenet/eval.py
+3
-8
model_zoo/lenet_quant/README.md
model_zoo/lenet_quant/README.md
+7
-7
model_zoo/lenet_quant/eval.py
model_zoo/lenet_quant/eval.py
+7
-6
model_zoo/lenet_quant/eval_quant.py
model_zoo/lenet_quant/eval_quant.py
+7
-7
model_zoo/lenet_quant/src/lenet.py
model_zoo/lenet_quant/src/lenet.py
+2
-2
model_zoo/lenet_quant/src/lenet_fusion.py
model_zoo/lenet_quant/src/lenet_fusion.py
+3
-2
model_zoo/lenet_quant/train.py
model_zoo/lenet_quant/train.py
+12
-4
model_zoo/lenet_quant/train_quant.py
model_zoo/lenet_quant/train_quant.py
+14
-7
tests/ut/python/predict/test_predict_save_model.py
tests/ut/python/predict/test_predict_save_model.py
+1
-1
tests/ut/python/utils/test_serialize.py
tests/ut/python/utils/test_serialize.py
+12
-12
未找到文件。
mindspore/_checkparam.py
浏览文件 @
087779b7
...
@@ -593,6 +593,17 @@ def check_bool(input_param):
...
@@ -593,6 +593,17 @@ def check_bool(input_param):
raise
TypeError
(
"Input type must be bool!"
)
raise
TypeError
(
"Input type must be bool!"
)
def
check_string
(
input_param
,
valid_values
):
"""String type judgment."""
if
isinstance
(
input_param
,
str
)
and
input_param
in
valid_values
:
return
input_param
if
len
(
valid_values
)
==
1
:
raise
ValueError
(
f
'Input should be str and must be
{
valid_values
[
0
]
}
,'
f
' but got
{
input_param
}
.'
)
raise
ValueError
(
f
'Input should be str and must be one of
{
valid_values
}
,'
f
' but got
{
input_param
}
.'
)
def
check_input_format
(
input_param
):
def
check_input_format
(
input_param
):
"""Judge input format."""
"""Judge input format."""
if
input_param
==
"NCHW"
:
if
input_param
==
"NCHW"
:
...
...
mindspore/ccsrc/utils/checkpoint.proto
浏览文件 @
087779b7
...
@@ -22,6 +22,7 @@ message Checkpoint {
...
@@ -22,6 +22,7 @@ message Checkpoint {
required
TensorProto
tensor
=
2
;
required
TensorProto
tensor
=
2
;
}
}
repeated
Value
value
=
1
;
repeated
Value
value
=
1
;
required
string
model_type
=
2
;
}
}
...
...
mindspore/train/callback/_checkpoint.py
浏览文件 @
087779b7
...
@@ -21,17 +21,16 @@ import time
...
@@ -21,17 +21,16 @@ import time
import
mindspore.context
as
context
import
mindspore.context
as
context
from
mindspore
import
log
as
logger
from
mindspore
import
log
as
logger
from
mindspore._checkparam
import
check_bool
,
check_int_non_negative
from
mindspore._checkparam
import
check_bool
,
check_
string
,
check_
int_non_negative
from
mindspore.train._utils
import
_make_directory
from
mindspore.train._utils
import
_make_directory
from
mindspore.train.serialization
import
_exec_save_checkpoint
,
_save_graph
from
mindspore.train.serialization
import
_exec_save_checkpoint
,
_save_graph
from
._callback
import
Callback
,
set_cur_net
from
._callback
import
Callback
,
set_cur_net
_cur_dir
=
os
.
getcwd
()
_cur_dir
=
os
.
getcwd
()
_save_dir
=
_cur_dir
_save_dir
=
_cur_dir
def
_check_file_name_prefix
(
file_name_prefix
):
def
_check_file_name_prefix
(
file_name_prefix
):
"""
"""
Check file name valid or not.
Check file name valid or not.
...
@@ -87,6 +86,7 @@ class CheckpointConfig:
...
@@ -87,6 +86,7 @@ class CheckpointConfig:
Can't be used with keep_checkpoint_max at the same time.
Can't be used with keep_checkpoint_max at the same time.
integrated_save (bool): Whether to intergrated save in automatic model parallel scene. Default: True.
integrated_save (bool): Whether to intergrated save in automatic model parallel scene. Default: True.
Integrated save function is only supported in automatic parallel scene, not supported in manual parallel.
Integrated save function is only supported in automatic parallel scene, not supported in manual parallel.
model_type (str): Model type in `normal`, `fusion` or `quant`. Default: "normal".
Raises:
Raises:
ValueError: If the input_param is None or 0.
ValueError: If the input_param is None or 0.
...
@@ -101,7 +101,8 @@ class CheckpointConfig:
...
@@ -101,7 +101,8 @@ class CheckpointConfig:
save_checkpoint_seconds
=
0
,
save_checkpoint_seconds
=
0
,
keep_checkpoint_max
=
5
,
keep_checkpoint_max
=
5
,
keep_checkpoint_per_n_minutes
=
0
,
keep_checkpoint_per_n_minutes
=
0
,
integrated_save
=
True
):
integrated_save
=
True
,
model_type
=
"normal"
):
if
not
save_checkpoint_steps
and
not
save_checkpoint_seconds
and
\
if
not
save_checkpoint_steps
and
not
save_checkpoint_seconds
and
\
not
keep_checkpoint_max
and
not
keep_checkpoint_per_n_minutes
:
not
keep_checkpoint_max
and
not
keep_checkpoint_per_n_minutes
:
...
@@ -115,6 +116,8 @@ class CheckpointConfig:
...
@@ -115,6 +116,8 @@ class CheckpointConfig:
keep_checkpoint_max
=
check_int_non_negative
(
keep_checkpoint_max
)
keep_checkpoint_max
=
check_int_non_negative
(
keep_checkpoint_max
)
if
keep_checkpoint_per_n_minutes
:
if
keep_checkpoint_per_n_minutes
:
keep_checkpoint_per_n_minutes
=
check_int_non_negative
(
keep_checkpoint_per_n_minutes
)
keep_checkpoint_per_n_minutes
=
check_int_non_negative
(
keep_checkpoint_per_n_minutes
)
if
model_type
:
model_type
=
check_string
(
model_type
,
[
"normal"
,
"fusion"
,
"quant"
])
self
.
_save_checkpoint_steps
=
save_checkpoint_steps
self
.
_save_checkpoint_steps
=
save_checkpoint_steps
self
.
_save_checkpoint_seconds
=
save_checkpoint_seconds
self
.
_save_checkpoint_seconds
=
save_checkpoint_seconds
...
@@ -129,6 +132,7 @@ class CheckpointConfig:
...
@@ -129,6 +132,7 @@ class CheckpointConfig:
if
not
self
.
_keep_checkpoint_per_n_minutes
or
self
.
_keep_checkpoint_per_n_minutes
==
0
:
if
not
self
.
_keep_checkpoint_per_n_minutes
or
self
.
_keep_checkpoint_per_n_minutes
==
0
:
self
.
_keep_checkpoint_max
=
1
self
.
_keep_checkpoint_max
=
1
self
.
_model_type
=
model_type
self
.
_integrated_save
=
check_bool
(
integrated_save
)
self
.
_integrated_save
=
check_bool
(
integrated_save
)
@
property
@
property
...
@@ -156,12 +160,18 @@ class CheckpointConfig:
...
@@ -156,12 +160,18 @@ class CheckpointConfig:
"""Get the value of _integrated_save."""
"""Get the value of _integrated_save."""
return
self
.
_integrated_save
return
self
.
_integrated_save
@
property
def
model_type
(
self
):
"""Get the value of model_type."""
return
self
.
_model_type
def
get_checkpoint_policy
(
self
):
def
get_checkpoint_policy
(
self
):
"""Get the policy of checkpoint."""
"""Get the policy of checkpoint."""
checkpoint_policy
=
{
'save_checkpoint_steps'
:
self
.
_save_checkpoint_steps
,
checkpoint_policy
=
{
'save_checkpoint_steps'
:
self
.
_save_checkpoint_steps
,
'save_checkpoint_seconds'
:
self
.
_save_checkpoint_seconds
,
'save_checkpoint_seconds'
:
self
.
_save_checkpoint_seconds
,
'keep_checkpoint_max'
:
self
.
_keep_checkpoint_max
,
'keep_checkpoint_max'
:
self
.
_keep_checkpoint_max
,
'keep_checkpoint_per_n_minutes'
:
self
.
_keep_checkpoint_per_n_minutes
}
'keep_checkpoint_per_n_minutes'
:
self
.
_keep_checkpoint_per_n_minutes
,
'model_type'
:
self
.
_model_type
}
return
checkpoint_policy
return
checkpoint_policy
...
@@ -226,7 +236,7 @@ class ModelCheckpoint(Callback):
...
@@ -226,7 +236,7 @@ class ModelCheckpoint(Callback):
graph_file_name
=
os
.
path
.
join
(
self
.
_directory
,
self
.
_prefix
+
'-graph.meta'
)
graph_file_name
=
os
.
path
.
join
(
self
.
_directory
,
self
.
_prefix
+
'-graph.meta'
)
_save_graph
(
cb_params
.
train_network
,
graph_file_name
)
_save_graph
(
cb_params
.
train_network
,
graph_file_name
)
self
.
_graph_saved
=
True
self
.
_graph_saved
=
True
self
.
_save_ckpt
(
cb_params
)
self
.
_save_ckpt
(
cb_params
,
self
.
_config
.
model_type
)
def
end
(
self
,
run_context
):
def
end
(
self
,
run_context
):
"""
"""
...
@@ -237,7 +247,7 @@ class ModelCheckpoint(Callback):
...
@@ -237,7 +247,7 @@ class ModelCheckpoint(Callback):
"""
"""
cb_params
=
run_context
.
original_args
()
cb_params
=
run_context
.
original_args
()
_to_save_last_ckpt
=
True
_to_save_last_ckpt
=
True
self
.
_save_ckpt
(
cb_params
,
_to_save_last_ckpt
)
self
.
_save_ckpt
(
cb_params
,
self
.
_config
.
model_type
,
_to_save_last_ckpt
)
from
mindspore.parallel._cell_wrapper
import
destroy_allgather_cell
from
mindspore.parallel._cell_wrapper
import
destroy_allgather_cell
destroy_allgather_cell
()
destroy_allgather_cell
()
...
@@ -256,7 +266,7 @@ class ModelCheckpoint(Callback):
...
@@ -256,7 +266,7 @@ class ModelCheckpoint(Callback):
return
False
return
False
def
_save_ckpt
(
self
,
cb_params
,
force_to_save
=
False
):
def
_save_ckpt
(
self
,
cb_params
,
model_type
,
force_to_save
=
False
):
"""Save checkpoint files."""
"""Save checkpoint files."""
if
cb_params
.
cur_step_num
==
self
.
_last_triggered_step
:
if
cb_params
.
cur_step_num
==
self
.
_last_triggered_step
:
return
return
...
@@ -292,7 +302,7 @@ class ModelCheckpoint(Callback):
...
@@ -292,7 +302,7 @@ class ModelCheckpoint(Callback):
set_cur_net
(
cb_params
.
train_network
)
set_cur_net
(
cb_params
.
train_network
)
cb_params
.
train_network
.
exec_checkpoint_graph
()
cb_params
.
train_network
.
exec_checkpoint_graph
()
_exec_save_checkpoint
(
cb_params
.
train_network
,
gen_file
,
self
.
_config
.
integrated_save
)
_exec_save_checkpoint
(
cb_params
.
train_network
,
gen_file
,
model_type
,
self
.
_config
.
integrated_save
)
if
os
.
path
.
exists
(
gen_file
):
if
os
.
path
.
exists
(
gen_file
):
shutil
.
move
(
gen_file
,
cur_file
)
shutil
.
move
(
gen_file
,
cur_file
)
...
...
mindspore/train/callback/_loss_monitor.py
浏览文件 @
087779b7
...
@@ -76,7 +76,7 @@ class LossMonitor(Callback):
...
@@ -76,7 +76,7 @@ class LossMonitor(Callback):
step_loss
=
np
.
mean
(
step_loss
.
asnumpy
())
step_loss
=
np
.
mean
(
step_loss
.
asnumpy
())
self
.
losses
.
append
(
step_loss
)
self
.
losses
.
append
(
step_loss
)
cur_step_in_epoch
=
int
((
cb_params
.
cur_step_num
-
1
)
%
cb_params
.
batch_num
)
cur_step_in_epoch
=
int
((
cb_params
.
cur_step_num
-
1
)
%
cb_params
.
batch_num
)
+
1
if
isinstance
(
step_loss
,
float
)
and
(
np
.
isnan
(
step_loss
)
or
np
.
isinf
(
step_loss
)):
if
isinstance
(
step_loss
,
float
)
and
(
np
.
isnan
(
step_loss
)
or
np
.
isinf
(
step_loss
)):
raise
ValueError
(
"Epoch: [{:3d}/{:3d}], step: [{:5d}/{:5d}]. "
raise
ValueError
(
"Epoch: [{:3d}/{:3d}], step: [{:5d}/{:5d}]. "
...
@@ -87,7 +87,7 @@ class LossMonitor(Callback):
...
@@ -87,7 +87,7 @@ class LossMonitor(Callback):
if
self
.
_per_print_times
!=
0
and
cb_params
.
cur_step_num
%
self
.
_per_print_times
==
0
:
if
self
.
_per_print_times
!=
0
and
cb_params
.
cur_step_num
%
self
.
_per_print_times
==
0
:
print
(
"Epoch: [{:3d}/{:3d}], step: [{:5d}/{:5d}], "
print
(
"Epoch: [{:3d}/{:3d}], step: [{:5d}/{:5d}], "
"loss: [{:5.4f}/{:5.4f}], time: [{:5.4f}]"
.
format
(
"loss: [{:5.4f}/{:5.4f}], time: [{:5.4f}]"
.
format
(
cb_params
.
cur_epoch_num
-
1
,
cb_params
.
epoch_num
,
cb_params
.
cur_epoch_num
,
cb_params
.
epoch_num
,
cur_step_in_epoch
,
int
(
cb_params
.
batch_num
),
cur_step_in_epoch
,
int
(
cb_params
.
batch_num
),
step_loss
,
np
.
mean
(
self
.
losses
),
step_loss
,
np
.
mean
(
self
.
losses
),
step_mseconds
),
flush
=
True
)
step_mseconds
),
flush
=
True
)
mindspore/train/serialization.py
浏览文件 @
087779b7
...
@@ -29,6 +29,7 @@ from mindspore.common.api import _executor
...
@@ -29,6 +29,7 @@ from mindspore.common.api import _executor
from
mindspore.common
import
dtype
as
mstype
from
mindspore.common
import
dtype
as
mstype
from
mindspore._checkparam
import
check_input_data
from
mindspore._checkparam
import
check_input_data
__all__
=
[
"save_checkpoint"
,
"load_checkpoint"
,
"load_param_into_net"
,
"export"
]
__all__
=
[
"save_checkpoint"
,
"load_checkpoint"
,
"load_param_into_net"
,
"export"
]
tensor_to_ms_type
=
{
"Int8"
:
mstype
.
int8
,
"Uint8"
:
mstype
.
uint8
,
"Int16"
:
mstype
.
int16
,
"Uint16"
:
mstype
.
uint16
,
tensor_to_ms_type
=
{
"Int8"
:
mstype
.
int8
,
"Uint8"
:
mstype
.
uint8
,
"Int16"
:
mstype
.
int16
,
"Uint16"
:
mstype
.
uint16
,
...
@@ -40,6 +41,8 @@ tensor_to_np_type = {"Int8": np.int8, "Uint8": np.uint8, "Int16": np.int16, "Uin
...
@@ -40,6 +41,8 @@ tensor_to_np_type = {"Int8": np.int8, "Uint8": np.uint8, "Int16": np.int16, "Uin
"Int32"
:
np
.
int32
,
"Uint32"
:
np
.
uint32
,
"Int64"
:
np
.
int64
,
"Uint64"
:
np
.
uint64
,
"Int32"
:
np
.
int32
,
"Uint32"
:
np
.
uint32
,
"Int64"
:
np
.
int64
,
"Uint64"
:
np
.
uint64
,
"Float16"
:
np
.
float16
,
"Float32"
:
np
.
float32
,
"Float64"
:
np
.
float64
,
"Bool"
:
np
.
bool_
}
"Float16"
:
np
.
float16
,
"Float32"
:
np
.
float32
,
"Float64"
:
np
.
float64
,
"Bool"
:
np
.
bool_
}
ModelType
=
[
"normal"
,
"fusion"
,
"quant"
]
def
_special_process_par
(
par
,
new_par
):
def
_special_process_par
(
par
,
new_par
):
"""
"""
...
@@ -101,20 +104,22 @@ def _update_param(param, new_param):
...
@@ -101,20 +104,22 @@ def _update_param(param, new_param):
param
.
set_parameter_data
(
type
(
param
.
data
)(
new_param
.
data
))
param
.
set_parameter_data
(
type
(
param
.
data
)(
new_param
.
data
))
def
save_checkpoint
(
parameter_list
,
ckp
oint_file_name
):
def
save_checkpoint
(
parameter_list
,
ckp
t_file_name
,
model_type
=
"normal"
):
"""
"""
Saves checkpoint info to a specified file.
Saves checkpoint info to a specified file.
Args:
Args:
parameter_list (list): Parameters list, each element is a dict
parameter_list (list): Parameters list, each element is a dict
like {"name":xx, "type":xx, "shape":xx, "data":xx}.
like {"name":xx, "type":xx, "shape":xx, "data":xx}.
ckpoint_file_name (str): Checkpoint file name.
ckpt_file_name (str): Checkpoint file name.
model_type (str): The name of model type. Default: "normal".
Raises:
Raises:
RuntimeError: Failed to save the Checkpoint file.
RuntimeError: Failed to save the Checkpoint file.
"""
"""
logger
.
info
(
"Execute save checkpoint process."
)
logger
.
info
(
"Execute save checkpoint process."
)
checkpoint_list
=
Checkpoint
()
checkpoint_list
=
Checkpoint
()
checkpoint_list
.
model_type
=
model_type
try
:
try
:
for
param
in
parameter_list
:
for
param
in
parameter_list
:
...
@@ -133,22 +138,23 @@ def save_checkpoint(parameter_list, ckpoint_file_name):
...
@@ -133,22 +138,23 @@ def save_checkpoint(parameter_list, ckpoint_file_name):
for
dim
in
param
[
'data'
].
shape
:
for
dim
in
param
[
'data'
].
shape
:
param_tensor
.
dims
.
append
(
dim
)
param_tensor
.
dims
.
append
(
dim
)
with
open
(
ckp
oin
t_file_name
,
"wb"
)
as
f
:
with
open
(
ckpt_file_name
,
"wb"
)
as
f
:
f
.
write
(
checkpoint_list
.
SerializeToString
())
f
.
write
(
checkpoint_list
.
SerializeToString
())
os
.
chmod
(
ckp
oin
t_file_name
,
stat
.
S_IRUSR
)
os
.
chmod
(
ckpt_file_name
,
stat
.
S_IRUSR
)
except
BaseException
as
e
:
except
BaseException
as
e
:
logger
.
error
(
"Failed to save the checkpoint file %s."
,
ckp
oin
t_file_name
)
logger
.
error
(
"Failed to save the checkpoint file %s."
,
ckpt_file_name
)
raise
RuntimeError
(
e
.
__str__
())
raise
RuntimeError
(
e
.
__str__
())
logger
.
info
(
"Save checkpoint process finish."
)
logger
.
info
(
"Save checkpoint process finish."
)
def
load_checkpoint
(
ckp
oint_file_name
,
net
=
None
):
def
load_checkpoint
(
ckp
t_file_name
,
model_type
=
"normal"
,
net
=
None
):
"""
"""
Loads checkpoint info from a specified file.
Loads checkpoint info from a specified file.
Args:
Args:
ckpoint_file_name (str): Checkpoint file name.
ckpt_file_name (str): Checkpoint file name.
model_type (str): The name of model type in `normal`, `fusion` or `quant`. Default: "normal".
net (Cell): Cell network. Default: None
net (Cell): Cell network. Default: None
Returns:
Returns:
...
@@ -157,28 +163,33 @@ def load_checkpoint(ckpoint_file_name, net=None):
...
@@ -157,28 +163,33 @@ def load_checkpoint(ckpoint_file_name, net=None):
Raises:
Raises:
ValueError: Checkpoint file is incorrect.
ValueError: Checkpoint file is incorrect.
"""
"""
if
not
isinstance
(
ckp
oin
t_file_name
,
str
):
if
not
isinstance
(
ckpt_file_name
,
str
):
raise
ValueError
(
"The ckp
oint_file_name must be S
tring."
)
raise
ValueError
(
"The ckp
t_file_name must be s
tring."
)
if
not
os
.
path
.
exists
(
ckpoint_file_name
)
or
ckpoint_file_name
[
-
5
:]
!=
".ckpt"
:
if
model_type
not
in
ModelType
:
raise
ValueError
(
f
"The model_type is not in
{
ModelType
}
."
)
if
not
os
.
path
.
exists
(
ckpt_file_name
)
or
ckpt_file_name
[
-
5
:]
!=
".ckpt"
:
raise
ValueError
(
"Please input the correct checkpoint file name."
)
raise
ValueError
(
"Please input the correct checkpoint file name."
)
if
os
.
path
.
getsize
(
ckp
oin
t_file_name
)
==
0
:
if
os
.
path
.
getsize
(
ckpt_file_name
)
==
0
:
raise
ValueError
(
"The checkpoint file may be empty, please make sure enter the correct file name."
)
raise
ValueError
(
"The checkpoint file may be empty, please make sure enter the correct file name."
)
logger
.
info
(
"Execute load checkpoint process."
)
logger
.
info
(
"Execute load checkpoint process."
)
checkpoint_list
=
Checkpoint
()
checkpoint_list
=
Checkpoint
()
try
:
try
:
with
open
(
ckp
oin
t_file_name
,
"rb"
)
as
f
:
with
open
(
ckpt_file_name
,
"rb"
)
as
f
:
pb_content
=
f
.
read
()
pb_content
=
f
.
read
()
checkpoint_list
.
ParseFromString
(
pb_content
)
checkpoint_list
.
ParseFromString
(
pb_content
)
except
BaseException
as
e
:
except
BaseException
as
e
:
logger
.
error
(
"Failed to read the checkpoint file
%s, please check the correct of the file."
,
ckpoin
t_file_name
)
logger
.
error
(
"Failed to read the checkpoint file
`%s`, please check the correct of the file."
,
ckp
t_file_name
)
raise
ValueError
(
e
.
__str__
())
raise
ValueError
(
e
.
__str__
())
parameter_dict
=
{}
parameter_dict
=
{}
if
model_type
!=
checkpoint_list
.
model_type
:
raise
KeyError
(
"Checkpoint file model type({}) is not equal to input model type({})."
.
format
(
checkpoint_list
.
model_type
,
model_type
))
try
:
try
:
for
element
in
checkpoint_list
.
value
:
for
element
in
checkpoint_list
.
value
:
data
=
element
.
tensor
.
tensor_content
data
=
element
.
tensor
.
tensor_content
...
@@ -206,7 +217,7 @@ def load_checkpoint(ckpoint_file_name, net=None):
...
@@ -206,7 +217,7 @@ def load_checkpoint(ckpoint_file_name, net=None):
logger
.
info
(
"Load checkpoint process finish."
)
logger
.
info
(
"Load checkpoint process finish."
)
except
BaseException
as
e
:
except
BaseException
as
e
:
logger
.
error
(
"Failed to load the checkpoint file
%s."
,
ckpoin
t_file_name
)
logger
.
error
(
"Failed to load the checkpoint file
`%s`."
,
ckp
t_file_name
)
raise
RuntimeError
(
e
.
__str__
())
raise
RuntimeError
(
e
.
__str__
())
if
net
:
if
net
:
...
@@ -303,14 +314,15 @@ def _save_graph(network, file_name):
...
@@ -303,14 +314,15 @@ def _save_graph(network, file_name):
os
.
chmod
(
file_name
,
stat
.
S_IWUSR
|
stat
.
S_IRUSR
)
os
.
chmod
(
file_name
,
stat
.
S_IWUSR
|
stat
.
S_IRUSR
)
def
_exec_save_checkpoint
(
train_network
,
ckp
oint_file_name
,
integrated_save
=
True
):
def
_exec_save_checkpoint
(
train_network
,
ckp
t_file_name
,
model_type
=
"normal"
,
integrated_save
=
True
):
"""
"""
Saves checkpoint for 'ms' backend.
Saves checkpoint for 'ms' backend.
Args:
Args:
train_network (Network): The train network for training.
train_network (Network): The train network for training.
ckpoint_file_name (str): The name of checkpoint file.
ckpt_file_name (str): The name of checkpoint file.
integrated_save (bool): Whether to intergrated save in automatic model parallel scene.
model_type (str): The name of model type in `normal`, `fusion` or `quant`. Default: "normal".
integrated_save (bool): Whether to integrated save in automatic model parallel scene.
"""
"""
param_dict
=
{}
param_dict
=
{}
...
@@ -334,7 +346,7 @@ def _exec_save_checkpoint(train_network, ckpoint_file_name, integrated_save=True
...
@@ -334,7 +346,7 @@ def _exec_save_checkpoint(train_network, ckpoint_file_name, integrated_save=True
each_param
[
"data"
]
=
param_data
each_param
[
"data"
]
=
param_data
param_list
.
append
(
each_param
)
param_list
.
append
(
each_param
)
save_checkpoint
(
param_list
,
ckp
oint_file_nam
e
)
save_checkpoint
(
param_list
,
ckp
t_file_name
,
model_typ
e
)
def
_get_merged_param_data
(
net
,
param_name
,
param_data
):
def
_get_merged_param_data
(
net
,
param_name
,
param_data
):
...
...
model_zoo/lenet/eval.py
浏览文件 @
087779b7
...
@@ -20,16 +20,14 @@ python eval.py --data_path /YourDataPath --ckpt_path Your.ckpt
...
@@ -20,16 +20,14 @@ python eval.py --data_path /YourDataPath --ckpt_path Your.ckpt
import
os
import
os
import
argparse
import
argparse
from
src.dataset
import
create_dataset
from
src.config
import
mnist_cfg
as
cfg
from
src.lenet
import
LeNet5
import
mindspore.nn
as
nn
import
mindspore.nn
as
nn
from
mindspore
import
context
from
mindspore
import
context
from
mindspore.train.serialization
import
load_checkpoint
,
load_param_into_net
from
mindspore.train.serialization
import
load_checkpoint
,
load_param_into_net
from
mindspore.train.callback
import
ModelCheckpoint
,
CheckpointConfig
from
mindspore.train
import
Model
from
mindspore.train
import
Model
from
mindspore.nn.metrics
import
Accuracy
from
mindspore.nn.metrics
import
Accuracy
from
src.dataset
import
create_dataset
from
src.config
import
mnist_cfg
as
cfg
from
src.lenet
import
LeNet5
if
__name__
==
"__main__"
:
if
__name__
==
"__main__"
:
parser
=
argparse
.
ArgumentParser
(
description
=
'MindSpore Lenet Example'
)
parser
=
argparse
.
ArgumentParser
(
description
=
'MindSpore Lenet Example'
)
...
@@ -49,9 +47,6 @@ if __name__ == "__main__":
...
@@ -49,9 +47,6 @@ if __name__ == "__main__":
net_loss
=
nn
.
SoftmaxCrossEntropyWithLogits
(
is_grad
=
False
,
sparse
=
True
,
reduction
=
"mean"
)
net_loss
=
nn
.
SoftmaxCrossEntropyWithLogits
(
is_grad
=
False
,
sparse
=
True
,
reduction
=
"mean"
)
repeat_size
=
cfg
.
epoch_size
repeat_size
=
cfg
.
epoch_size
net_opt
=
nn
.
Momentum
(
network
.
trainable_params
(),
cfg
.
lr
,
cfg
.
momentum
)
net_opt
=
nn
.
Momentum
(
network
.
trainable_params
(),
cfg
.
lr
,
cfg
.
momentum
)
config_ck
=
CheckpointConfig
(
save_checkpoint_steps
=
cfg
.
save_checkpoint_steps
,
keep_checkpoint_max
=
cfg
.
keep_checkpoint_max
)
ckpoint_cb
=
ModelCheckpoint
(
prefix
=
"checkpoint_lenet"
,
config
=
config_ck
)
model
=
Model
(
network
,
net_loss
,
net_opt
,
metrics
=
{
"Accuracy"
:
Accuracy
()})
model
=
Model
(
network
,
net_loss
,
net_opt
,
metrics
=
{
"Accuracy"
:
Accuracy
()})
print
(
"============== Starting Testing =============="
)
print
(
"============== Starting Testing =============="
)
...
...
model_zoo/lenet_quant/README.md
浏览文件 @
087779b7
...
@@ -128,9 +128,9 @@ After all the following we will get the loss value of each step as following:
...
@@ -128,9 +128,9 @@ After all the following we will get the loss value of each step as following:
```
bash
```
bash
>>>
Epoch:
[
1/ 10] step:
[
1/ 900], loss:
[
2.3040/2.5234],
time
:
[
1.300234]
>>>
Epoch:
[
1/ 10] step:
[
1/ 900], loss:
[
2.3040/2.5234],
time
:
[
1.300234]
>>>
...
>>>
...
>>>
Epoch:
[
10
/ 10] step:
[
887/ 900], loss:
[
0.0113/0.0223],
time
:
[
1.300234]
>>>
Epoch:
[
9
/ 10] step:
[
887/ 900], loss:
[
0.0113/0.0223],
time
:
[
1.300234]
>>>
Epoch:
[
10
/ 10] step:
[
888/ 900], loss:
[
0.0334/0.0223],
time
:
[
1.300234]
>>>
Epoch:
[
9
/ 10] step:
[
888/ 900], loss:
[
0.0334/0.0223],
time
:
[
1.300234]
>>>
Epoch:
[
10
/ 10] step:
[
889/ 900], loss:
[
0.0233/0.0223],
time
:
[
1.300234]
>>>
Epoch:
[
9
/ 10] step:
[
889/ 900], loss:
[
0.0233/0.0223],
time
:
[
1.300234]
```
```
Also, you can just run this command instead.
Also, you can just run this command instead.
...
@@ -197,9 +197,9 @@ After all the following we will get the loss value of each step as following:
...
@@ -197,9 +197,9 @@ After all the following we will get the loss value of each step as following:
```
bash
```
bash
>>>
Epoch:
[
1/ 10] step:
[
1/ 900], loss:
[
2.3040/2.5234],
time
:
[
1.300234]
>>>
Epoch:
[
1/ 10] step:
[
1/ 900], loss:
[
2.3040/2.5234],
time
:
[
1.300234]
>>>
...
>>>
...
>>>
Epoch:
[
10
/ 10] step:
[
887/ 900], loss:
[
0.0113/0.0223],
time
:
[
1.300234]
>>>
Epoch:
[
9
/ 10] step:
[
887/ 900], loss:
[
0.0113/0.0223],
time
:
[
1.300234]
>>>
Epoch:
[
10
/ 10] step:
[
888/ 900], loss:
[
0.0334/0.0223],
time
:
[
1.300234]
>>>
Epoch:
[
9
/ 10] step:
[
888/ 900], loss:
[
0.0334/0.0223],
time
:
[
1.300234]
>>>
Epoch:
[
10
/ 10] step:
[
889/ 900], loss:
[
0.0233/0.0223],
time
:
[
1.300234]
>>>
Epoch:
[
9
/ 10] step:
[
889/ 900], loss:
[
0.0233/0.0223],
time
:
[
1.300234]
```
```
### Evaluate quantization aware model
### Evaluate quantization aware model
...
@@ -215,7 +215,7 @@ param_dict = load_checkpoint(args.ckpt_path)
...
@@ -215,7 +215,7 @@ param_dict = load_checkpoint(args.ckpt_path)
load_param_into_net
(
network
,
param_dict
)
load_param_into_net
(
network
,
param_dict
)
# convert funsion netwrok to quantization aware network
# convert funsion netwrok to quantization aware network
network
=
quant
.
convert_quant_network
(
network
network
=
quant
.
convert_quant_network
(
network
)
```
```
Also, you can just run this command insread.
Also, you can just run this command insread.
...
...
model_zoo/lenet_quant/eval.py
浏览文件 @
087779b7
...
@@ -23,7 +23,6 @@ import argparse
...
@@ -23,7 +23,6 @@ import argparse
import
mindspore.nn
as
nn
import
mindspore.nn
as
nn
from
mindspore
import
context
from
mindspore
import
context
from
mindspore.train.serialization
import
load_checkpoint
,
load_param_into_net
from
mindspore.train.serialization
import
load_checkpoint
,
load_param_into_net
from
mindspore.train.callback
import
ModelCheckpoint
,
CheckpointConfig
from
mindspore.train
import
Model
from
mindspore.train
import
Model
from
mindspore.nn.metrics
import
Accuracy
from
mindspore.nn.metrics
import
Accuracy
from
src.dataset
import
create_dataset
from
src.dataset
import
create_dataset
...
@@ -47,16 +46,18 @@ if __name__ == "__main__":
...
@@ -47,16 +46,18 @@ if __name__ == "__main__":
ds_eval
=
create_dataset
(
os
.
path
.
join
(
args
.
data_path
,
"test"
),
cfg
.
batch_size
,
1
)
ds_eval
=
create_dataset
(
os
.
path
.
join
(
args
.
data_path
,
"test"
),
cfg
.
batch_size
,
1
)
step_size
=
ds_eval
.
get_dataset_size
()
step_size
=
ds_eval
.
get_dataset_size
()
# define fusion network
network
=
LeNet5Fusion
(
cfg
.
num_classes
)
network
=
LeNet5Fusion
(
cfg
.
num_classes
)
# define loss
net_loss
=
nn
.
SoftmaxCrossEntropyWithLogits
(
is_grad
=
False
,
sparse
=
True
,
reduction
=
"mean"
)
net_loss
=
nn
.
SoftmaxCrossEntropyWithLogits
(
is_grad
=
False
,
sparse
=
True
,
reduction
=
"mean"
)
repeat_size
=
cfg
.
epoch_size
# define network optimization
net_opt
=
nn
.
Momentum
(
network
.
trainable_params
(),
cfg
.
lr
,
cfg
.
momentum
)
net_opt
=
nn
.
Momentum
(
network
.
trainable_params
(),
cfg
.
lr
,
cfg
.
momentum
)
config_ck
=
CheckpointConfig
(
save_checkpoint_steps
=
cfg
.
epoch_size
*
step_size
,
keep_checkpoint_max
=
cfg
.
keep_checkpoint_max
)
# call back and monitor
ckpoint_cb
=
ModelCheckpoint
(
prefix
=
"checkpoint_lenet"
,
config
=
config_ck
)
model
=
Model
(
network
,
net_loss
,
net_opt
,
metrics
=
{
"Accuracy"
:
Accuracy
()})
model
=
Model
(
network
,
net_loss
,
net_opt
,
metrics
=
{
"Accuracy"
:
Accuracy
()})
param_dict
=
load_checkpoint
(
args
.
ckpt_path
)
# load check point into network
param_dict
=
load_checkpoint
(
args
.
ckpt_path
,
network
.
type
)
load_param_into_net
(
network
,
param_dict
)
load_param_into_net
(
network
,
param_dict
)
print
(
"============== Starting Testing =============="
)
print
(
"============== Starting Testing =============="
)
...
...
model_zoo/lenet_quant/eval_quant.py
浏览文件 @
087779b7
...
@@ -23,7 +23,6 @@ import argparse
...
@@ -23,7 +23,6 @@ import argparse
import
mindspore.nn
as
nn
import
mindspore.nn
as
nn
from
mindspore
import
context
from
mindspore
import
context
from
mindspore.train.serialization
import
load_checkpoint
,
load_param_into_net
from
mindspore.train.serialization
import
load_checkpoint
,
load_param_into_net
from
mindspore.train.callback
import
ModelCheckpoint
,
CheckpointConfig
from
mindspore.train
import
Model
from
mindspore.train
import
Model
from
mindspore.nn.metrics
import
Accuracy
from
mindspore.nn.metrics
import
Accuracy
from
mindspore.train.quant
import
quant
from
mindspore.train.quant
import
quant
...
@@ -48,20 +47,21 @@ if __name__ == "__main__":
...
@@ -48,20 +47,21 @@ if __name__ == "__main__":
ds_eval
=
create_dataset
(
os
.
path
.
join
(
args
.
data_path
,
"test"
),
cfg
.
batch_size
,
1
)
ds_eval
=
create_dataset
(
os
.
path
.
join
(
args
.
data_path
,
"test"
),
cfg
.
batch_size
,
1
)
step_size
=
ds_eval
.
get_dataset_size
()
step_size
=
ds_eval
.
get_dataset_size
()
# define fu
n
sion network
# define fusion network
network
=
LeNet5Fusion
(
cfg
.
num_classes
)
network
=
LeNet5Fusion
(
cfg
.
num_classes
)
# convert fu
n
sion netwrok to quantization aware network
# convert fusion netwrok to quantization aware network
network
=
quant
.
convert_quant_network
(
network
,
quant_delay
=
0
,
bn_fold
=
False
,
freeze_bn
=
10000
)
network
=
quant
.
convert_quant_network
(
network
,
quant_delay
=
0
,
bn_fold
=
False
,
freeze_bn
=
10000
)
# define loss
net_loss
=
nn
.
SoftmaxCrossEntropyWithLogits
(
is_grad
=
False
,
sparse
=
True
,
reduction
=
"mean"
)
net_loss
=
nn
.
SoftmaxCrossEntropyWithLogits
(
is_grad
=
False
,
sparse
=
True
,
reduction
=
"mean"
)
# define network optimization
net_opt
=
nn
.
Momentum
(
network
.
trainable_params
(),
cfg
.
lr
,
cfg
.
momentum
)
net_opt
=
nn
.
Momentum
(
network
.
trainable_params
(),
cfg
.
lr
,
cfg
.
momentum
)
config_ck
=
CheckpointConfig
(
save_checkpoint_steps
=
cfg
.
epoch_size
*
step_size
,
keep_checkpoint_max
=
cfg
.
keep_checkpoint_max
)
# call back and monitor
ckpoint_cb
=
ModelCheckpoint
(
prefix
=
"checkpoint_lenet"
,
config
=
config_ck
)
model
=
Model
(
network
,
net_loss
,
net_opt
,
metrics
=
{
"Accuracy"
:
Accuracy
()})
model
=
Model
(
network
,
net_loss
,
net_opt
,
metrics
=
{
"Accuracy"
:
Accuracy
()})
# load quantization aware network checkpoint
# load quantization aware network checkpoint
param_dict
=
load_checkpoint
(
args
.
ckpt_path
)
param_dict
=
load_checkpoint
(
args
.
ckpt_path
,
model_type
=
"quant"
)
load_param_into_net
(
network
,
param_dict
)
load_param_into_net
(
network
,
param_dict
)
print
(
"============== Starting Testing =============="
)
print
(
"============== Starting Testing =============="
)
...
...
model_zoo/lenet_quant/src/lenet.py
浏览文件 @
087779b7
...
@@ -34,8 +34,8 @@ class LeNet5(nn.Cell):
...
@@ -34,8 +34,8 @@ class LeNet5(nn.Cell):
super
(
LeNet5
,
self
).
__init__
()
super
(
LeNet5
,
self
).
__init__
()
self
.
num_class
=
num_class
self
.
num_class
=
num_class
self
.
conv1
=
nn
.
Conv2d
(
channel
,
6
,
5
)
self
.
conv1
=
nn
.
Conv2d
(
channel
,
6
,
5
,
pad_mode
=
'valid'
)
self
.
conv2
=
nn
.
Conv2d
(
6
,
16
,
5
)
self
.
conv2
=
nn
.
Conv2d
(
6
,
16
,
5
,
pad_mode
=
'valid'
)
self
.
fc1
=
nn
.
Dense
(
16
*
5
*
5
,
120
)
self
.
fc1
=
nn
.
Dense
(
16
*
5
*
5
,
120
)
self
.
fc2
=
nn
.
Dense
(
120
,
84
)
self
.
fc2
=
nn
.
Dense
(
120
,
84
)
self
.
fc3
=
nn
.
Dense
(
84
,
self
.
num_class
)
self
.
fc3
=
nn
.
Dense
(
84
,
self
.
num_class
)
...
...
model_zoo/lenet_quant/src/lenet_fusion.py
浏览文件 @
087779b7
...
@@ -32,11 +32,12 @@ class LeNet5(nn.Cell):
...
@@ -32,11 +32,12 @@ class LeNet5(nn.Cell):
def
__init__
(
self
,
num_class
=
10
,
channel
=
1
):
def
__init__
(
self
,
num_class
=
10
,
channel
=
1
):
super
(
LeNet5
,
self
).
__init__
()
super
(
LeNet5
,
self
).
__init__
()
self
.
type
=
"fusion"
self
.
num_class
=
num_class
self
.
num_class
=
num_class
# change `nn.Conv2d` to `nn.Conv2dBnAct`
# change `nn.Conv2d` to `nn.Conv2dBnAct`
self
.
conv1
=
nn
.
Conv2dBnAct
(
channel
,
6
,
5
,
activation
=
'relu'
)
self
.
conv1
=
nn
.
Conv2dBnAct
(
channel
,
6
,
5
,
pad_mode
=
'valid'
,
activation
=
'relu'
)
self
.
conv2
=
nn
.
Conv2dBnAct
(
6
,
16
,
5
,
activation
=
'relu'
)
self
.
conv2
=
nn
.
Conv2dBnAct
(
6
,
16
,
5
,
pad_mode
=
'valid'
,
activation
=
'relu'
)
# change `nn.Dense` to `nn.DenseBnAct`
# change `nn.Dense` to `nn.DenseBnAct`
self
.
fc1
=
nn
.
DenseBnAct
(
16
*
5
*
5
,
120
,
activation
=
'relu'
)
self
.
fc1
=
nn
.
DenseBnAct
(
16
*
5
*
5
,
120
,
activation
=
'relu'
)
self
.
fc2
=
nn
.
DenseBnAct
(
120
,
84
,
activation
=
'relu'
)
self
.
fc2
=
nn
.
DenseBnAct
(
120
,
84
,
activation
=
'relu'
)
...
...
model_zoo/lenet_quant/train.py
浏览文件 @
087779b7
...
@@ -46,16 +46,24 @@ if __name__ == "__main__":
...
@@ -46,16 +46,24 @@ if __name__ == "__main__":
ds_train
=
create_dataset
(
os
.
path
.
join
(
args
.
data_path
,
"train"
),
cfg
.
batch_size
,
cfg
.
epoch_size
)
ds_train
=
create_dataset
(
os
.
path
.
join
(
args
.
data_path
,
"train"
),
cfg
.
batch_size
,
cfg
.
epoch_size
)
step_size
=
ds_train
.
get_dataset_size
()
step_size
=
ds_train
.
get_dataset_size
()
# define fusion network
network
=
LeNet5Fusion
(
cfg
.
num_classes
)
network
=
LeNet5Fusion
(
cfg
.
num_classes
)
# define network loss
net_loss
=
nn
.
SoftmaxCrossEntropyWithLogits
(
is_grad
=
False
,
sparse
=
True
,
reduction
=
"mean"
)
net_loss
=
nn
.
SoftmaxCrossEntropyWithLogits
(
is_grad
=
False
,
sparse
=
True
,
reduction
=
"mean"
)
# define network optimization
net_opt
=
nn
.
Momentum
(
network
.
trainable_params
(),
cfg
.
lr
,
cfg
.
momentum
)
net_opt
=
nn
.
Momentum
(
network
.
trainable_params
(),
cfg
.
lr
,
cfg
.
momentum
)
# call back and monitor
time_cb
=
TimeMonitor
(
data_size
=
ds_train
.
get_dataset_size
())
time_cb
=
TimeMonitor
(
data_size
=
ds_train
.
get_dataset_size
())
config_ck
=
CheckpointConfig
(
save_checkpoint_steps
=
cfg
.
epoch_size
*
step_size
,
config_ckpt
=
CheckpointConfig
(
save_checkpoint_steps
=
cfg
.
epoch_size
*
step_size
,
keep_checkpoint_max
=
cfg
.
keep_checkpoint_max
)
keep_checkpoint_max
=
cfg
.
keep_checkpoint_max
,
ckpoint_cb
=
ModelCheckpoint
(
prefix
=
"checkpoint_lenet"
,
config
=
config_ck
)
model_type
=
network
.
type
)
ckpt_callback
=
ModelCheckpoint
(
prefix
=
"checkpoint_lenet"
,
config
=
config_ckpt
)
# define model
model
=
Model
(
network
,
net_loss
,
net_opt
,
metrics
=
{
"Accuracy"
:
Accuracy
()})
model
=
Model
(
network
,
net_loss
,
net_opt
,
metrics
=
{
"Accuracy"
:
Accuracy
()})
print
(
"============== Starting Training =============="
)
print
(
"============== Starting Training =============="
)
model
.
train
(
cfg
[
'epoch_size'
],
ds_train
,
callbacks
=
[
time_cb
,
ckp
oint_cb
,
LossMonitor
()],
model
.
train
(
cfg
[
'epoch_size'
],
ds_train
,
callbacks
=
[
time_cb
,
ckp
t_callback
,
LossMonitor
()],
dataset_sink_mode
=
args
.
dataset_sink_mode
)
dataset_sink_mode
=
args
.
dataset_sink_mode
)
print
(
"============== End Training =============="
)
print
(
"============== End Training =============="
)
model_zoo/lenet_quant/train_quant.py
浏览文件 @
087779b7
...
@@ -48,23 +48,30 @@ if __name__ == "__main__":
...
@@ -48,23 +48,30 @@ if __name__ == "__main__":
ds_train
=
create_dataset
(
os
.
path
.
join
(
args
.
data_path
,
"train"
),
cfg
.
batch_size
,
cfg
.
epoch_size
)
ds_train
=
create_dataset
(
os
.
path
.
join
(
args
.
data_path
,
"train"
),
cfg
.
batch_size
,
cfg
.
epoch_size
)
step_size
=
ds_train
.
get_dataset_size
()
step_size
=
ds_train
.
get_dataset_size
()
# define fu
n
sion network
# define fusion network
network
=
LeNet5Fusion
(
cfg
.
num_classes
)
network
=
LeNet5Fusion
(
cfg
.
num_classes
)
# load quantization aware network checkpoint
# load quantization aware network checkpoint
param_dict
=
load_checkpoint
(
args
.
ckpt_path
)
param_dict
=
load_checkpoint
(
args
.
ckpt_path
,
network
.
type
)
load_param_into_net
(
network
,
param_dict
)
load_param_into_net
(
network
,
param_dict
)
# convert fu
nsion netwro
k to quantization aware network
# convert fu
sion networ
k to quantization aware network
network
=
quant
.
convert_quant_network
(
network
,
quant_delay
=
0
,
bn_fold
=
False
,
freeze_bn
=
10000
)
network
=
quant
.
convert_quant_network
(
network
,
quant_delay
=
0
,
bn_fold
=
False
,
freeze_bn
=
10000
)
# define network loss
net_loss
=
nn
.
SoftmaxCrossEntropyWithLogits
(
is_grad
=
False
,
sparse
=
True
,
reduction
=
"mean"
)
net_loss
=
nn
.
SoftmaxCrossEntropyWithLogits
(
is_grad
=
False
,
sparse
=
True
,
reduction
=
"mean"
)
# define network optimization
net_opt
=
nn
.
Momentum
(
network
.
trainable_params
(),
cfg
.
lr
,
cfg
.
momentum
)
net_opt
=
nn
.
Momentum
(
network
.
trainable_params
(),
cfg
.
lr
,
cfg
.
momentum
)
# call back and monitor
time_cb
=
TimeMonitor
(
data_size
=
ds_train
.
get_dataset_size
())
time_cb
=
TimeMonitor
(
data_size
=
ds_train
.
get_dataset_size
())
config_ck
=
CheckpointConfig
(
save_checkpoint_steps
=
cfg
.
epoch_size
*
step_size
,
config_ckpt
=
CheckpointConfig
(
save_checkpoint_steps
=
cfg
.
epoch_size
*
step_size
,
keep_checkpoint_max
=
cfg
.
keep_checkpoint_max
)
keep_checkpoint_max
=
cfg
.
keep_checkpoint_max
,
ckpoint_cb
=
ModelCheckpoint
(
prefix
=
"checkpoint_lenet"
,
config
=
config_ck
)
model_type
=
"quant"
)
ckpt_callback
=
ModelCheckpoint
(
prefix
=
"checkpoint_lenet"
,
config
=
config_ckpt
)
# define model
model
=
Model
(
network
,
net_loss
,
net_opt
,
metrics
=
{
"Accuracy"
:
Accuracy
()})
model
=
Model
(
network
,
net_loss
,
net_opt
,
metrics
=
{
"Accuracy"
:
Accuracy
()})
print
(
"============== Starting Training =============="
)
print
(
"============== Starting Training =============="
)
model
.
train
(
cfg
[
'epoch_size'
],
ds_train
,
callbacks
=
[
time_cb
,
ckp
oint_cb
,
LossMonitor
()],
model
.
train
(
cfg
[
'epoch_size'
],
ds_train
,
callbacks
=
[
time_cb
,
ckp
t_callback
,
LossMonitor
()],
dataset_sink_mode
=
args
.
dataset_sink_mode
)
dataset_sink_mode
=
args
.
dataset_sink_mode
)
print
(
"============== End Training =============="
)
print
(
"============== End Training =============="
)
tests/ut/python/predict/test_predict_save_model.py
浏览文件 @
087779b7
...
@@ -85,7 +85,7 @@ if __name__ == '__main__':
...
@@ -85,7 +85,7 @@ if __name__ == '__main__':
is_ckpt_exist
=
os
.
path
.
exists
(
ckpt_file_path
)
is_ckpt_exist
=
os
.
path
.
exists
(
ckpt_file_path
)
if
is_ckpt_exist
:
if
is_ckpt_exist
:
param_dict
=
load_checkpoint
(
ckp
oin
t_file_name
=
ckpt_file_path
)
param_dict
=
load_checkpoint
(
ckpt_file_name
=
ckpt_file_path
)
load_param_into_net
(
net
,
param_dict
)
load_param_into_net
(
net
,
param_dict
)
export
(
net
,
input_data
,
file_name
=
model_path_name
,
file_format
=
'LITE'
)
export
(
net
,
input_data
,
file_name
=
model_path_name
,
file_format
=
'LITE'
)
print
(
"test lenet predict success."
)
print
(
"test lenet predict success."
)
...
...
tests/ut/python/utils/test_serialize.py
浏览文件 @
087779b7
...
@@ -111,19 +111,19 @@ def test_save_checkpoint():
...
@@ -111,19 +111,19 @@ def test_save_checkpoint():
os
.
chmod
(
'./parameters.ckpt'
,
stat
.
S_IWRITE
)
os
.
chmod
(
'./parameters.ckpt'
,
stat
.
S_IWRITE
)
os
.
remove
(
'./parameters.ckpt'
)
os
.
remove
(
'./parameters.ckpt'
)
ckp
oin
t_file_name
=
os
.
path
.
join
(
_cur_dir
,
'./parameters.ckpt'
)
ckpt_file_name
=
os
.
path
.
join
(
_cur_dir
,
'./parameters.ckpt'
)
save_checkpoint
(
parameter_list
,
ckp
oin
t_file_name
)
save_checkpoint
(
parameter_list
,
ckpt_file_name
)
def
test_load_checkpoint_error_filename
():
def
test_load_checkpoint_error_filename
():
ckp
oin
t_file_name
=
1
ckpt_file_name
=
1
with
pytest
.
raises
(
ValueError
):
with
pytest
.
raises
(
ValueError
):
load_checkpoint
(
ckp
oin
t_file_name
)
load_checkpoint
(
ckpt_file_name
)
def
test_load_checkpoint
():
def
test_load_checkpoint
():
ckp
oin
t_file_name
=
os
.
path
.
join
(
_cur_dir
,
'./parameters.ckpt'
)
ckpt_file_name
=
os
.
path
.
join
(
_cur_dir
,
'./parameters.ckpt'
)
par_dict
=
load_checkpoint
(
ckp
oin
t_file_name
)
par_dict
=
load_checkpoint
(
ckpt_file_name
)
assert
len
(
par_dict
)
==
3
assert
len
(
par_dict
)
==
3
assert
par_dict
[
'param_test'
].
name
==
'param_test'
assert
par_dict
[
'param_test'
].
name
==
'param_test'
...
@@ -136,17 +136,17 @@ def test_checkpoint_manager():
...
@@ -136,17 +136,17 @@ def test_checkpoint_manager():
""" test_checkpoint_manager """
""" test_checkpoint_manager """
ckp_mgr
=
_CheckpointManager
()
ckp_mgr
=
_CheckpointManager
()
ckp
oin
t_file_name
=
os
.
path
.
join
(
_cur_dir
,
'./test1.ckpt'
)
ckpt_file_name
=
os
.
path
.
join
(
_cur_dir
,
'./test1.ckpt'
)
with
open
(
ckp
oin
t_file_name
,
'w'
):
with
open
(
ckpt_file_name
,
'w'
):
os
.
chmod
(
ckp
oin
t_file_name
,
stat
.
S_IWUSR
|
stat
.
S_IRUSR
)
os
.
chmod
(
ckpt_file_name
,
stat
.
S_IWUSR
|
stat
.
S_IRUSR
)
ckp_mgr
.
update_ckpoint_filelist
(
_cur_dir
,
"test"
)
ckp_mgr
.
update_ckpoint_filelist
(
_cur_dir
,
"test"
)
assert
ckp_mgr
.
ckpoint_num
==
1
assert
ckp_mgr
.
ckpoint_num
==
1
ckp_mgr
.
remove_ckpoint_file
(
ckp
oin
t_file_name
)
ckp_mgr
.
remove_ckpoint_file
(
ckpt_file_name
)
ckp_mgr
.
update_ckpoint_filelist
(
_cur_dir
,
"test"
)
ckp_mgr
.
update_ckpoint_filelist
(
_cur_dir
,
"test"
)
assert
ckp_mgr
.
ckpoint_num
==
0
assert
ckp_mgr
.
ckpoint_num
==
0
assert
not
os
.
path
.
exists
(
ckp
oin
t_file_name
)
assert
not
os
.
path
.
exists
(
ckpt_file_name
)
another_file_name
=
os
.
path
.
join
(
_cur_dir
,
'./test2.ckpt'
)
another_file_name
=
os
.
path
.
join
(
_cur_dir
,
'./test2.ckpt'
)
another_file_name
=
os
.
path
.
realpath
(
another_file_name
)
another_file_name
=
os
.
path
.
realpath
(
another_file_name
)
...
@@ -283,7 +283,7 @@ def test_exec_save_checkpoint():
...
@@ -283,7 +283,7 @@ def test_exec_save_checkpoint():
loss_net
=
WithLossCell
(
net
,
loss
)
loss_net
=
WithLossCell
(
net
,
loss
)
train_network
=
TrainOneStepCell
(
loss_net
,
opt
)
train_network
=
TrainOneStepCell
(
loss_net
,
opt
)
_exec_save_checkpoint
(
train_network
,
ckp
oin
t_file_name
=
"./new_ckpt.ckpt"
)
_exec_save_checkpoint
(
train_network
,
ckpt_file_name
=
"./new_ckpt.ckpt"
)
load_checkpoint
(
"new_ckpt.ckpt"
)
load_checkpoint
(
"new_ckpt.ckpt"
)
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录