未验证 提交 0a00fc4e 编写于 作者: G Guanghua Yu 提交者: GitHub

cherry pick #42255 (fuse conv + bn in QAT) and #42378 (support skip_op_list in PTQ) (#43301)

* support fuse conv and bn in QAT (#42255)

* support skip_op_list in PostTrainingQuantization (#42378)

* fix unittest
上级 f4e09397
...@@ -28,6 +28,27 @@ class Identity(nn.Layer): ...@@ -28,6 +28,27 @@ class Identity(nn.Layer):
return input return input
def fuse_conv_bn(model):
is_train = False
if model.training:
model.eval()
is_train = True
fuse_list = []
tmp_pair = [None, None]
for name, layer in model.named_sublayers():
if isinstance(layer, nn.Conv2D):
tmp_pair[0] = name
if isinstance(layer, nn.BatchNorm2D):
tmp_pair[1] = name
if tmp_pair[0] and tmp_pair[1] and len(tmp_pair) == 2:
fuse_list.append(tmp_pair)
tmp_pair = [None, None]
model = fuse_layers(model, fuse_list)
if is_train:
model.train()
def fuse_layers(model, layers_to_fuse, inplace=False): def fuse_layers(model, layers_to_fuse, inplace=False):
''' '''
fuse layers in layers_to_fuse fuse layers in layers_to_fuse
......
...@@ -20,6 +20,7 @@ import os ...@@ -20,6 +20,7 @@ import os
import warnings import warnings
import paddle import paddle
import paddle.nn as nn
import paddle.nn.quant.quant_layers as quant_layers import paddle.nn.quant.quant_layers as quant_layers
from paddle.fluid import dygraph, core, framework, unique_name from paddle.fluid import dygraph, core, framework, unique_name
from paddle.fluid.framework import IrGraph from paddle.fluid.framework import IrGraph
...@@ -32,6 +33,7 @@ from ..quantization_pass import ReplaceFakeQuantDequantPass, QuantWeightPass ...@@ -32,6 +33,7 @@ from ..quantization_pass import ReplaceFakeQuantDequantPass, QuantWeightPass
from paddle.fluid.log_helper import get_logger from paddle.fluid.log_helper import get_logger
from .. import quantization_pass from .. import quantization_pass
from . import utils from . import utils
from . import fuse_utils
__all__ = ['ImperativeQuantAware'] __all__ = ['ImperativeQuantAware']
...@@ -52,6 +54,7 @@ class ImperativeQuantAware(object): ...@@ -52,6 +54,7 @@ class ImperativeQuantAware(object):
weight_bits=8, weight_bits=8,
activation_bits=8, activation_bits=8,
moving_rate=0.9, moving_rate=0.9,
fuse_conv_bn=False,
weight_preprocess_layer=None, weight_preprocess_layer=None,
act_preprocess_layer=None, act_preprocess_layer=None,
weight_quantize_layer=None, weight_quantize_layer=None,
...@@ -76,6 +79,7 @@ class ImperativeQuantAware(object): ...@@ -76,6 +79,7 @@ class ImperativeQuantAware(object):
activation_bits(int): quantization bit number for activations. activation_bits(int): quantization bit number for activations.
moving_rate(float): the parameter for 'moving_average_abs_max' moving_rate(float): the parameter for 'moving_average_abs_max'
quantization. quantization.
fuse_conv_bn(bool): Whether to fuse conv and bn, default is False.
weight_preprocess_layer(paddle.nn.Layer, optional): A paddle weight_preprocess_layer(paddle.nn.Layer, optional): A paddle
Layer that defines how to preprocess weight before quantization. Layer that defines how to preprocess weight before quantization.
Using this can quickly test if user's preprocess method works Using this can quickly test if user's preprocess method works
...@@ -188,6 +192,7 @@ class ImperativeQuantAware(object): ...@@ -188,6 +192,7 @@ class ImperativeQuantAware(object):
model_path="./imperative_model_qat") model_path="./imperative_model_qat")
""" """
super(ImperativeQuantAware, self).__init__() super(ImperativeQuantAware, self).__init__()
self.fuse_conv_bn = fuse_conv_bn
kwargs = { kwargs = {
"quantizable_layer_type": quantizable_layer_type, "quantizable_layer_type": quantizable_layer_type,
...@@ -256,8 +261,13 @@ class ImperativeQuantAware(object): ...@@ -256,8 +261,13 @@ class ImperativeQuantAware(object):
""" """
assert isinstance(model, dygraph.Layer), \ assert isinstance(model, dygraph.Layer), \
"The model must be the instance of dygraph.Layer." "The model must be the instance of dygraph.Layer."
if self.fuse_conv_bn:
fuse_utils.fuse_conv_bn(model)
self._quantize_inputs.apply(model) self._quantize_inputs.apply(model)
self._quantize_outputs.apply(model) self._quantize_outputs.apply(model)
return model
def save_quantized_model(self, layer, path, input_spec=None, **config): def save_quantized_model(self, layer, path, input_spec=None, **config):
self._quantize_outputs.save_quantized_model(layer, path, input_spec, self._quantize_outputs.save_quantized_model(layer, path, input_spec,
......
...@@ -126,6 +126,7 @@ class PostTrainingQuantization(object): ...@@ -126,6 +126,7 @@ class PostTrainingQuantization(object):
onnx_format=False, onnx_format=False,
optimize_model=False, optimize_model=False,
is_use_cache_file=False, is_use_cache_file=False,
skip_tensor_list=None,
cache_dir=None): cache_dir=None):
''' '''
Constructor. Constructor.
...@@ -198,6 +199,7 @@ class PostTrainingQuantization(object): ...@@ -198,6 +199,7 @@ class PostTrainingQuantization(object):
the model accuracy is usually higher when it is 'channel_wise_abs_max'. the model accuracy is usually higher when it is 'channel_wise_abs_max'.
onnx_format(bool): Whether to export the quantized model with format of ONNX. onnx_format(bool): Whether to export the quantized model with format of ONNX.
Default is False. Default is False.
skip_tensor_list(list): List of skip quant tensor name.
optimize_model(bool, optional): If set optimize_model as True, it applies optimize_model(bool, optional): If set optimize_model as True, it applies
some passes to the model before quantization, and it supports some passes to the model before quantization, and it supports
`conv2d/depthwise_conv2d + bn` pass so far. Some targets require the `conv2d/depthwise_conv2d + bn` pass so far. Some targets require the
...@@ -301,6 +303,7 @@ class PostTrainingQuantization(object): ...@@ -301,6 +303,7 @@ class PostTrainingQuantization(object):
self._activation_quantize_type = activation_quantize_type self._activation_quantize_type = activation_quantize_type
self._weight_quantize_type = weight_quantize_type self._weight_quantize_type = weight_quantize_type
self._onnx_format = onnx_format self._onnx_format = onnx_format
self._skip_tensor_list = skip_tensor_list
self._is_full_quantize = is_full_quantize self._is_full_quantize = is_full_quantize
if is_full_quantize: if is_full_quantize:
self._quantizable_op_type = self._support_quantize_op_type self._quantizable_op_type = self._support_quantize_op_type
...@@ -547,6 +550,12 @@ class PostTrainingQuantization(object): ...@@ -547,6 +550,12 @@ class PostTrainingQuantization(object):
persistable_var_names = _all_persistable_var_names(self._program) persistable_var_names = _all_persistable_var_names(self._program)
for block_id in range(len(self._program.blocks)): for block_id in range(len(self._program.blocks)):
for op in self._program.blocks[block_id].ops: for op in self._program.blocks[block_id].ops:
# skip quant form self._skip_tensor_list
if self._skip_tensor_list is not None:
for inp_name in utils._get_op_input_var_names(op):
if inp_name in self._skip_tensor_list:
op._set_attr("op_namescope", "skip_quant")
op_type = op.type op_type = op.type
if self._is_full_quantize and \ if self._is_full_quantize and \
op_type not in self._quantizable_op_type: op_type not in self._quantizable_op_type:
......
...@@ -354,6 +354,7 @@ set_tests_properties(test_quantization_pass PROPERTIES TIMEOUT 120) ...@@ -354,6 +354,7 @@ set_tests_properties(test_quantization_pass PROPERTIES TIMEOUT 120)
set_tests_properties(test_imperative_qat_channelwise PROPERTIES TIMEOUT 200) set_tests_properties(test_imperative_qat_channelwise PROPERTIES TIMEOUT 200)
set_tests_properties(test_user_defined_quantization PROPERTIES TIMEOUT 200) set_tests_properties(test_user_defined_quantization PROPERTIES TIMEOUT 200)
set_tests_properties(test_imperative_qat PROPERTIES TIMEOUT 200) set_tests_properties(test_imperative_qat PROPERTIES TIMEOUT 200)
set_tests_properties(test_imperative_qat_fuse PROPERTIES TIMEOUT 200)
set_tests_properties(test_imperative_out_scale PROPERTIES TIMEOUT 200) set_tests_properties(test_imperative_out_scale PROPERTIES TIMEOUT 200)
set_tests_properties(test_imperative_qat_user_defined PROPERTIES TIMEOUT 200) set_tests_properties(test_imperative_qat_user_defined PROPERTIES TIMEOUT 200)
......
...@@ -56,13 +56,15 @@ class TestImperativeQat(unittest.TestCase): ...@@ -56,13 +56,15 @@ class TestImperativeQat(unittest.TestCase):
self.onnx_format = False self.onnx_format = False
self.check_export_model_accuracy = True self.check_export_model_accuracy = True
self.diff_threshold = 0.01 self.diff_threshold = 0.01
self.fuse_conv_bn = False
def func_qat(self): def func_qat(self):
self.set_vars() self.set_vars()
imperative_qat = ImperativeQuantAware( imperative_qat = ImperativeQuantAware(
weight_quantize_type=self.weight_quantize_type, weight_quantize_type=self.weight_quantize_type,
activation_quantize_type=self.activation_quantize_type) activation_quantize_type=self.activation_quantize_type,
fuse_conv_bn=self.fuse_conv_bn)
with fluid.dygraph.guard(): with fluid.dygraph.guard():
# For CI coverage # For CI coverage
...@@ -214,6 +216,7 @@ class TestImperativeQatONNXFormat(unittest.TestCase): ...@@ -214,6 +216,7 @@ class TestImperativeQatONNXFormat(unittest.TestCase):
self.activation_quantize_type = 'moving_average_abs_max' self.activation_quantize_type = 'moving_average_abs_max'
self.onnx_format = True self.onnx_format = True
self.diff_threshold = 0.025 self.diff_threshold = 0.025
self.fuse_conv_bn = False
if __name__ == '__main__': if __name__ == '__main__':
......
...@@ -43,6 +43,7 @@ class TestImperativeQatChannelWise(TestImperativeQat): ...@@ -43,6 +43,7 @@ class TestImperativeQatChannelWise(TestImperativeQat):
self.activation_quantize_type = 'moving_average_abs_max' self.activation_quantize_type = 'moving_average_abs_max'
self.diff_threshold = 0.01 self.diff_threshold = 0.01
self.onnx_format = False self.onnx_format = False
self.fuse_conv_bn = False
print('weight_quantize_type', self.weight_quantize_type) print('weight_quantize_type', self.weight_quantize_type)
...@@ -52,6 +53,7 @@ class TestImperativeQatChannelWiseONNXFormat(TestImperativeQat): ...@@ -52,6 +53,7 @@ class TestImperativeQatChannelWiseONNXFormat(TestImperativeQat):
self.activation_quantize_type = 'moving_average_abs_max' self.activation_quantize_type = 'moving_average_abs_max'
self.onnx_format = True self.onnx_format = True
self.diff_threshold = 0.025 self.diff_threshold = 0.025
self.fuse_conv_bn = False
print('weight_quantize_type', self.weight_quantize_type) print('weight_quantize_type', self.weight_quantize_type)
......
# copyright (c) 2018 paddlepaddle authors. all rights reserved.
#
# licensed under the apache license, version 2.0 (the "license");
# you may not use this file except in compliance with the license.
# you may obtain a copy of the license at
#
# http://www.apache.org/licenses/license-2.0
#
# unless required by applicable law or agreed to in writing, software
# distributed under the license is distributed on an "as is" basis,
# without warranties or conditions of any kind, either express or implied.
# see the license for the specific language governing permissions and
# limitations under the license.
from __future__ import print_function
import os
import numpy as np
import random
import unittest
import logging
import paddle
import paddle.fluid as fluid
from paddle.fluid import core
from paddle.fluid.log_helper import get_logger
from test_imperative_qat import TestImperativeQat
paddle.enable_static()
os.environ["CPU_NUM"] = "1"
if core.is_compiled_with_cuda():
fluid.set_flags({"FLAGS_cudnn_deterministic": True})
_logger = get_logger(
__name__, logging.INFO, fmt='%(asctime)s-%(levelname)s: %(message)s')
class TestImperativeQatfuseBN(TestImperativeQat):
def set_vars(self):
self.weight_quantize_type = 'abs_max'
self.activation_quantize_type = 'moving_average_abs_max'
self.diff_threshold = 0.01
self.onnx_format = False
self.fuse_conv_bn = True
if __name__ == '__main__':
unittest.main()
...@@ -114,7 +114,8 @@ class TestPostTrainingQuantization(unittest.TestCase): ...@@ -114,7 +114,8 @@ class TestPostTrainingQuantization(unittest.TestCase):
is_optimize_model=False, is_optimize_model=False,
batch_size=10, batch_size=10,
batch_nums=10, batch_nums=10,
onnx_format=False): onnx_format=False,
skip_tensor_list=None):
place = fluid.CPUPlace() place = fluid.CPUPlace()
exe = fluid.Executor(place) exe = fluid.Executor(place)
...@@ -132,6 +133,7 @@ class TestPostTrainingQuantization(unittest.TestCase): ...@@ -132,6 +133,7 @@ class TestPostTrainingQuantization(unittest.TestCase):
is_full_quantize=is_full_quantize, is_full_quantize=is_full_quantize,
optimize_model=is_optimize_model, optimize_model=is_optimize_model,
onnx_format=onnx_format, onnx_format=onnx_format,
skip_tensor_list=skip_tensor_list,
is_use_cache_file=is_use_cache_file) is_use_cache_file=is_use_cache_file)
ptq.quantize() ptq.quantize()
ptq.save_quantized_model(self.int8_model_path) ptq.save_quantized_model(self.int8_model_path)
...@@ -150,7 +152,8 @@ class TestPostTrainingQuantization(unittest.TestCase): ...@@ -150,7 +152,8 @@ class TestPostTrainingQuantization(unittest.TestCase):
batch_size=10, batch_size=10,
infer_iterations=10, infer_iterations=10,
quant_iterations=5, quant_iterations=5,
onnx_format=False): onnx_format=False,
skip_tensor_list=None):
origin_model_path = self.download_model(data_url, data_md5, model_name) origin_model_path = self.download_model(data_url, data_md5, model_name)
origin_model_path = os.path.join(origin_model_path, model_name) origin_model_path = os.path.join(origin_model_path, model_name)
...@@ -162,10 +165,10 @@ class TestPostTrainingQuantization(unittest.TestCase): ...@@ -162,10 +165,10 @@ class TestPostTrainingQuantization(unittest.TestCase):
print("Start INT8 post training quantization for {0} on {1} images ...". print("Start INT8 post training quantization for {0} on {1} images ...".
format(model_name, quant_iterations * batch_size)) format(model_name, quant_iterations * batch_size))
self.generate_quantized_model(origin_model_path, algo, round_type, self.generate_quantized_model(
quantizable_op_type, is_full_quantize, origin_model_path, algo, round_type, quantizable_op_type,
is_use_cache_file, is_optimize_model, is_full_quantize, is_use_cache_file, is_optimize_model, batch_size,
batch_size, quant_iterations, onnx_format) quant_iterations, onnx_format, skip_tensor_list)
print("Start INT8 inference for {0} on {1} images ...".format( print("Start INT8 inference for {0} on {1} images ...".format(
model_name, infer_iterations * batch_size)) model_name, infer_iterations * batch_size))
...@@ -422,5 +425,38 @@ class TestPostTrainingmseForMnistONNXFormatFullQuant( ...@@ -422,5 +425,38 @@ class TestPostTrainingmseForMnistONNXFormatFullQuant(
onnx_format=onnx_format) onnx_format=onnx_format)
class TestPostTrainingavgForMnistSkipOP(TestPostTrainingQuantization):
def test_post_training_avg_skip_op(self):
model_name = "mnist_model"
data_url = "http://paddle-inference-dist.bj.bcebos.com/int8/mnist_model.tar.gz"
data_md5 = "be71d3997ec35ac2a65ae8a145e2887c"
algo = "avg"
round_type = "round"
quantizable_op_type = ["conv2d", "depthwise_conv2d", "mul"]
is_full_quantize = False
is_use_cache_file = False
is_optimize_model = True
diff_threshold = 0.01
batch_size = 10
infer_iterations = 50
quant_iterations = 5
skip_tensor_list = ["fc_0.w_0"]
self.run_test(
model_name,
data_url,
data_md5,
algo,
round_type,
quantizable_op_type,
is_full_quantize,
is_use_cache_file,
is_optimize_model,
diff_threshold,
batch_size,
infer_iterations,
quant_iterations,
skip_tensor_list=skip_tensor_list)
if __name__ == '__main__': if __name__ == '__main__':
unittest.main() unittest.main()
...@@ -241,7 +241,8 @@ class TestPostTrainingQuantization(unittest.TestCase): ...@@ -241,7 +241,8 @@ class TestPostTrainingQuantization(unittest.TestCase):
is_full_quantize=False, is_full_quantize=False,
is_use_cache_file=False, is_use_cache_file=False,
is_optimize_model=False, is_optimize_model=False,
onnx_format=False): onnx_format=False,
skip_tensor_list=None):
try: try:
os.system("mkdir " + self.int8_model) os.system("mkdir " + self.int8_model)
except Exception as e: except Exception as e:
...@@ -264,6 +265,7 @@ class TestPostTrainingQuantization(unittest.TestCase): ...@@ -264,6 +265,7 @@ class TestPostTrainingQuantization(unittest.TestCase):
is_full_quantize=is_full_quantize, is_full_quantize=is_full_quantize,
optimize_model=is_optimize_model, optimize_model=is_optimize_model,
onnx_format=onnx_format, onnx_format=onnx_format,
skip_tensor_list=skip_tensor_list,
is_use_cache_file=is_use_cache_file) is_use_cache_file=is_use_cache_file)
ptq.quantize() ptq.quantize()
ptq.save_quantized_model(self.int8_model) ptq.save_quantized_model(self.int8_model)
...@@ -279,7 +281,8 @@ class TestPostTrainingQuantization(unittest.TestCase): ...@@ -279,7 +281,8 @@ class TestPostTrainingQuantization(unittest.TestCase):
is_use_cache_file, is_use_cache_file,
is_optimize_model, is_optimize_model,
diff_threshold, diff_threshold,
onnx_format=False): onnx_format=False,
skip_tensor_list=None):
infer_iterations = self.infer_iterations infer_iterations = self.infer_iterations
batch_size = self.batch_size batch_size = self.batch_size
sample_iterations = self.sample_iterations sample_iterations = self.sample_iterations
...@@ -293,10 +296,10 @@ class TestPostTrainingQuantization(unittest.TestCase): ...@@ -293,10 +296,10 @@ class TestPostTrainingQuantization(unittest.TestCase):
print("Start INT8 post training quantization for {0} on {1} images ...". print("Start INT8 post training quantization for {0} on {1} images ...".
format(model, sample_iterations * batch_size)) format(model, sample_iterations * batch_size))
self.generate_quantized_model(model_cache_folder + "/model", self.generate_quantized_model(
quantizable_op_type, algo, round_type, model_cache_folder + "/model", quantizable_op_type, algo,
is_full_quantize, is_use_cache_file, round_type, is_full_quantize, is_use_cache_file, is_optimize_model,
is_optimize_model, onnx_format) onnx_format, skip_tensor_list)
print("Start INT8 inference for {0} on {1} images ...".format( print("Start INT8 inference for {0} on {1} images ...".format(
model, infer_iterations * batch_size)) model, infer_iterations * batch_size))
...@@ -444,5 +447,38 @@ class TestPostTrainingAvgONNXFormatForMobilenetv1(TestPostTrainingQuantization): ...@@ -444,5 +447,38 @@ class TestPostTrainingAvgONNXFormatForMobilenetv1(TestPostTrainingQuantization):
onnx_format=onnx_format) onnx_format=onnx_format)
class TestPostTrainingForMobilenetv1SkipOP(TestPostTrainingQuantization):
def test_post_training_mobilenetv1_skip(self):
model = "MobileNet-V1"
algo = "avg"
round_type = "round"
data_urls = [
'http://paddle-inference-dist.bj.bcebos.com/int8/mobilenetv1_int8_model.tar.gz'
]
data_md5s = ['13892b0716d26443a8cdea15b3c6438b']
quantizable_op_type = [
"conv2d",
"depthwise_conv2d",
"mul",
]
is_full_quantize = False
is_use_cache_file = False
is_optimize_model = True
diff_threshold = 0.025
skip_tensor_list = ["fc_0.w_0"]
self.run_test(
model,
algo,
round_type,
data_urls,
data_md5s,
quantizable_op_type,
is_full_quantize,
is_use_cache_file,
is_optimize_model,
diff_threshold,
skip_tensor_list=skip_tensor_list)
if __name__ == '__main__': if __name__ == '__main__':
unittest.main() unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册