Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
magicwindyyd
mindspore
提交
9b7a426c
M
mindspore
项目概览
magicwindyyd
/
mindspore
与 Fork 源项目一致
Fork自
MindSpore / mindspore
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
M
mindspore
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
9b7a426c
编写于
7月 03, 2020
作者:
C
chenzomi
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
bug fix in auto create quant graph in master
上级
cf6dd99e
变更
9
隐藏空白更改
内联
并排
Showing
9 changed file
with
98 addition
and
34 deletion
+98
-34
mindspore/nn/layer/quant.py
mindspore/nn/layer/quant.py
+2
-2
mindspore/train/callback/_loss_monitor.py
mindspore/train/callback/_loss_monitor.py
+1
-1
mindspore/train/quant/quant.py
mindspore/train/quant/quant.py
+25
-18
mindspore/train/serialization.py
mindspore/train/serialization.py
+4
-1
model_zoo/lenet_quant/eval.py
model_zoo/lenet_quant/eval.py
+1
-1
model_zoo/lenet_quant/eval_quant.py
model_zoo/lenet_quant/eval_quant.py
+1
-1
model_zoo/lenet_quant/export.py
model_zoo/lenet_quant/export.py
+56
-0
model_zoo/lenet_quant/train.py
model_zoo/lenet_quant/train.py
+2
-3
model_zoo/lenet_quant/train_quant.py
model_zoo/lenet_quant/train_quant.py
+6
-7
未找到文件。
mindspore/nn/layer/quant.py
浏览文件 @
9b7a426c
...
@@ -1193,9 +1193,9 @@ class QuantBlock(Cell):
...
@@ -1193,9 +1193,9 @@ class QuantBlock(Cell):
self
.
dequant
=
dequant_op
self
.
dequant
=
dequant_op
self
.
dequant_scale
=
dequant_scale
self
.
dequant_scale
=
dequant_scale
self
.
bias
=
bias
self
.
bias
=
bias
self
.
has_bias
=
bias
is
None
self
.
has_bias
=
bias
is
not
None
self
.
activation
=
activation
self
.
activation
=
activation
self
.
has_act
=
activation
is
None
self
.
has_act
=
activation
is
not
None
self
.
bias_add
=
P
.
BiasAdd
()
self
.
bias_add
=
P
.
BiasAdd
()
def
construct
(
self
,
x
):
def
construct
(
self
,
x
):
...
...
mindspore/train/callback/_loss_monitor.py
浏览文件 @
9b7a426c
...
@@ -86,7 +86,7 @@ class LossMonitor(Callback):
...
@@ -86,7 +86,7 @@ class LossMonitor(Callback):
if
self
.
_per_print_times
!=
0
and
cb_params
.
cur_step_num
%
self
.
_per_print_times
==
0
:
if
self
.
_per_print_times
!=
0
and
cb_params
.
cur_step_num
%
self
.
_per_print_times
==
0
:
print
(
"Epoch: [{:3d}/{:3d}], step: [{:5d}/{:5d}], "
print
(
"Epoch: [{:3d}/{:3d}], step: [{:5d}/{:5d}], "
"loss: [{:5.4f}], avg los: [{:5.4f}], time: [{:5.4f}]"
.
format
(
"loss: [{:5.4f}], avg los: [{:5.4f}], time: [{:5.4f}
ms
]"
.
format
(
cb_params
.
cur_epoch_num
,
cb_params
.
epoch_num
,
cb_params
.
cur_epoch_num
,
cb_params
.
epoch_num
,
cur_step_in_epoch
,
int
(
cb_params
.
batch_num
),
cur_step_in_epoch
,
int
(
cb_params
.
batch_num
),
step_loss
,
np
.
mean
(
self
.
losses
),
step_loss
,
np
.
mean
(
self
.
losses
),
...
...
mindspore/train/quant/quant.py
浏览文件 @
9b7a426c
...
@@ -33,7 +33,6 @@ from ...ops.operations import _inner_ops as inner
...
@@ -33,7 +33,6 @@ from ...ops.operations import _inner_ops as inner
from
...train
import
serialization
from
...train
import
serialization
from
.
import
quant_utils
from
.
import
quant_utils
_ACTIVATION_MAP
=
{
nn
.
ReLU
:
quant
.
ReLUQuant
,
_ACTIVATION_MAP
=
{
nn
.
ReLU
:
quant
.
ReLUQuant
,
nn
.
ReLU6
:
quant
.
ReLU6Quant
,
nn
.
ReLU6
:
quant
.
ReLU6Quant
,
nn
.
HSigmoid
:
quant
.
HSigmoidQuant
,
nn
.
HSigmoid
:
quant
.
HSigmoidQuant
,
...
@@ -178,7 +177,6 @@ class ConvertToQuantNetwork:
...
@@ -178,7 +177,6 @@ class ConvertToQuantNetwork:
dilation
=
conv_inner
.
dilation
,
dilation
=
conv_inner
.
dilation
,
group
=
conv_inner
.
group
,
group
=
conv_inner
.
group
,
eps
=
bn_inner
.
eps
,
eps
=
bn_inner
.
eps
,
momentum
=
1
-
bn_inner
.
momentum
,
quant_delay
=
self
.
weight_qdelay
,
quant_delay
=
self
.
weight_qdelay
,
freeze_bn
=
self
.
freeze_bn
,
freeze_bn
=
self
.
freeze_bn
,
per_channel
=
self
.
weight_channel
,
per_channel
=
self
.
weight_channel
,
...
@@ -268,16 +266,16 @@ class ConvertToQuantNetwork:
...
@@ -268,16 +266,16 @@ class ConvertToQuantNetwork:
narrow_range
=
self
.
act_range
)
narrow_range
=
self
.
act_range
)
class
Export
QuantNetworkDeploy
:
class
Export
ToQuantInferNetwork
:
"""
"""
Convert quantization aware network to
deploy
network.
Convert quantization aware network to
infer
network.
Args:
Args:
network (Cell): MindSpore network
produced by
`convert_quant_network`.
network (Cell): MindSpore network
API
`convert_quant_network`.
inputs (Tensor): Input
s of the `
network`.
inputs (Tensor): Input
tensors of the `quantization aware training
network`.
Returns:
Returns:
Cell,
converted
network.
Cell,
GEIR backend Infer
network.
"""
"""
__quant_op_name__
=
[
"TensorAdd"
,
"Sub"
,
"Mul"
,
"RealDiv"
]
__quant_op_name__
=
[
"TensorAdd"
,
"Sub"
,
"Mul"
,
"RealDiv"
]
...
@@ -287,7 +285,7 @@ class ExportQuantNetworkDeploy:
...
@@ -287,7 +285,7 @@ class ExportQuantNetworkDeploy:
network
=
validator
.
check_isinstance
(
'network'
,
network
,
(
nn
.
Cell
,))
network
=
validator
.
check_isinstance
(
'network'
,
network
,
(
nn
.
Cell
,))
self
.
data_type
=
mstype
.
int8
self
.
data_type
=
mstype
.
int8
self
.
network
=
copy
.
deepcopy
(
network
)
self
.
network
=
copy
.
deepcopy
(
network
)
self
.
all_paramters
=
{
p
.
name
:
p
for
p
in
self
.
network
.
get_parameters
()}
self
.
all_param
e
ters
=
{
p
.
name
:
p
for
p
in
self
.
network
.
get_parameters
()}
self
.
get_inputs_table
(
inputs
)
self
.
get_inputs_table
(
inputs
)
def
get_inputs_table
(
self
,
inputs
):
def
get_inputs_table
(
self
,
inputs
):
...
@@ -315,8 +313,8 @@ class ExportQuantNetworkDeploy:
...
@@ -315,8 +313,8 @@ class ExportQuantNetworkDeploy:
info
=
self
.
quant_info_table
.
get
(
w_minq_name
,
None
)
info
=
self
.
quant_info_table
.
get
(
w_minq_name
,
None
)
if
info
:
if
info
:
fack_quant_a_in_op
,
minq_name
=
info
fack_quant_a_in_op
,
minq_name
=
info
maxq
=
self
.
all_paramters
[
minq_name
[:
-
4
]
+
"maxq"
]
maxq
=
self
.
all_param
e
ters
[
minq_name
[:
-
4
]
+
"maxq"
]
minq
=
self
.
all_paramters
[
minq_name
]
minq
=
self
.
all_param
e
ters
[
minq_name
]
scale_a_in
,
zp_a_in
=
quant_utils
.
scale_zp_from_data
(
fack_quant_a_in_op
,
maxq
,
minq
,
np_type
)
scale_a_in
,
zp_a_in
=
quant_utils
.
scale_zp_from_data
(
fack_quant_a_in_op
,
maxq
,
minq
,
np_type
)
else
:
else
:
logger
.
warning
(
f
"Do not find `fake_quant` from input with `fack_quant.minq`
{
w_minq_name
}
"
)
logger
.
warning
(
f
"Do not find `fake_quant` from input with `fack_quant.minq`
{
w_minq_name
}
"
)
...
@@ -357,7 +355,7 @@ class ExportQuantNetworkDeploy:
...
@@ -357,7 +355,7 @@ class ExportQuantNetworkDeploy:
return
block
return
block
def
_convert_quant2deploy
(
self
,
network
):
def
_convert_quant2deploy
(
self
,
network
):
"""Convet network's all quant subcell to deploy subcell."""
"""Conve
r
t network's all quant subcell to deploy subcell."""
cells
=
network
.
name_cells
()
cells
=
network
.
name_cells
()
change
=
False
change
=
False
for
name
in
cells
:
for
name
in
cells
:
...
@@ -395,18 +393,26 @@ class ExportQuantNetworkDeploy:
...
@@ -395,18 +393,26 @@ class ExportQuantNetworkDeploy:
return
network
return
network
def
export
_geir
(
network
,
*
inputs
,
file_name
):
def
export
(
network
,
*
inputs
,
file_name
,
file_format
=
'GEIR'
):
"""
"""
Exports MindSpore quant predict model to deploy with GEIR.
Exports MindSpore quant
ization
predict model to deploy with GEIR.
Args:
Args:
network (Cell): MindSpore network produced by `convert_quant_network`.
network (Cell): MindSpore network produced by `convert_quant_network`.
inputs (Tensor): Inputs of the `network`.
inputs (Tensor): Inputs of the `
quantization aware training
network`.
file_name (str): File name of model to export.
file_name (str): File name of model to export.
file_format (str): MindSpore currently supports 'GEIR' format for exported quantization aware model.
- GEIR: Graph Engine Intermediate Representation. An Intermediate representation format of Ascend model.
"""
"""
exporter
=
ExportQuantNetworkDeploy
(
network
,
*
inputs
)
supported_formats
=
[
'GEIR'
]
deploy_net
=
exporter
.
run
()
serialization
.
export
(
deploy_net
,
*
inputs
,
file_name
=
file_name
,
file_format
=
"GEIR"
)
if
file_format
not
in
supported_formats
:
raise
ValueError
(
'Illegal file format {}.'
.
format
(
file_format
))
if
file_format
==
'GEIR'
:
exporter
=
ExportToQuantInferNetwork
(
network
,
*
inputs
)
deploy_net
=
exporter
.
run
()
serialization
.
export
(
deploy_net
,
*
inputs
,
file_name
=
file_name
,
file_format
=
file_format
)
def
convert_quant_network
(
network
,
def
convert_quant_network
(
network
,
...
@@ -443,6 +449,7 @@ def convert_quant_network(network,
...
@@ -443,6 +449,7 @@ def convert_quant_network(network,
Cell, Network which has change to quantization aware training network cell.
Cell, Network which has change to quantization aware training network cell.
"""
"""
support_device
=
[
"Ascend"
,
"GPU"
]
support_device
=
[
"Ascend"
,
"GPU"
]
def
convert2list
(
name
,
value
):
def
convert2list
(
name
,
value
):
if
not
isinstance
(
value
,
list
)
and
not
isinstance
(
value
,
tuple
):
if
not
isinstance
(
value
,
list
)
and
not
isinstance
(
value
,
tuple
):
value
=
[
value
]
value
=
[
value
]
...
@@ -457,7 +464,7 @@ def convert_quant_network(network,
...
@@ -457,7 +464,7 @@ def convert_quant_network(network,
narrow_range
=
convert2list
(
"narrow range"
,
narrow_range
)
narrow_range
=
convert2list
(
"narrow range"
,
narrow_range
)
if
context
.
get_context
(
'device_target'
)
not
in
support_device
:
if
context
.
get_context
(
'device_target'
)
not
in
support_device
:
raise
KeyError
(
"
Not support {} backend
."
.
format
(
context
.
get_context
(
'device_target'
)))
raise
KeyError
(
"
Unsupported {} device target
."
.
format
(
context
.
get_context
(
'device_target'
)))
net
=
ConvertToQuantNetwork
(
network
=
network
,
net
=
ConvertToQuantNetwork
(
network
=
network
,
quant_delay
=
quant_delay
,
quant_delay
=
quant_delay
,
...
...
mindspore/train/serialization.py
浏览文件 @
9b7a426c
...
@@ -160,7 +160,10 @@ def load_checkpoint(ckpt_file_name, net=None):
...
@@ -160,7 +160,10 @@ def load_checkpoint(ckpt_file_name, net=None):
if
not
isinstance
(
ckpt_file_name
,
str
):
if
not
isinstance
(
ckpt_file_name
,
str
):
raise
ValueError
(
"The ckpt_file_name must be string."
)
raise
ValueError
(
"The ckpt_file_name must be string."
)
if
not
os
.
path
.
exists
(
ckpt_file_name
)
or
ckpt_file_name
[
-
5
:]
!=
".ckpt"
:
if
not
os
.
path
.
exists
(
ckpt_file_name
):
raise
ValueError
(
"The checkpoint file is not exist."
)
if
ckpt_file_name
[
-
5
:]
!=
".ckpt"
:
raise
ValueError
(
"Please input the correct checkpoint file name."
)
raise
ValueError
(
"Please input the correct checkpoint file name."
)
if
os
.
path
.
getsize
(
ckpt_file_name
)
==
0
:
if
os
.
path
.
getsize
(
ckpt_file_name
)
==
0
:
...
...
model_zoo/lenet_quant/eval.py
浏览文件 @
9b7a426c
...
@@ -57,7 +57,7 @@ if __name__ == "__main__":
...
@@ -57,7 +57,7 @@ if __name__ == "__main__":
model
=
Model
(
network
,
net_loss
,
net_opt
,
metrics
=
{
"Accuracy"
:
Accuracy
()})
model
=
Model
(
network
,
net_loss
,
net_opt
,
metrics
=
{
"Accuracy"
:
Accuracy
()})
# load check point into network
# load check point into network
param_dict
=
load_checkpoint
(
args
.
ckpt_path
,
network
.
type
)
param_dict
=
load_checkpoint
(
args
.
ckpt_path
)
load_param_into_net
(
network
,
param_dict
)
load_param_into_net
(
network
,
param_dict
)
print
(
"============== Starting Testing =============="
)
print
(
"============== Starting Testing =============="
)
...
...
model_zoo/lenet_quant/eval_quant.py
浏览文件 @
9b7a426c
...
@@ -49,7 +49,7 @@ if __name__ == "__main__":
...
@@ -49,7 +49,7 @@ if __name__ == "__main__":
# define fusion network
# define fusion network
network
=
LeNet5Fusion
(
cfg
.
num_classes
)
network
=
LeNet5Fusion
(
cfg
.
num_classes
)
# convert fusion netw
ro
k to quantization aware network
# convert fusion netw
or
k to quantization aware network
network
=
quant
.
convert_quant_network
(
network
,
quant_delay
=
0
,
bn_fold
=
False
,
freeze_bn
=
10000
)
network
=
quant
.
convert_quant_network
(
network
,
quant_delay
=
0
,
bn_fold
=
False
,
freeze_bn
=
10000
)
# define loss
# define loss
...
...
model_zoo/lenet_quant/export.py
0 → 100644
浏览文件 @
9b7a426c
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""
export quantization aware training network to infer `GEIR` backend.
"""
import
argparse
import
numpy
as
np
import
mindspore
from
mindspore
import
Tensor
from
mindspore
import
context
from
mindspore.train.quant
import
quant
from
mindspore.train.serialization
import
load_checkpoint
,
load_param_into_net
from
src.config
import
mnist_cfg
as
cfg
from
src.lenet_fusion
import
LeNet5
as
LeNet5Fusion
parser
=
argparse
.
ArgumentParser
(
description
=
'MindSpore MNIST Example'
)
parser
.
add_argument
(
'--device_target'
,
type
=
str
,
default
=
"Ascend"
,
choices
=
[
'Ascend'
,
'GPU'
],
help
=
'device where the code will be implemented (default: Ascend)'
)
parser
.
add_argument
(
'--data_path'
,
type
=
str
,
default
=
"./MNIST_Data"
,
help
=
'path where the dataset is saved'
)
parser
.
add_argument
(
'--ckpt_path'
,
type
=
str
,
default
=
""
,
help
=
'if mode is test, must provide path where the trained ckpt file'
)
parser
.
add_argument
(
'--dataset_sink_mode'
,
type
=
bool
,
default
=
True
,
help
=
'dataset_sink_mode is False or True'
)
args
=
parser
.
parse_args
()
if
__name__
==
"__main__"
:
context
.
set_context
(
mode
=
context
.
GRAPH_MODE
,
device_target
=
args
.
device_target
)
# define fusion network
network
=
LeNet5Fusion
(
cfg
.
num_classes
)
# convert fusion network to quantization aware network
network
=
quant
.
convert_quant_network
(
network
,
quant_delay
=
0
,
bn_fold
=
False
,
freeze_bn
=
10000
)
# load quantization aware network checkpoint
param_dict
=
load_checkpoint
(
args
.
ckpt_path
)
load_param_into_net
(
network
,
param_dict
)
# export network
inputs
=
Tensor
(
np
.
ones
([
1
,
1
,
cfg
.
image_height
,
cfg
.
image_width
]),
mindspore
.
float32
)
quant
.
export
(
network
,
inputs
,
file_name
=
"lenet_quant"
,
file_format
=
'GEIR'
)
model_zoo/lenet_quant/train.py
浏览文件 @
9b7a426c
...
@@ -22,7 +22,7 @@ import os
...
@@ -22,7 +22,7 @@ import os
import
argparse
import
argparse
import
mindspore.nn
as
nn
import
mindspore.nn
as
nn
from
mindspore
import
context
from
mindspore
import
context
from
mindspore.train.callback
import
ModelCheckpoint
,
CheckpointConfig
,
LossMonitor
,
TimeMonitor
from
mindspore.train.callback
import
ModelCheckpoint
,
CheckpointConfig
,
LossMonitor
from
mindspore.train
import
Model
from
mindspore.train
import
Model
from
mindspore.nn.metrics
import
Accuracy
from
mindspore.nn.metrics
import
Accuracy
from
src.dataset
import
create_dataset
from
src.dataset
import
create_dataset
...
@@ -54,7 +54,6 @@ if __name__ == "__main__":
...
@@ -54,7 +54,6 @@ if __name__ == "__main__":
net_opt
=
nn
.
Momentum
(
network
.
trainable_params
(),
cfg
.
lr
,
cfg
.
momentum
)
net_opt
=
nn
.
Momentum
(
network
.
trainable_params
(),
cfg
.
lr
,
cfg
.
momentum
)
# call back and monitor
# call back and monitor
time_cb
=
TimeMonitor
(
data_size
=
ds_train
.
get_dataset_size
())
config_ckpt
=
CheckpointConfig
(
save_checkpoint_steps
=
cfg
.
epoch_size
*
step_size
,
config_ckpt
=
CheckpointConfig
(
save_checkpoint_steps
=
cfg
.
epoch_size
*
step_size
,
keep_checkpoint_max
=
cfg
.
keep_checkpoint_max
)
keep_checkpoint_max
=
cfg
.
keep_checkpoint_max
)
ckpt_callback
=
ModelCheckpoint
(
prefix
=
"checkpoint_lenet"
,
config
=
config_ckpt
)
ckpt_callback
=
ModelCheckpoint
(
prefix
=
"checkpoint_lenet"
,
config
=
config_ckpt
)
...
@@ -63,6 +62,6 @@ if __name__ == "__main__":
...
@@ -63,6 +62,6 @@ if __name__ == "__main__":
model
=
Model
(
network
,
net_loss
,
net_opt
,
metrics
=
{
"Accuracy"
:
Accuracy
()})
model
=
Model
(
network
,
net_loss
,
net_opt
,
metrics
=
{
"Accuracy"
:
Accuracy
()})
print
(
"============== Starting Training =============="
)
print
(
"============== Starting Training =============="
)
model
.
train
(
cfg
[
'epoch_size'
],
ds_train
,
callbacks
=
[
time_cb
,
ckpt_callback
,
LossMonitor
()],
model
.
train
(
cfg
[
'epoch_size'
],
ds_train
,
callbacks
=
[
ckpt_callback
,
LossMonitor
()],
dataset_sink_mode
=
args
.
dataset_sink_mode
)
dataset_sink_mode
=
args
.
dataset_sink_mode
)
print
(
"============== End Training =============="
)
print
(
"============== End Training =============="
)
model_zoo/lenet_quant/train_quant.py
浏览文件 @
9b7a426c
...
@@ -23,7 +23,7 @@ import argparse
...
@@ -23,7 +23,7 @@ import argparse
import
mindspore.nn
as
nn
import
mindspore.nn
as
nn
from
mindspore
import
context
from
mindspore
import
context
from
mindspore.train.serialization
import
load_checkpoint
,
load_param_into_net
from
mindspore.train.serialization
import
load_checkpoint
,
load_param_into_net
from
mindspore.train.callback
import
ModelCheckpoint
,
CheckpointConfig
,
LossMonitor
,
TimeMonitor
from
mindspore.train.callback
import
ModelCheckpoint
,
CheckpointConfig
,
LossMonitor
from
mindspore.train
import
Model
from
mindspore.train
import
Model
from
mindspore.nn.metrics
import
Accuracy
from
mindspore.nn.metrics
import
Accuracy
from
mindspore.train.quant
import
quant
from
mindspore.train.quant
import
quant
...
@@ -51,20 +51,19 @@ if __name__ == "__main__":
...
@@ -51,20 +51,19 @@ if __name__ == "__main__":
# define fusion network
# define fusion network
network
=
LeNet5Fusion
(
cfg
.
num_classes
)
network
=
LeNet5Fusion
(
cfg
.
num_classes
)
# convert fusion network to quantization aware network
network
=
quant
.
convert_quant_network
(
network
,
quant_delay
=
0
,
bn_fold
=
False
,
freeze_bn
=
10000
)
# load quantization aware network checkpoint
# load quantization aware network checkpoint
param_dict
=
load_checkpoint
(
args
.
ckpt_path
,
network
.
type
)
param_dict
=
load_checkpoint
(
args
.
ckpt_path
)
load_param_into_net
(
network
,
param_dict
)
load_param_into_net
(
network
,
param_dict
)
# convert fusion network to quantization aware network
network
=
quant
.
convert_quant_network
(
network
,
quant_delay
=
0
,
bn_fold
=
False
,
freeze_bn
=
10000
)
# define network loss
# define network loss
net_loss
=
nn
.
SoftmaxCrossEntropyWithLogits
(
is_grad
=
False
,
sparse
=
True
,
reduction
=
"mean"
)
net_loss
=
nn
.
SoftmaxCrossEntropyWithLogits
(
is_grad
=
False
,
sparse
=
True
,
reduction
=
"mean"
)
# define network optimization
# define network optimization
net_opt
=
nn
.
Momentum
(
network
.
trainable_params
(),
cfg
.
lr
,
cfg
.
momentum
)
net_opt
=
nn
.
Momentum
(
network
.
trainable_params
(),
cfg
.
lr
,
cfg
.
momentum
)
# call back and monitor
# call back and monitor
time_cb
=
TimeMonitor
(
data_size
=
ds_train
.
get_dataset_size
())
config_ckpt
=
CheckpointConfig
(
save_checkpoint_steps
=
cfg
.
epoch_size
*
step_size
,
config_ckpt
=
CheckpointConfig
(
save_checkpoint_steps
=
cfg
.
epoch_size
*
step_size
,
keep_checkpoint_max
=
cfg
.
keep_checkpoint_max
)
keep_checkpoint_max
=
cfg
.
keep_checkpoint_max
)
ckpt_callback
=
ModelCheckpoint
(
prefix
=
"checkpoint_lenet"
,
config
=
config_ckpt
)
ckpt_callback
=
ModelCheckpoint
(
prefix
=
"checkpoint_lenet"
,
config
=
config_ckpt
)
...
@@ -73,6 +72,6 @@ if __name__ == "__main__":
...
@@ -73,6 +72,6 @@ if __name__ == "__main__":
model
=
Model
(
network
,
net_loss
,
net_opt
,
metrics
=
{
"Accuracy"
:
Accuracy
()})
model
=
Model
(
network
,
net_loss
,
net_opt
,
metrics
=
{
"Accuracy"
:
Accuracy
()})
print
(
"============== Starting Training =============="
)
print
(
"============== Starting Training =============="
)
model
.
train
(
cfg
[
'epoch_size'
],
ds_train
,
callbacks
=
[
time_cb
,
ckpt_callback
,
LossMonitor
()],
model
.
train
(
cfg
[
'epoch_size'
],
ds_train
,
callbacks
=
[
ckpt_callback
,
LossMonitor
()],
dataset_sink_mode
=
args
.
dataset_sink_mode
)
dataset_sink_mode
=
args
.
dataset_sink_mode
)
print
(
"============== End Training =============="
)
print
(
"============== End Training =============="
)
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录