diff --git a/paddle/fluid/framework/ir/mkldnn/cpu_quantize_pass.cc b/paddle/fluid/framework/ir/mkldnn/cpu_quantize_pass.cc index 95ec17ac7f5f89dab6cf8d43893f93d8c1e4431a..1750c3fdc4a2fa79d20915cebfbf2ecc171f082e 100644 --- a/paddle/fluid/framework/ir/mkldnn/cpu_quantize_pass.cc +++ b/paddle/fluid/framework/ir/mkldnn/cpu_quantize_pass.cc @@ -14,6 +14,7 @@ #include "paddle/fluid/framework/ir/mkldnn/cpu_quantize_pass.h" #include +#include #include #include #include "paddle/fluid/framework/eigen.h" @@ -169,13 +170,31 @@ void CPUQuantizePass::DequantizeOutput(Graph* g, Node* op, Node* output, if (!scale_attr_name.empty()) op->Op()->SetAttr(scale_attr_name, scale); } +bool CPUQuantizePass::AreScalesPresentForNodes( + const Node* op_node, std::initializer_list nodes) const { + auto& scales = Get("quant_var_scales"); + bool present = true; + for (auto node : nodes) { + if (scales.count(node->Name()) == 0) { + present = false; + std::stringstream msg_ss; + msg_ss << "Quantization scale for the variable " << node->Name() + << " is missing."; + PrettyLogDetail(msg_ss.str().c_str()); + } + } + if (!present) { + std::stringstream msg_ss; + msg_ss << "Cannot quantize operator " << op_node->Name() + << " (type: " << op_node->Op()->Type() << ")."; + PrettyLogDetail(msg_ss.str().c_str()); + } + return present; +} + std::pair CPUQuantizePass::GetScaleDataForNode( const Node* node) const { auto& scales = Get("quant_var_scales"); - PADDLE_ENFORCE_EQ( - scales.count(node->Name()), 1, - platform::errors::InvalidArgument( - "Quantization scale for the variable %s is missing.", node->Name())); return scales[node->Name()]; } @@ -221,6 +240,25 @@ void CPUQuantizePass::QuantizeConv(Graph* graph, GET_IR_NODE_FROM_SUBGRAPH(conv_input, conv_input, conv_pattern); GET_IR_NODE_FROM_SUBGRAPH(conv_output, conv_output, conv_pattern); + if (with_residual_data) { + GET_IR_NODE_FROM_SUBGRAPH(conv_residual_data, conv_residual_data, + conv_pattern); + if (!AreScalesPresentForNodes(conv_op, {conv_input, conv_filter, + conv_residual_data, conv_output})) + return; + + bool is_residual_unsigned{false}; + auto residual_scale = + GetScaleValueForNode(conv_residual_data, &is_residual_unsigned); + + QuantizeInput(g, conv_op, conv_residual_data, "ResidualData", + residual_scale, is_residual_unsigned, "Scale_in_eltwise"); + } else { + if (!AreScalesPresentForNodes(conv_op, + {conv_input, conv_filter, conv_output})) + return; + } + bool is_input_unsigned{false}; auto input_scale = GetScaleValueForNode(conv_input, &is_input_unsigned); QuantizeInput(g, conv_op, conv_input, "Input", input_scale, @@ -236,17 +274,6 @@ void CPUQuantizePass::QuantizeConv(Graph* graph, conv_op->Op()->SetAttr("Scale_weights", filter_scale); - if (with_residual_data) { - GET_IR_NODE_FROM_SUBGRAPH(conv_residual_data, conv_residual_data, - conv_pattern); - bool is_residual_unsigned{false}; - auto residual_scale = - GetScaleValueForNode(conv_residual_data, &is_residual_unsigned); - - QuantizeInput(g, conv_op, conv_residual_data, "ResidualData", - residual_scale, is_residual_unsigned, "Scale_in_eltwise"); - } - bool is_output_unsigned{false}; auto output_scale = GetScaleValueForNode(conv_output, &is_output_unsigned); DequantizeOutput(g, conv_op, conv_output, "Output", output_scale, @@ -298,6 +325,8 @@ void CPUQuantizePass::QuantizeFc(Graph* graph) const { GET_IR_NODE_FROM_SUBGRAPH(input, input, fc_pattern); GET_IR_NODE_FROM_SUBGRAPH(output, output, fc_pattern); + if (!AreScalesPresentForNodes(fc, {input, weights, output})) return; + bool is_input_unsigned{false}; auto input_scale = GetScaleValueForNode(input, &is_input_unsigned); QuantizeInput(g, fc, input, "Input", input_scale, is_input_unsigned, @@ -348,6 +377,8 @@ void CPUQuantizePass::QuantizePool(Graph* graph) const { GET_IR_NODE_FROM_SUBGRAPH(pool_input, pool_input, pool_pattern); GET_IR_NODE_FROM_SUBGRAPH(pool_output, pool_output, pool_pattern); + if (!AreScalesPresentForNodes(pool_op, {pool_input, pool_output})) return; + bool is_input_unsigned{false}; auto input_scale = GetScaleValueForNode(pool_input, &is_input_unsigned); QuantizeInput(g, pool_op, pool_input, "X", input_scale, is_input_unsigned); @@ -384,6 +415,8 @@ void CPUQuantizePass::QuantizeConcat(Graph* graph) const { GET_IR_NODE_FROM_SUBGRAPH(concat_out, concat_out, concat_pattern); + if (!AreScalesPresentForNodes(concat_op, {concat_out})) return; + // if all inputs were unsigned, then the output was set to unsigned // during the scale calculation step bool are_all_inputs_unsigned{false}; @@ -423,6 +456,8 @@ void CPUQuantizePass::QuantizePriorBox(Graph* graph) const { GET_IR_NODE_FROM_SUBGRAPH(prior_box_input, prior_box_input, prior_box_pattern); + if (!AreScalesPresentForNodes(prior_box_op, {prior_box_input})) return; + bool is_input_unsigned{false}; auto input_scale = GetScaleValueForNode(prior_box_input, &is_input_unsigned); @@ -466,6 +501,9 @@ void CPUQuantizePass::QuantizeTranspose(Graph* graph) const { GET_IR_NODE_FROM_SUBGRAPH(transpose_in, transpose_in, transpose_pattern); GET_IR_NODE_FROM_SUBGRAPH(transpose_out, transpose_out, transpose_pattern); + if (!AreScalesPresentForNodes(transpose_op, {transpose_in, transpose_out})) + return; + bool is_input_unsigned{false}; auto input_scale = GetScaleValueForNode(transpose_in, &is_input_unsigned); QuantizeInput(g, transpose_op, transpose_in, "X", input_scale, @@ -515,6 +553,9 @@ void CPUQuantizePass::QuantizeReshape(Graph* graph) const { GET_IR_NODE_FROM_SUBGRAPH(reshape_in, reshape_in, reshape_pattern); GET_IR_NODE_FROM_SUBGRAPH(reshape_out, reshape_out, reshape_pattern); + if (!AreScalesPresentForNodes(reshape_op, {reshape_in, reshape_out})) + return; + bool is_input_unsigned{false}; auto input_scale = GetScaleValueForNode(reshape_in, &is_input_unsigned); QuantizeInput(g, reshape_op, reshape_in, "X", input_scale, @@ -562,6 +603,10 @@ void CPUQuantizePass::QuantizeMatmul(Graph* graph) const { GET_IR_NODE_FROM_SUBGRAPH(matmul_in_y, matmul_in_y, matmul_pattern); GET_IR_NODE_FROM_SUBGRAPH(matmul_out, matmul_out, matmul_pattern); + if (!AreScalesPresentForNodes(matmul_op, + {matmul_in_x, matmul_in_y, matmul_out})) + return; + bool is_x_unsigned{false}, is_y_unsigned{false}; auto input_x_scale = GetScaleValueForNode(matmul_in_x, &is_x_unsigned); auto input_y_scale = GetScaleValueForNode(matmul_in_y, &is_y_unsigned); diff --git a/paddle/fluid/framework/ir/mkldnn/cpu_quantize_pass.h b/paddle/fluid/framework/ir/mkldnn/cpu_quantize_pass.h index cca691d443200e7ddea745c97ea5875e65f84010..cd5c673061b79602f6eceda55fb0107d2a41535c 100644 --- a/paddle/fluid/framework/ir/mkldnn/cpu_quantize_pass.h +++ b/paddle/fluid/framework/ir/mkldnn/cpu_quantize_pass.h @@ -74,6 +74,8 @@ class CPUQuantizePass : public FusePassBase { bool is_unsigned, std::string scale_attr_name = "") const; + bool AreScalesPresentForNodes(const Node* op_node, + std::initializer_list nodes) const; std::pair GetScaleDataForNode(const Node* node) const; LoDTensor GetScaleTensorForNode(const Node* node) const; double GetScaleValueForNode(const Node* node, diff --git a/paddle/fluid/framework/ir/mkldnn/cpu_quantize_pass_tester.cc b/paddle/fluid/framework/ir/mkldnn/cpu_quantize_pass_tester.cc index 8a9a431e4d95d7825a3c3c52ba7879474c23dccc..c6264d503a012faa4a95143278308eb8d0edba74 100644 --- a/paddle/fluid/framework/ir/mkldnn/cpu_quantize_pass_tester.cc +++ b/paddle/fluid/framework/ir/mkldnn/cpu_quantize_pass_tester.cc @@ -486,41 +486,6 @@ TEST(CpuQuantizePass, reshapeBetweenNonQuantizedOp) { added_nodes_count, 2.0f * 127); } -void MainTestCheckScales( - const ProgramDesc& prog, - const std::initializer_list variable_names, - const std::string& var_without_scale) { - std::unique_ptr graph(new ir::Graph(prog)); - std::stringstream error_msg_ss; - error_msg_ss << "Quantization scale for the variable " << var_without_scale - << " is missing."; - bool caught_exception = false; - try { - int original_nodes_num, current_nodes_num; - PreparePass(&graph, prog, variable_names, &original_nodes_num, - ¤t_nodes_num, var_without_scale); - } catch (paddle::platform::EnforceNotMet& error) { - caught_exception = true; - std::string ex_msg = error.what(); - EXPECT_NE(ex_msg.find(error_msg_ss.str()), std::string::npos); - } - EXPECT_TRUE(caught_exception); -} - -// (a, w)->Conv->o -ProgramDesc BuildProgramDescCheckScalesConv() { - ProgramDesc prog; - SetOp(&prog, "conv2d", "Conv", {"a", "w"}, {"o"}, true, true); - return prog; -} - -// Check if an exception with a proper message is thrown when quantization scale -// is missing for a variable -TEST(CPUQuantizePass, check_scales) { - const std::initializer_list var_names = {"a", "w", "o"}; - MainTestCheckScales(BuildProgramDescCheckScalesConv(), var_names, "a"); -} - static const std::initializer_list variable_names_matmul = { "a", "b", "c", "d", "e", "f"}; diff --git a/python/paddle/fluid/contrib/slim/quantization/qat2_int8_mkldnn_pass.py b/python/paddle/fluid/contrib/slim/quantization/qat2_int8_mkldnn_pass.py index 64a1c3cf4edb43c35fd37e91454a931c0d241740..72dd668898c5c07689985a1a15e9b84f4d6db0e3 100644 --- a/python/paddle/fluid/contrib/slim/quantization/qat2_int8_mkldnn_pass.py +++ b/python/paddle/fluid/contrib/slim/quantization/qat2_int8_mkldnn_pass.py @@ -36,7 +36,7 @@ class Qat2Int8MkldnnPass(object): """ def __init__(self, - _quantized_ops, + _ops_to_quantize, _scope=None, _place=None, _core=None, @@ -53,7 +53,7 @@ class Qat2Int8MkldnnPass(object): self._fake_dequantize_types = [ 'fake_dequantize_max_abs', 'fake_channel_wise_dequantize_max_abs' ] - self._quantized_ops = _quantized_ops + self._ops_to_quantize = _ops_to_quantize self._scale_immutable_ops = [ 'transpose2', 'reshape2', 'pool2d', 'scale' ] @@ -101,10 +101,11 @@ class Qat2Int8MkldnnPass(object): return tensor def _is_conv_quantized(self): - return any(op_type in self._quantized_ops for op_type in self._conv_ops) + return any(op_type in self._ops_to_quantize + for op_type in self._conv_ops) def _is_fc_quantized(self): - return 'fc' in self._quantized_ops + return 'fc' in self._ops_to_quantize def _gather_input_scales_from_fake(self, graph): def _add_scale_for_vars(var_names, use_unsigned_int, lod_tensor): @@ -238,27 +239,13 @@ class Qat2Int8MkldnnPass(object): return np.array(scope.find_var(param_name).get_tensor()) def _remove_fake_ops(self, graph): - ''' - When FC isn't quantized: - Remove fake (de)quantize ops that do not surround mul. - When FC is quantized: - Remove all fake (de)quantize ops. - ''' - is_fc_quantized = self._is_fc_quantized() for op in graph.all_op_nodes(): if op.name() in self._fake_quantize_types: - op_out = graph._find_node_by_name(op.outputs, - op.output("Out")[0]) - next_op = op_out.outputs[0] - if next_op.name() not in self._mul_ops or is_fc_quantized: - self._remove_fake_quantize(graph, op) + self._remove_fake_quantize(graph, op) for op in graph.all_op_nodes(): if op.name() in self._fake_dequantize_types: - op_in = graph._find_node_by_name(op.inputs, op.input("X")[0]) - prev_op = op_in.inputs[0] - if prev_op.name() not in self._mul_ops or is_fc_quantized: - self._remove_fake_dequantize(graph, op) + self._remove_fake_dequantize(graph, op) return graph @@ -305,7 +292,7 @@ class Qat2Int8MkldnnPass(object): for op in graph.all_op_nodes(): if op.name() in self._conv_ops: self._dequantize_op_weights(graph, op, "Filter", "Output") - elif self._is_fc_quantized() and op.name() in self._mul_ops: + elif op.name() in self._mul_ops: self._dequantize_op_weights(graph, op, "Y", "Out") return graph @@ -357,19 +344,16 @@ class Qat2Int8MkldnnPass(object): graph = self._remove_ctrl_vars(graph) graph = self._apply_pass(graph, 'mkldnn_placement_pass', ['mkldnn_enabled_op_types'], [set()]) - if self._is_conv_quantized(): - graph = self._apply_pass(graph, 'depthwise_conv_mkldnn_pass') - graph = self._apply_pass(graph, 'conv_bn_fuse_pass') - graph = self._apply_pass(graph, 'conv_eltwiseadd_bn_fuse_pass') - graph = self._apply_pass(graph, 'conv_bias_mkldnn_fuse_pass') - graph = self._apply_pass(graph, - 'conv_elementwise_add_mkldnn_fuse_pass') - graph = self._apply_pass(graph, 'conv_relu_mkldnn_fuse_pass') - graph = self._apply_pass(graph, 'conv_relu6_mkldnn_fuse_pass') - if self._is_fc_quantized(): - graph = self._apply_pass(graph, 'fc_fuse_pass', - ['use_gpu', 'use_fc_padding'], - [False, False]) + graph = self._apply_pass(graph, 'depthwise_conv_mkldnn_pass') + graph = self._apply_pass(graph, 'conv_bn_fuse_pass') + graph = self._apply_pass(graph, 'conv_eltwiseadd_bn_fuse_pass') + graph = self._apply_pass(graph, 'conv_bias_mkldnn_fuse_pass') + graph = self._apply_pass(graph, 'conv_elementwise_add_mkldnn_fuse_pass') + graph = self._apply_pass(graph, 'conv_relu_mkldnn_fuse_pass') + graph = self._apply_pass(graph, 'conv_relu6_mkldnn_fuse_pass') + graph = self._apply_pass(graph, 'fc_fuse_pass', + ['use_gpu', 'use_fc_padding'], [False, False]) + if self._is_fc_quantized: graph = self._apply_pass(graph, 'fc_mkldnn_pass') graph = self._apply_pass(graph, 'matmul_transpose_reshape_fuse_pass') return graph @@ -492,7 +476,7 @@ class Qat2Int8MkldnnPass(object): def _quantize_fp32_graph(self, graph): ir_pass = self._core.get_pass('cpu_quantize_placement_pass') cpp_graph = graph.graph - ir_pass.set('quantize_enabled_op_types', self._quantized_ops) + ir_pass.set('quantize_enabled_op_types', self._ops_to_quantize) ir_pass.set('quantize_excluded_op_ids', self._find_avg_pooling_ids(graph)) ir_pass.apply(cpp_graph) diff --git a/python/paddle/fluid/contrib/slim/tests/CMakeLists.txt b/python/paddle/fluid/contrib/slim/tests/CMakeLists.txt index 737d67051a5fd997f9179a8765cb8b3a1f4f7680..f401d64b73b89327afa1a4b81de78118e1f7ce3b 100644 --- a/python/paddle/fluid/contrib/slim/tests/CMakeLists.txt +++ b/python/paddle/fluid/contrib/slim/tests/CMakeLists.txt @@ -57,7 +57,7 @@ endfunction() # set batch_size 10 for UT only (avoid OOM). For whole dataset, use batch_size 25 -function(inference_qat2_int8_image_classification_test target qat_model_dir fp32_model_dir dataset_path quantized_ops) +function(inference_qat2_int8_image_classification_test target qat_model_dir fp32_model_dir dataset_path ops_to_quantize) py_test(${target} SRCS "${CMAKE_CURRENT_SOURCE_DIR}/qat2_int8_image_classification_comparison.py" ENVS FLAGS_OMP_NUM_THREADS=${CPU_NUM_THREADS_ON_CI} OMP_NUM_THREADS=${CPU_NUM_THREADS_ON_CI} @@ -68,11 +68,11 @@ function(inference_qat2_int8_image_classification_test target qat_model_dir fp32 --batch_size 10 --batch_num 2 --acc_diff_threshold 0.1 - --quantized_ops ${quantized_ops}) + --ops_to_quantize ${ops_to_quantize}) endfunction() # set batch_size 10 for UT only (avoid OOM). For whole dataset, use batch_size 20 -function(inference_qat2_int8_nlp_test target qat_model_dir fp32_model_dir dataset_path labels_path quantized_ops) +function(inference_qat2_int8_nlp_test target qat_model_dir fp32_model_dir dataset_path labels_path) py_test(${target} SRCS "${CMAKE_CURRENT_SOURCE_DIR}/qat2_int8_nlp_comparison.py" ENVS FLAGS_OMP_NUM_THREADS=${CPU_NUM_THREADS_ON_CI} OMP_NUM_THREADS=${CPU_NUM_THREADS_ON_CI} @@ -83,8 +83,7 @@ function(inference_qat2_int8_nlp_test target qat_model_dir fp32_model_dir datase --labels ${labels_path} --batch_size 10 --batch_num 2 - --acc_diff_threshold 0.1 - --quantized_ops ${quantized_ops}) + --acc_diff_threshold 0.1) endfunction() function(download_qat_data install_dir data_file) @@ -99,12 +98,19 @@ function(download_qat_model install_dir data_file) endif() endfunction() -function(save_qat_model_test target qat_model_dir fp32_model_save_path int8_model_save_path quantized_ops) +function(save_qat_ic_model_test target qat_model_dir fp32_model_save_path int8_model_save_path ops_to_quantize) py_test(${target} SRCS ${CMAKE_CURRENT_SOURCE_DIR}/save_qat_model.py ARGS --qat_model_path ${qat_model_dir} --fp32_model_save_path ${fp32_model_save_path} --int8_model_save_path ${int8_model_save_path} - --quantized_ops ${quantized_ops}) + --ops_to_quantize ${ops_to_quantize}) +endfunction() + +function(save_qat_nlp_model_test target qat_model_dir fp32_model_save_path int8_model_save_path) + py_test(${target} SRCS ${CMAKE_CURRENT_SOURCE_DIR}/save_qat_model.py + ARGS --qat_model_path ${qat_model_dir} + --fp32_model_save_path ${fp32_model_save_path} + --int8_model_save_path ${int8_model_save_path}) endfunction() if(WIN32) @@ -213,7 +219,7 @@ if(LINUX AND WITH_MKLDNN) ### QATv2 for image classification - set(QAT2_IC_QUANTIZED_OPS "conv2d,pool2d") + set(QAT2_IC_OPS_TO_QUANTIZE "conv2d,pool2d") # QAT2 ResNet50 with input/output scales in `fake_quantize_moving_average_abs_max` operators, # with weight scales in `fake_dequantize_max_abs` operators @@ -221,33 +227,31 @@ if(LINUX AND WITH_MKLDNN) set(FP32_RESNET50_MODEL_DIR "${INT8_INSTALL_DIR}/resnet50") set(QAT2_RESNET50_MODEL_ARCHIVE "ResNet50_qat_perf.tar.gz") download_qat_model(${QAT2_RESNET50_MODEL_DIR} ${QAT2_RESNET50_MODEL_ARCHIVE}) - inference_qat2_int8_image_classification_test(test_qat2_int8_resnet50_mkldnn ${QAT2_RESNET50_MODEL_DIR}/ResNet50_qat_perf/float ${FP32_RESNET50_MODEL_DIR}/model ${IMAGENET_DATA_PATH} ${QAT2_IC_QUANTIZED_OPS}) + inference_qat2_int8_image_classification_test(test_qat2_int8_resnet50_mkldnn ${QAT2_RESNET50_MODEL_DIR}/ResNet50_qat_perf/float ${FP32_RESNET50_MODEL_DIR}/model ${IMAGENET_DATA_PATH} ${QAT2_IC_OPS_TO_QUANTIZE}) # QAT2 ResNet50 with input/output scales in `fake_quantize_range_abs_max` operators and the `out_threshold` attributes, # with weight scales in `fake_dequantize_max_abs` operators set(QAT2_RESNET50_RANGE_MODEL_DIR "${QAT_INSTALL_DIR}/ResNet50_qat_range") set(QAT2_RESNET50_RANGE_MODEL_ARCHIVE "ResNet50_qat_range.tar.gz") download_qat_model(${QAT2_RESNET50_RANGE_MODEL_DIR} ${QAT2_RESNET50_RANGE_MODEL_ARCHIVE}) - inference_qat2_int8_image_classification_test(test_qat2_int8_resnet50_range_mkldnn ${QAT2_RESNET50_RANGE_MODEL_DIR}/ResNet50_qat_range ${FP32_RESNET50_MODEL_DIR}/model ${IMAGENET_DATA_PATH} ${QAT2_IC_QUANTIZED_OPS}) + inference_qat2_int8_image_classification_test(test_qat2_int8_resnet50_range_mkldnn ${QAT2_RESNET50_RANGE_MODEL_DIR}/ResNet50_qat_range ${FP32_RESNET50_MODEL_DIR}/model ${IMAGENET_DATA_PATH} ${QAT2_IC_OPS_TO_QUANTIZE}) # QAT2 ResNet50 with input/output scales in `fake_quantize_range_abs_max` operators and the `out_threshold` attributes, # with weight scales in `fake_channel_wise_dequantize_max_abs` operators set(QAT2_RESNET50_CHANNELWISE_MODEL_DIR "${QAT_INSTALL_DIR}/ResNet50_qat_channelwise") set(QAT2_RESNET50_CHANNELWISE_MODEL_ARCHIVE "ResNet50_qat_channelwise.tar.gz") download_qat_model(${QAT2_RESNET50_CHANNELWISE_MODEL_DIR} ${QAT2_RESNET50_CHANNELWISE_MODEL_ARCHIVE}) - inference_qat2_int8_image_classification_test(test_qat2_int8_resnet50_channelwise_mkldnn ${QAT2_RESNET50_CHANNELWISE_MODEL_DIR}/ResNet50_qat_channelwise ${FP32_RESNET50_MODEL_DIR}/model ${IMAGENET_DATA_PATH} ${QAT2_IC_QUANTIZED_OPS}) + inference_qat2_int8_image_classification_test(test_qat2_int8_resnet50_channelwise_mkldnn ${QAT2_RESNET50_CHANNELWISE_MODEL_DIR}/ResNet50_qat_channelwise ${FP32_RESNET50_MODEL_DIR}/model ${IMAGENET_DATA_PATH} ${QAT2_IC_OPS_TO_QUANTIZE}) # QAT2 MobileNetV1 set(QAT2_MOBILENETV1_MODEL_DIR "${QAT_INSTALL_DIR}/MobileNet_qat_perf") set(FP32_MOBILENETV1_MODEL_DIR "${INT8_INSTALL_DIR}/mobilenetv1") set(QAT2_MOBILENETV1_MODEL_ARCHIVE "MobileNet_qat_perf.tar.gz") download_qat_model(${QAT2_MOBILENETV1_MODEL_DIR} ${QAT2_MOBILENETV1_MODEL_ARCHIVE}) - inference_qat2_int8_image_classification_test(test_qat2_int8_mobilenetv1_mkldnn ${QAT2_MOBILENETV1_MODEL_DIR}/MobileNet_qat_perf/float ${FP32_MOBILENETV1_MODEL_DIR}/model ${IMAGENET_DATA_PATH} ${QAT2_IC_QUANTIZED_OPS}) + inference_qat2_int8_image_classification_test(test_qat2_int8_mobilenetv1_mkldnn ${QAT2_MOBILENETV1_MODEL_DIR}/MobileNet_qat_perf/float ${FP32_MOBILENETV1_MODEL_DIR}/model ${IMAGENET_DATA_PATH} ${QAT2_IC_OPS_TO_QUANTIZE}) ### QATv2 for NLP - set(QAT2_NLP_QUANTIZED_OPS "fc,reshape2,transpose2,matmul") - set(NLP_DATA_ARCHIVE "Ernie_dataset.tar.gz") set(NLP_DATA_DIR "${INFERENCE_DEMO_INSTALL_DIR}/Ernie_dataset") set(NLP_DATA_PATH "${NLP_DATA_DIR}/Ernie_dataset/1.8w.bs1") @@ -261,17 +265,17 @@ if(LINUX AND WITH_MKLDNN) set(FP32_ERNIE_MODEL_ARCHIVE "ernie_fp32_model.tar.gz") set(FP32_ERNIE_MODEL_DIR "${QAT_INSTALL_DIR}/Ernie_float") download_qat_fp32_model(${FP32_ERNIE_MODEL_DIR} ${FP32_ERNIE_MODEL_ARCHIVE}) - inference_qat2_int8_nlp_test(test_qat2_int8_ernie_mkldnn ${QAT2_ERNIE_MODEL_DIR}/Ernie_qat/float ${FP32_ERNIE_MODEL_DIR}/ernie_fp32_model ${NLP_DATA_PATH} ${NLP_LABLES_PATH} ${QAT2_NLP_QUANTIZED_OPS}) + inference_qat2_int8_nlp_test(test_qat2_int8_ernie_mkldnn ${QAT2_ERNIE_MODEL_DIR}/Ernie_qat/float ${FP32_ERNIE_MODEL_DIR}/ernie_fp32_model ${NLP_DATA_PATH} ${NLP_LABLES_PATH}) ### Save QAT2 FP32 model or QAT2 INT8 model set(QAT2_INT8_RESNET50_SAVE_PATH "${QAT_INSTALL_DIR}/ResNet50_qat2_int8") set(QAT2_FP32_RESNET50_SAVE_PATH "${QAT_INSTALL_DIR}/ResNet50_qat2_fp32") - save_qat_model_test(save_qat2_model_resnet50 ${QAT2_RESNET50_MODEL_DIR}/ResNet50_qat_perf/float ${QAT2_FP32_RESNET50_SAVE_PATH} ${QAT2_INT8_RESNET50_SAVE_PATH} ${QAT2_IC_QUANTIZED_OPS}) + save_qat_ic_model_test(save_qat2_model_resnet50 ${QAT2_RESNET50_MODEL_DIR}/ResNet50_qat_perf/float ${QAT2_FP32_RESNET50_SAVE_PATH} ${QAT2_INT8_RESNET50_SAVE_PATH} ${QAT2_IC_OPS_TO_QUANTIZE}) set(QAT2_INT8_ERNIE_SAVE_PATH "${QAT_INSTALL_DIR}/Ernie_qat2_int8") set(QAT2_FP32_ERNIE_SAVE_PATH "${QAT_INSTALL_DIR}/Ernie_qat2_fp32") - save_qat_model_test(save_qat2_model_ernie ${QAT2_ERNIE_MODEL_DIR}/Ernie_qat/float ${QAT2_FP32_ERNIE_SAVE_PATH} ${QAT2_INT8_ERNIE_SAVE_PATH} ${QAT2_NLP_QUANTIZED_OPS}) + save_qat_nlp_model_test(save_qat2_model_ernie ${QAT2_ERNIE_MODEL_DIR}/Ernie_qat/float ${QAT2_FP32_ERNIE_SAVE_PATH} ${QAT2_INT8_ERNIE_SAVE_PATH}) endif() diff --git a/python/paddle/fluid/contrib/slim/tests/QAT_mkldnn_int8_readme.md b/python/paddle/fluid/contrib/slim/tests/QAT_mkldnn_int8_readme.md index 797e6c43435aac628dec34e7d90011f4abc12087..634c685fd32583d6f286ae4eeb4c7efab5deadd1 100644 --- a/python/paddle/fluid/contrib/slim/tests/QAT_mkldnn_int8_readme.md +++ b/python/paddle/fluid/contrib/slim/tests/QAT_mkldnn_int8_readme.md @@ -270,15 +270,17 @@ You can use the `qat2_int8_image_classification_comparison.py` script to reprodu * `--qat_model` - a path to a QAT model that will be transformed into INT8 model. * `--fp32_model` - a path to an FP32 model whose accuracy will be measured and compared to the accuracy of the INT8 model. -* `--quantized_ops` - a comma-separated list of names of operators to be quantized. When deciding which operators to put on the list, the following have to be considered: - * Only operators which support quantization will be taken into account. - * All the quantizable operators from the list, which are present in the model, must have quantization scales provided in the model. Otherwise, the quantization procedure will fail with a message saying which variable is missing a quantization scale. - * Sometimes it may be suboptimal to quantize all quantizable operators in the model (cf. *Notes* in the **Gathering scales** section above). To find the optimal configuration for this option, user can run benchmark a few times with different lists of quantized operators present in the model and compare the results. For Image Classification models mentioned above the list comprises of `conv2d` and `pool2d` operators. * `--infer_data` - a path to the validation dataset. +The following option is also accepted: +* `--ops_to_quantize` - a comma-separated list of operator types to quantize. If the option is not used, an attempt to quantize all quantizable operators will be made, and in that case only quantizable operators which have quantization scales provided in the QAT model will be quantized. When deciding which operators to put on the list, the following have to be considered: + * Only operators which support quantization will be taken into account. + * All the quantizable operators from the list, which are present in the model, must have quantization scales provided in the model. Otherwise, quantization of the operator will be skipped with a message saying which variable is missing a quantization scale. + * Sometimes it may be suboptimal to quantize all quantizable operators in the model (cf. *Notes* in the **Gathering scales** section above). To find the optimal configuration for this option, user can run benchmark a few times with different lists of quantized operators present in the model and compare the results. For Image Classification models mentioned above the list usually comprises of `conv2d` and `pool2d` operators. + ```bash cd /PATH/TO/PADDLE -OMP_NUM_THREADS=28 FLAGS_use_mkldnn=true python python/paddle/fluid/contrib/slim/tests/qat2_int8_image_classification_comparison.py --qat_model=/PATH/TO/DOWNLOADED/QAT/MODEL --fp32_model=/PATH/TO/DOWNLOADED/FP32/MODEL --infer_data=$HOME/.cache/paddle/dataset/int8/download/int8_full_val.bin --batch_size=50 --batch_num=1000 --acc_diff_threshold=0.01 --quantized_ops="conv2d,pool2d" +OMP_NUM_THREADS=28 FLAGS_use_mkldnn=true python python/paddle/fluid/contrib/slim/tests/qat2_int8_image_classification_comparison.py --qat_model=/PATH/TO/DOWNLOADED/QAT/MODEL --fp32_model=/PATH/TO/DOWNLOADED/FP32/MODEL --infer_data=$HOME/.cache/paddle/dataset/int8/download/int8_full_val.bin --batch_size=50 --batch_num=1000 --acc_diff_threshold=0.01 --ops_to_quantize="conv2d,pool2d" ``` > Notes: Due to a large amount of images in the `int8_full_val.bin` dataset (50 000), the accuracy benchmark may last long. To accelerate accuracy measuring, it is recommended to set `OMP_NUM_THREADS` to the maximum number of physical cores available on the server. @@ -287,11 +289,11 @@ OMP_NUM_THREADS=28 FLAGS_use_mkldnn=true python python/paddle/fluid/contrib/slim To reproduce the performance results, the environment variable `OMP_NUM_THREADS=1` and `--batch_size=1` option should be set. -1. Transform the QAT model into INT8 model by applying the `Qat2Int8MkldnnPass` pass and save the result. You can use the script `save_qat_model.py` for this purpose. It also requires the option `--quantized_ops` with a list of operators to be quantized. +1. Transform the QAT model into INT8 model by applying the `Qat2Int8MkldnnPass` pass and save the result. You can use the script `save_qat_model.py` for this purpose. It also accepts the option `--ops_to_quantize` with a list of operators to quantize. ```bash cd /PATH/TO/PADDLE/build - python ../python/paddle/fluid/contrib/slim/tests/save_qat_model.py --qat_model_path=/PATH/TO/DOWNLOADED/QAT/MODEL --int8_model_save_path=/PATH/TO/SAVE/QAT/INT8/MODEL --quantized_ops="conv2d,pool2d" + python ../python/paddle/fluid/contrib/slim/tests/save_qat_model.py --qat_model_path=/PATH/TO/DOWNLOADED/QAT/MODEL --int8_model_save_path=/PATH/TO/SAVE/QAT/INT8/MODEL --ops_to_quantize="conv2d,pool2d" ``` 2. Run the C-API test for performance benchmark. diff --git a/python/paddle/fluid/contrib/slim/tests/qat2_int8_image_classification_comparison.py b/python/paddle/fluid/contrib/slim/tests/qat2_int8_image_classification_comparison.py index 954524c6bcffb499da54b44c0b1fa3d4b1aa2ede..63d0c7fdb78f1843e1e6c8cdc261c654a891c4f6 100644 --- a/python/paddle/fluid/contrib/slim/tests/qat2_int8_image_classification_comparison.py +++ b/python/paddle/fluid/contrib/slim/tests/qat2_int8_image_classification_comparison.py @@ -62,10 +62,11 @@ def parse_args(): default=0.01, help='Accepted accuracy difference threshold.') parser.add_argument( - '--quantized_ops', + '--ops_to_quantize', type=str, default='', - help='A comma separated list of quantized operators.') + help='A comma separated list of operators to quantize. Only quantizable operators are taken into account. If the option is not used, an attempt to quantize all quantizable operators will be made.' + ) test_args, args = parser.parse_known_args(namespace=unittest) return test_args, sys.argv[:1] + args @@ -305,7 +306,9 @@ class Qat2Int8ImageClassificationComparisonTest(unittest.TestCase): skip_batch_num = test_case_args.skip_batch_num acc_diff_threshold = test_case_args.acc_diff_threshold self._debug = test_case_args.debug - self._quantized_ops = set(test_case_args.quantized_ops.split(',')) + self._quantized_ops = set() + if len(test_case_args.ops_to_quantize) > 0: + self._quantized_ops = set(test_case_args.ops_to_quantize.split(',')) _logger.info('FP32 & QAT INT8 prediction run.') _logger.info('QAT model: {0}'.format(qat_model_path)) diff --git a/python/paddle/fluid/contrib/slim/tests/qat2_int8_nlp_comparison.py b/python/paddle/fluid/contrib/slim/tests/qat2_int8_nlp_comparison.py index 98c7d0f8323181fbd6564745b94ea673f6afb5df..5c6d82271643b92ba56a9afec8d28415f300947d 100644 --- a/python/paddle/fluid/contrib/slim/tests/qat2_int8_nlp_comparison.py +++ b/python/paddle/fluid/contrib/slim/tests/qat2_int8_nlp_comparison.py @@ -68,10 +68,11 @@ def parse_args(): default=0.01, help='Accepted accuracy difference threshold.') parser.add_argument( - '--quantized_ops', + '--ops_to_quantize', type=str, default='', - help='A comma separated list of quantized operators.') + help='A comma separated list of operators to quantize. Only quantizable operators are taken into account. If the option is not used, an attempt to quantize all quantizable operators will be made.' + ) test_args, args = parser.parse_known_args(namespace=unittest) @@ -252,7 +253,9 @@ class QatInt8NLPComparisonTest(unittest.TestCase): skip_batch_num = test_case_args.skip_batch_num acc_diff_threshold = test_case_args.acc_diff_threshold self._debug = test_case_args.debug - self._quantized_ops = set(test_case_args.quantized_ops.split(',')) + self._quantized_ops = set() + if len(test_case_args.ops_to_quantize) > 0: + self._quantized_ops = set(test_case_args.ops_to_quantize.split(',')) _logger.info('FP32 & QAT INT8 prediction run.') _logger.info('QAT model: {0}'.format(qat_model_path)) diff --git a/python/paddle/fluid/contrib/slim/tests/save_qat_model.py b/python/paddle/fluid/contrib/slim/tests/save_qat_model.py index 7e275bd1ed7dabc788266724ba32926d93b1201a..74a70736cbb9a33626940a072fb8a23897677601 100644 --- a/python/paddle/fluid/contrib/slim/tests/save_qat_model.py +++ b/python/paddle/fluid/contrib/slim/tests/save_qat_model.py @@ -43,10 +43,11 @@ def parse_args(): default='', help='Saved optimized and quantized INT8 model') parser.add_argument( - '--quantized_ops', + '--ops_to_quantize', type=str, default='', - help='A comma separated list of quantized operators.') + help='A comma separated list of operators to quantize. Only quantizable operators are taken into account. If the option is not used, an attempt to quantize all quantizable operators will be made.' + ) test_args, args = parser.parse_known_args(namespace=unittest) return test_args, sys.argv[:1] + args @@ -65,9 +66,12 @@ def transform_and_save_model(original_path, save_path, save_type): fetch_targets] = fluid.io.load_inference_model(original_path, exe, 'model', 'params') - quantized_ops = set(test_args.quantized_ops.split(',')) + ops_to_quantize = set() + if len(test_args.ops_to_quantize) > 0: + ops_to_quantize = set(test_args.ops_to_quantize.split(',')) + transform_to_mkldnn_int8_pass = Qat2Int8MkldnnPass( - quantized_ops, _scope=inference_scope, _place=place, _core=core) + ops_to_quantize, _scope=inference_scope, _place=place, _core=core) graph = IrGraph(core.Graph(inference_program.desc), for_test=True) if save_type == 'FP32':