From 43f3d0cce37982157708ac2cdd293767efb9c1a1 Mon Sep 17 00:00:00 2001 From: Wojciech Uss Date: Mon, 27 Jul 2020 15:18:45 +0200 Subject: [PATCH] Add an option to choose inference targets in Quant tests (#25582) test=develop --- ...t2_int8_image_classification_comparison.py | 92 ++++++++++++------- .../slim/tests/quant2_int8_nlp_comparison.py | 88 +++++++++++------- 2 files changed, 113 insertions(+), 67 deletions(-) diff --git a/python/paddle/fluid/contrib/slim/tests/quant2_int8_image_classification_comparison.py b/python/paddle/fluid/contrib/slim/tests/quant2_int8_image_classification_comparison.py index 77c925a1b11..17e0f452e98 100644 --- a/python/paddle/fluid/contrib/slim/tests/quant2_int8_image_classification_comparison.py +++ b/python/paddle/fluid/contrib/slim/tests/quant2_int8_image_classification_comparison.py @@ -68,6 +68,12 @@ def parse_args(): type=str, default='', help='A comma separated list of operator ids to skip in quantization.') + parser.add_argument( + '--targets', + type=str, + default='quant,int8,fp32', + help='A comma separated list of inference types to run ("int8", "fp32", "quant"). Default: "quant,int8,fp32"' + ) parser.add_argument( '--debug', action='store_true', @@ -310,6 +316,12 @@ class Quant2Int8ImageClassificationComparisonTest(unittest.TestCase): assert int8_acc1 > 0.5 assert quant_acc1 - int8_acc1 <= threshold + def _strings_from_csv(self, string): + return set(s.strip() for s in string.split(',')) + + def _ints_from_csv(self, string): + return set(map(int, string.split(','))) + def test_graph_transformation(self): if not fluid.core.is_compiled_with_mkldnn(): return @@ -326,14 +338,19 @@ class Quant2Int8ImageClassificationComparisonTest(unittest.TestCase): self._debug = test_case_args.debug self._quantized_ops = set() - if len(test_case_args.ops_to_quantize) > 0: - self._quantized_ops = set( - op.strip() for op in test_case_args.ops_to_quantize.split(',')) + if test_case_args.ops_to_quantize: + self._quantized_ops = self._strings_from_csv( + test_case_args.ops_to_quantize) self._op_ids_to_skip = set([-1]) - if len(test_case_args.op_ids_to_skip) > 0: - self._op_ids_to_skip = set( - map(int, test_case_args.op_ids_to_skip.split(','))) + if test_case_args.op_ids_to_skip: + self._op_ids_to_skip = self._ints_from_csv( + test_case_args.op_ids_to_skip) + + self._targets = self._strings_from_csv(test_case_args.targets) + assert self._targets.intersection( + {'quant', 'int8', 'fp32'} + ), 'The --targets option, if used, must contain at least one of the targets: "quant", "int8", "fp32".' _logger.info('Quant & INT8 prediction run.') _logger.info('Quant model: {}'.format(quant_model_path)) @@ -348,35 +365,38 @@ class Quant2Int8ImageClassificationComparisonTest(unittest.TestCase): _logger.info('Op ids to skip quantization: {}.'.format(','.join( map(str, self._op_ids_to_skip)) if test_case_args.op_ids_to_skip else 'none')) + _logger.info('Targets: {}.'.format(','.join(self._targets))) - _logger.info('--- Quant prediction start ---') - val_reader = paddle.batch( - self._reader_creator(data_path), batch_size=batch_size) - quant_output, quant_acc1, quant_acc5, quant_fps, quant_lat = self._predict( - val_reader, - quant_model_path, - batch_size, - batch_num, - skip_batch_num, - target='quant') - self._print_performance('Quant', quant_fps, quant_lat) - self._print_accuracy('Quant', quant_acc1, quant_acc5) + if 'quant' in self._targets: + _logger.info('--- Quant prediction start ---') + val_reader = paddle.batch( + self._reader_creator(data_path), batch_size=batch_size) + quant_output, quant_acc1, quant_acc5, quant_fps, quant_lat = self._predict( + val_reader, + quant_model_path, + batch_size, + batch_num, + skip_batch_num, + target='quant') + self._print_performance('Quant', quant_fps, quant_lat) + self._print_accuracy('Quant', quant_acc1, quant_acc5) - _logger.info('--- INT8 prediction start ---') - val_reader = paddle.batch( - self._reader_creator(data_path), batch_size=batch_size) - int8_output, int8_acc1, int8_acc5, int8_fps, int8_lat = self._predict( - val_reader, - quant_model_path, - batch_size, - batch_num, - skip_batch_num, - target='int8') - self._print_performance('INT8', int8_fps, int8_lat) - self._print_accuracy('INT8', int8_acc1, int8_acc5) + if 'int8' in self._targets: + _logger.info('--- INT8 prediction start ---') + val_reader = paddle.batch( + self._reader_creator(data_path), batch_size=batch_size) + int8_output, int8_acc1, int8_acc5, int8_fps, int8_lat = self._predict( + val_reader, + quant_model_path, + batch_size, + batch_num, + skip_batch_num, + target='int8') + self._print_performance('INT8', int8_fps, int8_lat) + self._print_accuracy('INT8', int8_acc1, int8_acc5) fp32_acc1 = fp32_acc5 = fp32_fps = fp32_lat = -1 - if fp32_model_path: + if 'fp32' in self._targets and fp32_model_path: _logger.info('--- FP32 prediction start ---') val_reader = paddle.batch( self._reader_creator(data_path), batch_size=batch_size) @@ -390,10 +410,12 @@ class Quant2Int8ImageClassificationComparisonTest(unittest.TestCase): self._print_performance('FP32', fp32_fps, fp32_lat) self._print_accuracy('FP32', fp32_acc1, fp32_acc5) - self._summarize_performance(int8_fps, int8_lat, fp32_fps, fp32_lat) - self._summarize_accuracy(quant_acc1, quant_acc5, int8_acc1, int8_acc5, - fp32_acc1, fp32_acc5) - self._compare_accuracy(acc_diff_threshold, quant_acc1, int8_acc1) + if {'int8', 'fp32'}.issubset(self._targets): + self._summarize_performance(int8_fps, int8_lat, fp32_fps, fp32_lat) + if {'int8', 'quant'}.issubset(self._targets): + self._summarize_accuracy(quant_acc1, quant_acc5, int8_acc1, + int8_acc5, fp32_acc1, fp32_acc5) + self._compare_accuracy(acc_diff_threshold, quant_acc1, int8_acc1) if __name__ == '__main__': diff --git a/python/paddle/fluid/contrib/slim/tests/quant2_int8_nlp_comparison.py b/python/paddle/fluid/contrib/slim/tests/quant2_int8_nlp_comparison.py index 640d500152d..a534edb7efd 100644 --- a/python/paddle/fluid/contrib/slim/tests/quant2_int8_nlp_comparison.py +++ b/python/paddle/fluid/contrib/slim/tests/quant2_int8_nlp_comparison.py @@ -72,6 +72,12 @@ def parse_args(): type=str, default='', help='A comma separated list of operator ids to skip in quantization.') + parser.add_argument( + '--targets', + type=str, + default='quant,int8,fp32', + help='A comma separated list of inference types to run ("int8", "fp32", "quant"). Default: "quant,int8,fp32"' + ) parser.add_argument( '--debug', action='store_true', @@ -256,6 +262,12 @@ class QuantInt8NLPComparisonTest(unittest.TestCase): assert int8_acc > 0.5 assert quant_acc - int8_acc <= threshold + def _strings_from_csv(self, string): + return set(s.strip() for s in string.split(',')) + + def _ints_from_csv(self, string): + return set(map(int, string.split(','))) + def test_graph_transformation(self): if not fluid.core.is_compiled_with_mkldnn(): return @@ -274,13 +286,18 @@ class QuantInt8NLPComparisonTest(unittest.TestCase): self._quantized_ops = set() if test_case_args.ops_to_quantize: - self._quantized_ops = set( - op.strip() for op in test_case_args.ops_to_quantize.split(',')) + self._quantized_ops = self._strings_from_csv( + test_case_args.ops_to_quantize) self._op_ids_to_skip = set([-1]) if test_case_args.op_ids_to_skip: - self._op_ids_to_skip = set( - map(int, test_case_args.op_ids_to_skip.split(','))) + self._op_ids_to_skip = self._ints_from_csv( + test_case_args.op_ids_to_skip) + + self._targets = self._strings_from_csv(test_case_args.targets) + assert self._targets.intersection( + {'quant', 'int8', 'fp32'} + ), 'The --targets option, if used, must contain at least one of the targets: "quant", "int8", "fp32".' _logger.info('Quant & INT8 prediction run.') _logger.info('Quant model: {}'.format(quant_model_path)) @@ -296,35 +313,40 @@ class QuantInt8NLPComparisonTest(unittest.TestCase): _logger.info('Op ids to skip quantization: {}.'.format(','.join( map(str, self._op_ids_to_skip)) if test_case_args.op_ids_to_skip else 'none')) + _logger.info('Targets: {}.'.format(','.join(self._targets))) - _logger.info('--- Quant prediction start ---') - val_reader = paddle.batch( - self._reader_creator(data_path, labels_path), batch_size=batch_size) - quant_acc, quant_pps, quant_lat = self._predict( - val_reader, - quant_model_path, - batch_size, - batch_num, - skip_batch_num, - target='quant') - self._print_performance('Quant', quant_pps, quant_lat) - self._print_accuracy('Quant', quant_acc) + if 'quant' in self._targets: + _logger.info('--- Quant prediction start ---') + val_reader = paddle.batch( + self._reader_creator(data_path, labels_path), + batch_size=batch_size) + quant_acc, quant_pps, quant_lat = self._predict( + val_reader, + quant_model_path, + batch_size, + batch_num, + skip_batch_num, + target='quant') + self._print_performance('Quant', quant_pps, quant_lat) + self._print_accuracy('Quant', quant_acc) - _logger.info('--- INT8 prediction start ---') - val_reader = paddle.batch( - self._reader_creator(data_path, labels_path), batch_size=batch_size) - int8_acc, int8_pps, int8_lat = self._predict( - val_reader, - quant_model_path, - batch_size, - batch_num, - skip_batch_num, - target='int8') - self._print_performance('INT8', int8_pps, int8_lat) - self._print_accuracy('INT8', int8_acc) + if 'int8' in self._targets: + _logger.info('--- INT8 prediction start ---') + val_reader = paddle.batch( + self._reader_creator(data_path, labels_path), + batch_size=batch_size) + int8_acc, int8_pps, int8_lat = self._predict( + val_reader, + quant_model_path, + batch_size, + batch_num, + skip_batch_num, + target='int8') + self._print_performance('INT8', int8_pps, int8_lat) + self._print_accuracy('INT8', int8_acc) fp32_acc = fp32_pps = fp32_lat = -1 - if fp32_model_path: + if 'fp32' in self._targets and fp32_model_path: _logger.info('--- FP32 prediction start ---') val_reader = paddle.batch( self._reader_creator(data_path, labels_path), @@ -339,9 +361,11 @@ class QuantInt8NLPComparisonTest(unittest.TestCase): self._print_performance('FP32', fp32_pps, fp32_lat) self._print_accuracy('FP32', fp32_acc) - self._summarize_performance(int8_pps, int8_lat, fp32_pps, fp32_lat) - self._summarize_accuracy(quant_acc, int8_acc, fp32_acc) - self._compare_accuracy(acc_diff_threshold, quant_acc, int8_acc) + if {'int8', 'fp32'}.issubset(self._targets): + self._summarize_performance(int8_pps, int8_lat, fp32_pps, fp32_lat) + if {'int8', 'quant'}.issubset(self._targets): + self._summarize_accuracy(quant_acc, int8_acc, fp32_acc) + self._compare_accuracy(acc_diff_threshold, quant_acc, int8_acc) if __name__ == '__main__': -- GitLab