未验证 提交 d6442df6 编写于 作者: G Guanghua Yu 提交者: GitHub

support fuse conv and bn in QAT (#42255)

上级 b621a4f1
...@@ -28,6 +28,27 @@ class Identity(nn.Layer): ...@@ -28,6 +28,27 @@ class Identity(nn.Layer):
return input return input
def fuse_conv_bn(model):
is_train = False
if model.training:
model.eval()
is_train = True
fuse_list = []
tmp_pair = [None, None]
for name, layer in model.named_sublayers():
if isinstance(layer, nn.Conv2D):
tmp_pair[0] = name
if isinstance(layer, nn.BatchNorm2D):
tmp_pair[1] = name
if tmp_pair[0] and tmp_pair[1] and len(tmp_pair) == 2:
fuse_list.append(tmp_pair)
tmp_pair = [None, None]
model = fuse_layers(model, fuse_list)
if is_train:
model.train()
def fuse_layers(model, layers_to_fuse, inplace=False): def fuse_layers(model, layers_to_fuse, inplace=False):
''' '''
fuse layers in layers_to_fuse fuse layers in layers_to_fuse
......
...@@ -20,6 +20,7 @@ import os ...@@ -20,6 +20,7 @@ import os
import warnings import warnings
import paddle import paddle
import paddle.nn as nn
import paddle.nn.quant.quant_layers as quant_layers import paddle.nn.quant.quant_layers as quant_layers
from paddle.fluid import dygraph, core, framework, unique_name from paddle.fluid import dygraph, core, framework, unique_name
from paddle.fluid.framework import IrGraph from paddle.fluid.framework import IrGraph
...@@ -32,6 +33,7 @@ from ..quantization_pass import ReplaceFakeQuantDequantPass, QuantWeightPass ...@@ -32,6 +33,7 @@ from ..quantization_pass import ReplaceFakeQuantDequantPass, QuantWeightPass
from paddle.fluid.log_helper import get_logger from paddle.fluid.log_helper import get_logger
from .. import quantization_pass from .. import quantization_pass
from . import utils from . import utils
from . import fuse_utils
__all__ = ['ImperativeQuantAware'] __all__ = ['ImperativeQuantAware']
...@@ -52,6 +54,7 @@ class ImperativeQuantAware(object): ...@@ -52,6 +54,7 @@ class ImperativeQuantAware(object):
weight_bits=8, weight_bits=8,
activation_bits=8, activation_bits=8,
moving_rate=0.9, moving_rate=0.9,
fuse_conv_bn=False,
weight_preprocess_layer=None, weight_preprocess_layer=None,
act_preprocess_layer=None, act_preprocess_layer=None,
weight_quantize_layer=None, weight_quantize_layer=None,
...@@ -76,6 +79,7 @@ class ImperativeQuantAware(object): ...@@ -76,6 +79,7 @@ class ImperativeQuantAware(object):
activation_bits(int): quantization bit number for activations. activation_bits(int): quantization bit number for activations.
moving_rate(float): the parameter for 'moving_average_abs_max' moving_rate(float): the parameter for 'moving_average_abs_max'
quantization. quantization.
fuse_conv_bn(bool): Whether to fuse conv and bn, default is False.
weight_preprocess_layer(paddle.nn.Layer, optional): A paddle weight_preprocess_layer(paddle.nn.Layer, optional): A paddle
Layer that defines how to preprocess weight before quantization. Layer that defines how to preprocess weight before quantization.
Using this can quickly test if user's preprocess method works Using this can quickly test if user's preprocess method works
...@@ -188,6 +192,7 @@ class ImperativeQuantAware(object): ...@@ -188,6 +192,7 @@ class ImperativeQuantAware(object):
model_path="./imperative_model_qat") model_path="./imperative_model_qat")
""" """
super(ImperativeQuantAware, self).__init__() super(ImperativeQuantAware, self).__init__()
self.fuse_conv_bn = fuse_conv_bn
kwargs = { kwargs = {
"quantizable_layer_type": quantizable_layer_type, "quantizable_layer_type": quantizable_layer_type,
...@@ -256,8 +261,13 @@ class ImperativeQuantAware(object): ...@@ -256,8 +261,13 @@ class ImperativeQuantAware(object):
""" """
assert isinstance(model, dygraph.Layer), \ assert isinstance(model, dygraph.Layer), \
"The model must be the instance of dygraph.Layer." "The model must be the instance of dygraph.Layer."
if self.fuse_conv_bn:
fuse_utils.fuse_conv_bn(model)
self._quantize_inputs.apply(model) self._quantize_inputs.apply(model)
self._quantize_outputs.apply(model) self._quantize_outputs.apply(model)
return model
def save_quantized_model(self, layer, path, input_spec=None, **config): def save_quantized_model(self, layer, path, input_spec=None, **config):
self._quantize_outputs.save_quantized_model(layer, path, input_spec, self._quantize_outputs.save_quantized_model(layer, path, input_spec,
......
...@@ -354,6 +354,7 @@ set_tests_properties(test_quantization_pass PROPERTIES TIMEOUT 120) ...@@ -354,6 +354,7 @@ set_tests_properties(test_quantization_pass PROPERTIES TIMEOUT 120)
set_tests_properties(test_imperative_qat_channelwise PROPERTIES TIMEOUT 200) set_tests_properties(test_imperative_qat_channelwise PROPERTIES TIMEOUT 200)
set_tests_properties(test_user_defined_quantization PROPERTIES TIMEOUT 200) set_tests_properties(test_user_defined_quantization PROPERTIES TIMEOUT 200)
set_tests_properties(test_imperative_qat PROPERTIES TIMEOUT 200) set_tests_properties(test_imperative_qat PROPERTIES TIMEOUT 200)
set_tests_properties(test_imperative_qat_fuse PROPERTIES TIMEOUT 200)
set_tests_properties(test_imperative_out_scale PROPERTIES TIMEOUT 200) set_tests_properties(test_imperative_out_scale PROPERTIES TIMEOUT 200)
set_tests_properties(test_imperative_qat_user_defined PROPERTIES TIMEOUT 200) set_tests_properties(test_imperative_qat_user_defined PROPERTIES TIMEOUT 200)
......
...@@ -56,13 +56,15 @@ class TestImperativeQat(unittest.TestCase): ...@@ -56,13 +56,15 @@ class TestImperativeQat(unittest.TestCase):
self.onnx_format = False self.onnx_format = False
self.check_export_model_accuracy = True self.check_export_model_accuracy = True
self.diff_threshold = 0.01 self.diff_threshold = 0.01
self.fuse_conv_bn = False
def func_qat(self): def func_qat(self):
self.set_vars() self.set_vars()
imperative_qat = ImperativeQuantAware( imperative_qat = ImperativeQuantAware(
weight_quantize_type=self.weight_quantize_type, weight_quantize_type=self.weight_quantize_type,
activation_quantize_type=self.activation_quantize_type) activation_quantize_type=self.activation_quantize_type,
fuse_conv_bn=self.fuse_conv_bn)
with fluid.dygraph.guard(): with fluid.dygraph.guard():
# For CI coverage # For CI coverage
...@@ -214,6 +216,7 @@ class TestImperativeQatONNXFormat(unittest.TestCase): ...@@ -214,6 +216,7 @@ class TestImperativeQatONNXFormat(unittest.TestCase):
self.activation_quantize_type = 'moving_average_abs_max' self.activation_quantize_type = 'moving_average_abs_max'
self.onnx_format = True self.onnx_format = True
self.diff_threshold = 0.025 self.diff_threshold = 0.025
self.fuse_conv_bn = False
if __name__ == '__main__': if __name__ == '__main__':
......
...@@ -43,6 +43,7 @@ class TestImperativeQatChannelWise(TestImperativeQat): ...@@ -43,6 +43,7 @@ class TestImperativeQatChannelWise(TestImperativeQat):
self.activation_quantize_type = 'moving_average_abs_max' self.activation_quantize_type = 'moving_average_abs_max'
self.diff_threshold = 0.01 self.diff_threshold = 0.01
self.onnx_format = False self.onnx_format = False
self.fuse_conv_bn = False
print('weight_quantize_type', self.weight_quantize_type) print('weight_quantize_type', self.weight_quantize_type)
...@@ -52,6 +53,7 @@ class TestImperativeQatChannelWiseONNXFormat(TestImperativeQat): ...@@ -52,6 +53,7 @@ class TestImperativeQatChannelWiseONNXFormat(TestImperativeQat):
self.activation_quantize_type = 'moving_average_abs_max' self.activation_quantize_type = 'moving_average_abs_max'
self.onnx_format = True self.onnx_format = True
self.diff_threshold = 0.025 self.diff_threshold = 0.025
self.fuse_conv_bn = False
print('weight_quantize_type', self.weight_quantize_type) print('weight_quantize_type', self.weight_quantize_type)
......
# copyright (c) 2018 paddlepaddle authors. all rights reserved.
#
# licensed under the apache license, version 2.0 (the "license");
# you may not use this file except in compliance with the license.
# you may obtain a copy of the license at
#
# http://www.apache.org/licenses/license-2.0
#
# unless required by applicable law or agreed to in writing, software
# distributed under the license is distributed on an "as is" basis,
# without warranties or conditions of any kind, either express or implied.
# see the license for the specific language governing permissions and
# limitations under the license.
from __future__ import print_function
import os
import numpy as np
import random
import unittest
import logging
import paddle
import paddle.fluid as fluid
from paddle.fluid import core
from paddle.fluid.log_helper import get_logger
from test_imperative_qat import TestImperativeQat
paddle.enable_static()
os.environ["CPU_NUM"] = "1"
if core.is_compiled_with_cuda():
fluid.set_flags({"FLAGS_cudnn_deterministic": True})
_logger = get_logger(
__name__, logging.INFO, fmt='%(asctime)s-%(levelname)s: %(message)s')
class TestImperativeQatfuseBN(TestImperativeQat):
def set_vars(self):
self.weight_quantize_type = 'abs_max'
self.activation_quantize_type = 'moving_average_abs_max'
self.diff_threshold = 0.01
self.onnx_format = False
self.fuse_conv_bn = True
if __name__ == '__main__':
unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册