未验证 提交 b339dff2 编写于 作者: J juncaipeng 提交者: GitHub

fix use cache file, test=develop (#22240)

上级 895f8da7
......@@ -236,6 +236,7 @@ class TestPostTrainingQuantization(unittest.TestCase):
def generate_quantized_model(self,
model_path,
quantizable_op_type,
algo="KL",
is_full_quantize=False,
is_use_cache_file=False):
......@@ -250,9 +251,6 @@ class TestPostTrainingQuantization(unittest.TestCase):
exe = fluid.Executor(place)
scope = fluid.global_scope()
val_reader = val()
quantizable_op_type = [
"conv2d", "depthwise_conv2d", "mul", "pool2d", "elementwise_add"
]
ptq = PostTrainingQuantization(
executor=exe,
......@@ -265,8 +263,8 @@ class TestPostTrainingQuantization(unittest.TestCase):
ptq.quantize()
ptq.save_quantized_model(self.int8_model)
def run_test(self, model, algo, data_urls, data_md5s, is_full_quantize,
is_use_cache_file):
def run_test(self, model, algo, data_urls, data_md5s, quantizable_op_type,
is_full_quantize, is_use_cache_file):
infer_iterations = self.infer_iterations
batch_size = self.batch_size
sample_iterations = self.sample_iterations
......@@ -280,7 +278,8 @@ class TestPostTrainingQuantization(unittest.TestCase):
print("Start INT8 post training quantization for {0} on {1} images ...".
format(model, sample_iterations * batch_size))
self.generate_quantized_model(model_cache_folder + "/model", algo,
self.generate_quantized_model(model_cache_folder + "/model",
quantizable_op_type, algo,
is_full_quantize, is_use_cache_file)
print("Start INT8 inference for {0} on {1} images ...".format(
......@@ -308,10 +307,13 @@ class TestPostTrainingForMobilenetv1(TestPostTrainingQuantization):
'http://paddle-inference-dist.bj.bcebos.com/int8/mobilenetv1_int8_model.tar.gz'
]
data_md5s = ['13892b0716d26443a8cdea15b3c6438b']
quantizable_op_type = [
"conv2d", "depthwise_conv2d", "mul", "pool2d", "elementwise_add"
]
is_full_quantize = True
is_use_cache_file = False
self.run_test(model, algo, data_urls, data_md5s, is_full_quantize,
is_use_cache_file)
self.run_test(model, algo, data_urls, data_md5s, quantizable_op_type,
is_full_quantize, is_use_cache_file)
if __name__ == '__main__':
......
......@@ -25,10 +25,11 @@ class TestPostTrainingForResnet50(TestPostTrainingQuantization):
'http://paddle-inference-dist.bj.bcebos.com/int8/resnet50_int8_model.tar.gz'
]
data_md5s = ['4a5194524823d9b76da6e738e1367881']
quantizable_op_type = ["conv2d", "mul"]
is_full_quantize = False
is_use_cache_file = True
self.run_test(model, algo, data_urls, data_md5s, is_full_quantize,
is_use_cache_file)
is_use_cache_file = False
self.run_test(model, algo, data_urls, data_md5s, quantizable_op_type,
is_full_quantize, is_use_cache_file)
if __name__ == '__main__':
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册