Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
magicwindyyd
mindspore
提交
2a5d90dc
M
mindspore
项目概览
magicwindyyd
/
mindspore
与 Fork 源项目一致
Fork自
MindSpore / mindspore
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
M
mindspore
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
2a5d90dc
编写于
8月 20, 2020
作者:
C
chenfei
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
[Quant][lenet]eval should set bn_fold as true
上级
a27e6f57
变更
15
隐藏空白更改
内联
并排
Showing
15 changed file
with
55 addition
and
350 deletion
+55
-350
mindspore/train/quant/quant_utils.py
mindspore/train/quant/quant_utils.py
+32
-0
mindspore/train/serialization.py
mindspore/train/serialization.py
+1
-0
model_zoo/official/cv/lenet_quant/Readme.md
model_zoo/official/cv/lenet_quant/Readme.md
+0
-59
model_zoo/official/cv/lenet_quant/eval.py
model_zoo/official/cv/lenet_quant/eval.py
+0
-65
model_zoo/official/cv/lenet_quant/eval_quant.py
model_zoo/official/cv/lenet_quant/eval_quant.py
+3
-1
model_zoo/official/cv/lenet_quant/src/lenet.py
model_zoo/official/cv/lenet_quant/src/lenet.py
+0
-64
model_zoo/official/cv/lenet_quant/src/lenet_fusion.py
model_zoo/official/cv/lenet_quant/src/lenet_fusion.py
+2
-2
model_zoo/official/cv/lenet_quant/train.py
model_zoo/official/cv/lenet_quant/train.py
+0
-68
model_zoo/official/cv/lenet_quant/train_quant.py
model_zoo/official/cv/lenet_quant/train_quant.py
+5
-3
model_zoo/official/cv/mobilenetv2_quant/eval.py
model_zoo/official/cv/mobilenetv2_quant/eval.py
+3
-1
model_zoo/official/cv/mobilenetv2_quant/src/utils.py
model_zoo/official/cv/mobilenetv2_quant/src/utils.py
+0
-33
model_zoo/official/cv/mobilenetv2_quant/train.py
model_zoo/official/cv/mobilenetv2_quant/train.py
+3
-3
model_zoo/official/cv/resnet50_quant/eval.py
model_zoo/official/cv/resnet50_quant/eval.py
+4
-3
model_zoo/official/cv/resnet50_quant/src/utils.py
model_zoo/official/cv/resnet50_quant/src/utils.py
+0
-46
model_zoo/official/cv/resnet50_quant/train.py
model_zoo/official/cv/resnet50_quant/train.py
+2
-2
未找到文件。
mindspore/train/quant/quant_utils.py
浏览文件 @
2a5d90dc
...
@@ -250,3 +250,35 @@ def without_fold_batchnorm(weight, cell_quant):
...
@@ -250,3 +250,35 @@ def without_fold_batchnorm(weight, cell_quant):
weight
=
weight
*
_gamma
/
_sigma
weight
=
weight
*
_gamma
/
_sigma
bias
=
beta
-
gamma
*
mean
/
sigma
bias
=
beta
-
gamma
*
mean
/
sigma
return
weight
,
bias
return
weight
,
bias
def
load_nonquant_param_into_quant_net
(
quant_model
,
params_dict
):
"""
load fp32 model parameters to quantization model.
Args:
quant_model: quantization model
params_dict: f32 param
Returns:
None
"""
iterable_dict
=
{
'weight'
:
iter
([
item
for
item
in
params_dict
.
items
()
if
item
[
0
].
endswith
(
'weight'
)]),
'bias'
:
iter
([
item
for
item
in
params_dict
.
items
()
if
item
[
0
].
endswith
(
'bias'
)]),
'gamma'
:
iter
([
item
for
item
in
params_dict
.
items
()
if
item
[
0
].
endswith
(
'gamma'
)]),
'beta'
:
iter
([
item
for
item
in
params_dict
.
items
()
if
item
[
0
].
endswith
(
'beta'
)]),
'moving_mean'
:
iter
([
item
for
item
in
params_dict
.
items
()
if
item
[
0
].
endswith
(
'moving_mean'
)]),
'moving_variance'
:
iter
(
[
item
for
item
in
params_dict
.
items
()
if
item
[
0
].
endswith
(
'moving_variance'
)]),
'minq'
:
iter
([
item
for
item
in
params_dict
.
items
()
if
item
[
0
].
endswith
(
'minq'
)]),
'maxq'
:
iter
([
item
for
item
in
params_dict
.
items
()
if
item
[
0
].
endswith
(
'maxq'
)])
}
for
name
,
param
in
quant_model
.
parameters_and_names
():
key_name
=
name
.
split
(
"."
)[
-
1
]
if
key_name
not
in
iterable_dict
.
keys
():
raise
ValueError
(
f
"Can't find match parameter in ckpt,param name =
{
name
}
"
)
value_param
=
next
(
iterable_dict
[
key_name
],
None
)
if
value_param
is
not
None
:
param
.
set_parameter_data
(
value_param
[
1
].
data
)
print
(
f
'init model param
{
name
}
with checkpoint param
{
value_param
[
0
]
}
'
)
mindspore/train/serialization.py
浏览文件 @
2a5d90dc
...
@@ -308,6 +308,7 @@ def load_param_into_net(net, parameter_dict):
...
@@ -308,6 +308,7 @@ def load_param_into_net(net, parameter_dict):
logger
.
debug
(
"%s"
,
param_name
)
logger
.
debug
(
"%s"
,
param_name
)
logger
.
info
(
"Load parameter into net finish, {} parameters has not been loaded."
.
format
(
len
(
param_not_load
)))
logger
.
info
(
"Load parameter into net finish, {} parameters has not been loaded."
.
format
(
len
(
param_not_load
)))
return
param_not_load
def
_load_dismatch_prefix_params
(
net
,
parameter_dict
,
param_not_load
):
def
_load_dismatch_prefix_params
(
net
,
parameter_dict
,
param_not_load
):
...
...
model_zoo/official/cv/lenet_quant/Readme.md
浏览文件 @
2a5d90dc
...
@@ -93,65 +93,6 @@ Get the MNIST from scratch dataset.
...
@@ -93,65 +93,6 @@ Get the MNIST from scratch dataset.
ds_train = create_dataset(os.path.join(args.data_path, "train"),
ds_train = create_dataset(os.path.join(args.data_path, "train"),
cfg.batch_size, cfg.epoch_size)
cfg.batch_size, cfg.epoch_size)
step_size = ds_train.get_dataset_size()
step_size = ds_train.get_dataset_size()
```
### Train model
Load the Lenet fusion network, training network using loss
`nn.SoftmaxCrossEntropyWithLogits`
with optimization
`nn.Momentum`
.
```
Python
# Define the network
network = LeNet5Fusion(cfg.num_classes)
# Define the loss
net_loss = nn.SoftmaxCrossEntropyWithLogits(is_grad=False, sparse=True, reduction="mean")
# Define optimization
net_opt = nn.Momentum(network.trainable_params(), cfg.lr, cfg.momentum)
# Define model using loss and optimization.
time_cb = TimeMonitor(data_size=ds_train.get_dataset_size())
config_ck = CheckpointConfig(save_checkpoint_steps=cfg.epoch_size * step_size,
keep_checkpoint_max=cfg.keep_checkpoint_max)
ckpoint_cb = ModelCheckpoint(prefix="checkpoint_lenet", config=config_ck)
model = Model(network, net_loss, net_opt, metrics={"Accuracy": Accuracy()})
```
Now we can start training.
```
Python
model.train(cfg['epoch_size'], ds_train,
callbacks=[time_cb, ckpoint_cb, LossMonitor()],
dataset_sink_mode=args.dataset_sink_mode)
```
After all the following we will get the loss value of each step as following:
```
bash
>>>
Epoch:
[
1/ 10] step:
[
1/ 900], loss:
[
2.3040/2.5234],
time
:
[
1.300234]
>>>
...
>>>
Epoch:
[
9/ 10] step:
[
887/ 900], loss:
[
0.0113/0.0223],
time
:
[
1.300234]
>>>
Epoch:
[
9/ 10] step:
[
888/ 900], loss:
[
0.0334/0.0223],
time
:
[
1.300234]
>>>
Epoch:
[
9/ 10] step:
[
889/ 900], loss:
[
0.0233/0.0223],
time
:
[
1.300234]
```
Also, you can just run this command instead.
```
python
python
train
.
py
--
data_path
MNIST_Data
--
device_target
Ascend
```
### Evaluate fusion model
After training epoch stop. We can get the fusion model checkpoint file like
`checkpoint_lenet.ckpt`
. Meanwhile, we can evaluate this fusion model.
```
python
python
eval
.
py
--
data_path
MNIST_Data
--
device_target
Ascend
--
ckpt_path
checkpoint_lenet
.
ckpt
```
The top1 accuracy would display on shell.
```
bash
>>>
Accuracy: 98.53.
```
## Train quantization aware model
## Train quantization aware model
...
...
model_zoo/official/cv/lenet_quant/eval.py
已删除
100644 → 0
浏览文件 @
a27e6f57
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""
######################## eval lenet example ########################
eval lenet according to model file:
python eval.py --data_path /YourDataPath --ckpt_path Your.ckpt
"""
import
os
import
argparse
import
mindspore.nn
as
nn
from
mindspore
import
context
from
mindspore.train.serialization
import
load_checkpoint
,
load_param_into_net
from
mindspore.train
import
Model
from
mindspore.nn.metrics
import
Accuracy
from
src.dataset
import
create_dataset
from
src.config
import
mnist_cfg
as
cfg
from
src.lenet_fusion
import
LeNet5
as
LeNet5Fusion
parser
=
argparse
.
ArgumentParser
(
description
=
'MindSpore MNIST Example'
)
parser
.
add_argument
(
'--device_target'
,
type
=
str
,
default
=
"Ascend"
,
choices
=
[
'Ascend'
,
'GPU'
],
help
=
'device where the code will be implemented (default: Ascend)'
)
parser
.
add_argument
(
'--data_path'
,
type
=
str
,
default
=
"./MNIST_Data"
,
help
=
'path where the dataset is saved'
)
parser
.
add_argument
(
'--ckpt_path'
,
type
=
str
,
default
=
""
,
help
=
'if mode is test, must provide path where the trained ckpt file'
)
parser
.
add_argument
(
'--dataset_sink_mode'
,
type
=
bool
,
default
=
True
,
help
=
'dataset_sink_mode is False or True'
)
args
=
parser
.
parse_args
()
if
__name__
==
"__main__"
:
context
.
set_context
(
mode
=
context
.
GRAPH_MODE
,
device_target
=
args
.
device_target
)
ds_eval
=
create_dataset
(
os
.
path
.
join
(
args
.
data_path
,
"test"
),
cfg
.
batch_size
,
1
)
step_size
=
ds_eval
.
get_dataset_size
()
# define fusion network
network
=
LeNet5Fusion
(
cfg
.
num_classes
)
# define loss
net_loss
=
nn
.
SoftmaxCrossEntropyWithLogits
(
is_grad
=
False
,
sparse
=
True
,
reduction
=
"mean"
)
# define network optimization
net_opt
=
nn
.
Momentum
(
network
.
trainable_params
(),
cfg
.
lr
,
cfg
.
momentum
)
# call back and monitor
model
=
Model
(
network
,
net_loss
,
net_opt
,
metrics
=
{
"Accuracy"
:
Accuracy
()})
# load check point into network
param_dict
=
load_checkpoint
(
args
.
ckpt_path
)
load_param_into_net
(
network
,
param_dict
)
print
(
"============== Starting Testing =============="
)
acc
=
model
.
eval
(
ds_eval
,
dataset_sink_mode
=
args
.
dataset_sink_mode
)
print
(
"============== {} =============="
.
format
(
acc
))
model_zoo/official/cv/lenet_quant/eval_quant.py
浏览文件 @
2a5d90dc
...
@@ -63,7 +63,9 @@ if __name__ == "__main__":
...
@@ -63,7 +63,9 @@ if __name__ == "__main__":
# load quantization aware network checkpoint
# load quantization aware network checkpoint
param_dict
=
load_checkpoint
(
args
.
ckpt_path
)
param_dict
=
load_checkpoint
(
args
.
ckpt_path
)
load_param_into_net
(
network
,
param_dict
)
not_load_param
=
load_param_into_net
(
network
,
param_dict
)
if
not_load_param
:
raise
ValueError
(
"Load param into net fail!"
)
print
(
"============== Starting Testing =============="
)
print
(
"============== Starting Testing =============="
)
acc
=
model
.
eval
(
ds_eval
,
dataset_sink_mode
=
args
.
dataset_sink_mode
)
acc
=
model
.
eval
(
ds_eval
,
dataset_sink_mode
=
args
.
dataset_sink_mode
)
...
...
model_zoo/official/cv/lenet_quant/src/lenet.py
已删除
100644 → 0
浏览文件 @
a27e6f57
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""LeNet."""
import
mindspore.nn
as
nn
class
LeNet5
(
nn
.
Cell
):
"""
Lenet network
Args:
num_class (int): Num classes. Default: 10.
Returns:
Tensor, output tensor
Examples:
>>> LeNet(num_class=10)
"""
def
__init__
(
self
,
num_class
=
10
,
channel
=
1
):
super
(
LeNet5
,
self
).
__init__
()
self
.
num_class
=
num_class
self
.
conv1
=
nn
.
Conv2d
(
channel
,
6
,
5
,
pad_mode
=
'valid'
)
self
.
bn1
=
nn
.
BatchNorm2d
(
6
)
self
.
conv2
=
nn
.
Conv2d
(
6
,
16
,
5
,
pad_mode
=
'valid'
)
self
.
bn2
=
nn
.
BatchNorm2d
(
16
)
self
.
fc1
=
nn
.
Dense
(
16
*
5
*
5
,
120
)
self
.
fc2
=
nn
.
Dense
(
120
,
84
)
self
.
fc3
=
nn
.
Dense
(
84
,
self
.
num_class
)
self
.
relu
=
nn
.
ReLU
()
self
.
max_pool2d
=
nn
.
MaxPool2d
(
kernel_size
=
2
,
stride
=
2
)
self
.
flatten
=
nn
.
Flatten
()
def
construct
(
self
,
x
):
x
=
self
.
conv1
(
x
)
x
=
self
.
bn1
(
x
)
x
=
self
.
relu
(
x
)
x
=
self
.
max_pool2d
(
x
)
x
=
self
.
conv2
(
x
)
x
=
self
.
bn2
(
x
)
x
=
self
.
relu
(
x
)
x
=
self
.
max_pool2d
(
x
)
x
=
self
.
flatten
(
x
)
x
=
self
.
fc1
(
x
)
x
=
self
.
relu
(
x
)
x
=
self
.
fc2
(
x
)
x
=
self
.
relu
(
x
)
x
=
self
.
fc3
(
x
)
return
x
model_zoo/official/cv/lenet_quant/src/lenet_fusion.py
浏览文件 @
2a5d90dc
...
@@ -36,8 +36,8 @@ class LeNet5(nn.Cell):
...
@@ -36,8 +36,8 @@ class LeNet5(nn.Cell):
self
.
num_class
=
num_class
self
.
num_class
=
num_class
# change `nn.Conv2d` to `nn.Conv2dBnAct`
# change `nn.Conv2d` to `nn.Conv2dBnAct`
self
.
conv1
=
nn
.
Conv2dBnAct
(
channel
,
6
,
5
,
pad_mode
=
'valid'
,
has_bn
=
True
,
activation
=
'relu'
)
self
.
conv1
=
nn
.
Conv2dBnAct
(
channel
,
6
,
5
,
pad_mode
=
'valid'
,
activation
=
'relu'
)
self
.
conv2
=
nn
.
Conv2dBnAct
(
6
,
16
,
5
,
pad_mode
=
'valid'
,
has_bn
=
True
,
activation
=
'relu'
)
self
.
conv2
=
nn
.
Conv2dBnAct
(
6
,
16
,
5
,
pad_mode
=
'valid'
,
activation
=
'relu'
)
# change `nn.Dense` to `nn.DenseBnAct`
# change `nn.Dense` to `nn.DenseBnAct`
self
.
fc1
=
nn
.
DenseBnAct
(
16
*
5
*
5
,
120
,
activation
=
'relu'
)
self
.
fc1
=
nn
.
DenseBnAct
(
16
*
5
*
5
,
120
,
activation
=
'relu'
)
self
.
fc2
=
nn
.
DenseBnAct
(
120
,
84
,
activation
=
'relu'
)
self
.
fc2
=
nn
.
DenseBnAct
(
120
,
84
,
activation
=
'relu'
)
...
...
model_zoo/official/cv/lenet_quant/train.py
已删除
100644 → 0
浏览文件 @
a27e6f57
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""
######################## train lenet example ########################
train lenet and get network model files(.ckpt) :
python train.py --data_path /YourDataPath
"""
import
os
import
argparse
import
mindspore.nn
as
nn
from
mindspore
import
context
from
mindspore.train.callback
import
ModelCheckpoint
,
CheckpointConfig
from
mindspore.train
import
Model
from
mindspore.nn.metrics
import
Accuracy
from
src.dataset
import
create_dataset
from
src.config
import
mnist_cfg
as
cfg
from
src.lenet_fusion
import
LeNet5
as
LeNet5Fusion
from
src.loss_monitor
import
LossMonitor
parser
=
argparse
.
ArgumentParser
(
description
=
'MindSpore MNIST Example'
)
parser
.
add_argument
(
'--device_target'
,
type
=
str
,
default
=
"Ascend"
,
choices
=
[
'Ascend'
,
'GPU'
],
help
=
'device where the code will be implemented (default: Ascend)'
)
parser
.
add_argument
(
'--data_path'
,
type
=
str
,
default
=
"./MNIST_Data"
,
help
=
'path where the dataset is saved'
)
parser
.
add_argument
(
'--ckpt_path'
,
type
=
str
,
default
=
""
,
help
=
'if mode is test, must provide path where the trained ckpt file'
)
parser
.
add_argument
(
'--dataset_sink_mode'
,
type
=
bool
,
default
=
True
,
help
=
'dataset_sink_mode is False or True'
)
args
=
parser
.
parse_args
()
if
__name__
==
"__main__"
:
context
.
set_context
(
mode
=
context
.
GRAPH_MODE
,
device_target
=
args
.
device_target
)
ds_train
=
create_dataset
(
os
.
path
.
join
(
args
.
data_path
,
"train"
),
cfg
.
batch_size
,
1
)
step_size
=
ds_train
.
get_dataset_size
()
# define fusion network
network
=
LeNet5Fusion
(
cfg
.
num_classes
)
# define network loss
net_loss
=
nn
.
SoftmaxCrossEntropyWithLogits
(
is_grad
=
False
,
sparse
=
True
,
reduction
=
"mean"
)
# define network optimization
net_opt
=
nn
.
Momentum
(
network
.
trainable_params
(),
cfg
.
lr
,
cfg
.
momentum
)
# call back and monitor
config_ckpt
=
CheckpointConfig
(
save_checkpoint_steps
=
cfg
.
epoch_size
*
step_size
,
keep_checkpoint_max
=
cfg
.
keep_checkpoint_max
)
ckpt_callback
=
ModelCheckpoint
(
prefix
=
"checkpoint_lenet"
,
config
=
config_ckpt
)
# define model
model
=
Model
(
network
,
net_loss
,
net_opt
,
metrics
=
{
"Accuracy"
:
Accuracy
()})
print
(
"============== Starting Training =============="
)
model
.
train
(
cfg
[
'epoch_size'
],
ds_train
,
callbacks
=
[
ckpt_callback
,
LossMonitor
()],
dataset_sink_mode
=
args
.
dataset_sink_mode
)
print
(
"============== End Training =============="
)
model_zoo/official/cv/lenet_quant/train_quant.py
浏览文件 @
2a5d90dc
...
@@ -22,11 +22,12 @@ import os
...
@@ -22,11 +22,12 @@ import os
import
argparse
import
argparse
import
mindspore.nn
as
nn
import
mindspore.nn
as
nn
from
mindspore
import
context
from
mindspore
import
context
from
mindspore.train.serialization
import
load_checkpoint
,
load_param_into_net
from
mindspore.train.serialization
import
load_checkpoint
from
mindspore.train.callback
import
ModelCheckpoint
,
CheckpointConfig
from
mindspore.train.callback
import
ModelCheckpoint
,
CheckpointConfig
from
mindspore.train
import
Model
from
mindspore.train
import
Model
from
mindspore.nn.metrics
import
Accuracy
from
mindspore.nn.metrics
import
Accuracy
from
mindspore.train.quant
import
quant
from
mindspore.train.quant
import
quant
from
mindspore.train.quant.quant_utils
import
load_nonquant_param_into_quant_net
from
src.dataset
import
create_dataset
from
src.dataset
import
create_dataset
from
src.config
import
mnist_cfg
as
cfg
from
src.config
import
mnist_cfg
as
cfg
from
src.lenet_fusion
import
LeNet5
as
LeNet5Fusion
from
src.lenet_fusion
import
LeNet5
as
LeNet5Fusion
...
@@ -54,10 +55,11 @@ if __name__ == "__main__":
...
@@ -54,10 +55,11 @@ if __name__ == "__main__":
# load quantization aware network checkpoint
# load quantization aware network checkpoint
param_dict
=
load_checkpoint
(
args
.
ckpt_path
)
param_dict
=
load_checkpoint
(
args
.
ckpt_path
)
load_
param_into
_net
(
network
,
param_dict
)
load_
nonquant_param_into_quant
_net
(
network
,
param_dict
)
# convert fusion network to quantization aware network
# convert fusion network to quantization aware network
network
=
quant
.
convert_quant_network
(
network
,
quant_delay
=
900
,
per_channel
=
[
True
,
False
],
symmetric
=
[
False
,
False
])
network
=
quant
.
convert_quant_network
(
network
,
quant_delay
=
900
,
bn_fold
=
False
,
per_channel
=
[
True
,
False
],
symmetric
=
[
False
,
False
])
# define network loss
# define network loss
net_loss
=
nn
.
SoftmaxCrossEntropyWithLogits
(
is_grad
=
False
,
sparse
=
True
,
reduction
=
"mean"
)
net_loss
=
nn
.
SoftmaxCrossEntropyWithLogits
(
is_grad
=
False
,
sparse
=
True
,
reduction
=
"mean"
)
...
...
model_zoo/official/cv/mobilenetv2_quant/eval.py
浏览文件 @
2a5d90dc
...
@@ -68,7 +68,9 @@ if __name__ == '__main__':
...
@@ -68,7 +68,9 @@ if __name__ == '__main__':
# load checkpoint
# load checkpoint
if
args_opt
.
checkpoint_path
:
if
args_opt
.
checkpoint_path
:
param_dict
=
load_checkpoint
(
args_opt
.
checkpoint_path
)
param_dict
=
load_checkpoint
(
args_opt
.
checkpoint_path
)
load_param_into_net
(
network
,
param_dict
)
not_load_param
=
load_param_into_net
(
network
,
param_dict
)
if
not_load_param
:
raise
ValueError
(
"Load param into net fail!"
)
network
.
set_train
(
False
)
network
.
set_train
(
False
)
# define model
# define model
...
...
model_zoo/official/cv/mobilenetv2_quant/src/utils.py
浏览文件 @
2a5d90dc
...
@@ -25,39 +25,6 @@ from mindspore.ops import operations as P
...
@@ -25,39 +25,6 @@ from mindspore.ops import operations as P
from
mindspore.ops
import
functional
as
F
from
mindspore.ops
import
functional
as
F
from
mindspore.common
import
dtype
as
mstype
from
mindspore.common
import
dtype
as
mstype
def
_load_param_into_net
(
model
,
params_dict
):
"""
load fp32 model parameters to quantization model.
Args:
model: quantization model
params_dict: f32 param
Returns:
None
"""
iterable_dict
=
{
'weight'
:
iter
([
item
for
item
in
params_dict
.
items
()
if
item
[
0
].
endswith
(
'weight'
)]),
'bias'
:
iter
([
item
for
item
in
params_dict
.
items
()
if
item
[
0
].
endswith
(
'bias'
)]),
'gamma'
:
iter
([
item
for
item
in
params_dict
.
items
()
if
item
[
0
].
endswith
(
'gamma'
)]),
'beta'
:
iter
([
item
for
item
in
params_dict
.
items
()
if
item
[
0
].
endswith
(
'beta'
)]),
'moving_mean'
:
iter
([
item
for
item
in
params_dict
.
items
()
if
item
[
0
].
endswith
(
'moving_mean'
)]),
'moving_variance'
:
iter
(
[
item
for
item
in
params_dict
.
items
()
if
item
[
0
].
endswith
(
'moving_variance'
)]),
'minq'
:
iter
([
item
for
item
in
params_dict
.
items
()
if
item
[
0
].
endswith
(
'minq'
)]),
'maxq'
:
iter
([
item
for
item
in
params_dict
.
items
()
if
item
[
0
].
endswith
(
'maxq'
)])
}
for
name
,
param
in
model
.
parameters_and_names
():
key_name
=
name
.
split
(
"."
)[
-
1
]
if
key_name
not
in
iterable_dict
.
keys
():
raise
ValueError
(
f
"Can't find match parameter in ckpt,param name =
{
name
}
"
)
value_param
=
next
(
iterable_dict
[
key_name
],
None
)
if
value_param
is
not
None
:
param
.
set_parameter_data
(
value_param
[
1
].
data
)
print
(
f
'init model param
{
name
}
with checkpoint param
{
value_param
[
0
]
}
'
)
class
Monitor
(
Callback
):
class
Monitor
(
Callback
):
"""
"""
Monitor loss and time.
Monitor loss and time.
...
...
model_zoo/official/cv/mobilenetv2_quant/train.py
浏览文件 @
2a5d90dc
...
@@ -28,6 +28,7 @@ from mindspore.train.callback import ModelCheckpoint, CheckpointConfig
...
@@ -28,6 +28,7 @@ from mindspore.train.callback import ModelCheckpoint, CheckpointConfig
from
mindspore.train.serialization
import
load_checkpoint
from
mindspore.train.serialization
import
load_checkpoint
from
mindspore.communication.management
import
init
,
get_group_size
,
get_rank
from
mindspore.communication.management
import
init
,
get_group_size
,
get_rank
from
mindspore.train.quant
import
quant
from
mindspore.train.quant
import
quant
from
mindspore.train.quant.quant_utils
import
load_nonquant_param_into_quant_net
import
mindspore.dataset.engine
as
de
import
mindspore.dataset.engine
as
de
from
src.dataset
import
create_dataset
from
src.dataset
import
create_dataset
...
@@ -35,7 +36,6 @@ from src.lr_generator import get_lr
...
@@ -35,7 +36,6 @@ from src.lr_generator import get_lr
from
src.utils
import
Monitor
,
CrossEntropyWithLabelSmooth
from
src.utils
import
Monitor
,
CrossEntropyWithLabelSmooth
from
src.config
import
config_ascend_quant
,
config_gpu_quant
from
src.config
import
config_ascend_quant
,
config_gpu_quant
from
src.mobilenetV2
import
mobilenetV2
from
src.mobilenetV2
import
mobilenetV2
from
src.utils
import
_load_param_into_net
random
.
seed
(
1
)
random
.
seed
(
1
)
np
.
random
.
seed
(
1
)
np
.
random
.
seed
(
1
)
...
@@ -101,7 +101,7 @@ def train_on_ascend():
...
@@ -101,7 +101,7 @@ def train_on_ascend():
# load pre trained ckpt
# load pre trained ckpt
if
args_opt
.
pre_trained
:
if
args_opt
.
pre_trained
:
param_dict
=
load_checkpoint
(
args_opt
.
pre_trained
)
param_dict
=
load_checkpoint
(
args_opt
.
pre_trained
)
_load_param_into
_net
(
network
,
param_dict
)
load_nonquant_param_into_quant
_net
(
network
,
param_dict
)
# convert fusion network to quantization aware network
# convert fusion network to quantization aware network
network
=
quant
.
convert_quant_network
(
network
,
network
=
quant
.
convert_quant_network
(
network
,
bn_fold
=
True
,
bn_fold
=
True
,
...
@@ -163,7 +163,7 @@ def train_on_gpu():
...
@@ -163,7 +163,7 @@ def train_on_gpu():
# resume
# resume
if
args_opt
.
pre_trained
:
if
args_opt
.
pre_trained
:
param_dict
=
load_checkpoint
(
args_opt
.
pre_trained
)
param_dict
=
load_checkpoint
(
args_opt
.
pre_trained
)
_load_param_into
_net
(
network
,
param_dict
)
load_nonquant_param_into_quant
_net
(
network
,
param_dict
)
# convert fusion network to quantization aware network
# convert fusion network to quantization aware network
network
=
quant
.
convert_quant_network
(
network
,
network
=
quant
.
convert_quant_network
(
network
,
...
...
model_zoo/official/cv/resnet50_quant/eval.py
浏览文件 @
2a5d90dc
...
@@ -20,12 +20,11 @@ import argparse
...
@@ -20,12 +20,11 @@ import argparse
from
src.config
import
quant_set
,
config_quant
,
config_noquant
from
src.config
import
quant_set
,
config_quant
,
config_noquant
from
src.dataset
import
create_dataset
from
src.dataset
import
create_dataset
from
src.crossentropy
import
CrossEntropy
from
src.crossentropy
import
CrossEntropy
from
src.utils
import
_load_param_into_net
from
models.resnet_quant
import
resnet50_quant
from
models.resnet_quant
import
resnet50_quant
from
mindspore
import
context
from
mindspore
import
context
from
mindspore.train.model
import
Model
from
mindspore.train.model
import
Model
from
mindspore.train.serialization
import
load_checkpoint
from
mindspore.train.serialization
import
load_checkpoint
,
load_param_into_net
from
mindspore.train.quant
import
quant
from
mindspore.train.quant
import
quant
parser
=
argparse
.
ArgumentParser
(
description
=
'Image classification'
)
parser
=
argparse
.
ArgumentParser
(
description
=
'Image classification'
)
...
@@ -66,7 +65,9 @@ if __name__ == '__main__':
...
@@ -66,7 +65,9 @@ if __name__ == '__main__':
# load checkpoint
# load checkpoint
if
args_opt
.
checkpoint_path
:
if
args_opt
.
checkpoint_path
:
param_dict
=
load_checkpoint
(
args_opt
.
checkpoint_path
)
param_dict
=
load_checkpoint
(
args_opt
.
checkpoint_path
)
_load_param_into_net
(
net
,
param_dict
)
not_load_param
=
load_param_into_net
(
net
,
param_dict
)
if
not_load_param
:
raise
ValueError
(
"Load param into net fail!"
)
net
.
set_train
(
False
)
net
.
set_train
(
False
)
# define model
# define model
...
...
model_zoo/official/cv/resnet50_quant/src/utils.py
已删除
100644 → 0
浏览文件 @
a27e6f57
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""utils script"""
def
_load_param_into_net
(
model
,
params_dict
):
"""
load fp32 model parameters to quantization model.
Args:
model: quantization model
params_dict: f32 param
Returns:
None
"""
iterable_dict
=
{
'weight'
:
iter
([
item
for
item
in
params_dict
.
items
()
if
item
[
0
].
endswith
(
'weight'
)]),
'bias'
:
iter
([
item
for
item
in
params_dict
.
items
()
if
item
[
0
].
endswith
(
'bias'
)]),
'gamma'
:
iter
([
item
for
item
in
params_dict
.
items
()
if
item
[
0
].
endswith
(
'gamma'
)]),
'beta'
:
iter
([
item
for
item
in
params_dict
.
items
()
if
item
[
0
].
endswith
(
'beta'
)]),
'moving_mean'
:
iter
([
item
for
item
in
params_dict
.
items
()
if
item
[
0
].
endswith
(
'moving_mean'
)]),
'moving_variance'
:
iter
(
[
item
for
item
in
params_dict
.
items
()
if
item
[
0
].
endswith
(
'moving_variance'
)]),
'minq'
:
iter
([
item
for
item
in
params_dict
.
items
()
if
item
[
0
].
endswith
(
'minq'
)]),
'maxq'
:
iter
([
item
for
item
in
params_dict
.
items
()
if
item
[
0
].
endswith
(
'maxq'
)])
}
for
name
,
param
in
model
.
parameters_and_names
():
key_name
=
name
.
split
(
"."
)[
-
1
]
if
key_name
not
in
iterable_dict
.
keys
():
continue
value_param
=
next
(
iterable_dict
[
key_name
],
None
)
if
value_param
is
not
None
:
param
.
set_parameter_data
(
value_param
[
1
].
data
)
print
(
f
'init model param
{
name
}
with checkpoint param
{
value_param
[
0
]
}
'
)
model_zoo/official/cv/resnet50_quant/train.py
浏览文件 @
2a5d90dc
...
@@ -26,6 +26,7 @@ from mindspore.train.callback import ModelCheckpoint, CheckpointConfig, LossMoni
...
@@ -26,6 +26,7 @@ from mindspore.train.callback import ModelCheckpoint, CheckpointConfig, LossMoni
from
mindspore.train.loss_scale_manager
import
FixedLossScaleManager
from
mindspore.train.loss_scale_manager
import
FixedLossScaleManager
from
mindspore.train.serialization
import
load_checkpoint
from
mindspore.train.serialization
import
load_checkpoint
from
mindspore.train.quant
import
quant
from
mindspore.train.quant
import
quant
from
mindspore.train.quant.quant_utils
import
load_nonquant_param_into_quant_net
from
mindspore.communication.management
import
init
from
mindspore.communication.management
import
init
import
mindspore.nn
as
nn
import
mindspore.nn
as
nn
import
mindspore.common.initializer
as
weight_init
import
mindspore.common.initializer
as
weight_init
...
@@ -35,7 +36,6 @@ from src.dataset import create_dataset
...
@@ -35,7 +36,6 @@ from src.dataset import create_dataset
from
src.lr_generator
import
get_lr
from
src.lr_generator
import
get_lr
from
src.config
import
config_quant
from
src.config
import
config_quant
from
src.crossentropy
import
CrossEntropy
from
src.crossentropy
import
CrossEntropy
from
src.utils
import
_load_param_into_net
parser
=
argparse
.
ArgumentParser
(
description
=
'Image classification'
)
parser
=
argparse
.
ArgumentParser
(
description
=
'Image classification'
)
parser
.
add_argument
(
'--run_distribute'
,
type
=
bool
,
default
=
False
,
help
=
'Run distribute'
)
parser
.
add_argument
(
'--run_distribute'
,
type
=
bool
,
default
=
False
,
help
=
'Run distribute'
)
...
@@ -85,7 +85,7 @@ if __name__ == '__main__':
...
@@ -85,7 +85,7 @@ if __name__ == '__main__':
# weight init and load checkpoint file
# weight init and load checkpoint file
if
args_opt
.
pre_trained
:
if
args_opt
.
pre_trained
:
param_dict
=
load_checkpoint
(
args_opt
.
pre_trained
)
param_dict
=
load_checkpoint
(
args_opt
.
pre_trained
)
_load_param_into
_net
(
net
,
param_dict
)
load_nonquant_param_into_quant
_net
(
net
,
param_dict
)
epoch_size
=
config
.
epoch_size
-
config
.
pretrained_epoch_size
epoch_size
=
config
.
epoch_size
-
config
.
pretrained_epoch_size
else
:
else
:
for
_
,
cell
in
net
.
cells_and_names
():
for
_
,
cell
in
net
.
cells_and_names
():
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录