From d6442df69c9bff4ca3d502d514d9a9d7959c1228 Mon Sep 17 00:00:00 2001 From: Guanghua Yu <742925032@qq.com> Date: Wed, 4 May 2022 16:03:39 +0800 Subject: [PATCH] support fuse conv and bn in QAT (#42255) --- .../quantization/imperative/fuse_utils.py | 21 ++++++++ .../slim/quantization/imperative/qat.py | 10 ++++ .../fluid/contrib/slim/tests/CMakeLists.txt | 1 + .../contrib/slim/tests/test_imperative_qat.py | 5 +- .../tests/test_imperative_qat_channelwise.py | 2 + .../slim/tests/test_imperative_qat_fuse.py | 50 +++++++++++++++++++ 6 files changed, 88 insertions(+), 1 deletion(-) create mode 100644 python/paddle/fluid/contrib/slim/tests/test_imperative_qat_fuse.py diff --git a/python/paddle/fluid/contrib/slim/quantization/imperative/fuse_utils.py b/python/paddle/fluid/contrib/slim/quantization/imperative/fuse_utils.py index 14282df23d3..1f7a01f17b0 100644 --- a/python/paddle/fluid/contrib/slim/quantization/imperative/fuse_utils.py +++ b/python/paddle/fluid/contrib/slim/quantization/imperative/fuse_utils.py @@ -28,6 +28,27 @@ class Identity(nn.Layer): return input +def fuse_conv_bn(model): + is_train = False + if model.training: + model.eval() + is_train = True + fuse_list = [] + tmp_pair = [None, None] + for name, layer in model.named_sublayers(): + if isinstance(layer, nn.Conv2D): + tmp_pair[0] = name + if isinstance(layer, nn.BatchNorm2D): + tmp_pair[1] = name + + if tmp_pair[0] and tmp_pair[1] and len(tmp_pair) == 2: + fuse_list.append(tmp_pair) + tmp_pair = [None, None] + model = fuse_layers(model, fuse_list) + if is_train: + model.train() + + def fuse_layers(model, layers_to_fuse, inplace=False): ''' fuse layers in layers_to_fuse diff --git a/python/paddle/fluid/contrib/slim/quantization/imperative/qat.py b/python/paddle/fluid/contrib/slim/quantization/imperative/qat.py index 059cb7b0dd1..d5c3d9ab82d 100644 --- a/python/paddle/fluid/contrib/slim/quantization/imperative/qat.py +++ b/python/paddle/fluid/contrib/slim/quantization/imperative/qat.py @@ -20,6 +20,7 @@ import os import warnings import paddle +import paddle.nn as nn import paddle.nn.quant.quant_layers as quant_layers from paddle.fluid import dygraph, core, framework, unique_name from paddle.fluid.framework import IrGraph @@ -32,6 +33,7 @@ from ..quantization_pass import ReplaceFakeQuantDequantPass, QuantWeightPass from paddle.fluid.log_helper import get_logger from .. import quantization_pass from . import utils +from . import fuse_utils __all__ = ['ImperativeQuantAware'] @@ -52,6 +54,7 @@ class ImperativeQuantAware(object): weight_bits=8, activation_bits=8, moving_rate=0.9, + fuse_conv_bn=False, weight_preprocess_layer=None, act_preprocess_layer=None, weight_quantize_layer=None, @@ -76,6 +79,7 @@ class ImperativeQuantAware(object): activation_bits(int): quantization bit number for activations. moving_rate(float): the parameter for 'moving_average_abs_max' quantization. + fuse_conv_bn(bool): Whether to fuse conv and bn, default is False. weight_preprocess_layer(paddle.nn.Layer, optional): A paddle Layer that defines how to preprocess weight before quantization. Using this can quickly test if user's preprocess method works @@ -188,6 +192,7 @@ class ImperativeQuantAware(object): model_path="./imperative_model_qat") """ super(ImperativeQuantAware, self).__init__() + self.fuse_conv_bn = fuse_conv_bn kwargs = { "quantizable_layer_type": quantizable_layer_type, @@ -256,8 +261,13 @@ class ImperativeQuantAware(object): """ assert isinstance(model, dygraph.Layer), \ "The model must be the instance of dygraph.Layer." + + if self.fuse_conv_bn: + fuse_utils.fuse_conv_bn(model) + self._quantize_inputs.apply(model) self._quantize_outputs.apply(model) + return model def save_quantized_model(self, layer, path, input_spec=None, **config): self._quantize_outputs.save_quantized_model(layer, path, input_spec, diff --git a/python/paddle/fluid/contrib/slim/tests/CMakeLists.txt b/python/paddle/fluid/contrib/slim/tests/CMakeLists.txt index 30e2b4613b1..0140283b915 100644 --- a/python/paddle/fluid/contrib/slim/tests/CMakeLists.txt +++ b/python/paddle/fluid/contrib/slim/tests/CMakeLists.txt @@ -354,6 +354,7 @@ set_tests_properties(test_quantization_pass PROPERTIES TIMEOUT 120) set_tests_properties(test_imperative_qat_channelwise PROPERTIES TIMEOUT 200) set_tests_properties(test_user_defined_quantization PROPERTIES TIMEOUT 200) set_tests_properties(test_imperative_qat PROPERTIES TIMEOUT 200) +set_tests_properties(test_imperative_qat_fuse PROPERTIES TIMEOUT 200) set_tests_properties(test_imperative_out_scale PROPERTIES TIMEOUT 200) set_tests_properties(test_imperative_qat_user_defined PROPERTIES TIMEOUT 200) diff --git a/python/paddle/fluid/contrib/slim/tests/test_imperative_qat.py b/python/paddle/fluid/contrib/slim/tests/test_imperative_qat.py index 015ecb3d4a4..0d035390e2c 100644 --- a/python/paddle/fluid/contrib/slim/tests/test_imperative_qat.py +++ b/python/paddle/fluid/contrib/slim/tests/test_imperative_qat.py @@ -56,13 +56,15 @@ class TestImperativeQat(unittest.TestCase): self.onnx_format = False self.check_export_model_accuracy = True self.diff_threshold = 0.01 + self.fuse_conv_bn = False def func_qat(self): self.set_vars() imperative_qat = ImperativeQuantAware( weight_quantize_type=self.weight_quantize_type, - activation_quantize_type=self.activation_quantize_type) + activation_quantize_type=self.activation_quantize_type, + fuse_conv_bn=self.fuse_conv_bn) with fluid.dygraph.guard(): # For CI coverage @@ -214,6 +216,7 @@ class TestImperativeQatONNXFormat(unittest.TestCase): self.activation_quantize_type = 'moving_average_abs_max' self.onnx_format = True self.diff_threshold = 0.025 + self.fuse_conv_bn = False if __name__ == '__main__': diff --git a/python/paddle/fluid/contrib/slim/tests/test_imperative_qat_channelwise.py b/python/paddle/fluid/contrib/slim/tests/test_imperative_qat_channelwise.py index ff40b170345..94e0681d1f5 100644 --- a/python/paddle/fluid/contrib/slim/tests/test_imperative_qat_channelwise.py +++ b/python/paddle/fluid/contrib/slim/tests/test_imperative_qat_channelwise.py @@ -43,6 +43,7 @@ class TestImperativeQatChannelWise(TestImperativeQat): self.activation_quantize_type = 'moving_average_abs_max' self.diff_threshold = 0.01 self.onnx_format = False + self.fuse_conv_bn = False print('weight_quantize_type', self.weight_quantize_type) @@ -52,6 +53,7 @@ class TestImperativeQatChannelWiseONNXFormat(TestImperativeQat): self.activation_quantize_type = 'moving_average_abs_max' self.onnx_format = True self.diff_threshold = 0.025 + self.fuse_conv_bn = False print('weight_quantize_type', self.weight_quantize_type) diff --git a/python/paddle/fluid/contrib/slim/tests/test_imperative_qat_fuse.py b/python/paddle/fluid/contrib/slim/tests/test_imperative_qat_fuse.py new file mode 100644 index 00000000000..d580eb7ae7a --- /dev/null +++ b/python/paddle/fluid/contrib/slim/tests/test_imperative_qat_fuse.py @@ -0,0 +1,50 @@ +# copyright (c) 2018 paddlepaddle authors. all rights reserved. +# +# licensed under the apache license, version 2.0 (the "license"); +# you may not use this file except in compliance with the license. +# you may obtain a copy of the license at +# +# http://www.apache.org/licenses/license-2.0 +# +# unless required by applicable law or agreed to in writing, software +# distributed under the license is distributed on an "as is" basis, +# without warranties or conditions of any kind, either express or implied. +# see the license for the specific language governing permissions and +# limitations under the license. + +from __future__ import print_function + +import os +import numpy as np +import random +import unittest +import logging + +import paddle +import paddle.fluid as fluid +from paddle.fluid import core +from paddle.fluid.log_helper import get_logger + +from test_imperative_qat import TestImperativeQat + +paddle.enable_static() + +os.environ["CPU_NUM"] = "1" +if core.is_compiled_with_cuda(): + fluid.set_flags({"FLAGS_cudnn_deterministic": True}) + +_logger = get_logger( + __name__, logging.INFO, fmt='%(asctime)s-%(levelname)s: %(message)s') + + +class TestImperativeQatfuseBN(TestImperativeQat): + def set_vars(self): + self.weight_quantize_type = 'abs_max' + self.activation_quantize_type = 'moving_average_abs_max' + self.diff_threshold = 0.01 + self.onnx_format = False + self.fuse_conv_bn = True + + +if __name__ == '__main__': + unittest.main() -- GitLab