From ed7956a816130f4eb37ba3e235c09d1105ed1807 Mon Sep 17 00:00:00 2001 From: guofei <52460041+gfwm2013@users.noreply.github.com> Date: Sun, 21 Mar 2021 19:59:44 +0800 Subject: [PATCH] Fix skip_quant in QAT (#31704) * Fix skip_quant in QAT --- .../slim/quantization/imperative/qat.py | 38 +++++++++++++++++-- .../slim/quantization/imperative/utils.py | 6 +++ .../slim/tests/test_imperative_out_scale.py | 7 ++++ .../slim/tests/test_imperative_skip_op.py | 16 ++++++-- 4 files changed, 60 insertions(+), 7 deletions(-) diff --git a/python/paddle/fluid/contrib/slim/quantization/imperative/qat.py b/python/paddle/fluid/contrib/slim/quantization/imperative/qat.py index abfe06a332..68b4cfdc66 100644 --- a/python/paddle/fluid/contrib/slim/quantization/imperative/qat.py +++ b/python/paddle/fluid/contrib/slim/quantization/imperative/qat.py @@ -515,6 +515,8 @@ class ImperativeCalcOutputScale(object): self._out_scale_dict[ops_list[op_count]]) op_count += 1 + self._set_skip_quant_attr(inference_program) + # save the final quantized model that has output scales save_inference_model( dirname=dirname, @@ -537,9 +539,12 @@ class ImperativeCalcOutputScale(object): Init the scale params for calculating output scales and save them in the target layer. After the users define the dygraph model, the hooks for calculating output - scales will not execute immediately. If the users load the checkpoint now, - the scale params have not been created, so them cann't be loaded. - Therefore, define the scale params in the beginning. + scales will not execute immediately. If the users load parameters form + checkpoint and save the quantized inference model immediately, the inference + model would not be saved successfully. Beacuse the dygraph_to_static requires + that the parameters created in __init__, but the uniqueness of hook make it + impossible to create parameters in __init__. To avoid this mistake, we define + the scale parameters in the beginning instead of hook. """ def _create_param(in_layer, first_name, last_name, dtype): @@ -587,6 +592,33 @@ class ImperativeCalcOutputScale(object): op_type = op_type.replace('relu', 're_lu') return op_type in layer_name + def _set_skip_quant_attr(self, program): + block = program.global_block() + for op in block.ops: + if self._is_skip_quant_op(block, op): + op._set_attr("skip_quant", True) + + def _is_skip_quant_op(self, block, in_op): + """ + The input op should be skipped quantization. + 1. the type of input op should be conv2d, depthwise_conv2d or matmul + 2. the previous ops of the input op are not fake_quantize_dequantize ops + """ + + def _find_previous_op(block, var_name): + for op in block.ops: + if var_name in op.output_arg_names: + return op + + target_op_types = ["conv2d", "depthwise_conv2d", "matmul"] + if in_op.type not in target_op_types: + return False + + previous_ops = [_find_previous_op(block, arg_name) \ + for arg_name in in_op.input_arg_names] + return any(op is not None and op.type not in utils.fake_quantize_dequantize_types \ + for op in previous_ops ) + def _calc_output_scale_hook(self, layer, input, output): """ Create the MovingAverageAbsMaxScale layer for the target layer if needed. diff --git a/python/paddle/fluid/contrib/slim/quantization/imperative/utils.py b/python/paddle/fluid/contrib/slim/quantization/imperative/utils.py index 1ff4a408e0..3bf655265c 100644 --- a/python/paddle/fluid/contrib/slim/quantization/imperative/utils.py +++ b/python/paddle/fluid/contrib/slim/quantization/imperative/utils.py @@ -52,6 +52,12 @@ supported_quant_layers_map = { 'LayerNorm': paddle.nn.LayerNorm, } +fake_quantize_dequantize_types = [ + "fake_quantize_dequantize_abs_max", + "fake_quantize_dequantize_channel_wise_abs_max", + "fake_quantize_dequantize_moving_average_abs_max" +] + out_scale_layers_list = ( paddle.nn.Conv2D, paddle.nn.Linear, paddle.nn.MaxPool2D, paddle.nn.BatchNorm, paddle.nn.BatchNorm2D, paddle.nn.SyncBatchNorm, diff --git a/python/paddle/fluid/contrib/slim/tests/test_imperative_out_scale.py b/python/paddle/fluid/contrib/slim/tests/test_imperative_out_scale.py index 83ddac4196..ed29375d22 100644 --- a/python/paddle/fluid/contrib/slim/tests/test_imperative_out_scale.py +++ b/python/paddle/fluid/contrib/slim/tests/test_imperative_out_scale.py @@ -393,12 +393,16 @@ class TestImperativeOutSclae(unittest.TestCase): if 'fake' in op.type: static_ops.remove(op) + op_count = 0 for i in range(len(dynamic_ops)): if dynamic_ops[i].has_attr("out_threshold"): + op_count += 1 self.assertTrue(dynamic_ops[i].type == static_ops[i].type) self.assertTrue(dynamic_ops[i].attr("out_threshold") == static_ops[i].attr("out_threshold")) + self.assertTrue(op_count == 13) + class TestSaveQuanztizedModelFromCheckPoint(unittest.TestCase): def test_save_quantized_model(self): @@ -459,11 +463,14 @@ class TestSaveQuanztizedModelFromCheckPoint(unittest.TestCase): if 'fake' in op.type: static_ops.remove(op) + op_count = 0 for i in range(len(dynamic_ops)): if dynamic_ops[i].has_attr("out_threshold"): + op_count += 1 self.assertTrue(dynamic_ops[i].type == static_ops[i].type) self.assertTrue(dynamic_ops[i].attr("out_threshold") == static_ops[i].attr("out_threshold")) + self.assertTrue(op_count == 13) class TestSaveQuantizedModel_Warning(unittest.TestCase): diff --git a/python/paddle/fluid/contrib/slim/tests/test_imperative_skip_op.py b/python/paddle/fluid/contrib/slim/tests/test_imperative_skip_op.py index 0561055e6e..bda02769ce 100644 --- a/python/paddle/fluid/contrib/slim/tests/test_imperative_skip_op.py +++ b/python/paddle/fluid/contrib/slim/tests/test_imperative_skip_op.py @@ -200,9 +200,12 @@ class TestImperativeOutSclae(unittest.TestCase): params_filename="lenet" + INFER_PARAMS_SUFFIX)) model_ops = inference_program.global_block().ops - conv2d_count, mul_count = 0, 0 + conv2d_count, matmul_count = 0, 0 + conv2d_skip_count, matmul_skip_count = 0, 0 for i, op in enumerate(model_ops): if op.type == 'conv2d': + if op.has_attr("skip_quant"): + conv2d_skip_count += 1 if conv2d_count > 0: self.assertTrue( 'fake_quantize_dequantize' in model_ops[i - 1].type) @@ -211,14 +214,19 @@ class TestImperativeOutSclae(unittest.TestCase): 'fake_quantize_dequantize' not in model_ops[i - 1].type) conv2d_count += 1 - if op.type == 'mul': - if mul_count > 0: + if op.type == 'matmul': + if op.has_attr("skip_quant"): + matmul_skip_count += 1 + if matmul_count > 0: self.assertTrue( 'fake_quantize_dequantize' in model_ops[i - 1].type) else: self.assertTrue( 'fake_quantize_dequantize' not in model_ops[i - 1].type) - mul_count += 1 + matmul_count += 1 + + self.assertTrue(conv2d_skip_count == 1) + self.assertTrue(matmul_skip_count == 1) if __name__ == '__main__': -- GitLab