未验证 提交 40e1cb01 编写于 作者: G Guanghua Yu 提交者: GitHub

support fuse Conv and BN before QAT (#1146)

上级 ce482365
......@@ -80,6 +80,9 @@ QAT
# 需要量化的算子类型
'quantizable_layer_type': ['Conv2D', 'Linear'],
# 是否需要融合Conv + BN
'fuse_conv_bn': False,
}
..
......
......@@ -56,6 +56,8 @@ _quant_config_default = {
'moving_rate': 0.9,
# for dygraph quantization, layers of type in quantizable_layer_type will be quantized
'quantizable_layer_type': ['Conv2D', 'Linear'],
# whether fuse conv and bn before QAT
'fuse_conv_bn': False,
}
......@@ -141,7 +143,8 @@ class PACT(paddle.nn.Layer):
class QAT(object):
"""
Quant Aware Training(QAT): Add the fake quant logic for given quantizable layers, namely add the quant_dequant computational logic both for activation inputs and weight inputs.
Quant Aware Training(QAT): Add the fake quant logic for given quantizable layers,
namely add the quant_dequant computational logic both for activation inputs and weight inputs.
"""
def __init__(self,
......@@ -190,22 +193,42 @@ class QAT(object):
self.act_preprocess = PACT if self.config[
'activation_preprocess_type'] == 'PACT' else None
self.weight_preprocess = weight_preprocess if weight_preprocess is not None else self.weight_preprocess
self.act_preprocess = act_preprocess if act_preprocess is not None else self.act_preprocess
self.weight_preprocess = weight_preprocess if weight_preprocess is not None \
else self.weight_preprocess
self.act_preprocess = act_preprocess if act_preprocess is not None \
else self.act_preprocess
self.weight_quantize = weight_quantize
self.act_quantize = act_quantize
self.imperative_qat = ImperativeQuantAware(
weight_bits=self.config['weight_bits'],
activation_bits=self.config['activation_bits'],
weight_quantize_type=self.config['weight_quantize_type'],
activation_quantize_type=self.config['activation_quantize_type'],
moving_rate=self.config['moving_rate'],
quantizable_layer_type=self.config['quantizable_layer_type'],
weight_preprocess_layer=self.weight_preprocess,
act_preprocess_layer=self.act_preprocess,
weight_quantize_layer=self.weight_quantize,
act_quantize_layer=self.act_quantize)
# TODO: remove try-except when the version is stable
try:
self.imperative_qat = ImperativeQuantAware(
weight_bits=self.config['weight_bits'],
activation_bits=self.config['activation_bits'],
weight_quantize_type=self.config['weight_quantize_type'],
activation_quantize_type=self.config[
'activation_quantize_type'],
moving_rate=self.config['moving_rate'],
quantizable_layer_type=self.config['quantizable_layer_type'],
fuse_conv_bn=self.config[
'fuse_conv_bn'], # support Paddle > 2.3
weight_preprocess_layer=self.weight_preprocess,
act_preprocess_layer=self.act_preprocess,
weight_quantize_layer=self.weight_quantize,
act_quantize_layer=self.act_quantize)
except:
self.imperative_qat = ImperativeQuantAware(
weight_bits=self.config['weight_bits'],
activation_bits=self.config['activation_bits'],
weight_quantize_type=self.config['weight_quantize_type'],
activation_quantize_type=self.config[
'activation_quantize_type'],
moving_rate=self.config['moving_rate'],
quantizable_layer_type=self.config['quantizable_layer_type'],
weight_preprocess_layer=self.weight_preprocess,
act_preprocess_layer=self.act_preprocess,
weight_quantize_layer=self.weight_quantize,
act_quantize_layer=self.act_quantize)
def quantize(self, model, inplace=True):
"""
......@@ -224,11 +247,13 @@ class QAT(object):
self._model = copy.deepcopy(model)
if inplace:
self.imperative_qat.quantize(model)
quant_model = model
quantize_model = self.imperative_qat.quantize(model)
quant_model = quantize_model if quantize_model is not None else model
else:
quant_model = copy.deepcopy(model)
self.imperative_qat.quantize(quant_model)
quantize_model = self.imperative_qat.quantize(quant_model)
if quantize_model is not None:
quant_model = quantize_model
return quant_model
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册