diff --git a/demo/dygraph/quant/train.py b/demo/dygraph/quant/train.py index 202b3eb40f44e9063fce45868e2cf318bdc73a3e..31fc67be9c89ce42e56f87d52fe332f83e951551 100644 --- a/demo/dygraph/quant/train.py +++ b/demo/dygraph/quant/train.py @@ -165,6 +165,8 @@ def compress(args): 'moving_rate': 0.9, # for dygraph quantization, layers of type in quantizable_layer_type will be quantized 'quantizable_layer_type': ['Conv2D', 'Linear'], + # # Whether to export the quantized model with format of ONNX. + 'onnx_format': args.onnx_format, } if args.use_pact: @@ -360,8 +362,7 @@ def compress(args): input_spec=[ paddle.static.InputSpec( shape=[None, 3, 224, 224], dtype='float32') - ], - onnx_format=args.onnx_format) + ]) def main(): diff --git a/paddleslim/dygraph/quant/qat.py b/paddleslim/dygraph/quant/qat.py index 34b1ae7ba82bcbd904f55523ae2362ff02816398..05b1b40d3b1eb37912a417421ba89e4bf357a0d1 100644 --- a/paddleslim/dygraph/quant/qat.py +++ b/paddleslim/dygraph/quant/qat.py @@ -58,6 +58,8 @@ _quant_config_default = { 'quantizable_layer_type': ['Conv2D', 'Linear'], # whether fuse conv and bn before QAT 'fuse_conv_bn': False, + # Whether to export the quantized model with format of ONNX. Default is False. + 'onnx_format': False, } @@ -215,7 +217,9 @@ class QAT(object): weight_preprocess_layer=self.weight_preprocess, act_preprocess_layer=self.act_preprocess, weight_quantize_layer=self.weight_quantize, - act_quantize_layer=self.act_quantize) + act_quantize_layer=self.act_quantize, + onnx_format=self.config['onnx_format'], # support Paddle >= 2.4 + ) except: self.imperative_qat = ImperativeQuantAware( weight_bits=self.config['weight_bits'], @@ -257,11 +261,7 @@ class QAT(object): return quant_model - def save_quantized_model(self, - model, - path, - input_spec=None, - onnx_format=False): + def save_quantized_model(self, model, path, input_spec=None): """ Save the quantized inference model. @@ -287,20 +287,30 @@ class QAT(object): model.eval() self.imperative_qat.save_quantized_model( - layer=model, - path=path, - input_spec=input_spec, - onnx_format=onnx_format) + layer=model, path=path, input_spec=input_spec) def _remove_preprocess(self, model): state_dict = model.state_dict() - self.imperative_qat = ImperativeQuantAware( - weight_bits=self.config['weight_bits'], - activation_bits=self.config['activation_bits'], - weight_quantize_type=self.config['weight_quantize_type'], - activation_quantize_type=self.config['activation_quantize_type'], - moving_rate=self.config['moving_rate'], - quantizable_layer_type=self.config['quantizable_layer_type']) + try: + self.imperative_qat = ImperativeQuantAware( + weight_bits=self.config['weight_bits'], + activation_bits=self.config['activation_bits'], + weight_quantize_type=self.config['weight_quantize_type'], + activation_quantize_type=self.config[ + 'activation_quantize_type'], + moving_rate=self.config['moving_rate'], + quantizable_layer_type=self.config['quantizable_layer_type'], + onnx_format=self.config['onnx_format'], # support Paddle >= 2.4 + ) + except: + self.imperative_qat = ImperativeQuantAware( + weight_bits=self.config['weight_bits'], + activation_bits=self.config['activation_bits'], + weight_quantize_type=self.config['weight_quantize_type'], + activation_quantize_type=self.config[ + 'activation_quantize_type'], + moving_rate=self.config['moving_rate'], + quantizable_layer_type=self.config['quantizable_layer_type']) with paddle.utils.unique_name.guard(): if hasattr(model, "_layers"):