diff --git a/python/paddle/fluid/contrib/slim/tests/test_post_training_quantization_mobilenetv1.py b/python/paddle/fluid/contrib/slim/tests/test_post_training_quantization_mobilenetv1.py index 00a4e2c2aa49e03f7dfb41bca640624d6862f84c..41add5e8b8fb1eb41835873b086013cea83ad4a2 100644 --- a/python/paddle/fluid/contrib/slim/tests/test_post_training_quantization_mobilenetv1.py +++ b/python/paddle/fluid/contrib/slim/tests/test_post_training_quantization_mobilenetv1.py @@ -146,8 +146,6 @@ class TestPostTrainingQuantization(unittest.TestCase): os.system(cmd) self.batch_size = 1 if os.environ.get('DATASET') == 'full' else 50 - self.sample_iterations = 50 if os.environ.get( - 'DATASET') == 'full' else 2 self.infer_iterations = 50000 if os.environ.get( 'DATASET') == 'full' else 2 @@ -291,7 +289,6 @@ class TestPostTrainingQuantization(unittest.TestCase): batch_nums=10): infer_iterations = self.infer_iterations batch_size = self.batch_size - sample_iterations = self.sample_iterations model_cache_folder = self.download_data(data_urls, data_md5s, model) @@ -302,13 +299,12 @@ class TestPostTrainingQuantization(unittest.TestCase): infer_iterations) print("Start INT8 post training quantization for {0} on {1} images ...". - format(model, sample_iterations * batch_size)) + format(model, batch_nums * batch_size)) self.generate_quantized_model(os.path.join(model_cache_folder, "model"), - quantizable_op_type, batch_size, - sample_iterations, algo, round_type, - is_full_quantize, is_use_cache_file, - is_optimize_model, batch_nums, - onnx_format) + quantizable_op_type, batch_size, algo, + round_type, is_full_quantize, + is_use_cache_file, is_optimize_model, + batch_nums, onnx_format) print("Start INT8 inference for {0} on {1} images ...".format( model, infer_iterations * batch_size)) @@ -351,6 +347,7 @@ class TestPostTrainingKLForMobilenetv1(TestPostTrainingQuantization): is_use_cache_file = False is_optimize_model = True diff_threshold = 0.025 + batch_nums = 3 self.run_test(model, algo, round_type, data_urls, data_md5s, quantizable_op_type, is_full_quantize, is_use_cache_file, is_optimize_model, diff_threshold)