train.py 5.9 KB
Newer Older
X
Xinghai Sun 已提交
1 2 3 4
"""
   Trainer for a simplifed version of Baidu DeepSpeech2 model.
"""

5
import paddle.v2 as paddle
6
import distutils.util
7
import argparse
8
import gzip
9
import time
X
Xinghai Sun 已提交
10 11
import sys
from model import deep_speech2
12 13
from audio_data_utils import DataGenerator
import numpy as np
X
Xinghai Sun 已提交
14 15

#TODO: add WER metric
16 17

parser = argparse.ArgumentParser(
X
Xinghai Sun 已提交
18
    description='Simplified version of DeepSpeech2 trainer.')
19
parser.add_argument(
20
    "--batch_size", default=32, type=int, help="Minibatch size.")
21
parser.add_argument(
22 23 24 25
    "--num_passes",
    default=20,
    type=int,
    help="Training pass number. (default: %(default)s)")
26
parser.add_argument(
27 28 29 30
    "--num_conv_layers",
    default=2,
    type=int,
    help="Convolution layer number. (default: %(default)s)")
31
parser.add_argument(
32 33 34 35
    "--num_rnn_layers",
    default=3,
    type=int,
    help="RNN layer number. (default: %(default)s)")
36
parser.add_argument(
37 38 39 40
    "--rnn_layer_size",
    default=512,
    type=int,
    help="RNN layer cell number. (default: %(default)s)")
41
parser.add_argument(
42 43 44 45
    "--adam_learning_rate",
    default=5e-4,
    type=float,
    help="Learning rate for ADAM Optimizer. (default: %(default)s)")
46
parser.add_argument(
47 48 49 50
    "--use_gpu",
    default=True,
    type=distutils.util.strtobool,
    help="Use gpu or not. (default: %(default)s)")
51
parser.add_argument(
52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75
    "--use_sortagrad",
    default=False,
    type=distutils.util.strtobool,
    help="Use sortagrad or not. (default: %(default)s)")
parser.add_argument(
    "--trainer_count",
    default=4,
    type=int,
    help="Trainer number. (default: %(default)s)")
parser.add_argument(
    "--normalizer_manifest_path",
    default='./manifest.libri.train-clean-100',
    type=str,
    help="Manifest path for normalizer. (default: %(default)s)")
parser.add_argument(
    "--train_manifest_path",
    default='./manifest.libri.train-clean-100',
    type=str,
    help="Manifest path for training. (default: %(default)s)")
parser.add_argument(
    "--dev_manifest_path",
    default='./manifest.libri.dev-clean',
    type=str,
    help="Manifest path for validation. (default: %(default)s)")
76 77 78 79
args = parser.parse_args()


def train():
X
Xinghai Sun 已提交
80 81 82
    """
    DeepSpeech2 training.
    """
83
    # initialize data generator
84 85
    data_generator = DataGenerator(
        vocab_filepath='eng_vocab.txt',
86
        normalizer_manifest_path=args.normalizer_manifest_path,
87 88 89 90 91 92
        normalizer_num_samples=200,
        max_duration=20.0,
        min_duration=0.0,
        stride_ms=10,
        window_ms=20)

93
    # create network config
94
    dict_size = data_generator.vocabulary_size()
95 96 97
    audio_data = paddle.layer.data(
        name="audio_spectrogram",
        height=161,
98 99
        width=2000,
        type=paddle.data_type.dense_vector(322000))
100 101 102
    text_data = paddle.layer.data(
        name="transcript_text",
        type=paddle.data_type.integer_value_sequence(dict_size))
103 104 105 106 107 108 109
    cost, _ = deep_speech2(
        audio_data=audio_data,
        text_data=text_data,
        dict_size=dict_size,
        num_conv_layers=args.num_conv_layers,
        num_rnn_layers=args.num_rnn_layers,
        rnn_size=args.rnn_layer_size)
110 111 112 113

    # create parameters and optimizer
    parameters = paddle.parameters.create(cost)
    optimizer = paddle.optimizer.Adam(
114
        learning_rate=args.adam_learning_rate, gradient_clipping_threshold=400)
115 116 117
    trainer = paddle.trainer.SGD(
        cost=cost, parameters=parameters, update_equation=optimizer)

118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141
    # prepare data reader
    train_batch_reader_sortagrad = data_generator.batch_reader_creator(
        manifest_path=args.train_manifest_path,
        batch_size=args.batch_size // args.trainer_count,
        padding_to=2000,
        flatten=True,
        sort_by_duration=True,
        shuffle=False)
    train_batch_reader_nosortagrad = data_generator.batch_reader_creator(
        manifest_path=args.train_manifest_path,
        batch_size=args.batch_size // args.trainer_count,
        padding_to=2000,
        flatten=True,
        sort_by_duration=False,
        shuffle=True)
    test_batch_reader = data_generator.batch_reader_creator(
        manifest_path=args.dev_manifest_path,
        batch_size=args.batch_size // args.trainer_count,
        padding_to=2000,
        flatten=True,
        sort_by_duration=False,
        shuffle=False)
    feeding = data_generator.data_name_feeding()

142 143
    # create event handler
    def event_handler(event):
144
        global start_time
145 146
        global cost_sum
        global cost_counter
147
        if isinstance(event, paddle.event.EndIteration):
148 149 150
            cost_sum += event.cost
            cost_counter += 1
            if event.batch_id % 50 == 0:
151
                print "\nPass: %d, Batch: %d, TrainCost: %f" % (
152 153 154 155
                    event.pass_id, event.batch_id, cost_sum / cost_counter)
                cost_sum, cost_counter = 0.0, 0
                with gzip.open("params.tar.gz", 'w') as f:
                    parameters.to_tar(f)
156 157 158
            else:
                sys.stdout.write('.')
                sys.stdout.flush()
159 160
        if isinstance(event, paddle.event.BeginPass):
            start_time = time.time()
161
            cost_sum, cost_counter = 0.0, 0
162 163
        if isinstance(event, paddle.event.EndPass):
            result = trainer.test(reader=test_batch_reader, feeding=feeding)
164
            print "\n------- Time: %d sec,  Pass: %d, ValidationCost: %s" % (
165
                time.time() - start_time, event.pass_id, result.cost)
166 167

    # run train
168 169 170
    # first pass with sortagrad
    if args.use_sortagrad:
        trainer.train(
171
            reader=train_batch_reader_sortagrad,
172 173 174 175 176
            event_handler=event_handler,
            num_passes=1,
            feeding=feeding)
        args.num_passes -= 1
    # other passes without sortagrad
177
    trainer.train(
178
        reader=train_batch_reader_nosortagrad,
179
        event_handler=event_handler,
180
        num_passes=args.num_passes,
181 182 183 184
        feeding=feeding)


def main():
185
    paddle.init(use_gpu=args.use_gpu, trainer_count=args.trainer_count)
186 187 188 189 190
    train()


if __name__ == '__main__':
    main()