diff --git a/python/paddle/fluid/contrib/slim/tests/test_post_training_quantization_mobilenetv1.py b/python/paddle/fluid/contrib/slim/tests/test_post_training_quantization_mobilenetv1.py index 0cd804cc514a5caf2277b71cb304e7bc60d5e62f..e1be7c6809d4ac0c0d2a622a55161cfcca894f42 100644 --- a/python/paddle/fluid/contrib/slim/tests/test_post_training_quantization_mobilenetv1.py +++ b/python/paddle/fluid/contrib/slim/tests/test_post_training_quantization_mobilenetv1.py @@ -236,6 +236,7 @@ class TestPostTrainingQuantization(unittest.TestCase): def generate_quantized_model(self, model_path, + quantizable_op_type, algo="KL", is_full_quantize=False, is_use_cache_file=False): @@ -250,9 +251,6 @@ class TestPostTrainingQuantization(unittest.TestCase): exe = fluid.Executor(place) scope = fluid.global_scope() val_reader = val() - quantizable_op_type = [ - "conv2d", "depthwise_conv2d", "mul", "pool2d", "elementwise_add" - ] ptq = PostTrainingQuantization( executor=exe, @@ -265,8 +263,8 @@ class TestPostTrainingQuantization(unittest.TestCase): ptq.quantize() ptq.save_quantized_model(self.int8_model) - def run_test(self, model, algo, data_urls, data_md5s, is_full_quantize, - is_use_cache_file): + def run_test(self, model, algo, data_urls, data_md5s, quantizable_op_type, + is_full_quantize, is_use_cache_file): infer_iterations = self.infer_iterations batch_size = self.batch_size sample_iterations = self.sample_iterations @@ -280,7 +278,8 @@ class TestPostTrainingQuantization(unittest.TestCase): print("Start INT8 post training quantization for {0} on {1} images ...". format(model, sample_iterations * batch_size)) - self.generate_quantized_model(model_cache_folder + "/model", algo, + self.generate_quantized_model(model_cache_folder + "/model", + quantizable_op_type, algo, is_full_quantize, is_use_cache_file) print("Start INT8 inference for {0} on {1} images ...".format( @@ -308,10 +307,13 @@ class TestPostTrainingForMobilenetv1(TestPostTrainingQuantization): 'http://paddle-inference-dist.bj.bcebos.com/int8/mobilenetv1_int8_model.tar.gz' ] data_md5s = ['13892b0716d26443a8cdea15b3c6438b'] + quantizable_op_type = [ + "conv2d", "depthwise_conv2d", "mul", "pool2d", "elementwise_add" + ] is_full_quantize = True is_use_cache_file = False - self.run_test(model, algo, data_urls, data_md5s, is_full_quantize, - is_use_cache_file) + self.run_test(model, algo, data_urls, data_md5s, quantizable_op_type, + is_full_quantize, is_use_cache_file) if __name__ == '__main__': diff --git a/python/paddle/fluid/contrib/slim/tests/test_post_training_quantization_resnet50.py b/python/paddle/fluid/contrib/slim/tests/test_post_training_quantization_resnet50.py index 815f2e4332955ae0096124cdfda57b0aa3e872e1..93d84112524e7e302ec22f99354e6169c512800e 100644 --- a/python/paddle/fluid/contrib/slim/tests/test_post_training_quantization_resnet50.py +++ b/python/paddle/fluid/contrib/slim/tests/test_post_training_quantization_resnet50.py @@ -25,10 +25,11 @@ class TestPostTrainingForResnet50(TestPostTrainingQuantization): 'http://paddle-inference-dist.bj.bcebos.com/int8/resnet50_int8_model.tar.gz' ] data_md5s = ['4a5194524823d9b76da6e738e1367881'] + quantizable_op_type = ["conv2d", "mul"] is_full_quantize = False - is_use_cache_file = True - self.run_test(model, algo, data_urls, data_md5s, is_full_quantize, - is_use_cache_file) + is_use_cache_file = False + self.run_test(model, algo, data_urls, data_md5s, quantizable_op_type, + is_full_quantize, is_use_cache_file) if __name__ == '__main__':