未验证 提交 65b6b3f4 编写于 作者: D dyning 提交者: GitHub

Merge pull request #85 from WuHaobo/save

only eval and save at trainer 0
...@@ -62,9 +62,7 @@ def main(args): ...@@ -62,9 +62,7 @@ def main(args):
startup_prog = fluid.Program() startup_prog = fluid.Program()
train_prog = fluid.Program() train_prog = fluid.Program()
# best_top1_acc_list[0]: top1 acc best_top1_acc_list = (0.0, -1) # (top1_acc, epoch_id)
# best_top1_acc_list[1]: epoch id
best_top1_acc_list = [0.0, 0]
train_dataloader, train_fetchs = program.build( train_dataloader, train_fetchs = program.build(
config, train_prog, startup_prog, is_train=True) config, train_prog, startup_prog, is_train=True)
...@@ -97,24 +95,25 @@ def main(args): ...@@ -97,24 +95,25 @@ def main(args):
# 1. train with train dataset # 1. train with train dataset
program.run(train_dataloader, exe, compiled_train_prog, train_fetchs, program.run(train_dataloader, exe, compiled_train_prog, train_fetchs,
epoch_id, 'train') epoch_id, 'train')
# 2. validate with validate dataset if int(os.environ.get("PADDLE_TRAINERS_ID", 0)) == 0:
if config.validate and epoch_id % config.valid_interval == 0: # 2. validate with validate dataset
top1_acc = program.run(valid_dataloader, exe, compiled_valid_prog, if config.validate and epoch_id % config.valid_interval == 0:
valid_fetchs, epoch_id, 'valid') top1_acc = program.run(valid_dataloader, exe,
if top1_acc > best_top1_acc_list[0]: compiled_valid_prog, valid_fetchs,
best_top1_acc_list[0] = top1_acc epoch_id, 'valid')
best_top1_acc_list[1] = epoch_id if top1_acc > best_top1_acc_list[0]:
logger.info("Best top1 acc: {}, in epoch: {}".format( best_top1_acc_list = (top1_acc, epoch_id)
best_top1_acc_list[0], best_top1_acc_list[1])) logger.info("Best top1 acc: {}, in epoch: {}".format(
*best_top1_acc_list))
model_path = os.path.join(config.model_save_dir,
config.ARCHITECTURE["name"])
save_model(train_prog, model_path, "best_model")
# 3. save the persistable model
if epoch_id % config.save_interval == 0:
model_path = os.path.join(config.model_save_dir, model_path = os.path.join(config.model_save_dir,
config.ARCHITECTURE["name"]) config.ARCHITECTURE["name"])
save_model(train_prog, model_path, "best_model") save_model(train_prog, model_path, epoch_id)
# 3. save the persistable model
if epoch_id % config.save_interval == 0:
model_path = os.path.join(config.model_save_dir,
config.ARCHITECTURE["name"])
save_model(train_prog, model_path, epoch_id)
if __name__ == '__main__': if __name__ == '__main__':
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册