未验证 提交 40e1cb01 编写于 作者: G Guanghua Yu 提交者: GitHub

support fuse Conv and BN before QAT (#1146)

上级 ce482365
...@@ -80,6 +80,9 @@ QAT ...@@ -80,6 +80,9 @@ QAT
# 需要量化的算子类型 # 需要量化的算子类型
'quantizable_layer_type': ['Conv2D', 'Linear'], 'quantizable_layer_type': ['Conv2D', 'Linear'],
# 是否需要融合Conv + BN
'fuse_conv_bn': False,
} }
.. ..
......
...@@ -56,6 +56,8 @@ _quant_config_default = { ...@@ -56,6 +56,8 @@ _quant_config_default = {
'moving_rate': 0.9, 'moving_rate': 0.9,
# for dygraph quantization, layers of type in quantizable_layer_type will be quantized # for dygraph quantization, layers of type in quantizable_layer_type will be quantized
'quantizable_layer_type': ['Conv2D', 'Linear'], 'quantizable_layer_type': ['Conv2D', 'Linear'],
# whether fuse conv and bn before QAT
'fuse_conv_bn': False,
} }
...@@ -141,7 +143,8 @@ class PACT(paddle.nn.Layer): ...@@ -141,7 +143,8 @@ class PACT(paddle.nn.Layer):
class QAT(object): class QAT(object):
""" """
Quant Aware Training(QAT): Add the fake quant logic for given quantizable layers, namely add the quant_dequant computational logic both for activation inputs and weight inputs. Quant Aware Training(QAT): Add the fake quant logic for given quantizable layers,
namely add the quant_dequant computational logic both for activation inputs and weight inputs.
""" """
def __init__(self, def __init__(self,
...@@ -190,22 +193,42 @@ class QAT(object): ...@@ -190,22 +193,42 @@ class QAT(object):
self.act_preprocess = PACT if self.config[ self.act_preprocess = PACT if self.config[
'activation_preprocess_type'] == 'PACT' else None 'activation_preprocess_type'] == 'PACT' else None
self.weight_preprocess = weight_preprocess if weight_preprocess is not None else self.weight_preprocess self.weight_preprocess = weight_preprocess if weight_preprocess is not None \
self.act_preprocess = act_preprocess if act_preprocess is not None else self.act_preprocess else self.weight_preprocess
self.act_preprocess = act_preprocess if act_preprocess is not None \
else self.act_preprocess
self.weight_quantize = weight_quantize self.weight_quantize = weight_quantize
self.act_quantize = act_quantize self.act_quantize = act_quantize
self.imperative_qat = ImperativeQuantAware( # TODO: remove try-except when the version is stable
weight_bits=self.config['weight_bits'], try:
activation_bits=self.config['activation_bits'], self.imperative_qat = ImperativeQuantAware(
weight_quantize_type=self.config['weight_quantize_type'], weight_bits=self.config['weight_bits'],
activation_quantize_type=self.config['activation_quantize_type'], activation_bits=self.config['activation_bits'],
moving_rate=self.config['moving_rate'], weight_quantize_type=self.config['weight_quantize_type'],
quantizable_layer_type=self.config['quantizable_layer_type'], activation_quantize_type=self.config[
weight_preprocess_layer=self.weight_preprocess, 'activation_quantize_type'],
act_preprocess_layer=self.act_preprocess, moving_rate=self.config['moving_rate'],
weight_quantize_layer=self.weight_quantize, quantizable_layer_type=self.config['quantizable_layer_type'],
act_quantize_layer=self.act_quantize) fuse_conv_bn=self.config[
'fuse_conv_bn'], # support Paddle > 2.3
weight_preprocess_layer=self.weight_preprocess,
act_preprocess_layer=self.act_preprocess,
weight_quantize_layer=self.weight_quantize,
act_quantize_layer=self.act_quantize)
except:
self.imperative_qat = ImperativeQuantAware(
weight_bits=self.config['weight_bits'],
activation_bits=self.config['activation_bits'],
weight_quantize_type=self.config['weight_quantize_type'],
activation_quantize_type=self.config[
'activation_quantize_type'],
moving_rate=self.config['moving_rate'],
quantizable_layer_type=self.config['quantizable_layer_type'],
weight_preprocess_layer=self.weight_preprocess,
act_preprocess_layer=self.act_preprocess,
weight_quantize_layer=self.weight_quantize,
act_quantize_layer=self.act_quantize)
def quantize(self, model, inplace=True): def quantize(self, model, inplace=True):
""" """
...@@ -224,11 +247,13 @@ class QAT(object): ...@@ -224,11 +247,13 @@ class QAT(object):
self._model = copy.deepcopy(model) self._model = copy.deepcopy(model)
if inplace: if inplace:
self.imperative_qat.quantize(model) quantize_model = self.imperative_qat.quantize(model)
quant_model = model quant_model = quantize_model if quantize_model is not None else model
else: else:
quant_model = copy.deepcopy(model) quant_model = copy.deepcopy(model)
self.imperative_qat.quantize(quant_model) quantize_model = self.imperative_qat.quantize(quant_model)
if quantize_model is not None:
quant_model = quantize_model
return quant_model return quant_model
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册