未验证 提交 4cddb43c 编写于 作者: W Wojciech Uss 提交者: GitHub

Add support for Ernie NLP model to the Slim QAT (#22506)

* a test for Ernie QAT INT8 accuracy check

test=develop

* Remove NLP comparison test to split PRs

test=develop

* Fix typo and tabs, delete commented lines

test=develop

* re-combine the 2 PRs, test=develop
Co-authored-by: NMichał Gallus <sand3r@interia.eu>
Co-authored-by: Nbingyanghuang <33643817+bingyanghuang@users.noreply.github.com>
上级 5a1a9a1e
......@@ -226,15 +226,18 @@ if(WITH_MKLDNN)
set(INT8_DATA_DIR "${INFERENCE_DEMO_INSTALL_DIR}/int8v2")
### Image classification tests
set(IMAGENET_DATA_PATH "${INT8_DATA_DIR}/data.bin")
set(INT8_IMG_CLASS_TEST_APP "test_analyzer_int8_image_classification")
set(INT8_IMG_CLASS_TEST_APP_SRC "analyzer_int8_image_classification_tester.cc")
## Image classification models
# download dataset if necessary
download_int8_data(${INT8_DATA_DIR} "imagenet_val_100_tail.tar.gz")
# ImageNet small dataset
# May be already downloaded for INT8 QAT unit tests
set(IMAGENET_DATA_ARCHIVE "imagenet_val_100_tail.tar.gz")
set(IMAGENET_DATA_DIR "${INFERENCE_DEMO_INSTALL_DIR}/imagenet")
set(IMAGENET_DATA_PATH "${IMAGENET_DATA_DIR}/data.bin")
download_int8_data(${IMAGENET_DATA_DIR} ${IMAGENET_DATA_ARCHIVE})
# build test binary to be used in subsequent tests
set(INT8_IMG_CLASS_TEST_APP "test_analyzer_int8_image_classification")
set(INT8_IMG_CLASS_TEST_APP_SRC "analyzer_int8_image_classification_tester.cc")
inference_analysis_api_test_build(${INT8_IMG_CLASS_TEST_APP} ${INT8_IMG_CLASS_TEST_APP_SRC})
# resnet50 int8
......@@ -296,7 +299,7 @@ if(WITH_MKLDNN)
### optimized FP32 vs. QAT INT8 tests
set(QAT_DATA_DIR "${INFERENCE_DEMO_INSTALL_DIR}/int8v2")
set(QAT_DATA_DIR "${INFERENCE_DEMO_INSTALL_DIR}/qat")
set(QAT_IMG_CLASS_TEST_APP "test_analyzer_qat_image_classification")
set(QAT_IMG_CLASS_TEST_APP_SRC "analyzer_qat_image_classification_tester.cc")
......@@ -304,8 +307,8 @@ if(WITH_MKLDNN)
inference_analysis_api_test_build(${QAT_IMG_CLASS_TEST_APP} ${QAT_IMG_CLASS_TEST_APP_SRC})
# MobileNet FP32 vs. QAT INT8
# The FP32 model should already be downloaded for slim QAT unit tests
set(QAT2_MobileNet_MODEL_DIR "${QAT_DATA_DIR}/MobileNet_qat_perf")
download_qat_data(${QAT2_MobileNet_MODEL_DIR} "MobileNet_qat_perf.tar.gz")
set(QAT2_INT8_MobileNet_MODEL_DIR "${QAT_DATA_DIR}/MobileNet_qat_perf_int8")
download_qat_data(${QAT2_INT8_MobileNet_MODEL_DIR} "MobileNet_qat_perf_int8.tar.gz")
inference_analysis_api_qat_test_run(test_analyzer_qat_performance_benchmark ${QAT_IMG_CLASS_TEST_APP} ${QAT2_MobileNet_MODEL_DIR}/MobileNet_qat_perf/float ${QAT2_INT8_MobileNet_MODEL_DIR}/MobileNet_qat_perf_int8 ${IMAGENET_DATA_PATH})
......
......@@ -479,7 +479,7 @@ GetPrimitiveFactory(const MKLDNNDeviceContext& dev_ctx,
const Tensor* weights,
const mkldnn::engine& mkldnn_engine) {
const std::string key = platform::CreateKey(
platform::ThreadIDasStr(), input->format(),
platform::ThreadIDasStr(), input->format(), input->dims()[0],
framework::vectorize<int>(weights->dims()), ctx.OutputName("Out"));
auto prim_creator =
......
......@@ -17,10 +17,10 @@ from .... import core
from ....framework import IrGraph
from ....framework import IrNode
__all__ = ['FakeQAT2MkldnnINT8KernelPass', 'FakeQAT2MkldnnINT8PerfPass']
__all__ = ['QatInt8MkldnnPass', 'Qat2Int8MkldnnPass']
class FakeQAT2MkldnnINT8KernelPass(object):
class QatInt8MkldnnPass(object):
"""
Convert QuantizationFreezePass generated IrGraph to MKL-DNN supported INT8
IrGraph. Following transformations did in this pass:
......@@ -48,13 +48,13 @@ class FakeQAT2MkldnnINT8KernelPass(object):
# The original graph will be rewrite.
import paddle.fluid as fluid
from paddle.fluid.contrib.slim.quantization \
import FakeQAT2MkldnnINT8KernelPass
import QatInt8MkldnnPass
from paddle.fluid.framework import IrGraph
from paddle.fluid import core
graph = IrGraph(core.Graph(fluid.Program().desc), for_test=False)
place = fluid.CPUPlace()
mkldnn_pass = FakeQAT2MkldnnINT8KernelPass(fluid.global_scope(),
mkldnn_pass = QatInt8MkldnnPass(fluid.global_scope(),
place)
mkldnn_pass.apply(graph)
"""
......@@ -276,7 +276,7 @@ class FakeQAT2MkldnnINT8KernelPass(object):
graph.safe_remove_nodes(all_unused_vars)
class FakeQAT2MkldnnINT8PerfPass(object):
class Qat2Int8MkldnnPass(object):
"""
Transform a QAT model IrGraph into MKL-DNN supported INT8 IrGraph.
The pass consists of the following transformations:
......@@ -290,7 +290,12 @@ class FakeQAT2MkldnnINT8PerfPass(object):
passes (`cpu_quantize_pass`, `cpu_quantize_squash_pass`).
"""
def __init__(self, _scope=None, _place=None, _core=None, _debug=False):
def __init__(self,
_quantized_ops,
_scope=None,
_place=None,
_core=None,
_debug=False):
self._scope = _scope
self._place = _place
self._core = _core
......@@ -305,6 +310,10 @@ class FakeQAT2MkldnnINT8PerfPass(object):
'fake_quantize_dequantize_moving_average_abs_max'
]
self._fake_dequantize_types = ['fake_dequantize_max_abs']
self._quantized_ops = _quantized_ops
self._scale_immutable_ops = [
'transpose2', 'reshape2', 'pool2d', 'scale'
]
self._conv_ops = ['conv2d', 'depthwise_conv2d']
self._pool_ops = ['pool2d']
self._mul_ops = ['mul']
......@@ -324,8 +333,9 @@ class FakeQAT2MkldnnINT8PerfPass(object):
graph = self._dequantize_weights(graph)
graph = self._optimize_fp32_graph(graph)
graph = self._compute_weight_scales(graph)
graph = self._update_conv_relu_scales(graph)
graph = self._update_pooling_scales(graph)
graph = self._update_relu_output_scales(graph)
graph = self._propagate_scales(graph)
graph = self._set_dummy_fc_out_scales(graph)
graph = self._quantize_fp32_graph(graph)
graph = self._remove_unused_var_nodes(graph)
return graph
......@@ -346,6 +356,12 @@ class FakeQAT2MkldnnINT8PerfPass(object):
tensor.set(scale, core.CPUPlace())
return tensor
def _is_conv_quantized(self):
return any(op_type in self._quantized_ops for op_type in self._conv_ops)
def _is_fc_quantized(self):
return 'fc' in self._quantized_ops
def _gather_scales(self, graph):
for op in graph.all_op_nodes():
if op.name() in self._quantize_types:
......@@ -371,34 +387,94 @@ class FakeQAT2MkldnnINT8PerfPass(object):
self._weight_scales[input_name] = _max_range
return graph
def _update_pooling_scales(self, graph):
def _propagate_scales(self, graph):
def _update_scale_op_in_scale(op, input, output):
unsigned, tensor = self._var_quant_scales[output]
scale = np.array(tensor) * op.op().attr("scale")
new_tensor = self._convert_scale2tensor(scale.astype(np.float64))
self._var_quant_scales[input] = (unsigned, new_tensor)
def _update_scales(graph):
waiting_for_scale = set()
for op in graph.all_op_nodes():
if op.name() in self._scale_immutable_ops:
input_name = op.input("X")[0]
output_name = op.output("Out")[0]
tensor_names = [input_name, output_name]
# Scale is not quantized, so if it doesn't have any scales
# to propagate, its tensors won't be added to the waiting list.
if all(name not in self._var_quant_scales for name in tensor_names) \
and op.name() != 'scale':
waiting_for_scale.update(tensor_names)
continue
if input_name in self._var_quant_scales:
self._var_quant_scales[
output_name] = self._var_quant_scales[input_name]
elif output_name in self._var_quant_scales:
if op.name() == 'scale':
_update_scale_op_in_scale(op, input_name,
output_name)
else:
self._var_quant_scales[
input_name] = self._var_quant_scales[
output_name]
return waiting_for_scale
waiting_for_scale = _update_scales(graph)
while len(waiting_for_scale) != 0:
waiting_for_scale = _update_scales(graph)
return graph
def _set_dummy_fc_out_scales(self, graph):
'''
For the output tensors of FC that do not have an assigned scale,
assign a dummy scale (same scale as input), so that the quantize pass
won't fail. In the end these scales aren't used, since FCs that
have an unassigend output scale will have a force_fp32_output attr
set to True.
'''
for op in graph.all_op_nodes():
if op.name() in self._pool_ops:
input_name = op.input("X")[0]
if op.name() in self._fc_ops:
input_name = op.input("Input")[0]
output_name = op.output("Out")[0]
if input_name in self._var_quant_scales:
if input_name in self._var_quant_scales and \
output_name not in self._var_quant_scales:
# use input scale as a "dummy" scale
self._var_quant_scales[
output_name] = self._var_quant_scales[input_name]
return graph
def _load_param(self, scope, param_name):
return np.array(scope.find_var(param_name).get_tensor())
def _remove_fake_ops(self, graph):
'''
When FC isn't quantized:
Remove fake (de)quantize ops that do not surround mul.
When FC is quantized:
Remove all fake (de)quantize ops.
'''
is_fc_quantized = self._is_fc_quantized()
for op in graph.all_op_nodes():
if op.name() in self._fake_quantize_types:
op_out = graph._find_node_by_name(op.outputs,
op.output("Out")[0])
next_op = op_out.outputs[0]
if next_op.name() not in self._mul_ops:
if next_op.name() not in self._mul_ops or is_fc_quantized:
self._remove_fake_quantize(graph, op)
for op in graph.all_op_nodes():
if op.name() in self._fake_dequantize_types:
op_in = graph._find_node_by_name(op.inputs, op.input("X")[0])
prev_op = op_in.inputs[0]
if prev_op.name() not in self._mul_ops:
if prev_op.name() not in self._mul_ops or is_fc_quantized:
self._remove_fake_dequantize(graph, op)
return graph
def _remove_fake_quantize(self, graph, op):
......@@ -444,6 +520,8 @@ class FakeQAT2MkldnnINT8PerfPass(object):
for op in graph.all_op_nodes():
if op.name() in self._conv_ops:
self._dequantize_conv_weights(graph, op)
elif self._is_fc_quantized() and op.name() in self._mul_ops:
self._dequantize_mul_weights(graph, op)
return graph
def _dequantize_conv_weights(self, graph, op_node):
......@@ -472,13 +550,20 @@ class FakeQAT2MkldnnINT8PerfPass(object):
def _optimize_fp32_graph(self, graph):
graph = self._apply_pass(graph, 'mkldnn_placement_pass',
['mkldnn_enabled_op_types'], [set()])
graph = self._apply_pass(graph, 'depthwise_conv_mkldnn_pass')
graph = self._apply_pass(graph, 'conv_bn_fuse_pass')
graph = self._apply_pass(graph, 'conv_eltwiseadd_bn_fuse_pass')
graph = self._apply_pass(graph, 'conv_bias_mkldnn_fuse_pass')
graph = self._apply_pass(graph, 'conv_elementwise_add_mkldnn_fuse_pass')
graph = self._apply_pass(graph, 'conv_relu_mkldnn_fuse_pass')
graph = self._apply_pass(graph, 'conv_relu6_mkldnn_fuse_pass')
if self._is_conv_quantized():
graph = self._apply_pass(graph, 'depthwise_conv_mkldnn_pass')
graph = self._apply_pass(graph, 'conv_bn_fuse_pass')
graph = self._apply_pass(graph, 'conv_eltwiseadd_bn_fuse_pass')
graph = self._apply_pass(graph, 'conv_bias_mkldnn_fuse_pass')
graph = self._apply_pass(graph,
'conv_elementwise_add_mkldnn_fuse_pass')
graph = self._apply_pass(graph, 'conv_relu_mkldnn_fuse_pass')
graph = self._apply_pass(graph, 'conv_relu6_mkldnn_fuse_pass')
if self._is_fc_quantized():
graph = self._apply_pass(graph, 'fc_fuse_pass',
['use_gpu', 'use_fc_padding'],
[False, False])
graph = self._apply_pass(graph, 'fc_mkldnn_pass')
return graph
def _apply_pass(self, graph, pass_name, attrs=None, attr_values=None):
......@@ -528,6 +613,7 @@ class FakeQAT2MkldnnINT8PerfPass(object):
np.abs(weights.reshape(weights.shape[0], -1)).astype(
np.float64),
axis=axis)
scales[scales == np.Inf] = 0.0
lod_tensor = self._convert_scale2tensor(scales)
use_unsigned_int = False
......@@ -546,46 +632,41 @@ class FakeQAT2MkldnnINT8PerfPass(object):
ids.append(op.id())
return set(ids) if len(ids) else set([-1])
def _transform_to_quantize_mkldnn(self, graph, op_node):
"""
Transform fake_quantize_xx op to quantize mkldnn op in the graph.
"""
input_var_node = graph._find_node_by_name(op_node.inputs,
op_node.input("X")[0])
output_var_node = graph._find_node_by_name(op_node.outputs,
op_node.output("Out")[0])
scale_in = self._s8_max / self._load_param(
self._scope, op_node.input("InScale")[0])[0]
quant_op_node = graph.create_op_node(
op_type='quantize',
attrs={
'data_format': 'MKLDNNLAYOUT',
'use_mkldnn': 1,
'Scale': scale_in,
'is_negative_input': 1
},
inputs={'Input': input_var_node},
outputs={'Output': output_var_node})
graph.link_to(input_var_node, quant_op_node)
graph.link_to(quant_op_node, output_var_node)
graph.safe_remove_nodes(op_node)
return quant_op_node
def _update_relu_output_scales(self, graph):
def _update_scale(graph, ops, op_out_name, predicate):
'''
Sets the type of an output scale of a passed op type(s) to 'unsigned int8' if the
predicate applied on op passes. Typically, the predicate checks if op's
activation is set to relu.
'''
for op in graph.all_op_nodes():
if op.name() in ops:
out_name = op.output(op_out_name)[0]
if out_name in self._var_quant_scales and predicate(op.op(
)):
_, tensor = self._var_quant_scales[out_name]
self._var_quant_scales[out_name] = (True, tensor)
return graph
if self._is_conv_quantized():
conv_predicate = lambda op: op.attr("fuse_activation") == 'relu' and \
op.attr("fuse_residual_connection") == False
graph = _update_scale(graph, self._conv_ops, "Output",
conv_predicate)
if self._is_fc_quantized():
fc_predicate = lambda op: op.attr("activation_type") == 'relu'
graph = _update_scale(graph, self._fc_ops, "Out", fc_predicate)
def _update_conv_relu_scales(self, graph):
for op in graph.all_op_nodes():
if op.name() in self._conv_ops:
out_name = op.output("Output")[0]
if out_name in self._var_quant_scales and \
op.op().attr("fuse_activation") == 'relu' and \
op.op().attr("fuse_residual_connection") == False:
_, tensor = self._var_quant_scales[out_name]
self._var_quant_scales[out_name] = (True, tensor)
return graph
def _get_data_layout(self):
return 'NHWC' if self._is_conv_quantized() else 'NCHW'
def _quantize_fp32_graph(self, graph):
ir_pass = self._core.get_pass('cpu_quantize_placement_pass')
cpp_graph = graph.graph
ir_pass.set('quantize_enabled_op_types', {'conv2d', 'pool2d'})
ir_pass.set('quantize_enabled_op_types', self._quantized_ops)
ir_pass.set('quantize_excluded_op_ids',
self._find_avg_pooling_ids(graph))
ir_pass.apply(cpp_graph)
......@@ -593,8 +674,8 @@ class FakeQAT2MkldnnINT8PerfPass(object):
graph.draw('.', 'qat_int8_{}'.format(ir_pass.type()),
graph.all_op_nodes())
graph = self._apply_pass(graph, 'cpu_quantize_pass',
['quant_var_scales'],
[self._var_quant_scales])
graph = self._apply_pass(
graph, 'cpu_quantize_pass', ['quant_var_scales', 'data_layout'],
[self._var_quant_scales, self._get_data_layout()])
graph = self._apply_pass(graph, 'cpu_quantize_squash_pass')
return graph
file(GLOB TEST_OPS RELATIVE "${CMAKE_CURRENT_SOURCE_DIR}" "test_*.py")
string(REPLACE ".py" "" TEST_OPS "${TEST_OPS}")
function(_inference_analysis_python_api_int8_test target model_dir data_dir filename use_mkldnn)
function(_inference_analysis_python_api_int8_test target model_dir data_path filename use_mkldnn)
py_test(${target} SRCS ${filename}
ENVS CPU_NUM_THREADS=${CPU_NUM_THREADS_ON_CI}
FLAGS_use_mkldnn=${use_mkldnn}
ARGS --infer_model ${model_dir}/model
--infer_data ${data_dir}/data.bin
--infer_data ${data_path}
--int8_model_save_path int8_models/${target}
--warmup_batch_size ${WARMUP_BATCH_SIZE}
--batch_size 50)
endfunction()
function(inference_analysis_python_api_int8_test target model_dir data_dir filename)
_inference_analysis_python_api_int8_test(${target} ${model_dir} ${data_dir} ${filename} False)
endfunction()
function(inference_analysis_python_api_int8_test_mkldnn target model_dir data_dir filename)
_inference_analysis_python_api_int8_test(${target} ${model_dir} ${data_dir} ${filename} True)
function(inference_analysis_python_api_int8_test target model_dir data_path filename)
_inference_analysis_python_api_int8_test(${target} ${model_dir} ${data_path} ${filename} False)
endfunction()
function(inference_analysis_python_api_int8_test_custom_warmup_batch_size target model_dir data_dir filename warmup_batch_size)
......@@ -25,13 +21,29 @@ function(inference_analysis_python_api_int8_test_custom_warmup_batch_size target
inference_analysis_python_api_int8_test(${target} ${model_dir} ${data_dir} ${filename})
endfunction()
function(inference_qat_int8_test target model_dir data_dir test_script use_mkldnn)
py_test(${target} SRCS ${test_script}
function(inference_analysis_python_api_int8_test_mkldnn target model_dir data_path filename)
_inference_analysis_python_api_int8_test(${target} ${model_dir} ${data_path} ${filename} True)
endfunction()
function(download_qat_data install_dir data_file)
if (NOT EXISTS ${install_dir}/${data_file})
inference_download_and_uncompress(${install_dir} ${INFERENCE_URL}/int8 ${data_file})
endif()
endfunction()
function(download_qat_model install_dir data_file)
if (NOT EXISTS ${install_dir}/${data_file})
inference_download_and_uncompress(${install_dir} ${INFERENCE_URL}/int8/QAT_models ${data_file})
endif()
endfunction()
function(inference_qat_int8_image_classification_test target model_dir dataset_path)
py_test(${target} SRCS "${CMAKE_CURRENT_SOURCE_DIR}/qat_int8_image_classification_comparison.py"
ENVS FLAGS_OMP_NUM_THREADS=${CPU_NUM_THREADS_ON_CI}
OMP_NUM_THREADS=${CPU_NUM_THREADS_ON_CI}
FLAGS_use_mkldnn=${use_mkldnn}
FLAGS_use_mkldnn=true
ARGS --qat_model ${model_dir}/model
--infer_data ${data_dir}/data.bin
--infer_data ${dataset_path}
--batch_size 25
--batch_num 2
--acc_diff_threshold 0.1)
......@@ -39,24 +51,53 @@ endfunction()
# set batch_size 10 for UT only (avoid OOM). For whole dataset, use batch_size 25
function(inference_qat2_int8_test target model_dir data_dir test_script use_mkldnn)
py_test(${target} SRCS ${test_script}
function(inference_qat2_int8_image_classification_test target model_dir data_path quantized_ops)
py_test(${target} SRCS "${CMAKE_CURRENT_SOURCE_DIR}/qat_int8_image_classification_comparison.py"
ENVS FLAGS_OMP_NUM_THREADS=${CPU_NUM_THREADS_ON_CI}
OMP_NUM_THREADS=${CPU_NUM_THREADS_ON_CI}
FLAGS_use_mkldnn=${use_mkldnn}
FLAGS_use_mkldnn=true
ARGS --qat_model ${model_dir}/float
--infer_data ${data_dir}/data.bin
--infer_data ${data_path}
--batch_size 10
--batch_num 2
--acc_diff_threshold 0.1
--qat2)
--quantized_ops ${quantized_ops}
--qat2)
endfunction()
# set batch_size 10 for UT only (avoid OOM). For whole dataset, use batch_size 20
function(inference_qat2_int8_nlp_test target model_dir data_path labels_path quantized_ops)
py_test(${target} SRCS "${CMAKE_CURRENT_SOURCE_DIR}/qat_int8_nlp_comparison.py"
ENVS FLAGS_OMP_NUM_THREADS=${CPU_NUM_THREADS_ON_CI}
OMP_NUM_THREADS=${CPU_NUM_THREADS_ON_CI}
FLAGS_use_mkldnn=true
ARGS --qat_model ${model_dir}/float
--infer_data ${data_path}
--labels ${labels_path}
--batch_size 10
--batch_num 2
--quantized_ops ${quantized_ops}
--acc_diff_threshold 0.1)
endfunction()
function(download_qat_data install_dir data_file)
if (NOT EXISTS ${install_dir}/${data_file})
inference_download_and_uncompress(${install_dir} ${INFERENCE_URL}/int8 ${data_file})
endif()
endfunction()
function(download_qat_model install_dir data_file)
if (NOT EXISTS ${install_dir}/${data_file})
inference_download_and_uncompress(${install_dir} ${INFERENCE_URL}/int8/QAT_models ${data_file})
endif()
endfunction()
function(save_qat_model_test target qat_model_dir fp32_model_save_path int8_model_save_path test_script)
py_test(${target} SRCS ${test_script}
function(save_qat_model_test target qat_model_dir fp32_model_save_path int8_model_save_path quantized_ops)
py_test(${target} SRCS ${CMAKE_CURRENT_SOURCE_DIR}/save_qat_model.py
ARGS --qat_model_path ${qat_model_dir}
--fp32_model_save_path ${fp32_model_save_path}
--int8_model_save_path ${int8_model_save_path})
--fp32_model_save_path ${fp32_model_save_path}
--int8_model_save_path ${int8_model_save_path}
--quantized_ops ${quantized_ops})
endfunction()
if(WIN32)
......@@ -66,137 +107,151 @@ if(WIN32)
list(REMOVE_ITEM TEST_OPS test_weight_quantization_mobilenetv1)
endif()
# int8 image classification python api test
if(LINUX AND WITH_MKLDNN)
set(INT8_DATA_DIR "${INFERENCE_DEMO_INSTALL_DIR}/int8v2")
set(MKLDNN_INT8_TEST_FILE "test_mkldnn_int8_quantization_strategy.py")
set(MKLDNN_INT8_TEST_FILE_PATH "${CMAKE_CURRENT_SOURCE_DIR}/${MKLDNN_INT8_TEST_FILE}")
# googlenet int8
set(INT8_GOOGLENET_MODEL_DIR "${INT8_DATA_DIR}/googlenet")
inference_analysis_python_api_int8_test_custom_warmup_batch_size(test_slim_int8_googlenet ${INT8_GOOGLENET_MODEL_DIR} ${INT8_DATA_DIR} ${MKLDNN_INT8_TEST_FILE_PATH} 10)
# mobilenet int8
set(INT8_MOBILENET_MODEL_DIR "${INT8_DATA_DIR}/mobilenetv1")
inference_analysis_python_api_int8_test(test_slim_int8_mobilenet ${INT8_MOBILENET_MODEL_DIR} ${INT8_DATA_DIR} ${MKLDNN_INT8_TEST_FILE_PATH})
inference_analysis_python_api_int8_test_mkldnn(test_slim_int8_mobilenet_mkldnn ${INT8_MOBILENET_MODEL_DIR} ${INT8_DATA_DIR} ${MKLDNN_INT8_TEST_FILE_PATH})
# temporarily adding WITH_SLIM_MKLDNN_FULL_TEST FLAG for QA testing the following UTs locally,
# since the following UTs cost too much time on CI test.
if (WITH_SLIM_MKLDNN_FULL_TEST)
# resnet50 int8
set(INT8_RESNET50_MODEL_DIR "${INT8_DATA_DIR}/resnet50")
inference_analysis_python_api_int8_test(test_slim_int8_resnet50 ${INT8_RESNET50_MODEL_DIR} ${INT8_DATA_DIR} ${MKLDNN_INT8_TEST_FILE_PATH})
# mobilenetv2 int8
set(INT8_MOBILENETV2_MODEL_DIR "${INT8_DATA_DIR}/mobilenetv2")
inference_analysis_python_api_int8_test(test_slim_int8_mobilenetv2 ${INT8_MOBILENETV2_MODEL_DIR} ${INT8_DATA_DIR} ${MKLDNN_INT8_TEST_FILE_PATH})
# resnet101 int8
set(INT8_RESNET101_MODEL_DIR "${INT8_DATA_DIR}/resnet101")
inference_analysis_python_api_int8_test(test_slim_int8_resnet101 ${INT8_RESNET101_MODEL_DIR} ${INT8_DATA_DIR} ${MKLDNN_INT8_TEST_FILE_PATH})
# vgg16 int8
set(INT8_VGG16_MODEL_DIR "${INT8_DATA_DIR}/vgg16")
inference_analysis_python_api_int8_test(test_slim_int8_vgg16 ${INT8_VGG16_MODEL_DIR} ${INT8_DATA_DIR} ${MKLDNN_INT8_TEST_FILE_PATH})
# vgg19 int8
set(INT8_VGG19_MODEL_DIR "${INT8_DATA_DIR}/vgg19")
inference_analysis_python_api_int8_test(test_slim_int8_vgg19 ${INT8_VGG19_MODEL_DIR} ${INT8_DATA_DIR} ${MKLDNN_INT8_TEST_FILE_PATH})
endif()
endif()
# Since test_mkldnn_int8_quantization_strategy only supports testing on Linux
# with MKL-DNN, we remove it here for not repeating test, or not testing on other systems.
list(REMOVE_ITEM TEST_OPS test_mkldnn_int8_quantization_strategy)
#### Image classification dataset: ImageNet (small)
# The dataset should already be downloaded for INT8v2 unit tests
set(IMAGENET_DATA_PATH "${INFERENCE_DEMO_INSTALL_DIR}/imagenet/data.bin")
# QAT FP32 & INT8 comparison python api tests
if(LINUX AND WITH_MKLDNN)
set(DATASET_DIR "${INFERENCE_DEMO_INSTALL_DIR}/int8v2")
set(QAT_DATA_DIR "${INFERENCE_DEMO_INSTALL_DIR}/int8v2")
set(QAT_MODELS_BASE_URL "${INFERENCE_URL}/int8/QAT_models")
set(MKLDNN_QAT_TEST_FILE "qat_int8_comparison.py")
set(MKLDNN_QAT_TEST_FILE_PATH "${CMAKE_CURRENT_SOURCE_DIR}/${MKLDNN_QAT_TEST_FILE}")
# ImageNet small dataset
# May be already downloaded for INT8v2 unit tests
if (NOT EXISTS ${DATASET_DIR})
inference_download_and_uncompress(${DATASET_DIR} "${INFERENCE_URL}/int8" "imagenet_val_100_tail.tar.gz")
#### INT8 image classification python api test
# Models should be already downloaded for INT8v2 unit tests
set(INT8_INSTALL_DIR "${INFERENCE_DEMO_INSTALL_DIR}/int8v2")
set(INT8_IC_TEST_FILE "test_mkldnn_int8_quantization_strategy.py")
set(INT8_IC_TEST_FILE_PATH "${CMAKE_CURRENT_SOURCE_DIR}/${INT8_IC_TEST_FILE}")
# googlenet int8
set(INT8_GOOGLENET_MODEL_DIR "${INT8_INSTALL_DIR}/googlenet")
inference_analysis_python_api_int8_test_custom_warmup_batch_size(test_slim_int8_googlenet ${INT8_GOOGLENET_MODEL_DIR} ${IMAGENET_DATA_PATH} ${INT8_IC_TEST_FILE_PATH} 10)
# mobilenet int8
set(INT8_MOBILENET_MODEL_DIR "${INT8_INSTALL_DIR}/mobilenetv1")
inference_analysis_python_api_int8_test(test_slim_int8_mobilenet ${INT8_MOBILENET_MODEL_DIR} ${IMAGENET_DATA_PATH} ${INT8_IC_TEST_FILE_PATH})
inference_analysis_python_api_int8_test_mkldnn(test_slim_int8_mobilenet_mkldnn ${INT8_MOBILENET_MODEL_DIR} ${IMAGENET_DATA_PATH} ${INT8_IC_TEST_FILE_PATH})
# temporarily adding WITH_SLIM_MKLDNN_FULL_TEST FLAG for QA testing the following UTs locally,
# since the following UTs cost too much time on CI test.
if (WITH_SLIM_MKLDNN_FULL_TEST)
# resnet50 int8
set(INT8_RESNET50_MODEL_DIR "${INT8_INSTALL_DIR}/resnet50")
inference_analysis_python_api_int8_test(test_slim_int8_resnet50 ${INT8_RESNET50_MODEL_DIR} ${IMAGENET_DATA_PATH} ${INT8_IC_TEST_FILE_PATH})
# mobilenetv2 int8
set(INT8_MOBILENETV2_MODEL_DIR "${INT8_INSTALL_DIR}/mobilenetv2")
inference_analysis_python_api_int8_test(test_slim_int8_mobilenetv2 ${INT8_MOBILENETV2_MODEL_DIR} ${IMAGENET_DATA_PATH} ${INT8_IC_TEST_FILE_PATH})
# resnet101 int8
set(INT8_RESNET101_MODEL_DIR "${INT8_INSTALL_DIR}/resnet101")
inference_analysis_python_api_int8_test(test_slim_int8_resnet101 ${INT8_RESNET101_MODEL_DIR} ${IMAGENET_DATA_PATH} ${INT8_IC_TEST_FILE_PATH})
# vgg16 int8
set(INT8_VGG16_MODEL_DIR "${INT8_INSTALL_DIR}/vgg16")
inference_analysis_python_api_int8_test(test_slim_int8_vgg16 ${INT8_VGG16_MODEL_DIR} ${IMAGENET_DATA_PATH} ${INT8_IC_TEST_FILE_PATH})
# vgg19 int8
set(INT8_VGG19_MODEL_DIR "${INT8_INSTALL_DIR}/vgg19")
inference_analysis_python_api_int8_test(test_slim_int8_vgg19 ${INT8_VGG19_MODEL_DIR} ${IMAGENET_DATA_PATH} ${INT8_IC_TEST_FILE_PATH})
endif()
#### QAT FP32 & INT8 comparison python api tests
set(QAT_INSTALL_DIR "${INFERENCE_DEMO_INSTALL_DIR}/qat")
### QATv1 for image classification
# QAT ResNet50
set(QAT_RESNET50_MODEL_DIR "${QAT_DATA_DIR}/ResNet50_QAT")
if (NOT EXISTS ${QAT_RESNET50_MODEL_DIR})
inference_download_and_uncompress(${QAT_RESNET50_MODEL_DIR} "${QAT_MODELS_BASE_URL}" "ResNet50_qat_model.tar.gz" )
endif()
inference_qat_int8_test(test_qat_int8_resnet50_mkldnn ${QAT_RESNET50_MODEL_DIR} ${DATASET_DIR} ${MKLDNN_QAT_TEST_FILE_PATH} true)
set(QAT_RESNET50_MODEL_DIR "${QAT_INSTALL_DIR}/ResNet50_QAT")
set(QAT_RESNET50_MODEL_ARCHIVE "ResNet50_qat_model.tar.gz")
download_qat_model(${QAT_RESNET50_MODEL_DIR} ${QAT_RESNET50_MODEL_ARCHIVE})
inference_qat_int8_image_classification_test(test_qat_int8_resnet50_mkldnn ${QAT_RESNET50_MODEL_DIR} ${IMAGENET_DATA_PATH})
# QAT ResNet101
set(QAT_RESNET101_MODEL_DIR "${QAT_DATA_DIR}/ResNet101_QAT")
if (NOT EXISTS ${QAT_RESNET101_MODEL_DIR})
inference_download_and_uncompress(${QAT_RESNET101_MODEL_DIR} "${QAT_MODELS_BASE_URL}" "ResNet101_qat_model.tar.gz" )
endif()
# inference_qat_int8_test(test_qat_int8_resnet101_mkldnn ${QAT_RESNET101_MODEL_DIR} ${DATASET_DIR} ${MKLDNN_QAT_TEST_FILE_PATH} true)
set(QAT_RESNET101_MODEL_DIR "${QAT_INSTALL_DIR}/ResNet101_QAT")
set(QAT_RESNET101_MODEL_ARCHIVE "ResNet101_qat_model.tar.gz")
download_qat_model(${QAT_RESNET101_MODEL_DIR} ${QAT_RESNET101_MODEL_ARCHIVE})
# inference_qat_int8_image_classification_test(test_qat_int8_resnet101_mkldnn ${QAT_RESNET101_MODEL_DIR} ${IMAGENET_DATA_PATH})
# QAT GoogleNet
set(QAT_GOOGLENET_MODEL_DIR "${QAT_DATA_DIR}/GoogleNet_QAT")
if (NOT EXISTS ${QAT_GOOGLENET_MODEL_DIR})
inference_download_and_uncompress(${QAT_GOOGLENET_MODEL_DIR} "${QAT_MODELS_BASE_URL}" "GoogleNet_qat_model.tar.gz" )
endif()
inference_qat_int8_test(test_qat_int8_googlenet_mkldnn ${QAT_GOOGLENET_MODEL_DIR} ${DATASET_DIR} ${MKLDNN_QAT_TEST_FILE_PATH} true)
set(QAT_GOOGLENET_MODEL_DIR "${QAT_INSTALL_DIR}/GoogleNet_QAT")
set(QAT_GOOGLENET_MODEL_ARCHIVE "GoogleNet_qat_model.tar.gz")
download_qat_model(${QAT_GOOGLENET_MODEL_DIR} ${QAT_GOOGLENET_MODEL_ARCHIVE})
inference_qat_int8_image_classification_test(test_qat_int8_googlenet_mkldnn ${QAT_GOOGLENET_MODEL_DIR} ${IMAGENET_DATA_PATH})
# QAT MobileNetV1
set(QAT_MOBILENETV1_MODEL_DIR "${QAT_DATA_DIR}/MobileNetV1_QAT")
if (NOT EXISTS ${QAT_MOBILENETV1_MODEL_DIR})
inference_download_and_uncompress(${QAT_MOBILENETV1_MODEL_DIR} "${QAT_MODELS_BASE_URL}" "MobileNetV1_qat_model.tar.gz" )
endif()
inference_qat_int8_test(test_qat_int8_mobilenetv1_mkldnn ${QAT_MOBILENETV1_MODEL_DIR} ${DATASET_DIR} ${MKLDNN_QAT_TEST_FILE_PATH} true)
set(QAT_MOBILENETV1_MODEL_DIR "${QAT_INSTALL_DIR}/MobileNetV1_QAT")
set(QAT_MOBILENETV1_MODEL_ARCHIVE "MobileNetV1_qat_model.tar.gz")
download_qat_model(${QAT_MOBILENETV1_MODEL_DIR} ${QAT_MOBILENETV1_MODEL_ARCHIVE})
inference_qat_int8_image_classification_test(test_qat_int8_mobilenetv1_mkldnn ${QAT_MOBILENETV1_MODEL_DIR} ${IMAGENET_DATA_PATH})
# QAT MobileNetV2
set(QAT_MOBILENETV2_MODEL_DIR "${QAT_DATA_DIR}/MobileNetV2_QAT")
if (NOT EXISTS ${QAT_MOBILENETV2_MODEL_DIR})
inference_download_and_uncompress(${QAT_MOBILENETV2_MODEL_DIR} "${QAT_MODELS_BASE_URL}" "MobileNetV2_qat_model.tar.gz" )
endif()
inference_qat_int8_test(test_qat_int8_mobilenetv2_mkldnn ${QAT_MOBILENETV2_MODEL_DIR} ${DATASET_DIR} ${MKLDNN_QAT_TEST_FILE_PATH} true)
set(QAT_MOBILENETV2_MODEL_DIR "${QAT_INSTALL_DIR}/MobileNetV2_QAT")
set(QAT_MOBILENETV2_MODEL_ARCHIVE "MobileNetV2_qat_model.tar.gz")
download_qat_model(${QAT_MOBILENETV2_MODEL_DIR} ${QAT_MOBILENETV2_MODEL_ARCHIVE})
inference_qat_int8_image_classification_test(test_qat_int8_mobilenetv2_mkldnn ${QAT_MOBILENETV2_MODEL_DIR} ${IMAGENET_DATA_PATH})
# QAT VGG16
set(QAT_VGG16_MODEL_DIR "${QAT_DATA_DIR}/VGG16_QAT")
if (NOT EXISTS ${QAT_VGG16_MODEL_DIR})
inference_download_and_uncompress(${QAT_VGG16_MODEL_DIR} "${QAT_MODELS_BASE_URL}" "VGG16_qat_model.tar.gz" )
endif()
# inference_qat_int8_test(test_qat_int8_vgg16_mkldnn ${QAT_VGG16_MODEL_DIR} ${DATASET_DIR} ${MKLDNN_QAT_TEST_FILE_PATH} true)
set(QAT_VGG16_MODEL_DIR "${QAT_INSTALL_DIR}/VGG16_QAT")
set(QAT_VGG16_MODEL_ARCHIVE "VGG16_qat_model.tar.gz")
download_qat_model(${QAT_VGG16_MODEL_DIR} ${QAT_VGG16_MODEL_ARCHIVE})
# inference_qat_int8_image_classification_test(test_qat_int8_vgg16_mkldnn ${QAT_VGG16_MODEL_DIR} ${IMAGENET_DATA_PATH})
# QAT VGG19
set(QAT_VGG19_MODEL_DIR "${QAT_DATA_DIR}/VGG19_QAT")
if (NOT EXISTS ${QAT_VGG19_MODEL_DIR})
inference_download_and_uncompress(${QAT_VGG19_MODEL_DIR} "${QAT_MODELS_BASE_URL}" "VGG19_qat_model.tar.gz" )
endif()
# inference_qat_int8_test(test_qat_int8_vgg19_mkldnn ${QAT_VGG19_MODEL_DIR} ${DATASET_DIR} ${MKLDNN_QAT_TEST_FILE_PATH} true)
set(QAT2_RESNET50_MODEL_DIR "${QAT_DATA_DIR}/ResNet50_qat_perf")
if (NOT EXISTS ${QAT2_RESNET50_MODEL_DIR})
inference_download_and_uncompress(${QAT2_RESNET50_MODEL_DIR} "${QAT_MODELS_BASE_URL}" "ResNet50_qat_perf.tar.gz" )
endif()
inference_qat2_int8_test(test_qat2_int8_resnet50_mkldnn ${QAT2_RESNET50_MODEL_DIR}/ResNet50_qat_perf ${DATASET_DIR} ${MKLDNN_QAT_TEST_FILE_PATH} true)
set(QAT2_MOBILENETV1_MODEL_DIR "${QAT_DATA_DIR}/MobileNet_qat_perf")
if (NOT EXISTS ${QAT2_MOBILENETV1_MODEL_DIR})
inference_download_and_uncompress(${QAT2_MOBILENETV1_MODEL_DIR} "${QAT_MODELS_BASE_URL}" "MobileNet_qat_perf.tar.gz" )
endif()
inference_qat2_int8_test(test_qat2_int8_mobilenetv1_mkldnn ${QAT2_MOBILENETV1_MODEL_DIR}/MobileNet_qat_perf ${DATASET_DIR} ${MKLDNN_QAT_TEST_FILE_PATH} true)
# Save qat2 fp32 model or qat2 int8 model
set(QAT_VGG19_MODEL_DIR "${QAT_INSTALL_DIR}/VGG19_QAT")
set(QAT_VGG19_MODEL_ARCHIVE "VGG19_qat_model.tar.gz")
download_qat_model(${QAT_VGG19_MODEL_DIR} ${QAT_VGG19_MODEL_ARCHIVE})
# inference_qat_int8_image_classification_test(test_qat_int8_vgg19_mkldnn ${QAT_VGG19_MODEL_DIR} ${IMAGENET_DATA_PATH})
### QATv2 for image classification
set(QAT2_IC_QUANTIZED_OPS "conv2d,pool2d")
# QAT2 ResNet50
set(QAT2_RESNET50_MODEL_DIR "${QAT_INSTALL_DIR}/ResNet50_qat_perf")
set(QAT2_RESNET50_MODEL_ARCHIVE "ResNet50_qat_perf.tar.gz")
download_qat_model(${QAT2_RESNET50_MODEL_DIR} ${QAT2_RESNET50_MODEL_ARCHIVE})
inference_qat2_int8_image_classification_test(test_qat2_int8_resnet50_mkldnn ${QAT2_RESNET50_MODEL_DIR}/ResNet50_qat_perf ${IMAGENET_DATA_PATH} ${QAT2_IC_QUANTIZED_OPS})
# QAT2 MobileNetV1
set(QAT2_MOBILENETV1_MODEL_DIR "${QAT_INSTALL_DIR}/MobileNet_qat_perf")
set(QAT2_MOBILENETV1_MODEL_ARCHIVE "MobileNet_qat_perf.tar.gz")
download_qat_model(${QAT2_MOBILENETV1_MODEL_DIR} ${QAT2_MOBILENETV1_MODEL_ARCHIVE})
inference_qat2_int8_image_classification_test(test_qat2_int8_mobilenetv1_mkldnn ${QAT2_MOBILENETV1_MODEL_DIR}/MobileNet_qat_perf ${IMAGENET_DATA_PATH} ${QAT2_IC_QUANTIZED_OPS})
### QATv2 for NLP
set(QAT2_NLP_QUANTIZED_OPS "fc,reshape2,transpose2")
set(NLP_DATA_ARCHIVE "Ernie_dataset.tar.gz")
set(NLP_DATA_DIR "${INFERENCE_DEMO_INSTALL_DIR}/Ernie_dataset")
set(NLP_DATA_PATH "${NLP_DATA_DIR}/Ernie_dataset/1.8w.bs1")
set(NLP_LABLES_PATH "${NLP_DATA_DIR}/Ernie_dataset/label.xnli.dev")
download_qat_data(${NLP_DATA_DIR} ${NLP_DATA_ARCHIVE})
# QAT2 Ernie
set(QAT2_ERNIE_MODEL_ARCHIVE "ernie_qat.tar.gz")
set(QAT2_ERNIE_MODEL_DIR "${QAT_INSTALL_DIR}/Ernie_qat")
download_qat_model(${QAT2_ERNIE_MODEL_DIR} ${QAT2_ERNIE_MODEL_ARCHIVE})
inference_qat2_int8_nlp_test(test_qat2_int8_ernie_mkldnn ${QAT2_ERNIE_MODEL_DIR}/Ernie_qat ${NLP_DATA_PATH} ${NLP_LABLES_PATH} ${QAT2_NLP_QUANTIZED_OPS})
# Save QAT2 FP32 model or QAT2 INT8 model
set(QAT2_INT8_SAVE_PATH "${QAT_DATA_DIR}/ResNet50_qat2_int8")
set(QAT2_FP32_SAVE_PATH "${QAT_DATA_DIR}/ResNet50_qat2_fp32")
set(SAVE_QAT2_MODEL_SCRIPT "${CMAKE_CURRENT_SOURCE_DIR}/save_qat_model.py")
save_qat_model_test(save_qat2_model_resnet50 ${QAT2_RESNET50_MODEL_DIR}/ResNet50_qat_perf/float ${QAT2_FP32_SAVE_PATH} ${QAT2_INT8_SAVE_PATH} ${SAVE_QAT2_MODEL_SCRIPT} true)
save_qat_model_test(save_qat2_model_resnet50 ${QAT2_RESNET50_MODEL_DIR}/ResNet50_qat_perf/float ${QAT2_FP32_SAVE_PATH} ${QAT2_INT8_SAVE_PATH} ${SAVE_QAT2_MODEL_SCRIPT} true)
set(QAT2_INT8_ERNIE_SAVE_PATH "${QAT_INSTALL_DIR}/Ernie_qat2_int8")
set(QAT2_FP32_ERNIE_SAVE_PATH "${QAT_INSTALL_DIR}/Ernie_qat2_fp32")
save_qat_model_test(save_qat2_model_ernie ${QAT2_ERNIE_MODEL_DIR}/Ernie_qat/float ${QAT2_FP32_ERNIE_SAVE_PATH} ${QAT2_INT8_ERNIE_SAVE_PATH} ${QAT2_NLP_QUANTIZED_OPS})
endif()
# Since the test for QAT FP32 & INT8 comparison supports only testing on Linux
# Since the tests for QAT FP32 & INT8 comparison support only testing on Linux
# with MKL-DNN, we remove it here to not test it on other systems.
list(REMOVE_ITEM TEST_OPS qat_int8_comparison.py)
list(REMOVE_ITEM TEST_OPS
test_mkldnn_int8_quantization_strategy
qat_int8_image_classification_comparison
qat_int8_nlp_comparison)
foreach(src ${TEST_OPS})
py_test(${src} SRCS ${src}.py)
......
......@@ -24,8 +24,8 @@ import time
import paddle
import paddle.fluid as fluid
from paddle.fluid.framework import IrGraph
from paddle.fluid.contrib.slim.quantization import FakeQAT2MkldnnINT8KernelPass
from paddle.fluid.contrib.slim.quantization import FakeQAT2MkldnnINT8PerfPass
from paddle.fluid.contrib.slim.quantization import QatInt8MkldnnPass
from paddle.fluid.contrib.slim.quantization import Qat2Int8MkldnnPass
from paddle.fluid import core
logging.basicConfig(format='%(asctime)s-%(levelname)s: %(message)s')
......@@ -53,10 +53,6 @@ def parse_args():
action='store_true',
help='If used, the QAT model is treated as a second generation model for performance optimization.'
)
parser.add_argument(
'--save_model',
action='store_true',
help='If used, the QAT model will be saved after all transformations')
parser.add_argument('--infer_data', type=str, default='', help='Data file.')
parser.add_argument(
'--batch_num',
......@@ -68,15 +64,20 @@ def parse_args():
type=float,
default=0.01,
help='Accepted accuracy difference threshold.')
parser.add_argument(
'--quantized_ops',
type=str,
default='',
help='A comma separated list of quantized operators.')
test_args, args = parser.parse_known_args(namespace=unittest)
return test_args, sys.argv[:1] + args
class TestQatInt8Comparison(unittest.TestCase):
class QatInt8ImageClassificationComparisonTest(unittest.TestCase):
"""
Test for accuracy comparison of QAT FP32 and INT8 inference.
Test for accuracy comparison of QAT FP32 and INT8 Image Classification inference.
"""
def _reader_creator(self, data_file='data.bin'):
......@@ -182,14 +183,15 @@ class TestQatInt8Comparison(unittest.TestCase):
graph.draw('.', 'qat_orig', graph.all_op_nodes())
if (transform_to_int8):
if (test_case_args.qat2):
transform_to_mkldnn_int8_pass = FakeQAT2MkldnnINT8PerfPass(
transform_to_mkldnn_int8_pass = Qat2Int8MkldnnPass(
self._quantized_ops,
_scope=inference_scope,
_place=place,
_core=core,
_debug=self._debug)
graph = transform_to_mkldnn_int8_pass.apply(graph)
else:
mkldnn_int8_pass = FakeQAT2MkldnnINT8KernelPass(
mkldnn_int8_pass = QatInt8MkldnnPass(
_scope=inference_scope, _place=place)
graph = mkldnn_int8_pass.apply(graph)
......@@ -256,12 +258,6 @@ class TestQatInt8Comparison(unittest.TestCase):
_logger.info('Total inference run time: {:.2f} s'.format(
infer_total_time))
if test_case_args.save_model:
with fluid.scope_guard(inference_scope):
fluid.io.save_inference_model(
'transformed_qat_int8_model', feed_target_names,
fetch_targets, exe, inference_program)
return outputs, acc1_avg, acc5_avg, fps_avg, latency_avg
def _summarize_performance(self, fp32_fps, fp32_lat, int8_fps, int8_lat):
......@@ -298,6 +294,7 @@ class TestQatInt8Comparison(unittest.TestCase):
skip_batch_num = test_case_args.skip_batch_num
acc_diff_threshold = test_case_args.acc_diff_threshold
self._debug = test_case_args.debug
self._quantized_ops = set(test_case_args.quantized_ops.split(','))
_logger.info('QAT FP32 & INT8 prediction run.')
_logger.info('QAT model: {0}'.format(qat_model_path))
......@@ -305,6 +302,7 @@ class TestQatInt8Comparison(unittest.TestCase):
_logger.info('Batch size: {0}'.format(batch_size))
_logger.info('Batch number: {0}'.format(batch_num))
_logger.info('Accuracy drop threshold: {0}.'.format(acc_diff_threshold))
_logger.info('Quantized ops: {0}.'.format(self._quantized_ops))
_logger.info('--- QAT FP32 prediction start ---')
val_reader = paddle.batch(
......
# copyright (c) 2020 paddlepaddle authors. all rights reserved.
#
# licensed under the apache license, version 2.0 (the "license");
# you may not use this file except in compliance with the license.
# you may obtain a copy of the license at
#
# http://www.apache.org/licenses/license-2.0
#
# unless required by applicable law or agreed to in writing, software
# distributed under the license is distributed on an "as is" basis,
# without warranties or conditions of any kind, either express or implied.
# see the license for the specific language governing permissions and
# limitations under the license.
import unittest
import os
import sys
import argparse
import logging
import struct
import six
import numpy as np
import time
import paddle
import paddle.fluid as fluid
from paddle.fluid.framework import IrGraph
from paddle.fluid.contrib.slim.quantization import Qat2Int8MkldnnPass
from paddle.fluid import core
logging.basicConfig(format='%(asctime)s-%(levelname)s: %(message)s')
_logger = logging.getLogger(__name__)
_logger.setLevel(logging.INFO)
def parse_args():
parser = argparse.ArgumentParser()
parser.add_argument('--batch_size', type=int, default=1, help='Batch size.')
parser.add_argument(
'--skip_batch_num',
type=int,
default=0,
help='Number of the first minibatches to skip in performance statistics.'
)
parser.add_argument(
'--debug',
action='store_true',
help='If used, the graph of QAT model is drawn.')
parser.add_argument(
'--qat_model', type=str, default='', help='A path to a QAT model.')
parser.add_argument(
'--save_model',
action='store_true',
help='If used, the QAT model will be saved after all transformations')
parser.add_argument('--infer_data', type=str, default='', help='Data file.')
parser.add_argument(
'--labels', type=str, default='', help='File with labels.')
parser.add_argument(
'--batch_num',
type=int,
default=1,
help='Number of batches to process. 0 or less means all.')
parser.add_argument(
'--acc_diff_threshold',
type=float,
default=0.01,
help='Accepted accuracy difference threshold.')
parser.add_argument(
'--quantized_ops',
type=str,
default='',
help='A comma separated list of quantized operators.')
test_args, args = parser.parse_known_args(namespace=unittest)
return test_args, sys.argv[:1] + args
class QatInt8NLPComparisonTest(unittest.TestCase):
"""
Test for accuracy comparison of QAT FP32 and INT8 NLP inference.
"""
def _reader_creator(self, data_file=None, labels_file=None):
assert data_file, "The dataset file is missing."
assert labels_file, "The labels file is missing."
def reader():
with open(data_file, 'r') as df:
with open(labels_file, 'r') as lf:
data_lines = df.readlines()
labels_lines = lf.readlines()
assert len(data_lines) == len(
labels_lines
), "The number of labels does not match the length of the dataset."
for i in range(len(data_lines)):
data_fields = data_lines[i].split(';')
assert len(
data_fields
) >= 2, "The number of data fields in the dataset is less than 2"
buffers = []
shape = []
for j in range(2):
data = data_fields[j].split(':')
assert len(
data
) >= 2, "Size of data in the dataset is less than 2"
# Shape is stored under index 0, while data under 1
shape = data[0].split()
shape.pop(0)
shape_np = np.array(shape).astype("int64")
buffer_i = data[1].split()
buffer_np = np.array(buffer_i).astype("int64")
buffer_np.shape = tuple(shape_np)
buffers.append(buffer_np)
label = labels_lines[i]
yield buffers[0], buffers[1], int(label)
return reader
def _get_batch_correct(self, batch_output=None, labels=None):
total = len(batch_output)
assert total > 0, "The batch output is empty."
correct = 0
for n, output in enumerate(batch_output[0]):
max_idx = np.where(output == output.max())
if max_idx == labels[n]:
correct += 1
return correct
def _predict(self,
test_reader=None,
model_path=None,
batch_size=1,
batch_num=1,
skip_batch_num=0,
transform_to_int8=False):
place = fluid.CPUPlace()
exe = fluid.Executor(place)
inference_scope = fluid.executor.global_scope()
with fluid.scope_guard(inference_scope):
if os.path.exists(os.path.join(model_path, '__model__')):
[inference_program, feed_target_names,
fetch_targets] = fluid.io.load_inference_model(model_path, exe)
else:
[inference_program, feed_target_names,
fetch_targets] = fluid.io.load_inference_model(
model_path, exe, 'model', 'params')
graph = IrGraph(core.Graph(inference_program.desc), for_test=True)
if (self._debug):
graph.draw('.', 'qat_orig', graph.all_op_nodes())
if (transform_to_int8):
transform_to_mkldnn_int8_pass = Qat2Int8MkldnnPass(
self._quantized_ops,
_scope=inference_scope,
_place=place,
_core=core,
_debug=self._debug)
graph = transform_to_mkldnn_int8_pass.apply(graph)
inference_program = graph.to_program()
total_correct = 0
total_samples = 0
batch_times = []
ppses = [] # predictions per second
iters = 0
infer_start_time = time.time()
for data in test_reader():
if batch_num > 0 and iters >= batch_num:
break
if iters == skip_batch_num:
total_samples = 0
infer_start_time = time.time()
input0 = np.array([x[0] for x in data]).astype('int64')
input1 = np.array([x[1] for x in data]).astype('int64')
labels = np.array([x[2] for x in data]).astype('int64')
start = time.time()
out = exe.run(inference_program,
feed={
feed_target_names[0]: input0,
feed_target_names[1]: input1
},
fetch_list=fetch_targets)
batch_time = (time.time() - start) * 1000 # in miliseconds
batch_times.append(batch_time)
batch_correct = self._get_batch_correct(out, labels)
batch_len = len(data)
total_samples += batch_len
total_correct += batch_correct
batch_acc = float(batch_correct) / float(batch_len)
pps = batch_len / batch_time * 1000
ppses.append(pps)
latency = batch_time / batch_len
iters += 1
appx = ' (warm-up)' if iters <= skip_batch_num else ''
_logger.info(
'batch {0}{4}, acc: {1:.4f}, latency: {2:.4f} ms, predictions per sec: {3:.2f}'
.format(iters, batch_acc, latency, pps, appx))
# Postprocess benchmark data
infer_total_time = time.time() - infer_start_time
batch_latencies = batch_times[skip_batch_num:]
batch_latency_avg = np.average(batch_latencies)
latency_avg = batch_latency_avg / batch_size
ppses = ppses[skip_batch_num:]
pps_avg = np.average(ppses)
acc_avg = float(np.sum(total_correct)) / float(total_samples)
_logger.info('Total inference run time: {:.2f} s'.format(
infer_total_time))
return acc_avg, pps_avg, latency_avg
def _summarize_performance(self, fp32_pps, fp32_lat, int8_pps, int8_lat):
_logger.info('--- Performance summary ---')
_logger.info(
'FP32: avg predictions per sec: {0:.2f}, avg latency: {1:.4f} ms'.
format(fp32_pps, fp32_lat))
_logger.info(
'INT8: avg predictions per sec: {0:.2f}, avg latency: {1:.4f} ms'.
format(int8_pps, int8_lat))
def _compare_accuracy(self, fp32_acc, int8_acc, threshold):
_logger.info('--- Accuracy summary ---')
_logger.info(
'Accepted accuracy drop threshold: {0}. (condition: (FP32_acc - INT8_acc) <= threshold)'
.format(threshold))
_logger.info('FP32: avg accuracy: {0:.6f}'.format(fp32_acc))
_logger.info('INT8: avg accuracy: {0:.6f}'.format(int8_acc))
# Random outputs give accuracy about 0.33, we assume valid accuracy to be at least 0.5
assert fp32_acc > 0.5
assert int8_acc > 0.5
assert fp32_acc - int8_acc <= threshold
def test_graph_transformation(self):
if not fluid.core.is_compiled_with_mkldnn():
return
qat_model_path = test_case_args.qat_model
data_path = test_case_args.infer_data
labels_path = test_case_args.labels
batch_size = test_case_args.batch_size
batch_num = test_case_args.batch_num
skip_batch_num = test_case_args.skip_batch_num
acc_diff_threshold = test_case_args.acc_diff_threshold
self._debug = test_case_args.debug
self._quantized_ops = set(test_case_args.quantized_ops.split(','))
_logger.info('QAT FP32 & INT8 prediction run.')
_logger.info('QAT model: {0}'.format(qat_model_path))
_logger.info('Dataset: {0}'.format(data_path))
_logger.info('Labels: {0}'.format(labels_path))
_logger.info('Batch size: {0}'.format(batch_size))
_logger.info('Batch number: {0}'.format(batch_num))
_logger.info('Accuracy drop threshold: {0}.'.format(acc_diff_threshold))
_logger.info('Quantized ops: {0}.'.format(self._quantized_ops))
_logger.info('--- QAT FP32 prediction start ---')
val_reader = paddle.batch(
self._reader_creator(data_path, labels_path), batch_size=batch_size)
fp32_acc, fp32_pps, fp32_lat = self._predict(
val_reader,
qat_model_path,
batch_size,
batch_num,
skip_batch_num,
transform_to_int8=False)
_logger.info('--- QAT INT8 prediction start ---')
val_reader = paddle.batch(
self._reader_creator(data_path, labels_path), batch_size=batch_size)
int8_acc, int8_pps, int8_lat = self._predict(
val_reader,
qat_model_path,
batch_size,
batch_num,
skip_batch_num,
transform_to_int8=True)
self._summarize_performance(fp32_pps, fp32_lat, int8_pps, int8_lat)
self._compare_accuracy(fp32_acc, int8_acc, acc_diff_threshold)
if __name__ == '__main__':
global test_case_args
test_case_args, remaining_args = parse_args()
unittest.main(argv=remaining_args)
......@@ -24,7 +24,7 @@ import time
import paddle
import paddle.fluid as fluid
from paddle.fluid.framework import IrGraph
from paddle.fluid.contrib.slim.quantization import FakeQAT2MkldnnINT8PerfPass
from paddle.fluid.contrib.slim.quantization import Qat2Int8MkldnnPass
from paddle.fluid import core
......@@ -42,6 +42,11 @@ def parse_args():
type=str,
default='',
help='Saved optimized and quantized INT8 model')
parser.add_argument(
'--quantized_ops',
type=str,
default='',
help='A comma separated list of quantized operators.')
test_args, args = parser.parse_known_args(namespace=unittest)
return test_args, sys.argv[:1] + args
......@@ -60,8 +65,9 @@ def transform_and_save_model(original_path, save_path, save_type):
fetch_targets] = fluid.io.load_inference_model(original_path, exe,
'model', 'params')
transform_to_mkldnn_int8_pass = FakeQAT2MkldnnINT8PerfPass(
_scope=inference_scope, _place=place, _core=core)
quantized_ops = set(test_args.quantized_ops.split(','))
transform_to_mkldnn_int8_pass = Qat2Int8MkldnnPass(
quantized_ops, _scope=inference_scope, _place=place, _core=core)
graph = IrGraph(core.Graph(inference_program.desc), for_test=True)
if save_type == 'FP32':
......
......@@ -22,7 +22,7 @@ import paddle
from paddle.fluid.framework import IrGraph
from paddle.fluid.contrib.slim.quantization import QuantizationFreezePass
from paddle.fluid.contrib.slim.quantization import QuantizationTransformPass
from paddle.fluid.contrib.slim.quantization import FakeQAT2MkldnnINT8KernelPass
from paddle.fluid.contrib.slim.quantization import QatInt8MkldnnPass
from paddle.fluid import core
os.environ["CPU_NUM"] = "1"
......@@ -149,8 +149,7 @@ class TestMKLDNNTransformBasedFreezePass(unittest.TestCase):
freeze_pass.apply(test_graph)
# Transform quantized graph for MKL-DNN INT8 inference
mkldnn_int8_pass = FakeQAT2MkldnnINT8KernelPass(
_scope=scope, _place=place)
mkldnn_int8_pass = QatInt8MkldnnPass(_scope=scope, _place=place)
mkldnn_int8_pass.apply(test_graph)
dev_name = '_cpu_'
if not for_ci:
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册