From 7d27520ab8ddd85b262dea9c38edf1235c47c296 Mon Sep 17 00:00:00 2001 From: Guanghua Yu <742925032@qq.com> Date: Wed, 21 Sep 2022 18:57:58 +0800 Subject: [PATCH] fix ptq unittest (#46345) --- ...test_post_training_quantization_mobilenetv1.py | 15 ++++++--------- 1 file changed, 6 insertions(+), 9 deletions(-) diff --git a/python/paddle/fluid/contrib/slim/tests/test_post_training_quantization_mobilenetv1.py b/python/paddle/fluid/contrib/slim/tests/test_post_training_quantization_mobilenetv1.py index 00a4e2c2aa4..41add5e8b8f 100644 --- a/python/paddle/fluid/contrib/slim/tests/test_post_training_quantization_mobilenetv1.py +++ b/python/paddle/fluid/contrib/slim/tests/test_post_training_quantization_mobilenetv1.py @@ -146,8 +146,6 @@ class TestPostTrainingQuantization(unittest.TestCase): os.system(cmd) self.batch_size = 1 if os.environ.get('DATASET') == 'full' else 50 - self.sample_iterations = 50 if os.environ.get( - 'DATASET') == 'full' else 2 self.infer_iterations = 50000 if os.environ.get( 'DATASET') == 'full' else 2 @@ -291,7 +289,6 @@ class TestPostTrainingQuantization(unittest.TestCase): batch_nums=10): infer_iterations = self.infer_iterations batch_size = self.batch_size - sample_iterations = self.sample_iterations model_cache_folder = self.download_data(data_urls, data_md5s, model) @@ -302,13 +299,12 @@ class TestPostTrainingQuantization(unittest.TestCase): infer_iterations) print("Start INT8 post training quantization for {0} on {1} images ...". - format(model, sample_iterations * batch_size)) + format(model, batch_nums * batch_size)) self.generate_quantized_model(os.path.join(model_cache_folder, "model"), - quantizable_op_type, batch_size, - sample_iterations, algo, round_type, - is_full_quantize, is_use_cache_file, - is_optimize_model, batch_nums, - onnx_format) + quantizable_op_type, batch_size, algo, + round_type, is_full_quantize, + is_use_cache_file, is_optimize_model, + batch_nums, onnx_format) print("Start INT8 inference for {0} on {1} images ...".format( model, infer_iterations * batch_size)) @@ -351,6 +347,7 @@ class TestPostTrainingKLForMobilenetv1(TestPostTrainingQuantization): is_use_cache_file = False is_optimize_model = True diff_threshold = 0.025 + batch_nums = 3 self.run_test(model, algo, round_type, data_urls, data_md5s, quantizable_op_type, is_full_quantize, is_use_cache_file, is_optimize_model, diff_threshold) -- GitLab