train.py 7.7 KB
Newer Older
1
"""Trainer for DeepSpeech2 model."""
2 3 4 5 6 7
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

import sys
import os
8
import argparse
9
import gzip
10
import time
11
import distutils.util
12
import multiprocessing
13
import paddle.v2 as paddle
X
Xinghai Sun 已提交
14
from model import deep_speech2
15
from data_utils.data import DataGenerator
16
import utils
X
Xinghai Sun 已提交
17

18
parser = argparse.ArgumentParser(description=__doc__)
19
parser.add_argument(
20
    "--batch_size", default=256, type=int, help="Minibatch size.")
21
parser.add_argument(
22
    "--num_passes",
23
    default=200,
24 25
    type=int,
    help="Training pass number. (default: %(default)s)")
26
parser.add_argument(
27 28 29 30
    "--num_conv_layers",
    default=2,
    type=int,
    help="Convolution layer number. (default: %(default)s)")
31
parser.add_argument(
32 33 34 35
    "--num_rnn_layers",
    default=3,
    type=int,
    help="RNN layer number. (default: %(default)s)")
36
parser.add_argument(
37 38 39 40
    "--rnn_layer_size",
    default=512,
    type=int,
    help="RNN layer cell number. (default: %(default)s)")
41
parser.add_argument(
42 43 44 45
    "--adam_learning_rate",
    default=5e-4,
    type=float,
    help="Learning rate for ADAM Optimizer. (default: %(default)s)")
46
parser.add_argument(
47 48 49 50
    "--use_gpu",
    default=True,
    type=distutils.util.strtobool,
    help="Use gpu or not. (default: %(default)s)")
51
parser.add_argument(
52
    "--use_sortagrad",
53
    default=True,
54 55
    type=distutils.util.strtobool,
    help="Use sortagrad or not. (default: %(default)s)")
Y
Yibing Liu 已提交
56 57 58 59 60 61
parser.add_argument(
    "--specgram_type",
    default='linear',
    type=str,
    help="Feature type of audio data: 'linear' (power spectrum)"
    " or 'mfcc'. (default: %(default)s)")
62 63
parser.add_argument(
    "--max_duration",
64
    default=27.0,
65 66 67 68 69 70 71 72 73
    type=float,
    help="Audios with duration larger than this will be discarded. "
    "(default: %(default)s)")
parser.add_argument(
    "--min_duration",
    default=0.0,
    type=float,
    help="Audios with duration smaller than this will be discarded. "
    "(default: %(default)s)")
74 75
parser.add_argument(
    "--shuffle_method",
76
    default='batch_shuffle_clipped',
77 78 79
    type=str,
    help="Shuffle method: 'instance_shuffle', 'batch_shuffle', "
    "'batch_shuffle_batch'. (default: %(default)s)")
80 81
parser.add_argument(
    "--trainer_count",
82
    default=8,
83 84
    type=int,
    help="Trainer number. (default: %(default)s)")
85 86
parser.add_argument(
    "--num_threads_data",
87
    default=multiprocessing.cpu_count(),
88 89
    type=int,
    help="Number of cpu threads for preprocessing data. (default: %(default)s)")
90
parser.add_argument(
91 92
    "--mean_std_filepath",
    default='mean_std.npz',
93 94 95 96
    type=str,
    help="Manifest path for normalizer. (default: %(default)s)")
parser.add_argument(
    "--train_manifest_path",
97
    default='datasets/manifest.train',
98 99 100 101
    type=str,
    help="Manifest path for training. (default: %(default)s)")
parser.add_argument(
    "--dev_manifest_path",
102
    default='datasets/manifest.dev',
103 104
    type=str,
    help="Manifest path for validation. (default: %(default)s)")
105 106
parser.add_argument(
    "--vocab_filepath",
107
    default='datasets/vocab/eng_vocab.txt',
108 109
    type=str,
    help="Vocabulary filepath. (default: %(default)s)")
110 111
parser.add_argument(
    "--init_model_path",
Y
yangyaming 已提交
112
    default=None,
113
    type=str,
Y
yangyaming 已提交
114 115 116
    help="If set None, the training will start from scratch. "
    "Otherwise, the training will resume from "
    "the existing model of this path. (default: %(default)s)")
117 118
parser.add_argument(
    "--augmentation_config",
119 120 121
    default='[{"type": "shift", '
    '"params": {"min_shift_ms": -5, "max_shift_ms": 5},'
    '"prob": 1.0}]',
122 123 124
    type=str,
    help="Augmentation configuration in json-format. "
    "(default: %(default)s)")
