From ef53625063697be5a398df44669c187575045534 Mon Sep 17 00:00:00 2001 From: XGZhang <46363693+XGZhang11@users.noreply.github.com> Date: Tue, 31 Aug 2021 13:14:37 +0800 Subject: [PATCH] support fuse layers for ptq (#35015) --- .../quantization/imperative/fuse_utils.py | 175 ++++++++++++++++++ .../slim/quantization/imperative/ptq.py | 17 +- .../slim/tests/imperative_test_utils.py | 60 ++++++ .../contrib/slim/tests/test_imperative_ptq.py | 94 +++++++++- 4 files changed, 340 insertions(+), 6 deletions(-) create mode 100644 python/paddle/fluid/contrib/slim/quantization/imperative/fuse_utils.py diff --git a/python/paddle/fluid/contrib/slim/quantization/imperative/fuse_utils.py b/python/paddle/fluid/contrib/slim/quantization/imperative/fuse_utils.py new file mode 100644 index 00000000000..14282df23d3 --- /dev/null +++ b/python/paddle/fluid/contrib/slim/quantization/imperative/fuse_utils.py @@ -0,0 +1,175 @@ +# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import copy +import paddle +import paddle.nn as nn +from . import utils + + +class Identity(nn.Layer): + '''a layer to replace bn or relu layers''' + + def __init__(self, *args, **kwargs): + super(Identity, self).__init__() + + def forward(self, input): + return input + + +def fuse_layers(model, layers_to_fuse, inplace=False): + ''' + fuse layers in layers_to_fuse + + Args: + model(paddle.nn.Layer): The model to be fused. + layers_to_fuse(list): The layers' names to be fused. For + example,"fuse_list = [["conv1", "bn1"], ["conv2", "bn2"]]". + A TypeError would be raised if "fuse" was set as + True but "fuse_list" was None. + Default: None. + inplace(bool): Whether apply fusing to the input model. + Default: False. + + Return + fused_model(paddle.nn.Layer): The fused model. + ''' + if inplace == False: + model = copy.deepcopy(model) + for layers in layers_to_fuse: + _fuse_layers(model, layers) + return model + + +def _fuse_layers(model, layers_list): + '''fuse all the layers in layers_list''' + layer_list = [] + for layer_name in layers_list: + parent_layer, sub_name = utils.find_parent_layer_and_sub_name( + model, layer_name) + layer_list.append(getattr(parent_layer, sub_name)) + new_layers = _fuse_func(layer_list) + for i, item in enumerate(layers_list): + parent_layer, sub_name = utils.find_parent_layer_and_sub_name(model, + item) + setattr(parent_layer, sub_name, new_layers[i]) + + +def _fuse_func(layer_list): + '''choose the fuser method and fuse layers''' + types = tuple(type(m) for m in layer_list) + fusion_method = types_to_fusion_method.get(types, None) + new_layers = [None] * len(layer_list) + fused_layer = fusion_method(*layer_list) + for handle_id, pre_hook_fn in layer_list[0]._forward_pre_hooks.items(): + fused_layer.register_forward_pre_hook(pre_hook_fn) + del layer_list[0]._forward_pre_hooks[handle_id] + for handle_id, hook_fn in layer_list[-1]._forward_post_hooks.items(): + fused_layer.register_forward_post_hook(hook_fn) + del layer_list[-1]._forward_post_hooks[handle_id] + new_layers[0] = fused_layer + for i in range(1, len(layer_list)): + identity = Identity() + identity.training = layer_list[0].training + new_layers[i] = identity + return new_layers + + +def _fuse_conv_bn(conv, bn): + '''fuse conv and bn for train or eval''' + assert(conv.training == bn.training),\ + "Conv and BN both must be in the same mode (train or eval)." + if conv.training: + assert bn._num_features == conv._out_channels, 'Output channel of Conv2d must match num_features of BatchNorm2d' + raise NotImplementedError + else: + return _fuse_conv_bn_eval(conv, bn) + + +def _fuse_conv_bn_eval(conv, bn): + '''fuse conv and bn for eval''' + assert (not (conv.training or bn.training)), "Fusion only for eval!" + fused_conv = copy.deepcopy(conv) + + fused_weight, fused_bias = _fuse_conv_bn_weights( + fused_conv.weight, fused_conv.bias, bn._mean, bn._variance, bn._epsilon, + bn.weight, bn.bias) + fused_conv.weight.set_value(fused_weight) + if fused_conv.bias is None: + fused_conv.bias = paddle.create_parameter( + shape=[fused_conv._out_channels], is_bias=True, dtype=bn.bias.dtype) + fused_conv.bias.set_value(fused_bias) + return fused_conv + + +def _fuse_conv_bn_weights(conv_w, conv_b, bn_rm, bn_rv, bn_eps, bn_w, bn_b): + '''fuse weights and bias of conv and bn''' + if conv_b is None: + conv_b = paddle.zeros_like(bn_rm) + if bn_w is None: + bn_w = paddle.ones_like(bn_rm) + if bn_b is None: + bn_b = paddle.zeros_like(bn_rm) + bn_var_rsqrt = paddle.rsqrt(bn_rv + bn_eps) + conv_w = conv_w * \ + (bn_w * bn_var_rsqrt).reshape([-1] + [1] * (len(conv_w.shape) - 1)) + conv_b = (conv_b - bn_rm) * bn_var_rsqrt * bn_w + bn_b + return conv_w, conv_b + + +def _fuse_linear_bn(linear, bn): + '''fuse linear and bn''' + assert (linear.training == bn.training),\ + "Linear and BN both must be in the same mode (train or eval)." + if linear.training: + assert bn._num_features == linear.weight.shape[ + 1], 'Output channel of Linear must match num_features of BatchNorm' + raise NotImplementedError + else: + return _fuse_linear_bn_eval(linear, bn) + + +def _fuse_linear_bn_eval(linear, bn): + '''fuse linear and bn for eval''' + assert (not (linear.training or bn.training)), "Fusion only for eval!" + fused_linear = copy.deepcopy(linear) + + fused_weight, fused_bias = _fuse_linear_bn_weights( + fused_linear.weight, fused_linear.bias, bn._mean, bn._variance, + bn._epsilon, bn.weight, bn.bias) + fused_linear.weight.set_value(fused_weight) + if fused_linear.bias is None: + fused_linear.bias = paddle.create_parameter( + shape=[fused_linear.weight.shape[1]], + is_bias=True, + dtype=bn.bias.dtype) + fused_linear.bias.set_value(fused_bias) + return fused_linear + + +def _fuse_linear_bn_weights(linear_w, linear_b, bn_rm, bn_rv, bn_eps, bn_w, + bn_b): + '''fuse weights and bias of linear and bn''' + if linear_b is None: + linear_b = paddle.zeros_like(bn_rm) + bn_scale = bn_w * paddle.rsqrt(bn_rv + bn_eps) + fused_w = linear_w * bn_scale.unsqueeze(-1) + fused_b = (linear_b - bn_rm) * bn_scale + bn_b + return fused_w, fused_b + + +types_to_fusion_method = { + (nn.Conv2D, nn.BatchNorm2D): _fuse_conv_bn, + (nn.Linear, nn.BatchNorm1D): _fuse_linear_bn, +} diff --git a/python/paddle/fluid/contrib/slim/quantization/imperative/ptq.py b/python/paddle/fluid/contrib/slim/quantization/imperative/ptq.py index 3a536ab1d20..64d9cd32101 100644 --- a/python/paddle/fluid/contrib/slim/quantization/imperative/ptq.py +++ b/python/paddle/fluid/contrib/slim/quantization/imperative/ptq.py @@ -22,6 +22,7 @@ import paddle.nn.quant.quant_layers as quant_layers from paddle.fluid.log_helper import get_logger from paddle.fluid.dygraph.io import INFER_MODEL_SUFFIX, INFER_PARAMS_SUFFIX +from . import fuse_utils from . import utils from . import ptq_hooks from . import ptq_config @@ -55,7 +56,7 @@ class ImperativePTQ(object): self._quant_config = quant_config - def quantize(self, model, inplace=False): + def quantize(self, model, inplace=False, fuse=False, fuse_list=None): """ Add quant config and hook to the target layer. @@ -63,15 +64,23 @@ class ImperativePTQ(object): model(paddle.nn.Layer): The model to be quantized. inplace(bool): Whether apply quantization to the input model. Default: False. - Returns: + fuse(bool): Whether to fuse layers. + Default: False. + fuse_list(list): The layers' names to be fused. For example, + "fuse_list = [["conv1", "bn1"], ["conv2", "bn2"]]". + A TypeError would be raised if "fuse" was set as + True but "fuse_list" was None. + Default: None. + Return quantized_model(paddle.nn.Layer): The quantized model. """ assert isinstance(model, paddle.nn.Layer), \ "The model must be the instance of paddle.nn.Layer." - if not inplace: model = copy.deepcopy(model) - + if fuse: + model.eval() + model = fuse_utils.fuse_layers(model, fuse_list) for name, layer in model.named_sublayers(): if PTQRegistry.is_supported_layer(layer) \ and utils.is_leaf_layer(layer) \ diff --git a/python/paddle/fluid/contrib/slim/tests/imperative_test_utils.py b/python/paddle/fluid/contrib/slim/tests/imperative_test_utils.py index 5c91f01d0bd..466cc14eae0 100644 --- a/python/paddle/fluid/contrib/slim/tests/imperative_test_utils.py +++ b/python/paddle/fluid/contrib/slim/tests/imperative_test_utils.py @@ -20,6 +20,7 @@ from paddle.fluid import core from paddle.fluid.dygraph.container import Sequential from paddle.nn import ReLU, ReLU6, LeakyReLU, Sigmoid, Softmax, PReLU from paddle.nn import Linear, Conv2D, Softmax, BatchNorm2D, MaxPool2D +from paddle.nn import BatchNorm1D from paddle.fluid.log_helper import get_logger @@ -43,6 +44,15 @@ def fix_model_dict(model): return model +def pre_hook(layer, input): + input_return = (input[0] * 2) + return input_return + + +def post_hook(layer, input, output): + return output * 2 + + def train_lenet(lenet, reader, optimizer): loss_list = [] lenet.train() @@ -224,3 +234,53 @@ class ImperativeLenetWithSkipQuant(fluid.dygraph.Layer): x = self.softmax_0(x) return x + + +class ImperativeLinearBn(fluid.dygraph.Layer): + def __init__(self): + super(ImperativeLinearBn, self).__init__() + + fc_w_attr = paddle.ParamAttr( + name="fc_weight", + initializer=paddle.nn.initializer.Constant(value=0.5)) + fc_b_attr = paddle.ParamAttr( + name="fc_bias", + initializer=paddle.nn.initializer.Constant(value=1.0)) + bn_w_attr = paddle.ParamAttr( + name="bn_weight", + initializer=paddle.nn.initializer.Constant(value=0.5)) + + self.linear = Linear( + in_features=10, + out_features=10, + weight_attr=fc_w_attr, + bias_attr=fc_b_attr) + self.bn = BatchNorm1D(10, weight_attr=bn_w_attr) + + def forward(self, inputs): + x = self.linear(inputs) + x = self.bn(x) + + return x + + +class ImperativeLinearBn_hook(fluid.dygraph.Layer): + def __init__(self): + super(ImperativeLinearBn_hook, self).__init__() + + fc_w_attr = paddle.ParamAttr( + name="linear_weight", + initializer=paddle.nn.initializer.Constant(value=0.5)) + + self.linear = Linear( + in_features=10, out_features=10, weight_attr=fc_w_attr) + self.bn = BatchNorm1D(10) + + forward_pre = self.linear.register_forward_pre_hook(pre_hook) + forward_post = self.bn.register_forward_post_hook(post_hook) + + def forward(self, inputs): + x = self.linear(inputs) + x = self.bn(x) + + return x diff --git a/python/paddle/fluid/contrib/slim/tests/test_imperative_ptq.py b/python/paddle/fluid/contrib/slim/tests/test_imperative_ptq.py index 575a91642a7..fb92b12cb0d 100644 --- a/python/paddle/fluid/contrib/slim/tests/test_imperative_ptq.py +++ b/python/paddle/fluid/contrib/slim/tests/test_imperative_ptq.py @@ -23,18 +23,48 @@ import unittest import copy import logging +import paddle.nn as nn import paddle import paddle.fluid as fluid from paddle.fluid.contrib.slim.quantization import * from paddle.fluid.log_helper import get_logger from paddle.dataset.common import download -from imperative_test_utils import fix_model_dict, ImperativeLenet +from imperative_test_utils import fix_model_dict, ImperativeLenet, ImperativeLinearBn +from imperative_test_utils import ImperativeLinearBn_hook _logger = get_logger( __name__, logging.INFO, fmt='%(asctime)s-%(levelname)s: %(message)s') +class TestFuseLinearBn(unittest.TestCase): + """ + Fuse the linear and bn layers, and then quantize the model. + """ + + def test_fuse(self): + model = ImperativeLinearBn() + model_h = ImperativeLinearBn_hook() + inputs = paddle.randn((3, 10), dtype="float32") + config = PTQConfig(AbsmaxQuantizer(), AbsmaxQuantizer()) + ptq = ImperativePTQ(config) + f_l = [['linear', 'bn']] + quant_model = ptq.quantize(model, fuse=True, fuse_list=f_l) + quant_h = ptq.quantize(model_h, fuse=True, fuse_list=f_l) + for name, layer in quant_model.named_sublayers(): + if name in f_l: + assert not (isinstance(layer, nn.BatchNorm1D) or + isinstance(layer, nn.BatchNorm2D)) + out = model(inputs) + out_h = model_h(inputs) + out_quant = quant_model(inputs) + out_quant_h = quant_h(inputs) + cos_sim_func = nn.CosineSimilarity(axis=0) + print('fuse linear+bn', + cos_sim_func(out.flatten(), out_quant.flatten())) + print(cos_sim_func(out_h.flatten(), out_quant_h.flatten())) + + class TestImperativePTQ(unittest.TestCase): """ """ @@ -177,7 +207,6 @@ class TestImperativePTQ(unittest.TestCase): model = ImperativeLenet() model_state_dict = paddle.load(params_path) model.set_state_dict(model_state_dict) - # Quantize, calibrate and save quant_model = self.ptq.quantize(model) before_acc_top1 = self.model_test(quant_model, self.batch_num, @@ -216,6 +245,67 @@ class TestImperativePTQ(unittest.TestCase): print("total time: %ss \n" % (end_time - start_time)) +class TestImperativePTQfuse(TestImperativePTQ): + def test_ptq(self): + start_time = time.time() + + self.set_vars() + + # Load model + params_path = self.download_model(self.lenet_url, self.lenet_md5, + "lenet") + params_path += "/lenet_pretrained/lenet.pdparams" + + model = ImperativeLenet() + model_state_dict = paddle.load(params_path) + model.set_state_dict(model_state_dict) + # Quantize, calibrate and save + f_l = [['features.0', 'features.1'], ['features.4', 'features.5']] + quant_model = self.ptq.quantize(model, fuse=True, fuse_list=f_l) + for name, layer in quant_model.named_sublayers(): + if name in f_l: + assert not (isinstance(layer, nn.BatchNorm1D) or + isinstance(layer, nn.BatchNorm2D)) + before_acc_top1 = self.model_test(quant_model, self.batch_num, + self.batch_size) + + input_spec = [ + paddle.static.InputSpec( + shape=[None, 1, 28, 28], dtype='float32') + ] + self.ptq.save_quantized_model( + model=quant_model, path=self.save_path, input_spec=input_spec) + print('Quantized model saved in {%s}' % self.save_path) + + after_acc_top1 = self.model_test(quant_model, self.batch_num, + self.batch_size) + + paddle.enable_static() + infer_acc_top1 = self.program_test(self.save_path, self.batch_num, + self.batch_size) + paddle.disable_static() + + # Check + print('Before converted acc_top1: %s' % before_acc_top1) + print('After converted acc_top1: %s' % after_acc_top1) + print('Infer acc_top1: %s' % infer_acc_top1) + + #Check whether the quant_model is correct after converting. + #The acc of quantized model should be higher than 0.95. + self.assertTrue( + after_acc_top1 >= self.eval_acc_top1, + msg="The test acc {%f} is less than {%f}." % + (after_acc_top1, self.eval_acc_top1)) + #Check the saved infer_model.The acc of infer model + #should not be lower than the one of dygraph model. + self.assertTrue( + infer_acc_top1 >= after_acc_top1, + msg='The acc is lower after converting model.') + + end_time = time.time() + print("total time: %ss \n" % (end_time - start_time)) + + class TestImperativePTQHist(TestImperativePTQ): def set_vars(self): config = PTQConfig(HistQuantizer(), AbsmaxQuantizer()) -- GitLab