未验证 提交 6625f543 编写于 作者: Y Yibing Liu 提交者: GitHub

Merge pull request #663 from kuke/model_init

Enable checkpoints saving and training resuming
...@@ -42,10 +42,11 @@ def parse_args(): ...@@ -42,10 +42,11 @@ def parse_args():
default='data/infer_label.lst', default='data/infer_label.lst',
help='The label list path for inference. (default: %(default)s)') help='The label list path for inference. (default: %(default)s)')
parser.add_argument( parser.add_argument(
'--model_save_path', '--infer_model_path',
type=str, type=str,
default='./checkpoints/deep_asr.pass_0.model/', default='./infer_models/deep_asr.pass_0.infer.model/',
help='The directory for saving model. (default: %(default)s)') help='The directory for loading inference model. '
'(default: %(default)s)')
args = parser.parse_args() args = parser.parse_args()
return args return args
...@@ -68,15 +69,15 @@ def infer(args): ...@@ -68,15 +69,15 @@ def infer(args):
""" Gets one batch of feature data and predicts labels for each sample. """ Gets one batch of feature data and predicts labels for each sample.
""" """
if not os.path.exists(args.model_save_path): if not os.path.exists(args.infer_model_path):
raise IOError("Invalid model path!") raise IOError("Invalid inference model path!")
place = fluid.CUDAPlace(0) if args.device == 'GPU' else fluid.CPUPlace() place = fluid.CUDAPlace(0) if args.device == 'GPU' else fluid.CPUPlace()
exe = fluid.Executor(place) exe = fluid.Executor(place)
# load model # load model
[infer_program, feed_dict, [infer_program, feed_dict,
fetch_targets] = fluid.io.load_inference_model(args.model_save_path, exe) fetch_targets] = fluid.io.load_inference_model(args.infer_model_path, exe)
ltrans = [ ltrans = [
trans_add_delta.TransAddDelta(2, 2), trans_add_delta.TransAddDelta(2, 2),
......
...@@ -125,8 +125,9 @@ def profile(args): ...@@ -125,8 +125,9 @@ def profile(args):
class_num=1749, class_num=1749,
parallel=args.parallel) parallel=args.parallel)
adam_optimizer = fluid.optimizer.Adam(learning_rate=args.learning_rate) optimizer = fluid.optimizer.Momentum(
adam_optimizer.minimize(avg_cost) learning_rate=args.learning_rate, momentum=0.9)
optimizer.minimize(avg_cost)
place = fluid.CPUPlace() if args.device == 'CPU' else fluid.CUDAPlace(0) place = fluid.CPUPlace() if args.device == 'CPU' else fluid.CUDAPlace(0)
exe = fluid.Executor(place) exe = fluid.Executor(place)
......
...@@ -34,17 +34,17 @@ def parse_args(): ...@@ -34,17 +34,17 @@ def parse_args():
'--stacked_num', '--stacked_num',
type=int, type=int,
default=5, default=5,
help='Number of lstm layers to stack. (default: %(default)d)') help='Number of lstmp layers to stack. (default: %(default)d)')
parser.add_argument( parser.add_argument(
'--proj_dim', '--proj_dim',
type=int, type=int,
default=512, default=512,
help='Project size of lstm unit. (default: %(default)d)') help='Project size of lstmp unit. (default: %(default)d)')
parser.add_argument( parser.add_argument(
'--hidden_dim', '--hidden_dim',
type=int, type=int,
default=1024, default=1024,
help='Hidden size of lstm unit. (default: %(default)d)') help='Hidden size of lstmp unit. (default: %(default)d)')
parser.add_argument( parser.add_argument(
'--pass_num', '--pass_num',
type=int, type=int,
...@@ -95,11 +95,23 @@ def parse_args(): ...@@ -95,11 +95,23 @@ def parse_args():
default='data/val_label.lst', default='data/val_label.lst',
help='The label list path for validation. (default: %(default)s)') help='The label list path for validation. (default: %(default)s)')
parser.add_argument( parser.add_argument(
'--model_save_dir', '--init_model_path',
type=str,
default=None,
help="The model (checkpoint) path which the training resumes from. "
"If None, train the model from scratch. (default: %(default)s)")
parser.add_argument(
'--checkpoints',
type=str, type=str,
default='./checkpoints', default='./checkpoints',
help="The directory for saving model. Do not save model if set to " help="The directory for saving checkpoints. Do not save checkpoints "
"''. (default: %(default)s)") "if set to ''. (default: %(default)s)")
parser.add_argument(
'--infer_models',
type=str,
default='./infer_models',
help="The directory for saving inference models. Do not save inference "
"models if set to ''. (default: %(default)s)")
args = parser.parse_args() args = parser.parse_args()
return args return args
...@@ -115,6 +127,15 @@ def train(args): ...@@ -115,6 +127,15 @@ def train(args):
"""train in loop. """train in loop.
""" """
# paths check
if args.init_model_path is not None and \
not os.path.exists(args.init_model_path):
raise IOError("Invalid initial model path!")
if args.checkpoints != '' and not os.path.exists(args.checkpoints):
os.mkdir(args.checkpoints)
if args.infer_models != '' and not os.path.exists(args.infer_models):
os.mkdir(args.infer_models)
prediction, avg_cost, accuracy = stacked_lstmp_model( prediction, avg_cost, accuracy = stacked_lstmp_model(
hidden_dim=args.hidden_dim, hidden_dim=args.hidden_dim,
proj_dim=args.proj_dim, proj_dim=args.proj_dim,
...@@ -122,8 +143,9 @@ def train(args): ...@@ -122,8 +143,9 @@ def train(args):
class_num=1749, class_num=1749,
parallel=args.parallel) parallel=args.parallel)
adam_optimizer = fluid.optimizer.Adam(learning_rate=args.learning_rate) optimizer = fluid.optimizer.Momentum(
adam_optimizer.minimize(avg_cost) learning_rate=args.learning_rate, momentum=0.9)
optimizer.minimize(avg_cost)
# program for test # program for test
test_program = fluid.default_main_program().clone() test_program = fluid.default_main_program().clone()
...@@ -134,6 +156,10 @@ def train(args): ...@@ -134,6 +156,10 @@ def train(args):
exe = fluid.Executor(place) exe = fluid.Executor(place)
exe.run(fluid.default_startup_program()) exe.run(fluid.default_startup_program())
# resume training if initial model provided.
if args.init_model_path is not None:
fluid.io.load_persistables(exe, args.init_model_path)
ltrans = [ ltrans = [
trans_add_delta.TransAddDelta(2, 2), trans_add_delta.TransAddDelta(2, 2),
trans_mean_variance_norm.TransMeanVarianceNorm(args.mean_var), trans_mean_variance_norm.TransMeanVarianceNorm(args.mean_var),
...@@ -200,15 +226,28 @@ def train(args): ...@@ -200,15 +226,28 @@ def train(args):
print("\nBatch %d, train cost: %f, train acc: %f" % print("\nBatch %d, train cost: %f, train acc: %f" %
(batch_id, lodtensor_to_ndarray(cost)[0], (batch_id, lodtensor_to_ndarray(cost)[0],
lodtensor_to_ndarray(acc)[0])) lodtensor_to_ndarray(acc)[0]))
# save the latest checkpoint
if args.checkpoints != '':
model_path = os.path.join(args.checkpoints,
"deep_asr.latest.checkpoint")
fluid.io.save_persistables(exe, model_path)
else: else:
sys.stdout.write('.') sys.stdout.write('.')
sys.stdout.flush() sys.stdout.flush()
# run test # run test
val_cost, val_acc = test(exe) val_cost, val_acc = test(exe)
# save model
if args.model_save_dir != '': # save checkpoint per pass
if args.checkpoints != '':
model_path = os.path.join( model_path = os.path.join(
args.model_save_dir, "deep_asr.pass_" + str(pass_id) + ".model") args.checkpoints,
"deep_asr.pass_" + str(pass_id) + ".checkpoint")
fluid.io.save_persistables(exe, model_path)
# save inference model
if args.infer_models != '':
model_path = os.path.join(
args.infer_models,
"deep_asr.pass_" + str(pass_id) + ".infer.model")
fluid.io.save_inference_model(model_path, ["feature"], fluid.io.save_inference_model(model_path, ["feature"],
[prediction], exe) [prediction], exe)
# cal pass time # cal pass time
...@@ -223,7 +262,4 @@ if __name__ == '__main__': ...@@ -223,7 +262,4 @@ if __name__ == '__main__':
args = parse_args() args = parse_args()
print_arguments(args) print_arguments(args)
if args.model_save_dir != '' and not os.path.exists(args.model_save_dir):
os.mkdir(args.model_save_dir)
train(args) train(args)
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册