Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
PaddleSlim
提交
e406740a
P
PaddleSlim
项目概览
PaddlePaddle
/
PaddleSlim
大约 1 年 前同步成功
通知
51
Star
1434
Fork
344
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
53
列表
看板
标记
里程碑
合并请求
16
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
PaddleSlim
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
53
Issue
53
列表
看板
标记
里程碑
合并请求
16
合并请求
16
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
e406740a
编写于
12月 27, 2021
作者:
I
itminner
提交者:
GitHub
12月 27, 2021
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
quant aware with infer model (#947)
quant aware with infer model
上级
bf123166
变更
9
隐藏空白更改
内联
并排
Showing
9 changed file
with
1159 addition
and
0 deletion
+1159
-0
demo/quant/quant_aware_with_infermodel/README.md
demo/quant/quant_aware_with_infermodel/README.md
+113
-0
demo/quant/quant_aware_with_infermodel/export_quantmodel.py
demo/quant/quant_aware_with_infermodel/export_quantmodel.py
+91
-0
demo/quant/quant_aware_with_infermodel/quant_aware_with_infermodel.py
...uant_aware_with_infermodel/quant_aware_with_infermodel.py
+148
-0
demo/quant/quant_post_hpo/README.md
demo/quant/quant_post_hpo/README.md
+72
-0
paddleslim/dist/single_distiller.py
paddleslim/dist/single_distiller.py
+11
-0
paddleslim/quant/__init__.py
paddleslim/quant/__init__.py
+1
-0
paddleslim/quant/quant_aware_with_infermodel.py
paddleslim/quant/quant_aware_with_infermodel.py
+474
-0
paddleslim/quant/quanter.py
paddleslim/quant/quanter.py
+28
-0
tests/test_quant_aware_with_infermodel.py
tests/test_quant_aware_with_infermodel.py
+221
-0
未找到文件。
demo/quant/quant_aware_with_infermodel/README.md
0 → 100644
浏览文件 @
e406740a
# 使用预测模型进行量化训练示例
预测模型获取
动态图使用paddle.jit.save保存;
静态图使用paddle.static.save_inference_model保存。
本示例将介绍如何使用预测模型进行蒸馏量化训练,
首先使用接口
``paddleslim.quant.quant_aware_with_infermodel``
训练量化模型,
训练完成后,使用接口
``paddleslim.quant.export_quant_infermodel``
将训好的量化模型导出为预测模型。
## 分类模型量化训练流程
###1. 准备数据
在
``demo``
文件夹下创建
``data``
文件夹,将
``ImageNet``
数据集解压在
``data``
文件夹下,解压后
``data/ILSVRC2012``
文件夹下应包含以下文件:
-
``'train'``
文件夹,训练图片
-
``'train_list.txt'``
文件
-
``'val'``
文件夹,验证图片
-
``'val_list.txt'``
文件
### 2. 准备需要量化的模型
飞桨图像识别套件PaddleClas是飞桨为工业界和学术界所准备的一个图像识别任务的工具集,本示例使用该套件产出imagenet分类模型。
####① 下载MobileNetV2预训练模型
预训练模型库地址
``https://github.com/PaddlePaddle/PaddleClas/blob/release/2.3/docs/zh_CN/algorithm_introduction/ImageNet_models.md``
MobileNetV2预训练模型地址
``https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/MobileNetV2_pretrained.pdparams``
在PaddleClas代码库根目录创建pretrained文件夹,MobileNetV2预训练参数保存在该文件夹中。
#### ② 导出预测模型
PaddleClas代码库根目录执行如下命令,导出预测模型
```
python tools/export_model.py \
-c ppcls/configs/ImageNet/MobileNetV2/MobileNetV2.yaml \
-o Global.pretrained_model=pretrained/MobileNetV2_pretrained \
-o Global.save_inference_dir=infermodel_mobilenetv2 \
```
#### ③ 测试模型精度
使用
[
eval.py
](
../quant_post/eval.py
)
脚本得到模型的分类精度:
```
python ../quant_post/eval.py --model_path infermodel_mobilenetv2 --model_name inference.pdmodel --params_name inference.pdiparams
```
精度输出为:
```
top1_acc/top5_acc= [0.71918 0.90568]
```
### 3. 进行量化蒸馏训练
蒸馏量化训练示例脚本为
[
quant_aware_with_infermodel.py
](
./quant_aware_with_infermodel.py
)
,使用接口
``paddleslim.quant.quant_aware_with_infermodel``
对模型进行量化训练。运行命令为:
```
python quant_aware_with_infermodel.py \
--batch_size=2 \
--num_epoch=30 \
--save_iter_step=100 \
--learning_rate=0.0001 \
--weight_decay=0.00004 \
--use_pact=True \
--checkpoint_path="./inference_model/MobileNet_quantaware_ckpt/" \
--model_path="./infermodel_mobilenetv2/" \
--model_filename="inference.pdmodel" \
--params_filename="inference.pdiparams" \
--teacher_model_path="./infermodel_mobilenetv2/" \
--teacher_model_filename="inference.pdmodel" \
--teacher_params_filename="inference.pdiparams" \
--distill_node_name_list "teacher_conv2d_54.tmp_0" "conv2d_54.tmp_0" "teacher_conv2d_55.tmp_0" "conv2d_55.tmp_0" \
"teacher_conv2d_57.tmp_0" "conv2d_57.tmp_0" "teacher_elementwise_add_0" "elementwise_add_0" \
"teacher_conv2d_61.tmp_0" "conv2d_61.tmp_0" "teacher_elementwise_add_1" "elementwise_add_1" \
"teacher_elementwise_add_2" "elementwise_add_2" "teacher_conv2d_67.tmp_0" "conv2d_67.tmp_0" \
"teacher_elementwise_add_3" "elementwise_add_3" "teacher_elementwise_add_4" "elementwise_add_4" \
"teacher_elementwise_add_5" "elementwise_add_5" "teacher_conv2d_75.tmp_0" "conv2d_75.tmp_0" \
"teacher_elementwise_add_6" "elementwise_add_6" "teacher_elementwise_add_7" "elementwise_add_7" \
"teacher_conv2d_81.tmp_0" "conv2d_81.tmp_0" "teacher_elementwise_add_8" "elementwise_add_8" \
"teacher_elementwise_add_9" "elementwise_add_9" "teacher_conv2d_87.tmp_0" "conv2d_87.tmp_0" \
"teacher_linear_1.tmp_0" "linear_1.tmp_0"
```
-
``batch_size``
: 量化训练batch size。
-
``num_epoch``
: 量化训练epoch数。
-
``save_iter_step``
: 每隔save_iter_step保存一次checkpoint。
-
``learning_rate``
: 量化训练学习率,推荐使用float模型训练最小一级学习率。
-
``weight_decay``
: 推荐使用float模型训练weight decay设置。
-
``use_pact``
: 是否使用pact量化算法, 推荐使用。
-
``checkpoint_path``
: 量化训练模型checkpoint保存路径。
-
``model_path``
: 需要量化的预测模型所在路径。
-
``model_filename``
: 如果需要量化的模型的参数文件保存在一个文件中,则设置为该模型的模型文件名称,如果参数文件保存在多个文件中,则不需要设置。
-
``params_filename``
: 如果需要量化的模型的参数文件保存在一个文件中,则设置为该模型的参数文件名称,如果参数文件保存在多个文件中,则不需要设置。
-
``teacher_model_path``
: teacher模型所在路径, 可以和量化模型是同一个,即自蒸馏。
-
``teacher_model_filename``
: teacher模型model文件名字。
-
``teacher_params_filename``
: teacher模型参数文件名字。
-
``distill_node_name_list``
: 蒸馏节点名字列表,每两个节点组成一对,分别属于teacher模型和量化模型。
运行以上命令后,可在
``${checkpoint_path}``
下看到量化后模型的checkpoint。
### 4. 量化模型导出
量化模型checkpoint导出为预测模型。
```
python export_quantmodel.py \
--use_gpu=True \
--checkpoint_path="./MobileNetV2_checkpoints/epoch_0_iter_2000" \
--infermodel_save_path="./quant_infermodel_mobilenetv2/" \
```
###5. 测试精度
使用
[
eval.py
](
../quant_post/eval.py
)
脚本对量化后的模型进行精度测试:
```
python ../quant_post/eval.py --model_path ./quant_infermodel_mobilenetv2/ --model_name model --params_name params
```
精度输出为:
```
top1_acc/top5_acc= [0.71764 0.90418]
```
demo/quant/quant_aware_with_infermodel/export_quantmodel.py
0 → 100755
浏览文件 @
e406740a
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import
os
import
sys
import
math
import
time
import
numpy
as
np
import
paddle
import
logging
import
argparse
import
functools
sys
.
path
[
0
]
=
os
.
path
.
join
(
os
.
path
.
dirname
(
"__file__"
),
os
.
path
.
pardir
,
os
.
path
.
pardir
)
sys
.
path
[
1
]
=
os
.
path
.
join
(
os
.
path
.
dirname
(
"__file__"
),
os
.
path
.
pardir
,
os
.
path
.
pardir
,
os
.
path
.
pardir
)
from
paddleslim.common
import
get_logger
from
paddleslim.quant
import
export_quant_infermodel
from
utility
import
add_arguments
,
print_arguments
import
imagenet_reader
as
reader
_logger
=
get_logger
(
__name__
,
level
=
logging
.
INFO
)
parser
=
argparse
.
ArgumentParser
(
description
=
__doc__
)
add_arg
=
functools
.
partial
(
add_arguments
,
argparser
=
parser
)
# yapf: disable
add_arg
(
'use_gpu'
,
bool
,
True
,
"Whether to use GPU or not."
)
add_arg
(
'batch_size'
,
int
,
4
,
"train batch size."
)
add_arg
(
'num_epoch'
,
int
,
1
,
"train epoch num."
)
add_arg
(
'save_iter_step'
,
int
,
1
,
"save train checkpoint every save_iter_step iter num."
)
add_arg
(
'learning_rate'
,
float
,
0.0001
,
"learning rate."
)
add_arg
(
'weight_decay'
,
float
,
0.00004
,
"weight decay."
)
add_arg
(
'use_pact'
,
bool
,
True
,
"whether use pact quantization."
)
add_arg
(
'checkpoint_path'
,
str
,
None
,
"model dir to save quanted model checkpoints"
)
add_arg
(
'model_path_prefix'
,
str
,
None
,
"storage directory of model + model name (excluding suffix)"
)
add_arg
(
'teacher_model_path_prefix'
,
str
,
None
,
"storage directory of teacher model + teacher model name (excluding suffix)"
)
add_arg
(
'distill_node_name_list'
,
str
,
None
,
"distill node name list"
,
nargs
=
"+"
)
add_arg
(
'checkpoint_filename'
,
str
,
None
,
"checkpoint filename to export inference model"
)
add_arg
(
'export_inference_model_path_prefix'
,
str
,
None
,
"inference model export path prefix"
)
def
export
(
args
):
place
=
paddle
.
CUDAPlace
(
0
)
if
args
.
use_gpu
else
paddle
.
CPUPlace
()
exe
=
paddle
.
static
.
Executor
(
place
)
quant_config
=
{
'weight_quantize_type'
:
'channel_wise_abs_max'
,
'activation_quantize_type'
:
'moving_average_abs_max'
,
'not_quant_pattern'
:
[
'skip_quant'
],
'quantize_op_types'
:
[
'conv2d'
,
'depthwise_conv2d'
,
'mul'
]
}
train_config
=
{
"num_epoch"
:
args
.
num_epoch
,
# training epoch num
"max_iter"
:
-
1
,
"save_iter_step"
:
args
.
save_iter_step
,
"learning_rate"
:
args
.
learning_rate
,
"weight_decay"
:
args
.
weight_decay
,
"use_pact"
:
args
.
use_pact
,
"quant_model_ckpt_path"
:
args
.
checkpoint_path
,
"teacher_model_path_prefix"
:
args
.
teacher_model_path_prefix
,
"model_path_prefix"
:
args
.
model_path_prefix
,
"distill_node_pair"
:
args
.
distill_node_name_list
}
export_quant_infermodel
(
exe
,
place
,
scope
=
None
,
quant_config
=
quant_config
,
train_config
=
train_config
,
checkpoint_path
=
os
.
path
.
join
(
args
.
checkpoint_path
,
args
.
checkpoint_filename
),
export_inference_model_path_prefix
=
args
.
export_inference_model_path_prefix
)
def
main
():
args
=
parser
.
parse_args
()
args
.
use_pact
=
bool
(
args
.
use_pact
)
print_arguments
(
args
)
export
(
args
)
if
__name__
==
'__main__'
:
paddle
.
enable_static
()
main
()
demo/quant/quant_aware_with_infermodel/quant_aware_with_infermodel.py
0 → 100755
浏览文件 @
e406740a
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import
os
import
sys
import
math
import
time
import
numpy
as
np
import
paddle
import
logging
import
argparse
import
functools
sys
.
path
[
0
]
=
os
.
path
.
join
(
os
.
path
.
dirname
(
"__file__"
),
os
.
path
.
pardir
,
os
.
path
.
pardir
)
sys
.
path
[
1
]
=
os
.
path
.
join
(
os
.
path
.
dirname
(
"__file__"
),
os
.
path
.
pardir
,
os
.
path
.
pardir
,
os
.
path
.
pardir
)
from
paddleslim.common
import
get_logger
from
paddleslim.quant
import
quant_aware_with_infermodel
from
utility
import
add_arguments
,
print_arguments
import
imagenet_reader
as
reader
_logger
=
get_logger
(
__name__
,
level
=
logging
.
INFO
)
parser
=
argparse
.
ArgumentParser
(
description
=
__doc__
)
add_arg
=
functools
.
partial
(
add_arguments
,
argparser
=
parser
)
# yapf: disable
add_arg
(
'use_gpu'
,
bool
,
True
,
"whether to use GPU or not."
)
add_arg
(
'batch_size'
,
int
,
1
,
"train batch size."
)
add_arg
(
'num_epoch'
,
int
,
1
,
"train epoch num."
)
add_arg
(
'save_iter_step'
,
int
,
1
,
"save train checkpoint every save_iter_step iter num."
)
add_arg
(
'learning_rate'
,
float
,
0.0001
,
"learning rate."
)
add_arg
(
'weight_decay'
,
float
,
0.00004
,
"weight decay."
)
add_arg
(
'use_pact'
,
bool
,
True
,
"whether use pact quantization."
)
add_arg
(
'checkpoint_path'
,
str
,
None
,
"model dir to save quanted model checkpoints"
)
add_arg
(
'model_path_prefix'
,
str
,
None
,
"storage directory of model + model name (excluding suffix)"
)
add_arg
(
'teacher_model_path_prefix'
,
str
,
None
,
"storage directory of teacher model + teacher model name (excluding suffix)"
)
add_arg
(
'distill_node_name_list'
,
str
,
None
,
"distill node name list"
,
nargs
=
"+"
)
DATA_DIR
=
"../../data/ILSVRC2012/"
def
eval
(
exe
,
place
,
compiled_test_program
,
test_feed_names
,
test_fetch_list
):
val_reader
=
paddle
.
batch
(
reader
.
val
(),
batch_size
=
1
)
image
=
paddle
.
static
.
data
(
name
=
'x'
,
shape
=
[
None
,
3
,
224
,
224
],
dtype
=
'float32'
)
label
=
paddle
.
static
.
data
(
name
=
'label'
,
shape
=
[
None
,
1
],
dtype
=
'int64'
)
results
=
[]
for
batch_id
,
data
in
enumerate
(
val_reader
()):
# top1_acc, top5_acc
if
len
(
test_feed_names
)
==
1
:
# eval "infer model", which input is image, output is classification probability
image
=
data
[
0
][
0
].
reshape
((
1
,
3
,
224
,
224
))
label
=
[[
d
[
1
]]
for
d
in
data
]
pred
=
exe
.
run
(
compiled_test_program
,
feed
=
{
test_feed_names
[
0
]:
image
},
fetch_list
=
test_fetch_list
)
pred
=
np
.
array
(
pred
[
0
])
label
=
np
.
array
(
label
)
sort_array
=
pred
.
argsort
(
axis
=
1
)
top_1_pred
=
sort_array
[:,
-
1
:][:,
::
-
1
]
top_1
=
np
.
mean
(
label
==
top_1_pred
)
top_5_pred
=
sort_array
[:,
-
5
:][:,
::
-
1
]
acc_num
=
0
for
i
in
range
(
len
(
label
)):
if
label
[
i
][
0
]
in
top_5_pred
[
i
]:
acc_num
+=
1
top_5
=
float
(
acc_num
)
/
len
(
label
)
results
.
append
([
top_1
,
top_5
])
else
:
# eval "eval model", which inputs are image and label, output is top1 and top5 accuracy
image
=
data
[
0
][
0
].
reshape
((
1
,
3
,
224
,
224
))
label
=
[[
d
[
1
]]
for
d
in
data
]
result
=
exe
.
run
(
compiled_test_program
,
feed
=
{
test_feed_names
[
0
]:
image
,
test_feed_names
[
1
]:
label
},
fetch_list
=
test_fetch_list
)
result
=
[
np
.
mean
(
r
)
for
r
in
result
]
results
.
append
(
result
)
result
=
np
.
mean
(
np
.
array
(
results
),
axis
=
0
)
return
result
def
quantize
(
args
):
place
=
paddle
.
CUDAPlace
(
0
)
if
args
.
use_gpu
else
paddle
.
CPUPlace
()
#place = paddle.CPUPlace()
exe
=
paddle
.
static
.
Executor
(
place
)
quant_config
=
{
'weight_quantize_type'
:
'channel_wise_abs_max'
,
'activation_quantize_type'
:
'moving_average_abs_max'
,
'not_quant_pattern'
:
[
'skip_quant'
],
'quantize_op_types'
:
[
'conv2d'
,
'depthwise_conv2d'
,
'mul'
]
}
train_config
=
{
"num_epoch"
:
args
.
num_epoch
,
# training epoch num
"max_iter"
:
-
1
,
"save_iter_step"
:
args
.
save_iter_step
,
"learning_rate"
:
args
.
learning_rate
,
"weight_decay"
:
args
.
weight_decay
,
"use_pact"
:
args
.
use_pact
,
"quant_model_ckpt_path"
:
args
.
checkpoint_path
,
"teacher_model_path_prefix"
:
args
.
teacher_model_path_prefix
,
"model_path_prefix"
:
args
.
model_path_prefix
,
"distill_node_pair"
:
args
.
distill_node_name_list
}
def
test_callback
(
compiled_test_program
,
feed_names
,
fetch_list
,
checkpoint_name
):
ret
=
eval
(
exe
,
place
,
compiled_test_program
,
feed_names
,
fetch_list
)
print
(
"{0} top1_acc/top5_acc= {1}"
.
format
(
checkpoint_name
,
ret
))
train_reader
=
paddle
.
batch
(
reader
.
train
(),
batch_size
=
args
.
batch_size
)
def
train_reader_wrapper
():
def
gen
():
for
i
,
data
in
enumerate
(
train_reader
()):
imgs
=
np
.
float32
([
item
[
0
]
for
item
in
data
])
yield
{
"x"
:
imgs
}
return
gen
quant_aware_with_infermodel
(
exe
,
place
,
scope
=
None
,
train_reader
=
train_reader_wrapper
(),
quant_config
=
quant_config
,
train_config
=
train_config
,
test_callback
=
test_callback
)
def
main
():
args
=
parser
.
parse_args
()
args
.
use_pact
=
bool
(
args
.
use_pact
)
print
(
"args.use_pact"
,
args
.
use_pact
)
print_arguments
(
args
)
quantize
(
args
)
if
__name__
==
'__main__'
:
paddle
.
enable_static
()
main
()
demo/quant/quant_post_hpo/README.md
0 → 100755
浏览文件 @
e406740a
# 静态离线量化超参搜索示例
本示例将介绍如何使用离线量化超参搜索接口
``paddleslim.quant.quant_post_hpo``
来对训练好的分类模型进行离线量化超参搜索。
## 分类模型的离线量化超参搜索流程
### 准备数据
在
``demo``
文件夹下创建
``data``
文件夹,将
``ImageNet``
数据集解压在
``data``
文件夹下,解压后
``data/ILSVRC2012``
文件夹下应包含以下文件:
-
``'train'``
文件夹,训练图片
-
``'train_list.txt'``
文件
-
``'val'``
文件夹,验证图片
-
``'val_list.txt'``
文件
### 准备需要量化的模型
离线量化接口只支持加载通过
``paddle.static.save_inference_model``
接口保存的模型。因此如果您的模型是通过其他接口保存的,需要先将模型进行转化。本示例将以分类模型为例进行说明。
首先在
[
imagenet分类模型
](
https://github.com/PaddlePaddle/models/tree/develop/PaddleCV/image_classification#%E5%B7%B2%E5%8F%91%E5%B8%83%E6%A8%A1%E5%9E%8B%E5%8F%8A%E5%85%B6%E6%80%A7%E8%83%BD
)
中下载训练好的
``mobilenetv1``
模型。
在当前文件夹下创建
``'pretrain'``
文件夹,将
``mobilenetv1``
模型在该文件夹下解压,解压后的目录为
``pretrain/MobileNetV1_pretrained``
### 导出模型
通过运行以下命令可将模型转化为离线量化接口可用的模型:
```
python ../quant_post/export_model.py --model "MobileNet" --pretrained_model ./pretrain/MobileNetV1_pretrained --data imagenet
```
转化之后的模型存储在
``inference_model/MobileNet/``
文件夹下,可看到该文件夹下有
``'model'``
,
``'weights'``
两个文件。
### 静态离线量化
接下来对导出的模型文件进行静态离线量化,静态离线量化的脚本为
[
quant_post_hpo.py
](
./quant_post_hpo.py
)
,脚本中使用接口
``paddleslim.quant.quant_post_hpo``
对模型进行离线量化。运行命令为:
```
python quant_post_hpo.py \
--use_gpu=True \
--model_path="./inference_model/MobileNet/" \
--save_path="./inference_model/MobileNet_quant/" \
--model_filename="model" \
--params_filename="weights" \
--max_model_quant_count=26
```
-
``model_path``
: 需要量化的模型所在路径
-
``save_path``
: 量化后的模型保存的路径
-
``model_filename``
: 如果需要量化的模型的参数文件保存在一个文件中,则设置为该模型的模型文件名称,如果参数文件保存在多个文件中,则不需要设置。
-
``params_filename``
: 如果需要量化的模型的参数文件保存在一个文件中,则设置为该模型的参数文件名称,如果参数文件保存在多个文件中,则不需要设置。
-
``max_model_quant_count``
: 最大离线量化搜索次数,次数越多产出高精度量化模型概率越大,耗时也会相应增加。建议值:大于20小于30。
运行以上命令后,可在
``${save_path}``
下看到量化后的模型文件和参数文件。
### 测试精度
使用
[
eval.py
](
../quant_post/eval.py
)
脚本对量化前后的模型进行测试,得到模型的分类精度进行对比。
首先测试量化前的模型的精度,运行以下命令:
```
python ../quant_post/eval.py --model_path ./inference_model/MobileNet --model_name model --params_name weights
```
精度输出为:
```
top1_acc/top5_acc= [0.70898 0.89534]
```
使用以下命令测试离线量化后的模型的精度:
```
python ../quant_post/eval.py --model_path ./inference_model/MobileNet_quant/ --model_name __model__ --params_name __params__
```
精度输出为
```
top1_acc/top5_acc= [0.70653 0.89369]
```
paddleslim/dist/single_distiller.py
浏览文件 @
e406740a
...
@@ -14,6 +14,7 @@
...
@@ -14,6 +14,7 @@
import
numpy
as
np
import
numpy
as
np
import
paddle
import
paddle
from
paddleslim.core
import
GraphWrapper
def
merge
(
teacher_program
,
def
merge
(
teacher_program
,
...
@@ -94,6 +95,16 @@ def merge(teacher_program,
...
@@ -94,6 +95,16 @@ def merge(teacher_program,
student_program
.
global_block
().
append_op
(
student_program
.
global_block
().
append_op
(
type
=
op
.
type
,
inputs
=
inputs
,
outputs
=
outputs
,
attrs
=
attrs
)
type
=
op
.
type
,
inputs
=
inputs
,
outputs
=
outputs
,
attrs
=
attrs
)
student_graph
=
GraphWrapper
(
student_program
)
for
op
in
student_graph
.
ops
():
belongsto_teacher
=
False
for
inp
in
op
.
all_inputs
():
if
'teacher'
in
inp
.
name
():
belongsto_teacher
=
True
break
if
belongsto_teacher
:
op
.
_op
.
_set_attr
(
"skip_quant"
,
True
)
def
fsp_loss
(
teacher_var1_name
,
def
fsp_loss
(
teacher_var1_name
,
teacher_var2_name
,
teacher_var2_name
,
...
...
paddleslim/quant/__init__.py
浏览文件 @
e406740a
...
@@ -31,6 +31,7 @@ try:
...
@@ -31,6 +31,7 @@ try:
],
"training-aware and post-training quant is not supported in 2.0 alpha version paddle"
],
"training-aware and post-training quant is not supported in 2.0 alpha version paddle"
from
.quanter
import
quant_aware
,
convert
,
quant_post_static
,
quant_post_dynamic
from
.quanter
import
quant_aware
,
convert
,
quant_post_static
,
quant_post_dynamic
from
.quanter
import
quant_post
,
quant_post_only_weight
from
.quanter
import
quant_post
,
quant_post_only_weight
from
.quant_aware_with_infermodel
import
quant_aware_with_infermodel
,
export_quant_infermodel
from
.quant_post_hpo
import
quant_post_hpo
from
.quant_post_hpo
import
quant_post_hpo
except
Exception
as
e
:
except
Exception
as
e
:
_logger
.
warning
(
e
)
_logger
.
warning
(
e
)
...
...
paddleslim/quant/quant_aware_with_infermodel.py
0 → 100644
浏览文件 @
e406740a
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""train aware quant with infermodel"""
import
copy
import
os
import
argparse
import
json
import
six
from
collections
import
namedtuple
import
time
import
shutil
import
numpy
as
np
import
paddle
import
paddle.fluid
as
fluid
from
paddle.fluid
import
unique_name
from
paddle.fluid
import
core
from
paddle.fluid.framework
import
Parameter
from
paddleslim.dist
import
merge
,
l2_loss
,
soft_label_loss
,
fsp_loss
from
paddleslim.core
import
GraphWrapper
from
paddleslim.quant
import
quant_aware
,
convert
from
.quanter
import
_quant_config_default
,
_parse_configs
,
pact
,
get_pact_optimizer
import
logging
logging
.
getLogger
().
setLevel
(
logging
.
INFO
)
from
..common
import
get_logger
_logger
=
get_logger
(
__name__
,
level
=
logging
.
INFO
)
############################################################################################################
# quantization training configs
############################################################################################################
_train_config_default
=
{
# configs of training aware quantization with infermodel
"num_epoch"
:
1000
,
# training epoch num
"max_iter"
:
-
1
,
# max training iteration num
"save_iter_step"
:
1000
,
# save quant model checkpoint every save_iter_step iteration
"learning_rate"
:
0.0001
,
# learning rate
"weight_decay"
:
0.0001
,
# weight decay
"use_pact"
:
False
,
# use pact quantization or not
# quant model checkpoints save path
"quant_model_ckpt_path"
:
"./quant_model_checkpoints/"
,
# storage directory of teacher model + teacher model name (excluding suffix)
"teacher_model_path_prefix"
:
None
,
# storage directory of model + model name (excluding suffix)
"model_path_prefix"
:
None
,
""" distillation node configuration:
the name of the distillation supervision nodes is configured as a list,
and the teacher node and student node are arranged in pairs.
for example, ["teacher_fc_0.tmp_0", "fc_0.tmp_0", "teacher_batch_norm_24.tmp_4", "batch_norm_24.tmp_4"]
"""
"distill_node_pair"
:
None
}
def
_parse_train_configs
(
train_config
):
"""
check if user's train configs are valid.
Args:
train_config(dict): user's train config.
Return:
configs(dict): final configs will be used.
"""
configs
=
copy
.
deepcopy
(
_train_config_default
)
configs
.
update
(
train_config
)
assert
isinstance
(
configs
[
'num_epoch'
],
int
),
\
"'num_epoch' must be int value"
assert
isinstance
(
configs
[
'max_iter'
],
int
),
\
"'max_iter' must be int value"
assert
isinstance
(
configs
[
'save_iter_step'
],
int
),
\
"'save_iter_step' must be int value"
assert
isinstance
(
configs
[
'learning_rate'
],
float
),
\
"'learning_rate' must be float"
assert
isinstance
(
configs
[
'weight_decay'
],
float
),
\
"'weight_decay' must be float"
assert
isinstance
(
configs
[
'use_pact'
],
bool
),
\
"'use_pact' must be bool"
assert
isinstance
(
configs
[
'quant_model_ckpt_path'
],
str
),
\
"'quant_model_ckpt_path' must be str"
assert
isinstance
(
configs
[
'teacher_model_path_prefix'
],
str
),
\
"'teacher_model_path_prefix' must both be string"
assert
isinstance
(
configs
[
'model_path_prefix'
],
str
),
\
"'model_path_prefix' must both be str"
assert
isinstance
(
configs
[
'distill_node_pair'
],
list
),
\
"'distill_node_pair' must both be list"
assert
len
(
configs
[
'distill_node_pair'
])
>
0
,
\
"'distill_node_pair' not configured with distillation nodes"
assert
len
(
configs
[
'distill_node_pair'
])
%
2
==
0
,
\
"'distill_node_pair' distillation nodes need to be configured in pairs"
return
train_config
def
_create_optimizer
(
train_config
):
"""create optimizer"""
optimizer
=
paddle
.
optimizer
.
SGD
(
learning_rate
=
train_config
[
"learning_rate"
],
weight_decay
=
paddle
.
regularizer
.
L2Decay
(
train_config
[
"weight_decay"
]))
return
optimizer
def
_remove_fetch_node
(
program
):
"""remove fetch node in program"""
for
block
in
program
.
blocks
:
removed
=
0
ops
=
list
(
block
.
ops
)
for
op
in
ops
:
if
op
.
type
==
"fetch"
:
idx
=
ops
.
index
(
op
)
block
.
_remove_op
(
idx
-
removed
)
removed
+=
1
def
_recover_reserve_space_with_bn
(
program
):
"""Add the outputs which is only used for training and not saved in
inference program."""
for
block_idx
in
six
.
moves
.
range
(
program
.
num_blocks
):
block
=
program
.
block
(
block_idx
)
for
op
in
block
.
ops
:
if
op
.
type
==
"batch_norm"
:
if
"ReserveSpace"
not
in
op
.
output_names
or
len
(
op
.
output
(
"ReserveSpace"
))
==
0
:
reserve_space
=
block
.
create_var
(
name
=
unique_name
.
generate_with_ignorable_key
(
"."
.
join
(
[
"reserve_space"
,
'tmp'
])),
dtype
=
block
.
var
(
op
.
input
(
"X"
)[
0
]).
dtype
,
type
=
core
.
VarDesc
.
VarType
.
LOD_TENSOR
,
persistable
=
False
,
stop_gradient
=
True
)
op
.
desc
.
set_output
(
"ReserveSpace"
,
[
reserve_space
.
name
])
return
program
def
_recover_param_attr
(
program
):
"""recover parameters attribute.
Params in infermodel are stored in the form of variable, which can not be trained."""
all_weights
=
[
param
for
param
in
program
.
list_vars
()
\
if
param
.
persistable
is
True
and
param
.
name
!=
'feed'
and
param
.
name
!=
'fetch'
]
for
w
in
all_weights
:
new_w
=
Parameter
(
block
=
program
.
block
(
0
),
shape
=
w
.
shape
,
dtype
=
w
.
dtype
,
type
=
w
.
type
,
name
=
w
.
name
)
new_w
.
set_value
(
w
.
get_value
())
program
.
block
(
0
).
vars
[
w
.
name
]
=
new_w
return
program
def
_parse_distill_loss
(
train_config
):
"""parse distill loss config"""
assert
len
(
train_config
[
"distill_node_pair"
])
%
2
==
0
,
\
"distill_node_pair config wrong, the length needs to be an even number"
print
(
"train config.distill_node_pair: "
,
train_config
[
"distill_node_pair"
])
distill_loss
=
0
for
i
in
range
(
len
(
train_config
[
"distill_node_pair"
])
//
2
):
print
(
train_config
[
"distill_node_pair"
][
i
*
2
],
train_config
[
"distill_node_pair"
][
i
*
2
+
1
])
distill_loss
+=
l2_loss
(
train_config
[
"distill_node_pair"
][
i
*
2
],
train_config
[
"distill_node_pair"
][
i
*
2
+
1
])
return
distill_loss
DistillProgramInfo
=
namedtuple
(
"DistillProgramInfo"
,
\
"startup_program train_program train_feed_names train_fetch_list
\
optimizer test_program test_feed_names test_fetch_list"
)
def
build_distill_prog_with_infermodel
(
executor
,
place
,
train_config
):
"""build distill program with infermodel"""
[
train_program
,
feed_target_names
,
fetch_targets
]
=
paddle
.
static
.
load_inference_model
(
\
path_prefix
=
train_config
[
"model_path_prefix"
],
\
executor
=
executor
)
_remove_fetch_node
(
train_program
)
[
teacher_program
,
teacher_feed_target_names
,
teacher_fetch_targets
]
=
paddle
.
static
.
load_inference_model
(
\
path_prefix
=
train_config
[
"teacher_model_path_prefix"
],
\
executor
=
executor
)
_remove_fetch_node
(
teacher_program
)
test_program
=
train_program
.
clone
(
for_test
=
True
)
train_program
=
_recover_param_attr
(
train_program
)
train_program
=
_recover_reserve_space_with_bn
(
train_program
)
for
var
in
train_program
.
list_vars
():
var
.
stop_gradient
=
False
train_graph
=
GraphWrapper
(
train_program
)
for
op
in
train_graph
.
ops
():
op
.
_op
.
_set_attr
(
"is_test"
,
False
)
############################################################################
# distill
############################################################################
data_name_map
=
{}
assert
len
(
feed_target_names
)
==
len
(
teacher_feed_target_names
),
\
"the number of feed nodes in the teacher model is not equal to the student model"
for
i
,
name
in
enumerate
(
feed_target_names
):
data_name_map
[
teacher_feed_target_names
[
i
]]
=
name
merge
(
teacher_program
,
train_program
,
data_name_map
,
place
)
# all feed node should set stop_gradient is False, for using pact quant algo.
for
var
in
train_program
.
list_vars
():
if
var
.
name
in
data_name_map
.
values
()
or
var
.
name
in
data_name_map
.
keys
(
):
var
.
stop_gradient
=
False
train_fetch_list
=
[]
train_fetch_name_list
=
[]
startup_program
=
paddle
.
static
.
Program
()
with
paddle
.
static
.
program_guard
(
train_program
,
startup_program
):
with
fluid
.
unique_name
.
guard
(
'merge'
):
optimizer
=
_create_optimizer
(
train_config
)
distill_loss
=
_parse_distill_loss
(
train_config
)
loss
=
paddle
.
mean
(
distill_loss
)
loss
.
stop_gradient
=
False
p_g_list
=
paddle
.
static
.
append_backward
(
loss
=
loss
)
opts
=
optimizer
.
apply_gradients
(
p_g_list
)
train_fetch_list
.
append
(
loss
)
train_fetch_name_list
.
append
(
loss
.
name
)
return
DistillProgramInfo
(
startup_program
,
train_program
,
\
feed_target_names
,
train_fetch_list
,
optimizer
,
\
test_program
,
feed_target_names
,
fetch_targets
)
def
_compile_program
(
program
,
fetch_var_name
):
"""compiling program"""
compiled_prog
=
paddle
.
static
.
CompiledProgram
(
program
)
build_strategy
=
paddle
.
static
.
BuildStrategy
()
build_strategy
.
memory_optimize
=
False
build_strategy
.
enable_inplace
=
False
build_strategy
.
fuse_all_reduce_ops
=
False
build_strategy
.
sync_batch_norm
=
False
exec_strategy
=
paddle
.
static
.
ExecutionStrategy
()
compiled_prog
=
compiled_prog
.
with_data_parallel
(
loss_name
=
fetch_var_name
,
build_strategy
=
build_strategy
,
exec_strategy
=
exec_strategy
)
return
compiled_prog
def
quant_aware_with_infermodel
(
executor
,
place
,
scope
=
None
,
train_reader
=
None
,
quant_config
=
None
,
train_config
=
None
,
test_callback
=
None
):
"""train aware quantization with infermodel
Args:
executor(paddle.static.Executor): The executor to load, run and save the
quantized model.
place(paddle.CPUPlace or paddle.CUDAPlace): This parameter represents
the executor run on which device.
scope(paddle.static.Scope, optional): Scope records the mapping between
variable names and variables, similar to brackets in
programming languages. Usually users can use
`paddle.static.global_scope <https://www.paddlepaddle.org.cn/documentation/docs/zh/develop/api_cn/executor_cn/global_scope_cn.html>`_.
When ``None`` will use
`paddle.static.global_scope() <https://www.paddlepaddle.org.cn/documentation/docs/zh/develop/api_cn/executor_cn/global_scope_cn.html>`_
. Default: ``None``.
train_reader(data generator): data generator, yield feed_dictionary, {feed_name[0]:data[0], feed_name[1]:data[1]}.
quant_config(dict, optional): configs for convert. if set None, will use
default config. It must be same with config that used in
'quant_aware'. Default is None.
train_config(dict):train aware configs, include num_epoch, save_iter_step, learning_rate,
weight_decay, use_pact, quant_model_ckpt_path,
model_path_prefix, teacher_model_path_prefix,
distill_node_pair(teacher_node_name1, node_name1, teacher_node_name2, teacher_node_name2, ...)
test_callback(callback function): callback function include two params: compiled test quant program and checkpoint save filename.
user can implement test logic.
Returns:
None
"""
scope
=
paddle
.
static
.
global_scope
()
if
not
scope
else
scope
# parse quant config
if
quant_config
is
None
:
quant_config
=
_quant_config_default
else
:
assert
isinstance
(
quant_config
,
dict
),
"quant config must be dict"
quant_config
=
_parse_configs
(
quant_config
)
_logger
.
info
(
"quant_aware config {}"
.
format
(
quant_config
))
train_config
=
_parse_train_configs
(
train_config
)
distill_program_info
=
build_distill_prog_with_infermodel
(
executor
,
place
,
train_config
)
startup_program
=
distill_program_info
.
startup_program
train_program
=
distill_program_info
.
train_program
train_feed_names
=
distill_program_info
.
train_feed_names
train_fetch_list
=
distill_program_info
.
train_fetch_list
optimizer
=
distill_program_info
.
optimizer
test_program
=
distill_program_info
.
test_program
test_feed_names
=
distill_program_info
.
test_feed_names
test_fetch_list
=
distill_program_info
.
test_fetch_list
############################################################################
# quant
############################################################################
use_pact
=
train_config
[
"use_pact"
]
if
use_pact
:
act_preprocess_func
=
pact
optimizer_func
=
get_pact_optimizer
pact_executor
=
executor
else
:
act_preprocess_func
=
None
optimizer_func
=
None
pact_executor
=
None
test_program
=
quant_aware
(
test_program
,
place
,
quant_config
,
scope
=
scope
,
act_preprocess_func
=
act_preprocess_func
,
optimizer_func
=
optimizer_func
,
executor
=
pact_executor
,
for_test
=
True
)
train_program
=
quant_aware
(
train_program
,
place
,
quant_config
,
scope
=
scope
,
act_preprocess_func
=
act_preprocess_func
,
optimizer_func
=
optimizer_func
,
executor
=
pact_executor
,
for_test
=
False
,
return_program
=
True
)
executor
.
run
(
startup_program
)
compiled_train_prog
=
_compile_program
(
train_program
,
train_fetch_list
[
0
].
name
)
compiled_test_prog
=
_compile_program
(
test_program
,
test_fetch_list
[
0
].
name
)
num_epoch
=
train_config
[
"num_epoch"
]
save_iter_step
=
train_config
[
"save_iter_step"
]
iter_sum
=
0
for
epoch
in
range
(
num_epoch
):
for
iter_num
,
feed_dict
in
enumerate
(
train_reader
()):
np_probs_float
=
executor
.
run
(
compiled_train_prog
,
\
feed
=
feed_dict
,
\
fetch_list
=
train_fetch_list
)
print
(
"loss: "
,
np_probs_float
)
if
iter_num
>
0
and
iter_num
%
save_iter_step
==
0
:
checkpoint_name
=
"epoch_"
+
str
(
epoch
)
+
"_iter_"
+
str
(
iter_num
)
paddle
.
static
.
save
(
program
=
test_program
,
model_path
=
os
.
path
.
join
(
train_config
[
"quant_model_ckpt_path"
],
checkpoint_name
))
test_callback
(
compiled_test_prog
,
test_feed_names
,
test_fetch_list
,
checkpoint_name
)
iter_sum
+=
1
if
train_config
[
"max_iter"
]
>=
0
and
iter_sum
>
train_config
[
"max_iter"
]:
return
def
export_quant_infermodel
(
executor
,
place
=
None
,
scope
=
None
,
quant_config
=
None
,
train_config
=
None
,
checkpoint_path
=
None
,
export_inference_model_path_prefix
=
"./export_quant_infermodel"
):
"""export quant model checkpoints to infermodel.
Args:
executor(paddle.static.Executor): The executor to load, run and save the
quantized model.
place(paddle.CPUPlace or paddle.CUDAPlace): This parameter represents
the executor run on which device.
scope(paddle.static.Scope, optional): Scope records the mapping between
variable names and variables, similar to brackets in
programming languages. Usually users can use
`paddle.static.global_scope <https://www.paddlepaddle.org.cn/documentation/docs/zh/develop/api_cn/executor_cn/global_scope_cn.html>`_.
When ``None`` will use
`paddle.static.global_scope() <https://www.paddlepaddle.org.cn/documentation/docs/zh/develop/api_cn/executor_cn/global_scope_cn.html>`_
. Default: ``None``.
quant_config(dict, optional): configs for convert. if set None, will use
default config. It must be same with config that used in
'quant_aware'. Default is None.
train_config(dict):train aware configs, include num_epoch, save_iter_step, learning_rate,
weight_decay, use_pact, quant_model_ckpt_path,
model_path_prefix, teacher_model_path_prefix,
distill_node_pair(teacher_node_name1, node_name1, teacher_node_name2, teacher_node_name2, ...)
checkpoint_path(str): checkpoint path need to export quant infer model.
export_inference_model_path_prefix(str): export infer model path prefix, storage directory of model + model name (excluding suffix).
Returns:
None
"""
scope
=
paddle
.
static
.
global_scope
()
if
not
scope
else
scope
# parse quant config
if
quant_config
is
None
:
quant_config
=
_quant_config_default
else
:
assert
isinstance
(
quant_config
,
dict
),
"quant config must be dict"
quant_config
=
_parse_configs
(
quant_config
)
_logger
.
info
(
"quant_aware config {}"
.
format
(
quant_config
))
train_config
=
_parse_train_configs
(
train_config
)
distill_program_info
=
build_distill_prog_with_infermodel
(
executor
,
place
,
train_config
)
test_program
=
distill_program_info
.
test_program
test_feed_names
=
distill_program_info
.
test_feed_names
test_fetch_list
=
distill_program_info
.
test_fetch_list
############################################################################
# quant
############################################################################
use_pact
=
train_config
[
"use_pact"
]
if
use_pact
:
act_preprocess_func
=
pact
optimizer_func
=
get_pact_optimizer
pact_executor
=
executor
else
:
act_preprocess_func
=
None
optimizer_func
=
None
pact_executor
=
None
test_program
=
quant_aware
(
test_program
,
place
,
quant_config
,
scope
=
scope
,
act_preprocess_func
=
act_preprocess_func
,
optimizer_func
=
optimizer_func
,
executor
=
pact_executor
,
for_test
=
True
)
paddle
.
static
.
load
(
executor
=
executor
,
model_path
=
os
.
path
.
join
(
checkpoint_path
),
program
=
test_program
)
############################################################################################################
# 3. Freeze the graph after training by adjusting the quantize
# operators' order for the inference.
# The dtype of float_program's weights is float32, but in int8 range.
############################################################################################################
float_program
,
int8_program
=
convert
(
test_program
,
place
,
quant_config
,
\
scope
=
scope
,
\
save_int8
=
True
)
############################################################################################################
# 4. Save inference model
############################################################################################################
export_model_dir
=
os
.
path
.
abspath
(
os
.
path
.
join
(
export_inference_model_path_prefix
,
os
.
path
.
pardir
))
if
not
os
.
path
.
exists
(
export_model_dir
):
os
.
makedirs
(
export_model_dir
)
feed_vars
=
[]
for
name
in
test_feed_names
:
for
var
in
float_program
.
list_vars
():
if
var
.
name
==
name
:
feed_vars
.
append
(
var
)
break
assert
len
(
feed_vars
)
>
0
,
"can not find feed vars in quant program"
paddle
.
static
.
save_inference_model
(
path_prefix
=
export_inference_model_path_prefix
,
feed_vars
=
feed_vars
,
fetch_vars
=
test_fetch_list
,
executor
=
executor
,
program
=
float_program
)
paddleslim/quant/quanter.py
浏览文件 @
e406740a
...
@@ -29,6 +29,7 @@ from paddle.fluid.contrib.slim.quantization import OutScaleForTrainingPass
...
@@ -29,6 +29,7 @@ from paddle.fluid.contrib.slim.quantization import OutScaleForTrainingPass
from
paddle.fluid.contrib.slim.quantization
import
OutScaleForInferencePass
from
paddle.fluid.contrib.slim.quantization
import
OutScaleForInferencePass
from
paddle.fluid
import
core
from
paddle.fluid
import
core
from
paddle.fluid.contrib.slim.quantization
import
WeightQuantization
from
paddle.fluid.contrib.slim.quantization
import
WeightQuantization
from
paddle.fluid.layer_helper
import
LayerHelper
from
..common
import
get_logger
from
..common
import
get_logger
_logger
=
get_logger
(
__name__
,
level
=
logging
.
INFO
)
_logger
=
get_logger
(
__name__
,
level
=
logging
.
INFO
)
...
@@ -561,3 +562,30 @@ def quant_post_dynamic(model_dir,
...
@@ -561,3 +562,30 @@ def quant_post_dynamic(model_dir,
# For compatibility, we keep quant_post_only_weight api for now,
# For compatibility, we keep quant_post_only_weight api for now,
# and it will be deprecated in the future.
# and it will be deprecated in the future.
quant_post_only_weight
=
quant_post_dynamic
quant_post_only_weight
=
quant_post_dynamic
def
pact
(
x
,
name
=
None
):
helper
=
LayerHelper
(
"pact"
,
**
locals
())
dtype
=
'float32'
init_thres
=
20
u_param_attr
=
paddle
.
fluid
.
ParamAttr
(
name
=
x
.
name
+
'_pact'
,
initializer
=
paddle
.
fluid
.
initializer
.
ConstantInitializer
(
value
=
init_thres
),
regularizer
=
paddle
.
fluid
.
regularizer
.
L2Decay
(
0.0001
),
learning_rate
=
1
)
u_param
=
helper
.
create_parameter
(
attr
=
u_param_attr
,
shape
=
[
1
],
dtype
=
dtype
)
x
=
paddle
.
fluid
.
layers
.
elementwise_sub
(
x
,
paddle
.
fluid
.
layers
.
relu
(
paddle
.
fluid
.
layers
.
elementwise_sub
(
x
,
u_param
)))
x
=
paddle
.
fluid
.
layers
.
elementwise_add
(
x
,
paddle
.
fluid
.
layers
.
relu
(
paddle
.
fluid
.
layers
.
elementwise_sub
(
-
u_param
,
x
)))
return
x
def
get_pact_optimizer
():
return
paddle
.
fluid
.
optimizer
.
MomentumOptimizer
(
0.0001
,
0.9
)
tests/test_quant_aware_with_infermodel.py
0 → 100644
浏览文件 @
e406740a
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import
sys
import
os
sys
.
path
.
append
(
"../"
)
sys
.
path
.
append
(
"."
)
sys
.
path
[
0
]
=
os
.
path
.
join
(
os
.
path
.
dirname
(
"__file__"
),
os
.
path
.
pardir
)
import
unittest
import
paddle
from
paddleslim.quant
import
quant_aware
,
convert
from
paddleslim.quant
import
quant_aware_with_infermodel
,
export_quant_infermodel
from
static_case
import
StaticCase
sys
.
path
.
append
(
"../demo"
)
from
models
import
MobileNet
from
layers
import
conv_bn_layer
import
paddle.dataset.mnist
as
reader
from
paddle.fluid.framework
import
IrGraph
from
paddle.fluid
import
core
import
numpy
as
np
class
TestQuantAwareWithInferModelCase1
(
StaticCase
):
def
test_accuracy
(
self
):
float_infer_model_path_prefix
=
"./mv1_float_inference"
image
=
paddle
.
static
.
data
(
name
=
'image'
,
shape
=
[
None
,
1
,
28
,
28
],
dtype
=
'float32'
)
label
=
paddle
.
static
.
data
(
name
=
'label'
,
shape
=
[
None
,
1
],
dtype
=
'int64'
)
model
=
MobileNet
()
out
=
model
.
net
(
input
=
image
,
class_dim
=
10
)
cost
=
paddle
.
nn
.
functional
.
loss
.
cross_entropy
(
input
=
out
,
label
=
label
)
avg_cost
=
paddle
.
mean
(
x
=
cost
)
acc_top1
=
paddle
.
metric
.
accuracy
(
input
=
out
,
label
=
label
,
k
=
1
)
acc_top5
=
paddle
.
metric
.
accuracy
(
input
=
out
,
label
=
label
,
k
=
5
)
optimizer
=
paddle
.
optimizer
.
Momentum
(
momentum
=
0.9
,
learning_rate
=
0.01
,
weight_decay
=
paddle
.
regularizer
.
L2Decay
(
4e-5
))
optimizer
.
minimize
(
avg_cost
)
main_prog
=
paddle
.
static
.
default_main_program
()
val_prog
=
main_prog
.
clone
(
for_test
=
True
)
#place = paddle.CPUPlace()
place
=
paddle
.
CUDAPlace
(
0
)
if
paddle
.
is_compiled_with_cuda
(
)
else
paddle
.
CPUPlace
()
exe
=
paddle
.
static
.
Executor
(
place
)
exe
.
run
(
paddle
.
static
.
default_startup_program
())
def
transform
(
x
):
return
np
.
reshape
(
x
,
[
1
,
28
,
28
])
train_dataset
=
paddle
.
vision
.
datasets
.
MNIST
(
mode
=
'train'
,
backend
=
'cv2'
,
transform
=
transform
)
test_dataset
=
paddle
.
vision
.
datasets
.
MNIST
(
mode
=
'test'
,
backend
=
'cv2'
,
transform
=
transform
)
train_loader
=
paddle
.
io
.
DataLoader
(
train_dataset
,
places
=
place
,
feed_list
=
[
image
,
label
],
drop_last
=
True
,
batch_size
=
64
,
return_list
=
False
)
valid_loader
=
paddle
.
io
.
DataLoader
(
test_dataset
,
places
=
place
,
feed_list
=
[
image
,
label
],
batch_size
=
64
,
return_list
=
False
)
def
sample_generator_creator
():
def
__reader__
():
for
data
in
test_dataset
:
image
,
label
=
data
yield
image
,
label
return
__reader__
def
train
(
program
):
iter
=
0
for
data
in
train_loader
():
cost
,
top1
,
top5
=
exe
.
run
(
program
,
feed
=
data
,
fetch_list
=
[
avg_cost
,
acc_top1
,
acc_top5
])
iter
+=
1
if
iter
%
100
==
0
:
print
(
'train iter={}, avg loss {}, acc_top1 {}, acc_top5 {}'
.
format
(
iter
,
cost
,
top1
,
top5
))
def
test
(
program
,
outputs
=
[
avg_cost
,
acc_top1
,
acc_top5
]):
iter
=
0
result
=
[[],
[],
[]]
for
data
in
valid_loader
():
cost
,
top1
,
top5
=
exe
.
run
(
program
,
feed
=
data
,
fetch_list
=
outputs
)
iter
+=
1
if
iter
%
100
==
0
:
print
(
'eval iter={}, avg loss {}, acc_top1 {}, acc_top5 {}'
.
format
(
iter
,
cost
,
top1
,
top5
))
result
[
0
].
append
(
cost
)
result
[
1
].
append
(
top1
)
result
[
2
].
append
(
top5
)
print
(
' avg loss {}, acc_top1 {}, acc_top5 {}'
.
format
(
np
.
mean
(
result
[
0
]),
np
.
mean
(
result
[
1
]),
np
.
mean
(
result
[
2
])))
return
np
.
mean
(
result
[
1
]),
np
.
mean
(
result
[
2
])
train
(
main_prog
)
top1_1
,
top5_1
=
test
(
val_prog
)
paddle
.
static
.
save_inference_model
(
path_prefix
=
float_infer_model_path_prefix
,
feed_vars
=
[
image
,
label
],
fetch_vars
=
[
avg_cost
,
acc_top1
,
acc_top5
],
executor
=
exe
,
program
=
val_prog
)
quant_config
=
{
'weight_quantize_type'
:
'channel_wise_abs_max'
,
'activation_quantize_type'
:
'moving_average_abs_max'
,
'not_quant_pattern'
:
[
'skip_quant'
],
'quantize_op_types'
:
[
'conv2d'
,
'depthwise_conv2d'
,
'mul'
]
}
train_config
=
{
"num_epoch"
:
1
,
# training epoch num
"max_iter"
:
20
,
"save_iter_step"
:
10
,
"learning_rate"
:
0.0001
,
"weight_decay"
:
0.0001
,
"use_pact"
:
False
,
"quant_model_ckpt_path"
:
"./quantaware_with_infermodel_checkpoints/"
,
"teacher_model_path_prefix"
:
float_infer_model_path_prefix
,
"model_path_prefix"
:
float_infer_model_path_prefix
,
"distill_node_pair"
:
[
"teacher_fc_0.tmp_0"
,
"fc_0.tmp_0"
,
"teacher_batch_norm_24.tmp_4"
,
"batch_norm_24.tmp_4"
,
"teacher_batch_norm_22.tmp_4"
,
"batch_norm_22.tmp_4"
,
"teacher_batch_norm_18.tmp_4"
,
"batch_norm_18.tmp_4"
,
"teacher_batch_norm_13.tmp_4"
,
"batch_norm_13.tmp_4"
,
"teacher_batch_norm_5.tmp_4"
,
"batch_norm_5.tmp_4"
]
}
def
test_callback
(
compiled_test_program
,
feed_names
,
fetch_list
,
checkpoint_name
):
outputs
=
fetch_list
iter
=
0
result
=
[[],
[],
[]]
for
data
in
valid_loader
():
cost
,
top1
,
top5
=
exe
.
run
(
compiled_test_program
,
feed
=
data
,
fetch_list
=
fetch_list
)
iter
+=
1
if
iter
%
100
==
0
:
print
(
'eval iter={}, avg loss {}, acc_top1 {}, acc_top5 {}'
.
format
(
iter
,
cost
,
top1
,
top5
))
result
[
0
].
append
(
cost
)
result
[
1
].
append
(
top1
)
result
[
2
].
append
(
top5
)
print
(
"quant model checkpoint: "
+
checkpoint_name
+
' avg loss {}, acc_top1 {}, acc_top5 {}'
.
format
(
np
.
mean
(
result
[
0
]),
np
.
mean
(
result
[
1
]),
np
.
mean
(
result
[
2
])))
return
np
.
mean
(
result
[
1
]),
np
.
mean
(
result
[
2
])
def
test_quant_aware_with_infermodel
(
exe
,
place
):
quant_aware_with_infermodel
(
exe
,
place
,
scope
=
None
,
train_reader
=
train_loader
,
quant_config
=
quant_config
,
train_config
=
train_config
,
test_callback
=
test_callback
)
def
test_export_quant_infermodel
(
exe
,
place
,
checkpoint_path
,
quant_infermodel_save_path
):
export_quant_infermodel
(
exe
,
place
,
scope
=
None
,
quant_config
=
quant_config
,
train_config
=
train_config
,
checkpoint_path
=
checkpoint_path
,
export_inference_model_path_prefix
=
quant_infermodel_save_path
)
#place = paddle.CPUPlace()
place
=
paddle
.
CUDAPlace
(
0
)
if
paddle
.
is_compiled_with_cuda
(
)
else
paddle
.
CPUPlace
()
exe
=
paddle
.
static
.
Executor
(
place
)
test_quant_aware_with_infermodel
(
exe
,
place
)
checkpoint_path
=
"./quantaware_with_infermodel_checkpoints/epoch_0_iter_10"
quant_infermodel_save_path
=
"./quantaware_with_infermodel_export"
test_export_quant_infermodel
(
exe
,
place
,
checkpoint_path
,
quant_infermodel_save_path
)
train_config
[
"use_pact"
]
=
True
test_quant_aware_with_infermodel
(
exe
,
place
)
train_config
[
"use_pact"
]
=
False
checkpoint_path
=
"./quantaware_with_infermodel_checkpoints/epoch_0_iter_10"
quant_infermodel_save_path
=
"./quantaware_with_infermodel_export"
test_export_quant_infermodel
(
exe
,
place
,
checkpoint_path
,
quant_infermodel_save_path
)
if
__name__
==
'__main__'
:
unittest
.
main
()
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录