Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
BaiXuePrincess
Paddle
提交
0af1a87b
P
Paddle
项目概览
BaiXuePrincess
/
Paddle
与 Fork 源项目一致
Fork自
PaddlePaddle / Paddle
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
0af1a87b
编写于
1月 05, 2022
作者:
J
Jiaqi Liu
提交者:
GitHub
1月 05, 2022
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
Make post training quant API support dataloader (#38686)
* make post training quant API support dataloader
上级
60c51de5
变更
2
显示空白变更内容
内联
并排
Showing
2 changed file
with
52 addition
and
8 deletion
+52
-8
python/paddle/fluid/contrib/slim/quantization/post_training_quantization.py
...d/contrib/slim/quantization/post_training_quantization.py
+14
-3
python/paddle/fluid/contrib/slim/tests/test_post_training_quantization_while.py
...ntrib/slim/tests/test_post_training_quantization_while.py
+38
-5
未找到文件。
python/paddle/fluid/contrib/slim/quantization/post_training_quantization.py
浏览文件 @
0af1a87b
...
...
@@ -17,6 +17,7 @@ import re
import
logging
import
numpy
as
np
import
shutil
from
inspect
import
isgeneratorfunction
from
....
import
io
from
....
import
core
from
....
import
framework
...
...
@@ -136,6 +137,7 @@ class PostTrainingQuantization(object):
params_filename
=
None
,
batch_generator
=
None
,
sample_generator
=
None
,
data_loader
=
None
,
batch_size
=
10
,
batch_nums
=
None
,
algo
=
"KL"
,
...
...
@@ -175,6 +177,9 @@ class PostTrainingQuantization(object):
calibrate data for DataLoader, and it only returns a sample every
time. Note that, sample_generator and batch_generator, only one
should be set. Beisdes, sample_generator dose not support lod tensor.
data_loader(Python Generator, Paddle.io.DataLoader, optional): The
Generator or Dataloader provides calibrate data, and it could
return a batch every time.
batch_size(int, optional): The batch size of DataLoader. Default is 10.
batch_nums(int, optional): If batch_nums is not None, the number of
calibrate data is batch_size*batch_nums. If batch_nums is None, use
...
...
@@ -279,8 +284,11 @@ class PostTrainingQuantization(object):
assert
executor
is
not
None
,
"The executor cannot be None."
assert
model_dir
is
not
None
,
"The model_dir cannot be None."
assert
any
([
gen
is
not
None
]
for
gen
in
[
sample_generator
,
batch_generator
]),
"The sample_generator and batch_generator "
\
"cannot be None in the same time."
batch_generator
,
data_loader
]),
"The sample_generator, batch_generator "
\
"and data_loader cannot be None in the same time."
if
data_loader
is
not
None
:
assert
isinstance
(
data_loader
,
(
io
.
DataLoader
,
type
(
isgeneratorfunction
))),
\
"data_loader only accepts `paddle.io.DataLoader` or Generator instance."
assert
batch_size
>
0
,
"The batch_size should be greater than 0."
assert
algo
in
self
.
_support_algo_type
,
\
"The algo should be KL, hist, mse, avg, abs_max or min_max."
...
...
@@ -323,7 +331,7 @@ class PostTrainingQuantization(object):
self
.
_program
=
None
self
.
_feed_list
=
None
self
.
_fetch_list
=
None
self
.
_data_loader
=
None
self
.
_data_loader
=
data_loader
self
.
_out_scale_op_list
=
_out_scale_op_list
self
.
_quantized_weight_var_name
=
set
()
...
...
@@ -473,6 +481,9 @@ class PostTrainingQuantization(object):
feed_vars
=
[
framework
.
_get_var
(
str
(
var_name
),
self
.
_program
)
\
for
var_name
in
self
.
_feed_list
]
if
self
.
_data_loader
is
not
None
:
return
self
.
_data_loader
=
io
.
DataLoader
.
from_generator
(
feed_list
=
feed_vars
,
capacity
=
3
*
self
.
_batch_size
,
iterable
=
True
)
if
self
.
_sample_generator
is
not
None
:
...
...
python/paddle/fluid/contrib/slim/tests/test_post_training_quantization_while.py
浏览文件 @
0af1a87b
...
...
@@ -115,19 +115,30 @@ class TestPostTrainingQuantization(unittest.TestCase):
is_use_cache_file
=
False
,
is_optimize_model
=
False
,
batch_size
=
10
,
batch_nums
=
10
):
batch_nums
=
10
,
is_data_loader
=
False
):
place
=
fluid
.
CPUPlace
()
exe
=
fluid
.
Executor
(
place
)
scope
=
fluid
.
global_scope
()
val_reader
=
paddle
.
dataset
.
mnist
.
train
()
def
val_data_generator
():
batches
=
[]
for
data
in
val_reader
():
batches
.
append
(
data
[
0
].
reshape
(
1
,
28
,
28
))
if
len
(
batches
)
==
batch_size
:
batches
=
np
.
asarray
(
batches
)
yield
{
"x"
:
batches
}
batches
=
[]
ptq
=
PostTrainingQuantization
(
executor
=
exe
,
model_dir
=
model_path
,
model_filename
=
'model.pdmodel'
,
params_filename
=
'model.pdiparams'
,
sample_generator
=
val_reader
,
sample_generator
=
val_reader
if
not
is_data_loader
else
None
,
data_loader
=
val_data_generator
if
is_data_loader
else
None
,
batch_size
=
batch_size
,
batch_nums
=
batch_nums
,
algo
=
algo
,
...
...
@@ -153,7 +164,8 @@ class TestPostTrainingQuantization(unittest.TestCase):
diff_threshold
,
batch_size
=
10
,
infer_iterations
=
10
,
quant_iterations
=
5
):
quant_iterations
=
5
,
is_data_loader
=
False
):
origin_model_path
=
self
.
download_model
(
data_url
,
data_md5
,
model_name
)
#origin_model_path = os.path.join(origin_model_path, model_name)
...
...
@@ -166,8 +178,15 @@ class TestPostTrainingQuantization(unittest.TestCase):
print
(
"Start INT8 post training quantization for {0} on {1} images ..."
.
format
(
model_name
,
quant_iterations
*
batch_size
))
self
.
generate_quantized_model
(
origin_model_path
,
algo
,
quantizable_op_type
,
is_full_quantize
,
is_use_cache_file
,
is_optimize_model
,
batch_size
,
quant_iterations
)
origin_model_path
,
algo
,
quantizable_op_type
,
is_full_quantize
,
is_use_cache_file
,
is_optimize_model
,
batch_size
,
quant_iterations
,
is_data_loader
=
is_data_loader
)
print
(
"Start INT8 inference for {0} on {1} images ..."
.
format
(
model_name
,
infer_iterations
*
batch_size
))
...
...
@@ -307,6 +326,20 @@ class TestPostTrainingAbsMaxForWhile(TestPostTrainingQuantization):
is_full_quantize
,
is_use_cache_file
,
is_optimize_model
,
diff_threshold
,
batch_size
,
infer_iterations
,
quant_iterations
)
self
.
run_test
(
model_name
,
data_url
,
data_md5
,
algo
,
quantizable_op_type
,
is_full_quantize
,
is_use_cache_file
,
is_optimize_model
,
diff_threshold
,
batch_size
,
infer_iterations
,
quant_iterations
,
is_data_loader
=
True
)
if
__name__
==
'__main__'
:
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录