未验证 提交 7d27520a 编写于 作者: G Guanghua Yu 提交者: GitHub

fix ptq unittest (#46345)

上级 3aa6bd57
...@@ -146,8 +146,6 @@ class TestPostTrainingQuantization(unittest.TestCase): ...@@ -146,8 +146,6 @@ class TestPostTrainingQuantization(unittest.TestCase):
os.system(cmd) os.system(cmd)
self.batch_size = 1 if os.environ.get('DATASET') == 'full' else 50 self.batch_size = 1 if os.environ.get('DATASET') == 'full' else 50
self.sample_iterations = 50 if os.environ.get(
'DATASET') == 'full' else 2
self.infer_iterations = 50000 if os.environ.get( self.infer_iterations = 50000 if os.environ.get(
'DATASET') == 'full' else 2 'DATASET') == 'full' else 2
...@@ -291,7 +289,6 @@ class TestPostTrainingQuantization(unittest.TestCase): ...@@ -291,7 +289,6 @@ class TestPostTrainingQuantization(unittest.TestCase):
batch_nums=10): batch_nums=10):
infer_iterations = self.infer_iterations infer_iterations = self.infer_iterations
batch_size = self.batch_size batch_size = self.batch_size
sample_iterations = self.sample_iterations
model_cache_folder = self.download_data(data_urls, data_md5s, model) model_cache_folder = self.download_data(data_urls, data_md5s, model)
...@@ -302,13 +299,12 @@ class TestPostTrainingQuantization(unittest.TestCase): ...@@ -302,13 +299,12 @@ class TestPostTrainingQuantization(unittest.TestCase):
infer_iterations) infer_iterations)
print("Start INT8 post training quantization for {0} on {1} images ...". print("Start INT8 post training quantization for {0} on {1} images ...".
format(model, sample_iterations * batch_size)) format(model, batch_nums * batch_size))
self.generate_quantized_model(os.path.join(model_cache_folder, "model"), self.generate_quantized_model(os.path.join(model_cache_folder, "model"),
quantizable_op_type, batch_size, quantizable_op_type, batch_size, algo,
sample_iterations, algo, round_type, round_type, is_full_quantize,
is_full_quantize, is_use_cache_file, is_use_cache_file, is_optimize_model,
is_optimize_model, batch_nums, batch_nums, onnx_format)
onnx_format)
print("Start INT8 inference for {0} on {1} images ...".format( print("Start INT8 inference for {0} on {1} images ...".format(
model, infer_iterations * batch_size)) model, infer_iterations * batch_size))
...@@ -351,6 +347,7 @@ class TestPostTrainingKLForMobilenetv1(TestPostTrainingQuantization): ...@@ -351,6 +347,7 @@ class TestPostTrainingKLForMobilenetv1(TestPostTrainingQuantization):
is_use_cache_file = False is_use_cache_file = False
is_optimize_model = True is_optimize_model = True
diff_threshold = 0.025 diff_threshold = 0.025
batch_nums = 3
self.run_test(model, algo, round_type, data_urls, data_md5s, self.run_test(model, algo, round_type, data_urls, data_md5s,
quantizable_op_type, is_full_quantize, is_use_cache_file, quantizable_op_type, is_full_quantize, is_use_cache_file,
is_optimize_model, diff_threshold) is_optimize_model, diff_threshold)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册