Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
机器未来
Paddle
提交
2a82c565
P
Paddle
项目概览
机器未来
/
Paddle
与 Fork 源项目一致
Fork自
PaddlePaddle / Paddle
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
1
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
1
Issue
1
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
2a82c565
编写于
1月 25, 2019
作者:
H
Haihao Shen
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
Refine INT8 calibration API; shorten the iteration number to reduce test time; test=develop
上级
e043ea96
变更
2
隐藏空白更改
内联
并排
Showing
2 changed file
with
42 addition
and
27 deletion
+42
-27
python/paddle/fluid/contrib/int8_inference/utility.py
python/paddle/fluid/contrib/int8_inference/utility.py
+29
-5
python/paddle/fluid/contrib/tests/test_calibration.py
python/paddle/fluid/contrib/tests/test_calibration.py
+13
-22
未找到文件。
python/paddle/fluid/contrib/int8_inference/utility.py
浏览文件 @
2a82c565
...
...
@@ -32,10 +32,13 @@ class Calibrator(object):
def
__init__
(
self
,
*
args
,
**
kwargs
):
self
.
program
=
kwargs
[
'program'
]
self
.
iterations
=
kwargs
[
'iterations'
]
self
.
pretrained_model
=
kwargs
[
'pretrained_model'
]
self
.
debug
=
kwargs
[
'debug'
]
self
.
debug
=
kwargs
[
'debug'
]
if
'debug'
in
kwargs
else
False
self
.
algo
=
kwargs
[
'algo'
]
self
.
output
=
kwargs
[
'output'
]
self
.
feed_var_names
=
kwargs
[
'feed_var_names'
]
self
.
fetch_list
=
kwargs
[
'fetch_list'
]
self
.
exe
=
kwargs
[
'exe'
]
self
.
_conv_input_var_name
=
[]
self
.
_conv_output_var_name
=
[]
...
...
@@ -54,17 +57,38 @@ class Calibrator(object):
self
.
_u8_output_var
=
[]
self
.
_s8_output_var
=
[]
self
.
_persistable_vars
=
[]
self
.
_sampling_data
=
{}
def
generate_sampling_program
(
self
):
self
.
__init_analysis
()
self
.
__generate_output_program
()
def
generate_quantized_data
(
self
,
sampling_data
):
self
.
__sampling
(
sampling_data
)
def
save_int8_model
(
self
):
self
.
__sampling
(
s
elf
.
_s
ampling_data
)
self
.
__save_scale
()
self
.
__update_program
()
self
.
__update_output_program_attr
()
self
.
__display_debug
()
self
.
__save_offline_model
()
def
sample_data
(
self
):
'''
Sampling the tensor data of variable.
'''
for
i
in
self
.
sampling_program
.
list_vars
():
if
i
.
name
in
self
.
sampling_vars
:
np_data
=
np
.
array
(
fluid
.
global_scope
().
find_var
(
i
.
name
)
.
get_tensor
())
if
i
.
name
not
in
self
.
_sampling_data
:
self
.
_sampling_data
[
i
.
name
]
=
[]
self
.
_sampling_data
[
i
.
name
].
append
(
np_data
)
def
__save_offline_model
(
self
):
'''
Save the quantized model to the disk.
'''
fluid
.
io
.
save_inference_model
(
self
.
output
,
self
.
feed_var_names
,
self
.
fetch_list
,
self
.
exe
,
self
.
sampling_program
)
def
__display_debug
(
self
):
if
self
.
debug
:
...
...
python/paddle/fluid/contrib/tests/test_calibration.py
浏览文件 @
2a82c565
...
...
@@ -26,7 +26,7 @@ import paddle.fluid.profiler as profiler
from
PIL
import
Image
,
ImageEnhance
import
math
sys
.
path
.
append
(
'..'
)
import
int8_inference.utility
as
ut
import
int8_inference.utility
as
int8_utility
random
.
seed
(
0
)
np
.
random
.
seed
(
0
)
...
...
@@ -120,13 +120,13 @@ class TestCalibration(unittest.TestCase):
def
setUp
(
self
):
# TODO(guomingz): Put the download process in the cmake.
# Download and unzip test data set
imagenet_dl_url
=
'http://paddle-inference-dist.
bj
.bcebos.com/int8/calibration_test_data.tar.gz'
imagenet_dl_url
=
'http://paddle-inference-dist.
cdn
.bcebos.com/int8/calibration_test_data.tar.gz'
zip_file_name
=
imagenet_dl_url
.
split
(
'/'
)[
-
1
]
cmd
=
'rm -rf data {} && mkdir data && wget {} && tar xvf {} -C data'
.
format
(
zip_file_name
,
imagenet_dl_url
,
zip_file_name
)
os
.
system
(
cmd
)
# resnet50 fp32 data
resnet50_fp32_model_url
=
'http://paddle-inference-dist.
bj
.bcebos.com/int8/resnet50_int8_model.tar.gz'
resnet50_fp32_model_url
=
'http://paddle-inference-dist.
cdn
.bcebos.com/int8/resnet50_int8_model.tar.gz'
resnet50_zip_name
=
resnet50_fp32_model_url
.
split
(
'/'
)[
-
1
]
resnet50_unzip_folder_name
=
'resnet50_fp32'
cmd
=
'rm -rf {} {} && mkdir {} && wget {} && tar xvf {} -C {}'
.
format
(
...
...
@@ -135,8 +135,7 @@ class TestCalibration(unittest.TestCase):
resnet50_zip_name
,
resnet50_unzip_folder_name
)
os
.
system
(
cmd
)
self
.
iterations
=
100
self
.
skip_batch_num
=
5
self
.
iterations
=
50
def
run_program
(
self
,
model_path
,
generate_int8
=
False
,
algo
=
'direct'
):
image_shape
=
[
3
,
224
,
224
]
...
...
@@ -163,16 +162,15 @@ class TestCalibration(unittest.TestCase):
print
(
"Start calibration ..."
)
calibrator
=
ut
.
Calibrator
(
calibrator
=
int8_utility
.
Calibrator
(
program
=
infer_program
,
pretrained_model
=
model_path
,
iterations
=
100
,
debug
=
Fals
e
,
algo
=
algo
)
sampling_data
=
{}
algo
=
algo
,
exe
=
ex
e
,
output
=
int8_model
,
feed_var_names
=
feed_dict
,
fetch_list
=
fetch_targets
)
calibrator
.
generate_sampling_program
()
test_info
=
[]
cnt
=
0
for
batch_id
,
data
in
enumerate
(
val_reader
()):
...
...
@@ -192,13 +190,7 @@ class TestCalibration(unittest.TestCase):
feed_dict
[
1
]:
label
},
fetch_list
=
fetch_targets
)
if
generate_int8
:
for
i
in
calibrator
.
sampling_program
.
list_vars
():
if
i
.
name
in
calibrator
.
sampling_vars
:
np_data
=
np
.
array
(
fluid
.
global_scope
().
find_var
(
i
.
name
)
.
get_tensor
())
if
i
.
name
not
in
sampling_data
:
sampling_data
[
i
.
name
]
=
[]
sampling_data
[
i
.
name
].
append
(
np_data
)
calibrator
.
sample_data
()
test_info
.
append
(
np
.
mean
(
acc1
)
*
len
(
data
))
cnt
+=
len
(
data
)
...
...
@@ -209,9 +201,8 @@ class TestCalibration(unittest.TestCase):
break
if
generate_int8
:
calibrator
.
generate_quantized_data
(
sampling_data
)
fluid
.
io
.
save_inference_model
(
int8_model
,
feed_dict
,
fetch_targets
,
exe
,
calibrator
.
sampling_program
)
calibrator
.
save_int8_model
()
print
(
"Calibration is done and the corresponding files were generated at {}"
.
format
(
os
.
path
.
abspath
(
"calibration_out"
)))
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录