Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
PaddleSlim
提交
8e1691b4
P
PaddleSlim
项目概览
PaddlePaddle
/
PaddleSlim
大约 1 年 前同步成功
通知
51
Star
1434
Fork
344
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
53
列表
看板
标记
里程碑
合并请求
16
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
PaddleSlim
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
53
Issue
53
列表
看板
标记
里程碑
合并请求
16
合并请求
16
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
8e1691b4
编写于
12月 07, 2022
作者:
C
Chang Xu
提交者:
GitHub
12月 07, 2022
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
Add Tests for Analysis & Support EvalFunc is None (#1574)
上级
ef6a8f25
变更
10
隐藏空白更改
内联
并排
Showing
10 changed file
with
852 addition
and
102 deletion
+852
-102
docs/zh_cn/tutorials/quant/AnalysisPTQ.md
docs/zh_cn/tutorials/quant/AnalysisPTQ.md
+58
-12
docs/zh_cn/tutorials/quant/AnalysisQAT.md
docs/zh_cn/tutorials/quant/AnalysisQAT.md
+55
-13
example/quantization_analysis/GPT/README.md
example/quantization_analysis/GPT/README.md
+15
-2
example/quantization_analysis/GPT/analysis.py
example/quantization_analysis/GPT/analysis.py
+2
-2
paddleslim/quant/analysis_ptq.py
paddleslim/quant/analysis_ptq.py
+122
-56
paddleslim/quant/analysis_qat.py
paddleslim/quant/analysis_qat.py
+66
-17
tests/test_analysis_ptq.py
tests/test_analysis_ptq.py
+73
-0
tests/test_analysis_ptq_eval_func.py
tests/test_analysis_ptq_eval_func.py
+171
-0
tests/test_analysis_qat.py
tests/test_analysis_qat.py
+96
-0
tests/test_analysis_qat_eval_func.py
tests/test_analysis_qat_eval_func.py
+194
-0
未找到文件。
docs/zh_cn/tutorials/quant/AnalysisPTQ.md
浏览文件 @
8e1691b4
...
...
@@ -19,7 +19,7 @@
| model_dir | 必须传入的模型文件路径,可为文件夹名;若模型为ONNX类型,直接输入'.onnx'模型文件名称即可 |
| model_filename | 默认为None,若model_dir为文件夹名,则必须传入以'.pdmodel'结尾的模型名称,若model_dir为'.onnx'模型文件名称,则不需要传入 |
| params_filename | 默认为None,若model_dir为文件夹名,则必须传入以'.pdiparams'结尾的模型名称,若model_dir为'.onnx'模型文件名称,则不需要传入 |
| eval_function | 若需要验证精度,需要传入自定义的验证函数 |
| eval_function | 若需要验证精度,需要传入自定义的验证函数
;若不传入,精度误差分析将根据Cosine Similarity计算得出
|
| data_loader | 模型校准时使用的数据,DataLoader继承自
`paddle.io.DataLoader`
。可以直接使用模型套件中的DataLoader,或者根据
[
paddle.io.DataLoader
](
https://www.paddlepaddle.org.cn/documentation/docs/zh/api/paddle/io/DataLoader_cn.html#dataloader
)
自定义所需要的DataLoader |
| save_dir | 分析后保存模型精度或pdf等文件的文件夹,默认为
`analysis_results`
|
| resume | 是否加载中间分析文件,默认为False|
...
...
@@ -31,19 +31,65 @@
## 3. 量化分析工具的使用
**创建量化分析工具**
:
```
shell
# 下载Inference模型
wget
-q
https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/inference/MobileNetV1_infer.tar
tar
-xf
MobileNetV1_infer.tar
# 下载demo数据集
wget
-q
https://sys-p0.bj.bcebos.com/slim_ci/ILSVRC2012_data_demo.tar.gz
tar
-xf
ILSVRC2012_data_demo.tar.gz
```
```
shell
import paddle
from PIL import Image
from paddle.vision.datasets import DatasetFolder
from paddle.vision.transforms import transforms
from paddleslim.quant.analysis_ptq import AnalysisPTQ
paddle.enable_static
()
class ImageNetDataset
(
DatasetFolder
)
:
def __init__
(
self, path,
image_size
=
224
)
:
super
(
ImageNetDataset, self
)
.__init__
(
path
)
normalize
=
transforms.Normalize
(
mean
=[
123.675, 116.28, 103.53],
std
=[
58.395, 57.120, 57.375]
)
self.transform
=
transforms.Compose
([
transforms.Resize
(
256
)
, transforms.CenterCrop
(
image_size
)
,
transforms.Transpose
()
, normalize
])
def __getitem__
(
self, idx
)
:
img_path, _
=
self.samples[idx]
return
self.transform
(
Image.open
(
img_path
)
.convert
(
'RGB'
))
def __len__
(
self
)
:
return
len
(
self.samples
)
train_dataset
=
ImageNetDataset
(
"./ILSVRC2012_data_demo/ILSVRC2012/train/"
)
image
=
paddle.static.data
(
name
=
'inputs'
,
shape
=[
None] +
[
3, 224, 224],
dtype
=
'float32'
)
train_loader
=
paddle.io.DataLoader
(
train_dataset,
feed_list
=[
image],
batch_size
=
8,
return_list
=
False
)
analyzer
=
AnalysisPTQ
(
model_dir=config["model_dir"],
model_filename=config["model_filename"],
params_filename=config["params_filename"],
eval_function=eval_function,
data_loader=data_loader,
save_dir=config['save_dir'],
ptq_config=config['PTQ'])
model_dir
=
"./MobileNetV1_infer"
,
model_filename
=
"inference.pdmodel"
,
params_filename
=
"inference.pdiparams"
,
save_dir
=
"MobileNetV1_analysis"
,
ptq_config
={
'quantizable_op_type'
:
[
"conv2d"
,
"depthwise_conv2d"
]
,
'weight_quantize_type'
:
'abs_max'
,
'activation_quantize_type'
:
'moving_average_abs_max'
,
'is_full_quantize'
: False,
'batch_size'
: 8,
'batch_nums'
: 1,
}
,
data_loader
=
train_loader
)
```
**统计分析**
```
```
shell
analyzer.statistical_analyse
()
```
...
...
@@ -75,7 +121,7 @@ analyzer.statistical_analyse()
**精度误差分析**
```
```
shell
analyzer.metric_error_analyse
()
```
调用该接口,会遍历量化模型中的一层,并计算量化该层后模型的损失。调用该接口时,需要输入Eval Function。会产出所有只量化一层的模型精度排序,将默认保存在
`./analysis_results/analysis.txt`
中。
...
...
@@ -83,8 +129,8 @@ analyzer.metric_error_analyse()
**直接产出符合预期精度的目标量化模型**
```
analyzer.get_target_quant_model(target_metric)
```
shell
analyzer.get_target_quant_model
(
target_metric
=
70.0
)
```
## 4. 根据分析结果执行离线量化
...
...
docs/zh_cn/tutorials/quant/AnalysisQAT.md
浏览文件 @
8e1691b4
...
...
@@ -14,7 +14,7 @@
| params_filename | 默认为None,若model_dir为文件夹名,则必须传入以'.pdiparams'结尾的模型名称 |
| quantizable_op_type | 需分析的量化的op类型,默认为
`conv2d`
,
`depthwise_conv2d`
,
`mul`
|
| qat_metric | 量化模型的精度,可不传入,默认为None,不传入时会自动计算 |
| eval_function |
需要传入自定义的验证函数
|
| eval_function |
若需要验证精度,需要传入自定义的验证函数;若不传入,精度误差分析将根据Cosine Similarity计算得出
|
| data_loader | 模型校准时使用的数据,DataLoader继承自
`paddle.io.DataLoader`
。可以直接使用模型套件中的DataLoader,或者根据
[
paddle.io.DataLoader
](
https://www.paddlepaddle.org.cn/documentation/docs/zh/api/paddle/io/DataLoader_cn.html#dataloader
)
自定义所需要的DataLoader |
| save_dir | 分析后保存模型精度或pdf等文件的文件夹,默认为
`analysis_results`
|
| resume | 是否加载中间分析文件,默认为False|
...
...
@@ -25,24 +25,66 @@
## 3. 量化分析工具的使用
**创建量化分析工具**
:
```
shell
# 下载Inference模型
wget
-q
https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/inference/MobileNetV1_infer.tar
tar
-xf
MobileNetV1_infer.tar
wget
-q
https://paddle-slim-models.bj.bcebos.com/act/MobileNetV1_QAT.tar
tar
-xf
MobileNetV1_QAT.tar
# 下载demo数据集
wget
-q
https://sys-p0.bj.bcebos.com/slim_ci/ILSVRC2012_data_demo.tar.gz
tar
-xf
ILSVRC2012_data_demo.tar.gz
```
```
shell
import paddle
from PIL import Image
from paddle.vision.datasets import DatasetFolder
from paddle.vision.transforms import transforms
from paddle.fluid.contrib.slim.quantization import PostTrainingQuantization
from paddleslim.quant.analysis_qat import AnalysisQAT
paddle.enable_static
()
class ImageNetDataset
(
DatasetFolder
)
:
def __init__
(
self, path,
image_size
=
224
)
:
super
(
ImageNetDataset, self
)
.__init__
(
path
)
normalize
=
transforms.Normalize
(
mean
=[
123.675, 116.28, 103.53],
std
=[
58.395, 57.120, 57.375]
)
self.transform
=
transforms.Compose
([
transforms.Resize
(
256
)
, transforms.CenterCrop
(
image_size
)
,
transforms.Transpose
()
, normalize
])
def __getitem__
(
self, idx
)
:
img_path, _
=
self.samples[idx]
return
self.transform
(
Image.open
(
img_path
)
.convert
(
'RGB'
))
def __len__
(
self
)
:
return
len
(
self.samples
)
train_dataset
=
ImageNetDataset
(
"./ILSVRC2012_data_demo/ILSVRC2012/train/"
)
image
=
paddle.static.data
(
name
=
'inputs'
,
shape
=[
None] +
[
3, 224, 224],
dtype
=
'float32'
)
train_loader
=
paddle.io.DataLoader
(
train_dataset,
feed_list
=[
image],
batch_size
=
8,
return_list
=
False
)
analyzer
=
AnalysisQAT
(
quant_model_dir=config["quant_model_dir"],
float_model_dir=config["float_model_dir"],
model_filename=config["model_filename"],
params_filename=config["params_filename"],
quantizable_op_type=config['quantizable_op_type'],
qat_metric=config['qat_metric'],
eval_function=eval_function,
data_loader=eval_loader,
save_dir=config['save_dir'],
resume=config['resume'],
)
float_model_dir
=
"./MobileNetV1_infer"
,
quant_model_dir
=
"./MobileNetV1_QAT"
,
model_filename
=
"inference.pdmodel"
,
params_filename
=
"inference.pdiparams"
,
save_dir
=
"MobileNetV1_analysis"
,
data_loader
=
train_loader
)
```
**精度误差分析**
```
```
shell
analyzer.metric_error_analyse
()
```
调用该接口,会遍历量化模型中的每一层,去掉量化节点并计算当前层不量化的模型精度。调用该接口时,需要输入Eval Function。会产出所有去掉一层量化的模型精度排序,将默认保存在
`./analysis_results/analysis.txt`
中。具体使用可参考
[
GPT量化训练敏感度分析DEMO
](
../../../../example/quantization_analysis/GPT/README.md
)
。
...
...
example/quantization_analysis/GPT/README.md
浏览文件 @
8e1691b4
...
...
@@ -24,10 +24,22 @@
量化敏感度分析基于验证集获得每层的敏感度,可下载和使用
[
LAMBADA
](
https://raw.githubusercontent.com/cybertronai/bflm/master/lambada_test.jsonl
)
或者
[
WikiText
](
https://s3.amazonaws.com/research.metamind.io/wikitext/wikitext-103-v1.zip
)
数据集。本示例使用LAMBADA数据集来进行敏感度分析。
```
shell
# 下载验证数据
wget https://raw.githubusercontent.com/cybertronai/bflm/master/lambada_test.jsonl
```
#### 3.3 准备预测模型
-
[
GPT-345M
](
https://bj.bcebos.com/v1/paddle-slim-models/GPT_345M_Baseline.tar
)
:Base模型
-
[
GPT-345M
](
https://bj.bcebos.com/v1/paddle-slim-models/GPT_345_QAT_wo_analysis.tar
)
:分析前量化训练后的模型
-
下载量化前Base模型
```
shell
wget https://bj.bcebos.com/v1/paddle-slim-models/GPT_345M_Baseline.tar
```
-
下载分析前量化训练后的模型
```
shell
wget https://bj.bcebos.com/v1/paddle-slim-models/GPT_345_QAT_wo_analysis.tar
```
如想自行导出,可参考
[
GPT模型量化训练
](
https://github.com/PaddlePaddle/PaddleFleetX/blob/release/2.4/projects/gpt/docs/quantization_aware_training.md
)
。
#### 3.4 量化敏感度分析
量化敏感度分析示例通过analysis.py脚本启动,会使用接口
```paddleslim.quant.AnalysisQAT```
对模型进行敏感度分析。配置config文件中模型路径、数据路径和量化相关的参数,配置完成后便可对模型进行敏感度分析。具体运行命令为:
...
...
@@ -44,3 +56,4 @@ python analysis.py --config_path=./configs/gpt_345M_analysis.yaml
#### 3.5 重新量化训练
根据分析结果,重新量化训练时,去掉了```linear_31```,```linear_27```,```linear_22```,```linear_43```,```linear_83```,```linear_15```,```linear_87```七层Linear的量化,最后量化模型精度达到44.94。
重新量化训练的过程在 PaddleFleetX 中实现,可参考[GPT模型量化训练](https://github.com/PaddlePaddle/PaddleFleetX/blob/release/2.4/projects/gpt/docs/quantization_aware_training.md)。
example/quantization_analysis/GPT/analysis.py
浏览文件 @
8e1691b4
...
...
@@ -104,8 +104,8 @@ def eval_function(exe, program, feed_names, fetch_list):
total_score
+=
acc
.
numpy
()[
0
]
if
eval_step
!=
0
and
(
eval_step
%
10
==
0
):
print
(
"[eval] step: %d,
batch: %d,
%s: %.9f, speed: %.2f step/s"
%
(
eval_step
,
eval_step
,
score_name
,
total_score
,
print
(
"[eval] step: %d, %s: %.9f, speed: %.2f step/s"
%
(
eval_step
,
score_name
,
total_score
,
1.
/
(
time
.
time
()
-
tic_eval
)))
tic_eval
=
time
.
time
()
paddle
.
enable_static
()
...
...
paddleslim/quant/analysis_ptq.py
浏览文件 @
8e1691b4
...
...
@@ -24,7 +24,7 @@ import numpy as np
import
random
import
tempfile
import
paddle
from
.quanter
import
quant_post
import
paddle.nn.functional
as
F
from
..core
import
GraphWrapper
from
..common
import
get_logger
from
..common
import
get_feed_vars
,
wrap_dataloader
,
load_inference_model
,
get_model_dir
...
...
@@ -80,6 +80,7 @@ class AnalysisPTQ(object):
'is_full_quantize'
]
if
'is_full_quantize'
in
ptq_config
else
False
self
.
onnx_format
=
ptq_config
[
'onnx_format'
]
if
'onnx_format'
in
ptq_config
else
False
ptq_config
[
'onnx_format'
]
=
self
.
onnx_format
if
'algo'
not
in
ptq_config
:
ptq_config
[
'algo'
]
=
'avg'
...
...
@@ -134,9 +135,8 @@ class AnalysisPTQ(object):
# load tobe_analyized_layer from checkpoint
if
resume
:
self
.
load_checkpoint
()
self
.
tobe_analyized_layer
=
set
(
self
.
support_quant_val_name_list
)
-
set
(
list
(
self
.
quant_layer_metrics
.
keys
()))
self
.
tobe_analyized_layer
=
sorted
(
list
(
self
.
tobe_analyized_layer
))
self
.
tobe_analyized_layer
=
sorted
(
list
(
self
.
support_quant_val_name_list
))
def
save_checkpoint
(
self
):
if
not
os
.
path
.
exists
(
self
.
save_dir
):
...
...
@@ -172,7 +172,6 @@ class AnalysisPTQ(object):
model_filename
=
self
.
model_filename
,
params_filename
=
self
.
params_filename
,
skip_tensor_list
=
skip_tensor_list
,
onnx_format
=
self
.
onnx_format
,
**
self
.
ptq_config
)
def
sampling
(
self
,
executor
,
program
,
scope
):
...
...
@@ -187,65 +186,125 @@ class AnalysisPTQ(object):
if
batch_id
>=
self
.
batch_nums
:
break
def
eval_quant_model
(
self
,
skip_list
):
def
fp_int_cosine_similarity
(
self
,
executor
,
float_program
,
quant_program
,
float_scope
,
quant_scope
):
cosine_similarity
=
[]
for
step
,
data
in
enumerate
(
self
.
data_loader
()):
with
paddle
.
static
.
scope_guard
(
float_scope
):
float_preds
=
executor
.
run
(
program
=
float_program
,
feed
=
data
,
fetch_list
=
self
.
fetch_list
,
return_numpy
=
False
)
float_preds
=
float_preds
[
0
]
with
paddle
.
static
.
scope_guard
(
quant_scope
):
quant_preds
=
executor
.
run
(
program
=
quant_program
,
feed
=
data
,
fetch_list
=
self
.
fetch_list
,
return_numpy
=
False
)
quant_preds
=
quant_preds
[
0
]
paddle
.
disable_static
()
float_preds
=
paddle
.
to_tensor
(
float_preds
)
quant_preds
=
paddle
.
to_tensor
(
quant_preds
)
cos_sim
=
F
.
cosine_similarity
(
float_preds
,
quant_preds
).
mean
()
cos_sim
=
cos_sim
.
numpy
()
cosine_similarity
.
append
(
cos_sim
)
if
step
!=
0
and
(
step
%
10
==
0
):
_logger
.
info
(
"[step]: %d, cosine similarity: %.9f"
%
(
step
,
np
.
array
(
cosine_similarity
).
mean
()))
paddle
.
enable_static
()
return
np
.
array
(
cosine_similarity
).
mean
()
def
get_sensitive_metric
(
self
,
skip_list
,
layer_name
):
executor
=
paddle
.
static
.
Executor
(
self
.
places
)
post_training_quantization
=
self
.
create_ptq
(
executor
,
skip_list
)
program
=
post_training_quantization
.
quantize
()
_logger
.
info
(
'Evaluating...'
)
if
self
.
onnx_format
:
post_training_quantization
.
save_quantized_model
(
self
.
temp_save_path
,
model_filename
=
'model.pdmodel'
,
params_filename
=
'model.pdiparams'
)
program
,
_
,
_
=
load_inference_model
(
self
.
temp_save_path
,
executor
,
model_filename
=
'model.pdmodel'
,
params_filename
=
'model.pdiparams'
)
quant_metric
=
self
.
eval_function
(
executor
,
program
,
self
.
feed_list
,
self
.
fetch_list
)
if
self
.
eval_function
is
not
None
:
post_training_quantization
=
self
.
create_ptq
(
executor
,
skip_list
)
program
=
post_training_quantization
.
quantize
()
_logger
.
info
(
'Evaluating...'
)
if
self
.
onnx_format
:
post_training_quantization
.
save_quantized_model
(
self
.
temp_save_path
,
model_filename
=
'model.pdmodel'
,
params_filename
=
'model.pdiparams'
)
program
,
_
,
_
=
load_inference_model
(
self
.
temp_save_path
,
executor
,
model_filename
=
'model.pdmodel'
,
params_filename
=
'model.pdiparams'
)
metric
=
self
.
eval_function
(
executor
,
program
,
self
.
feed_list
,
self
.
fetch_list
)
if
skip_list
is
None
:
executor
.
close
()
return
metric
sensitive_metric
=
self
.
base_metric
-
metric
_logger
.
info
(
"Quantized layer name: %s, the accuracy: %.4f, the sensitive metric: %.4f"
%
(
layer_name
,
metric
,
sensitive_metric
))
else
:
float_scope
=
paddle
.
static
.
Scope
()
quant_scope
=
paddle
.
static
.
Scope
()
with
paddle
.
static
.
scope_guard
(
float_scope
):
[
float_program
,
_
,
_
]
=
load_inference_model
(
self
.
model_dir
,
executor
=
executor
,
model_filename
=
self
.
model_filename
,
params_filename
=
self
.
params_filename
)
with
paddle
.
static
.
scope_guard
(
quant_scope
):
post_training_quantization
=
self
.
create_ptq
(
executor
,
skip_list
)
quant_program
=
post_training_quantization
.
quantize
()
metric
=
self
.
fp_int_cosine_similarity
(
executor
,
float_program
,
quant_program
,
float_scope
,
quant_scope
)
sensitive_metric
=
1.0
-
metric
_logger
.
info
(
"Quantized layer name: %s, the cosine similarity: %.4f, the sensitive metric: %.4f"
%
(
layer_name
,
metric
,
sensitive_metric
))
executor
.
close
()
return
quant
_metric
return
sensitive
_metric
def
metric_error_analyse
(
self
):
'''
Evaluate the quantized models, which are generated by quantizing each weight operator one by one. The results will be saved into analysis.txt.
'''
assert
self
.
data_loader
is
not
None
,
"When computing the sensitivity of quantized layers, the data loader is needed"
assert
self
.
eval_function
is
not
None
,
"When computing the sensitivity of quantized layers, the eval function is needed"
# evaluate before quant
_logger
.
info
(
'Start to evaluate the base model.'
)
executor
=
paddle
.
static
.
Executor
(
self
.
places
)
[
program
,
feed_list
,
fetch_list
]
=
load_inference_model
(
\
self
.
model_di
r
,
\
executor
=
executor
,
\
model_filename
=
self
.
model_filename
,
\
params_filename
=
self
.
params_filename
)
self
.
base_metric
=
self
.
eval_function
(
executor
,
program
,
feed_list
,
fetch_list
)
_logger
.
info
(
'Before quantized, the accuracy of the model is: {}'
.
format
(
self
.
base_metric
)
)
if
self
.
eval_function
is
not
None
:
# evaluate before quant
_logger
.
info
(
'Start to evaluate the base model.'
)
executor
=
paddle
.
static
.
Executor
(
self
.
places
)
[
program
,
feed_list
,
fetch_list
]
=
load_inference_model
(
\
self
.
model_dir
,
\
executor
=
executo
r
,
\
model_filename
=
self
.
model_filename
,
\
params_filename
=
self
.
params_filename
)
self
.
base_metric
=
self
.
eval_function
(
executor
,
program
,
feed_list
,
fetch_list
)
_logger
.
info
(
'Before quantized, the accuracy of the model is: {}'
.
format
(
self
.
base_metric
))
executor
.
close
(
)
# evaluate before quant
_logger
.
info
(
'Start to evaluate the quantized model.'
)
self
.
quant_metric
=
self
.
eval_quant_model
(
None
)
_logger
.
info
(
'After quantized, the accuracy of the model is: {}'
.
format
(
self
.
quant_metric
))
# evaluate before quant
_logger
.
info
(
'Start to evaluate the quantized model.'
)
self
.
quant_metric
=
self
.
get_sensitive_metric
(
None
,
'all quantizable layers'
)
_logger
.
info
(
'After quantized, the accuracy of the model is: {}'
.
format
(
self
.
quant_metric
))
# For each layer, quantize the weight op and evaluate the quantized model.
for
i
,
layer_name
in
enumerate
(
self
.
tobe_analyized_layer
):
if
layer_name
in
self
.
quant_layer_metrics
:
continue
_logger
.
info
(
'Checking {}/{} quant model: quant layer {}'
.
format
(
i
+
1
,
len
(
self
.
tobe_analyized_layer
),
layer_name
))
skip_list
=
copy
.
copy
(
list
(
self
.
support_quant_val_name_list
))
skip_list
.
remove
(
layer_name
)
quant_metric
=
self
.
eval_quant_model
(
skip_list
)
_logger
.
info
(
"Quantized layer name: {}, eval metric: {}, the loss caused by this layer: {}"
.
format
(
layer_name
,
round
(
quant_metric
,
4
),
round
(
self
.
base_metric
-
quant_metric
,
4
)))
self
.
quant_layer_metrics
[
layer_name
]
=
quant_metric
sensitive_metric
=
self
.
get_sensitive_metric
(
skip_list
,
layer_name
)
self
.
quant_layer_metrics
[
layer_name
]
=
sensitive_metric
self
.
save_checkpoint
()
if
self
.
onnx_format
:
...
...
@@ -254,18 +313,18 @@ class AnalysisPTQ(object):
self
.
sensitivity_ranklist
=
sorted
(
self
.
quant_layer_metrics
,
key
=
self
.
quant_layer_metrics
.
get
,
reverse
=
Fals
e
)
reverse
=
Tru
e
)
_logger
.
info
(
'Finished computing the sensitivity of the model.'
)
for
name
in
self
.
sensitivity_ranklist
:
_logger
.
info
(
"
quant layer name: {}, eval metric: {}"
.
format
(
name
,
self
.
quant_layer_metrics
[
name
]))
_logger
.
info
(
"
Quantized layer name: {}, sensitivity metric: {}"
.
format
(
name
,
self
.
quant_layer_metrics
[
name
]))
analysis_file
=
os
.
path
.
join
(
self
.
save_dir
,
"analysis.txt"
)
with
open
(
analysis_file
,
"w"
)
as
analysis_ret_f
:
for
name
in
self
.
sensitivity_ranklist
:
analysis_ret_f
.
write
(
"
quant layer name: {}, eval
metric: {}
\n
"
.
format
(
"
Quantized layer name: {}, sensitivity
metric: {}
\n
"
.
format
(
name
,
self
.
quant_layer_metrics
[
name
]))
_logger
.
info
(
'Analysis file is saved in {}'
.
format
(
analysis_file
))
...
...
@@ -285,7 +344,7 @@ class AnalysisPTQ(object):
executor
=
executor
,
\
model_filename
=
self
.
model_filename
,
\
params_filename
=
self
.
params_filename
)
scope
=
paddle
.
static
.
Executor
.
global_scope
()
scope
=
paddle
.
static
.
global_scope
()
persistable_var_names
=
[]
for
var
in
program
.
list_vars
():
if
var
.
persistable
:
...
...
@@ -294,6 +353,9 @@ class AnalysisPTQ(object):
self
.
acts_weight_map
=
self
.
get_weight_act_map
(
program
,
self
.
weight_names
,
persistable_var_names
)
activations_names
=
list
(
self
.
acts_weight_map
.
keys
())
for
var
in
program
.
list_vars
():
if
var
.
name
in
activations_names
:
var
.
persistable
=
True
# sample
self
.
sampling
(
executor
,
program
,
scope
)
...
...
@@ -305,7 +367,7 @@ class AnalysisPTQ(object):
def
collect_quant_stat
(
self
):
_logger
.
info
(
'Collecting Statistic After PTQ...'
)
executor
=
paddle
.
static
.
Executor
(
self
.
places
)
scope
=
paddle
.
static
.
Executor
.
global_scope
()
scope
=
paddle
.
static
.
global_scope
()
post_training_quantization
=
self
.
create_ptq
(
executor
,
None
)
program
=
post_training_quantization
.
quantize
()
...
...
@@ -525,13 +587,13 @@ class AnalysisPTQ(object):
rank_list
=
sorted
(
self
.
quant_layer_metrics
,
key
=
self
.
quant_layer_metrics
.
get
,
reverse
=
Fals
e
)
reverse
=
Tru
e
)
else
:
_logger
.
info
(
'Analyse metric error before get target quantized model.'
)
self
.
metric_error_analyse
()
while
True
:
while
len
(
rank_list
)
>
0
:
skip_list
.
append
(
rank_list
.
pop
(
0
))
_logger
.
info
(
'Skip Ops: {}'
.
format
(
skip_list
))
executor
=
paddle
.
static
.
Executor
(
self
.
places
)
...
...
@@ -559,3 +621,7 @@ class AnalysisPTQ(object):
'The quantized model does not satisfy the target metric. Skip next Op...'
)
executor
.
close
()
else
:
_logger
.
info
(
'Sorry, the target quantized model cannot be found. Please set lower target metric.'
)
paddleslim/quant/analysis_qat.py
浏览文件 @
8e1691b4
...
...
@@ -20,6 +20,7 @@ import logging
import
numpy
as
np
import
paddle
import
paddle.nn.functional
as
F
from
paddle.framework
import
core
from
paddle.fluid.framework
import
IrGraph
from
..common
import
get_logger
,
load_inference_model
...
...
@@ -69,6 +70,7 @@ class AnalysisQAT(object):
self
.
quantizable_op_type
=
quantizable_op_type
self
.
qat_metric
=
qat_metric
self
.
eval_function
=
eval_function
self
.
data_loader
=
data_loader
self
.
save_dir
=
save_dir
self
.
checkpoint_name
=
os
.
path
.
join
(
save_dir
,
'analysis_checkpoint.pkl'
)
self
.
nonquant_layer_metrics
=
{}
...
...
@@ -98,8 +100,13 @@ class AnalysisQAT(object):
if
'quantized'
in
input_name
:
self
.
inputs_of_quantized_op
.
append
(
input_names
)
break
if
self
.
eval_function
is
None
:
assert
self
.
data_loader
is
not
None
,
"DataLoader cannot be None if Eval Fuction is None."
_logger
.
info
(
'The sensitivity will measured by cosine similarity of the outputs from float model and quantized model.'
)
if
self
.
qat_metric
is
None
:
if
self
.
qat_metric
is
None
and
self
.
eval_function
is
not
None
:
_logger
.
info
(
'Calculating the metric of QAT model...'
)
self
.
qat_metric
=
self
.
eval_function
(
executor
,
program
,
self
.
feed_list
,
self
.
fetch_list
)
*
100
...
...
@@ -107,6 +114,9 @@ class AnalysisQAT(object):
round
(
self
.
qat_metric
,
4
)))
executor
.
close
()
if
resume
:
self
.
load_checkpoint
()
def
save_checkpoint
(
self
):
if
not
os
.
path
.
exists
(
self
.
save_dir
):
os
.
makedirs
(
self
.
save_dir
)
...
...
@@ -199,6 +209,35 @@ class AnalysisQAT(object):
return
graph
.
to_program
()
def
fp_int_cosine_similarity
(
self
,
executor
,
float_program
,
quant_program
,
float_scope
,
quant_scope
):
cosine_similarity
=
[]
for
step
,
data
in
enumerate
(
self
.
data_loader
()):
with
paddle
.
static
.
scope_guard
(
float_scope
):
float_preds
=
executor
.
run
(
program
=
float_program
,
feed
=
data
,
fetch_list
=
self
.
fetch_list
,
return_numpy
=
False
)
float_preds
=
float_preds
[
0
]
with
paddle
.
static
.
scope_guard
(
quant_scope
):
quant_preds
=
executor
.
run
(
program
=
quant_program
,
feed
=
data
,
fetch_list
=
self
.
fetch_list
,
return_numpy
=
False
)
quant_preds
=
quant_preds
[
0
]
paddle
.
disable_static
()
float_preds
=
paddle
.
to_tensor
(
float_preds
)
quant_preds
=
paddle
.
to_tensor
(
quant_preds
)
cos_sim
=
F
.
cosine_similarity
(
float_preds
,
quant_preds
).
mean
()
cos_sim
=
cos_sim
.
numpy
()
cosine_similarity
.
append
(
cos_sim
)
if
step
!=
0
and
(
step
%
10
==
0
):
_logger
.
info
(
"[step]: %d, cosine similarity: %.9f"
%
(
step
,
np
.
array
(
cosine_similarity
).
mean
()))
paddle
.
enable_static
()
return
np
.
array
(
cosine_similarity
).
mean
()
def
metric_error_analyse
(
self
):
executor
=
paddle
.
static
.
Executor
(
self
.
places
)
...
...
@@ -207,12 +246,14 @@ class AnalysisQAT(object):
for
idx
,
input_list
in
enumerate
(
self
.
inputs_of_quantized_op
):
weight_name
=
self
.
get_weight_name
(
input_list
)
if
weight_name
in
self
.
nonquant_layer_metrics
:
continue
_logger
.
info
(
'Checking {}/{} quant model: without quant layer {}'
.
format
(
idx
+
1
,
len
(
self
.
inputs_of_quantized_op
),
weight_name
))
with
paddle
.
static
.
scope_guard
(
float_scope
):
load_inference_model
(
[
float_program
,
_
,
_
]
=
load_inference_model
(
self
.
float_model_dir
,
executor
=
executor
,
model_filename
=
self
.
model_filename
,
...
...
@@ -232,18 +273,26 @@ class AnalysisQAT(object):
input_list
,
graph
,
float_scope
,
quant_scope
)
saved_program
=
self
.
relink_graph
(
graph
,
input_rename_map
,
output_rename_map
,
removed_ops
)
with
paddle
.
static
.
scope_guard
(
quant_scope
):
_logger
.
info
(
'Skip quant {}, evaluating....'
.
format
(
weight_name
))
metric
=
self
.
eval_function
(
executor
,
saved_program
,
self
.
feed_list
,
self
.
fetch_list
)
*
100
self
.
nonquant_layer_metrics
[
weight_name
]
=
metric
if
self
.
eval_function
is
not
None
:
with
paddle
.
static
.
scope_guard
(
quant_scope
):
_logger
.
info
(
'Skip quant {}, evaluating....'
.
format
(
weight_name
))
metric
=
self
.
eval_function
(
executor
,
saved_program
,
self
.
feed_list
,
self
.
fetch_list
)
*
100
self
.
nonquant_layer_metrics
[
weight_name
]
=
metric
-
self
.
qat_metric
_logger
.
info
(
'When skip quant %s, the eval metric is %.4f, the sensitive metric is %.4f'
%
(
weight_name
,
metric
,
metric
-
self
.
qat_metric
))
else
:
metric
=
self
.
fp_int_cosine_similarity
(
executor
,
float_program
,
saved_program
,
float_scope
,
quant_scope
)
self
.
nonquant_layer_metrics
[
weight_name
]
=
1
-
metric
_logger
.
info
(
'When skip quant {}, the metric is {}, the diff is {}'
.
format
(
weight_name
,
round
(
metric
,
4
),
round
(
metric
-
self
.
qat_metric
,
4
)))
'When skip quant %s, the cosine similarity is %.4f, the sensitive metric is %.4f'
%
(
weight_name
,
metric
,
1
-
metric
))
self
.
save_checkpoint
()
executor
.
close
()
...
...
@@ -254,13 +303,13 @@ class AnalysisQAT(object):
reverse
=
True
)
_logger
.
info
(
'Finished computing the sensitivity of the model.'
)
for
name
in
self
.
sensitivity_ranklist
:
_logger
.
info
(
"
without quant layer name: {}, eval metric: {}"
.
format
(
name
,
self
.
nonquant_layer_metrics
[
name
]))
_logger
.
info
(
"
Without quant layer name: {}, sensitive metric: {}"
.
format
(
name
,
self
.
nonquant_layer_metrics
[
name
]))
analysis_file
=
os
.
path
.
join
(
self
.
save_dir
,
"analysis.txt"
)
with
open
(
analysis_file
,
"w"
)
as
analysis_ret_f
:
for
name
in
self
.
sensitivity_ranklist
:
analysis_ret_f
.
write
(
"
without layer name: {}, eval metric: {}
\n
"
.
format
(
name
,
self
.
nonquant_layer_metrics
[
name
]))
"
Without quant layer name: {}, sensitive metric: {}
\n
"
.
format
(
name
,
self
.
nonquant_layer_metrics
[
name
]))
_logger
.
info
(
'Analysis file is saved in {}'
.
format
(
analysis_file
))
tests/test_analysis_ptq.py
0 → 100644
浏览文件 @
8e1691b4
import
os
import
sys
import
unittest
sys
.
path
.
append
(
"../"
)
import
paddle
from
PIL
import
Image
from
paddle.vision.datasets
import
DatasetFolder
from
paddle.vision.transforms
import
transforms
from
paddleslim.quant.analysis_ptq
import
AnalysisPTQ
paddle
.
enable_static
()
class
ImageNetDataset
(
DatasetFolder
):
def
__init__
(
self
,
path
,
image_size
=
224
):
super
(
ImageNetDataset
,
self
).
__init__
(
path
)
normalize
=
transforms
.
Normalize
(
mean
=
[
123.675
,
116.28
,
103.53
],
std
=
[
58.395
,
57.120
,
57.375
])
self
.
transform
=
transforms
.
Compose
([
transforms
.
Resize
(
256
),
transforms
.
CenterCrop
(
image_size
),
transforms
.
Transpose
(),
normalize
])
def
__getitem__
(
self
,
idx
):
img_path
,
_
=
self
.
samples
[
idx
]
return
self
.
transform
(
Image
.
open
(
img_path
).
convert
(
'RGB'
))
def
__len__
(
self
):
return
len
(
self
.
samples
)
class
AnalysisPTQDemo
(
unittest
.
TestCase
):
def
__init__
(
self
,
*
args
,
**
kwargs
):
super
(
AnalysisPTQDemo
,
self
).
__init__
(
*
args
,
**
kwargs
)
if
not
os
.
path
.
exists
(
'MobileNetV1_infer'
):
os
.
system
(
'wget -q https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/inference/MobileNetV1_infer.tar'
)
os
.
system
(
'tar -xf MobileNetV1_infer.tar'
)
if
not
os
.
path
.
exists
(
'ILSVRC2012_data_demo'
):
os
.
system
(
'wget -q https://sys-p0.bj.bcebos.com/slim_ci/ILSVRC2012_data_demo.tar.gz'
)
os
.
system
(
'tar -xf ILSVRC2012_data_demo.tar.gz'
)
def
test_demo
(
self
):
train_dataset
=
ImageNetDataset
(
"./ILSVRC2012_data_demo/ILSVRC2012/train/"
)
image
=
paddle
.
static
.
data
(
name
=
'inputs'
,
shape
=
[
None
]
+
[
3
,
224
,
224
],
dtype
=
'float32'
)
train_loader
=
paddle
.
io
.
DataLoader
(
train_dataset
,
feed_list
=
[
image
],
batch_size
=
8
,
return_list
=
False
)
analyzer
=
AnalysisPTQ
(
model_dir
=
"./MobileNetV1_infer"
,
model_filename
=
"inference.pdmodel"
,
params_filename
=
"inference.pdiparams"
,
save_dir
=
"MobileNetV1_analysis"
,
ptq_config
=
{
'quantizable_op_type'
:
[
"conv2d"
,
"depthwise_conv2d"
],
'weight_quantize_type'
:
'abs_max'
,
'activation_quantize_type'
:
'moving_average_abs_max'
,
'is_full_quantize'
:
False
,
'batch_size'
:
8
,
'batch_nums'
:
1
,
},
data_loader
=
train_loader
)
analyzer
.
statistical_analyse
()
analyzer
.
metric_error_analyse
()
os
.
system
(
'rm -rf MobileNetV1_analysis'
)
if
__name__
==
'__main__'
:
unittest
.
main
()
tests/test_analysis_ptq_eval_func.py
0 → 100644
浏览文件 @
8e1691b4
import
os
import
sys
import
unittest
import
numpy
as
np
sys
.
path
.
append
(
"../"
)
import
paddle
from
PIL
import
Image
from
paddle.vision.datasets
import
DatasetFolder
from
paddle.vision.transforms
import
transforms
from
paddleslim.quant.analysis_ptq
import
AnalysisPTQ
paddle
.
enable_static
()
class
ImageNetDataset
(
DatasetFolder
):
def
__init__
(
self
,
data_dir
,
image_size
=
224
,
mode
=
'train'
):
super
(
ImageNetDataset
,
self
).
__init__
(
data_dir
)
self
.
data_dir
=
data_dir
normalize
=
transforms
.
Normalize
(
mean
=
[
123.675
,
116.28
,
103.53
],
std
=
[
58.395
,
57.120
,
57.375
])
self
.
transform
=
transforms
.
Compose
([
transforms
.
Resize
(
256
),
transforms
.
CenterCrop
(
image_size
),
transforms
.
Transpose
(),
normalize
])
self
.
mode
=
mode
train_file_list
=
os
.
path
.
join
(
data_dir
,
'train_list.txt'
)
val_file_list
=
os
.
path
.
join
(
data_dir
,
'val_list.txt'
)
self
.
mode
=
mode
if
mode
==
'train'
:
with
open
(
train_file_list
)
as
flist
:
full_lines
=
[
line
.
strip
()
for
line
in
flist
]
np
.
random
.
shuffle
(
full_lines
)
lines
=
full_lines
self
.
samples
=
[
line
.
split
()
for
line
in
lines
]
else
:
with
open
(
val_file_list
)
as
flist
:
lines
=
[
line
.
strip
()
for
line
in
flist
]
self
.
samples
=
[
line
.
split
()
for
line
in
lines
]
def
__getitem__
(
self
,
idx
):
img_path
,
label
=
self
.
samples
[
idx
]
if
self
.
mode
==
'train'
:
return
self
.
transform
(
Image
.
open
(
os
.
path
.
join
(
self
.
data_dir
,
img_path
)).
convert
(
'RGB'
))
else
:
return
self
.
transform
(
Image
.
open
(
os
.
path
.
join
(
self
.
data_dir
,
img_path
)).
convert
(
'RGB'
)),
np
.
array
([
label
]).
astype
(
'int64'
)
def
__len__
(
self
):
return
len
(
self
.
samples
)
class
AnalysisPTQEvalFunction
(
unittest
.
TestCase
):
def
__init__
(
self
,
*
args
,
**
kwargs
):
super
(
AnalysisPTQEvalFunction
,
self
).
__init__
(
*
args
,
**
kwargs
)
if
not
os
.
path
.
exists
(
'MobileNetV1_infer'
):
os
.
system
(
'wget -q https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/inference/MobileNetV1_infer.tar'
)
os
.
system
(
'tar -xf MobileNetV1_infer.tar'
)
if
not
os
.
path
.
exists
(
'ILSVRC2012_data_demo'
):
os
.
system
(
'wget -q https://sys-p0.bj.bcebos.com/slim_ci/ILSVRC2012_data_demo.tar.gz'
)
os
.
system
(
'tar -xf ILSVRC2012_data_demo.tar.gz'
)
def
test_demo
(
self
):
train_dataset
=
ImageNetDataset
(
"./ILSVRC2012_data_demo/ILSVRC2012/"
)
image
=
paddle
.
static
.
data
(
name
=
'inputs'
,
shape
=
[
None
]
+
[
3
,
224
,
224
],
dtype
=
'float32'
)
label
=
paddle
.
static
.
data
(
name
=
'labels'
,
shape
=
[
None
]
+
[
1
],
dtype
=
'float32'
)
train_loader
=
paddle
.
io
.
DataLoader
(
train_dataset
,
feed_list
=
[
image
],
batch_size
=
8
,
return_list
=
False
)
def
reader_wrapper
(
reader
,
input_name
):
def
gen
():
for
i
,
(
imgs
,
label
)
in
enumerate
(
reader
()):
yield
{
input_name
:
imgs
}
return
gen
def
eval_reader
(
data_dir
,
batch_size
,
crop_size
,
resize_size
,
place
=
None
):
val_dataset
=
ImageNetDataset
(
"./ILSVRC2012_data_demo/ILSVRC2012/"
,
mode
=
'val'
)
val_loader
=
paddle
.
io
.
DataLoader
(
val_dataset
,
feed_list
=
[
image
,
label
],
batch_size
=
batch_size
,
shuffle
=
False
,
drop_last
=
False
,
num_workers
=
0
,
return_list
=
False
)
return
val_loader
def
eval_function
(
exe
,
compiled_test_program
,
test_feed_names
,
test_fetch_list
):
val_loader
=
eval_reader
(
'./ILSVRC2012_data_demo/ILSVRC2012/'
,
batch_size
=
32
,
crop_size
=
224
,
resize_size
=
256
)
results
=
[]
print
(
'Evaluating...'
)
for
batch_id
,
data
in
enumerate
(
val_loader
):
image
=
data
[
0
][
'inputs'
]
label
=
data
[
0
][
'labels'
]
# top1_acc, top5_acc
if
len
(
test_feed_names
)
==
1
:
image
=
np
.
array
(
image
)
label
=
np
.
array
(
label
).
astype
(
'int64'
)
pred
=
exe
.
run
(
compiled_test_program
,
feed
=
{
test_feed_names
[
0
]:
image
},
fetch_list
=
test_fetch_list
)
pred
=
np
.
array
(
pred
[
0
])
label
=
np
.
array
(
label
)
sort_array
=
pred
.
argsort
(
axis
=
1
)
top_1_pred
=
sort_array
[:,
-
1
:][:,
::
-
1
]
top_1
=
np
.
mean
(
label
==
top_1_pred
)
top_5_pred
=
sort_array
[:,
-
5
:][:,
::
-
1
]
acc_num
=
0
for
i
in
range
(
len
(
label
)):
if
label
[
i
][
0
]
in
top_5_pred
[
i
]:
acc_num
+=
1
top_5
=
float
(
acc_num
)
/
len
(
label
)
results
.
append
([
top_1
,
top_5
])
else
:
# eval "eval model", which inputs are image and label, output is top1 and top5 accuracy
image
=
np
.
array
(
image
)
label
=
np
.
array
(
label
).
astype
(
'int64'
)
result
=
exe
.
run
(
compiled_test_program
,
feed
=
{
test_feed_names
[
0
]:
image
,
test_feed_names
[
1
]:
label
},
fetch_list
=
test_fetch_list
)
result
=
[
np
.
mean
(
r
)
for
r
in
result
]
results
.
append
(
result
)
if
batch_id
%
100
==
0
:
print
(
'Eval iter: '
,
batch_id
)
result
=
np
.
mean
(
np
.
array
(
results
),
axis
=
0
)
return
result
[
0
]
analyzer
=
AnalysisPTQ
(
model_dir
=
"./MobileNetV1_infer"
,
model_filename
=
"inference.pdmodel"
,
params_filename
=
"inference.pdiparams"
,
save_dir
=
"MobileNetV1_analysis"
,
ptq_config
=
{
'quantizable_op_type'
:
[
"conv2d"
,
"depthwise_conv2d"
],
'weight_quantize_type'
:
'abs_max'
,
'activation_quantize_type'
:
'moving_average_abs_max'
,
'is_full_quantize'
:
False
,
'batch_size'
:
8
,
'batch_nums'
:
10
,
},
data_loader
=
train_loader
,
eval_function
=
eval_function
)
analyzer
.
metric_error_analyse
()
analyzer
.
get_target_quant_model
(
69.5
)
os
.
system
(
'rm -rf MobileNetV1_analysis'
)
if
__name__
==
'__main__'
:
unittest
.
main
()
tests/test_analysis_qat.py
0 → 100644
浏览文件 @
8e1691b4
import
os
import
sys
import
unittest
sys
.
path
.
append
(
"../"
)
import
paddle
from
PIL
import
Image
from
paddle.vision.datasets
import
DatasetFolder
from
paddle.vision.transforms
import
transforms
from
paddle.fluid.contrib.slim.quantization
import
PostTrainingQuantization
from
paddleslim.quant.analysis_qat
import
AnalysisQAT
paddle
.
enable_static
()
class
ImageNetDataset
(
DatasetFolder
):
def
__init__
(
self
,
path
,
image_size
=
224
):
super
(
ImageNetDataset
,
self
).
__init__
(
path
)
normalize
=
transforms
.
Normalize
(
mean
=
[
123.675
,
116.28
,
103.53
],
std
=
[
58.395
,
57.120
,
57.375
])
self
.
transform
=
transforms
.
Compose
([
transforms
.
Resize
(
256
),
transforms
.
CenterCrop
(
image_size
),
transforms
.
Transpose
(),
normalize
])
def
__getitem__
(
self
,
idx
):
img_path
,
_
=
self
.
samples
[
idx
]
return
self
.
transform
(
Image
.
open
(
img_path
).
convert
(
'RGB'
))
def
__len__
(
self
):
return
len
(
self
.
samples
)
class
AnalysisQATDemo
(
unittest
.
TestCase
):
def
__init__
(
self
,
*
args
,
**
kwargs
):
super
(
AnalysisQATDemo
,
self
).
__init__
(
*
args
,
**
kwargs
)
if
not
os
.
path
.
exists
(
'MobileNetV1_infer'
):
os
.
system
(
'wget -q https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/inference/MobileNetV1_infer.tar'
)
os
.
system
(
'tar -xf MobileNetV1_infer.tar'
)
if
not
os
.
path
.
exists
(
'ILSVRC2012_data_demo'
):
os
.
system
(
'wget -q https://sys-p0.bj.bcebos.com/slim_ci/ILSVRC2012_data_demo.tar.gz'
)
os
.
system
(
'tar -xf ILSVRC2012_data_demo.tar.gz'
)
def
test_demo
(
self
):
train_dataset
=
ImageNetDataset
(
"./ILSVRC2012_data_demo/ILSVRC2012/train/"
)
image
=
paddle
.
static
.
data
(
name
=
'inputs'
,
shape
=
[
None
]
+
[
3
,
224
,
224
],
dtype
=
'float32'
)
train_loader
=
paddle
.
io
.
DataLoader
(
train_dataset
,
feed_list
=
[
image
],
batch_size
=
8
,
return_list
=
False
)
place
=
paddle
.
CUDAPlace
(
0
)
if
paddle
.
is_compiled_with_cuda
(
)
else
paddle
.
CPUPlace
()
executor
=
paddle
.
static
.
Executor
(
place
)
ptq_config
=
{
'quantizable_op_type'
:
[
"conv2d"
,
"depthwise_conv2d"
],
'weight_quantize_type'
:
'abs_max'
,
'activation_quantize_type'
:
'moving_average_abs_max'
,
'is_full_quantize'
:
False
,
'batch_size'
:
8
,
'batch_nums'
:
10
,
}
post_training_quantization
=
PostTrainingQuantization
(
executor
=
executor
,
data_loader
=
train_loader
,
model_dir
=
"./MobileNetV1_infer"
,
model_filename
=
"inference.pdmodel"
,
params_filename
=
"inference.pdiparams"
,
onnx_format
=
True
,
algo
=
'avg'
,
**
ptq_config
)
post_training_quantization
.
quantize
()
post_training_quantization
.
save_quantized_model
(
"./MobileNetV1_quant"
,
model_filename
=
'inference.pdmodel'
,
params_filename
=
'inference.pdiparams'
)
analyzer
=
AnalysisQAT
(
float_model_dir
=
"./MobileNetV1_infer"
,
quant_model_dir
=
"./MobileNetV1_quant"
,
model_filename
=
"inference.pdmodel"
,
params_filename
=
"inference.pdiparams"
,
save_dir
=
"analysis_result"
,
data_loader
=
train_loader
)
analyzer
.
metric_error_analyse
()
os
.
system
(
'rm -rf analysis_result'
)
os
.
system
(
'rm -rf MobileNetV1_quant'
)
if
__name__
==
'__main__'
:
unittest
.
main
()
tests/test_analysis_qat_eval_func.py
0 → 100644
浏览文件 @
8e1691b4
import
os
import
sys
import
unittest
import
numpy
as
np
sys
.
path
.
append
(
"../"
)
import
paddle
from
PIL
import
Image
from
paddle.vision.datasets
import
DatasetFolder
from
paddle.vision.transforms
import
transforms
from
paddleslim.quant.analysis_qat
import
AnalysisQAT
from
paddle.fluid.contrib.slim.quantization
import
PostTrainingQuantization
paddle
.
enable_static
()
class
ImageNetDataset
(
DatasetFolder
):
def
__init__
(
self
,
data_dir
,
image_size
=
224
,
mode
=
'train'
):
super
(
ImageNetDataset
,
self
).
__init__
(
data_dir
)
self
.
data_dir
=
data_dir
normalize
=
transforms
.
Normalize
(
mean
=
[
123.675
,
116.28
,
103.53
],
std
=
[
58.395
,
57.120
,
57.375
])
self
.
transform
=
transforms
.
Compose
([
transforms
.
Resize
(
256
),
transforms
.
CenterCrop
(
image_size
),
transforms
.
Transpose
(),
normalize
])
self
.
mode
=
mode
train_file_list
=
os
.
path
.
join
(
data_dir
,
'train_list.txt'
)
val_file_list
=
os
.
path
.
join
(
data_dir
,
'val_list.txt'
)
self
.
mode
=
mode
if
mode
==
'train'
:
with
open
(
train_file_list
)
as
flist
:
full_lines
=
[
line
.
strip
()
for
line
in
flist
]
np
.
random
.
shuffle
(
full_lines
)
lines
=
full_lines
self
.
samples
=
[
line
.
split
()
for
line
in
lines
]
else
:
with
open
(
val_file_list
)
as
flist
:
lines
=
[
line
.
strip
()
for
line
in
flist
]
self
.
samples
=
[
line
.
split
()
for
line
in
lines
]
def
__getitem__
(
self
,
idx
):
img_path
,
label
=
self
.
samples
[
idx
]
if
self
.
mode
==
'train'
:
return
self
.
transform
(
Image
.
open
(
os
.
path
.
join
(
self
.
data_dir
,
img_path
)).
convert
(
'RGB'
))
else
:
return
self
.
transform
(
Image
.
open
(
os
.
path
.
join
(
self
.
data_dir
,
img_path
)).
convert
(
'RGB'
)),
np
.
array
([
label
]).
astype
(
'int64'
)
def
__len__
(
self
):
return
len
(
self
.
samples
)
class
AnalysisQATEvalFunction
(
unittest
.
TestCase
):
def
__init__
(
self
,
*
args
,
**
kwargs
):
super
(
AnalysisQATEvalFunction
,
self
).
__init__
(
*
args
,
**
kwargs
)
if
not
os
.
path
.
exists
(
'MobileNetV1_infer'
):
os
.
system
(
'wget -q https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/inference/MobileNetV1_infer.tar'
)
os
.
system
(
'tar -xf MobileNetV1_infer.tar'
)
if
not
os
.
path
.
exists
(
'ILSVRC2012_data_demo'
):
os
.
system
(
'wget -q https://sys-p0.bj.bcebos.com/slim_ci/ILSVRC2012_data_demo.tar.gz'
)
os
.
system
(
'tar -xf ILSVRC2012_data_demo.tar.gz'
)
def
test_demo
(
self
):
train_dataset
=
ImageNetDataset
(
"./ILSVRC2012_data_demo/ILSVRC2012/"
)
image
=
paddle
.
static
.
data
(
name
=
'inputs'
,
shape
=
[
None
]
+
[
3
,
224
,
224
],
dtype
=
'float32'
)
label
=
paddle
.
static
.
data
(
name
=
'labels'
,
shape
=
[
None
]
+
[
1
],
dtype
=
'float32'
)
train_loader
=
paddle
.
io
.
DataLoader
(
train_dataset
,
feed_list
=
[
image
],
batch_size
=
8
,
return_list
=
False
)
def
reader_wrapper
(
reader
,
input_name
):
def
gen
():
for
i
,
(
imgs
,
label
)
in
enumerate
(
reader
()):
yield
{
input_name
:
imgs
}
return
gen
def
eval_reader
(
data_dir
,
batch_size
,
crop_size
,
resize_size
,
place
=
None
):
val_dataset
=
ImageNetDataset
(
"./ILSVRC2012_data_demo/ILSVRC2012/"
,
mode
=
'val'
)
val_loader
=
paddle
.
io
.
DataLoader
(
val_dataset
,
feed_list
=
[
image
,
label
],
batch_size
=
batch_size
,
shuffle
=
False
,
drop_last
=
False
,
num_workers
=
0
,
return_list
=
False
)
return
val_loader
def
eval_function
(
exe
,
compiled_test_program
,
test_feed_names
,
test_fetch_list
):
val_loader
=
eval_reader
(
'./ILSVRC2012_data_demo/ILSVRC2012/'
,
batch_size
=
32
,
crop_size
=
224
,
resize_size
=
256
)
results
=
[]
print
(
'Evaluating...'
)
for
batch_id
,
data
in
enumerate
(
val_loader
):
image
=
data
[
0
][
'inputs'
]
label
=
data
[
0
][
'labels'
]
# top1_acc, top5_acc
if
len
(
test_feed_names
)
==
1
:
image
=
np
.
array
(
image
)
label
=
np
.
array
(
label
).
astype
(
'int64'
)
pred
=
exe
.
run
(
compiled_test_program
,
feed
=
{
test_feed_names
[
0
]:
image
},
fetch_list
=
test_fetch_list
)
pred
=
np
.
array
(
pred
[
0
])
label
=
np
.
array
(
label
)
sort_array
=
pred
.
argsort
(
axis
=
1
)
top_1_pred
=
sort_array
[:,
-
1
:][:,
::
-
1
]
top_1
=
np
.
mean
(
label
==
top_1_pred
)
top_5_pred
=
sort_array
[:,
-
5
:][:,
::
-
1
]
acc_num
=
0
for
i
in
range
(
len
(
label
)):
if
label
[
i
][
0
]
in
top_5_pred
[
i
]:
acc_num
+=
1
top_5
=
float
(
acc_num
)
/
len
(
label
)
results
.
append
([
top_1
,
top_5
])
else
:
# eval "eval model", which inputs are image and label, output is top1 and top5 accuracy
image
=
np
.
array
(
image
)
label
=
np
.
array
(
label
).
astype
(
'int64'
)
result
=
exe
.
run
(
compiled_test_program
,
feed
=
{
test_feed_names
[
0
]:
image
,
test_feed_names
[
1
]:
label
},
fetch_list
=
test_fetch_list
)
result
=
[
np
.
mean
(
r
)
for
r
in
result
]
results
.
append
(
result
)
if
batch_id
%
100
==
0
:
print
(
'Eval iter: '
,
batch_id
)
result
=
np
.
mean
(
np
.
array
(
results
),
axis
=
0
)
return
result
[
0
]
place
=
paddle
.
CUDAPlace
(
0
)
if
paddle
.
is_compiled_with_cuda
(
)
else
paddle
.
CPUPlace
()
executor
=
paddle
.
static
.
Executor
(
place
)
ptq_config
=
{
'quantizable_op_type'
:
[
"conv2d"
,
"depthwise_conv2d"
],
'weight_quantize_type'
:
'abs_max'
,
'activation_quantize_type'
:
'moving_average_abs_max'
,
'is_full_quantize'
:
False
,
'batch_size'
:
8
,
'batch_nums'
:
10
,
}
post_training_quantization
=
PostTrainingQuantization
(
executor
=
executor
,
data_loader
=
train_loader
,
model_dir
=
"./MobileNetV1_infer"
,
model_filename
=
"inference.pdmodel"
,
params_filename
=
"inference.pdiparams"
,
onnx_format
=
True
,
algo
=
'avg'
,
**
ptq_config
)
post_training_quantization
.
quantize
()
post_training_quantization
.
save_quantized_model
(
"./MobileNetV1_QAT"
,
model_filename
=
'inference.pdmodel'
,
params_filename
=
'inference.pdiparams'
)
analyzer
=
AnalysisQAT
(
float_model_dir
=
"./MobileNetV1_infer"
,
quant_model_dir
=
"./MobileNetV1_QAT"
,
model_filename
=
"inference.pdmodel"
,
params_filename
=
"inference.pdiparams"
,
save_dir
=
"MobileNetV1_analysis"
,
data_loader
=
train_loader
,
eval_function
=
eval_function
)
analyzer
.
metric_error_analyse
()
os
.
system
(
'rm -rf MobileNetV1_analysis'
)
os
.
system
(
'rm -rf MobileNetV1_QAT'
)
if
__name__
==
'__main__'
:
unittest
.
main
()
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录