未验证 提交 3a9c4213 编写于 作者: L Liufang Sang 提交者: GitHub

fix save checkpoint in quantization (#257)

* fix save checkpoint in quantization
* fix details
上级 d99e2045
...@@ -22,6 +22,7 @@ import time ...@@ -22,6 +22,7 @@ import time
import numpy as np import numpy as np
import datetime import datetime
from collections import deque from collections import deque
import shutil
from paddle import fluid from paddle import fluid
...@@ -42,6 +43,21 @@ logging.basicConfig(level=logging.INFO, format=FORMAT) ...@@ -42,6 +43,21 @@ logging.basicConfig(level=logging.INFO, format=FORMAT)
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
def save_checkpoint(exe, prog, path, train_prog):
if os.path.isdir(path):
shutil.rmtree(path)
logger.info('Save model to {}.'.format(path))
fluid.io.save_persistables(exe, path, main_program=prog)
v = train_prog.global_block().var('@LR_DECAY_COUNTER@')
fluid.io.save_vars(exe, dirname=path, vars=[v])
def load_global_step(exe, prog, path):
v = prog.global_block().var('@LR_DECAY_COUNTER@')
fluid.io.load_vars(exe, path, prog, [v])
def main(): def main():
env = os.environ env = os.environ
FLAGS.dist = 'PADDLE_TRAINER_ID' in env and 'PADDLE_TRAINERS_NUM' in env FLAGS.dist = 'PADDLE_TRAINER_ID' in env and 'PADDLE_TRAINERS_NUM' in env
...@@ -176,9 +192,9 @@ def main(): ...@@ -176,9 +192,9 @@ def main():
cfg.pretrain_weights, cfg.pretrain_weights,
ignore_params=ignore_params) ignore_params=ignore_params)
# insert quantize op in train_prog, return type is CompiledProgram # insert quantize op in train_prog, return type is CompiledProgram
train_prog = quant_aware(train_prog, place, config, for_test=False) train_prog_quant = quant_aware(train_prog, place, config, for_test=False)
compiled_train_prog = train_prog.with_data_parallel( compiled_train_prog = train_prog_quant.with_data_parallel(
loss_name=loss.name, loss_name=loss.name,
build_strategy=build_strategy, build_strategy=build_strategy,
exec_strategy=exec_strategy) exec_strategy=exec_strategy)
...@@ -192,6 +208,7 @@ def main(): ...@@ -192,6 +208,7 @@ def main():
start_iter = 0 start_iter = 0
if FLAGS.resume_checkpoint: if FLAGS.resume_checkpoint:
checkpoint.load_checkpoint(exe, eval_prog, FLAGS.resume_checkpoint) checkpoint.load_checkpoint(exe, eval_prog, FLAGS.resume_checkpoint)
load_global_step(exe, train_prog, FLAGS.resume_checkpoint)
start_iter = checkpoint.global_step() start_iter = checkpoint.global_step()
train_reader = create_reader(cfg.TrainReader, train_reader = create_reader(cfg.TrainReader,
...@@ -237,7 +254,8 @@ def main(): ...@@ -237,7 +254,8 @@ def main():
if (it > 0 and it % cfg.snapshot_iter == 0 or it == cfg.max_iters - 1) \ if (it > 0 and it % cfg.snapshot_iter == 0 or it == cfg.max_iters - 1) \
and (not FLAGS.dist or trainer_id == 0): and (not FLAGS.dist or trainer_id == 0):
save_name = str(it) if it != cfg.max_iters - 1 else "model_final" save_name = str(it) if it != cfg.max_iters - 1 else "model_final"
checkpoint.save(exe, eval_prog, os.path.join(save_dir, save_name)) save_checkpoint(exe, eval_prog,
os.path.join(save_dir, save_name), train_prog)
if FLAGS.eval: if FLAGS.eval:
# evaluation # evaluation
...@@ -254,8 +272,9 @@ def main(): ...@@ -254,8 +272,9 @@ def main():
if box_ap_stats[0] > best_box_ap_list[0]: if box_ap_stats[0] > best_box_ap_list[0]:
best_box_ap_list[0] = box_ap_stats[0] best_box_ap_list[0] = box_ap_stats[0]
best_box_ap_list[1] = it best_box_ap_list[1] = it
checkpoint.save(exe, eval_prog, save_checkpoint(exe, eval_prog,
os.path.join(save_dir, "best_model")) os.path.join(save_dir, "best_model"),
train_prog)
logger.info("Best test box ap: {}, in iter: {}".format( logger.info("Best test box ap: {}, in iter: {}".format(
best_box_ap_list[0], best_box_ap_list[1])) best_box_ap_list[0], best_box_ap_list[1]))
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册