train.py 6.1 KB
Newer Older
X
Xinghai Sun 已提交
1 2 3 4
"""
   Trainer for a simplifed version of Baidu DeepSpeech2 model.
"""

5
import paddle.v2 as paddle
6
import distutils.util
7
import argparse
8
import gzip
9
import time
X
Xinghai Sun 已提交
10 11
import sys
from model import deep_speech2
12 13
from audio_data_utils import DataGenerator
import numpy as np
14
import os
X
Xinghai Sun 已提交
15 16

#TODO: add WER metric
17 18

parser = argparse.ArgumentParser(
X
Xinghai Sun 已提交
19
    description='Simplified version of DeepSpeech2 trainer.')
20
parser.add_argument(
21
    "--batch_size", default=32, type=int, help="Minibatch size.")
22
parser.add_argument(
23 24 25 26
    "--num_passes",
    default=20,
    type=int,
    help="Training pass number. (default: %(default)s)")
27
parser.add_argument(
28 29 30 31
    "--num_conv_layers",
    default=2,
    type=int,
    help="Convolution layer number. (default: %(default)s)")
32
parser.add_argument(
33 34 35 36
    "--num_rnn_layers",
    default=3,
    type=int,
    help="RNN layer number. (default: %(default)s)")
37
parser.add_argument(
38 39 40 41
    "--rnn_layer_size",
    default=512,
    type=int,
    help="RNN layer cell number. (default: %(default)s)")
42
parser.add_argument(
43 44 45 46
    "--adam_learning_rate",
    default=5e-4,
    type=float,
    help="Learning rate for ADAM Optimizer. (default: %(default)s)")
47
parser.add_argument(
48 49 50 51
    "--use_gpu",
    default=True,
    type=distutils.util.strtobool,
    help="Use gpu or not. (default: %(default)s)")
52
parser.add_argument(
53 54 55 56 57 58 59 60 61 62 63
    "--use_sortagrad",
    default=False,
    type=distutils.util.strtobool,
    help="Use sortagrad or not. (default: %(default)s)")
parser.add_argument(
    "--trainer_count",
    default=4,
    type=int,
    help="Trainer number. (default: %(default)s)")
parser.add_argument(
    "--normalizer_manifest_path",
64
    default='data/manifest.libri.train-clean-100',
65 66 67 68
    type=str,
    help="Manifest path for normalizer. (default: %(default)s)")
parser.add_argument(
    "--train_manifest_path",
69
    default='data/manifest.libri.train-clean-100',
70 71 72 73
    type=str,
    help="Manifest path for training. (default: %(default)s)")
parser.add_argument(
    "--dev_manifest_path",
74
    default='data/manifest.libri.dev-clean',
75 76
    type=str,
    help="Manifest path for validation. (default: %(default)s)")
77 78 79 80 81
parser.add_argument(
    "--vocab_filepath",
    default='data/eng_vocab.txt',
    type=str,
    help="Vocabulary filepath. (default: %(default)s)")
82 83
parser.add_argument(
    "--init_model_path",
Y
yangyaming 已提交
84
    default=None,
85
    type=str,
Y
yangyaming 已提交
86 87 88
    help="If set None, the training will start from scratch. "
    "Otherwise, the training will resume from "
    "the existing model of this path. (default: %(default)s)")
89 90 91 92
args = parser.parse_args()


def train():
X
Xinghai Sun 已提交
93 94 95
    """
    DeepSpeech2 training.
    """
96

97
    # initialize data generator
98 99 100 101 102 103 104 105 106
    def data_generator():
        return DataGenerator(
            vocab_filepath=args.vocab_filepath,
            normalizer_manifest_path=args.normalizer_manifest_path,
            normalizer_num_samples=200,
            max_duration=20.0,
            min_duration=0.0,
            stride_ms=10,
            window_ms=20)
107

108 109
    train_generator = data_generator()
    test_generator = data_generator()
110
    # create network config
111 112 113 114
    dict_size = train_generator.vocabulary_size()
    # paddle.data_type.dense_array is used for variable batch input.
    # the size 161 * 161 is only an placeholder value and the real shape
    # of input batch data will be set at each batch.
115
    audio_data = paddle.layer.data(
116
        name="audio_spectrogram", type=paddle.data_type.dense_array(161 * 161))
117 118 119
    text_data = paddle.layer.data(
        name="transcript_text",
        type=paddle.data_type.integer_value_sequence(dict_size))
120
    cost = deep_speech2(
121 122 123 124 125
        audio_data=audio_data,
        text_data=text_data,
        dict_size=dict_size,
        num_conv_layers=args.num_conv_layers,
        num_rnn_layers=args.num_rnn_layers,
126 127
        rnn_size=args.rnn_layer_size,
        is_inference=False)
128

129 130 131 132
    # create/load parameters and optimizer
    if args.init_model_path is None:
        parameters = paddle.parameters.create(cost)
    else:
Y
yangyaming 已提交
133 134
        if not os.path.isfile(args.init_model_path):
            raise IOError("Invalid model!")
135 136
        parameters = paddle.parameters.Parameters.from_tar(
            gzip.open(args.init_model_path))
137
    optimizer = paddle.optimizer.Adam(
138
        learning_rate=args.adam_learning_rate, gradient_clipping_threshold=400)
139 140 141
    trainer = paddle.trainer.SGD(
        cost=cost, parameters=parameters, update_equation=optimizer)

142
    # prepare data reader
143
    train_batch_reader = train_generator.batch_reader_creator(
144
        manifest_path=args.train_manifest_path,
145
        batch_size=args.batch_size,
146
        sortagrad=True,
147
        shuffle=True)
148
    test_batch_reader = test_generator.batch_reader_creator(
149
        manifest_path=args.dev_manifest_path,
150
        batch_size=args.batch_size,
151
        shuffle=False)
152
    feeding = train_generator.data_name_feeding()
153

154 155
    # create event handler
    def event_handler(event):
156
        global start_time, cost_sum, cost_counter
157
        if isinstance(event, paddle.event.EndIteration):
158 159 160
            cost_sum += event.cost
            cost_counter += 1
            if event.batch_id % 50 == 0:
161
                print "\nPass: %d, Batch: %d, TrainCost: %f" % (
162 163 164 165
                    event.pass_id, event.batch_id, cost_sum / cost_counter)
                cost_sum, cost_counter = 0.0, 0
                with gzip.open("params.tar.gz", 'w') as f:
                    parameters.to_tar(f)
166 167 168
            else:
                sys.stdout.write('.')
                sys.stdout.flush()
169 170
        if isinstance(event, paddle.event.BeginPass):
            start_time = time.time()
171
            cost_sum, cost_counter = 0.0, 0
172 173
        if isinstance(event, paddle.event.EndPass):
            result = trainer.test(reader=test_batch_reader, feeding=feeding)
174
            print "\n------- Time: %d sec,  Pass: %d, ValidationCost: %s" % (
175
                time.time() - start_time, event.pass_id, result.cost)
176 177

    # run train
178
    trainer.train(
179
        reader=train_batch_reader,
180
        event_handler=event_handler,
181
        num_passes=args.num_passes,
182 183 184 185
        feeding=feeding)


def main():
186
    paddle.init(use_gpu=args.use_gpu, trainer_count=args.trainer_count)
187 188 189 190 191
    train()


if __name__ == '__main__':
    main()