125 126 127 128
args = parser.parse_args()


def train():
129
    """DeepSpeech2 training."""
130

131
    # initialize data generator
132 133 134
    def data_generator():
        return DataGenerator(
            vocab_filepath=args.vocab_filepath,
135
            mean_std_filepath=args.mean_std_filepath,
136 137 138
            augmentation_config=args.augmentation_config,
            max_duration=args.max_duration,
            min_duration=args.min_duration,
Y
Yibing Liu 已提交
139
            specgram_type=args.specgram_type,
140
            num_threads=args.num_threads_data)
141

142 143
    train_generator = data_generator()
    test_generator = data_generator()
144

145
    # create network config
146
    # paddle.data_type.dense_array is used for variable batch input.
147 148
    # The size 161 * 161 is only an placeholder value and the real shape
    # of input batch data will be induced during training.
149
    audio_data = paddle.layer.data(
150
        name="audio_spectrogram", type=paddle.data_type.dense_array(161 * 161))
151 152
    text_data = paddle.layer.data(
        name="transcript_text",
153 154
        type=paddle.data_type.integer_value_sequence(
            train_generator.vocab_size))
155
    cost = deep_speech2(
156 157
        audio_data=audio_data,
        text_data=text_data,
158
        dict_size=train_generator.vocab_size,
159 160
        num_conv_layers=args.num_conv_layers,
        num_rnn_layers=args.num_rnn_layers,
161 162
        rnn_size=args.rnn_layer_size,
        is_inference=False)
163

164 165 166 167
    # create/load parameters and optimizer
    if args.init_model_path is None:
        parameters = paddle.parameters.create(cost)
    else:
Y
yangyaming 已提交
168 169
        if not os.path.isfile(args.init_model_path):
            raise IOError("Invalid model!")
170 171
        parameters = paddle.parameters.Parameters.from_tar(
            gzip.open(args.init_model_path))
172
    optimizer = paddle.optimizer.Adam(
173
        learning_rate=args.adam_learning_rate, gradient_clipping_threshold=400)
174 175 176
    trainer = paddle.trainer.SGD(
        cost=cost, parameters=parameters, update_equation=optimizer)

177
    # prepare data reader
178
    train_batch_reader = train_generator.batch_reader_creator(
179
        manifest_path=args.train_manifest_path,
180
        batch_size=args.batch_size,
181
        min_batch_size=args.trainer_count,
182
        sortagrad=args.use_sortagrad if args.init_model_path is None else False,
183
        shuffle_method=args.shuffle_method)
184
    test_batch_reader = test_generator.batch_reader_creator(
185
        manifest_path=args.dev_manifest_path,
186
        batch_size=args.batch_size,
187
        min_batch_size=1,  # must be 1, but will have errors.
188
        sortagrad=False,
189
        shuffle_method=None)
190

191 192
    # create event handler
    def event_handler(event):
193
        global start_time, cost_sum, cost_counter
194
        if isinstance(event, paddle.event.EndIteration):
195 196
            cost_sum += event.cost
            cost_counter += 1
197 198 199
            if (event.batch_id + 1) % 100 == 0:
                print("\nPass: %d, Batch: %d, TrainCost: %f" % (
                    event.pass_id, event.batch_id + 1, cost_sum / cost_counter))
200
                cost_sum, cost_counter = 0.0, 0
201
                with gzip.open("checkpoints/params.latest.tar.gz", 'w') as f:
202
                    parameters.to_tar(f)
203 204 205
            else:
                sys.stdout.write('.')
                sys.stdout.flush()
206 207
        if isinstance(event, paddle.event.BeginPass):
            start_time = time.time()
208
            cost_sum, cost_counter = 0.0, 0
209
        if isinstance(event, paddle.event.EndPass):
210 211 212 213
            result = trainer.test(
                reader=test_batch_reader, feeding=test_generator.feeding)
            print("\n------- Time: %d sec,  Pass: %d, ValidationCost: %s" %
                  (time.time() - start_time, event.pass_id, result.cost))
214 215 216
            with gzip.open("checkpoints/params.pass-%d.tar.gz" % event.pass_id,
                           'w') as f:
                parameters.to_tar(f)
217 218

    # run train
219
    trainer.train(
220
        reader=train_batch_reader,
221
        event_handler=event_handler,
222
        num_passes=args.num_passes,
223
        feeding=train_generator.feeding)
224 225 226


def main():
227
    utils.print_arguments(args)
228
    paddle.init(use_gpu=args.use_gpu, trainer_count=args.trainer_count)
229 230 231 232 233
    train()


if __name__ == '__main__':
    main()