未验证 提交 d9fa688a 编写于 作者: G Guanghua Yu 提交者: GitHub

update dygraph new format QAT API (#1464)

上级 9814a312
......@@ -165,6 +165,8 @@ def compress(args):
'moving_rate': 0.9,
# for dygraph quantization, layers of type in quantizable_layer_type will be quantized
'quantizable_layer_type': ['Conv2D', 'Linear'],
# # Whether to export the quantized model with format of ONNX.
'onnx_format': args.onnx_format,
}
if args.use_pact:
......@@ -360,8 +362,7 @@ def compress(args):
input_spec=[
paddle.static.InputSpec(
shape=[None, 3, 224, 224], dtype='float32')
],
onnx_format=args.onnx_format)
])
def main():
......
......@@ -58,6 +58,8 @@ _quant_config_default = {
'quantizable_layer_type': ['Conv2D', 'Linear'],
# whether fuse conv and bn before QAT
'fuse_conv_bn': False,
# Whether to export the quantized model with format of ONNX. Default is False.
'onnx_format': False,
}
......@@ -215,7 +217,9 @@ class QAT(object):
weight_preprocess_layer=self.weight_preprocess,
act_preprocess_layer=self.act_preprocess,
weight_quantize_layer=self.weight_quantize,
act_quantize_layer=self.act_quantize)
act_quantize_layer=self.act_quantize,
onnx_format=self.config['onnx_format'], # support Paddle >= 2.4
)
except:
self.imperative_qat = ImperativeQuantAware(
weight_bits=self.config['weight_bits'],
......@@ -257,11 +261,7 @@ class QAT(object):
return quant_model
def save_quantized_model(self,
model,
path,
input_spec=None,
onnx_format=False):
def save_quantized_model(self, model, path, input_spec=None):
"""
Save the quantized inference model.
......@@ -287,20 +287,30 @@ class QAT(object):
model.eval()
self.imperative_qat.save_quantized_model(
layer=model,
path=path,
input_spec=input_spec,
onnx_format=onnx_format)
layer=model, path=path, input_spec=input_spec)
def _remove_preprocess(self, model):
state_dict = model.state_dict()
self.imperative_qat = ImperativeQuantAware(
weight_bits=self.config['weight_bits'],
activation_bits=self.config['activation_bits'],
weight_quantize_type=self.config['weight_quantize_type'],
activation_quantize_type=self.config['activation_quantize_type'],
moving_rate=self.config['moving_rate'],
quantizable_layer_type=self.config['quantizable_layer_type'])
try:
self.imperative_qat = ImperativeQuantAware(
weight_bits=self.config['weight_bits'],
activation_bits=self.config['activation_bits'],
weight_quantize_type=self.config['weight_quantize_type'],
activation_quantize_type=self.config[
'activation_quantize_type'],
moving_rate=self.config['moving_rate'],
quantizable_layer_type=self.config['quantizable_layer_type'],
onnx_format=self.config['onnx_format'], # support Paddle >= 2.4
)
except:
self.imperative_qat = ImperativeQuantAware(
weight_bits=self.config['weight_bits'],
activation_bits=self.config['activation_bits'],
weight_quantize_type=self.config['weight_quantize_type'],
activation_quantize_type=self.config[
'activation_quantize_type'],
moving_rate=self.config['moving_rate'],
quantizable_layer_type=self.config['quantizable_layer_type'])
with paddle.utils.unique_name.guard():
if hasattr(model, "_layers"):
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册