未验证 提交 d9fa688a 编写于 作者: G Guanghua Yu 提交者: GitHub

update dygraph new format QAT API (#1464)

上级 9814a312
...@@ -165,6 +165,8 @@ def compress(args): ...@@ -165,6 +165,8 @@ def compress(args):
'moving_rate': 0.9, 'moving_rate': 0.9,
# for dygraph quantization, layers of type in quantizable_layer_type will be quantized # for dygraph quantization, layers of type in quantizable_layer_type will be quantized
'quantizable_layer_type': ['Conv2D', 'Linear'], 'quantizable_layer_type': ['Conv2D', 'Linear'],
# # Whether to export the quantized model with format of ONNX.
'onnx_format': args.onnx_format,
} }
if args.use_pact: if args.use_pact:
...@@ -360,8 +362,7 @@ def compress(args): ...@@ -360,8 +362,7 @@ def compress(args):
input_spec=[ input_spec=[
paddle.static.InputSpec( paddle.static.InputSpec(
shape=[None, 3, 224, 224], dtype='float32') shape=[None, 3, 224, 224], dtype='float32')
], ])
onnx_format=args.onnx_format)
def main(): def main():
......
...@@ -58,6 +58,8 @@ _quant_config_default = { ...@@ -58,6 +58,8 @@ _quant_config_default = {
'quantizable_layer_type': ['Conv2D', 'Linear'], 'quantizable_layer_type': ['Conv2D', 'Linear'],
# whether fuse conv and bn before QAT # whether fuse conv and bn before QAT
'fuse_conv_bn': False, 'fuse_conv_bn': False,
# Whether to export the quantized model with format of ONNX. Default is False.
'onnx_format': False,
} }
...@@ -215,7 +217,9 @@ class QAT(object): ...@@ -215,7 +217,9 @@ class QAT(object):
weight_preprocess_layer=self.weight_preprocess, weight_preprocess_layer=self.weight_preprocess,
act_preprocess_layer=self.act_preprocess, act_preprocess_layer=self.act_preprocess,
weight_quantize_layer=self.weight_quantize, weight_quantize_layer=self.weight_quantize,
act_quantize_layer=self.act_quantize) act_quantize_layer=self.act_quantize,
onnx_format=self.config['onnx_format'], # support Paddle >= 2.4
)
except: except:
self.imperative_qat = ImperativeQuantAware( self.imperative_qat = ImperativeQuantAware(
weight_bits=self.config['weight_bits'], weight_bits=self.config['weight_bits'],
...@@ -257,11 +261,7 @@ class QAT(object): ...@@ -257,11 +261,7 @@ class QAT(object):
return quant_model return quant_model
def save_quantized_model(self, def save_quantized_model(self, model, path, input_spec=None):
model,
path,
input_spec=None,
onnx_format=False):
""" """
Save the quantized inference model. Save the quantized inference model.
...@@ -287,18 +287,28 @@ class QAT(object): ...@@ -287,18 +287,28 @@ class QAT(object):
model.eval() model.eval()
self.imperative_qat.save_quantized_model( self.imperative_qat.save_quantized_model(
layer=model, layer=model, path=path, input_spec=input_spec)
path=path,
input_spec=input_spec,
onnx_format=onnx_format)
def _remove_preprocess(self, model): def _remove_preprocess(self, model):
state_dict = model.state_dict() state_dict = model.state_dict()
try:
self.imperative_qat = ImperativeQuantAware( self.imperative_qat = ImperativeQuantAware(
weight_bits=self.config['weight_bits'], weight_bits=self.config['weight_bits'],
activation_bits=self.config['activation_bits'], activation_bits=self.config['activation_bits'],
weight_quantize_type=self.config['weight_quantize_type'], weight_quantize_type=self.config['weight_quantize_type'],
activation_quantize_type=self.config['activation_quantize_type'], activation_quantize_type=self.config[
'activation_quantize_type'],
moving_rate=self.config['moving_rate'],
quantizable_layer_type=self.config['quantizable_layer_type'],
onnx_format=self.config['onnx_format'], # support Paddle >= 2.4
)
except:
self.imperative_qat = ImperativeQuantAware(
weight_bits=self.config['weight_bits'],
activation_bits=self.config['activation_bits'],
weight_quantize_type=self.config['weight_quantize_type'],
activation_quantize_type=self.config[
'activation_quantize_type'],
moving_rate=self.config['moving_rate'], moving_rate=self.config['moving_rate'],
quantizable_layer_type=self.config['quantizable_layer_type']) quantizable_layer_type=self.config['quantizable_layer_type'])
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册