未验证 提交 ef536250 编写于 作者: X XGZhang 提交者: GitHub

support fuse layers for ptq (#35015)

上级 561841d2
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import copy
import paddle
import paddle.nn as nn
from . import utils
class Identity(nn.Layer):
'''a layer to replace bn or relu layers'''
def __init__(self, *args, **kwargs):
super(Identity, self).__init__()
def forward(self, input):
return input
def fuse_layers(model, layers_to_fuse, inplace=False):
'''
fuse layers in layers_to_fuse
Args:
model(paddle.nn.Layer): The model to be fused.
layers_to_fuse(list): The layers' names to be fused. For
example,"fuse_list = [["conv1", "bn1"], ["conv2", "bn2"]]".
A TypeError would be raised if "fuse" was set as
True but "fuse_list" was None.
Default: None.
inplace(bool): Whether apply fusing to the input model.
Default: False.
Return
fused_model(paddle.nn.Layer): The fused model.
'''
if inplace == False:
model = copy.deepcopy(model)
for layers in layers_to_fuse:
_fuse_layers(model, layers)
return model
def _fuse_layers(model, layers_list):
'''fuse all the layers in layers_list'''
layer_list = []
for layer_name in layers_list:
parent_layer, sub_name = utils.find_parent_layer_and_sub_name(
model, layer_name)
layer_list.append(getattr(parent_layer, sub_name))
new_layers = _fuse_func(layer_list)
for i, item in enumerate(layers_list):
parent_layer, sub_name = utils.find_parent_layer_and_sub_name(model,
item)
setattr(parent_layer, sub_name, new_layers[i])
def _fuse_func(layer_list):
'''choose the fuser method and fuse layers'''
types = tuple(type(m) for m in layer_list)
fusion_method = types_to_fusion_method.get(types, None)
new_layers = [None] * len(layer_list)
fused_layer = fusion_method(*layer_list)
for handle_id, pre_hook_fn in layer_list[0]._forward_pre_hooks.items():
fused_layer.register_forward_pre_hook(pre_hook_fn)
del layer_list[0]._forward_pre_hooks[handle_id]
for handle_id, hook_fn in layer_list[-1]._forward_post_hooks.items():
fused_layer.register_forward_post_hook(hook_fn)
del layer_list[-1]._forward_post_hooks[handle_id]
new_layers[0] = fused_layer
for i in range(1, len(layer_list)):
identity = Identity()
identity.training = layer_list[0].training
new_layers[i] = identity
return new_layers
def _fuse_conv_bn(conv, bn):
'''fuse conv and bn for train or eval'''
assert(conv.training == bn.training),\
"Conv and BN both must be in the same mode (train or eval)."
if conv.training:
assert bn._num_features == conv._out_channels, 'Output channel of Conv2d must match num_features of BatchNorm2d'
raise NotImplementedError
else:
return _fuse_conv_bn_eval(conv, bn)
def _fuse_conv_bn_eval(conv, bn):
'''fuse conv and bn for eval'''
assert (not (conv.training or bn.training)), "Fusion only for eval!"
fused_conv = copy.deepcopy(conv)
fused_weight, fused_bias = _fuse_conv_bn_weights(
fused_conv.weight, fused_conv.bias, bn._mean, bn._variance, bn._epsilon,
bn.weight, bn.bias)
fused_conv.weight.set_value(fused_weight)
if fused_conv.bias is None:
fused_conv.bias = paddle.create_parameter(
shape=[fused_conv._out_channels], is_bias=True, dtype=bn.bias.dtype)
fused_conv.bias.set_value(fused_bias)
return fused_conv
def _fuse_conv_bn_weights(conv_w, conv_b, bn_rm, bn_rv, bn_eps, bn_w, bn_b):
'''fuse weights and bias of conv and bn'''
if conv_b is None:
conv_b = paddle.zeros_like(bn_rm)
if bn_w is None:
bn_w = paddle.ones_like(bn_rm)
if bn_b is None:
bn_b = paddle.zeros_like(bn_rm)
bn_var_rsqrt = paddle.rsqrt(bn_rv + bn_eps)
conv_w = conv_w * \
(bn_w * bn_var_rsqrt).reshape([-1] + [1] * (len(conv_w.shape) - 1))
conv_b = (conv_b - bn_rm) * bn_var_rsqrt * bn_w + bn_b
return conv_w, conv_b
def _fuse_linear_bn(linear, bn):
'''fuse linear and bn'''
assert (linear.training == bn.training),\
"Linear and BN both must be in the same mode (train or eval)."
if linear.training:
assert bn._num_features == linear.weight.shape[
1], 'Output channel of Linear must match num_features of BatchNorm'
raise NotImplementedError
else:
return _fuse_linear_bn_eval(linear, bn)
def _fuse_linear_bn_eval(linear, bn):
'''fuse linear and bn for eval'''
assert (not (linear.training or bn.training)), "Fusion only for eval!"
fused_linear = copy.deepcopy(linear)
fused_weight, fused_bias = _fuse_linear_bn_weights(
fused_linear.weight, fused_linear.bias, bn._mean, bn._variance,
bn._epsilon, bn.weight, bn.bias)
fused_linear.weight.set_value(fused_weight)
if fused_linear.bias is None:
fused_linear.bias = paddle.create_parameter(
shape=[fused_linear.weight.shape[1]],
is_bias=True,
dtype=bn.bias.dtype)
fused_linear.bias.set_value(fused_bias)
return fused_linear
def _fuse_linear_bn_weights(linear_w, linear_b, bn_rm, bn_rv, bn_eps, bn_w,
bn_b):
'''fuse weights and bias of linear and bn'''
if linear_b is None:
linear_b = paddle.zeros_like(bn_rm)
bn_scale = bn_w * paddle.rsqrt(bn_rv + bn_eps)
fused_w = linear_w * bn_scale.unsqueeze(-1)
fused_b = (linear_b - bn_rm) * bn_scale + bn_b
return fused_w, fused_b
types_to_fusion_method = {
(nn.Conv2D, nn.BatchNorm2D): _fuse_conv_bn,
(nn.Linear, nn.BatchNorm1D): _fuse_linear_bn,
}
...@@ -22,6 +22,7 @@ import paddle.nn.quant.quant_layers as quant_layers ...@@ -22,6 +22,7 @@ import paddle.nn.quant.quant_layers as quant_layers
from paddle.fluid.log_helper import get_logger from paddle.fluid.log_helper import get_logger
from paddle.fluid.dygraph.io import INFER_MODEL_SUFFIX, INFER_PARAMS_SUFFIX from paddle.fluid.dygraph.io import INFER_MODEL_SUFFIX, INFER_PARAMS_SUFFIX
from . import fuse_utils
from . import utils from . import utils
from . import ptq_hooks from . import ptq_hooks
from . import ptq_config from . import ptq_config
...@@ -55,7 +56,7 @@ class ImperativePTQ(object): ...@@ -55,7 +56,7 @@ class ImperativePTQ(object):
self._quant_config = quant_config self._quant_config = quant_config
def quantize(self, model, inplace=False): def quantize(self, model, inplace=False, fuse=False, fuse_list=None):
""" """
Add quant config and hook to the target layer. Add quant config and hook to the target layer.
...@@ -63,15 +64,23 @@ class ImperativePTQ(object): ...@@ -63,15 +64,23 @@ class ImperativePTQ(object):
model(paddle.nn.Layer): The model to be quantized. model(paddle.nn.Layer): The model to be quantized.
inplace(bool): Whether apply quantization to the input model. inplace(bool): Whether apply quantization to the input model.
Default: False. Default: False.
Returns: fuse(bool): Whether to fuse layers.
Default: False.
fuse_list(list): The layers' names to be fused. For example,
"fuse_list = [["conv1", "bn1"], ["conv2", "bn2"]]".
A TypeError would be raised if "fuse" was set as
True but "fuse_list" was None.
Default: None.
Return
quantized_model(paddle.nn.Layer): The quantized model. quantized_model(paddle.nn.Layer): The quantized model.
""" """
assert isinstance(model, paddle.nn.Layer), \ assert isinstance(model, paddle.nn.Layer), \
"The model must be the instance of paddle.nn.Layer." "The model must be the instance of paddle.nn.Layer."
if not inplace: if not inplace:
model = copy.deepcopy(model) model = copy.deepcopy(model)
if fuse:
model.eval()
model = fuse_utils.fuse_layers(model, fuse_list)
for name, layer in model.named_sublayers(): for name, layer in model.named_sublayers():
if PTQRegistry.is_supported_layer(layer) \ if PTQRegistry.is_supported_layer(layer) \
and utils.is_leaf_layer(layer) \ and utils.is_leaf_layer(layer) \
......
...@@ -20,6 +20,7 @@ from paddle.fluid import core ...@@ -20,6 +20,7 @@ from paddle.fluid import core
from paddle.fluid.dygraph.container import Sequential from paddle.fluid.dygraph.container import Sequential
from paddle.nn import ReLU, ReLU6, LeakyReLU, Sigmoid, Softmax, PReLU from paddle.nn import ReLU, ReLU6, LeakyReLU, Sigmoid, Softmax, PReLU
from paddle.nn import Linear, Conv2D, Softmax, BatchNorm2D, MaxPool2D from paddle.nn import Linear, Conv2D, Softmax, BatchNorm2D, MaxPool2D
from paddle.nn import BatchNorm1D
from paddle.fluid.log_helper import get_logger from paddle.fluid.log_helper import get_logger
...@@ -43,6 +44,15 @@ def fix_model_dict(model): ...@@ -43,6 +44,15 @@ def fix_model_dict(model):
return model return model
def pre_hook(layer, input):
input_return = (input[0] * 2)
return input_return
def post_hook(layer, input, output):
return output * 2
def train_lenet(lenet, reader, optimizer): def train_lenet(lenet, reader, optimizer):
loss_list = [] loss_list = []
lenet.train() lenet.train()
...@@ -224,3 +234,53 @@ class ImperativeLenetWithSkipQuant(fluid.dygraph.Layer): ...@@ -224,3 +234,53 @@ class ImperativeLenetWithSkipQuant(fluid.dygraph.Layer):
x = self.softmax_0(x) x = self.softmax_0(x)
return x return x
class ImperativeLinearBn(fluid.dygraph.Layer):
def __init__(self):
super(ImperativeLinearBn, self).__init__()
fc_w_attr = paddle.ParamAttr(
name="fc_weight",
initializer=paddle.nn.initializer.Constant(value=0.5))
fc_b_attr = paddle.ParamAttr(
name="fc_bias",
initializer=paddle.nn.initializer.Constant(value=1.0))
bn_w_attr = paddle.ParamAttr(
name="bn_weight",
initializer=paddle.nn.initializer.Constant(value=0.5))
self.linear = Linear(
in_features=10,
out_features=10,
weight_attr=fc_w_attr,
bias_attr=fc_b_attr)
self.bn = BatchNorm1D(10, weight_attr=bn_w_attr)
def forward(self, inputs):
x = self.linear(inputs)
x = self.bn(x)
return x
class ImperativeLinearBn_hook(fluid.dygraph.Layer):
def __init__(self):
super(ImperativeLinearBn_hook, self).__init__()
fc_w_attr = paddle.ParamAttr(
name="linear_weight",
initializer=paddle.nn.initializer.Constant(value=0.5))
self.linear = Linear(
in_features=10, out_features=10, weight_attr=fc_w_attr)
self.bn = BatchNorm1D(10)
forward_pre = self.linear.register_forward_pre_hook(pre_hook)
forward_post = self.bn.register_forward_post_hook(post_hook)
def forward(self, inputs):
x = self.linear(inputs)
x = self.bn(x)
return x
...@@ -23,18 +23,48 @@ import unittest ...@@ -23,18 +23,48 @@ import unittest
import copy import copy
import logging import logging
import paddle.nn as nn
import paddle import paddle
import paddle.fluid as fluid import paddle.fluid as fluid
from paddle.fluid.contrib.slim.quantization import * from paddle.fluid.contrib.slim.quantization import *
from paddle.fluid.log_helper import get_logger from paddle.fluid.log_helper import get_logger
from paddle.dataset.common import download from paddle.dataset.common import download
from imperative_test_utils import fix_model_dict, ImperativeLenet from imperative_test_utils import fix_model_dict, ImperativeLenet, ImperativeLinearBn
from imperative_test_utils import ImperativeLinearBn_hook
_logger = get_logger( _logger = get_logger(
__name__, logging.INFO, fmt='%(asctime)s-%(levelname)s: %(message)s') __name__, logging.INFO, fmt='%(asctime)s-%(levelname)s: %(message)s')
class TestFuseLinearBn(unittest.TestCase):
"""
Fuse the linear and bn layers, and then quantize the model.
"""
def test_fuse(self):
model = ImperativeLinearBn()
model_h = ImperativeLinearBn_hook()
inputs = paddle.randn((3, 10), dtype="float32")
config = PTQConfig(AbsmaxQuantizer(), AbsmaxQuantizer())
ptq = ImperativePTQ(config)
f_l = [['linear', 'bn']]
quant_model = ptq.quantize(model, fuse=True, fuse_list=f_l)
quant_h = ptq.quantize(model_h, fuse=True, fuse_list=f_l)
for name, layer in quant_model.named_sublayers():
if name in f_l:
assert not (isinstance(layer, nn.BatchNorm1D) or
isinstance(layer, nn.BatchNorm2D))
out = model(inputs)
out_h = model_h(inputs)
out_quant = quant_model(inputs)
out_quant_h = quant_h(inputs)
cos_sim_func = nn.CosineSimilarity(axis=0)
print('fuse linear+bn',
cos_sim_func(out.flatten(), out_quant.flatten()))
print(cos_sim_func(out_h.flatten(), out_quant_h.flatten()))
class TestImperativePTQ(unittest.TestCase): class TestImperativePTQ(unittest.TestCase):
""" """
""" """
...@@ -177,7 +207,6 @@ class TestImperativePTQ(unittest.TestCase): ...@@ -177,7 +207,6 @@ class TestImperativePTQ(unittest.TestCase):
model = ImperativeLenet() model = ImperativeLenet()
model_state_dict = paddle.load(params_path) model_state_dict = paddle.load(params_path)
model.set_state_dict(model_state_dict) model.set_state_dict(model_state_dict)
# Quantize, calibrate and save # Quantize, calibrate and save
quant_model = self.ptq.quantize(model) quant_model = self.ptq.quantize(model)
before_acc_top1 = self.model_test(quant_model, self.batch_num, before_acc_top1 = self.model_test(quant_model, self.batch_num,
...@@ -216,6 +245,67 @@ class TestImperativePTQ(unittest.TestCase): ...@@ -216,6 +245,67 @@ class TestImperativePTQ(unittest.TestCase):
print("total time: %ss \n" % (end_time - start_time)) print("total time: %ss \n" % (end_time - start_time))
class TestImperativePTQfuse(TestImperativePTQ):
def test_ptq(self):
start_time = time.time()
self.set_vars()
# Load model
params_path = self.download_model(self.lenet_url, self.lenet_md5,
"lenet")
params_path += "/lenet_pretrained/lenet.pdparams"
model = ImperativeLenet()
model_state_dict = paddle.load(params_path)
model.set_state_dict(model_state_dict)
# Quantize, calibrate and save
f_l = [['features.0', 'features.1'], ['features.4', 'features.5']]
quant_model = self.ptq.quantize(model, fuse=True, fuse_list=f_l)
for name, layer in quant_model.named_sublayers():
if name in f_l:
assert not (isinstance(layer, nn.BatchNorm1D) or
isinstance(layer, nn.BatchNorm2D))
before_acc_top1 = self.model_test(quant_model, self.batch_num,
self.batch_size)
input_spec = [
paddle.static.InputSpec(
shape=[None, 1, 28, 28], dtype='float32')
]
self.ptq.save_quantized_model(
model=quant_model, path=self.save_path, input_spec=input_spec)
print('Quantized model saved in {%s}' % self.save_path)
after_acc_top1 = self.model_test(quant_model, self.batch_num,
self.batch_size)
paddle.enable_static()
infer_acc_top1 = self.program_test(self.save_path, self.batch_num,
self.batch_size)
paddle.disable_static()
# Check
print('Before converted acc_top1: %s' % before_acc_top1)
print('After converted acc_top1: %s' % after_acc_top1)
print('Infer acc_top1: %s' % infer_acc_top1)
#Check whether the quant_model is correct after converting.
#The acc of quantized model should be higher than 0.95.
self.assertTrue(
after_acc_top1 >= self.eval_acc_top1,
msg="The test acc {%f} is less than {%f}." %
(after_acc_top1, self.eval_acc_top1))
#Check the saved infer_model.The acc of infer model
#should not be lower than the one of dygraph model.
self.assertTrue(
infer_acc_top1 >= after_acc_top1,
msg='The acc is lower after converting model.')
end_time = time.time()
print("total time: %ss \n" % (end_time - start_time))
class TestImperativePTQHist(TestImperativePTQ): class TestImperativePTQHist(TestImperativePTQ):
def set_vars(self): def set_vars(self):
config = PTQConfig(HistQuantizer(), AbsmaxQuantizer()) config = PTQConfig(HistQuantizer(), AbsmaxQuantizer())
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册