diff --git a/slim/quantization/train.py b/slim/quantization/train.py index 2fa8884f5028ed12effaee3106ba6d5a26ee02fb..f8a7d6faefc1c01ae1b7538facaee4cb7b2fad2b 100644 --- a/slim/quantization/train.py +++ b/slim/quantization/train.py @@ -22,6 +22,7 @@ import time import numpy as np import datetime from collections import deque +import shutil from paddle import fluid @@ -42,6 +43,21 @@ logging.basicConfig(level=logging.INFO, format=FORMAT) logger = logging.getLogger(__name__) +def save_checkpoint(exe, prog, path, train_prog): + if os.path.isdir(path): + shutil.rmtree(path) + logger.info('Save model to {}.'.format(path)) + fluid.io.save_persistables(exe, path, main_program=prog) + + v = train_prog.global_block().var('@LR_DECAY_COUNTER@') + fluid.io.save_vars(exe, dirname=path, vars=[v]) + + +def load_global_step(exe, prog, path): + v = prog.global_block().var('@LR_DECAY_COUNTER@') + fluid.io.load_vars(exe, path, prog, [v]) + + def main(): env = os.environ FLAGS.dist = 'PADDLE_TRAINER_ID' in env and 'PADDLE_TRAINERS_NUM' in env @@ -176,9 +192,9 @@ def main(): cfg.pretrain_weights, ignore_params=ignore_params) # insert quantize op in train_prog, return type is CompiledProgram - train_prog = quant_aware(train_prog, place, config, for_test=False) + train_prog_quant = quant_aware(train_prog, place, config, for_test=False) - compiled_train_prog = train_prog.with_data_parallel( + compiled_train_prog = train_prog_quant.with_data_parallel( loss_name=loss.name, build_strategy=build_strategy, exec_strategy=exec_strategy) @@ -192,6 +208,7 @@ def main(): start_iter = 0 if FLAGS.resume_checkpoint: checkpoint.load_checkpoint(exe, eval_prog, FLAGS.resume_checkpoint) + load_global_step(exe, train_prog, FLAGS.resume_checkpoint) start_iter = checkpoint.global_step() train_reader = create_reader(cfg.TrainReader, @@ -237,7 +254,8 @@ def main(): if (it > 0 and it % cfg.snapshot_iter == 0 or it == cfg.max_iters - 1) \ and (not FLAGS.dist or trainer_id == 0): save_name = str(it) if it != cfg.max_iters - 1 else "model_final" - checkpoint.save(exe, eval_prog, os.path.join(save_dir, save_name)) + save_checkpoint(exe, eval_prog, + os.path.join(save_dir, save_name), train_prog) if FLAGS.eval: # evaluation @@ -254,8 +272,9 @@ def main(): if box_ap_stats[0] > best_box_ap_list[0]: best_box_ap_list[0] = box_ap_stats[0] best_box_ap_list[1] = it - checkpoint.save(exe, eval_prog, - os.path.join(save_dir, "best_model")) + save_checkpoint(exe, eval_prog, + os.path.join(save_dir, "best_model"), + train_prog) logger.info("Best test box ap: {}, in iter: {}".format( best_box_ap_list[0], best_box_ap_list[1]))