未验证 提交 b621a4f1 编写于 作者: G Guanghua Yu 提交者: GitHub

support skip_op_list in PostTrainingQuantization (#42378)

上级 87afccb2
......@@ -126,6 +126,7 @@ class PostTrainingQuantization(object):
onnx_format=False,
optimize_model=False,
is_use_cache_file=False,
skip_tensor_list=None,
cache_dir=None):
'''
Constructor.
......@@ -198,6 +199,7 @@ class PostTrainingQuantization(object):
the model accuracy is usually higher when it is 'channel_wise_abs_max'.
onnx_format(bool): Whether to export the quantized model with format of ONNX.
Default is False.
skip_tensor_list(list): List of skip quant tensor name.
optimize_model(bool, optional): If set optimize_model as True, it applies
some passes to the model before quantization, and it supports
`conv2d/depthwise_conv2d + bn` pass so far. Some targets require the
......@@ -301,6 +303,7 @@ class PostTrainingQuantization(object):
self._activation_quantize_type = activation_quantize_type
self._weight_quantize_type = weight_quantize_type
self._onnx_format = onnx_format
self._skip_tensor_list = skip_tensor_list
self._is_full_quantize = is_full_quantize
if is_full_quantize:
self._quantizable_op_type = self._support_quantize_op_type
......@@ -547,6 +550,12 @@ class PostTrainingQuantization(object):
persistable_var_names = _all_persistable_var_names(self._program)
for block_id in range(len(self._program.blocks)):
for op in self._program.blocks[block_id].ops:
# skip quant form self._skip_tensor_list
if self._skip_tensor_list is not None:
for inp_name in utils._get_op_input_var_names(op):
if inp_name in self._skip_tensor_list:
op._set_attr("op_namescope", "skip_quant")
op_type = op.type
if self._is_full_quantize and \
op_type not in self._quantizable_op_type:
......
......@@ -117,7 +117,8 @@ class TestPostTrainingQuantization(unittest.TestCase):
is_optimize_model=False,
batch_size=10,
batch_nums=10,
onnx_format=False):
onnx_format=False,
skip_tensor_list=None):
place = fluid.CPUPlace()
exe = fluid.Executor(place)
......@@ -136,6 +137,7 @@ class TestPostTrainingQuantization(unittest.TestCase):
is_full_quantize=is_full_quantize,
optimize_model=is_optimize_model,
onnx_format=onnx_format,
skip_tensor_list=skip_tensor_list,
is_use_cache_file=is_use_cache_file)
ptq.quantize()
ptq.save_quantized_model(self.int8_model_path)
......@@ -154,7 +156,8 @@ class TestPostTrainingQuantization(unittest.TestCase):
batch_size=10,
infer_iterations=10,
quant_iterations=5,
onnx_format=False):
onnx_format=False,
skip_tensor_list=None):
origin_model_path = self.download_model(data_url, data_md5, model_name)
origin_model_path = os.path.join(origin_model_path, model_name)
......@@ -166,10 +169,10 @@ class TestPostTrainingQuantization(unittest.TestCase):
print("Start INT8 post training quantization for {0} on {1} images ...".
format(model_name, quant_iterations * batch_size))
self.generate_quantized_model(origin_model_path, algo, round_type,
quantizable_op_type, is_full_quantize,
is_use_cache_file, is_optimize_model,
batch_size, quant_iterations, onnx_format)
self.generate_quantized_model(
origin_model_path, algo, round_type, quantizable_op_type,
is_full_quantize, is_use_cache_file, is_optimize_model, batch_size,
quant_iterations, onnx_format, skip_tensor_list)
print("Start INT8 inference for {0} on {1} images ...".format(
model_name, infer_iterations * batch_size))
......@@ -426,5 +429,38 @@ class TestPostTrainingmseForMnistONNXFormatFullQuant(
onnx_format=onnx_format)
class TestPostTrainingavgForMnistSkipOP(TestPostTrainingQuantization):
def test_post_training_avg_skip_op(self):
model_name = "mnist_model"
data_url = "http://paddle-inference-dist.bj.bcebos.com/int8/mnist_model.tar.gz"
data_md5 = "be71d3997ec35ac2a65ae8a145e2887c"
algo = "avg"
round_type = "round"
quantizable_op_type = ["conv2d", "depthwise_conv2d", "mul"]
is_full_quantize = False
is_use_cache_file = False
is_optimize_model = True
diff_threshold = 0.01
batch_size = 10
infer_iterations = 50
quant_iterations = 5
skip_tensor_list = ["fc_0.w_0"]
self.run_test(
model_name,
data_url,
data_md5,
algo,
round_type,
quantizable_op_type,
is_full_quantize,
is_use_cache_file,
is_optimize_model,
diff_threshold,
batch_size,
infer_iterations,
quant_iterations,
skip_tensor_list=skip_tensor_list)
if __name__ == '__main__':
unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册