未验证 提交 14f6c74b 编写于 作者: G Guanghua Yu 提交者: GitHub

fix ptq unittest (#45447)

上级 126940b3
...@@ -326,6 +326,7 @@ class PostTrainingQuantization(object): ...@@ -326,6 +326,7 @@ class PostTrainingQuantization(object):
self._activation_quantize_type = activation_quantize_type self._activation_quantize_type = activation_quantize_type
self._weight_quantize_type = weight_quantize_type self._weight_quantize_type = weight_quantize_type
self._onnx_format = onnx_format self._onnx_format = onnx_format
self._clip_extra = True if self._onnx_format else False
self._skip_tensor_list = skip_tensor_list self._skip_tensor_list = skip_tensor_list
self._is_full_quantize = is_full_quantize self._is_full_quantize = is_full_quantize
if is_full_quantize: if is_full_quantize:
...@@ -505,7 +506,6 @@ class PostTrainingQuantization(object): ...@@ -505,7 +506,6 @@ class PostTrainingQuantization(object):
Returns: Returns:
None None
''' '''
clip_extra = True if self._onnx_format else False
io.save_inference_model(dirname=save_model_path, io.save_inference_model(dirname=save_model_path,
model_filename=model_filename, model_filename=model_filename,
params_filename=params_filename, params_filename=params_filename,
...@@ -513,7 +513,7 @@ class PostTrainingQuantization(object): ...@@ -513,7 +513,7 @@ class PostTrainingQuantization(object):
target_vars=self._fetch_list, target_vars=self._fetch_list,
executor=self._executor, executor=self._executor,
main_program=self._program, main_program=self._program,
clip_extra=clip_extra) clip_extra=self._clip_extra)
_logger.info("The quantized model is saved in " + save_model_path) _logger.info("The quantized model is saved in " + save_model_path)
def _load_model_data(self): def _load_model_data(self):
...@@ -535,6 +535,8 @@ class PostTrainingQuantization(object): ...@@ -535,6 +535,8 @@ class PostTrainingQuantization(object):
for var_name in self._feed_list] for var_name in self._feed_list]
if self._data_loader is not None: if self._data_loader is not None:
self._batch_nums = self._batch_nums if self._batch_nums else len(
self._data_loader)
return return
self._data_loader = io.DataLoader.from_generator(feed_list=feed_vars, self._data_loader = io.DataLoader.from_generator(feed_list=feed_vars,
capacity=3 * capacity=3 *
...@@ -548,6 +550,8 @@ class PostTrainingQuantization(object): ...@@ -548,6 +550,8 @@ class PostTrainingQuantization(object):
elif self._batch_generator is not None: elif self._batch_generator is not None:
self._data_loader.set_batch_generator(self._batch_generator, self._data_loader.set_batch_generator(self._batch_generator,
places=self._place) places=self._place)
self._batch_nums = self._batch_nums if self._batch_nums else len(
list(self._data_loader))
def _optimize_fp32_model(self): def _optimize_fp32_model(self):
''' '''
......
...@@ -191,6 +191,8 @@ class TestPostTrainingQuantization(unittest.TestCase): ...@@ -191,6 +191,8 @@ class TestPostTrainingQuantization(unittest.TestCase):
onnx_format=onnx_format, onnx_format=onnx_format,
is_use_cache_file=is_use_cache_file) is_use_cache_file=is_use_cache_file)
ptq.quantize() ptq.quantize()
if onnx_format:
ptq._clip_extra = False
ptq.save_quantized_model(self.int8_model_path) ptq.save_quantized_model(self.int8_model_path)
def run_test(self, def run_test(self,
...@@ -226,7 +228,7 @@ class TestPostTrainingQuantization(unittest.TestCase): ...@@ -226,7 +228,7 @@ class TestPostTrainingQuantization(unittest.TestCase):
self.generate_quantized_model(fp32_model_path, data_path, algo, self.generate_quantized_model(fp32_model_path, data_path, algo,
round_type, quantizable_op_type, round_type, quantizable_op_type,
is_full_quantize, is_use_cache_file, is_full_quantize, is_use_cache_file,
is_optimize_model, quant_iterations, is_optimize_model, 10, quant_iterations,
onnx_format) onnx_format)
print("Start INT8 inference for {0} on {1} samples ...".format( print("Start INT8 inference for {0} on {1} samples ...".format(
......
...@@ -246,6 +246,7 @@ class TestPostTrainingQuantization(unittest.TestCase): ...@@ -246,6 +246,7 @@ class TestPostTrainingQuantization(unittest.TestCase):
is_full_quantize=False, is_full_quantize=False,
is_use_cache_file=False, is_use_cache_file=False,
is_optimize_model=False, is_optimize_model=False,
batch_nums=10,
onnx_format=False): onnx_format=False):
try: try:
os.system("mkdir " + self.int8_model) os.system("mkdir " + self.int8_model)
...@@ -263,6 +264,7 @@ class TestPostTrainingQuantization(unittest.TestCase): ...@@ -263,6 +264,7 @@ class TestPostTrainingQuantization(unittest.TestCase):
sample_generator=val_reader, sample_generator=val_reader,
model_dir=model_path, model_dir=model_path,
algo=algo, algo=algo,
batch_nums=batch_nums,
quantizable_op_type=quantizable_op_type, quantizable_op_type=quantizable_op_type,
round_type=round_type, round_type=round_type,
is_full_quantize=is_full_quantize, is_full_quantize=is_full_quantize,
...@@ -283,7 +285,8 @@ class TestPostTrainingQuantization(unittest.TestCase): ...@@ -283,7 +285,8 @@ class TestPostTrainingQuantization(unittest.TestCase):
is_use_cache_file, is_use_cache_file,
is_optimize_model, is_optimize_model,
diff_threshold, diff_threshold,
onnx_format=False): onnx_format=False,
batch_nums=10):
infer_iterations = self.infer_iterations infer_iterations = self.infer_iterations
batch_size = self.batch_size batch_size = self.batch_size
sample_iterations = self.sample_iterations sample_iterations = self.sample_iterations
...@@ -301,7 +304,8 @@ class TestPostTrainingQuantization(unittest.TestCase): ...@@ -301,7 +304,8 @@ class TestPostTrainingQuantization(unittest.TestCase):
self.generate_quantized_model(os.path.join(model_cache_folder, "model"), self.generate_quantized_model(os.path.join(model_cache_folder, "model"),
quantizable_op_type, algo, round_type, quantizable_op_type, algo, round_type,
is_full_quantize, is_use_cache_file, is_full_quantize, is_use_cache_file,
is_optimize_model, onnx_format) is_optimize_model, batch_nums,
onnx_format)
print("Start INT8 inference for {0} on {1} images ...".format( print("Start INT8 inference for {0} on {1} images ...".format(
model, infer_iterations * batch_size)) model, infer_iterations * batch_size))
...@@ -392,9 +396,18 @@ class TestPostTraininghistForMobilenetv1(TestPostTrainingQuantization): ...@@ -392,9 +396,18 @@ class TestPostTraininghistForMobilenetv1(TestPostTrainingQuantization):
is_use_cache_file = False is_use_cache_file = False
is_optimize_model = True is_optimize_model = True
diff_threshold = 0.03 diff_threshold = 0.03
self.run_test(model, algo, round_type, data_urls, data_md5s, batch_nums = 3
quantizable_op_type, is_full_quantize, is_use_cache_file, self.run_test(model,
is_optimize_model, diff_threshold) algo,
round_type,
data_urls,
data_md5s,
quantizable_op_type,
is_full_quantize,
is_use_cache_file,
is_optimize_model,
diff_threshold,
batch_nums=batch_nums)
class TestPostTrainingAbsMaxForMobilenetv1(TestPostTrainingQuantization): class TestPostTrainingAbsMaxForMobilenetv1(TestPostTrainingQuantization):
...@@ -441,6 +454,7 @@ class TestPostTrainingAvgONNXFormatForMobilenetv1(TestPostTrainingQuantization): ...@@ -441,6 +454,7 @@ class TestPostTrainingAvgONNXFormatForMobilenetv1(TestPostTrainingQuantization):
is_optimize_model = True is_optimize_model = True
onnx_format = True onnx_format = True
diff_threshold = 0.05 diff_threshold = 0.05
batch_nums = 3
self.run_test(model, self.run_test(model,
algo, algo,
round_type, round_type,
...@@ -451,7 +465,8 @@ class TestPostTrainingAvgONNXFormatForMobilenetv1(TestPostTrainingQuantization): ...@@ -451,7 +465,8 @@ class TestPostTrainingAvgONNXFormatForMobilenetv1(TestPostTrainingQuantization):
is_use_cache_file, is_use_cache_file,
is_optimize_model, is_optimize_model,
diff_threshold, diff_threshold,
onnx_format=onnx_format) onnx_format=onnx_format,
batch_nums=batch_nums)
if __name__ == '__main__': if __name__ == '__main__':
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册