From 399afde437a15d862195274814521d4e46ad416c Mon Sep 17 00:00:00 2001 From: Yibing Liu Date: Tue, 6 Feb 2018 11:55:56 +0000 Subject: [PATCH] Enable model saving during training --- fluid/DeepASR/train.py | 21 ++++++++++++++++++--- 1 file changed, 18 insertions(+), 3 deletions(-) diff --git a/fluid/DeepASR/train.py b/fluid/DeepASR/train.py index 0472714a..91fd5161 100644 --- a/fluid/DeepASR/train.py +++ b/fluid/DeepASR/train.py @@ -95,6 +95,11 @@ def parse_args(): type=str, default='data/val_label.lst', help='label list path for validation.') + parser.add_argument( + '--model_save_dir', + type=str, + default='./checkpoints', + help='directory to save model.') args = parser.parse_args() return args @@ -109,8 +114,8 @@ def print_arguments(args): def train(args): """train in loop.""" - _, avg_cost, accuracy = stacked_lstmp_model(args.hidden_dim, args.proj_dim, - args.stacked_num, args.parallel) + prediction, avg_cost, accuracy = stacked_lstmp_model( + args.hidden_dim, args.proj_dim, args.stacked_num, args.parallel) adam_optimizer = fluid.optimizer.Adam(learning_rate=args.learning_rate) adam_optimizer.minimize(avg_cost) @@ -192,10 +197,16 @@ def train(args): else: sys.stdout.write('.') sys.stdout.flush() - + # run test val_cost, val_acc = test(exe) pass_end_time = time.time() time_consumed = pass_end_time - pass_start_time + # save model + if args.model_save_dir is not None: + model_path = os.path.join( + args.model_save_dir, "deep_asr.pass_" + str(pass_id) + ".model") + fluid.io.save_inference_model(model_path, ["feature"], + [prediction], exe) print("\nPass %d, time consumed: %f s, val cost: %f, val acc: %f\n" % (pass_id, time_consumed, val_cost, val_acc)) @@ -205,4 +216,8 @@ if __name__ == '__main__': args = parse_args() print_arguments(args) + if args.model_save_dir is not None and \ + not os.path.exists(args.model_save_dir): + os.mkdir(args.model_save_dir) + train(args) -- GitLab