Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
Crayon鑫
Paddle
提交
43f3d0cc
P
Paddle
项目概览
Crayon鑫
/
Paddle
与 Fork 源项目一致
Fork自
PaddlePaddle / Paddle
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
1
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
1
Issue
1
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
43f3d0cc
编写于
7月 27, 2020
作者:
W
Wojciech Uss
提交者:
GitHub
7月 27, 2020
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
Add an option to choose inference targets in Quant tests (#25582)
test=develop
上级
b158a21b
变更
2
隐藏空白更改
内联
并排
Showing
2 changed file
with
113 addition
and
67 deletion
+113
-67
python/paddle/fluid/contrib/slim/tests/quant2_int8_image_classification_comparison.py
...slim/tests/quant2_int8_image_classification_comparison.py
+57
-35
python/paddle/fluid/contrib/slim/tests/quant2_int8_nlp_comparison.py
...le/fluid/contrib/slim/tests/quant2_int8_nlp_comparison.py
+56
-32
未找到文件。
python/paddle/fluid/contrib/slim/tests/quant2_int8_image_classification_comparison.py
浏览文件 @
43f3d0cc
...
...
@@ -68,6 +68,12 @@ def parse_args():
type
=
str
,
default
=
''
,
help
=
'A comma separated list of operator ids to skip in quantization.'
)
parser
.
add_argument
(
'--targets'
,
type
=
str
,
default
=
'quant,int8,fp32'
,
help
=
'A comma separated list of inference types to run ("int8", "fp32", "quant"). Default: "quant,int8,fp32"'
)
parser
.
add_argument
(
'--debug'
,
action
=
'store_true'
,
...
...
@@ -310,6 +316,12 @@ class Quant2Int8ImageClassificationComparisonTest(unittest.TestCase):
assert
int8_acc1
>
0.5
assert
quant_acc1
-
int8_acc1
<=
threshold
def
_strings_from_csv
(
self
,
string
):
return
set
(
s
.
strip
()
for
s
in
string
.
split
(
','
))
def
_ints_from_csv
(
self
,
string
):
return
set
(
map
(
int
,
string
.
split
(
','
)))
def
test_graph_transformation
(
self
):
if
not
fluid
.
core
.
is_compiled_with_mkldnn
():
return
...
...
@@ -326,14 +338,19 @@ class Quant2Int8ImageClassificationComparisonTest(unittest.TestCase):
self
.
_debug
=
test_case_args
.
debug
self
.
_quantized_ops
=
set
()
if
len
(
test_case_args
.
ops_to_quantize
)
>
0
:
self
.
_quantized_ops
=
se
t
(
op
.
strip
()
for
op
in
test_case_args
.
ops_to_quantize
.
split
(
','
)
)
if
test_case_args
.
ops_to_quantize
:
self
.
_quantized_ops
=
se
lf
.
_strings_from_csv
(
test_case_args
.
ops_to_quantize
)
self
.
_op_ids_to_skip
=
set
([
-
1
])
if
len
(
test_case_args
.
op_ids_to_skip
)
>
0
:
self
.
_op_ids_to_skip
=
set
(
map
(
int
,
test_case_args
.
op_ids_to_skip
.
split
(
','
)))
if
test_case_args
.
op_ids_to_skip
:
self
.
_op_ids_to_skip
=
self
.
_ints_from_csv
(
test_case_args
.
op_ids_to_skip
)
self
.
_targets
=
self
.
_strings_from_csv
(
test_case_args
.
targets
)
assert
self
.
_targets
.
intersection
(
{
'quant'
,
'int8'
,
'fp32'
}
),
'The --targets option, if used, must contain at least one of the targets: "quant", "int8", "fp32".'
_logger
.
info
(
'Quant & INT8 prediction run.'
)
_logger
.
info
(
'Quant model: {}'
.
format
(
quant_model_path
))
...
...
@@ -348,35 +365,38 @@ class Quant2Int8ImageClassificationComparisonTest(unittest.TestCase):
_logger
.
info
(
'Op ids to skip quantization: {}.'
.
format
(
','
.
join
(
map
(
str
,
self
.
_op_ids_to_skip
))
if
test_case_args
.
op_ids_to_skip
else
'none'
))
_logger
.
info
(
'Targets: {}.'
.
format
(
','
.
join
(
self
.
_targets
)))
_logger
.
info
(
'--- Quant prediction start ---'
)
val_reader
=
paddle
.
batch
(
self
.
_reader_creator
(
data_path
),
batch_size
=
batch_size
)
quant_output
,
quant_acc1
,
quant_acc5
,
quant_fps
,
quant_lat
=
self
.
_predict
(
val_reader
,
quant_model_path
,
batch_size
,
batch_num
,
skip_batch_num
,
target
=
'quant'
)
self
.
_print_performance
(
'Quant'
,
quant_fps
,
quant_lat
)
self
.
_print_accuracy
(
'Quant'
,
quant_acc1
,
quant_acc5
)
if
'quant'
in
self
.
_targets
:
_logger
.
info
(
'--- Quant prediction start ---'
)
val_reader
=
paddle
.
batch
(
self
.
_reader_creator
(
data_path
),
batch_size
=
batch_size
)
quant_output
,
quant_acc1
,
quant_acc5
,
quant_fps
,
quant_lat
=
self
.
_predict
(
val_reader
,
quant_model_path
,
batch_size
,
batch_num
,
skip_batch_num
,
target
=
'quant'
)
self
.
_print_performance
(
'Quant'
,
quant_fps
,
quant_lat
)
self
.
_print_accuracy
(
'Quant'
,
quant_acc1
,
quant_acc5
)
_logger
.
info
(
'--- INT8 prediction start ---'
)
val_reader
=
paddle
.
batch
(
self
.
_reader_creator
(
data_path
),
batch_size
=
batch_size
)
int8_output
,
int8_acc1
,
int8_acc5
,
int8_fps
,
int8_lat
=
self
.
_predict
(
val_reader
,
quant_model_path
,
batch_size
,
batch_num
,
skip_batch_num
,
target
=
'int8'
)
self
.
_print_performance
(
'INT8'
,
int8_fps
,
int8_lat
)
self
.
_print_accuracy
(
'INT8'
,
int8_acc1
,
int8_acc5
)
if
'int8'
in
self
.
_targets
:
_logger
.
info
(
'--- INT8 prediction start ---'
)
val_reader
=
paddle
.
batch
(
self
.
_reader_creator
(
data_path
),
batch_size
=
batch_size
)
int8_output
,
int8_acc1
,
int8_acc5
,
int8_fps
,
int8_lat
=
self
.
_predict
(
val_reader
,
quant_model_path
,
batch_size
,
batch_num
,
skip_batch_num
,
target
=
'int8'
)
self
.
_print_performance
(
'INT8'
,
int8_fps
,
int8_lat
)
self
.
_print_accuracy
(
'INT8'
,
int8_acc1
,
int8_acc5
)
fp32_acc1
=
fp32_acc5
=
fp32_fps
=
fp32_lat
=
-
1
if
fp32_model_path
:
if
'fp32'
in
self
.
_targets
and
fp32_model_path
:
_logger
.
info
(
'--- FP32 prediction start ---'
)
val_reader
=
paddle
.
batch
(
self
.
_reader_creator
(
data_path
),
batch_size
=
batch_size
)
...
...
@@ -390,10 +410,12 @@ class Quant2Int8ImageClassificationComparisonTest(unittest.TestCase):
self
.
_print_performance
(
'FP32'
,
fp32_fps
,
fp32_lat
)
self
.
_print_accuracy
(
'FP32'
,
fp32_acc1
,
fp32_acc5
)
self
.
_summarize_performance
(
int8_fps
,
int8_lat
,
fp32_fps
,
fp32_lat
)
self
.
_summarize_accuracy
(
quant_acc1
,
quant_acc5
,
int8_acc1
,
int8_acc5
,
fp32_acc1
,
fp32_acc5
)
self
.
_compare_accuracy
(
acc_diff_threshold
,
quant_acc1
,
int8_acc1
)
if
{
'int8'
,
'fp32'
}.
issubset
(
self
.
_targets
):
self
.
_summarize_performance
(
int8_fps
,
int8_lat
,
fp32_fps
,
fp32_lat
)
if
{
'int8'
,
'quant'
}.
issubset
(
self
.
_targets
):
self
.
_summarize_accuracy
(
quant_acc1
,
quant_acc5
,
int8_acc1
,
int8_acc5
,
fp32_acc1
,
fp32_acc5
)
self
.
_compare_accuracy
(
acc_diff_threshold
,
quant_acc1
,
int8_acc1
)
if
__name__
==
'__main__'
:
...
...
python/paddle/fluid/contrib/slim/tests/quant2_int8_nlp_comparison.py
浏览文件 @
43f3d0cc
...
...
@@ -72,6 +72,12 @@ def parse_args():
type
=
str
,
default
=
''
,
help
=
'A comma separated list of operator ids to skip in quantization.'
)
parser
.
add_argument
(
'--targets'
,
type
=
str
,
default
=
'quant,int8,fp32'
,
help
=
'A comma separated list of inference types to run ("int8", "fp32", "quant"). Default: "quant,int8,fp32"'
)
parser
.
add_argument
(
'--debug'
,
action
=
'store_true'
,
...
...
@@ -256,6 +262,12 @@ class QuantInt8NLPComparisonTest(unittest.TestCase):
assert
int8_acc
>
0.5
assert
quant_acc
-
int8_acc
<=
threshold
def
_strings_from_csv
(
self
,
string
):
return
set
(
s
.
strip
()
for
s
in
string
.
split
(
','
))
def
_ints_from_csv
(
self
,
string
):
return
set
(
map
(
int
,
string
.
split
(
','
)))
def
test_graph_transformation
(
self
):
if
not
fluid
.
core
.
is_compiled_with_mkldnn
():
return
...
...
@@ -274,13 +286,18 @@ class QuantInt8NLPComparisonTest(unittest.TestCase):
self
.
_quantized_ops
=
set
()
if
test_case_args
.
ops_to_quantize
:
self
.
_quantized_ops
=
se
t
(
op
.
strip
()
for
op
in
test_case_args
.
ops_to_quantize
.
split
(
','
)
)
self
.
_quantized_ops
=
se
lf
.
_strings_from_csv
(
test_case_args
.
ops_to_quantize
)
self
.
_op_ids_to_skip
=
set
([
-
1
])
if
test_case_args
.
op_ids_to_skip
:
self
.
_op_ids_to_skip
=
set
(
map
(
int
,
test_case_args
.
op_ids_to_skip
.
split
(
','
)))
self
.
_op_ids_to_skip
=
self
.
_ints_from_csv
(
test_case_args
.
op_ids_to_skip
)
self
.
_targets
=
self
.
_strings_from_csv
(
test_case_args
.
targets
)
assert
self
.
_targets
.
intersection
(
{
'quant'
,
'int8'
,
'fp32'
}
),
'The --targets option, if used, must contain at least one of the targets: "quant", "int8", "fp32".'
_logger
.
info
(
'Quant & INT8 prediction run.'
)
_logger
.
info
(
'Quant model: {}'
.
format
(
quant_model_path
))
...
...
@@ -296,35 +313,40 @@ class QuantInt8NLPComparisonTest(unittest.TestCase):
_logger
.
info
(
'Op ids to skip quantization: {}.'
.
format
(
','
.
join
(
map
(
str
,
self
.
_op_ids_to_skip
))
if
test_case_args
.
op_ids_to_skip
else
'none'
))
_logger
.
info
(
'Targets: {}.'
.
format
(
','
.
join
(
self
.
_targets
)))
_logger
.
info
(
'--- Quant prediction start ---'
)
val_reader
=
paddle
.
batch
(
self
.
_reader_creator
(
data_path
,
labels_path
),
batch_size
=
batch_size
)
quant_acc
,
quant_pps
,
quant_lat
=
self
.
_predict
(
val_reader
,
quant_model_path
,
batch_size
,
batch_num
,
skip_batch_num
,
target
=
'quant'
)
self
.
_print_performance
(
'Quant'
,
quant_pps
,
quant_lat
)
self
.
_print_accuracy
(
'Quant'
,
quant_acc
)
if
'quant'
in
self
.
_targets
:
_logger
.
info
(
'--- Quant prediction start ---'
)
val_reader
=
paddle
.
batch
(
self
.
_reader_creator
(
data_path
,
labels_path
),
batch_size
=
batch_size
)
quant_acc
,
quant_pps
,
quant_lat
=
self
.
_predict
(
val_reader
,
quant_model_path
,
batch_size
,
batch_num
,
skip_batch_num
,
target
=
'quant'
)
self
.
_print_performance
(
'Quant'
,
quant_pps
,
quant_lat
)
self
.
_print_accuracy
(
'Quant'
,
quant_acc
)
_logger
.
info
(
'--- INT8 prediction start ---'
)
val_reader
=
paddle
.
batch
(
self
.
_reader_creator
(
data_path
,
labels_path
),
batch_size
=
batch_size
)
int8_acc
,
int8_pps
,
int8_lat
=
self
.
_predict
(
val_reader
,
quant_model_path
,
batch_size
,
batch_num
,
skip_batch_num
,
target
=
'int8'
)
self
.
_print_performance
(
'INT8'
,
int8_pps
,
int8_lat
)
self
.
_print_accuracy
(
'INT8'
,
int8_acc
)
if
'int8'
in
self
.
_targets
:
_logger
.
info
(
'--- INT8 prediction start ---'
)
val_reader
=
paddle
.
batch
(
self
.
_reader_creator
(
data_path
,
labels_path
),
batch_size
=
batch_size
)
int8_acc
,
int8_pps
,
int8_lat
=
self
.
_predict
(
val_reader
,
quant_model_path
,
batch_size
,
batch_num
,
skip_batch_num
,
target
=
'int8'
)
self
.
_print_performance
(
'INT8'
,
int8_pps
,
int8_lat
)
self
.
_print_accuracy
(
'INT8'
,
int8_acc
)
fp32_acc
=
fp32_pps
=
fp32_lat
=
-
1
if
fp32_model_path
:
if
'fp32'
in
self
.
_targets
and
fp32_model_path
:
_logger
.
info
(
'--- FP32 prediction start ---'
)
val_reader
=
paddle
.
batch
(
self
.
_reader_creator
(
data_path
,
labels_path
),
...
...
@@ -339,9 +361,11 @@ class QuantInt8NLPComparisonTest(unittest.TestCase):
self
.
_print_performance
(
'FP32'
,
fp32_pps
,
fp32_lat
)
self
.
_print_accuracy
(
'FP32'
,
fp32_acc
)
self
.
_summarize_performance
(
int8_pps
,
int8_lat
,
fp32_pps
,
fp32_lat
)
self
.
_summarize_accuracy
(
quant_acc
,
int8_acc
,
fp32_acc
)
self
.
_compare_accuracy
(
acc_diff_threshold
,
quant_acc
,
int8_acc
)
if
{
'int8'
,
'fp32'
}.
issubset
(
self
.
_targets
):
self
.
_summarize_performance
(
int8_pps
,
int8_lat
,
fp32_pps
,
fp32_lat
)
if
{
'int8'
,
'quant'
}.
issubset
(
self
.
_targets
):
self
.
_summarize_accuracy
(
quant_acc
,
int8_acc
,
fp32_acc
)
self
.
_compare_accuracy
(
acc_diff_threshold
,
quant_acc
,
int8_acc
)
if
__name__
==
'__main__'
:
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录