未验证 提交 d6442df6 编写于 作者: G Guanghua Yu 提交者: GitHub

support fuse conv and bn in QAT (#42255)

上级 b621a4f1
......@@ -28,6 +28,27 @@ class Identity(nn.Layer):
return input
def fuse_conv_bn(model):
is_train = False
if model.training:
model.eval()
is_train = True
fuse_list = []
tmp_pair = [None, None]
for name, layer in model.named_sublayers():
if isinstance(layer, nn.Conv2D):
tmp_pair[0] = name
if isinstance(layer, nn.BatchNorm2D):
tmp_pair[1] = name
if tmp_pair[0] and tmp_pair[1] and len(tmp_pair) == 2:
fuse_list.append(tmp_pair)
tmp_pair = [None, None]
model = fuse_layers(model, fuse_list)
if is_train:
model.train()
def fuse_layers(model, layers_to_fuse, inplace=False):
'''
fuse layers in layers_to_fuse
......
......@@ -20,6 +20,7 @@ import os
import warnings
import paddle
import paddle.nn as nn
import paddle.nn.quant.quant_layers as quant_layers
from paddle.fluid import dygraph, core, framework, unique_name
from paddle.fluid.framework import IrGraph
......@@ -32,6 +33,7 @@ from ..quantization_pass import ReplaceFakeQuantDequantPass, QuantWeightPass
from paddle.fluid.log_helper import get_logger
from .. import quantization_pass
from . import utils
from . import fuse_utils
__all__ = ['ImperativeQuantAware']
......@@ -52,6 +54,7 @@ class ImperativeQuantAware(object):
weight_bits=8,
activation_bits=8,
moving_rate=0.9,
fuse_conv_bn=False,
weight_preprocess_layer=None,
act_preprocess_layer=None,
weight_quantize_layer=None,
......@@ -76,6 +79,7 @@ class ImperativeQuantAware(object):
activation_bits(int): quantization bit number for activations.
moving_rate(float): the parameter for 'moving_average_abs_max'
quantization.
fuse_conv_bn(bool): Whether to fuse conv and bn, default is False.
weight_preprocess_layer(paddle.nn.Layer, optional): A paddle
Layer that defines how to preprocess weight before quantization.
Using this can quickly test if user's preprocess method works
......@@ -188,6 +192,7 @@ class ImperativeQuantAware(object):
model_path="./imperative_model_qat")
"""
super(ImperativeQuantAware, self).__init__()
self.fuse_conv_bn = fuse_conv_bn
kwargs = {
"quantizable_layer_type": quantizable_layer_type,
......@@ -256,8 +261,13 @@ class ImperativeQuantAware(object):
"""
assert isinstance(model, dygraph.Layer), \
"The model must be the instance of dygraph.Layer."
if self.fuse_conv_bn:
fuse_utils.fuse_conv_bn(model)
self._quantize_inputs.apply(model)
self._quantize_outputs.apply(model)
return model
def save_quantized_model(self, layer, path, input_spec=None, **config):
self._quantize_outputs.save_quantized_model(layer, path, input_spec,
......
......@@ -354,6 +354,7 @@ set_tests_properties(test_quantization_pass PROPERTIES TIMEOUT 120)
set_tests_properties(test_imperative_qat_channelwise PROPERTIES TIMEOUT 200)
set_tests_properties(test_user_defined_quantization PROPERTIES TIMEOUT 200)
set_tests_properties(test_imperative_qat PROPERTIES TIMEOUT 200)
set_tests_properties(test_imperative_qat_fuse PROPERTIES TIMEOUT 200)
set_tests_properties(test_imperative_out_scale PROPERTIES TIMEOUT 200)
set_tests_properties(test_imperative_qat_user_defined PROPERTIES TIMEOUT 200)
......
......@@ -56,13 +56,15 @@ class TestImperativeQat(unittest.TestCase):
self.onnx_format = False
self.check_export_model_accuracy = True
self.diff_threshold = 0.01
self.fuse_conv_bn = False
def func_qat(self):
self.set_vars()
imperative_qat = ImperativeQuantAware(
weight_quantize_type=self.weight_quantize_type,
activation_quantize_type=self.activation_quantize_type)
activation_quantize_type=self.activation_quantize_type,
fuse_conv_bn=self.fuse_conv_bn)
with fluid.dygraph.guard():
# For CI coverage
......@@ -214,6 +216,7 @@ class TestImperativeQatONNXFormat(unittest.TestCase):
self.activation_quantize_type = 'moving_average_abs_max'
self.onnx_format = True
self.diff_threshold = 0.025
self.fuse_conv_bn = False
if __name__ == '__main__':
......
......@@ -43,6 +43,7 @@ class TestImperativeQatChannelWise(TestImperativeQat):
self.activation_quantize_type = 'moving_average_abs_max'
self.diff_threshold = 0.01
self.onnx_format = False
self.fuse_conv_bn = False
print('weight_quantize_type', self.weight_quantize_type)
......@@ -52,6 +53,7 @@ class TestImperativeQatChannelWiseONNXFormat(TestImperativeQat):
self.activation_quantize_type = 'moving_average_abs_max'
self.onnx_format = True
self.diff_threshold = 0.025
self.fuse_conv_bn = False
print('weight_quantize_type', self.weight_quantize_type)
......
# copyright (c) 2018 paddlepaddle authors. all rights reserved.
#
# licensed under the apache license, version 2.0 (the "license");
# you may not use this file except in compliance with the license.
# you may obtain a copy of the license at
#
# http://www.apache.org/licenses/license-2.0
#
# unless required by applicable law or agreed to in writing, software
# distributed under the license is distributed on an "as is" basis,
# without warranties or conditions of any kind, either express or implied.
# see the license for the specific language governing permissions and
# limitations under the license.
from __future__ import print_function
import os
import numpy as np
import random
import unittest
import logging
import paddle
import paddle.fluid as fluid
from paddle.fluid import core
from paddle.fluid.log_helper import get_logger
from test_imperative_qat import TestImperativeQat
paddle.enable_static()
os.environ["CPU_NUM"] = "1"
if core.is_compiled_with_cuda():
fluid.set_flags({"FLAGS_cudnn_deterministic": True})
_logger = get_logger(
__name__, logging.INFO, fmt='%(asctime)s-%(levelname)s: %(message)s')
class TestImperativeQatfuseBN(TestImperativeQat):
def set_vars(self):
self.weight_quantize_type = 'abs_max'
self.activation_quantize_type = 'moving_average_abs_max'
self.diff_threshold = 0.01
self.onnx_format = False
self.fuse_conv_bn = True
if __name__ == '__main__':
unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册