From 8c937cbd3be54bbb28c33f1600f3b7413b091f59 Mon Sep 17 00:00:00 2001 From: Yibing Liu Date: Fri, 9 Feb 2018 19:54:47 +0800 Subject: [PATCH] Enable checkpoint saving and training resuming --- fluid/DeepASR/train.py | 61 +++++++++++++++++++++++++++++++++--------- 1 file changed, 48 insertions(+), 13 deletions(-) diff --git a/fluid/DeepASR/train.py b/fluid/DeepASR/train.py index 2fdc6c0b..8612b355 100644 --- a/fluid/DeepASR/train.py +++ b/fluid/DeepASR/train.py @@ -34,17 +34,17 @@ def parse_args(): '--stacked_num', type=int, default=5, - help='Number of lstm layers to stack. (default: %(default)d)') + help='Number of lstmp layers to stack. (default: %(default)d)') parser.add_argument( '--proj_dim', type=int, default=512, - help='Project size of lstm unit. (default: %(default)d)') + help='Project size of lstmp unit. (default: %(default)d)') parser.add_argument( '--hidden_dim', type=int, default=1024, - help='Hidden size of lstm unit. (default: %(default)d)') + help='Hidden size of lstmp unit. (default: %(default)d)') parser.add_argument( '--pass_num', type=int, @@ -95,11 +95,23 @@ def parse_args(): default='data/val_label.lst', help='The label list path for validation. (default: %(default)s)') parser.add_argument( - '--model_save_dir', + '--init_model_path', + type=str, + default=None, + help="The model path which the training resumes from. If None, train " + "the model from scratch. (default: %(default)s)") + parser.add_argument( + '--checkpoints', type=str, default='./checkpoints', - help="The directory for saving model. Do not save model if set to " - "''. (default: %(default)s)") + help="The directory for saving checkpoints. Do not save checkpoints " + "if set to ''. (default: %(default)s)") + parser.add_argument( + '--infer_models', + type=str, + default='./infer_models', + help="The directory for saving inference models. Do not save inference " + "models if set to ''. (default: %(default)s)") args = parser.parse_args() return args @@ -115,6 +127,15 @@ def train(args): """train in loop. """ + # paths check + if args.init_model_path is not None and \ + not os.path.exists(args.init_model_path): + raise IOError("Invalid initial model path!") + if args.checkpoints != '' and not os.path.exists(args.checkpoints): + os.mkdir(args.checkpoints) + if args.infer_models != '' and not os.path.exists(args.infer_models): + os.mkdir(args.infer_models) + prediction, avg_cost, accuracy = stacked_lstmp_model( hidden_dim=args.hidden_dim, proj_dim=args.proj_dim, @@ -134,6 +155,10 @@ def train(args): exe = fluid.Executor(place) exe.run(fluid.default_startup_program()) + # resume training if initial model provided. + if args.init_model_path is not None: + fluid.io.load_persistables(exe, args.init_model_path) + ltrans = [ trans_add_delta.TransAddDelta(2, 2), trans_mean_variance_norm.TransMeanVarianceNorm(args.mean_var), @@ -183,7 +208,7 @@ def train(args): for batch_id, batch_data in enumerate( train_data_reader.batch_iterator(args.batch_size, args.minimum_batch_size)): - # load_data + # lo ad_data (features, labels, lod) = batch_data feature_t.set(features, place) feature_t.set_lod([lod]) @@ -200,15 +225,28 @@ def train(args): print("\nBatch %d, train cost: %f, train acc: %f" % (batch_id, lodtensor_to_ndarray(cost)[0], lodtensor_to_ndarray(acc)[0])) + # save the latest checkpoints + if args.checkpoints != '': + model_path = os.path.join(args.checkpoints, + "deep_asr.latest.checkpoint") + fluid.io.save_persistables(exe, model_path) else: sys.stdout.write('.') sys.stdout.flush() # run test val_cost, val_acc = test(exe) - # save model - if args.model_save_dir != '': + + # save checkpoint per pass + if args.checkpoints != '': model_path = os.path.join( - args.model_save_dir, "deep_asr.pass_" + str(pass_id) + ".model") + args.checkpoints, + "deep_asr.pass_" + str(pass_id) + ".checkpoint") + fluid.io.save_persistables(exe, model_path) + # save inference model + if args.infer_models != '': + model_path = os.path.join( + args.infer_models, + "deep_asr.pass_" + str(pass_id) + ".infer.model") fluid.io.save_inference_model(model_path, ["feature"], [prediction], exe) # cal pass time @@ -223,7 +261,4 @@ if __name__ == '__main__': args = parse_args() print_arguments(args) - if args.model_save_dir != '' and not os.path.exists(args.model_save_dir): - os.mkdir(args.model_save_dir) - train(args) -- GitLab