From a45431c99a775782b7fe5633f313d36ff582e797 Mon Sep 17 00:00:00 2001 From: Bai Yifan Date: Mon, 10 Aug 2020 19:29:44 +0800 Subject: [PATCH] support pact demo load checkpoints (#414) --- demo/quant/pact_quant_aware/README.md | 4 ++-- demo/quant/pact_quant_aware/train.py | 32 +++++++++++++++++++++------ 2 files changed, 27 insertions(+), 9 deletions(-) diff --git a/demo/quant/pact_quant_aware/README.md b/demo/quant/pact_quant_aware/README.md index c494d42c..cf64d7e2 100644 --- a/demo/quant/pact_quant_aware/README.md +++ b/demo/quant/pact_quant_aware/README.md @@ -159,7 +159,7 @@ compiled_train_prog = compiled_train_prog.with_data_parallel( 普通量化: ``` -python train.py --model MobileNetV3_large_x1_0 --pretrained_model ./pretrain/MobileNetV3_large_x1_0_ssld_pretrained --checkpoint_dir ./output/MobileNetV3_large_x1_0 --num_epochs 30 --lr 0.0001 --use_pact False +python train.py --model MobileNetV3_large_x1_0 --pretrained_model ./pretrain/MobileNetV3_large_x1_0_ssld_pretrained --num_epochs 30 --lr 0.0001 --use_pact False ``` @@ -179,7 +179,7 @@ python train.py --model MobileNetV3_large_x1_0 --pretrained_model ./pretrain/Mob 使用PACT量化训练 ``` -python train.py --model MobileNetV3_large_x1_0 --pretrained_model ./pretrain/MobileNetV3_large_x1_0_ssld_pretrained --checkpoint_dir ./output/MobileNetV3_large_x1_0 --num_epochs 30 --lr 0.0001 --use_pact True --batch_size 128 --lr_strategy=piecewise_decay --step_epochs 20 --l2_decay 1e-5 +python train.py --model MobileNetV3_large_x1_0 --pretrained_model ./pretrain/MobileNetV3_large_x1_0_ssld_pretrained --num_epochs 30 --lr 0.0001 --use_pact True --batch_size 128 --lr_strategy=piecewise_decay --step_epochs 20 --l2_decay 1e-5 ``` 输出结果为 diff --git a/demo/quant/pact_quant_aware/train.py b/demo/quant/pact_quant_aware/train.py index 9911f944..812c4a04 100644 --- a/demo/quant/pact_quant_aware/train.py +++ b/demo/quant/pact_quant_aware/train.py @@ -53,8 +53,12 @@ add_arg('data', str, "imagenet", "Which data to use. 'mnist' or 'imagenet'") add_arg('log_period', int, 10, "Log period in batches.") -add_arg('checkpoint_dir', str, "output", - "checkpoint save dir") +add_arg('checkpoint_dir', str, None, + "checkpoint dir") +add_arg('checkpoint_epoch', int, None, + "checkpoint epoch") +add_arg('output_dir', str, "output/MobileNetV3_large_x1_0", + "model save dir") add_arg('use_pact', bool, True, "Whether to use PACT or not.") @@ -244,6 +248,7 @@ def compress(args): compiled_train_prog, feed=train_feeder.feed(data), fetch_list=[avg_cost.name, acc_top1.name, acc_top5.name]) + end_time = time.time() loss_n = np.mean(loss_n) acc_top1_n = np.mean(acc_top1_n) @@ -279,24 +284,37 @@ def compress(args): # train loop best_acc1 = 0.0 best_epoch = 0 - for i in range(args.num_epochs): + + start_epoch = 0 + if args.checkpoint_dir is not None: + ckpt_path = args.checkpoint_dir + assert args.checkpoint_epoch is not None, "checkpoint_epoch must be set" + start_epoch = args.checkpoint_epoch + fluid.io.load_persistables( + exe, dirname=args.checkpoint_dir, main_program=val_program) + start_step = start_epoch * int( + math.ceil(float(args.total_images) / args.batch_size)) + v = fluid.global_scope().find_var('@LR_DECAY_COUNTER@').get_tensor() + v.set(np.array([start_step]).astype(np.float32), place) + + for i in range(start_epoch, args.num_epochs): train(i, compiled_train_prog) acc1 = test(i, val_program) fluid.io.save_persistables( exe, - dirname=os.path.join(args.checkpoint_dir, str(i)), + dirname=os.path.join(args.output_dir, str(i)), main_program=val_program) if acc1 > best_acc1: best_acc1 = acc1 best_epoch = i fluid.io.save_persistables( exe, - dirname=os.path.join(args.checkpoint_dir, 'best_model'), + dirname=os.path.join(args.output_dir, 'best_model'), main_program=val_program) - if os.path.exists(os.path.join(args.checkpoint_dir, 'best_model')): + if os.path.exists(os.path.join(args.output_dir, 'best_model')): fluid.io.load_persistables( exe, - dirname=os.path.join(args.checkpoint_dir, 'best_model'), + dirname=os.path.join(args.output_dir, 'best_model'), main_program=val_program) # 3. Freeze the graph after training by adjusting the quantize # operators' order for the inference. -- GitLab