diff --git a/deep_speech_2/train.py b/deep_speech_2/train.py index e6a7d076bbfe36e837767657e87a75f73cd2b6d2..89ab23c685888d86956a1c2b798447a9e79c8a47 100644 --- a/deep_speech_2/train.py +++ b/deep_speech_2/train.py @@ -11,6 +11,7 @@ import sys from model import deep_speech2 from audio_data_utils import DataGenerator import numpy as np +import os #TODO: add WER metric @@ -78,6 +79,13 @@ parser.add_argument( default='data/eng_vocab.txt', type=str, help="Vocabulary filepath. (default: %(default)s)") +parser.add_argument( + "--init_model_path", + default=None, + type=str, + help="If set None, the training will start from scratch. " + "Otherwise, the training will resume from " + "the existing model of this path. (default: %(default)s)") args = parser.parse_args() @@ -114,8 +122,14 @@ def train(): rnn_size=args.rnn_layer_size, is_inference=False) - # create parameters and optimizer - parameters = paddle.parameters.create(cost) + # create/load parameters and optimizer + if args.init_model_path is None: + parameters = paddle.parameters.create(cost) + else: + if not os.path.isfile(args.init_model_path): + raise IOError("Invalid model!") + parameters = paddle.parameters.Parameters.from_tar( + gzip.open(args.init_model_path)) optimizer = paddle.optimizer.Adam( learning_rate=args.adam_learning_rate, gradient_clipping_threshold=400) trainer = paddle.trainer.SGD(