提交 992e3be8 编写于 作者: X Xinghai Sun 提交者: GitHub

Merge pull request #72 from pkuyym/fix-2372

Add model resuming for DeepSpeech2 trainer.
......@@ -11,6 +11,7 @@ import sys
from model import deep_speech2
from audio_data_utils import DataGenerator
import numpy as np
import os
#TODO: add WER metric
......@@ -78,6 +79,13 @@ parser.add_argument(
default='data/eng_vocab.txt',
type=str,
help="Vocabulary filepath. (default: %(default)s)")
parser.add_argument(
"--init_model_path",
default=None,
type=str,
help="If set None, the training will start from scratch. "
"Otherwise, the training will resume from "
"the existing model of this path. (default: %(default)s)")
args = parser.parse_args()
......@@ -114,8 +122,14 @@ def train():
rnn_size=args.rnn_layer_size,
is_inference=False)
# create parameters and optimizer
# create/load parameters and optimizer
if args.init_model_path is None:
parameters = paddle.parameters.create(cost)
else:
if not os.path.isfile(args.init_model_path):
raise IOError("Invalid model!")
parameters = paddle.parameters.Parameters.from_tar(
gzip.open(args.init_model_path))
optimizer = paddle.optimizer.Adam(
learning_rate=args.adam_learning_rate, gradient_clipping_threshold=400)
trainer = paddle.trainer.SGD(
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册