Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
Crayon鑫
Paddle
提交
d0a921ba
P
Paddle
项目概览
Crayon鑫
/
Paddle
与 Fork 源项目一致
Fork自
PaddlePaddle / Paddle
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
1
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
1
Issue
1
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
d0a921ba
编写于
7月 06, 2020
作者:
W
Wojciech Uss
提交者:
GitHub
7月 06, 2020
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
Quant2 updates and fixes (#25313)
上级
869d59cc
变更
6
显示空白变更内容
内联
并排
Showing
6 changed file
with
228 addition
and
149 deletion
+228
-149
paddle/fluid/framework/ir/mkldnn/cpu_quantize_pass.cc
paddle/fluid/framework/ir/mkldnn/cpu_quantize_pass.cc
+11
-5
python/paddle/fluid/contrib/slim/quantization/quant2_int8_mkldnn_pass.py
...luid/contrib/slim/quantization/quant2_int8_mkldnn_pass.py
+39
-20
python/paddle/fluid/contrib/slim/tests/CMakeLists.txt
python/paddle/fluid/contrib/slim/tests/CMakeLists.txt
+18
-18
python/paddle/fluid/contrib/slim/tests/quant2_int8_image_classification_comparison.py
...slim/tests/quant2_int8_image_classification_comparison.py
+88
-54
python/paddle/fluid/contrib/slim/tests/quant2_int8_nlp_comparison.py
...le/fluid/contrib/slim/tests/quant2_int8_nlp_comparison.py
+65
-32
python/paddle/fluid/contrib/slim/tests/save_quant_model.py
python/paddle/fluid/contrib/slim/tests/save_quant_model.py
+7
-20
未找到文件。
paddle/fluid/framework/ir/mkldnn/cpu_quantize_pass.cc
浏览文件 @
d0a921ba
...
@@ -46,10 +46,8 @@ void LogCannotQuantizeOp(Node* op, const char* details = nullptr) {
...
@@ -46,10 +46,8 @@ void LogCannotQuantizeOp(Node* op, const char* details = nullptr) {
}
}
void
LogScaleIsMissingForVar
(
Node
*
var
)
{
void
LogScaleIsMissingForVar
(
Node
*
var
)
{
std
::
stringstream
msg_ss
;
VLOG
(
4
)
<<
"Quantization scale for the variable "
<<
var
->
Name
()
msg_ss
<<
"Quantization scale for the variable "
<<
var
->
Name
()
<<
" is missing."
;
<<
" is missing."
;
PrettyLogDetail
(
msg_ss
.
str
().
c_str
());
}
}
void
LogQuantizationDisabled
(
Node
*
op
)
{
void
LogQuantizationDisabled
(
Node
*
op
)
{
...
@@ -256,6 +254,14 @@ void CPUQuantizePass::QuantizeConv(Graph* graph,
...
@@ -256,6 +254,14 @@ void CPUQuantizePass::QuantizeConv(Graph* graph,
GET_IR_NODE_FROM_SUBGRAPH
(
conv_input
,
conv_input
,
conv_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
conv_input
,
conv_input
,
conv_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
conv_output
,
conv_output
,
conv_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
conv_output
,
conv_output
,
conv_pattern
);
auto
has_output_scale
=
AreScalesPresentForNodes
(
conv_op
,
{
conv_output
});
if
(
with_residual_data
&&
!
has_output_scale
)
{
LogCannotQuantizeOp
(
conv_op
,
"Conv op with ResidualData input cannot be quantized "
"without output scale."
);
return
;
}
if
(
with_residual_data
)
{
if
(
with_residual_data
)
{
GET_IR_NODE_FROM_SUBGRAPH
(
conv_residual_data
,
conv_residual_data
,
GET_IR_NODE_FROM_SUBGRAPH
(
conv_residual_data
,
conv_residual_data
,
conv_pattern
);
conv_pattern
);
...
@@ -294,7 +300,7 @@ void CPUQuantizePass::QuantizeConv(Graph* graph,
...
@@ -294,7 +300,7 @@ void CPUQuantizePass::QuantizeConv(Graph* graph,
conv_op
->
Op
()
->
SetAttr
(
"Scale_weights"
,
filter_scale
);
conv_op
->
Op
()
->
SetAttr
(
"Scale_weights"
,
filter_scale
);
// if quantization scale is missing for output tensor, return fp32 data
// if quantization scale is missing for output tensor, return fp32 data
if
(
AreScalesPresentForNodes
(
conv_op
,
{
conv_output
})
)
{
if
(
has_output_scale
)
{
bool
is_output_unsigned
{
false
};
bool
is_output_unsigned
{
false
};
auto
output_scale
=
auto
output_scale
=
GetScaleValueForNode
(
conv_output
,
&
is_output_unsigned
);
GetScaleValueForNode
(
conv_output
,
&
is_output_unsigned
);
...
...
python/paddle/fluid/contrib/slim/quantization/quant2_int8_mkldnn_pass.py
浏览文件 @
d0a921ba
...
@@ -55,7 +55,7 @@ class Quant2Int8MkldnnPass(object):
...
@@ -55,7 +55,7 @@ class Quant2Int8MkldnnPass(object):
'fake_dequantize_max_abs'
,
'fake_channel_wise_dequantize_max_abs'
'fake_dequantize_max_abs'
,
'fake_channel_wise_dequantize_max_abs'
]
]
self
.
_ops_to_quantize
=
_ops_to_quantize
self
.
_ops_to_quantize
=
_ops_to_quantize
self
.
_op_ids_to_skip
=
_op_ids_to_skip
if
_op_ids_to_skip
!=
None
else
set
(
self
.
_op_ids_to_skip
=
_op_ids_to_skip
if
_op_ids_to_skip
is
not
None
else
set
(
[
-
1
])
[
-
1
])
self
.
_scale_immutable_ops
=
[
self
.
_scale_immutable_ops
=
[
'transpose2'
,
'reshape2'
,
'pool2d'
,
'scale'
'transpose2'
,
'reshape2'
,
'pool2d'
,
'scale'
...
@@ -71,11 +71,14 @@ class Quant2Int8MkldnnPass(object):
...
@@ -71,11 +71,14 @@ class Quant2Int8MkldnnPass(object):
self
.
_var_quant_scales
=
{}
self
.
_var_quant_scales
=
{}
self
.
_max_range
=
{}
self
.
_max_range
=
{}
self
.
_s8_max
=
127
self
.
_s8_max
=
127
self
.
_pass_idx
=
0
self
.
_pass_group
=
'int8'
def
apply
(
self
,
graph
):
def
apply
(
self
,
graph
):
assert
isinstance
(
graph
,
assert
isinstance
(
graph
,
IrGraph
),
'graph must be the instance of IrGraph.'
IrGraph
),
'graph must be the instance of IrGraph.'
self
.
_reset_pass_idx_and_group
(
'int8'
)
graph
=
self
.
_gather_weight_scales_from_fake
(
graph
)
graph
=
self
.
_gather_weight_scales_from_fake
(
graph
)
graph
=
self
.
_gather_output_scales_from_attr
(
graph
)
graph
=
self
.
_gather_output_scales_from_attr
(
graph
)
graph
=
self
.
_gather_input_scales_from_fake
(
graph
)
graph
=
self
.
_gather_input_scales_from_fake
(
graph
)
...
@@ -86,21 +89,24 @@ class Quant2Int8MkldnnPass(object):
...
@@ -86,21 +89,24 @@ class Quant2Int8MkldnnPass(object):
graph
=
self
.
_update_relu_output_scales
(
graph
)
graph
=
self
.
_update_relu_output_scales
(
graph
)
graph
=
self
.
_propagate_scales
(
graph
)
graph
=
self
.
_propagate_scales
(
graph
)
graph
=
self
.
_quantize_fp32_graph
(
graph
)
graph
=
self
.
_quantize_fp32_graph
(
graph
)
graph
=
self
.
_
optimize_int8_graph
(
graph
)
graph
=
self
.
_
final_optimizations
(
graph
)
graph
=
self
.
_cleanup
(
graph
)
graph
=
self
.
_cleanup
(
graph
)
return
graph
return
graph
def
apply
_fp32
(
self
,
graph
):
def
prepare_and_optimize
_fp32
(
self
,
graph
):
assert
isinstance
(
graph
,
assert
isinstance
(
graph
,
IrGraph
),
'graph must be the instance of IrGraph.'
IrGraph
),
'graph must be the instance of IrGraph.'
graph
=
self
.
_gather_weight_scales_from_fake
(
graph
)
self
.
_reset_pass_idx_and_group
(
'fp32'
)
graph
=
self
.
_remove_fake_ops
(
graph
)
graph
=
self
.
_dequantize_weights
(
graph
)
graph
=
self
.
_optimize_fp32_graph
(
graph
)
graph
=
self
.
_optimize_fp32_graph
(
graph
)
graph
=
self
.
_final_optimizations
(
graph
)
graph
=
self
.
_cleanup
(
graph
)
graph
=
self
.
_cleanup
(
graph
)
return
graph
return
graph
def
_reset_pass_idx_and_group
(
self
,
group
):
self
.
_pass_idx
=
0
self
.
_pass_group
=
group
def
_convert_scale2tensor
(
self
,
scale
):
def
_convert_scale2tensor
(
self
,
scale
):
tensor
=
core
.
LoDTensor
()
tensor
=
core
.
LoDTensor
()
tensor
.
set
(
scale
,
core
.
CPUPlace
())
tensor
.
set
(
scale
,
core
.
CPUPlace
())
...
@@ -333,20 +339,38 @@ class Quant2Int8MkldnnPass(object):
...
@@ -333,20 +339,38 @@ class Quant2Int8MkldnnPass(object):
def
_optimize_fp32_graph
(
self
,
graph
):
def
_optimize_fp32_graph
(
self
,
graph
):
graph
=
self
.
_update_activations
(
graph
)
graph
=
self
.
_update_activations
(
graph
)
graph
=
self
.
_remove_ctrl_vars
(
graph
)
graph
=
self
.
_remove_ctrl_vars
(
graph
)
graph
=
self
.
_apply_pass
(
graph
,
'attention_lstm_fuse_pass'
)
graph
=
self
.
_apply_pass
(
graph
,
'seqconv_eltadd_relu_fuse_pass'
)
# graph = self._apply_pass(graph, 'seqpool_concat_fuse_pass')
graph
=
self
.
_apply_pass
(
graph
,
'seqpool_cvm_concat_fuse_pass'
)
# graph = self._apply_pass(graph, 'embedding_fc_lstm_fuse_pass')
graph
=
self
.
_apply_pass
(
graph
,
'fc_lstm_fuse_pass'
)
graph
=
self
.
_apply_pass
(
graph
,
'mul_lstm_fuse_pass'
)
graph
=
self
.
_apply_pass
(
graph
,
'fc_gru_fuse_pass'
)
graph
=
self
.
_apply_pass
(
graph
,
'mul_gru_fuse_pass'
)
graph
=
self
.
_apply_pass
(
graph
,
'seq_concat_fc_fuse_pass'
)
graph
=
self
.
_apply_pass
(
graph
,
'squared_mat_sub_fuse_pass'
)
graph
=
self
.
_apply_pass
(
graph
,
'is_test_pass'
)
graph
=
self
.
_apply_pass
(
graph
,
'mkldnn_placement_pass'
,
graph
=
self
.
_apply_pass
(
graph
,
'mkldnn_placement_pass'
,
[
'mkldnn_enabled_op_types'
],
[
set
()])
[
'mkldnn_enabled_op_types'
],
[
set
()])
graph
=
self
.
_apply_pass
(
graph
,
'depthwise_conv_mkldnn_pass'
)
graph
=
self
.
_apply_pass
(
graph
,
'depthwise_conv_mkldnn_pass'
)
graph
=
self
.
_apply_pass
(
graph
,
'conv_bn_fuse_pass'
)
graph
=
self
.
_apply_pass
(
graph
,
'conv_bn_fuse_pass'
)
graph
=
self
.
_apply_pass
(
graph
,
'conv_eltwiseadd_bn_fuse_pass'
)
graph
=
self
.
_apply_pass
(
graph
,
'conv_eltwiseadd_bn_fuse_pass'
)
graph
=
self
.
_apply_pass
(
graph
,
'conv_transpose_bn_fuse_pass'
)
graph
=
self
.
_apply_pass
(
graph
,
'conv_transpose_eltwiseadd_bn_fuse_pass'
)
graph
=
self
.
_apply_pass
(
graph
,
'conv_bias_mkldnn_fuse_pass'
)
graph
=
self
.
_apply_pass
(
graph
,
'conv_bias_mkldnn_fuse_pass'
)
graph
=
self
.
_apply_pass
(
graph
,
'conv_elementwise_add_mkldnn_fuse_pass'
)
graph
=
self
.
_apply_pass
(
graph
,
'conv_elementwise_add_mkldnn_fuse_pass'
)
graph
=
self
.
_apply_pass
(
graph
,
'conv_relu_mkldnn_fuse_pass'
)
graph
=
self
.
_apply_pass
(
graph
,
'conv_relu_mkldnn_fuse_pass'
)
graph
=
self
.
_apply_pass
(
graph
,
'conv_relu6_mkldnn_fuse_pass'
)
graph
=
self
.
_apply_pass
(
graph
,
'conv_relu6_mkldnn_fuse_pass'
)
graph
=
self
.
_apply_pass
(
graph
,
'fc_fuse_pass'
,
graph
=
self
.
_apply_pass
(
graph
,
'fc_fuse_pass'
,
[
'use_gpu'
,
'use_fc_padding'
],
[
False
,
False
])
[
'use_gpu'
,
'use_fc_padding'
],
[
False
,
False
])
graph
=
self
.
_apply_pass
(
graph
,
'repeated_fc_relu_fuse_pass'
)
if
self
.
_is_fc_quantized
(
graph
):
if
self
.
_is_fc_quantized
(
graph
):
graph
=
self
.
_apply_pass
(
graph
,
'fc_mkldnn_pass'
)
graph
=
self
.
_apply_pass
(
graph
,
'fc_mkldnn_pass'
)
graph
=
self
.
_apply_pass
(
graph
,
'matmul_transpose_reshape_fuse_pass'
)
graph
=
self
.
_apply_pass
(
graph
,
'matmul_transpose_reshape_fuse_pass'
)
# the following pass should be the last one since it will work on all fused ops.
graph
=
self
.
_apply_pass
(
graph
,
'runtime_context_cache_pass'
)
return
graph
return
graph
def
_apply_pass
(
self
,
graph
,
pass_name
,
attrs
=
None
,
attr_values
=
None
):
def
_apply_pass
(
self
,
graph
,
pass_name
,
attrs
=
None
,
attr_values
=
None
):
...
@@ -362,12 +386,13 @@ class Quant2Int8MkldnnPass(object):
...
@@ -362,12 +386,13 @@ class Quant2Int8MkldnnPass(object):
ir_pass
.
set
(
attr
,
value
)
ir_pass
.
set
(
attr
,
value
)
ir_pass
.
apply
(
cpp_graph
)
ir_pass
.
apply
(
cpp_graph
)
if
self
.
_debug
:
if
self
.
_debug
:
graph
.
draw
(
'.'
,
'
quant_fp32_{}'
.
format
(
pass_name
)
,
graph
.
draw
(
'.'
,
'
{}_{}_{}'
.
format
(
self
.
_pass_group
,
self
.
_pass_idx
,
graph
.
all_op_nodes
())
pass_name
),
graph
.
all_op_nodes
())
self
.
_remove_unused_var_nodes
(
graph
)
self
.
_remove_unused_var_nodes
(
graph
)
self
.
_pass_idx
+=
1
return
graph
return
graph
def
_
optimize_int8_graph
(
self
,
graph
):
def
_
final_optimizations
(
self
,
graph
):
# remove dropout ops
# remove dropout ops
graph
=
self
.
_apply_pass
(
graph
,
'simplify_with_basic_ops_pass'
)
graph
=
self
.
_apply_pass
(
graph
,
'simplify_with_basic_ops_pass'
)
# make some MKL-DNN ops working inplace
# make some MKL-DNN ops working inplace
...
@@ -448,8 +473,7 @@ class Quant2Int8MkldnnPass(object):
...
@@ -448,8 +473,7 @@ class Quant2Int8MkldnnPass(object):
self
.
_var_quant_scales
[
out_name
]
=
(
True
,
tensor
)
self
.
_var_quant_scales
[
out_name
]
=
(
True
,
tensor
)
return
graph
return
graph
conv_predicate
=
lambda
op
:
op
.
attr
(
"fuse_activation"
)
in
self
.
_relu_ops
and
\
conv_predicate
=
lambda
op
:
op
.
attr
(
"fuse_activation"
)
in
self
.
_relu_ops
op
.
attr
(
"fuse_residual_connection"
)
==
False
graph
=
_set_unsigned_scale
(
graph
,
self
.
_conv_ops
,
"Output"
,
graph
=
_set_unsigned_scale
(
graph
,
self
.
_conv_ops
,
"Output"
,
conv_predicate
)
conv_predicate
)
...
@@ -465,15 +489,10 @@ class Quant2Int8MkldnnPass(object):
...
@@ -465,15 +489,10 @@ class Quant2Int8MkldnnPass(object):
return
'NHWC'
if
self
.
_is_conv_quantized
(
graph
)
else
'NCHW'
return
'NHWC'
if
self
.
_is_conv_quantized
(
graph
)
else
'NCHW'
def
_quantize_fp32_graph
(
self
,
graph
):
def
_quantize_fp32_graph
(
self
,
graph
):
ir_pass
=
self
.
_core
.
get_pass
(
'cpu_quantize_placement_pass'
)
graph
=
self
.
_apply_pass
(
cpp_graph
=
graph
.
graph
graph
,
'cpu_quantize_placement_pass'
,
ir_pass
.
set
(
'quantize_enabled_op_types'
,
self
.
_ops_to_quantize
)
[
'quantize_enabled_op_types'
,
'quantize_excluded_op_ids'
],
ir_pass
.
set
(
'quantize_excluded_op_ids'
,
[
self
.
_ops_to_quantize
,
self
.
_find_avg_pooling_ids
(
graph
)])
self
.
_find_avg_pooling_ids
(
graph
))
ir_pass
.
apply
(
cpp_graph
)
if
self
.
_debug
:
graph
.
draw
(
'.'
,
'quant_int8_{}'
.
format
(
ir_pass
.
type
()),
graph
.
all_op_nodes
())
graph
=
self
.
_apply_pass
(
graph
,
'scale_matmul_fuse_pass'
)
graph
=
self
.
_apply_pass
(
graph
,
'scale_matmul_fuse_pass'
)
graph
=
self
.
_apply_pass
(
graph
,
graph
=
self
.
_apply_pass
(
graph
,
'reshape_transpose_matmul_mkldnn_fuse_pass'
)
'reshape_transpose_matmul_mkldnn_fuse_pass'
)
...
...
python/paddle/fluid/contrib/slim/tests/CMakeLists.txt
浏览文件 @
d0a921ba
...
@@ -57,7 +57,7 @@ endfunction()
...
@@ -57,7 +57,7 @@ endfunction()
# set batch_size 10 for UT only (avoid OOM). For whole dataset, use batch_size 25
# set batch_size 10 for UT only (avoid OOM). For whole dataset, use batch_size 25
function
(
inference_quant2_int8_image_classification_test target quant_model_dir fp32_model_dir dataset_path
ops_to_quantize
)
function
(
inference_quant2_int8_image_classification_test target quant_model_dir fp32_model_dir dataset_path
)
py_test
(
${
target
}
SRCS
"
${
CMAKE_CURRENT_SOURCE_DIR
}
/quant2_int8_image_classification_comparison.py"
py_test
(
${
target
}
SRCS
"
${
CMAKE_CURRENT_SOURCE_DIR
}
/quant2_int8_image_classification_comparison.py"
ENVS FLAGS_OMP_NUM_THREADS=
${
CPU_NUM_THREADS_ON_CI
}
ENVS FLAGS_OMP_NUM_THREADS=
${
CPU_NUM_THREADS_ON_CI
}
OMP_NUM_THREADS=
${
CPU_NUM_THREADS_ON_CI
}
OMP_NUM_THREADS=
${
CPU_NUM_THREADS_ON_CI
}
...
@@ -67,12 +67,11 @@ function(inference_quant2_int8_image_classification_test target quant_model_dir
...
@@ -67,12 +67,11 @@ function(inference_quant2_int8_image_classification_test target quant_model_dir
--infer_data
${
dataset_path
}
--infer_data
${
dataset_path
}
--batch_size 10
--batch_size 10
--batch_num 2
--batch_num 2
--acc_diff_threshold 0.1
--acc_diff_threshold 0.1
)
--ops_to_quantize
${
ops_to_quantize
}
)
endfunction
()
endfunction
()
# set batch_size 10 for UT only (avoid OOM). For whole dataset, use batch_size 20
# set batch_size 10 for UT only (avoid OOM). For whole dataset, use batch_size 20
function
(
inference_quant2_int8_nlp_test target quant_model_dir fp32_model_dir dataset_path labels_path
)
function
(
inference_quant2_int8_nlp_test target quant_model_dir fp32_model_dir dataset_path labels_path
ops_to_quantize
)
py_test
(
${
target
}
SRCS
"
${
CMAKE_CURRENT_SOURCE_DIR
}
/quant2_int8_nlp_comparison.py"
py_test
(
${
target
}
SRCS
"
${
CMAKE_CURRENT_SOURCE_DIR
}
/quant2_int8_nlp_comparison.py"
ENVS FLAGS_OMP_NUM_THREADS=
${
CPU_NUM_THREADS_ON_CI
}
ENVS FLAGS_OMP_NUM_THREADS=
${
CPU_NUM_THREADS_ON_CI
}
OMP_NUM_THREADS=
${
CPU_NUM_THREADS_ON_CI
}
OMP_NUM_THREADS=
${
CPU_NUM_THREADS_ON_CI
}
...
@@ -83,7 +82,8 @@ function(inference_quant2_int8_nlp_test target quant_model_dir fp32_model_dir da
...
@@ -83,7 +82,8 @@ function(inference_quant2_int8_nlp_test target quant_model_dir fp32_model_dir da
--labels
${
labels_path
}
--labels
${
labels_path
}
--batch_size 10
--batch_size 10
--batch_num 2
--batch_num 2
--acc_diff_threshold 0.1
)
--acc_diff_threshold 0.1
--ops_to_quantize
${
ops_to_quantize
}
)
endfunction
()
endfunction
()
function
(
download_quant_data install_dir data_file
)
function
(
download_quant_data install_dir data_file
)
...
@@ -98,20 +98,20 @@ function(download_quant_model install_dir data_file)
...
@@ -98,20 +98,20 @@ function(download_quant_model install_dir data_file)
endif
()
endif
()
endfunction
()
endfunction
()
function
(
save_quant_ic_model_test target quant_model_dir fp32_model_save_path int8_model_save_path
ops_to_quantize
)
function
(
save_quant_ic_model_test target quant_model_dir fp32_model_save_path int8_model_save_path
)
py_test
(
${
target
}
SRCS
${
CMAKE_CURRENT_SOURCE_DIR
}
/save_quant_model.py
py_test
(
${
target
}
SRCS
${
CMAKE_CURRENT_SOURCE_DIR
}
/save_quant_model.py
ARGS --quant_model_path
${
quant_model_dir
}
ARGS --quant_model_path
${
quant_model_dir
}
--fp32_model_save_path
${
fp32_model_save_path
}
--fp32_model_save_path
${
fp32_model_save_path
}
--int8_model_save_path
${
int8_model_save_path
}
--int8_model_save_path
${
int8_model_save_path
}
--ops_to_quantize
${
ops_to_quantize
}
--debug
)
--debug
)
endfunction
()
endfunction
()
function
(
save_quant_nlp_model_test target quant_model_dir fp32_model_save_path int8_model_save_path
)
function
(
save_quant_nlp_model_test target quant_model_dir fp32_model_save_path int8_model_save_path
ops_to_quantize
)
py_test
(
${
target
}
SRCS
${
CMAKE_CURRENT_SOURCE_DIR
}
/save_quant_model.py
py_test
(
${
target
}
SRCS
${
CMAKE_CURRENT_SOURCE_DIR
}
/save_quant_model.py
ARGS --quant_model_path
${
quant_model_dir
}
ARGS --quant_model_path
${
quant_model_dir
}
--fp32_model_save_path
${
fp32_model_save_path
}
--fp32_model_save_path
${
fp32_model_save_path
}
--int8_model_save_path
${
int8_model_save_path
}
)
--int8_model_save_path
${
int8_model_save_path
}
--ops_to_quantize
${
ops_to_quantize
}
)
endfunction
()
endfunction
()
function
(
convert_model2dot_test target model_path save_graph_dir save_graph_name
)
function
(
convert_model2dot_test target model_path save_graph_dir save_graph_name
)
...
@@ -224,36 +224,34 @@ if(LINUX AND WITH_MKLDNN)
...
@@ -224,36 +224,34 @@ if(LINUX AND WITH_MKLDNN)
### Quant2 for image classification
### Quant2 for image classification
set
(
QUANT2_IC_OPS_TO_QUANTIZE
"conv2d,pool2d"
)
# Quant2 ResNet50 with input/output scales in `fake_quantize_moving_average_abs_max` operators,
# Quant2 ResNet50 with input/output scales in `fake_quantize_moving_average_abs_max` operators,
# with weight scales in `fake_dequantize_max_abs` operators
# with weight scales in `fake_dequantize_max_abs` operators
set
(
QUANT2_RESNET50_MODEL_DIR
"
${
QUANT_INSTALL_DIR
}
/ResNet50_quant2"
)
set
(
QUANT2_RESNET50_MODEL_DIR
"
${
QUANT_INSTALL_DIR
}
/ResNet50_quant2"
)
set
(
QUANT2_RESNET50_MODEL_ARCHIVE
"ResNet50_qat_perf.tar.gz"
)
set
(
QUANT2_RESNET50_MODEL_ARCHIVE
"ResNet50_qat_perf.tar.gz"
)
download_quant_model
(
${
QUANT2_RESNET50_MODEL_DIR
}
${
QUANT2_RESNET50_MODEL_ARCHIVE
}
)
download_quant_model
(
${
QUANT2_RESNET50_MODEL_DIR
}
${
QUANT2_RESNET50_MODEL_ARCHIVE
}
)
set
(
FP32_RESNET50_MODEL_DIR
"
${
INT8_INSTALL_DIR
}
/resnet50"
)
set
(
FP32_RESNET50_MODEL_DIR
"
${
INT8_INSTALL_DIR
}
/resnet50"
)
inference_quant2_int8_image_classification_test
(
test_quant2_int8_resnet50_mkldnn
${
QUANT2_RESNET50_MODEL_DIR
}
/ResNet50_qat_perf/float
${
FP32_RESNET50_MODEL_DIR
}
/model
${
IMAGENET_DATA_PATH
}
${
QUANT2_IC_OPS_TO_QUANTIZE
}
)
inference_quant2_int8_image_classification_test
(
test_quant2_int8_resnet50_mkldnn
${
QUANT2_RESNET50_MODEL_DIR
}
/ResNet50_qat_perf/float
${
FP32_RESNET50_MODEL_DIR
}
/model
${
IMAGENET_DATA_PATH
}
)
# Quant2 ResNet50 with input/output scales in `fake_quantize_range_abs_max` operators and the `out_threshold` attributes,
# Quant2 ResNet50 with input/output scales in `fake_quantize_range_abs_max` operators and the `out_threshold` attributes,
# with weight scales in `fake_dequantize_max_abs` operators
# with weight scales in `fake_dequantize_max_abs` operators
set
(
QUANT2_RESNET50_RANGE_MODEL_DIR
"
${
QUANT_INSTALL_DIR
}
/ResNet50_quant2_range"
)
set
(
QUANT2_RESNET50_RANGE_MODEL_DIR
"
${
QUANT_INSTALL_DIR
}
/ResNet50_quant2_range"
)
set
(
QUANT2_RESNET50_RANGE_MODEL_ARCHIVE
"ResNet50_qat_range.tar.gz"
)
set
(
QUANT2_RESNET50_RANGE_MODEL_ARCHIVE
"ResNet50_qat_range.tar.gz"
)
download_quant_model
(
${
QUANT2_RESNET50_RANGE_MODEL_DIR
}
${
QUANT2_RESNET50_RANGE_MODEL_ARCHIVE
}
)
download_quant_model
(
${
QUANT2_RESNET50_RANGE_MODEL_DIR
}
${
QUANT2_RESNET50_RANGE_MODEL_ARCHIVE
}
)
inference_quant2_int8_image_classification_test
(
test_quant2_int8_resnet50_range_mkldnn
${
QUANT2_RESNET50_RANGE_MODEL_DIR
}
/ResNet50_qat_range
${
FP32_RESNET50_MODEL_DIR
}
/model
${
IMAGENET_DATA_PATH
}
${
QUANT2_IC_OPS_TO_QUANTIZE
}
)
inference_quant2_int8_image_classification_test
(
test_quant2_int8_resnet50_range_mkldnn
${
QUANT2_RESNET50_RANGE_MODEL_DIR
}
/ResNet50_qat_range
${
FP32_RESNET50_MODEL_DIR
}
/model
${
IMAGENET_DATA_PATH
}
)
# Quant2 ResNet50 with input/output scales in `fake_quantize_range_abs_max` operators and the `out_threshold` attributes,
# Quant2 ResNet50 with input/output scales in `fake_quantize_range_abs_max` operators and the `out_threshold` attributes,
# with weight scales in `fake_channel_wise_dequantize_max_abs` operators
# with weight scales in `fake_channel_wise_dequantize_max_abs` operators
set
(
QUANT2_RESNET50_CHANNELWISE_MODEL_DIR
"
${
QUANT_INSTALL_DIR
}
/ResNet50_quant2_channelwise"
)
set
(
QUANT2_RESNET50_CHANNELWISE_MODEL_DIR
"
${
QUANT_INSTALL_DIR
}
/ResNet50_quant2_channelwise"
)
set
(
QUANT2_RESNET50_CHANNELWISE_MODEL_ARCHIVE
"ResNet50_qat_channelwise.tar.gz"
)
set
(
QUANT2_RESNET50_CHANNELWISE_MODEL_ARCHIVE
"ResNet50_qat_channelwise.tar.gz"
)
download_quant_model
(
${
QUANT2_RESNET50_CHANNELWISE_MODEL_DIR
}
${
QUANT2_RESNET50_CHANNELWISE_MODEL_ARCHIVE
}
)
download_quant_model
(
${
QUANT2_RESNET50_CHANNELWISE_MODEL_DIR
}
${
QUANT2_RESNET50_CHANNELWISE_MODEL_ARCHIVE
}
)
inference_quant2_int8_image_classification_test
(
test_quant2_int8_resnet50_channelwise_mkldnn
${
QUANT2_RESNET50_CHANNELWISE_MODEL_DIR
}
/ResNet50_qat_channelwise
${
FP32_RESNET50_MODEL_DIR
}
/model
${
IMAGENET_DATA_PATH
}
${
QUANT2_IC_OPS_TO_QUANTIZE
}
)
inference_quant2_int8_image_classification_test
(
test_quant2_int8_resnet50_channelwise_mkldnn
${
QUANT2_RESNET50_CHANNELWISE_MODEL_DIR
}
/ResNet50_qat_channelwise
${
FP32_RESNET50_MODEL_DIR
}
/model
${
IMAGENET_DATA_PATH
}
)
# Quant2 MobileNetV1
# Quant2 MobileNetV1
set
(
QUANT2_MOBILENETV1_MODEL_DIR
"
${
QUANT_INSTALL_DIR
}
/MobileNetV1_quant2"
)
set
(
QUANT2_MOBILENETV1_MODEL_DIR
"
${
QUANT_INSTALL_DIR
}
/MobileNetV1_quant2"
)
set
(
QUANT2_MOBILENETV1_MODEL_ARCHIVE
"MobileNet_qat_perf.tar.gz"
)
set
(
QUANT2_MOBILENETV1_MODEL_ARCHIVE
"MobileNet_qat_perf.tar.gz"
)
download_quant_model
(
${
QUANT2_MOBILENETV1_MODEL_DIR
}
${
QUANT2_MOBILENETV1_MODEL_ARCHIVE
}
)
download_quant_model
(
${
QUANT2_MOBILENETV1_MODEL_DIR
}
${
QUANT2_MOBILENETV1_MODEL_ARCHIVE
}
)
set
(
FP32_MOBILENETV1_MODEL_DIR
"
${
INT8_INSTALL_DIR
}
/mobilenetv1"
)
set
(
FP32_MOBILENETV1_MODEL_DIR
"
${
INT8_INSTALL_DIR
}
/mobilenetv1"
)
inference_quant2_int8_image_classification_test
(
test_quant2_int8_mobilenetv1_mkldnn
${
QUANT2_MOBILENETV1_MODEL_DIR
}
/MobileNet_qat_perf/float
${
FP32_MOBILENETV1_MODEL_DIR
}
/model
${
IMAGENET_DATA_PATH
}
${
QUANT2_IC_OPS_TO_QUANTIZE
}
)
inference_quant2_int8_image_classification_test
(
test_quant2_int8_mobilenetv1_mkldnn
${
QUANT2_MOBILENETV1_MODEL_DIR
}
/MobileNet_qat_perf/float
${
FP32_MOBILENETV1_MODEL_DIR
}
/model
${
IMAGENET_DATA_PATH
}
)
### Quant2 for NLP
### Quant2 for NLP
...
@@ -263,6 +261,8 @@ if(LINUX AND WITH_MKLDNN)
...
@@ -263,6 +261,8 @@ if(LINUX AND WITH_MKLDNN)
set
(
NLP_LABLES_PATH
"
${
NLP_DATA_DIR
}
/Ernie_dataset/label.xnli.dev"
)
set
(
NLP_LABLES_PATH
"
${
NLP_DATA_DIR
}
/Ernie_dataset/label.xnli.dev"
)
download_quant_data
(
${
NLP_DATA_DIR
}
${
NLP_DATA_ARCHIVE
}
)
download_quant_data
(
${
NLP_DATA_DIR
}
${
NLP_DATA_ARCHIVE
}
)
set
(
QUANT2_NLP_OPS_TO_QUANTIZE
"fc,reshape2,transpose2,matmul,elementwise_add"
)
# Quant2 Ernie
# Quant2 Ernie
set
(
QUANT2_ERNIE_MODEL_ARCHIVE
"ernie_qat.tar.gz"
)
set
(
QUANT2_ERNIE_MODEL_ARCHIVE
"ernie_qat.tar.gz"
)
set
(
QUANT2_ERNIE_MODEL_DIR
"
${
QUANT_INSTALL_DIR
}
/Ernie_quant2"
)
set
(
QUANT2_ERNIE_MODEL_DIR
"
${
QUANT_INSTALL_DIR
}
/Ernie_quant2"
)
...
@@ -270,17 +270,17 @@ if(LINUX AND WITH_MKLDNN)
...
@@ -270,17 +270,17 @@ if(LINUX AND WITH_MKLDNN)
set
(
FP32_ERNIE_MODEL_ARCHIVE
"ernie_fp32_model.tar.gz"
)
set
(
FP32_ERNIE_MODEL_ARCHIVE
"ernie_fp32_model.tar.gz"
)
set
(
FP32_ERNIE_MODEL_DIR
"
${
QUANT_INSTALL_DIR
}
/Ernie_float"
)
set
(
FP32_ERNIE_MODEL_DIR
"
${
QUANT_INSTALL_DIR
}
/Ernie_float"
)
download_quant_fp32_model
(
${
FP32_ERNIE_MODEL_DIR
}
${
FP32_ERNIE_MODEL_ARCHIVE
}
)
download_quant_fp32_model
(
${
FP32_ERNIE_MODEL_DIR
}
${
FP32_ERNIE_MODEL_ARCHIVE
}
)
inference_quant2_int8_nlp_test
(
test_quant2_int8_ernie_mkldnn
${
QUANT2_ERNIE_MODEL_DIR
}
/Ernie_qat/float
${
FP32_ERNIE_MODEL_DIR
}
/ernie_fp32_model
${
NLP_DATA_PATH
}
${
NLP_LABLES_PATH
}
)
inference_quant2_int8_nlp_test
(
test_quant2_int8_ernie_mkldnn
${
QUANT2_ERNIE_MODEL_DIR
}
/Ernie_qat/float
${
FP32_ERNIE_MODEL_DIR
}
/ernie_fp32_model
${
NLP_DATA_PATH
}
${
NLP_LABLES_PATH
}
${
QUANT2_NLP_OPS_TO_QUANTIZE
}
)
### Save FP32 model or INT8 model from Quant model
### Save FP32 model or INT8 model from Quant model
set
(
QUANT2_INT8_RESNET50_SAVE_PATH
"
${
QUANT_INSTALL_DIR
}
/ResNet50_quant2_int8"
)
set
(
QUANT2_INT8_RESNET50_SAVE_PATH
"
${
QUANT_INSTALL_DIR
}
/ResNet50_quant2_int8"
)
set
(
QUANT2_FP32_RESNET50_SAVE_PATH
"
${
QUANT_INSTALL_DIR
}
/ResNet50_quant2_fp32"
)
set
(
QUANT2_FP32_RESNET50_SAVE_PATH
"
${
QUANT_INSTALL_DIR
}
/ResNet50_quant2_fp32"
)
save_quant_ic_model_test
(
save_quant2_model_resnet50
${
QUANT2_RESNET50_MODEL_DIR
}
/ResNet50_qat_perf/float
${
QUANT2_FP32_RESNET50_SAVE_PATH
}
${
QUANT2_INT8_RESNET50_SAVE_PATH
}
${
QUANT2_IC_OPS_TO_QUANTIZE
}
)
save_quant_ic_model_test
(
save_quant2_model_resnet50
${
QUANT2_RESNET50_MODEL_DIR
}
/ResNet50_qat_perf/float
${
QUANT2_FP32_RESNET50_SAVE_PATH
}
${
QUANT2_INT8_RESNET50_SAVE_PATH
}
)
set
(
QUANT2_INT8_ERNIE_SAVE_PATH
"
${
QUANT_INSTALL_DIR
}
/Ernie_quant2_int8"
)
set
(
QUANT2_INT8_ERNIE_SAVE_PATH
"
${
QUANT_INSTALL_DIR
}
/Ernie_quant2_int8"
)
set
(
QUANT2_FP32_ERNIE_SAVE_PATH
"
${
QUANT_INSTALL_DIR
}
/Ernie_quant2_fp32"
)
set
(
QUANT2_FP32_ERNIE_SAVE_PATH
"
${
QUANT_INSTALL_DIR
}
/Ernie_quant2_fp32"
)
save_quant_nlp_model_test
(
save_quant2_model_ernie
${
QUANT2_ERNIE_MODEL_DIR
}
/Ernie_qat/float
${
QUANT2_FP32_ERNIE_SAVE_PATH
}
${
QUANT2_INT8_ERNIE_SAVE_PATH
}
)
save_quant_nlp_model_test
(
save_quant2_model_ernie
${
QUANT2_ERNIE_MODEL_DIR
}
/Ernie_qat/float
${
QUANT2_FP32_ERNIE_SAVE_PATH
}
${
QUANT2_INT8_ERNIE_SAVE_PATH
}
${
QUANT2_NLP_OPS_TO_QUANTIZE
}
)
# Convert Quant2 model to dot and pdf files
# Convert Quant2 model to dot and pdf files
set
(
QUANT2_INT8_ERNIE_DOT_SAVE_PATH
"
${
QUANT_INSTALL_DIR
}
/Ernie_quant2_int8_dot_file"
)
set
(
QUANT2_INT8_ERNIE_DOT_SAVE_PATH
"
${
QUANT_INSTALL_DIR
}
/Ernie_quant2_int8_dot_file"
)
...
...
python/paddle/fluid/contrib/slim/tests/quant2_int8_image_classification_comparison.py
浏览文件 @
d0a921ba
...
@@ -167,7 +167,8 @@ class Quant2Int8ImageClassificationComparisonTest(unittest.TestCase):
...
@@ -167,7 +167,8 @@ class Quant2Int8ImageClassificationComparisonTest(unittest.TestCase):
batch_size
=
1
,
batch_size
=
1
,
batch_num
=
1
,
batch_num
=
1
,
skip_batch_num
=
0
,
skip_batch_num
=
0
,
transform_to_int8
=
False
):
target
=
'quant'
):
assert
target
in
[
'quant'
,
'int8'
,
'fp32'
]
place
=
fluid
.
CPUPlace
()
place
=
fluid
.
CPUPlace
()
exe
=
fluid
.
Executor
(
place
)
exe
=
fluid
.
Executor
(
place
)
inference_scope
=
fluid
.
executor
.
global_scope
()
inference_scope
=
fluid
.
executor
.
global_scope
()
...
@@ -183,17 +184,19 @@ class Quant2Int8ImageClassificationComparisonTest(unittest.TestCase):
...
@@ -183,17 +184,19 @@ class Quant2Int8ImageClassificationComparisonTest(unittest.TestCase):
graph
=
IrGraph
(
core
.
Graph
(
inference_program
.
desc
),
for_test
=
True
)
graph
=
IrGraph
(
core
.
Graph
(
inference_program
.
desc
),
for_test
=
True
)
if
(
self
.
_debug
):
if
(
self
.
_debug
):
graph
.
draw
(
'.'
,
'quant_orig'
,
graph
.
all_op_nodes
())
graph
.
draw
(
'.'
,
'quant_orig'
,
graph
.
all_op_nodes
())
if
(
transform_to_int8
):
quant_transform_pass
=
Quant2Int8MkldnnPass
(
transform_to_mkldnn_int8_pass
=
Quant2Int8MkldnnPass
(
self
.
_quantized_ops
,
self
.
_quantized_ops
,
_op_ids_to_skip
=
self
.
_op_ids_to_skip
,
_op_ids_to_skip
=
self
.
_op_ids_to_skip
,
_scope
=
inference_scope
,
_scope
=
inference_scope
,
_place
=
place
,
_place
=
place
,
_core
=
core
,
_core
=
core
,
_debug
=
self
.
_debug
)
_debug
=
self
.
_debug
)
graph
=
transform_to_mkldnn_int8_pass
.
apply
(
graph
)
if
(
target
==
'quant'
):
else
:
graph
=
self
.
_prepare_for_fp32_mkldnn
(
graph
)
graph
=
self
.
_prepare_for_fp32_mkldnn
(
graph
)
elif
(
target
==
'int8'
):
graph
=
quant_transform_pass
.
apply
(
graph
)
else
:
# target == fp32
graph
=
quant_transform_pass
.
prepare_and_optimize_fp32
(
graph
)
inference_program
=
graph
.
to_program
()
inference_program
=
graph
.
to_program
()
...
@@ -222,18 +225,7 @@ class Quant2Int8ImageClassificationComparisonTest(unittest.TestCase):
...
@@ -222,18 +225,7 @@ class Quant2Int8ImageClassificationComparisonTest(unittest.TestCase):
images
=
np
.
array
(
images
).
astype
(
'float32'
)
images
=
np
.
array
(
images
).
astype
(
'float32'
)
labels
=
np
.
array
([
x
[
1
]
for
x
in
data
]).
astype
(
'int64'
)
labels
=
np
.
array
([
x
[
1
]
for
x
in
data
]).
astype
(
'int64'
)
if
(
transform_to_int8
==
True
):
if
(
target
==
'fp32'
):
# INT8 models obtained from Quant models do not have accuracy measuring layers
start
=
time
.
time
()
out
=
exe
.
run
(
inference_program
,
feed
=
{
feed_target_names
[
0
]:
images
},
fetch_list
=
fetch_targets
)
batch_time
=
(
time
.
time
()
-
start
)
*
1000
# in miliseconds
outputs
.
append
(
out
[
0
])
# Calculate accuracy result
batch_acc1
,
batch_acc5
=
self
.
_get_batch_accuracy
(
out
[
0
],
labels
)
else
:
# FP32 models have accuracy measuring layers
# FP32 models have accuracy measuring layers
labels
=
labels
.
reshape
([
-
1
,
1
])
labels
=
labels
.
reshape
([
-
1
,
1
])
start
=
time
.
time
()
start
=
time
.
time
()
...
@@ -246,6 +238,18 @@ class Quant2Int8ImageClassificationComparisonTest(unittest.TestCase):
...
@@ -246,6 +238,18 @@ class Quant2Int8ImageClassificationComparisonTest(unittest.TestCase):
batch_time
=
(
time
.
time
()
-
start
)
*
1000
# in miliseconds
batch_time
=
(
time
.
time
()
-
start
)
*
1000
# in miliseconds
batch_acc1
,
batch_acc5
=
out
[
1
][
0
],
out
[
2
][
0
]
batch_acc1
,
batch_acc5
=
out
[
1
][
0
],
out
[
2
][
0
]
outputs
.
append
(
batch_acc1
)
outputs
.
append
(
batch_acc1
)
else
:
# Quant INT8 models do not have accuracy measuring layers
start
=
time
.
time
()
out
=
exe
.
run
(
inference_program
,
feed
=
{
feed_target_names
[
0
]:
images
},
fetch_list
=
fetch_targets
)
batch_time
=
(
time
.
time
()
-
start
)
*
1000
# in miliseconds
outputs
.
append
(
out
[
0
])
# Calculate accuracy result
batch_acc1
,
batch_acc5
=
self
.
_get_batch_accuracy
(
out
[
0
],
labels
)
infer_accs1
.
append
(
batch_acc1
)
infer_accs1
.
append
(
batch_acc1
)
infer_accs5
.
append
(
batch_acc5
)
infer_accs5
.
append
(
batch_acc5
)
samples
=
len
(
data
)
samples
=
len
(
data
)
...
@@ -274,28 +278,37 @@ class Quant2Int8ImageClassificationComparisonTest(unittest.TestCase):
...
@@ -274,28 +278,37 @@ class Quant2Int8ImageClassificationComparisonTest(unittest.TestCase):
return
outputs
,
acc1_avg
,
acc5_avg
,
fps_avg
,
latency_avg
return
outputs
,
acc1_avg
,
acc5_avg
,
fps_avg
,
latency_avg
def
_summarize_performance
(
self
,
fp32_fps
,
fp32_lat
,
int8_fps
,
int8_lat
):
def
_print_performance
(
self
,
title
,
fps
,
lat
):
_logger
.
info
(
'{0}: avg fps: {1:.2f}, avg latency: {2:.4f} ms'
.
format
(
title
,
fps
,
lat
))
def
_print_accuracy
(
self
,
title
,
acc1
,
acc5
):
_logger
.
info
(
'{0}: avg top1 accuracy: {1:.4f}, avg top5 accuracy: {2:.4f}'
.
format
(
title
,
acc1
,
acc5
))
def
_summarize_performance
(
self
,
int8_fps
,
int8_lat
,
fp32_fps
,
fp32_lat
):
_logger
.
info
(
'--- Performance summary ---'
)
_logger
.
info
(
'--- Performance summary ---'
)
_logger
.
info
(
'FP32: avg fps: {0:.2f}, avg latency: {1:.4f} ms'
.
format
(
self
.
_print_performance
(
'INT8'
,
int8_fps
,
int8_lat
)
fp32_fps
,
fp32_lat
))
if
fp32_lat
>=
0
:
_logger
.
info
(
'INT8: avg fps: {0:.2f}, avg latency: {1:.4f} ms'
.
format
(
self
.
_print_performance
(
'FP32'
,
fp32_fps
,
fp32_lat
)
int8_fps
,
int8_lat
))
def
_
compare_accuracy
(
self
,
fp32_acc1
,
fp32
_acc5
,
int8_acc1
,
int8_acc5
,
def
_
summarize_accuracy
(
self
,
quant_acc1
,
quant
_acc5
,
int8_acc1
,
int8_acc5
,
threshold
):
fp32_acc1
,
fp32_acc5
):
_logger
.
info
(
'--- Accuracy summary ---'
)
_logger
.
info
(
'--- Accuracy summary ---'
)
self
.
_print_accuracy
(
'Quant'
,
quant_acc1
,
quant_acc5
)
self
.
_print_accuracy
(
'INT8'
,
int8_acc1
,
int8_acc5
)
if
fp32_acc1
>=
0
:
self
.
_print_accuracy
(
'FP32'
,
fp32_acc1
,
fp32_acc5
)
def
_compare_accuracy
(
self
,
threshold
,
quant_acc1
,
int8_acc1
):
_logger
.
info
(
_logger
.
info
(
'Accepted top1 accuracy drop threshold: {0}. (condition: (
FP32_top1_acc - IN8_top1_acc) <= threshold
)'
'Accepted top1 accuracy drop threshold: {0}. (condition: (
Quant_top1_acc - IN8_top1_acc) <= threshold && Quant_top1_acc > 0.5 && INT8_top1_acc > 0.5
)'
.
format
(
threshold
))
.
format
(
threshold
))
_logger
.
info
(
# We assume valid accuracy to be at least 0.5
'FP32: avg top1 accuracy: {0:.4f}, avg top5 accuracy: {1:.4f}'
.
assert
quant_acc1
>
0.5
format
(
fp32_acc1
,
fp32_acc5
))
assert
int8_acc1
>
0.5
_logger
.
info
(
assert
quant_acc1
-
int8_acc1
<=
threshold
'INT8: avg top1 accuracy: {0:.4f}, avg top5 accuracy: {1:.4f}'
.
format
(
int8_acc1
,
int8_acc5
))
assert
fp32_acc1
>
0.0
assert
int8_acc1
>
0.0
assert
fp32_acc1
-
int8_acc1
<=
threshold
def
test_graph_transformation
(
self
):
def
test_graph_transformation
(
self
):
if
not
fluid
.
core
.
is_compiled_with_mkldnn
():
if
not
fluid
.
core
.
is_compiled_with_mkldnn
():
...
@@ -303,10 +316,9 @@ class Quant2Int8ImageClassificationComparisonTest(unittest.TestCase):
...
@@ -303,10 +316,9 @@ class Quant2Int8ImageClassificationComparisonTest(unittest.TestCase):
quant_model_path
=
test_case_args
.
quant_model
quant_model_path
=
test_case_args
.
quant_model
assert
quant_model_path
,
'The Quant model path cannot be empty. Please, use the --quant_model option.'
assert
quant_model_path
,
'The Quant model path cannot be empty. Please, use the --quant_model option.'
fp32_model_path
=
test_case_args
.
fp32_model
assert
fp32_model_path
,
'The FP32 model path cannot be empty. Please, use the --fp32_model option.'
data_path
=
test_case_args
.
infer_data
data_path
=
test_case_args
.
infer_data
assert
data_path
,
'The dataset path cannot be empty. Please, use the --infer_data option.'
assert
data_path
,
'The dataset path cannot be empty. Please, use the --infer_data option.'
fp32_model_path
=
test_case_args
.
fp32_model
batch_size
=
test_case_args
.
batch_size
batch_size
=
test_case_args
.
batch_size
batch_num
=
test_case_args
.
batch_num
batch_num
=
test_case_args
.
batch_num
skip_batch_num
=
test_case_args
.
skip_batch_num
skip_batch_num
=
test_case_args
.
skip_batch_num
...
@@ -323,8 +335,9 @@ class Quant2Int8ImageClassificationComparisonTest(unittest.TestCase):
...
@@ -323,8 +335,9 @@ class Quant2Int8ImageClassificationComparisonTest(unittest.TestCase):
self
.
_op_ids_to_skip
=
set
(
self
.
_op_ids_to_skip
=
set
(
map
(
int
,
test_case_args
.
op_ids_to_skip
.
split
(
','
)))
map
(
int
,
test_case_args
.
op_ids_to_skip
.
split
(
','
)))
_logger
.
info
(
'
FP32 & Quant
INT8 prediction run.'
)
_logger
.
info
(
'
Quant &
INT8 prediction run.'
)
_logger
.
info
(
'Quant model: {}'
.
format
(
quant_model_path
))
_logger
.
info
(
'Quant model: {}'
.
format
(
quant_model_path
))
if
fp32_model_path
:
_logger
.
info
(
'FP32 model: {}'
.
format
(
fp32_model_path
))
_logger
.
info
(
'FP32 model: {}'
.
format
(
fp32_model_path
))
_logger
.
info
(
'Dataset: {}'
.
format
(
data_path
))
_logger
.
info
(
'Dataset: {}'
.
format
(
data_path
))
_logger
.
info
(
'Batch size: {}'
.
format
(
batch_size
))
_logger
.
info
(
'Batch size: {}'
.
format
(
batch_size
))
...
@@ -336,17 +349,20 @@ class Quant2Int8ImageClassificationComparisonTest(unittest.TestCase):
...
@@ -336,17 +349,20 @@ class Quant2Int8ImageClassificationComparisonTest(unittest.TestCase):
map
(
str
,
self
.
_op_ids_to_skip
))
if
test_case_args
.
op_ids_to_skip
map
(
str
,
self
.
_op_ids_to_skip
))
if
test_case_args
.
op_ids_to_skip
else
'none'
))
else
'none'
))
_logger
.
info
(
'---
FP32
prediction start ---'
)
_logger
.
info
(
'---
Quant
prediction start ---'
)
val_reader
=
paddle
.
batch
(
val_reader
=
paddle
.
batch
(
self
.
_reader_creator
(
data_path
),
batch_size
=
batch_size
)
self
.
_reader_creator
(
data_path
),
batch_size
=
batch_size
)
fp32_output
,
fp32_acc1
,
fp32_acc5
,
fp32_fps
,
fp32
_lat
=
self
.
_predict
(
quant_output
,
quant_acc1
,
quant_acc5
,
quant_fps
,
quant
_lat
=
self
.
_predict
(
val_reader
,
val_reader
,
fp32
_model_path
,
quant
_model_path
,
batch_size
,
batch_size
,
batch_num
,
batch_num
,
skip_batch_num
,
skip_batch_num
,
transform_to_int8
=
False
)
target
=
'quant'
)
_logger
.
info
(
'--- Quant INT8 prediction start ---'
)
self
.
_print_performance
(
'Quant'
,
quant_fps
,
quant_lat
)
self
.
_print_accuracy
(
'Quant'
,
quant_acc1
,
quant_acc5
)
_logger
.
info
(
'--- INT8 prediction start ---'
)
val_reader
=
paddle
.
batch
(
val_reader
=
paddle
.
batch
(
self
.
_reader_creator
(
data_path
),
batch_size
=
batch_size
)
self
.
_reader_creator
(
data_path
),
batch_size
=
batch_size
)
int8_output
,
int8_acc1
,
int8_acc5
,
int8_fps
,
int8_lat
=
self
.
_predict
(
int8_output
,
int8_acc1
,
int8_acc5
,
int8_fps
,
int8_lat
=
self
.
_predict
(
...
@@ -355,11 +371,29 @@ class Quant2Int8ImageClassificationComparisonTest(unittest.TestCase):
...
@@ -355,11 +371,29 @@ class Quant2Int8ImageClassificationComparisonTest(unittest.TestCase):
batch_size
,
batch_size
,
batch_num
,
batch_num
,
skip_batch_num
,
skip_batch_num
,
transform_to_int8
=
True
)
target
=
'int8'
)
self
.
_print_performance
(
'INT8'
,
int8_fps
,
int8_lat
)
self
.
_print_accuracy
(
'INT8'
,
int8_acc1
,
int8_acc5
)
self
.
_summarize_performance
(
fp32_fps
,
fp32_lat
,
int8_fps
,
int8_lat
)
fp32_acc1
=
fp32_acc5
=
fp32_fps
=
fp32_lat
=
-
1
self
.
_compare_accuracy
(
fp32_acc1
,
fp32_acc5
,
int8_acc1
,
int8_acc5
,
if
fp32_model_path
:
acc_diff_threshold
)
_logger
.
info
(
'--- FP32 prediction start ---'
)
val_reader
=
paddle
.
batch
(
self
.
_reader_creator
(
data_path
),
batch_size
=
batch_size
)
fp32_output
,
fp32_acc1
,
fp32_acc5
,
fp32_fps
,
fp32_lat
=
self
.
_predict
(
val_reader
,
fp32_model_path
,
batch_size
,
batch_num
,
skip_batch_num
,
target
=
'fp32'
)
self
.
_print_performance
(
'FP32'
,
fp32_fps
,
fp32_lat
)
self
.
_print_accuracy
(
'FP32'
,
fp32_acc1
,
fp32_acc5
)
self
.
_summarize_performance
(
int8_fps
,
int8_lat
,
fp32_fps
,
fp32_lat
)
self
.
_summarize_accuracy
(
quant_acc1
,
quant_acc5
,
int8_acc1
,
int8_acc5
,
fp32_acc1
,
fp32_acc5
)
self
.
_compare_accuracy
(
acc_diff_threshold
,
quant_acc1
,
int8_acc1
)
if
__name__
==
'__main__'
:
if
__name__
==
'__main__'
:
...
...
python/paddle/fluid/contrib/slim/tests/quant2_int8_nlp_comparison.py
浏览文件 @
d0a921ba
...
@@ -17,8 +17,6 @@ import os
...
@@ -17,8 +17,6 @@ import os
import
sys
import
sys
import
argparse
import
argparse
import
logging
import
logging
import
struct
import
six
import
numpy
as
np
import
numpy
as
np
import
time
import
time
import
paddle
import
paddle
...
@@ -143,7 +141,8 @@ class QuantInt8NLPComparisonTest(unittest.TestCase):
...
@@ -143,7 +141,8 @@ class QuantInt8NLPComparisonTest(unittest.TestCase):
batch_size
=
1
,
batch_size
=
1
,
batch_num
=
1
,
batch_num
=
1
,
skip_batch_num
=
0
,
skip_batch_num
=
0
,
transform_to_int8
=
False
):
target
=
'quant'
):
assert
target
in
[
'quant'
,
'int8'
,
'fp32'
]
place
=
fluid
.
CPUPlace
()
place
=
fluid
.
CPUPlace
()
exe
=
fluid
.
Executor
(
place
)
exe
=
fluid
.
Executor
(
place
)
inference_scope
=
fluid
.
executor
.
global_scope
()
inference_scope
=
fluid
.
executor
.
global_scope
()
...
@@ -159,15 +158,19 @@ class QuantInt8NLPComparisonTest(unittest.TestCase):
...
@@ -159,15 +158,19 @@ class QuantInt8NLPComparisonTest(unittest.TestCase):
graph
=
IrGraph
(
core
.
Graph
(
inference_program
.
desc
),
for_test
=
True
)
graph
=
IrGraph
(
core
.
Graph
(
inference_program
.
desc
),
for_test
=
True
)
if
(
self
.
_debug
):
if
(
self
.
_debug
):
graph
.
draw
(
'.'
,
'quant_orig'
,
graph
.
all_op_nodes
())
graph
.
draw
(
'.'
,
'quant_orig'
,
graph
.
all_op_nodes
())
if
(
t
ransform_to_int8
):
if
(
t
arget
!=
'quant'
):
transform_to_mkldnn_int8
_pass
=
Quant2Int8MkldnnPass
(
quant_transform
_pass
=
Quant2Int8MkldnnPass
(
self
.
_quantized_ops
,
self
.
_quantized_ops
,
_op_ids_to_skip
=
self
.
_op_ids_to_skip
,
_op_ids_to_skip
=
self
.
_op_ids_to_skip
,
_scope
=
inference_scope
,
_scope
=
inference_scope
,
_place
=
place
,
_place
=
place
,
_core
=
core
,
_core
=
core
,
_debug
=
self
.
_debug
)
_debug
=
self
.
_debug
)
graph
=
transform_to_mkldnn_int8_pass
.
apply
(
graph
)
if
(
target
==
'int8'
):
graph
=
quant_transform_pass
.
apply
(
graph
)
else
:
# target == fp32
graph
=
quant_transform_pass
.
prepare_and_optimize_fp32
(
graph
)
inference_program
=
graph
.
to_program
()
inference_program
=
graph
.
to_program
()
...
@@ -223,26 +226,35 @@ class QuantInt8NLPComparisonTest(unittest.TestCase):
...
@@ -223,26 +226,35 @@ class QuantInt8NLPComparisonTest(unittest.TestCase):
return
acc_avg
,
pps_avg
,
latency_avg
return
acc_avg
,
pps_avg
,
latency_avg
def
_summarize_performance
(
self
,
fp32_pps
,
fp32_lat
,
int8_pps
,
int8_lat
):
def
_print_performance
(
self
,
title
,
pps
,
lat
):
_logger
.
info
(
'--- Performance summary ---'
)
_logger
.
info
(
'FP32: avg predictions per sec: {0:.2f}, avg latency: {1:.4f} ms'
.
format
(
fp32_pps
,
fp32_lat
))
_logger
.
info
(
_logger
.
info
(
'INT8: avg predictions per sec: {0:.2f}, avg latency: {1:.4f} ms'
.
'{0}: avg predictions per sec: {1:.2f}, avg latency: {2:.4f} ms'
.
format
(
int8_pps
,
int8_lat
))
format
(
title
,
pps
,
lat
))
def
_print_accuracy
(
self
,
title
,
acc
):
_logger
.
info
(
'{0}: avg accuracy: {1:.6f}'
.
format
(
title
,
acc
))
def
_summarize_performance
(
self
,
int8_pps
,
int8_lat
,
fp32_pps
,
fp32_lat
):
_logger
.
info
(
'--- Performance summary ---'
)
self
.
_print_performance
(
'INT8'
,
int8_pps
,
int8_lat
)
if
fp32_lat
>=
0
:
self
.
_print_performance
(
'FP32'
,
fp32_pps
,
fp32_lat
)
def
_
compare_accuracy
(
self
,
fp32_acc
,
int8_acc
,
threshold
):
def
_
summarize_accuracy
(
self
,
quant_acc
,
int8_acc
,
fp32_acc
):
_logger
.
info
(
'--- Accuracy summary ---'
)
_logger
.
info
(
'--- Accuracy summary ---'
)
self
.
_print_accuracy
(
'Quant'
,
quant_acc
)
self
.
_print_accuracy
(
'INT8'
,
int8_acc
)
if
fp32_acc
>=
0
:
self
.
_print_accuracy
(
'FP32'
,
fp32_acc
)
def
_compare_accuracy
(
self
,
threshold
,
quant_acc
,
int8_acc
):
_logger
.
info
(
_logger
.
info
(
'Accepted accuracy drop threshold: {0}. (condition: (
FP32
_acc - INT8_acc) <= threshold)'
'Accepted accuracy drop threshold: {0}. (condition: (
Quant
_acc - INT8_acc) <= threshold)'
.
format
(
threshold
))
.
format
(
threshold
))
_logger
.
info
(
'FP32: avg accuracy: {0:.6f}'
.
format
(
fp32_acc
))
_logger
.
info
(
'INT8: avg accuracy: {0:.6f}'
.
format
(
int8_acc
))
# Random outputs give accuracy about 0.33, we assume valid accuracy to be at least 0.5
# Random outputs give accuracy about 0.33, we assume valid accuracy to be at least 0.5
assert
fp32
_acc
>
0.5
assert
quant
_acc
>
0.5
assert
int8_acc
>
0.5
assert
int8_acc
>
0.5
assert
fp32
_acc
-
int8_acc
<=
threshold
assert
quant
_acc
-
int8_acc
<=
threshold
def
test_graph_transformation
(
self
):
def
test_graph_transformation
(
self
):
if
not
fluid
.
core
.
is_compiled_with_mkldnn
():
if
not
fluid
.
core
.
is_compiled_with_mkldnn
():
...
@@ -250,9 +262,9 @@ class QuantInt8NLPComparisonTest(unittest.TestCase):
...
@@ -250,9 +262,9 @@ class QuantInt8NLPComparisonTest(unittest.TestCase):
quant_model_path
=
test_case_args
.
quant_model
quant_model_path
=
test_case_args
.
quant_model
assert
quant_model_path
,
'The Quant model path cannot be empty. Please, use the --quant_model option.'
assert
quant_model_path
,
'The Quant model path cannot be empty. Please, use the --quant_model option.'
fp32_model_path
=
test_case_args
.
fp32_model
if
test_case_args
.
fp32_model
else
quant_model_path
data_path
=
test_case_args
.
infer_data
data_path
=
test_case_args
.
infer_data
assert
data_path
,
'The dataset path cannot be empty. Please, use the --infer_data option.'
assert
data_path
,
'The dataset path cannot be empty. Please, use the --infer_data option.'
fp32_model_path
=
test_case_args
.
fp32_model
labels_path
=
test_case_args
.
labels
labels_path
=
test_case_args
.
labels
batch_size
=
test_case_args
.
batch_size
batch_size
=
test_case_args
.
batch_size
batch_num
=
test_case_args
.
batch_num
batch_num
=
test_case_args
.
batch_num
...
@@ -270,8 +282,9 @@ class QuantInt8NLPComparisonTest(unittest.TestCase):
...
@@ -270,8 +282,9 @@ class QuantInt8NLPComparisonTest(unittest.TestCase):
self
.
_op_ids_to_skip
=
set
(
self
.
_op_ids_to_skip
=
set
(
map
(
int
,
test_case_args
.
op_ids_to_skip
.
split
(
','
)))
map
(
int
,
test_case_args
.
op_ids_to_skip
.
split
(
','
)))
_logger
.
info
(
'
FP32 & Quant
INT8 prediction run.'
)
_logger
.
info
(
'
Quant &
INT8 prediction run.'
)
_logger
.
info
(
'Quant model: {}'
.
format
(
quant_model_path
))
_logger
.
info
(
'Quant model: {}'
.
format
(
quant_model_path
))
if
fp32_model_path
:
_logger
.
info
(
'FP32 model: {}'
.
format
(
fp32_model_path
))
_logger
.
info
(
'FP32 model: {}'
.
format
(
fp32_model_path
))
_logger
.
info
(
'Dataset: {}'
.
format
(
data_path
))
_logger
.
info
(
'Dataset: {}'
.
format
(
data_path
))
_logger
.
info
(
'Labels: {}'
.
format
(
labels_path
))
_logger
.
info
(
'Labels: {}'
.
format
(
labels_path
))
...
@@ -284,18 +297,20 @@ class QuantInt8NLPComparisonTest(unittest.TestCase):
...
@@ -284,18 +297,20 @@ class QuantInt8NLPComparisonTest(unittest.TestCase):
map
(
str
,
self
.
_op_ids_to_skip
))
if
test_case_args
.
op_ids_to_skip
map
(
str
,
self
.
_op_ids_to_skip
))
if
test_case_args
.
op_ids_to_skip
else
'none'
))
else
'none'
))
_logger
.
info
(
'---
FP32
prediction start ---'
)
_logger
.
info
(
'---
Quant
prediction start ---'
)
val_reader
=
paddle
.
batch
(
val_reader
=
paddle
.
batch
(
self
.
_reader_creator
(
data_path
,
labels_path
),
batch_size
=
batch_size
)
self
.
_reader_creator
(
data_path
,
labels_path
),
batch_size
=
batch_size
)
fp32_acc
,
fp32_pps
,
fp32
_lat
=
self
.
_predict
(
quant_acc
,
quant_pps
,
quant
_lat
=
self
.
_predict
(
val_reader
,
val_reader
,
fp32
_model_path
,
quant
_model_path
,
batch_size
,
batch_size
,
batch_num
,
batch_num
,
skip_batch_num
,
skip_batch_num
,
transform_to_int8
=
False
)
target
=
'quant'
)
_logger
.
info
(
'FP32: avg accuracy: {0:.6f}'
.
format
(
fp32_acc
))
self
.
_print_performance
(
'Quant'
,
quant_pps
,
quant_lat
)
_logger
.
info
(
'--- Quant INT8 prediction start ---'
)
self
.
_print_accuracy
(
'Quant'
,
quant_acc
)
_logger
.
info
(
'--- INT8 prediction start ---'
)
val_reader
=
paddle
.
batch
(
val_reader
=
paddle
.
batch
(
self
.
_reader_creator
(
data_path
,
labels_path
),
batch_size
=
batch_size
)
self
.
_reader_creator
(
data_path
,
labels_path
),
batch_size
=
batch_size
)
int8_acc
,
int8_pps
,
int8_lat
=
self
.
_predict
(
int8_acc
,
int8_pps
,
int8_lat
=
self
.
_predict
(
...
@@ -304,11 +319,29 @@ class QuantInt8NLPComparisonTest(unittest.TestCase):
...
@@ -304,11 +319,29 @@ class QuantInt8NLPComparisonTest(unittest.TestCase):
batch_size
,
batch_size
,
batch_num
,
batch_num
,
skip_batch_num
,
skip_batch_num
,
transform_to_int8
=
True
)
target
=
'int8'
)
_logger
.
info
(
'INT8: avg accuracy: {0:.6f}'
.
format
(
int8_acc
))
self
.
_print_performance
(
'INT8'
,
int8_pps
,
int8_lat
)
self
.
_print_accuracy
(
'INT8'
,
int8_acc
)
fp32_acc
=
fp32_pps
=
fp32_lat
=
-
1
if
fp32_model_path
:
_logger
.
info
(
'--- FP32 prediction start ---'
)
val_reader
=
paddle
.
batch
(
self
.
_reader_creator
(
data_path
,
labels_path
),
batch_size
=
batch_size
)
fp32_acc
,
fp32_pps
,
fp32_lat
=
self
.
_predict
(
val_reader
,
fp32_model_path
,
batch_size
,
batch_num
,
skip_batch_num
,
target
=
'fp32'
)
self
.
_print_performance
(
'FP32'
,
fp32_pps
,
fp32_lat
)
self
.
_print_accuracy
(
'FP32'
,
fp32_acc
)
self
.
_summarize_performance
(
fp32_pps
,
fp32_lat
,
int8_pps
,
int8_lat
)
self
.
_summarize_performance
(
int8_pps
,
int8_lat
,
fp32_pps
,
fp32_lat
)
self
.
_compare_accuracy
(
fp32_acc
,
int8_acc
,
acc_diff_threshold
)
self
.
_summarize_accuracy
(
quant_acc
,
int8_acc
,
fp32_acc
)
self
.
_compare_accuracy
(
acc_diff_threshold
,
quant_acc
,
int8_acc
)
if
__name__
==
'__main__'
:
if
__name__
==
'__main__'
:
...
...
python/paddle/fluid/contrib/slim/tests/save_quant_model.py
浏览文件 @
d0a921ba
...
@@ -35,11 +35,6 @@ def parse_args():
...
@@ -35,11 +35,6 @@ def parse_args():
type
=
str
,
type
=
str
,
default
=
''
,
default
=
''
,
help
=
'A path to a Quant model.'
)
help
=
'A path to a Quant model.'
)
parser
.
add_argument
(
'--fp32_model_save_path'
,
type
=
str
,
default
=
''
,
help
=
'Saved optimized fp32 model'
)
parser
.
add_argument
(
parser
.
add_argument
(
'--int8_model_save_path'
,
'--int8_model_save_path'
,
type
=
str
,
type
=
str
,
...
@@ -65,7 +60,7 @@ def parse_args():
...
@@ -65,7 +60,7 @@ def parse_args():
return
test_args
,
sys
.
argv
[:
1
]
+
args
return
test_args
,
sys
.
argv
[:
1
]
+
args
def
transform_and_save_
model
(
original_path
,
save_path
,
save_type
):
def
transform_and_save_
int8_model
(
original_path
,
save_path
):
place
=
fluid
.
CPUPlace
()
place
=
fluid
.
CPUPlace
()
exe
=
fluid
.
Executor
(
place
)
exe
=
fluid
.
Executor
(
place
)
inference_scope
=
fluid
.
executor
.
global_scope
()
inference_scope
=
fluid
.
executor
.
global_scope
()
...
@@ -96,26 +91,18 @@ def transform_and_save_model(original_path, save_path, save_type):
...
@@ -96,26 +91,18 @@ def transform_and_save_model(original_path, save_path, save_type):
_place
=
place
,
_place
=
place
,
_core
=
core
,
_core
=
core
,
_debug
=
test_args
.
debug
)
_debug
=
test_args
.
debug
)
graph
=
IrGraph
(
core
.
Graph
(
inference_program
.
desc
),
for_test
=
True
)
if
save_type
==
'FP32'
:
graph
=
transform_to_mkldnn_int8_pass
.
apply_fp32
(
graph
)
elif
save_type
==
'INT8'
:
graph
=
transform_to_mkldnn_int8_pass
.
apply
(
graph
)
graph
=
transform_to_mkldnn_int8_pass
.
apply
(
graph
)
inference_program
=
graph
.
to_program
()
inference_program
=
graph
.
to_program
()
with
fluid
.
scope_guard
(
inference_scope
):
with
fluid
.
scope_guard
(
inference_scope
):
fluid
.
io
.
save_inference_model
(
save_path
,
feed_target_names
,
fluid
.
io
.
save_inference_model
(
save_path
,
feed_target_names
,
fetch_targets
,
exe
,
inference_program
)
fetch_targets
,
exe
,
inference_program
)
print
(
"Success! Transformed Quant_{0} model can be found at {1}
\n
"
.
print
(
format
(
save_type
,
save_path
))
"Success! INT8 model obtained from the Quant model can be found at {}
\n
"
.
format
(
save_path
))
if
__name__
==
'__main__'
:
if
__name__
==
'__main__'
:
global
test_args
global
test_args
test_args
,
remaining_args
=
parse_args
()
test_args
,
remaining_args
=
parse_args
()
if
test_args
.
fp32_model_save_path
:
transform_and_save_int8_model
(
test_args
.
quant_model_path
,
transform_and_save_model
(
test_args
.
quant_model_path
,
test_args
.
int8_model_save_path
)
test_args
.
fp32_model_save_path
,
'FP32'
)
if
test_args
.
int8_model_save_path
:
transform_and_save_model
(
test_args
.
quant_model_path
,
test_args
.
int8_model_save_path
,
'INT8'
)
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录