diff --git a/paddle/fluid/inference/tests/api/CMakeLists.txt b/paddle/fluid/inference/tests/api/CMakeLists.txt index bd43962e10daa5dc3a8c42c7a45a66195fc1586f..145cd250a025906ff2052811d6fbf4aeaeef8206 100644 --- a/paddle/fluid/inference/tests/api/CMakeLists.txt +++ b/paddle/fluid/inference/tests/api/CMakeLists.txt @@ -226,15 +226,18 @@ if(WITH_MKLDNN) set(INT8_DATA_DIR "${INFERENCE_DEMO_INSTALL_DIR}/int8v2") - ### Image classification tests - set(IMAGENET_DATA_PATH "${INT8_DATA_DIR}/data.bin") - set(INT8_IMG_CLASS_TEST_APP "test_analyzer_int8_image_classification") - set(INT8_IMG_CLASS_TEST_APP_SRC "analyzer_int8_image_classification_tester.cc") + ## Image classification models - # download dataset if necessary - download_int8_data(${INT8_DATA_DIR} "imagenet_val_100_tail.tar.gz") + # ImageNet small dataset + # May be already downloaded for INT8 QAT unit tests + set(IMAGENET_DATA_ARCHIVE "imagenet_val_100_tail.tar.gz") + set(IMAGENET_DATA_DIR "${INFERENCE_DEMO_INSTALL_DIR}/imagenet") + set(IMAGENET_DATA_PATH "${IMAGENET_DATA_DIR}/data.bin") + download_int8_data(${IMAGENET_DATA_DIR} ${IMAGENET_DATA_ARCHIVE}) # build test binary to be used in subsequent tests + set(INT8_IMG_CLASS_TEST_APP "test_analyzer_int8_image_classification") + set(INT8_IMG_CLASS_TEST_APP_SRC "analyzer_int8_image_classification_tester.cc") inference_analysis_api_test_build(${INT8_IMG_CLASS_TEST_APP} ${INT8_IMG_CLASS_TEST_APP_SRC}) # resnet50 int8 @@ -296,7 +299,7 @@ if(WITH_MKLDNN) ### optimized FP32 vs. QAT INT8 tests - set(QAT_DATA_DIR "${INFERENCE_DEMO_INSTALL_DIR}/int8v2") + set(QAT_DATA_DIR "${INFERENCE_DEMO_INSTALL_DIR}/qat") set(QAT_IMG_CLASS_TEST_APP "test_analyzer_qat_image_classification") set(QAT_IMG_CLASS_TEST_APP_SRC "analyzer_qat_image_classification_tester.cc") @@ -304,8 +307,8 @@ if(WITH_MKLDNN) inference_analysis_api_test_build(${QAT_IMG_CLASS_TEST_APP} ${QAT_IMG_CLASS_TEST_APP_SRC}) # MobileNet FP32 vs. QAT INT8 + # The FP32 model should already be downloaded for slim QAT unit tests set(QAT2_MobileNet_MODEL_DIR "${QAT_DATA_DIR}/MobileNet_qat_perf") - download_qat_data(${QAT2_MobileNet_MODEL_DIR} "MobileNet_qat_perf.tar.gz") set(QAT2_INT8_MobileNet_MODEL_DIR "${QAT_DATA_DIR}/MobileNet_qat_perf_int8") download_qat_data(${QAT2_INT8_MobileNet_MODEL_DIR} "MobileNet_qat_perf_int8.tar.gz") inference_analysis_api_qat_test_run(test_analyzer_qat_performance_benchmark ${QAT_IMG_CLASS_TEST_APP} ${QAT2_MobileNet_MODEL_DIR}/MobileNet_qat_perf/float ${QAT2_INT8_MobileNet_MODEL_DIR}/MobileNet_qat_perf_int8 ${IMAGENET_DATA_PATH}) diff --git a/paddle/fluid/operators/mkldnn/fc_mkldnn_op.cc b/paddle/fluid/operators/mkldnn/fc_mkldnn_op.cc index dcf0b996bde40e5ccf7df6f9e22f735cadc5b64a..d45567bd5200cf8a46b92dfc39f194b4bada2d4c 100644 --- a/paddle/fluid/operators/mkldnn/fc_mkldnn_op.cc +++ b/paddle/fluid/operators/mkldnn/fc_mkldnn_op.cc @@ -479,7 +479,7 @@ GetPrimitiveFactory(const MKLDNNDeviceContext& dev_ctx, const Tensor* weights, const mkldnn::engine& mkldnn_engine) { const std::string key = platform::CreateKey( - platform::ThreadIDasStr(), input->format(), + platform::ThreadIDasStr(), input->format(), input->dims()[0], framework::vectorize(weights->dims()), ctx.OutputName("Out")); auto prim_creator = diff --git a/python/paddle/fluid/contrib/slim/quantization/quantization_mkldnn_pass.py b/python/paddle/fluid/contrib/slim/quantization/quantization_mkldnn_pass.py index 0feaa62e2f6b207bfed1b6ac5c5dea6f75b4a509..02750f9e83a7a366e67485df346e5cb82e85ff38 100644 --- a/python/paddle/fluid/contrib/slim/quantization/quantization_mkldnn_pass.py +++ b/python/paddle/fluid/contrib/slim/quantization/quantization_mkldnn_pass.py @@ -17,10 +17,10 @@ from .... import core from ....framework import IrGraph from ....framework import IrNode -__all__ = ['FakeQAT2MkldnnINT8KernelPass', 'FakeQAT2MkldnnINT8PerfPass'] +__all__ = ['QatInt8MkldnnPass', 'Qat2Int8MkldnnPass'] -class FakeQAT2MkldnnINT8KernelPass(object): +class QatInt8MkldnnPass(object): """ Convert QuantizationFreezePass generated IrGraph to MKL-DNN supported INT8 IrGraph. Following transformations did in this pass: @@ -48,13 +48,13 @@ class FakeQAT2MkldnnINT8KernelPass(object): # The original graph will be rewrite. import paddle.fluid as fluid from paddle.fluid.contrib.slim.quantization \ - import FakeQAT2MkldnnINT8KernelPass + import QatInt8MkldnnPass from paddle.fluid.framework import IrGraph from paddle.fluid import core graph = IrGraph(core.Graph(fluid.Program().desc), for_test=False) place = fluid.CPUPlace() - mkldnn_pass = FakeQAT2MkldnnINT8KernelPass(fluid.global_scope(), + mkldnn_pass = QatInt8MkldnnPass(fluid.global_scope(), place) mkldnn_pass.apply(graph) """ @@ -276,7 +276,7 @@ class FakeQAT2MkldnnINT8KernelPass(object): graph.safe_remove_nodes(all_unused_vars) -class FakeQAT2MkldnnINT8PerfPass(object): +class Qat2Int8MkldnnPass(object): """ Transform a QAT model IrGraph into MKL-DNN supported INT8 IrGraph. The pass consists of the following transformations: @@ -290,7 +290,12 @@ class FakeQAT2MkldnnINT8PerfPass(object): passes (`cpu_quantize_pass`, `cpu_quantize_squash_pass`). """ - def __init__(self, _scope=None, _place=None, _core=None, _debug=False): + def __init__(self, + _quantized_ops, + _scope=None, + _place=None, + _core=None, + _debug=False): self._scope = _scope self._place = _place self._core = _core @@ -305,6 +310,10 @@ class FakeQAT2MkldnnINT8PerfPass(object): 'fake_quantize_dequantize_moving_average_abs_max' ] self._fake_dequantize_types = ['fake_dequantize_max_abs'] + self._quantized_ops = _quantized_ops + self._scale_immutable_ops = [ + 'transpose2', 'reshape2', 'pool2d', 'scale' + ] self._conv_ops = ['conv2d', 'depthwise_conv2d'] self._pool_ops = ['pool2d'] self._mul_ops = ['mul'] @@ -324,8 +333,9 @@ class FakeQAT2MkldnnINT8PerfPass(object): graph = self._dequantize_weights(graph) graph = self._optimize_fp32_graph(graph) graph = self._compute_weight_scales(graph) - graph = self._update_conv_relu_scales(graph) - graph = self._update_pooling_scales(graph) + graph = self._update_relu_output_scales(graph) + graph = self._propagate_scales(graph) + graph = self._set_dummy_fc_out_scales(graph) graph = self._quantize_fp32_graph(graph) graph = self._remove_unused_var_nodes(graph) return graph @@ -346,6 +356,12 @@ class FakeQAT2MkldnnINT8PerfPass(object): tensor.set(scale, core.CPUPlace()) return tensor + def _is_conv_quantized(self): + return any(op_type in self._quantized_ops for op_type in self._conv_ops) + + def _is_fc_quantized(self): + return 'fc' in self._quantized_ops + def _gather_scales(self, graph): for op in graph.all_op_nodes(): if op.name() in self._quantize_types: @@ -371,34 +387,94 @@ class FakeQAT2MkldnnINT8PerfPass(object): self._weight_scales[input_name] = _max_range return graph - def _update_pooling_scales(self, graph): + def _propagate_scales(self, graph): + def _update_scale_op_in_scale(op, input, output): + unsigned, tensor = self._var_quant_scales[output] + scale = np.array(tensor) * op.op().attr("scale") + new_tensor = self._convert_scale2tensor(scale.astype(np.float64)) + self._var_quant_scales[input] = (unsigned, new_tensor) + + def _update_scales(graph): + waiting_for_scale = set() + for op in graph.all_op_nodes(): + if op.name() in self._scale_immutable_ops: + input_name = op.input("X")[0] + output_name = op.output("Out")[0] + tensor_names = [input_name, output_name] + + # Scale is not quantized, so if it doesn't have any scales + # to propagate, its tensors won't be added to the waiting list. + if all(name not in self._var_quant_scales for name in tensor_names) \ + and op.name() != 'scale': + waiting_for_scale.update(tensor_names) + continue + + if input_name in self._var_quant_scales: + self._var_quant_scales[ + output_name] = self._var_quant_scales[input_name] + elif output_name in self._var_quant_scales: + if op.name() == 'scale': + _update_scale_op_in_scale(op, input_name, + output_name) + else: + self._var_quant_scales[ + input_name] = self._var_quant_scales[ + output_name] + return waiting_for_scale + + waiting_for_scale = _update_scales(graph) + + while len(waiting_for_scale) != 0: + waiting_for_scale = _update_scales(graph) + + return graph + + def _set_dummy_fc_out_scales(self, graph): + ''' + For the output tensors of FC that do not have an assigned scale, + assign a dummy scale (same scale as input), so that the quantize pass + won't fail. In the end these scales aren't used, since FCs that + have an unassigend output scale will have a force_fp32_output attr + set to True. + ''' for op in graph.all_op_nodes(): - if op.name() in self._pool_ops: - input_name = op.input("X")[0] + if op.name() in self._fc_ops: + input_name = op.input("Input")[0] output_name = op.output("Out")[0] - if input_name in self._var_quant_scales: + if input_name in self._var_quant_scales and \ + output_name not in self._var_quant_scales: + # use input scale as a "dummy" scale self._var_quant_scales[ output_name] = self._var_quant_scales[input_name] + return graph def _load_param(self, scope, param_name): return np.array(scope.find_var(param_name).get_tensor()) def _remove_fake_ops(self, graph): + ''' + When FC isn't quantized: + Remove fake (de)quantize ops that do not surround mul. + When FC is quantized: + Remove all fake (de)quantize ops. + ''' + is_fc_quantized = self._is_fc_quantized() for op in graph.all_op_nodes(): if op.name() in self._fake_quantize_types: op_out = graph._find_node_by_name(op.outputs, op.output("Out")[0]) next_op = op_out.outputs[0] - if next_op.name() not in self._mul_ops: + if next_op.name() not in self._mul_ops or is_fc_quantized: self._remove_fake_quantize(graph, op) for op in graph.all_op_nodes(): if op.name() in self._fake_dequantize_types: op_in = graph._find_node_by_name(op.inputs, op.input("X")[0]) prev_op = op_in.inputs[0] - if prev_op.name() not in self._mul_ops: + if prev_op.name() not in self._mul_ops or is_fc_quantized: self._remove_fake_dequantize(graph, op) + return graph def _remove_fake_quantize(self, graph, op): @@ -444,6 +520,8 @@ class FakeQAT2MkldnnINT8PerfPass(object): for op in graph.all_op_nodes(): if op.name() in self._conv_ops: self._dequantize_conv_weights(graph, op) + elif self._is_fc_quantized() and op.name() in self._mul_ops: + self._dequantize_mul_weights(graph, op) return graph def _dequantize_conv_weights(self, graph, op_node): @@ -472,13 +550,20 @@ class FakeQAT2MkldnnINT8PerfPass(object): def _optimize_fp32_graph(self, graph): graph = self._apply_pass(graph, 'mkldnn_placement_pass', ['mkldnn_enabled_op_types'], [set()]) - graph = self._apply_pass(graph, 'depthwise_conv_mkldnn_pass') - graph = self._apply_pass(graph, 'conv_bn_fuse_pass') - graph = self._apply_pass(graph, 'conv_eltwiseadd_bn_fuse_pass') - graph = self._apply_pass(graph, 'conv_bias_mkldnn_fuse_pass') - graph = self._apply_pass(graph, 'conv_elementwise_add_mkldnn_fuse_pass') - graph = self._apply_pass(graph, 'conv_relu_mkldnn_fuse_pass') - graph = self._apply_pass(graph, 'conv_relu6_mkldnn_fuse_pass') + if self._is_conv_quantized(): + graph = self._apply_pass(graph, 'depthwise_conv_mkldnn_pass') + graph = self._apply_pass(graph, 'conv_bn_fuse_pass') + graph = self._apply_pass(graph, 'conv_eltwiseadd_bn_fuse_pass') + graph = self._apply_pass(graph, 'conv_bias_mkldnn_fuse_pass') + graph = self._apply_pass(graph, + 'conv_elementwise_add_mkldnn_fuse_pass') + graph = self._apply_pass(graph, 'conv_relu_mkldnn_fuse_pass') + graph = self._apply_pass(graph, 'conv_relu6_mkldnn_fuse_pass') + if self._is_fc_quantized(): + graph = self._apply_pass(graph, 'fc_fuse_pass', + ['use_gpu', 'use_fc_padding'], + [False, False]) + graph = self._apply_pass(graph, 'fc_mkldnn_pass') return graph def _apply_pass(self, graph, pass_name, attrs=None, attr_values=None): @@ -528,6 +613,7 @@ class FakeQAT2MkldnnINT8PerfPass(object): np.abs(weights.reshape(weights.shape[0], -1)).astype( np.float64), axis=axis) + scales[scales == np.Inf] = 0.0 lod_tensor = self._convert_scale2tensor(scales) use_unsigned_int = False @@ -546,46 +632,41 @@ class FakeQAT2MkldnnINT8PerfPass(object): ids.append(op.id()) return set(ids) if len(ids) else set([-1]) - def _transform_to_quantize_mkldnn(self, graph, op_node): - """ - Transform fake_quantize_xx op to quantize mkldnn op in the graph. - """ - input_var_node = graph._find_node_by_name(op_node.inputs, - op_node.input("X")[0]) - output_var_node = graph._find_node_by_name(op_node.outputs, - op_node.output("Out")[0]) - scale_in = self._s8_max / self._load_param( - self._scope, op_node.input("InScale")[0])[0] - quant_op_node = graph.create_op_node( - op_type='quantize', - attrs={ - 'data_format': 'MKLDNNLAYOUT', - 'use_mkldnn': 1, - 'Scale': scale_in, - 'is_negative_input': 1 - }, - inputs={'Input': input_var_node}, - outputs={'Output': output_var_node}) - graph.link_to(input_var_node, quant_op_node) - graph.link_to(quant_op_node, output_var_node) - graph.safe_remove_nodes(op_node) - return quant_op_node + def _update_relu_output_scales(self, graph): + def _update_scale(graph, ops, op_out_name, predicate): + ''' + Sets the type of an output scale of a passed op type(s) to 'unsigned int8' if the + predicate applied on op passes. Typically, the predicate checks if op's + activation is set to relu. + ''' + for op in graph.all_op_nodes(): + if op.name() in ops: + out_name = op.output(op_out_name)[0] + if out_name in self._var_quant_scales and predicate(op.op( + )): + _, tensor = self._var_quant_scales[out_name] + self._var_quant_scales[out_name] = (True, tensor) + return graph + + if self._is_conv_quantized(): + conv_predicate = lambda op: op.attr("fuse_activation") == 'relu' and \ + op.attr("fuse_residual_connection") == False + graph = _update_scale(graph, self._conv_ops, "Output", + conv_predicate) + + if self._is_fc_quantized(): + fc_predicate = lambda op: op.attr("activation_type") == 'relu' + graph = _update_scale(graph, self._fc_ops, "Out", fc_predicate) - def _update_conv_relu_scales(self, graph): - for op in graph.all_op_nodes(): - if op.name() in self._conv_ops: - out_name = op.output("Output")[0] - if out_name in self._var_quant_scales and \ - op.op().attr("fuse_activation") == 'relu' and \ - op.op().attr("fuse_residual_connection") == False: - _, tensor = self._var_quant_scales[out_name] - self._var_quant_scales[out_name] = (True, tensor) return graph + def _get_data_layout(self): + return 'NHWC' if self._is_conv_quantized() else 'NCHW' + def _quantize_fp32_graph(self, graph): ir_pass = self._core.get_pass('cpu_quantize_placement_pass') cpp_graph = graph.graph - ir_pass.set('quantize_enabled_op_types', {'conv2d', 'pool2d'}) + ir_pass.set('quantize_enabled_op_types', self._quantized_ops) ir_pass.set('quantize_excluded_op_ids', self._find_avg_pooling_ids(graph)) ir_pass.apply(cpp_graph) @@ -593,8 +674,8 @@ class FakeQAT2MkldnnINT8PerfPass(object): graph.draw('.', 'qat_int8_{}'.format(ir_pass.type()), graph.all_op_nodes()) - graph = self._apply_pass(graph, 'cpu_quantize_pass', - ['quant_var_scales'], - [self._var_quant_scales]) + graph = self._apply_pass( + graph, 'cpu_quantize_pass', ['quant_var_scales', 'data_layout'], + [self._var_quant_scales, self._get_data_layout()]) graph = self._apply_pass(graph, 'cpu_quantize_squash_pass') return graph diff --git a/python/paddle/fluid/contrib/slim/tests/CMakeLists.txt b/python/paddle/fluid/contrib/slim/tests/CMakeLists.txt index 8b80cac9018b71b3bffb5184e308f641dca08f18..c79d924fe09e39bda461c9064308347ff5700133 100644 --- a/python/paddle/fluid/contrib/slim/tests/CMakeLists.txt +++ b/python/paddle/fluid/contrib/slim/tests/CMakeLists.txt @@ -1,23 +1,19 @@ file(GLOB TEST_OPS RELATIVE "${CMAKE_CURRENT_SOURCE_DIR}" "test_*.py") string(REPLACE ".py" "" TEST_OPS "${TEST_OPS}") -function(_inference_analysis_python_api_int8_test target model_dir data_dir filename use_mkldnn) +function(_inference_analysis_python_api_int8_test target model_dir data_path filename use_mkldnn) py_test(${target} SRCS ${filename} ENVS CPU_NUM_THREADS=${CPU_NUM_THREADS_ON_CI} FLAGS_use_mkldnn=${use_mkldnn} ARGS --infer_model ${model_dir}/model - --infer_data ${data_dir}/data.bin + --infer_data ${data_path} --int8_model_save_path int8_models/${target} --warmup_batch_size ${WARMUP_BATCH_SIZE} --batch_size 50) endfunction() -function(inference_analysis_python_api_int8_test target model_dir data_dir filename) - _inference_analysis_python_api_int8_test(${target} ${model_dir} ${data_dir} ${filename} False) -endfunction() - -function(inference_analysis_python_api_int8_test_mkldnn target model_dir data_dir filename) - _inference_analysis_python_api_int8_test(${target} ${model_dir} ${data_dir} ${filename} True) +function(inference_analysis_python_api_int8_test target model_dir data_path filename) + _inference_analysis_python_api_int8_test(${target} ${model_dir} ${data_path} ${filename} False) endfunction() function(inference_analysis_python_api_int8_test_custom_warmup_batch_size target model_dir data_dir filename warmup_batch_size) @@ -25,13 +21,29 @@ function(inference_analysis_python_api_int8_test_custom_warmup_batch_size target inference_analysis_python_api_int8_test(${target} ${model_dir} ${data_dir} ${filename}) endfunction() -function(inference_qat_int8_test target model_dir data_dir test_script use_mkldnn) - py_test(${target} SRCS ${test_script} +function(inference_analysis_python_api_int8_test_mkldnn target model_dir data_path filename) + _inference_analysis_python_api_int8_test(${target} ${model_dir} ${data_path} ${filename} True) +endfunction() + +function(download_qat_data install_dir data_file) + if (NOT EXISTS ${install_dir}/${data_file}) + inference_download_and_uncompress(${install_dir} ${INFERENCE_URL}/int8 ${data_file}) + endif() +endfunction() + +function(download_qat_model install_dir data_file) + if (NOT EXISTS ${install_dir}/${data_file}) + inference_download_and_uncompress(${install_dir} ${INFERENCE_URL}/int8/QAT_models ${data_file}) + endif() +endfunction() + +function(inference_qat_int8_image_classification_test target model_dir dataset_path) + py_test(${target} SRCS "${CMAKE_CURRENT_SOURCE_DIR}/qat_int8_image_classification_comparison.py" ENVS FLAGS_OMP_NUM_THREADS=${CPU_NUM_THREADS_ON_CI} OMP_NUM_THREADS=${CPU_NUM_THREADS_ON_CI} - FLAGS_use_mkldnn=${use_mkldnn} + FLAGS_use_mkldnn=true ARGS --qat_model ${model_dir}/model - --infer_data ${data_dir}/data.bin + --infer_data ${dataset_path} --batch_size 25 --batch_num 2 --acc_diff_threshold 0.1) @@ -39,24 +51,53 @@ endfunction() # set batch_size 10 for UT only (avoid OOM). For whole dataset, use batch_size 25 -function(inference_qat2_int8_test target model_dir data_dir test_script use_mkldnn) - py_test(${target} SRCS ${test_script} +function(inference_qat2_int8_image_classification_test target model_dir data_path quantized_ops) + py_test(${target} SRCS "${CMAKE_CURRENT_SOURCE_DIR}/qat_int8_image_classification_comparison.py" ENVS FLAGS_OMP_NUM_THREADS=${CPU_NUM_THREADS_ON_CI} OMP_NUM_THREADS=${CPU_NUM_THREADS_ON_CI} - FLAGS_use_mkldnn=${use_mkldnn} + FLAGS_use_mkldnn=true ARGS --qat_model ${model_dir}/float - --infer_data ${data_dir}/data.bin + --infer_data ${data_path} --batch_size 10 --batch_num 2 --acc_diff_threshold 0.1 - --qat2) + --quantized_ops ${quantized_ops} + --qat2) +endfunction() + +# set batch_size 10 for UT only (avoid OOM). For whole dataset, use batch_size 20 +function(inference_qat2_int8_nlp_test target model_dir data_path labels_path quantized_ops) + py_test(${target} SRCS "${CMAKE_CURRENT_SOURCE_DIR}/qat_int8_nlp_comparison.py" + ENVS FLAGS_OMP_NUM_THREADS=${CPU_NUM_THREADS_ON_CI} + OMP_NUM_THREADS=${CPU_NUM_THREADS_ON_CI} + FLAGS_use_mkldnn=true + ARGS --qat_model ${model_dir}/float + --infer_data ${data_path} + --labels ${labels_path} + --batch_size 10 + --batch_num 2 + --quantized_ops ${quantized_ops} + --acc_diff_threshold 0.1) +endfunction() + +function(download_qat_data install_dir data_file) + if (NOT EXISTS ${install_dir}/${data_file}) + inference_download_and_uncompress(${install_dir} ${INFERENCE_URL}/int8 ${data_file}) + endif() +endfunction() + +function(download_qat_model install_dir data_file) + if (NOT EXISTS ${install_dir}/${data_file}) + inference_download_and_uncompress(${install_dir} ${INFERENCE_URL}/int8/QAT_models ${data_file}) + endif() endfunction() -function(save_qat_model_test target qat_model_dir fp32_model_save_path int8_model_save_path test_script) - py_test(${target} SRCS ${test_script} +function(save_qat_model_test target qat_model_dir fp32_model_save_path int8_model_save_path quantized_ops) + py_test(${target} SRCS ${CMAKE_CURRENT_SOURCE_DIR}/save_qat_model.py ARGS --qat_model_path ${qat_model_dir} - --fp32_model_save_path ${fp32_model_save_path} - --int8_model_save_path ${int8_model_save_path}) + --fp32_model_save_path ${fp32_model_save_path} + --int8_model_save_path ${int8_model_save_path} + --quantized_ops ${quantized_ops}) endfunction() if(WIN32) @@ -66,137 +107,151 @@ if(WIN32) list(REMOVE_ITEM TEST_OPS test_weight_quantization_mobilenetv1) endif() -# int8 image classification python api test if(LINUX AND WITH_MKLDNN) - set(INT8_DATA_DIR "${INFERENCE_DEMO_INSTALL_DIR}/int8v2") - set(MKLDNN_INT8_TEST_FILE "test_mkldnn_int8_quantization_strategy.py") - set(MKLDNN_INT8_TEST_FILE_PATH "${CMAKE_CURRENT_SOURCE_DIR}/${MKLDNN_INT8_TEST_FILE}") - - # googlenet int8 - set(INT8_GOOGLENET_MODEL_DIR "${INT8_DATA_DIR}/googlenet") - inference_analysis_python_api_int8_test_custom_warmup_batch_size(test_slim_int8_googlenet ${INT8_GOOGLENET_MODEL_DIR} ${INT8_DATA_DIR} ${MKLDNN_INT8_TEST_FILE_PATH} 10) - - # mobilenet int8 - set(INT8_MOBILENET_MODEL_DIR "${INT8_DATA_DIR}/mobilenetv1") - inference_analysis_python_api_int8_test(test_slim_int8_mobilenet ${INT8_MOBILENET_MODEL_DIR} ${INT8_DATA_DIR} ${MKLDNN_INT8_TEST_FILE_PATH}) - inference_analysis_python_api_int8_test_mkldnn(test_slim_int8_mobilenet_mkldnn ${INT8_MOBILENET_MODEL_DIR} ${INT8_DATA_DIR} ${MKLDNN_INT8_TEST_FILE_PATH}) - - # temporarily adding WITH_SLIM_MKLDNN_FULL_TEST FLAG for QA testing the following UTs locally, - # since the following UTs cost too much time on CI test. - if (WITH_SLIM_MKLDNN_FULL_TEST) - # resnet50 int8 - set(INT8_RESNET50_MODEL_DIR "${INT8_DATA_DIR}/resnet50") - inference_analysis_python_api_int8_test(test_slim_int8_resnet50 ${INT8_RESNET50_MODEL_DIR} ${INT8_DATA_DIR} ${MKLDNN_INT8_TEST_FILE_PATH}) - - # mobilenetv2 int8 - set(INT8_MOBILENETV2_MODEL_DIR "${INT8_DATA_DIR}/mobilenetv2") - inference_analysis_python_api_int8_test(test_slim_int8_mobilenetv2 ${INT8_MOBILENETV2_MODEL_DIR} ${INT8_DATA_DIR} ${MKLDNN_INT8_TEST_FILE_PATH}) - - # resnet101 int8 - set(INT8_RESNET101_MODEL_DIR "${INT8_DATA_DIR}/resnet101") - inference_analysis_python_api_int8_test(test_slim_int8_resnet101 ${INT8_RESNET101_MODEL_DIR} ${INT8_DATA_DIR} ${MKLDNN_INT8_TEST_FILE_PATH}) - - # vgg16 int8 - set(INT8_VGG16_MODEL_DIR "${INT8_DATA_DIR}/vgg16") - inference_analysis_python_api_int8_test(test_slim_int8_vgg16 ${INT8_VGG16_MODEL_DIR} ${INT8_DATA_DIR} ${MKLDNN_INT8_TEST_FILE_PATH}) - - # vgg19 int8 - set(INT8_VGG19_MODEL_DIR "${INT8_DATA_DIR}/vgg19") - inference_analysis_python_api_int8_test(test_slim_int8_vgg19 ${INT8_VGG19_MODEL_DIR} ${INT8_DATA_DIR} ${MKLDNN_INT8_TEST_FILE_PATH}) - endif() -endif() -# Since test_mkldnn_int8_quantization_strategy only supports testing on Linux -# with MKL-DNN, we remove it here for not repeating test, or not testing on other systems. -list(REMOVE_ITEM TEST_OPS test_mkldnn_int8_quantization_strategy) + #### Image classification dataset: ImageNet (small) + # The dataset should already be downloaded for INT8v2 unit tests + set(IMAGENET_DATA_PATH "${INFERENCE_DEMO_INSTALL_DIR}/imagenet/data.bin") -# QAT FP32 & INT8 comparison python api tests -if(LINUX AND WITH_MKLDNN) - set(DATASET_DIR "${INFERENCE_DEMO_INSTALL_DIR}/int8v2") - set(QAT_DATA_DIR "${INFERENCE_DEMO_INSTALL_DIR}/int8v2") - set(QAT_MODELS_BASE_URL "${INFERENCE_URL}/int8/QAT_models") - set(MKLDNN_QAT_TEST_FILE "qat_int8_comparison.py") - set(MKLDNN_QAT_TEST_FILE_PATH "${CMAKE_CURRENT_SOURCE_DIR}/${MKLDNN_QAT_TEST_FILE}") - - # ImageNet small dataset - # May be already downloaded for INT8v2 unit tests - if (NOT EXISTS ${DATASET_DIR}) - inference_download_and_uncompress(${DATASET_DIR} "${INFERENCE_URL}/int8" "imagenet_val_100_tail.tar.gz") + #### INT8 image classification python api test + # Models should be already downloaded for INT8v2 unit tests + + set(INT8_INSTALL_DIR "${INFERENCE_DEMO_INSTALL_DIR}/int8v2") + set(INT8_IC_TEST_FILE "test_mkldnn_int8_quantization_strategy.py") + set(INT8_IC_TEST_FILE_PATH "${CMAKE_CURRENT_SOURCE_DIR}/${INT8_IC_TEST_FILE}") + + # googlenet int8 + set(INT8_GOOGLENET_MODEL_DIR "${INT8_INSTALL_DIR}/googlenet") + inference_analysis_python_api_int8_test_custom_warmup_batch_size(test_slim_int8_googlenet ${INT8_GOOGLENET_MODEL_DIR} ${IMAGENET_DATA_PATH} ${INT8_IC_TEST_FILE_PATH} 10) + + # mobilenet int8 + set(INT8_MOBILENET_MODEL_DIR "${INT8_INSTALL_DIR}/mobilenetv1") + inference_analysis_python_api_int8_test(test_slim_int8_mobilenet ${INT8_MOBILENET_MODEL_DIR} ${IMAGENET_DATA_PATH} ${INT8_IC_TEST_FILE_PATH}) + inference_analysis_python_api_int8_test_mkldnn(test_slim_int8_mobilenet_mkldnn ${INT8_MOBILENET_MODEL_DIR} ${IMAGENET_DATA_PATH} ${INT8_IC_TEST_FILE_PATH}) + + # temporarily adding WITH_SLIM_MKLDNN_FULL_TEST FLAG for QA testing the following UTs locally, + # since the following UTs cost too much time on CI test. + if (WITH_SLIM_MKLDNN_FULL_TEST) + # resnet50 int8 + set(INT8_RESNET50_MODEL_DIR "${INT8_INSTALL_DIR}/resnet50") + inference_analysis_python_api_int8_test(test_slim_int8_resnet50 ${INT8_RESNET50_MODEL_DIR} ${IMAGENET_DATA_PATH} ${INT8_IC_TEST_FILE_PATH}) + + # mobilenetv2 int8 + set(INT8_MOBILENETV2_MODEL_DIR "${INT8_INSTALL_DIR}/mobilenetv2") + inference_analysis_python_api_int8_test(test_slim_int8_mobilenetv2 ${INT8_MOBILENETV2_MODEL_DIR} ${IMAGENET_DATA_PATH} ${INT8_IC_TEST_FILE_PATH}) + + # resnet101 int8 + set(INT8_RESNET101_MODEL_DIR "${INT8_INSTALL_DIR}/resnet101") + inference_analysis_python_api_int8_test(test_slim_int8_resnet101 ${INT8_RESNET101_MODEL_DIR} ${IMAGENET_DATA_PATH} ${INT8_IC_TEST_FILE_PATH}) + + # vgg16 int8 + set(INT8_VGG16_MODEL_DIR "${INT8_INSTALL_DIR}/vgg16") + inference_analysis_python_api_int8_test(test_slim_int8_vgg16 ${INT8_VGG16_MODEL_DIR} ${IMAGENET_DATA_PATH} ${INT8_IC_TEST_FILE_PATH}) + + # vgg19 int8 + set(INT8_VGG19_MODEL_DIR "${INT8_INSTALL_DIR}/vgg19") + inference_analysis_python_api_int8_test(test_slim_int8_vgg19 ${INT8_VGG19_MODEL_DIR} ${IMAGENET_DATA_PATH} ${INT8_IC_TEST_FILE_PATH}) endif() + #### QAT FP32 & INT8 comparison python api tests + + set(QAT_INSTALL_DIR "${INFERENCE_DEMO_INSTALL_DIR}/qat") + + ### QATv1 for image classification + # QAT ResNet50 - set(QAT_RESNET50_MODEL_DIR "${QAT_DATA_DIR}/ResNet50_QAT") - if (NOT EXISTS ${QAT_RESNET50_MODEL_DIR}) - inference_download_and_uncompress(${QAT_RESNET50_MODEL_DIR} "${QAT_MODELS_BASE_URL}" "ResNet50_qat_model.tar.gz" ) - endif() - inference_qat_int8_test(test_qat_int8_resnet50_mkldnn ${QAT_RESNET50_MODEL_DIR} ${DATASET_DIR} ${MKLDNN_QAT_TEST_FILE_PATH} true) + set(QAT_RESNET50_MODEL_DIR "${QAT_INSTALL_DIR}/ResNet50_QAT") + set(QAT_RESNET50_MODEL_ARCHIVE "ResNet50_qat_model.tar.gz") + download_qat_model(${QAT_RESNET50_MODEL_DIR} ${QAT_RESNET50_MODEL_ARCHIVE}) + inference_qat_int8_image_classification_test(test_qat_int8_resnet50_mkldnn ${QAT_RESNET50_MODEL_DIR} ${IMAGENET_DATA_PATH}) # QAT ResNet101 - set(QAT_RESNET101_MODEL_DIR "${QAT_DATA_DIR}/ResNet101_QAT") - if (NOT EXISTS ${QAT_RESNET101_MODEL_DIR}) - inference_download_and_uncompress(${QAT_RESNET101_MODEL_DIR} "${QAT_MODELS_BASE_URL}" "ResNet101_qat_model.tar.gz" ) - endif() - # inference_qat_int8_test(test_qat_int8_resnet101_mkldnn ${QAT_RESNET101_MODEL_DIR} ${DATASET_DIR} ${MKLDNN_QAT_TEST_FILE_PATH} true) + set(QAT_RESNET101_MODEL_DIR "${QAT_INSTALL_DIR}/ResNet101_QAT") + set(QAT_RESNET101_MODEL_ARCHIVE "ResNet101_qat_model.tar.gz") + download_qat_model(${QAT_RESNET101_MODEL_DIR} ${QAT_RESNET101_MODEL_ARCHIVE}) + # inference_qat_int8_image_classification_test(test_qat_int8_resnet101_mkldnn ${QAT_RESNET101_MODEL_DIR} ${IMAGENET_DATA_PATH}) # QAT GoogleNet - set(QAT_GOOGLENET_MODEL_DIR "${QAT_DATA_DIR}/GoogleNet_QAT") - if (NOT EXISTS ${QAT_GOOGLENET_MODEL_DIR}) - inference_download_and_uncompress(${QAT_GOOGLENET_MODEL_DIR} "${QAT_MODELS_BASE_URL}" "GoogleNet_qat_model.tar.gz" ) - endif() - inference_qat_int8_test(test_qat_int8_googlenet_mkldnn ${QAT_GOOGLENET_MODEL_DIR} ${DATASET_DIR} ${MKLDNN_QAT_TEST_FILE_PATH} true) + set(QAT_GOOGLENET_MODEL_DIR "${QAT_INSTALL_DIR}/GoogleNet_QAT") + set(QAT_GOOGLENET_MODEL_ARCHIVE "GoogleNet_qat_model.tar.gz") + download_qat_model(${QAT_GOOGLENET_MODEL_DIR} ${QAT_GOOGLENET_MODEL_ARCHIVE}) + inference_qat_int8_image_classification_test(test_qat_int8_googlenet_mkldnn ${QAT_GOOGLENET_MODEL_DIR} ${IMAGENET_DATA_PATH}) # QAT MobileNetV1 - set(QAT_MOBILENETV1_MODEL_DIR "${QAT_DATA_DIR}/MobileNetV1_QAT") - if (NOT EXISTS ${QAT_MOBILENETV1_MODEL_DIR}) - inference_download_and_uncompress(${QAT_MOBILENETV1_MODEL_DIR} "${QAT_MODELS_BASE_URL}" "MobileNetV1_qat_model.tar.gz" ) - endif() - inference_qat_int8_test(test_qat_int8_mobilenetv1_mkldnn ${QAT_MOBILENETV1_MODEL_DIR} ${DATASET_DIR} ${MKLDNN_QAT_TEST_FILE_PATH} true) + set(QAT_MOBILENETV1_MODEL_DIR "${QAT_INSTALL_DIR}/MobileNetV1_QAT") + set(QAT_MOBILENETV1_MODEL_ARCHIVE "MobileNetV1_qat_model.tar.gz") + download_qat_model(${QAT_MOBILENETV1_MODEL_DIR} ${QAT_MOBILENETV1_MODEL_ARCHIVE}) + inference_qat_int8_image_classification_test(test_qat_int8_mobilenetv1_mkldnn ${QAT_MOBILENETV1_MODEL_DIR} ${IMAGENET_DATA_PATH}) # QAT MobileNetV2 - set(QAT_MOBILENETV2_MODEL_DIR "${QAT_DATA_DIR}/MobileNetV2_QAT") - if (NOT EXISTS ${QAT_MOBILENETV2_MODEL_DIR}) - inference_download_and_uncompress(${QAT_MOBILENETV2_MODEL_DIR} "${QAT_MODELS_BASE_URL}" "MobileNetV2_qat_model.tar.gz" ) - endif() - inference_qat_int8_test(test_qat_int8_mobilenetv2_mkldnn ${QAT_MOBILENETV2_MODEL_DIR} ${DATASET_DIR} ${MKLDNN_QAT_TEST_FILE_PATH} true) + set(QAT_MOBILENETV2_MODEL_DIR "${QAT_INSTALL_DIR}/MobileNetV2_QAT") + set(QAT_MOBILENETV2_MODEL_ARCHIVE "MobileNetV2_qat_model.tar.gz") + download_qat_model(${QAT_MOBILENETV2_MODEL_DIR} ${QAT_MOBILENETV2_MODEL_ARCHIVE}) + inference_qat_int8_image_classification_test(test_qat_int8_mobilenetv2_mkldnn ${QAT_MOBILENETV2_MODEL_DIR} ${IMAGENET_DATA_PATH}) # QAT VGG16 - set(QAT_VGG16_MODEL_DIR "${QAT_DATA_DIR}/VGG16_QAT") - if (NOT EXISTS ${QAT_VGG16_MODEL_DIR}) - inference_download_and_uncompress(${QAT_VGG16_MODEL_DIR} "${QAT_MODELS_BASE_URL}" "VGG16_qat_model.tar.gz" ) - endif() - # inference_qat_int8_test(test_qat_int8_vgg16_mkldnn ${QAT_VGG16_MODEL_DIR} ${DATASET_DIR} ${MKLDNN_QAT_TEST_FILE_PATH} true) + set(QAT_VGG16_MODEL_DIR "${QAT_INSTALL_DIR}/VGG16_QAT") + set(QAT_VGG16_MODEL_ARCHIVE "VGG16_qat_model.tar.gz") + download_qat_model(${QAT_VGG16_MODEL_DIR} ${QAT_VGG16_MODEL_ARCHIVE}) + # inference_qat_int8_image_classification_test(test_qat_int8_vgg16_mkldnn ${QAT_VGG16_MODEL_DIR} ${IMAGENET_DATA_PATH}) # QAT VGG19 - set(QAT_VGG19_MODEL_DIR "${QAT_DATA_DIR}/VGG19_QAT") - if (NOT EXISTS ${QAT_VGG19_MODEL_DIR}) - inference_download_and_uncompress(${QAT_VGG19_MODEL_DIR} "${QAT_MODELS_BASE_URL}" "VGG19_qat_model.tar.gz" ) - endif() - # inference_qat_int8_test(test_qat_int8_vgg19_mkldnn ${QAT_VGG19_MODEL_DIR} ${DATASET_DIR} ${MKLDNN_QAT_TEST_FILE_PATH} true) - - set(QAT2_RESNET50_MODEL_DIR "${QAT_DATA_DIR}/ResNet50_qat_perf") - if (NOT EXISTS ${QAT2_RESNET50_MODEL_DIR}) - inference_download_and_uncompress(${QAT2_RESNET50_MODEL_DIR} "${QAT_MODELS_BASE_URL}" "ResNet50_qat_perf.tar.gz" ) - endif() - inference_qat2_int8_test(test_qat2_int8_resnet50_mkldnn ${QAT2_RESNET50_MODEL_DIR}/ResNet50_qat_perf ${DATASET_DIR} ${MKLDNN_QAT_TEST_FILE_PATH} true) - - set(QAT2_MOBILENETV1_MODEL_DIR "${QAT_DATA_DIR}/MobileNet_qat_perf") - if (NOT EXISTS ${QAT2_MOBILENETV1_MODEL_DIR}) - inference_download_and_uncompress(${QAT2_MOBILENETV1_MODEL_DIR} "${QAT_MODELS_BASE_URL}" "MobileNet_qat_perf.tar.gz" ) - endif() - inference_qat2_int8_test(test_qat2_int8_mobilenetv1_mkldnn ${QAT2_MOBILENETV1_MODEL_DIR}/MobileNet_qat_perf ${DATASET_DIR} ${MKLDNN_QAT_TEST_FILE_PATH} true) - - # Save qat2 fp32 model or qat2 int8 model + set(QAT_VGG19_MODEL_DIR "${QAT_INSTALL_DIR}/VGG19_QAT") + set(QAT_VGG19_MODEL_ARCHIVE "VGG19_qat_model.tar.gz") + download_qat_model(${QAT_VGG19_MODEL_DIR} ${QAT_VGG19_MODEL_ARCHIVE}) + # inference_qat_int8_image_classification_test(test_qat_int8_vgg19_mkldnn ${QAT_VGG19_MODEL_DIR} ${IMAGENET_DATA_PATH}) + + ### QATv2 for image classification + + set(QAT2_IC_QUANTIZED_OPS "conv2d,pool2d") + + # QAT2 ResNet50 + set(QAT2_RESNET50_MODEL_DIR "${QAT_INSTALL_DIR}/ResNet50_qat_perf") + set(QAT2_RESNET50_MODEL_ARCHIVE "ResNet50_qat_perf.tar.gz") + download_qat_model(${QAT2_RESNET50_MODEL_DIR} ${QAT2_RESNET50_MODEL_ARCHIVE}) + inference_qat2_int8_image_classification_test(test_qat2_int8_resnet50_mkldnn ${QAT2_RESNET50_MODEL_DIR}/ResNet50_qat_perf ${IMAGENET_DATA_PATH} ${QAT2_IC_QUANTIZED_OPS}) + + # QAT2 MobileNetV1 + set(QAT2_MOBILENETV1_MODEL_DIR "${QAT_INSTALL_DIR}/MobileNet_qat_perf") + set(QAT2_MOBILENETV1_MODEL_ARCHIVE "MobileNet_qat_perf.tar.gz") + download_qat_model(${QAT2_MOBILENETV1_MODEL_DIR} ${QAT2_MOBILENETV1_MODEL_ARCHIVE}) + inference_qat2_int8_image_classification_test(test_qat2_int8_mobilenetv1_mkldnn ${QAT2_MOBILENETV1_MODEL_DIR}/MobileNet_qat_perf ${IMAGENET_DATA_PATH} ${QAT2_IC_QUANTIZED_OPS}) + + ### QATv2 for NLP + + set(QAT2_NLP_QUANTIZED_OPS "fc,reshape2,transpose2") + + set(NLP_DATA_ARCHIVE "Ernie_dataset.tar.gz") + set(NLP_DATA_DIR "${INFERENCE_DEMO_INSTALL_DIR}/Ernie_dataset") + set(NLP_DATA_PATH "${NLP_DATA_DIR}/Ernie_dataset/1.8w.bs1") + set(NLP_LABLES_PATH "${NLP_DATA_DIR}/Ernie_dataset/label.xnli.dev") + download_qat_data(${NLP_DATA_DIR} ${NLP_DATA_ARCHIVE}) + + # QAT2 Ernie + set(QAT2_ERNIE_MODEL_ARCHIVE "ernie_qat.tar.gz") + set(QAT2_ERNIE_MODEL_DIR "${QAT_INSTALL_DIR}/Ernie_qat") + download_qat_model(${QAT2_ERNIE_MODEL_DIR} ${QAT2_ERNIE_MODEL_ARCHIVE}) + inference_qat2_int8_nlp_test(test_qat2_int8_ernie_mkldnn ${QAT2_ERNIE_MODEL_DIR}/Ernie_qat ${NLP_DATA_PATH} ${NLP_LABLES_PATH} ${QAT2_NLP_QUANTIZED_OPS}) + + # Save QAT2 FP32 model or QAT2 INT8 model set(QAT2_INT8_SAVE_PATH "${QAT_DATA_DIR}/ResNet50_qat2_int8") set(QAT2_FP32_SAVE_PATH "${QAT_DATA_DIR}/ResNet50_qat2_fp32") set(SAVE_QAT2_MODEL_SCRIPT "${CMAKE_CURRENT_SOURCE_DIR}/save_qat_model.py") - save_qat_model_test(save_qat2_model_resnet50 ${QAT2_RESNET50_MODEL_DIR}/ResNet50_qat_perf/float ${QAT2_FP32_SAVE_PATH} ${QAT2_INT8_SAVE_PATH} ${SAVE_QAT2_MODEL_SCRIPT} true) + save_qat_model_test(save_qat2_model_resnet50 ${QAT2_RESNET50_MODEL_DIR}/ResNet50_qat_perf/float ${QAT2_FP32_SAVE_PATH} ${QAT2_INT8_SAVE_PATH} ${SAVE_QAT2_MODEL_SCRIPT} true) + + set(QAT2_INT8_ERNIE_SAVE_PATH "${QAT_INSTALL_DIR}/Ernie_qat2_int8") + set(QAT2_FP32_ERNIE_SAVE_PATH "${QAT_INSTALL_DIR}/Ernie_qat2_fp32") + save_qat_model_test(save_qat2_model_ernie ${QAT2_ERNIE_MODEL_DIR}/Ernie_qat/float ${QAT2_FP32_ERNIE_SAVE_PATH} ${QAT2_INT8_ERNIE_SAVE_PATH} ${QAT2_NLP_QUANTIZED_OPS}) endif() -# Since the test for QAT FP32 & INT8 comparison supports only testing on Linux +# Since the tests for QAT FP32 & INT8 comparison support only testing on Linux # with MKL-DNN, we remove it here to not test it on other systems. -list(REMOVE_ITEM TEST_OPS qat_int8_comparison.py) +list(REMOVE_ITEM TEST_OPS + test_mkldnn_int8_quantization_strategy + qat_int8_image_classification_comparison + qat_int8_nlp_comparison) foreach(src ${TEST_OPS}) py_test(${src} SRCS ${src}.py) diff --git a/python/paddle/fluid/contrib/slim/tests/qat_int8_comparison.py b/python/paddle/fluid/contrib/slim/tests/qat_int8_image_classification_comparison.py similarity index 93% rename from python/paddle/fluid/contrib/slim/tests/qat_int8_comparison.py rename to python/paddle/fluid/contrib/slim/tests/qat_int8_image_classification_comparison.py index 9f713450684904dd95c0d18d4018adcc27151a84..0a0359dc6e32ca996db3385a3098795c14b2706d 100644 --- a/python/paddle/fluid/contrib/slim/tests/qat_int8_comparison.py +++ b/python/paddle/fluid/contrib/slim/tests/qat_int8_image_classification_comparison.py @@ -24,8 +24,8 @@ import time import paddle import paddle.fluid as fluid from paddle.fluid.framework import IrGraph -from paddle.fluid.contrib.slim.quantization import FakeQAT2MkldnnINT8KernelPass -from paddle.fluid.contrib.slim.quantization import FakeQAT2MkldnnINT8PerfPass +from paddle.fluid.contrib.slim.quantization import QatInt8MkldnnPass +from paddle.fluid.contrib.slim.quantization import Qat2Int8MkldnnPass from paddle.fluid import core logging.basicConfig(format='%(asctime)s-%(levelname)s: %(message)s') @@ -53,10 +53,6 @@ def parse_args(): action='store_true', help='If used, the QAT model is treated as a second generation model for performance optimization.' ) - parser.add_argument( - '--save_model', - action='store_true', - help='If used, the QAT model will be saved after all transformations') parser.add_argument('--infer_data', type=str, default='', help='Data file.') parser.add_argument( '--batch_num', @@ -68,15 +64,20 @@ def parse_args(): type=float, default=0.01, help='Accepted accuracy difference threshold.') + parser.add_argument( + '--quantized_ops', + type=str, + default='', + help='A comma separated list of quantized operators.') test_args, args = parser.parse_known_args(namespace=unittest) return test_args, sys.argv[:1] + args -class TestQatInt8Comparison(unittest.TestCase): +class QatInt8ImageClassificationComparisonTest(unittest.TestCase): """ - Test for accuracy comparison of QAT FP32 and INT8 inference. + Test for accuracy comparison of QAT FP32 and INT8 Image Classification inference. """ def _reader_creator(self, data_file='data.bin'): @@ -182,14 +183,15 @@ class TestQatInt8Comparison(unittest.TestCase): graph.draw('.', 'qat_orig', graph.all_op_nodes()) if (transform_to_int8): if (test_case_args.qat2): - transform_to_mkldnn_int8_pass = FakeQAT2MkldnnINT8PerfPass( + transform_to_mkldnn_int8_pass = Qat2Int8MkldnnPass( + self._quantized_ops, _scope=inference_scope, _place=place, _core=core, _debug=self._debug) graph = transform_to_mkldnn_int8_pass.apply(graph) else: - mkldnn_int8_pass = FakeQAT2MkldnnINT8KernelPass( + mkldnn_int8_pass = QatInt8MkldnnPass( _scope=inference_scope, _place=place) graph = mkldnn_int8_pass.apply(graph) @@ -256,12 +258,6 @@ class TestQatInt8Comparison(unittest.TestCase): _logger.info('Total inference run time: {:.2f} s'.format( infer_total_time)) - if test_case_args.save_model: - with fluid.scope_guard(inference_scope): - fluid.io.save_inference_model( - 'transformed_qat_int8_model', feed_target_names, - fetch_targets, exe, inference_program) - return outputs, acc1_avg, acc5_avg, fps_avg, latency_avg def _summarize_performance(self, fp32_fps, fp32_lat, int8_fps, int8_lat): @@ -298,6 +294,7 @@ class TestQatInt8Comparison(unittest.TestCase): skip_batch_num = test_case_args.skip_batch_num acc_diff_threshold = test_case_args.acc_diff_threshold self._debug = test_case_args.debug + self._quantized_ops = set(test_case_args.quantized_ops.split(',')) _logger.info('QAT FP32 & INT8 prediction run.') _logger.info('QAT model: {0}'.format(qat_model_path)) @@ -305,6 +302,7 @@ class TestQatInt8Comparison(unittest.TestCase): _logger.info('Batch size: {0}'.format(batch_size)) _logger.info('Batch number: {0}'.format(batch_num)) _logger.info('Accuracy drop threshold: {0}.'.format(acc_diff_threshold)) + _logger.info('Quantized ops: {0}.'.format(self._quantized_ops)) _logger.info('--- QAT FP32 prediction start ---') val_reader = paddle.batch( diff --git a/python/paddle/fluid/contrib/slim/tests/qat_int8_nlp_comparison.py b/python/paddle/fluid/contrib/slim/tests/qat_int8_nlp_comparison.py new file mode 100644 index 0000000000000000000000000000000000000000..b6eb825b09d2c0aad0d45049feafc09106d1e8c7 --- /dev/null +++ b/python/paddle/fluid/contrib/slim/tests/qat_int8_nlp_comparison.py @@ -0,0 +1,288 @@ +# copyright (c) 2020 paddlepaddle authors. all rights reserved. +# +# licensed under the apache license, version 2.0 (the "license"); +# you may not use this file except in compliance with the license. +# you may obtain a copy of the license at +# +# http://www.apache.org/licenses/license-2.0 +# +# unless required by applicable law or agreed to in writing, software +# distributed under the license is distributed on an "as is" basis, +# without warranties or conditions of any kind, either express or implied. +# see the license for the specific language governing permissions and +# limitations under the license. + +import unittest +import os +import sys +import argparse +import logging +import struct +import six +import numpy as np +import time +import paddle +import paddle.fluid as fluid +from paddle.fluid.framework import IrGraph +from paddle.fluid.contrib.slim.quantization import Qat2Int8MkldnnPass +from paddle.fluid import core + +logging.basicConfig(format='%(asctime)s-%(levelname)s: %(message)s') +_logger = logging.getLogger(__name__) +_logger.setLevel(logging.INFO) + + +def parse_args(): + parser = argparse.ArgumentParser() + parser.add_argument('--batch_size', type=int, default=1, help='Batch size.') + parser.add_argument( + '--skip_batch_num', + type=int, + default=0, + help='Number of the first minibatches to skip in performance statistics.' + ) + parser.add_argument( + '--debug', + action='store_true', + help='If used, the graph of QAT model is drawn.') + parser.add_argument( + '--qat_model', type=str, default='', help='A path to a QAT model.') + parser.add_argument( + '--save_model', + action='store_true', + help='If used, the QAT model will be saved after all transformations') + parser.add_argument('--infer_data', type=str, default='', help='Data file.') + parser.add_argument( + '--labels', type=str, default='', help='File with labels.') + parser.add_argument( + '--batch_num', + type=int, + default=1, + help='Number of batches to process. 0 or less means all.') + parser.add_argument( + '--acc_diff_threshold', + type=float, + default=0.01, + help='Accepted accuracy difference threshold.') + parser.add_argument( + '--quantized_ops', + type=str, + default='', + help='A comma separated list of quantized operators.') + + test_args, args = parser.parse_known_args(namespace=unittest) + + return test_args, sys.argv[:1] + args + + +class QatInt8NLPComparisonTest(unittest.TestCase): + """ + Test for accuracy comparison of QAT FP32 and INT8 NLP inference. + """ + + def _reader_creator(self, data_file=None, labels_file=None): + assert data_file, "The dataset file is missing." + assert labels_file, "The labels file is missing." + + def reader(): + with open(data_file, 'r') as df: + with open(labels_file, 'r') as lf: + data_lines = df.readlines() + labels_lines = lf.readlines() + assert len(data_lines) == len( + labels_lines + ), "The number of labels does not match the length of the dataset." + + for i in range(len(data_lines)): + data_fields = data_lines[i].split(';') + assert len( + data_fields + ) >= 2, "The number of data fields in the dataset is less than 2" + buffers = [] + shape = [] + for j in range(2): + data = data_fields[j].split(':') + assert len( + data + ) >= 2, "Size of data in the dataset is less than 2" + # Shape is stored under index 0, while data under 1 + shape = data[0].split() + shape.pop(0) + shape_np = np.array(shape).astype("int64") + buffer_i = data[1].split() + buffer_np = np.array(buffer_i).astype("int64") + buffer_np.shape = tuple(shape_np) + buffers.append(buffer_np) + label = labels_lines[i] + yield buffers[0], buffers[1], int(label) + + return reader + + def _get_batch_correct(self, batch_output=None, labels=None): + total = len(batch_output) + assert total > 0, "The batch output is empty." + correct = 0 + for n, output in enumerate(batch_output[0]): + max_idx = np.where(output == output.max()) + if max_idx == labels[n]: + correct += 1 + return correct + + def _predict(self, + test_reader=None, + model_path=None, + batch_size=1, + batch_num=1, + skip_batch_num=0, + transform_to_int8=False): + place = fluid.CPUPlace() + exe = fluid.Executor(place) + inference_scope = fluid.executor.global_scope() + with fluid.scope_guard(inference_scope): + if os.path.exists(os.path.join(model_path, '__model__')): + [inference_program, feed_target_names, + fetch_targets] = fluid.io.load_inference_model(model_path, exe) + else: + [inference_program, feed_target_names, + fetch_targets] = fluid.io.load_inference_model( + model_path, exe, 'model', 'params') + + graph = IrGraph(core.Graph(inference_program.desc), for_test=True) + if (self._debug): + graph.draw('.', 'qat_orig', graph.all_op_nodes()) + if (transform_to_int8): + transform_to_mkldnn_int8_pass = Qat2Int8MkldnnPass( + self._quantized_ops, + _scope=inference_scope, + _place=place, + _core=core, + _debug=self._debug) + graph = transform_to_mkldnn_int8_pass.apply(graph) + + inference_program = graph.to_program() + + total_correct = 0 + total_samples = 0 + batch_times = [] + ppses = [] # predictions per second + iters = 0 + infer_start_time = time.time() + for data in test_reader(): + if batch_num > 0 and iters >= batch_num: + break + if iters == skip_batch_num: + total_samples = 0 + infer_start_time = time.time() + input0 = np.array([x[0] for x in data]).astype('int64') + input1 = np.array([x[1] for x in data]).astype('int64') + labels = np.array([x[2] for x in data]).astype('int64') + + start = time.time() + out = exe.run(inference_program, + feed={ + feed_target_names[0]: input0, + feed_target_names[1]: input1 + }, + fetch_list=fetch_targets) + batch_time = (time.time() - start) * 1000 # in miliseconds + batch_times.append(batch_time) + batch_correct = self._get_batch_correct(out, labels) + batch_len = len(data) + total_samples += batch_len + total_correct += batch_correct + batch_acc = float(batch_correct) / float(batch_len) + pps = batch_len / batch_time * 1000 + ppses.append(pps) + latency = batch_time / batch_len + iters += 1 + appx = ' (warm-up)' if iters <= skip_batch_num else '' + _logger.info( + 'batch {0}{4}, acc: {1:.4f}, latency: {2:.4f} ms, predictions per sec: {3:.2f}' + .format(iters, batch_acc, latency, pps, appx)) + + # Postprocess benchmark data + infer_total_time = time.time() - infer_start_time + batch_latencies = batch_times[skip_batch_num:] + batch_latency_avg = np.average(batch_latencies) + latency_avg = batch_latency_avg / batch_size + ppses = ppses[skip_batch_num:] + pps_avg = np.average(ppses) + acc_avg = float(np.sum(total_correct)) / float(total_samples) + _logger.info('Total inference run time: {:.2f} s'.format( + infer_total_time)) + + return acc_avg, pps_avg, latency_avg + + def _summarize_performance(self, fp32_pps, fp32_lat, int8_pps, int8_lat): + _logger.info('--- Performance summary ---') + _logger.info( + 'FP32: avg predictions per sec: {0:.2f}, avg latency: {1:.4f} ms'. + format(fp32_pps, fp32_lat)) + _logger.info( + 'INT8: avg predictions per sec: {0:.2f}, avg latency: {1:.4f} ms'. + format(int8_pps, int8_lat)) + + def _compare_accuracy(self, fp32_acc, int8_acc, threshold): + _logger.info('--- Accuracy summary ---') + _logger.info( + 'Accepted accuracy drop threshold: {0}. (condition: (FP32_acc - INT8_acc) <= threshold)' + .format(threshold)) + _logger.info('FP32: avg accuracy: {0:.6f}'.format(fp32_acc)) + _logger.info('INT8: avg accuracy: {0:.6f}'.format(int8_acc)) + # Random outputs give accuracy about 0.33, we assume valid accuracy to be at least 0.5 + assert fp32_acc > 0.5 + assert int8_acc > 0.5 + assert fp32_acc - int8_acc <= threshold + + def test_graph_transformation(self): + if not fluid.core.is_compiled_with_mkldnn(): + return + + qat_model_path = test_case_args.qat_model + data_path = test_case_args.infer_data + labels_path = test_case_args.labels + batch_size = test_case_args.batch_size + batch_num = test_case_args.batch_num + skip_batch_num = test_case_args.skip_batch_num + acc_diff_threshold = test_case_args.acc_diff_threshold + self._debug = test_case_args.debug + self._quantized_ops = set(test_case_args.quantized_ops.split(',')) + + _logger.info('QAT FP32 & INT8 prediction run.') + _logger.info('QAT model: {0}'.format(qat_model_path)) + _logger.info('Dataset: {0}'.format(data_path)) + _logger.info('Labels: {0}'.format(labels_path)) + _logger.info('Batch size: {0}'.format(batch_size)) + _logger.info('Batch number: {0}'.format(batch_num)) + _logger.info('Accuracy drop threshold: {0}.'.format(acc_diff_threshold)) + _logger.info('Quantized ops: {0}.'.format(self._quantized_ops)) + + _logger.info('--- QAT FP32 prediction start ---') + val_reader = paddle.batch( + self._reader_creator(data_path, labels_path), batch_size=batch_size) + fp32_acc, fp32_pps, fp32_lat = self._predict( + val_reader, + qat_model_path, + batch_size, + batch_num, + skip_batch_num, + transform_to_int8=False) + _logger.info('--- QAT INT8 prediction start ---') + val_reader = paddle.batch( + self._reader_creator(data_path, labels_path), batch_size=batch_size) + int8_acc, int8_pps, int8_lat = self._predict( + val_reader, + qat_model_path, + batch_size, + batch_num, + skip_batch_num, + transform_to_int8=True) + + self._summarize_performance(fp32_pps, fp32_lat, int8_pps, int8_lat) + self._compare_accuracy(fp32_acc, int8_acc, acc_diff_threshold) + + +if __name__ == '__main__': + global test_case_args + test_case_args, remaining_args = parse_args() + unittest.main(argv=remaining_args) diff --git a/python/paddle/fluid/contrib/slim/tests/save_qat_model.py b/python/paddle/fluid/contrib/slim/tests/save_qat_model.py index 03db63fc103b5d9b7ed063d0bb8cadf45600d2e9..7e275bd1ed7dabc788266724ba32926d93b1201a 100644 --- a/python/paddle/fluid/contrib/slim/tests/save_qat_model.py +++ b/python/paddle/fluid/contrib/slim/tests/save_qat_model.py @@ -24,7 +24,7 @@ import time import paddle import paddle.fluid as fluid from paddle.fluid.framework import IrGraph -from paddle.fluid.contrib.slim.quantization import FakeQAT2MkldnnINT8PerfPass +from paddle.fluid.contrib.slim.quantization import Qat2Int8MkldnnPass from paddle.fluid import core @@ -42,6 +42,11 @@ def parse_args(): type=str, default='', help='Saved optimized and quantized INT8 model') + parser.add_argument( + '--quantized_ops', + type=str, + default='', + help='A comma separated list of quantized operators.') test_args, args = parser.parse_known_args(namespace=unittest) return test_args, sys.argv[:1] + args @@ -60,8 +65,9 @@ def transform_and_save_model(original_path, save_path, save_type): fetch_targets] = fluid.io.load_inference_model(original_path, exe, 'model', 'params') - transform_to_mkldnn_int8_pass = FakeQAT2MkldnnINT8PerfPass( - _scope=inference_scope, _place=place, _core=core) + quantized_ops = set(test_args.quantized_ops.split(',')) + transform_to_mkldnn_int8_pass = Qat2Int8MkldnnPass( + quantized_ops, _scope=inference_scope, _place=place, _core=core) graph = IrGraph(core.Graph(inference_program.desc), for_test=True) if save_type == 'FP32': diff --git a/python/paddle/fluid/contrib/slim/tests/test_quantization_mkldnn_pass.py b/python/paddle/fluid/contrib/slim/tests/test_quantization_mkldnn_pass.py index 7ccf67d9788251fa6b5589cbd2e56a152976fb76..eb75070e45c9d62830bc5c66a41f54afc5a0ff5d 100644 --- a/python/paddle/fluid/contrib/slim/tests/test_quantization_mkldnn_pass.py +++ b/python/paddle/fluid/contrib/slim/tests/test_quantization_mkldnn_pass.py @@ -22,7 +22,7 @@ import paddle from paddle.fluid.framework import IrGraph from paddle.fluid.contrib.slim.quantization import QuantizationFreezePass from paddle.fluid.contrib.slim.quantization import QuantizationTransformPass -from paddle.fluid.contrib.slim.quantization import FakeQAT2MkldnnINT8KernelPass +from paddle.fluid.contrib.slim.quantization import QatInt8MkldnnPass from paddle.fluid import core os.environ["CPU_NUM"] = "1" @@ -149,8 +149,7 @@ class TestMKLDNNTransformBasedFreezePass(unittest.TestCase): freeze_pass.apply(test_graph) # Transform quantized graph for MKL-DNN INT8 inference - mkldnn_int8_pass = FakeQAT2MkldnnINT8KernelPass( - _scope=scope, _place=place) + mkldnn_int8_pass = QatInt8MkldnnPass(_scope=scope, _place=place) mkldnn_int8_pass.apply(test_graph) dev_name = '_cpu_' if not for_ci: