未验证 提交 ed7956a8 编写于 作者: G guofei 提交者: GitHub

Fix skip_quant in QAT (#31704)

* Fix skip_quant in QAT
上级 8c19d7aa
......@@ -515,6 +515,8 @@ class ImperativeCalcOutputScale(object):
self._out_scale_dict[ops_list[op_count]])
op_count += 1
self._set_skip_quant_attr(inference_program)
# save the final quantized model that has output scales
save_inference_model(
dirname=dirname,
......@@ -537,9 +539,12 @@ class ImperativeCalcOutputScale(object):
Init the scale params for calculating output scales and save them in the
target layer.
After the users define the dygraph model, the hooks for calculating output
scales will not execute immediately. If the users load the checkpoint now,
the scale params have not been created, so them cann't be loaded.
Therefore, define the scale params in the beginning.
scales will not execute immediately. If the users load parameters form
checkpoint and save the quantized inference model immediately, the inference
model would not be saved successfully. Beacuse the dygraph_to_static requires
that the parameters created in __init__, but the uniqueness of hook make it
impossible to create parameters in __init__. To avoid this mistake, we define
the scale parameters in the beginning instead of hook.
"""
def _create_param(in_layer, first_name, last_name, dtype):
......@@ -587,6 +592,33 @@ class ImperativeCalcOutputScale(object):
op_type = op_type.replace('relu', 're_lu')
return op_type in layer_name
def _set_skip_quant_attr(self, program):
block = program.global_block()
for op in block.ops:
if self._is_skip_quant_op(block, op):
op._set_attr("skip_quant", True)
def _is_skip_quant_op(self, block, in_op):
"""
The input op should be skipped quantization.
1. the type of input op should be conv2d, depthwise_conv2d or matmul
2. the previous ops of the input op are not fake_quantize_dequantize ops
"""
def _find_previous_op(block, var_name):
for op in block.ops:
if var_name in op.output_arg_names:
return op
target_op_types = ["conv2d", "depthwise_conv2d", "matmul"]
if in_op.type not in target_op_types:
return False
previous_ops = [_find_previous_op(block, arg_name) \
for arg_name in in_op.input_arg_names]
return any(op is not None and op.type not in utils.fake_quantize_dequantize_types \
for op in previous_ops )
def _calc_output_scale_hook(self, layer, input, output):
"""
Create the MovingAverageAbsMaxScale layer for the target layer if needed.
......
......@@ -52,6 +52,12 @@ supported_quant_layers_map = {
'LayerNorm': paddle.nn.LayerNorm,
}
fake_quantize_dequantize_types = [
"fake_quantize_dequantize_abs_max",
"fake_quantize_dequantize_channel_wise_abs_max",
"fake_quantize_dequantize_moving_average_abs_max"
]
out_scale_layers_list = (
paddle.nn.Conv2D, paddle.nn.Linear, paddle.nn.MaxPool2D,
paddle.nn.BatchNorm, paddle.nn.BatchNorm2D, paddle.nn.SyncBatchNorm,
......
......@@ -393,12 +393,16 @@ class TestImperativeOutSclae(unittest.TestCase):
if 'fake' in op.type:
static_ops.remove(op)
op_count = 0
for i in range(len(dynamic_ops)):
if dynamic_ops[i].has_attr("out_threshold"):
op_count += 1
self.assertTrue(dynamic_ops[i].type == static_ops[i].type)
self.assertTrue(dynamic_ops[i].attr("out_threshold") ==
static_ops[i].attr("out_threshold"))
self.assertTrue(op_count == 13)
class TestSaveQuanztizedModelFromCheckPoint(unittest.TestCase):
def test_save_quantized_model(self):
......@@ -459,11 +463,14 @@ class TestSaveQuanztizedModelFromCheckPoint(unittest.TestCase):
if 'fake' in op.type:
static_ops.remove(op)
op_count = 0
for i in range(len(dynamic_ops)):
if dynamic_ops[i].has_attr("out_threshold"):
op_count += 1
self.assertTrue(dynamic_ops[i].type == static_ops[i].type)
self.assertTrue(dynamic_ops[i].attr("out_threshold") ==
static_ops[i].attr("out_threshold"))
self.assertTrue(op_count == 13)
class TestSaveQuantizedModel_Warning(unittest.TestCase):
......
......@@ -200,9 +200,12 @@ class TestImperativeOutSclae(unittest.TestCase):
params_filename="lenet" + INFER_PARAMS_SUFFIX))
model_ops = inference_program.global_block().ops
conv2d_count, mul_count = 0, 0
conv2d_count, matmul_count = 0, 0
conv2d_skip_count, matmul_skip_count = 0, 0
for i, op in enumerate(model_ops):
if op.type == 'conv2d':
if op.has_attr("skip_quant"):
conv2d_skip_count += 1
if conv2d_count > 0:
self.assertTrue(
'fake_quantize_dequantize' in model_ops[i - 1].type)
......@@ -211,14 +214,19 @@ class TestImperativeOutSclae(unittest.TestCase):
'fake_quantize_dequantize' not in model_ops[i - 1].type)
conv2d_count += 1
if op.type == 'mul':
if mul_count > 0:
if op.type == 'matmul':
if op.has_attr("skip_quant"):
matmul_skip_count += 1
if matmul_count > 0:
self.assertTrue(
'fake_quantize_dequantize' in model_ops[i - 1].type)
else:
self.assertTrue(
'fake_quantize_dequantize' not in model_ops[i - 1].type)
mul_count += 1
matmul_count += 1
self.assertTrue(conv2d_skip_count == 1)
self.assertTrue(matmul_skip_count == 1)
if __name__ == '__main__':
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册