Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
magicwindyyd
mindspore
提交
56806d22
M
mindspore
项目概览
magicwindyyd
/
mindspore
与 Fork 源项目一致
Fork自
MindSpore / mindspore
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
M
mindspore
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
56806d22
编写于
8月 17, 2020
作者:
C
chenfei
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
standardlization of moblienetv2 and resnet50 quant network
上级
c7f461ac
变更
6
隐藏空白更改
内联
并排
Showing
6 changed file
with
57 addition
and
100 deletion
+57
-100
model_zoo/official/cv/mobilenetv2_quant/scripts/run_train_quant.sh
.../official/cv/mobilenetv2_quant/scripts/run_train_quant.sh
+1
-3
model_zoo/official/cv/mobilenetv2_quant/src/config.py
model_zoo/official/cv/mobilenetv2_quant/src/config.py
+1
-41
model_zoo/official/cv/mobilenetv2_quant/src/utils.py
model_zoo/official/cv/mobilenetv2_quant/src/utils.py
+32
-0
model_zoo/official/cv/mobilenetv2_quant/train.py
model_zoo/official/cv/mobilenetv2_quant/train.py
+17
-18
model_zoo/official/cv/resnet50_quant/src/config.py
model_zoo/official/cv/resnet50_quant/src/config.py
+1
-28
model_zoo/official/cv/resnet50_quant/train.py
model_zoo/official/cv/resnet50_quant/train.py
+5
-10
未找到文件。
model_zoo/official/cv/mobilenetv2_quant/scripts/run_train_quant.sh
浏览文件 @
56806d22
...
...
@@ -43,7 +43,6 @@ run_ascend()
--training_script
=
${
BASEPATH
}
/../train.py
\
--dataset_path
=
$5
\
--pre_trained
=
$6
\
--quantization_aware
=
True
\
--device_target
=
$1
&> train.log &
# dataset train folder
}
...
...
@@ -75,8 +74,7 @@ run_gpu()
python
${
BASEPATH
}
/../train.py
\
--dataset_path
=
$4
\
--device_target
=
$1
\
--pre_trained
=
$5
\
--quantization_aware
=
True &> ../train.log &
# dataset train folder
--pre_trained
=
$5
&> ../train.log &
# dataset train folder
}
if
[
$#
-gt
6
]
||
[
$#
-lt
5
]
...
...
model_zoo/official/cv/mobilenetv2_quant/src/config.py
浏览文件 @
56806d22
...
...
@@ -16,34 +16,12 @@
network config setting, will be used in train.py and eval.py
"""
from
easydict
import
EasyDict
as
ed
config_ascend
=
ed
({
"num_classes"
:
1000
,
"image_height"
:
224
,
"image_width"
:
224
,
"batch_size"
:
256
,
"data_load_mode"
:
"mindrecord"
,
"epoch_size"
:
200
,
"start_epoch"
:
0
,
"warmup_epochs"
:
4
,
"lr"
:
0.4
,
"momentum"
:
0.9
,
"weight_decay"
:
4e-5
,
"label_smooth"
:
0.1
,
"loss_scale"
:
1024
,
"save_checkpoint"
:
True
,
"save_checkpoint_epochs"
:
1
,
"keep_checkpoint_max"
:
300
,
"save_checkpoint_path"
:
"./checkpoint"
,
"quantization_aware"
:
False
,
})
config_ascend_quant
=
ed
({
"num_classes"
:
1000
,
"image_height"
:
224
,
"image_width"
:
224
,
"batch_size"
:
192
,
"data_load_mode"
:
"mind
record
"
,
"data_load_mode"
:
"mind
ata
"
,
"epoch_size"
:
60
,
"start_epoch"
:
200
,
"warmup_epochs"
:
1
,
...
...
@@ -59,24 +37,6 @@ config_ascend_quant = ed({
"quantization_aware"
:
True
,
})
config_gpu
=
ed
({
"num_classes"
:
1000
,
"image_height"
:
224
,
"image_width"
:
224
,
"batch_size"
:
150
,
"epoch_size"
:
200
,
"warmup_epochs"
:
4
,
"lr"
:
0.8
,
"momentum"
:
0.9
,
"weight_decay"
:
4e-5
,
"label_smooth"
:
0.1
,
"loss_scale"
:
1024
,
"save_checkpoint"
:
True
,
"save_checkpoint_epochs"
:
1
,
"keep_checkpoint_max"
:
300
,
"save_checkpoint_path"
:
"./checkpoint"
,
})
config_gpu_quant
=
ed
({
"num_classes"
:
1000
,
"image_height"
:
224
,
...
...
model_zoo/official/cv/mobilenetv2_quant/src/utils.py
浏览文件 @
56806d22
...
...
@@ -26,6 +26,38 @@ from mindspore.ops import functional as F
from
mindspore.common
import
dtype
as
mstype
def
_load_param_into_net
(
model
,
params_dict
):
"""
load fp32 model parameters to quantization model.
Args:
model: quantization model
params_dict: f32 param
Returns:
None
"""
iterable_dict
=
{
'weight'
:
iter
([
item
for
item
in
params_dict
.
items
()
if
item
[
0
].
endswith
(
'weight'
)]),
'bias'
:
iter
([
item
for
item
in
params_dict
.
items
()
if
item
[
0
].
endswith
(
'bias'
)]),
'gamma'
:
iter
([
item
for
item
in
params_dict
.
items
()
if
item
[
0
].
endswith
(
'gamma'
)]),
'beta'
:
iter
([
item
for
item
in
params_dict
.
items
()
if
item
[
0
].
endswith
(
'beta'
)]),
'moving_mean'
:
iter
([
item
for
item
in
params_dict
.
items
()
if
item
[
0
].
endswith
(
'moving_mean'
)]),
'moving_variance'
:
iter
(
[
item
for
item
in
params_dict
.
items
()
if
item
[
0
].
endswith
(
'moving_variance'
)]),
'minq'
:
iter
([
item
for
item
in
params_dict
.
items
()
if
item
[
0
].
endswith
(
'minq'
)]),
'maxq'
:
iter
([
item
for
item
in
params_dict
.
items
()
if
item
[
0
].
endswith
(
'maxq'
)])
}
for
name
,
param
in
model
.
parameters_and_names
():
key_name
=
name
.
split
(
"."
)[
-
1
]
if
key_name
not
in
iterable_dict
.
keys
():
raise
ValueError
(
f
"Can't find match parameter in ckpt,param name =
{
name
}
"
)
value_param
=
next
(
iterable_dict
[
key_name
],
None
)
if
value_param
is
not
None
:
param
.
set_parameter_data
(
value_param
[
1
].
data
)
print
(
f
'init model param
{
name
}
with checkpoint param
{
value_param
[
0
]
}
'
)
class
Monitor
(
Callback
):
"""
Monitor loss and time.
...
...
model_zoo/official/cv/mobilenetv2_quant/train.py
浏览文件 @
56806d22
...
...
@@ -25,7 +25,7 @@ from mindspore import nn
from
mindspore.train.model
import
Model
,
ParallelMode
from
mindspore.train.loss_scale_manager
import
FixedLossScaleManager
from
mindspore.train.callback
import
ModelCheckpoint
,
CheckpointConfig
from
mindspore.train.serialization
import
load_checkpoint
,
load_param_into_net
from
mindspore.train.serialization
import
load_checkpoint
from
mindspore.communication.management
import
init
,
get_group_size
,
get_rank
from
mindspore.train.quant
import
quant
import
mindspore.dataset.engine
as
de
...
...
@@ -33,8 +33,9 @@ import mindspore.dataset.engine as de
from
src.dataset
import
create_dataset
from
src.lr_generator
import
get_lr
from
src.utils
import
Monitor
,
CrossEntropyWithLabelSmooth
from
src.config
import
config_ascend_quant
,
config_
ascend
,
config_gpu_quant
,
config_gpu
from
src.config
import
config_ascend_quant
,
config_
gpu_quant
from
src.mobilenetV2
import
mobilenetV2
from
src.utils
import
_load_param_into_net
random
.
seed
(
1
)
np
.
random
.
seed
(
1
)
...
...
@@ -44,7 +45,6 @@ parser = argparse.ArgumentParser(description='Image classification')
parser
.
add_argument
(
'--dataset_path'
,
type
=
str
,
default
=
None
,
help
=
'Dataset path'
)
parser
.
add_argument
(
'--pre_trained'
,
type
=
str
,
default
=
None
,
help
=
'Pertained checkpoint path'
)
parser
.
add_argument
(
'--device_target'
,
type
=
str
,
default
=
None
,
help
=
'Run device target'
)
parser
.
add_argument
(
'--quantization_aware'
,
type
=
bool
,
default
=
False
,
help
=
'Use quantization aware training'
)
args_opt
=
parser
.
parse_args
()
if
args_opt
.
device_target
==
"Ascend"
:
...
...
@@ -69,7 +69,7 @@ else:
def
train_on_ascend
():
config
=
config_ascend_quant
if
args_opt
.
quantization_aware
else
config_ascend
config
=
config_ascend_quant
print
(
"training args: {}"
.
format
(
args_opt
))
print
(
"training configure: {}"
.
format
(
config
))
print
(
"parallel args: rank_id {}, device_id {}, rank_size {}"
.
format
(
rank_id
,
device_id
,
rank_size
))
...
...
@@ -101,14 +101,12 @@ def train_on_ascend():
# load pre trained ckpt
if
args_opt
.
pre_trained
:
param_dict
=
load_checkpoint
(
args_opt
.
pre_trained
)
load_param_into_net
(
network
,
param_dict
)
_load_param_into_net
(
network
,
param_dict
)
# convert fusion network to quantization aware network
if
config
.
quantization_aware
:
network
=
quant
.
convert_quant_network
(
network
,
bn_fold
=
True
,
per_channel
=
[
True
,
False
],
symmetric
=
[
True
,
False
])
network
=
quant
.
convert_quant_network
(
network
,
bn_fold
=
True
,
per_channel
=
[
True
,
False
],
symmetric
=
[
True
,
False
])
# get learning rate
lr
=
Tensor
(
get_lr
(
global_step
=
config
.
start_epoch
*
step_size
,
...
...
@@ -141,7 +139,7 @@ def train_on_ascend():
def
train_on_gpu
():
config
=
config_gpu_quant
if
args_opt
.
quantization_aware
else
config_gpu
config
=
config_gpu_quant
print
(
"training args: {}"
.
format
(
args_opt
))
print
(
"training configure: {}"
.
format
(
config
))
...
...
@@ -165,14 +163,15 @@ def train_on_gpu():
# resume
if
args_opt
.
pre_trained
:
param_dict
=
load_checkpoint
(
args_opt
.
pre_trained
)
load_param_into_net
(
network
,
param_dict
)
_
load_param_into_net
(
network
,
param_dict
)
# convert fusion network to quantization aware network
if
config
.
quantization_aware
:
network
=
quant
.
convert_quant_network
(
network
,
bn_fold
=
True
,
per_channel
=
[
True
,
False
],
symmetric
=
[
True
,
True
])
network
=
quant
.
convert_quant_network
(
network
,
bn_fold
=
True
,
per_channel
=
[
True
,
False
],
symmetric
=
[
True
,
True
],
freeze_bn
=
1000000
,
quant_delay
=
step_size
*
2
)
# get learning rate
loss_scale
=
FixedLossScaleManager
(
config
.
loss_scale
,
drop_overflow_update
=
False
)
...
...
model_zoo/official/cv/resnet50_quant/src/config.py
浏览文件 @
56806d22
...
...
@@ -16,33 +16,6 @@
network config setting, will be used in train.py and eval.py
"""
from
easydict
import
EasyDict
as
ed
quant_set
=
ed
({
"quantization_aware"
:
True
,
})
config_noquant
=
ed
({
"class_num"
:
1001
,
"batch_size"
:
32
,
"loss_scale"
:
1024
,
"momentum"
:
0.9
,
"weight_decay"
:
1e-4
,
"epoch_size"
:
90
,
"pretrained_epoch_size"
:
1
,
"buffer_size"
:
1000
,
"image_height"
:
224
,
"image_width"
:
224
,
"data_load_mode"
:
"mindrecord"
,
"save_checkpoint"
:
True
,
"save_checkpoint_epochs"
:
1
,
"keep_checkpoint_max"
:
50
,
"save_checkpoint_path"
:
"./"
,
"warmup_epochs"
:
0
,
"lr_decay_mode"
:
"cosine"
,
"use_label_smooth"
:
True
,
"label_smooth_factor"
:
0.1
,
"lr_init"
:
0
,
"lr_max"
:
0.1
,
})
config_quant
=
ed
({
"class_num"
:
1001
,
"batch_size"
:
32
,
...
...
@@ -54,7 +27,7 @@ config_quant = ed({
"buffer_size"
:
1000
,
"image_height"
:
224
,
"image_width"
:
224
,
"data_load_mode"
:
"mind
record
"
,
"data_load_mode"
:
"mind
ata
"
,
"save_checkpoint"
:
True
,
"save_checkpoint_epochs"
:
1
,
"keep_checkpoint_max"
:
50
,
...
...
model_zoo/official/cv/resnet50_quant/train.py
浏览文件 @
56806d22
...
...
@@ -33,7 +33,7 @@ import mindspore.common.initializer as weight_init
from
models.resnet_quant
import
resnet50_quant
from
src.dataset
import
create_dataset
from
src.lr_generator
import
get_lr
from
src.config
import
quant_set
,
config_quant
,
config_no
quant
from
src.config
import
config_
quant
from
src.crossentropy
import
CrossEntropy
from
src.utils
import
_load_param_into_net
...
...
@@ -44,7 +44,7 @@ parser.add_argument('--dataset_path', type=str, default=None, help='Dataset path
parser
.
add_argument
(
'--device_target'
,
type
=
str
,
default
=
'Ascend'
,
help
=
'Device target'
)
parser
.
add_argument
(
'--pre_trained'
,
type
=
str
,
default
=
None
,
help
=
'Pertained checkpoint path'
)
args_opt
=
parser
.
parse_args
()
config
=
config_quant
if
quant_set
.
quantization_aware
else
config_noquant
config
=
config_quant
if
args_opt
.
device_target
==
"Ascend"
:
device_id
=
int
(
os
.
getenv
(
'DEVICE_ID'
))
...
...
@@ -110,9 +110,8 @@ if __name__ == '__main__':
target
=
args_opt
.
device_target
)
step_size
=
dataset
.
get_dataset_size
()
if
quant_set
.
quantization_aware
:
# convert fusion network to quantization aware network
net
=
quant
.
convert_quant_network
(
net
,
bn_fold
=
True
,
per_channel
=
[
True
,
False
],
symmetric
=
[
True
,
False
])
# convert fusion network to quantization aware network
net
=
quant
.
convert_quant_network
(
net
,
bn_fold
=
True
,
per_channel
=
[
True
,
False
],
symmetric
=
[
True
,
False
])
# get learning rate
lr
=
get_lr
(
lr_init
=
config
.
lr_init
,
...
...
@@ -131,11 +130,7 @@ if __name__ == '__main__':
config
.
weight_decay
,
config
.
loss_scale
)
# define model
if
quant_set
.
quantization_aware
:
model
=
Model
(
net
,
loss_fn
=
loss
,
optimizer
=
opt
,
loss_scale_manager
=
loss_scale
,
metrics
=
{
'acc'
})
else
:
model
=
Model
(
net
,
loss_fn
=
loss
,
optimizer
=
opt
,
loss_scale_manager
=
loss_scale
,
metrics
=
{
'acc'
},
amp_level
=
"O2"
)
model
=
Model
(
net
,
loss_fn
=
loss
,
optimizer
=
opt
,
loss_scale_manager
=
loss_scale
,
metrics
=
{
'acc'
})
print
(
"============== Starting Training =============="
)
time_callback
=
TimeMonitor
(
data_size
=
step_size
)
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录