未验证 提交 b339dff2 编写于 作者: J juncaipeng 提交者: GitHub

fix use cache file, test=develop (#22240)

上级 895f8da7
...@@ -236,6 +236,7 @@ class TestPostTrainingQuantization(unittest.TestCase): ...@@ -236,6 +236,7 @@ class TestPostTrainingQuantization(unittest.TestCase):
def generate_quantized_model(self, def generate_quantized_model(self,
model_path, model_path,
quantizable_op_type,
algo="KL", algo="KL",
is_full_quantize=False, is_full_quantize=False,
is_use_cache_file=False): is_use_cache_file=False):
...@@ -250,9 +251,6 @@ class TestPostTrainingQuantization(unittest.TestCase): ...@@ -250,9 +251,6 @@ class TestPostTrainingQuantization(unittest.TestCase):
exe = fluid.Executor(place) exe = fluid.Executor(place)
scope = fluid.global_scope() scope = fluid.global_scope()
val_reader = val() val_reader = val()
quantizable_op_type = [
"conv2d", "depthwise_conv2d", "mul", "pool2d", "elementwise_add"
]
ptq = PostTrainingQuantization( ptq = PostTrainingQuantization(
executor=exe, executor=exe,
...@@ -265,8 +263,8 @@ class TestPostTrainingQuantization(unittest.TestCase): ...@@ -265,8 +263,8 @@ class TestPostTrainingQuantization(unittest.TestCase):
ptq.quantize() ptq.quantize()
ptq.save_quantized_model(self.int8_model) ptq.save_quantized_model(self.int8_model)
def run_test(self, model, algo, data_urls, data_md5s, is_full_quantize, def run_test(self, model, algo, data_urls, data_md5s, quantizable_op_type,
is_use_cache_file): is_full_quantize, is_use_cache_file):
infer_iterations = self.infer_iterations infer_iterations = self.infer_iterations
batch_size = self.batch_size batch_size = self.batch_size
sample_iterations = self.sample_iterations sample_iterations = self.sample_iterations
...@@ -280,7 +278,8 @@ class TestPostTrainingQuantization(unittest.TestCase): ...@@ -280,7 +278,8 @@ class TestPostTrainingQuantization(unittest.TestCase):
print("Start INT8 post training quantization for {0} on {1} images ...". print("Start INT8 post training quantization for {0} on {1} images ...".
format(model, sample_iterations * batch_size)) format(model, sample_iterations * batch_size))
self.generate_quantized_model(model_cache_folder + "/model", algo, self.generate_quantized_model(model_cache_folder + "/model",
quantizable_op_type, algo,
is_full_quantize, is_use_cache_file) is_full_quantize, is_use_cache_file)
print("Start INT8 inference for {0} on {1} images ...".format( print("Start INT8 inference for {0} on {1} images ...".format(
...@@ -308,10 +307,13 @@ class TestPostTrainingForMobilenetv1(TestPostTrainingQuantization): ...@@ -308,10 +307,13 @@ class TestPostTrainingForMobilenetv1(TestPostTrainingQuantization):
'http://paddle-inference-dist.bj.bcebos.com/int8/mobilenetv1_int8_model.tar.gz' 'http://paddle-inference-dist.bj.bcebos.com/int8/mobilenetv1_int8_model.tar.gz'
] ]
data_md5s = ['13892b0716d26443a8cdea15b3c6438b'] data_md5s = ['13892b0716d26443a8cdea15b3c6438b']
quantizable_op_type = [
"conv2d", "depthwise_conv2d", "mul", "pool2d", "elementwise_add"
]
is_full_quantize = True is_full_quantize = True
is_use_cache_file = False is_use_cache_file = False
self.run_test(model, algo, data_urls, data_md5s, is_full_quantize, self.run_test(model, algo, data_urls, data_md5s, quantizable_op_type,
is_use_cache_file) is_full_quantize, is_use_cache_file)
if __name__ == '__main__': if __name__ == '__main__':
......
...@@ -25,10 +25,11 @@ class TestPostTrainingForResnet50(TestPostTrainingQuantization): ...@@ -25,10 +25,11 @@ class TestPostTrainingForResnet50(TestPostTrainingQuantization):
'http://paddle-inference-dist.bj.bcebos.com/int8/resnet50_int8_model.tar.gz' 'http://paddle-inference-dist.bj.bcebos.com/int8/resnet50_int8_model.tar.gz'
] ]
data_md5s = ['4a5194524823d9b76da6e738e1367881'] data_md5s = ['4a5194524823d9b76da6e738e1367881']
quantizable_op_type = ["conv2d", "mul"]
is_full_quantize = False is_full_quantize = False
is_use_cache_file = True is_use_cache_file = False
self.run_test(model, algo, data_urls, data_md5s, is_full_quantize, self.run_test(model, algo, data_urls, data_md5s, quantizable_op_type,
is_use_cache_file) is_full_quantize, is_use_cache_file)
if __name__ == '__main__': if __name__ == '__main__':
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册