From 3d744162ddebfae0ff54ccee908b346cdec7e6a6 Mon Sep 17 00:00:00 2001 From: Wojciech Uss Date: Thu, 23 Apr 2020 15:49:15 +0200 Subject: [PATCH] QAT: support for new models (#23928) * QAT: support range-based quantization and scales from attribute * added support for channelwise --- .../quantization/qat2_int8_mkldnn_pass.py | 102 +++++++++++------- .../fluid/contrib/slim/tests/CMakeLists.txt | 19 +++- 2 files changed, 82 insertions(+), 39 deletions(-) diff --git a/python/paddle/fluid/contrib/slim/quantization/qat2_int8_mkldnn_pass.py b/python/paddle/fluid/contrib/slim/quantization/qat2_int8_mkldnn_pass.py index 43b173f434d..8f217083865 100644 --- a/python/paddle/fluid/contrib/slim/quantization/qat2_int8_mkldnn_pass.py +++ b/python/paddle/fluid/contrib/slim/quantization/qat2_int8_mkldnn_pass.py @@ -45,16 +45,14 @@ class Qat2Int8MkldnnPass(object): self._place = _place self._core = _core self._debug = _debug - self._quantize_types = [ + self._fake_quantize_types = [ 'fake_quantize_moving_average_abs_max', 'fake_quantize_range_abs_max', 'fake_quantize_dequantize_moving_average_abs_max' ] - self._fake_quantize_types = [ - 'fake_quantize_moving_average_abs_max', - 'fake_quantize_dequantize_moving_average_abs_max' + self._fake_dequantize_types = [ + 'fake_dequantize_max_abs', 'fake_channel_wise_dequantize_max_abs' ] - self._fake_dequantize_types = ['fake_dequantize_max_abs'] self._quantized_ops = _quantized_ops self._scale_immutable_ops = [ 'transpose2', 'reshape2', 'pool2d', 'scale' @@ -74,7 +72,9 @@ class Qat2Int8MkldnnPass(object): assert isinstance(graph, IrGraph), 'graph must be the instance of IrGraph.' - graph = self._gather_scales(graph) + graph = self._gather_weight_scales_from_fake(graph) + graph = self._gather_output_scales_from_attr(graph) + graph = self._gather_input_scales_from_fake(graph) graph = self._remove_fake_ops(graph) graph = self._dequantize_weights(graph) graph = self._optimize_fp32_graph(graph) @@ -83,6 +83,7 @@ class Qat2Int8MkldnnPass(object): graph = self._propagate_scales(graph) graph = self._set_dummy_out_scales(graph) graph = self._quantize_fp32_graph(graph) + graph = self._optimize_int8_graph(graph) graph = self._cleanup(graph) return graph @@ -90,9 +91,6 @@ class Qat2Int8MkldnnPass(object): assert isinstance(graph, IrGraph), 'graph must be the instance of IrGraph.' - graph = self._gather_scales(graph) - graph = self._remove_fake_ops(graph) - graph = self._dequantize_weights(graph) graph = self._optimize_fp32_graph(graph) graph = self._cleanup(graph) return graph @@ -108,29 +106,61 @@ class Qat2Int8MkldnnPass(object): def _is_fc_quantized(self): return 'fc' in self._quantized_ops - def _gather_scales(self, graph): + def _gather_input_scales_from_fake(self, graph): + def _add_scale_for_vars(var_names, use_unsigned_int, lod_tensor): + scales = self._var_quant_scales + for var_name in var_names: + scales[var_name] = (use_unsigned_int, lod_tensor) + for op in graph.all_op_nodes(): - if op.name() in self._quantize_types: + if op.name() in self._fake_quantize_types: bit_length = op.op().attr("bit_length") assert bit_length == 8, 'Unsupported number quantization bits ({}). Only 8 is supported now.'.format( bit_length) input_name = op.input("X")[0] scale_name = op.input("InScale")[0] + output_name = op.output("Out")[0] # Gather new weights scale after folding batchnorm in convolution scale = np.array(1.0 / self._load_param( self._scope, scale_name)[0]).astype(np.float64) lod_tensor = self._convert_scale2tensor(scale) use_unsigned_int = False - self._var_quant_scales[input_name] = (use_unsigned_int, - lod_tensor) - self._var_quant_scales[scale_name.replace(".scale", "")] = ( - use_unsigned_int, lod_tensor) + _add_scale_for_vars([input_name, output_name], use_unsigned_int, + lod_tensor) + + return graph + def _gather_weight_scales_from_fake(self, graph): + for op in graph.all_op_nodes(): if op.name() in self._fake_dequantize_types: input_name = op.input("X")[0] - _max_range = op.op().attr("max_range") - self._weight_scales[input_name] = _max_range + if op.op().has_attr("max_range"): + _max_range = np.array(op.op().attr("max_range")).astype( + np.float64) + self._weight_scales[input_name] = _max_range + else: + scale_name = op.input("Scales")[0] + scale = np.array( + self._s8_max * self._s8_max / self._load_param( + self._scope, scale_name)).astype(np.float64) + self._weight_scales[input_name] = scale + + return graph + + def _gather_output_scales_from_attr(self, graph): + for op in graph.all_op_nodes(): + if op.op().has_attr("out_threshold"): + attr_scale = op.op().attr("out_threshold") + if attr_scale == 0.0: continue + scale = np.array(1.0 / attr_scale).astype(np.float64) + scale_lod_tensor = self._convert_scale2tensor(scale) + use_unsigned_int = False + for output_name in op.op().outputs(): + for out_var_name in op.op().output(output_name): + self._var_quant_scales[out_var_name] = ( + use_unsigned_int, scale_lod_tensor) + return graph def _propagate_scales(self, graph): @@ -274,29 +304,24 @@ class Qat2Int8MkldnnPass(object): def _dequantize_weights(self, graph): for op in graph.all_op_nodes(): if op.name() in self._conv_ops: - self._dequantize_conv_weights(graph, op) + self._dequantize_op_weights(graph, op, "Filter", "Output") elif self._is_fc_quantized() and op.name() in self._mul_ops: - self._dequantize_mul_weights(graph, op) + self._dequantize_op_weights(graph, op, "Y", "Out") return graph - def _dequantize_conv_weights(self, graph, op_node): - weight_name = op_node.input("Filter")[0] - output_name = op_node.output("Output")[0] + def _dequantize_op_weights(self, graph, op_node, weight_name, output_name): + weight_var_name = op_node.input(weight_name)[0] + output_var_name = op_node.output(output_name)[0] # Convert int8 range weights to fp32 range weights - scales = self._weight_scales[output_name] - weight = self._load_param(self._scope, weight_name) - w_fp32 = np.divide(np.multiply(weight, self._s8_max), scales) - w_fp32 = w_fp32.reshape(weight.shape) - self._restore_var(weight_name, w_fp32) - - def _dequantize_mul_weights(self, graph, op_node): - weight_name = op_node.input("Y")[0] - output_name = op_node.output("Out")[0] - scales = self._weight_scales[output_name] - weight = self._load_param(self._scope, weight_name) - w_fp32 = np.divide(np.multiply(weight, self._s8_max), scales) - w_fp32 = w_fp32.reshape(weight.shape) - self._restore_var(weight_name, w_fp32) + scales = self._weight_scales[output_var_name] + weight = self._load_param(self._scope, weight_var_name) + assert scales.size == 1 or scales.size == len( + weight + ), "The size of weight scales vector ({}) does not match the number of output channels ({}) in the weights tensor {}.".format( + scales.size, len(weight), weight_var_name) + w_fp32 = np.divide(np.multiply(weight, self._s8_max).T, scales.T).T + w_fp32 = w_fp32.reshape(weight.shape).astype(np.float32) + self._restore_var(weight_var_name, w_fp32) def _restore_var(self, name, array): tensor = self._scope.find_var(name).get_tensor() @@ -366,11 +391,14 @@ class Qat2Int8MkldnnPass(object): self._remove_unused_var_nodes(graph) return graph - def _cleanup(self, graph): + def _optimize_int8_graph(self, graph): # remove dropout ops graph = self._apply_pass(graph, 'simplify_with_basic_ops_pass') # make some MKL-DNN ops working inplace graph = self._apply_pass(graph, 'mkldnn_inplace_pass') + return graph + + def _cleanup(self, graph): graph = self._remove_unused_var_nodes(graph) graph = self._set_op_role_forward(graph) return graph diff --git a/python/paddle/fluid/contrib/slim/tests/CMakeLists.txt b/python/paddle/fluid/contrib/slim/tests/CMakeLists.txt index ec3d23a1c06..737d67051a5 100644 --- a/python/paddle/fluid/contrib/slim/tests/CMakeLists.txt +++ b/python/paddle/fluid/contrib/slim/tests/CMakeLists.txt @@ -68,7 +68,7 @@ function(inference_qat2_int8_image_classification_test target qat_model_dir fp32 --batch_size 10 --batch_num 2 --acc_diff_threshold 0.1 - --quantized_ops ${quantized_ops}) + --quantized_ops ${quantized_ops}) endfunction() # set batch_size 10 for UT only (avoid OOM). For whole dataset, use batch_size 20 @@ -215,13 +215,28 @@ if(LINUX AND WITH_MKLDNN) set(QAT2_IC_QUANTIZED_OPS "conv2d,pool2d") - # QAT2 ResNet50 + # QAT2 ResNet50 with input/output scales in `fake_quantize_moving_average_abs_max` operators, + # with weight scales in `fake_dequantize_max_abs` operators set(QAT2_RESNET50_MODEL_DIR "${QAT_INSTALL_DIR}/ResNet50_qat_perf") set(FP32_RESNET50_MODEL_DIR "${INT8_INSTALL_DIR}/resnet50") set(QAT2_RESNET50_MODEL_ARCHIVE "ResNet50_qat_perf.tar.gz") download_qat_model(${QAT2_RESNET50_MODEL_DIR} ${QAT2_RESNET50_MODEL_ARCHIVE}) inference_qat2_int8_image_classification_test(test_qat2_int8_resnet50_mkldnn ${QAT2_RESNET50_MODEL_DIR}/ResNet50_qat_perf/float ${FP32_RESNET50_MODEL_DIR}/model ${IMAGENET_DATA_PATH} ${QAT2_IC_QUANTIZED_OPS}) + # QAT2 ResNet50 with input/output scales in `fake_quantize_range_abs_max` operators and the `out_threshold` attributes, + # with weight scales in `fake_dequantize_max_abs` operators + set(QAT2_RESNET50_RANGE_MODEL_DIR "${QAT_INSTALL_DIR}/ResNet50_qat_range") + set(QAT2_RESNET50_RANGE_MODEL_ARCHIVE "ResNet50_qat_range.tar.gz") + download_qat_model(${QAT2_RESNET50_RANGE_MODEL_DIR} ${QAT2_RESNET50_RANGE_MODEL_ARCHIVE}) + inference_qat2_int8_image_classification_test(test_qat2_int8_resnet50_range_mkldnn ${QAT2_RESNET50_RANGE_MODEL_DIR}/ResNet50_qat_range ${FP32_RESNET50_MODEL_DIR}/model ${IMAGENET_DATA_PATH} ${QAT2_IC_QUANTIZED_OPS}) + + # QAT2 ResNet50 with input/output scales in `fake_quantize_range_abs_max` operators and the `out_threshold` attributes, + # with weight scales in `fake_channel_wise_dequantize_max_abs` operators + set(QAT2_RESNET50_CHANNELWISE_MODEL_DIR "${QAT_INSTALL_DIR}/ResNet50_qat_channelwise") + set(QAT2_RESNET50_CHANNELWISE_MODEL_ARCHIVE "ResNet50_qat_channelwise.tar.gz") + download_qat_model(${QAT2_RESNET50_CHANNELWISE_MODEL_DIR} ${QAT2_RESNET50_CHANNELWISE_MODEL_ARCHIVE}) + inference_qat2_int8_image_classification_test(test_qat2_int8_resnet50_channelwise_mkldnn ${QAT2_RESNET50_CHANNELWISE_MODEL_DIR}/ResNet50_qat_channelwise ${FP32_RESNET50_MODEL_DIR}/model ${IMAGENET_DATA_PATH} ${QAT2_IC_QUANTIZED_OPS}) + # QAT2 MobileNetV1 set(QAT2_MOBILENETV1_MODEL_DIR "${QAT_INSTALL_DIR}/MobileNet_qat_perf") set(FP32_MOBILENETV1_MODEL_DIR "${INT8_INSTALL_DIR}/mobilenetv1") -- GitLab