未验证 提交 3a9c4213 编写于 作者: L Liufang Sang 提交者: GitHub

fix save checkpoint in quantization (#257)

* fix save checkpoint in quantization
* fix details
上级 d99e2045
......@@ -22,6 +22,7 @@ import time
import numpy as np
import datetime
from collections import deque
import shutil
from paddle import fluid
......@@ -42,6 +43,21 @@ logging.basicConfig(level=logging.INFO, format=FORMAT)
logger = logging.getLogger(__name__)
def save_checkpoint(exe, prog, path, train_prog):
if os.path.isdir(path):
shutil.rmtree(path)
logger.info('Save model to {}.'.format(path))
fluid.io.save_persistables(exe, path, main_program=prog)
v = train_prog.global_block().var('@LR_DECAY_COUNTER@')
fluid.io.save_vars(exe, dirname=path, vars=[v])
def load_global_step(exe, prog, path):
v = prog.global_block().var('@LR_DECAY_COUNTER@')
fluid.io.load_vars(exe, path, prog, [v])
def main():
env = os.environ
FLAGS.dist = 'PADDLE_TRAINER_ID' in env and 'PADDLE_TRAINERS_NUM' in env
......@@ -176,9 +192,9 @@ def main():
cfg.pretrain_weights,
ignore_params=ignore_params)
# insert quantize op in train_prog, return type is CompiledProgram
train_prog = quant_aware(train_prog, place, config, for_test=False)
train_prog_quant = quant_aware(train_prog, place, config, for_test=False)
compiled_train_prog = train_prog.with_data_parallel(
compiled_train_prog = train_prog_quant.with_data_parallel(
loss_name=loss.name,
build_strategy=build_strategy,
exec_strategy=exec_strategy)
......@@ -192,6 +208,7 @@ def main():
start_iter = 0
if FLAGS.resume_checkpoint:
checkpoint.load_checkpoint(exe, eval_prog, FLAGS.resume_checkpoint)
load_global_step(exe, train_prog, FLAGS.resume_checkpoint)
start_iter = checkpoint.global_step()
train_reader = create_reader(cfg.TrainReader,
......@@ -237,7 +254,8 @@ def main():
if (it > 0 and it % cfg.snapshot_iter == 0 or it == cfg.max_iters - 1) \
and (not FLAGS.dist or trainer_id == 0):
save_name = str(it) if it != cfg.max_iters - 1 else "model_final"
checkpoint.save(exe, eval_prog, os.path.join(save_dir, save_name))
save_checkpoint(exe, eval_prog,
os.path.join(save_dir, save_name), train_prog)
if FLAGS.eval:
# evaluation
......@@ -254,8 +272,9 @@ def main():
if box_ap_stats[0] > best_box_ap_list[0]:
best_box_ap_list[0] = box_ap_stats[0]
best_box_ap_list[1] = it
checkpoint.save(exe, eval_prog,
os.path.join(save_dir, "best_model"))
save_checkpoint(exe, eval_prog,
os.path.join(save_dir, "best_model"),
train_prog)
logger.info("Best test box ap: {}, in iter: {}".format(
best_box_ap_list[0], best_box_ap_list[1]))
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册