From 14f6c74bc106c01986e8636ec30fe876460ad828 Mon Sep 17 00:00:00 2001 From: Guanghua Yu <742925032@qq.com> Date: Fri, 26 Aug 2022 16:21:03 +0800 Subject: [PATCH] fix ptq unittest (#45447) --- .../post_training_quantization.py | 8 ++++-- ...t_post_training_quantization_lstm_model.py | 4 ++- ..._post_training_quantization_mobilenetv1.py | 27 ++++++++++++++----- 3 files changed, 30 insertions(+), 9 deletions(-) diff --git a/python/paddle/fluid/contrib/slim/quantization/post_training_quantization.py b/python/paddle/fluid/contrib/slim/quantization/post_training_quantization.py index bd8fa2a072d..668aa12210c 100644 --- a/python/paddle/fluid/contrib/slim/quantization/post_training_quantization.py +++ b/python/paddle/fluid/contrib/slim/quantization/post_training_quantization.py @@ -326,6 +326,7 @@ class PostTrainingQuantization(object): self._activation_quantize_type = activation_quantize_type self._weight_quantize_type = weight_quantize_type self._onnx_format = onnx_format + self._clip_extra = True if self._onnx_format else False self._skip_tensor_list = skip_tensor_list self._is_full_quantize = is_full_quantize if is_full_quantize: @@ -505,7 +506,6 @@ class PostTrainingQuantization(object): Returns: None ''' - clip_extra = True if self._onnx_format else False io.save_inference_model(dirname=save_model_path, model_filename=model_filename, params_filename=params_filename, @@ -513,7 +513,7 @@ class PostTrainingQuantization(object): target_vars=self._fetch_list, executor=self._executor, main_program=self._program, - clip_extra=clip_extra) + clip_extra=self._clip_extra) _logger.info("The quantized model is saved in " + save_model_path) def _load_model_data(self): @@ -535,6 +535,8 @@ class PostTrainingQuantization(object): for var_name in self._feed_list] if self._data_loader is not None: + self._batch_nums = self._batch_nums if self._batch_nums else len( + self._data_loader) return self._data_loader = io.DataLoader.from_generator(feed_list=feed_vars, capacity=3 * @@ -548,6 +550,8 @@ class PostTrainingQuantization(object): elif self._batch_generator is not None: self._data_loader.set_batch_generator(self._batch_generator, places=self._place) + self._batch_nums = self._batch_nums if self._batch_nums else len( + list(self._data_loader)) def _optimize_fp32_model(self): ''' diff --git a/python/paddle/fluid/contrib/slim/tests/test_post_training_quantization_lstm_model.py b/python/paddle/fluid/contrib/slim/tests/test_post_training_quantization_lstm_model.py index 6100ed4f82a..575d0826b27 100644 --- a/python/paddle/fluid/contrib/slim/tests/test_post_training_quantization_lstm_model.py +++ b/python/paddle/fluid/contrib/slim/tests/test_post_training_quantization_lstm_model.py @@ -191,6 +191,8 @@ class TestPostTrainingQuantization(unittest.TestCase): onnx_format=onnx_format, is_use_cache_file=is_use_cache_file) ptq.quantize() + if onnx_format: + ptq._clip_extra = False ptq.save_quantized_model(self.int8_model_path) def run_test(self, @@ -226,7 +228,7 @@ class TestPostTrainingQuantization(unittest.TestCase): self.generate_quantized_model(fp32_model_path, data_path, algo, round_type, quantizable_op_type, is_full_quantize, is_use_cache_file, - is_optimize_model, quant_iterations, + is_optimize_model, 10, quant_iterations, onnx_format) print("Start INT8 inference for {0} on {1} samples ...".format( diff --git a/python/paddle/fluid/contrib/slim/tests/test_post_training_quantization_mobilenetv1.py b/python/paddle/fluid/contrib/slim/tests/test_post_training_quantization_mobilenetv1.py index cb6d685f721..fc675ed4a07 100644 --- a/python/paddle/fluid/contrib/slim/tests/test_post_training_quantization_mobilenetv1.py +++ b/python/paddle/fluid/contrib/slim/tests/test_post_training_quantization_mobilenetv1.py @@ -246,6 +246,7 @@ class TestPostTrainingQuantization(unittest.TestCase): is_full_quantize=False, is_use_cache_file=False, is_optimize_model=False, + batch_nums=10, onnx_format=False): try: os.system("mkdir " + self.int8_model) @@ -263,6 +264,7 @@ class TestPostTrainingQuantization(unittest.TestCase): sample_generator=val_reader, model_dir=model_path, algo=algo, + batch_nums=batch_nums, quantizable_op_type=quantizable_op_type, round_type=round_type, is_full_quantize=is_full_quantize, @@ -283,7 +285,8 @@ class TestPostTrainingQuantization(unittest.TestCase): is_use_cache_file, is_optimize_model, diff_threshold, - onnx_format=False): + onnx_format=False, + batch_nums=10): infer_iterations = self.infer_iterations batch_size = self.batch_size sample_iterations = self.sample_iterations @@ -301,7 +304,8 @@ class TestPostTrainingQuantization(unittest.TestCase): self.generate_quantized_model(os.path.join(model_cache_folder, "model"), quantizable_op_type, algo, round_type, is_full_quantize, is_use_cache_file, - is_optimize_model, onnx_format) + is_optimize_model, batch_nums, + onnx_format) print("Start INT8 inference for {0} on {1} images ...".format( model, infer_iterations * batch_size)) @@ -392,9 +396,18 @@ class TestPostTraininghistForMobilenetv1(TestPostTrainingQuantization): is_use_cache_file = False is_optimize_model = True diff_threshold = 0.03 - self.run_test(model, algo, round_type, data_urls, data_md5s, - quantizable_op_type, is_full_quantize, is_use_cache_file, - is_optimize_model, diff_threshold) + batch_nums = 3 + self.run_test(model, + algo, + round_type, + data_urls, + data_md5s, + quantizable_op_type, + is_full_quantize, + is_use_cache_file, + is_optimize_model, + diff_threshold, + batch_nums=batch_nums) class TestPostTrainingAbsMaxForMobilenetv1(TestPostTrainingQuantization): @@ -441,6 +454,7 @@ class TestPostTrainingAvgONNXFormatForMobilenetv1(TestPostTrainingQuantization): is_optimize_model = True onnx_format = True diff_threshold = 0.05 + batch_nums = 3 self.run_test(model, algo, round_type, @@ -451,7 +465,8 @@ class TestPostTrainingAvgONNXFormatForMobilenetv1(TestPostTrainingQuantization): is_use_cache_file, is_optimize_model, diff_threshold, - onnx_format=onnx_format) + onnx_format=onnx_format, + batch_nums=batch_nums) if __name__ == '__main__': -- GitLab