未验证 提交 14f6c74b 编写于 作者: G Guanghua Yu 提交者: GitHub

fix ptq unittest (#45447)

上级 126940b3
......@@ -326,6 +326,7 @@ class PostTrainingQuantization(object):
self._activation_quantize_type = activation_quantize_type
self._weight_quantize_type = weight_quantize_type
self._onnx_format = onnx_format
self._clip_extra = True if self._onnx_format else False
self._skip_tensor_list = skip_tensor_list
self._is_full_quantize = is_full_quantize
if is_full_quantize:
......@@ -505,7 +506,6 @@ class PostTrainingQuantization(object):
Returns:
None
'''
clip_extra = True if self._onnx_format else False
io.save_inference_model(dirname=save_model_path,
model_filename=model_filename,
params_filename=params_filename,
......@@ -513,7 +513,7 @@ class PostTrainingQuantization(object):
target_vars=self._fetch_list,
executor=self._executor,
main_program=self._program,
clip_extra=clip_extra)
clip_extra=self._clip_extra)
_logger.info("The quantized model is saved in " + save_model_path)
def _load_model_data(self):
......@@ -535,6 +535,8 @@ class PostTrainingQuantization(object):
for var_name in self._feed_list]
if self._data_loader is not None:
self._batch_nums = self._batch_nums if self._batch_nums else len(
self._data_loader)
return
self._data_loader = io.DataLoader.from_generator(feed_list=feed_vars,
capacity=3 *
......@@ -548,6 +550,8 @@ class PostTrainingQuantization(object):
elif self._batch_generator is not None:
self._data_loader.set_batch_generator(self._batch_generator,
places=self._place)
self._batch_nums = self._batch_nums if self._batch_nums else len(
list(self._data_loader))
def _optimize_fp32_model(self):
'''
......
......@@ -191,6 +191,8 @@ class TestPostTrainingQuantization(unittest.TestCase):
onnx_format=onnx_format,
is_use_cache_file=is_use_cache_file)
ptq.quantize()
if onnx_format:
ptq._clip_extra = False
ptq.save_quantized_model(self.int8_model_path)
def run_test(self,
......@@ -226,7 +228,7 @@ class TestPostTrainingQuantization(unittest.TestCase):
self.generate_quantized_model(fp32_model_path, data_path, algo,
round_type, quantizable_op_type,
is_full_quantize, is_use_cache_file,
is_optimize_model, quant_iterations,
is_optimize_model, 10, quant_iterations,
onnx_format)
print("Start INT8 inference for {0} on {1} samples ...".format(
......
......@@ -246,6 +246,7 @@ class TestPostTrainingQuantization(unittest.TestCase):
is_full_quantize=False,
is_use_cache_file=False,
is_optimize_model=False,
batch_nums=10,
onnx_format=False):
try:
os.system("mkdir " + self.int8_model)
......@@ -263,6 +264,7 @@ class TestPostTrainingQuantization(unittest.TestCase):
sample_generator=val_reader,
model_dir=model_path,
algo=algo,
batch_nums=batch_nums,
quantizable_op_type=quantizable_op_type,
round_type=round_type,
is_full_quantize=is_full_quantize,
......@@ -283,7 +285,8 @@ class TestPostTrainingQuantization(unittest.TestCase):
is_use_cache_file,
is_optimize_model,
diff_threshold,
onnx_format=False):
onnx_format=False,
batch_nums=10):
infer_iterations = self.infer_iterations
batch_size = self.batch_size
sample_iterations = self.sample_iterations
......@@ -301,7 +304,8 @@ class TestPostTrainingQuantization(unittest.TestCase):
self.generate_quantized_model(os.path.join(model_cache_folder, "model"),
quantizable_op_type, algo, round_type,
is_full_quantize, is_use_cache_file,
is_optimize_model, onnx_format)
is_optimize_model, batch_nums,
onnx_format)
print("Start INT8 inference for {0} on {1} images ...".format(
model, infer_iterations * batch_size))
......@@ -392,9 +396,18 @@ class TestPostTraininghistForMobilenetv1(TestPostTrainingQuantization):
is_use_cache_file = False
is_optimize_model = True
diff_threshold = 0.03
self.run_test(model, algo, round_type, data_urls, data_md5s,
quantizable_op_type, is_full_quantize, is_use_cache_file,
is_optimize_model, diff_threshold)
batch_nums = 3
self.run_test(model,
algo,
round_type,
data_urls,
data_md5s,
quantizable_op_type,
is_full_quantize,
is_use_cache_file,
is_optimize_model,
diff_threshold,
batch_nums=batch_nums)
class TestPostTrainingAbsMaxForMobilenetv1(TestPostTrainingQuantization):
......@@ -441,6 +454,7 @@ class TestPostTrainingAvgONNXFormatForMobilenetv1(TestPostTrainingQuantization):
is_optimize_model = True
onnx_format = True
diff_threshold = 0.05
batch_nums = 3
self.run_test(model,
algo,
round_type,
......@@ -451,7 +465,8 @@ class TestPostTrainingAvgONNXFormatForMobilenetv1(TestPostTrainingQuantization):
is_use_cache_file,
is_optimize_model,
diff_threshold,
onnx_format=onnx_format)
onnx_format=onnx_format,
batch_nums=batch_nums)
if __name__ == '__main__':
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册