From a494889c91f2025bfaee2a6951a724da4be77590 Mon Sep 17 00:00:00 2001 From: ceci3 Date: Mon, 14 Feb 2022 16:35:51 +0800 Subject: [PATCH] fix pact train.py (#5472) --- tutorials/mobilenetv3_prod/Step6/train.py | 20 ++++++++++---------- 1 file changed, 10 insertions(+), 10 deletions(-) diff --git a/tutorials/mobilenetv3_prod/Step6/train.py b/tutorials/mobilenetv3_prod/Step6/train.py index 1bc10553..951a49c9 100644 --- a/tutorials/mobilenetv3_prod/Step6/train.py +++ b/tutorials/mobilenetv3_prod/Step6/train.py @@ -302,16 +302,16 @@ def main(args): paddle.save(optimizer.state_dict(), os.path.join(args.output_dir, 'best.pdopt')) - if args.pact_quant: - input_spec = [ - InputSpec( - shape=[None, 3, 224, 224], dtype='float32') - ] - quanter.save_quantized_model( - model, - os.path.join(args.output_dir, "qat_inference"), - input_spec=input_spec) - print("QAT inference model saved in {args.output_dir}") + if args.pact_quant: + input_spec = [ + InputSpec( + shape=[None, 3, 224, 224], dtype='float32') + ] + quanter.save_quantized_model( + model, + os.path.join(args.output_dir, "qat_inference"), + input_spec=input_spec) + print("QAT inference model saved in {args.output_dir}") total_time = time.time() - start_time total_time_str = str(datetime.timedelta(seconds=int(total_time))) -- GitLab