提交 399afde4 编写于 作者: Y Yibing Liu

Enable model saving during training

上级 c309bd3e
...@@ -95,6 +95,11 @@ def parse_args(): ...@@ -95,6 +95,11 @@ def parse_args():
type=str, type=str,
default='data/val_label.lst', default='data/val_label.lst',
help='label list path for validation.') help='label list path for validation.')
parser.add_argument(
'--model_save_dir',
type=str,
default='./checkpoints',
help='directory to save model.')
args = parser.parse_args() args = parser.parse_args()
return args return args
...@@ -109,8 +114,8 @@ def print_arguments(args): ...@@ -109,8 +114,8 @@ def print_arguments(args):
def train(args): def train(args):
"""train in loop.""" """train in loop."""
_, avg_cost, accuracy = stacked_lstmp_model(args.hidden_dim, args.proj_dim, prediction, avg_cost, accuracy = stacked_lstmp_model(
args.stacked_num, args.parallel) args.hidden_dim, args.proj_dim, args.stacked_num, args.parallel)
adam_optimizer = fluid.optimizer.Adam(learning_rate=args.learning_rate) adam_optimizer = fluid.optimizer.Adam(learning_rate=args.learning_rate)
adam_optimizer.minimize(avg_cost) adam_optimizer.minimize(avg_cost)
...@@ -192,10 +197,16 @@ def train(args): ...@@ -192,10 +197,16 @@ def train(args):
else: else:
sys.stdout.write('.') sys.stdout.write('.')
sys.stdout.flush() sys.stdout.flush()
# run test
val_cost, val_acc = test(exe) val_cost, val_acc = test(exe)
pass_end_time = time.time() pass_end_time = time.time()
time_consumed = pass_end_time - pass_start_time time_consumed = pass_end_time - pass_start_time
# save model
if args.model_save_dir is not None:
model_path = os.path.join(
args.model_save_dir, "deep_asr.pass_" + str(pass_id) + ".model")
fluid.io.save_inference_model(model_path, ["feature"],
[prediction], exe)
print("\nPass %d, time consumed: %f s, val cost: %f, val acc: %f\n" % print("\nPass %d, time consumed: %f s, val cost: %f, val acc: %f\n" %
(pass_id, time_consumed, val_cost, val_acc)) (pass_id, time_consumed, val_cost, val_acc))
...@@ -205,4 +216,8 @@ if __name__ == '__main__': ...@@ -205,4 +216,8 @@ if __name__ == '__main__':
args = parse_args() args = parse_args()
print_arguments(args) print_arguments(args)
if args.model_save_dir is not None and \
not os.path.exists(args.model_save_dir):
os.mkdir(args.model_save_dir)
train(args) train(args)
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册