Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
magicwindyyd
mindspore
提交
a71868f1
M
mindspore
项目概览
magicwindyyd
/
mindspore
与 Fork 源项目一致
Fork自
MindSpore / mindspore
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
M
mindspore
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
a71868f1
编写于
8月 20, 2020
作者:
C
chenfei
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
add ci for quant
add ci for quant
上级
2a5d90dc
变更
8
隐藏空白更改
内联
并排
Showing
8 changed file
with
386 addition
and
11 deletion
+386
-11
model_zoo/official/cv/lenet_quant/eval_quant.py
model_zoo/official/cv/lenet_quant/eval_quant.py
+0
-1
model_zoo/official/cv/resnet50_quant/eval.py
model_zoo/official/cv/resnet50_quant/eval.py
+7
-8
model_zoo/official/cv/resnet50_quant/src/dataset.py
model_zoo/official/cv/resnet50_quant/src/dataset.py
+2
-2
tests/st/quantization/lenet_quant/config.py
tests/st/quantization/lenet_quant/config.py
+44
-0
tests/st/quantization/lenet_quant/dataset.py
tests/st/quantization/lenet_quant/dataset.py
+60
-0
tests/st/quantization/lenet_quant/lenet.py
tests/st/quantization/lenet_quant/lenet.py
+79
-0
tests/st/quantization/lenet_quant/lenet_fusion.py
tests/st/quantization/lenet_quant/lenet_fusion.py
+58
-0
tests/st/quantization/lenet_quant/test_lenet_quant.py
tests/st/quantization/lenet_quant/test_lenet_quant.py
+136
-0
未找到文件。
model_zoo/official/cv/lenet_quant/eval_quant.py
浏览文件 @
a71868f1
...
...
@@ -45,7 +45,6 @@ args = parser.parse_args()
if
__name__
==
"__main__"
:
context
.
set_context
(
mode
=
context
.
GRAPH_MODE
,
device_target
=
args
.
device_target
)
ds_eval
=
create_dataset
(
os
.
path
.
join
(
args
.
data_path
,
"test"
),
cfg
.
batch_size
,
1
)
step_size
=
ds_eval
.
get_dataset_size
()
# define fusion network
network
=
LeNet5Fusion
(
cfg
.
num_classes
)
...
...
model_zoo/official/cv/resnet50_quant/eval.py
浏览文件 @
a71868f1
...
...
@@ -17,7 +17,7 @@
import
os
import
argparse
from
src.config
import
quant_set
,
config_quant
,
config_no
quant
from
src.config
import
config_
quant
from
src.dataset
import
create_dataset
from
src.crossentropy
import
CrossEntropy
from
models.resnet_quant
import
resnet50_quant
...
...
@@ -34,7 +34,7 @@ parser.add_argument('--device_target', type=str, default='Ascend', help='Device
args_opt
=
parser
.
parse_args
()
context
.
set_context
(
mode
=
context
.
GRAPH_MODE
,
device_target
=
args_opt
.
device_target
,
save_graphs
=
False
)
config
=
config_quant
if
quant_set
.
quantization_aware
else
config_noquant
config
=
config_quant
if
args_opt
.
device_target
==
"Ascend"
:
device_id
=
int
(
os
.
getenv
(
'DEVICE_ID'
))
...
...
@@ -43,12 +43,11 @@ if args_opt.device_target == "Ascend":
if
__name__
==
'__main__'
:
# define fusion network
net
=
resnet50_quant
(
class_num
=
config
.
class_num
)
if
quant_set
.
quantization_aware
:
# convert fusion network to quantization aware network
net
=
quant
.
convert_quant_network
(
net
,
bn_fold
=
True
,
per_channel
=
[
True
,
False
],
symmetric
=
[
True
,
False
])
# convert fusion network to quantization aware network
net
=
quant
.
convert_quant_network
(
net
,
bn_fold
=
True
,
per_channel
=
[
True
,
False
],
symmetric
=
[
True
,
False
])
# define network loss
if
not
config
.
use_label_smooth
:
config
.
label_smooth_factor
=
0.0
...
...
model_zoo/official/cv/resnet50_quant/src/dataset.py
浏览文件 @
a71868f1
...
...
@@ -23,9 +23,9 @@ import mindspore.dataset.transforms.vision.c_transforms as C
import
mindspore.dataset.transforms.c_transforms
as
C2
import
mindspore.dataset.transforms.vision.py_transforms
as
P
from
mindspore.communication.management
import
init
,
get_rank
,
get_group_size
from
src.config
import
quant_set
,
config_quant
,
config_no
quant
from
src.config
import
config_
quant
config
=
config_quant
if
quant_set
.
quantization_aware
else
config_noquant
config
=
config_quant
def
create_dataset
(
dataset_path
,
do_train
,
repeat_num
=
1
,
batch_size
=
32
,
target
=
"Ascend"
):
...
...
tests/st/quantization/lenet_quant/config.py
0 → 100644
浏览文件 @
a71868f1
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""
network config setting, will be used in test_lenet_quant.py
"""
from
easydict
import
EasyDict
as
edict
nonquant_cfg
=
edict
({
'num_classes'
:
10
,
'lr'
:
0.01
,
'momentum'
:
0.9
,
'epoch_size'
:
10
,
'batch_size'
:
32
,
'buffer_size'
:
1000
,
'image_height'
:
32
,
'image_width'
:
32
,
'save_checkpoint_steps'
:
1875
,
'keep_checkpoint_max'
:
10
,
})
quant_cfg
=
edict
({
'num_classes'
:
10
,
'lr'
:
0.01
,
'momentum'
:
0.9
,
'epoch_size'
:
10
,
'batch_size'
:
64
,
'buffer_size'
:
1000
,
'image_height'
:
32
,
'image_width'
:
32
,
'keep_checkpoint_max'
:
10
,
})
tests/st/quantization/lenet_quant/dataset.py
0 → 100644
浏览文件 @
a71868f1
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""
Produce the dataset
"""
import
mindspore.dataset
as
ds
import
mindspore.dataset.transforms.vision.c_transforms
as
CV
import
mindspore.dataset.transforms.c_transforms
as
C
from
mindspore.dataset.transforms.vision
import
Inter
from
mindspore.common
import
dtype
as
mstype
def
create_dataset
(
data_path
,
batch_size
=
32
,
repeat_size
=
1
,
num_parallel_workers
=
1
):
"""
create dataset for train or test
"""
# define dataset
mnist_ds
=
ds
.
MnistDataset
(
data_path
)
resize_height
,
resize_width
=
32
,
32
rescale
=
1.0
/
255.0
shift
=
0.0
rescale_nml
=
1
/
0.3081
shift_nml
=
-
1
*
0.1307
/
0.3081
# define map operations
resize_op
=
CV
.
Resize
((
resize_height
,
resize_width
),
interpolation
=
Inter
.
LINEAR
)
# Bilinear mode
rescale_nml_op
=
CV
.
Rescale
(
rescale_nml
,
shift_nml
)
rescale_op
=
CV
.
Rescale
(
rescale
,
shift
)
hwc2chw_op
=
CV
.
HWC2CHW
()
type_cast_op
=
C
.
TypeCast
(
mstype
.
int32
)
# apply map operations on images
mnist_ds
=
mnist_ds
.
map
(
input_columns
=
"label"
,
operations
=
type_cast_op
,
num_parallel_workers
=
num_parallel_workers
)
mnist_ds
=
mnist_ds
.
map
(
input_columns
=
"image"
,
operations
=
resize_op
,
num_parallel_workers
=
num_parallel_workers
)
mnist_ds
=
mnist_ds
.
map
(
input_columns
=
"image"
,
operations
=
rescale_op
,
num_parallel_workers
=
num_parallel_workers
)
mnist_ds
=
mnist_ds
.
map
(
input_columns
=
"image"
,
operations
=
rescale_nml_op
,
num_parallel_workers
=
num_parallel_workers
)
mnist_ds
=
mnist_ds
.
map
(
input_columns
=
"image"
,
operations
=
hwc2chw_op
,
num_parallel_workers
=
num_parallel_workers
)
# apply DatasetOps
buffer_size
=
10000
mnist_ds
=
mnist_ds
.
shuffle
(
buffer_size
=
buffer_size
)
# 10000 as in LeNet train script
mnist_ds
=
mnist_ds
.
batch
(
batch_size
,
drop_remainder
=
True
)
mnist_ds
=
mnist_ds
.
repeat
(
repeat_size
)
return
mnist_ds
tests/st/quantization/lenet_quant/lenet.py
0 → 100644
浏览文件 @
a71868f1
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""LeNet."""
import
mindspore.nn
as
nn
from
mindspore.common.initializer
import
TruncatedNormal
def
conv
(
in_channels
,
out_channels
,
kernel_size
,
stride
=
1
,
padding
=
0
):
"""weight initial for conv layer"""
weight
=
weight_variable
()
return
nn
.
Conv2d
(
in_channels
,
out_channels
,
kernel_size
=
kernel_size
,
stride
=
stride
,
padding
=
padding
,
weight_init
=
weight
,
has_bias
=
False
,
pad_mode
=
"valid"
)
def
fc_with_initialize
(
input_channels
,
out_channels
):
"""weight initial for fc layer"""
weight
=
weight_variable
()
bias
=
weight_variable
()
return
nn
.
Dense
(
input_channels
,
out_channels
,
weight
,
bias
)
def
weight_variable
():
"""weight initial"""
return
TruncatedNormal
(
0.02
)
class
LeNet5
(
nn
.
Cell
):
"""
Lenet network
Args:
num_class (int): Num classes. Default: 10.
Returns:
Tensor, output tensor
Examples:
>>> LeNet(num_class=10)
"""
def
__init__
(
self
,
num_class
=
10
,
channel
=
1
):
super
(
LeNet5
,
self
).
__init__
()
self
.
num_class
=
num_class
self
.
conv1
=
conv
(
channel
,
6
,
5
)
self
.
conv2
=
conv
(
6
,
16
,
5
)
self
.
fc1
=
fc_with_initialize
(
16
*
5
*
5
,
120
)
self
.
fc2
=
fc_with_initialize
(
120
,
84
)
self
.
fc3
=
fc_with_initialize
(
84
,
self
.
num_class
)
self
.
relu
=
nn
.
ReLU
()
self
.
max_pool2d
=
nn
.
MaxPool2d
(
kernel_size
=
2
,
stride
=
2
)
self
.
flatten
=
nn
.
Flatten
()
def
construct
(
self
,
x
):
x
=
self
.
conv1
(
x
)
x
=
self
.
relu
(
x
)
x
=
self
.
max_pool2d
(
x
)
x
=
self
.
conv2
(
x
)
x
=
self
.
relu
(
x
)
x
=
self
.
max_pool2d
(
x
)
x
=
self
.
flatten
(
x
)
x
=
self
.
fc1
(
x
)
x
=
self
.
relu
(
x
)
x
=
self
.
fc2
(
x
)
x
=
self
.
relu
(
x
)
x
=
self
.
fc3
(
x
)
return
x
tests/st/quantization/lenet_quant/lenet_fusion.py
0 → 100644
浏览文件 @
a71868f1
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""LeNet."""
import
mindspore.nn
as
nn
class
LeNet5
(
nn
.
Cell
):
"""
Lenet network
Args:
num_class (int): Num classes. Default: 10.
Returns:
Tensor, output tensor
Examples:
>>> LeNet(num_class=10)
"""
def
__init__
(
self
,
num_class
=
10
,
channel
=
1
):
super
(
LeNet5
,
self
).
__init__
()
self
.
type
=
"fusion"
self
.
num_class
=
num_class
# change `nn.Conv2d` to `nn.Conv2dBnAct`
self
.
conv1
=
nn
.
Conv2dBnAct
(
channel
,
6
,
5
,
pad_mode
=
'valid'
,
activation
=
'relu'
)
self
.
conv2
=
nn
.
Conv2dBnAct
(
6
,
16
,
5
,
pad_mode
=
'valid'
,
activation
=
'relu'
)
# change `nn.Dense` to `nn.DenseBnAct`
self
.
fc1
=
nn
.
DenseBnAct
(
16
*
5
*
5
,
120
,
activation
=
'relu'
)
self
.
fc2
=
nn
.
DenseBnAct
(
120
,
84
,
activation
=
'relu'
)
self
.
fc3
=
nn
.
DenseBnAct
(
84
,
self
.
num_class
)
self
.
max_pool2d
=
nn
.
MaxPool2d
(
kernel_size
=
2
,
stride
=
2
)
self
.
flatten
=
nn
.
Flatten
()
def
construct
(
self
,
x
):
x
=
self
.
conv1
(
x
)
x
=
self
.
max_pool2d
(
x
)
x
=
self
.
conv2
(
x
)
x
=
self
.
max_pool2d
(
x
)
x
=
self
.
flatten
(
x
)
x
=
self
.
fc1
(
x
)
x
=
self
.
fc2
(
x
)
x
=
self
.
fc3
(
x
)
return
x
tests/st/quantization/lenet_quant/test_lenet_quant.py
0 → 100644
浏览文件 @
a71868f1
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""
train and infer lenet quantization network
"""
import
os
import
pytest
from
mindspore
import
context
import
mindspore.nn
as
nn
from
mindspore.nn.metrics
import
Accuracy
from
mindspore.train.callback
import
ModelCheckpoint
,
CheckpointConfig
,
LossMonitor
,
TimeMonitor
from
mindspore.train.serialization
import
load_checkpoint
,
load_param_into_net
from
mindspore.train
import
Model
from
mindspore.train.quant
import
quant
from
mindspore.train.quant.quant_utils
import
load_nonquant_param_into_quant_net
from
dataset
import
create_dataset
from
config
import
nonquant_cfg
,
quant_cfg
from
lenet
import
LeNet5
from
lenet_fusion
import
LeNet5
as
LeNet5Fusion
device_target
=
'GPU'
data_path
=
"/home/workspace/mindspore_dataset/mnist"
def
train_lenet
():
context
.
set_context
(
mode
=
context
.
GRAPH_MODE
,
device_target
=
device_target
)
cfg
=
nonquant_cfg
ds_train
=
create_dataset
(
os
.
path
.
join
(
data_path
,
"train"
),
cfg
.
batch_size
)
network
=
LeNet5
(
cfg
.
num_classes
)
net_loss
=
nn
.
SoftmaxCrossEntropyWithLogits
(
is_grad
=
False
,
sparse
=
True
,
reduction
=
"mean"
)
net_opt
=
nn
.
Momentum
(
network
.
trainable_params
(),
cfg
.
lr
,
cfg
.
momentum
)
time_cb
=
TimeMonitor
(
data_size
=
ds_train
.
get_dataset_size
())
config_ck
=
CheckpointConfig
(
save_checkpoint_steps
=
cfg
.
save_checkpoint_steps
,
keep_checkpoint_max
=
cfg
.
keep_checkpoint_max
)
ckpoint_cb
=
ModelCheckpoint
(
prefix
=
"checkpoint_lenet"
,
config
=
config_ck
)
model
=
Model
(
network
,
net_loss
,
net_opt
,
metrics
=
{
"Accuracy"
:
Accuracy
()})
print
(
"============== Starting Training Lenet=============="
)
model
.
train
(
cfg
[
'epoch_size'
],
ds_train
,
callbacks
=
[
time_cb
,
ckpoint_cb
,
LossMonitor
()],
dataset_sink_mode
=
True
)
def
train_lenet_quant
():
context
.
set_context
(
mode
=
context
.
GRAPH_MODE
,
device_target
=
device_target
)
cfg
=
quant_cfg
ckpt_path
=
'./checkpoint_lenet-10_1875.ckpt'
ds_train
=
create_dataset
(
os
.
path
.
join
(
data_path
,
"train"
),
cfg
.
batch_size
,
1
)
step_size
=
ds_train
.
get_dataset_size
()
# define fusion network
network
=
LeNet5Fusion
(
cfg
.
num_classes
)
# load quantization aware network checkpoint
param_dict
=
load_checkpoint
(
ckpt_path
)
load_nonquant_param_into_quant_net
(
network
,
param_dict
)
# convert fusion network to quantization aware network
network
=
quant
.
convert_quant_network
(
network
,
quant_delay
=
900
,
bn_fold
=
False
,
per_channel
=
[
True
,
False
],
symmetric
=
[
False
,
False
])
# define network loss
net_loss
=
nn
.
SoftmaxCrossEntropyWithLogits
(
is_grad
=
False
,
sparse
=
True
,
reduction
=
"mean"
)
# define network optimization
net_opt
=
nn
.
Momentum
(
network
.
trainable_params
(),
cfg
.
lr
,
cfg
.
momentum
)
# call back and monitor
config_ckpt
=
CheckpointConfig
(
save_checkpoint_steps
=
cfg
.
epoch_size
*
step_size
,
keep_checkpoint_max
=
cfg
.
keep_checkpoint_max
)
ckpt_callback
=
ModelCheckpoint
(
prefix
=
"checkpoint_lenet"
,
config
=
config_ckpt
)
# define model
model
=
Model
(
network
,
net_loss
,
net_opt
,
metrics
=
{
"Accuracy"
:
Accuracy
()})
print
(
"============== Starting Training =============="
)
model
.
train
(
cfg
[
'epoch_size'
],
ds_train
,
callbacks
=
[
ckpt_callback
,
LossMonitor
()],
dataset_sink_mode
=
True
)
print
(
"============== End Training =============="
)
def
eval_quant
():
context
.
set_context
(
mode
=
context
.
GRAPH_MODE
,
device_target
=
device_target
)
cfg
=
quant_cfg
ds_eval
=
create_dataset
(
os
.
path
.
join
(
data_path
,
"test"
),
cfg
.
batch_size
,
1
)
ckpt_path
=
'./checkpoint_lenet_1-10_937.ckpt'
# define fusion network
network
=
LeNet5Fusion
(
cfg
.
num_classes
)
# convert fusion network to quantization aware network
network
=
quant
.
convert_quant_network
(
network
,
quant_delay
=
0
,
bn_fold
=
False
,
freeze_bn
=
10000
,
per_channel
=
[
True
,
False
])
# define loss
net_loss
=
nn
.
SoftmaxCrossEntropyWithLogits
(
is_grad
=
False
,
sparse
=
True
,
reduction
=
"mean"
)
# define network optimization
net_opt
=
nn
.
Momentum
(
network
.
trainable_params
(),
cfg
.
lr
,
cfg
.
momentum
)
# call back and monitor
model
=
Model
(
network
,
net_loss
,
net_opt
,
metrics
=
{
"Accuracy"
:
Accuracy
()})
# load quantization aware network checkpoint
param_dict
=
load_checkpoint
(
ckpt_path
)
not_load_param
=
load_param_into_net
(
network
,
param_dict
)
if
not_load_param
:
raise
ValueError
(
"Load param into net fail!"
)
print
(
"============== Starting Testing =============="
)
acc
=
model
.
eval
(
ds_eval
,
dataset_sink_mode
=
True
)
print
(
"============== {} =============="
.
format
(
acc
))
assert
acc
[
'Accuracy'
]
>
0.98
@
pytest
.
mark
.
level0
@
pytest
.
mark
.
platform_x86_gpu_training
@
pytest
.
mark
.
env_onecard
def
test_lenet_quant
():
train_lenet
()
train_lenet_quant
()
eval_quant
()
if
__name__
==
"__main__"
:
train_lenet_quant
()
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录