diff --git a/tests/test_quant_aware.py b/tests/test_quant_aware.py index 706302bc1b9776912ff11406dd8691765a8f22a1..f26e810a33bb65e823890ab0132e7141a80e4dcf 100644 --- a/tests/test_quant_aware.py +++ b/tests/test_quant_aware.py @@ -13,6 +13,7 @@ # limitations under the License. import sys +from typing import List sys.path.append("../") import unittest import paddle @@ -43,8 +44,8 @@ class TestQuantAwareCase(StaticCase): main_prog = paddle.static.default_main_program() val_prog = paddle.static.default_main_program().clone(for_test=True) - place = paddle.CUDAPlace(0) if paddle.is_compiled_with_cuda( - ) else paddle.CPUPlace() + place = paddle.CUDAPlace( + 0) if paddle.is_compiled_with_cuda() else paddle.CPUPlace() exe = paddle.static.Executor(place) exe.run(paddle.static.default_startup_program()) @@ -104,67 +105,87 @@ class TestQuantAwareCase(StaticCase): train(main_prog) top1_1, top5_1 = test(main_prog) + ops_with_weights = [ + 'depthwise_conv2d', + 'mul', + 'conv2d', + ] + ops_without_weights = [ + 'relu', + ] + config = { 'weight_quantize_type': 'channel_wise_abs_max', 'activation_quantize_type': 'moving_average_abs_max', - 'quantize_op_types': ['depthwise_conv2d', 'mul', 'conv2d'], + 'quantize_op_types': ops_with_weights + ops_without_weights, } quant_train_prog = quant_aware(main_prog, place, config, for_test=False) quant_eval_prog = quant_aware(val_prog, place, config, for_test=True) - op_nums_1, quant_op_nums_1 = self.get_op_number(quant_eval_prog) - # test quant_aware op numbers - self.assertTrue(op_nums_1 * 2 == quant_op_nums_1) + + # Step1: check the quantizers count in qat graph + quantizers_count_in_qat = self.count_op(quant_eval_prog, + ['quantize_linear']) + ops_with_weights_count = self.count_op(quant_eval_prog, + ops_with_weights) + ops_without_weights_count = self.count_op(quant_eval_prog, + ops_without_weights) + self.assertEqual(ops_with_weights_count * 2 + ops_without_weights_count, + quantizers_count_in_qat) + + with paddle.static.program_guard(quant_eval_prog): + paddle.static.save_inference_model("./models/mobilenet_qat", [ + image, label + ], [avg_cost, acc_top1, acc_top5], exe) train(quant_train_prog) convert_eval_prog = convert(quant_eval_prog, place, config) + with paddle.static.program_guard(convert_eval_prog): + paddle.static.save_inference_model("./models/mobilenet_onnx", [ + image, label + ], [avg_cost, acc_top1, acc_top5], exe) + top1_2, top5_2 = test(convert_eval_prog) # values before quantization and after quantization should be close print("before quantization: top1: {}, top5: {}".format(top1_1, top5_1)) print("after quantization: top1: {}, top5: {}".format(top1_2, top5_2)) - convert_op_nums_1, convert_quant_op_nums_1 = self.get_convert_op_number( - convert_eval_prog) - # test convert op numbers - self.assertTrue(convert_op_nums_1 + 25 == convert_quant_op_nums_1) + # Step2: check the quantizers count in onnx graph + quantizers_count = self.count_op(convert_eval_prog, ['quantize_linear']) + observers_count = self.count_op(quant_eval_prog, + ['moving_average_abs_max_scale']) + self.assertEqual(quantizers_count, ops_with_weights_count + + ops_without_weights_count + observers_count) + # Step3: check the quantization skipping config['not_quant_pattern'] = ['last_fc'] - quant_prog_2 = quant_aware( + skip_quant_prog = quant_aware( main_prog, place, config=config, for_test=True) - op_nums_2, quant_op_nums_2 = self.get_op_number(quant_prog_2) - convert_prog_2 = convert(quant_prog_2, place, config=config) - convert_op_nums_2, convert_quant_op_nums_2 = self.get_convert_op_number( - convert_prog_2) - - self.assertTrue(op_nums_1 == op_nums_2) - # test skip_quant - self.assertTrue(quant_op_nums_1 - 2 == quant_op_nums_2) - self.assertTrue(convert_quant_op_nums_1 == convert_quant_op_nums_2) - - def get_op_number(self, prog): + skip_quantizers_count_in_qat = self.count_op(skip_quant_prog, + ['quantize_linear']) + skip_ops_with_weights_count = self.count_op(skip_quant_prog, + ops_with_weights) + skip_ops_without_weights_count = self.count_op(skip_quant_prog, + ops_without_weights) + self.assertEqual(skip_ops_without_weights_count, + ops_without_weights_count) + self.assertEqual(skip_ops_with_weights_count, ops_with_weights_count) + self.assertEqual(skip_quantizers_count_in_qat + 2, + quantizers_count_in_qat) + + skip_quant_prog_onnx = convert(skip_quant_prog, place, config=config) + skip_quantizers_count_in_onnx = self.count_op(skip_quant_prog_onnx, + ['quantize_linear']) + self.assertEqual(quantizers_count, skip_quantizers_count_in_onnx) + + def count_op(self, prog, ops: List[str]): graph = paddle.fluid.framework.IrGraph( paddle.framework.core.Graph(prog.desc), for_test=False) - quant_op_nums = 0 - op_nums = 0 - for op in graph.all_op_nodes(): - if op.name() in ['conv2d', 'depthwise_conv2d', 'mul']: - op_nums += 1 - elif op.name() == 'quantize_linear': - quant_op_nums += 1 - return op_nums, quant_op_nums - - def get_convert_op_number(self, prog): - graph = paddle.fluid.framework.IrGraph( - paddle.framework.core.Graph(prog.desc), for_test=True) - quant_op_nums = 0 op_nums = 0 - dequant_num = 0 for op in graph.all_op_nodes(): - if op.name() not in ['quantize_linear', 'dequantize_linear']: + if op.name() in ops: op_nums += 1 - elif op.name() == 'quantize_linear': - quant_op_nums += 1 - return op_nums, quant_op_nums + return op_nums if __name__ == '__main__':