train.py 4.7 KB
Newer Older
X
Xinghai Sun 已提交
1 2 3 4
"""
   Trainer for a simplifed version of Baidu DeepSpeech2 model.
"""

5 6
import paddle.v2 as paddle
import argparse
7
import gzip
8
import time
X
Xinghai Sun 已提交
9 10
import sys
from model import deep_speech2
11 12
from audio_data_utils import DataGenerator
import numpy as np
X
Xinghai Sun 已提交
13 14

#TODO: add WER metric
15 16

parser = argparse.ArgumentParser(
X
Xinghai Sun 已提交
17
    description='Simplified version of DeepSpeech2 trainer.')
18
parser.add_argument(
19
    "--batch_size", default=32, type=int, help="Minibatch size.")
20 21 22
parser.add_argument("--trainer", default=1, type=int, help="Trainer number.")
parser.add_argument(
    "--num_passes", default=20, type=int, help="Training pass number.")
23
parser.add_argument(
X
Xinghai Sun 已提交
24
    "--num_conv_layers", default=3, type=int, help="Convolution layer number.")
25
parser.add_argument(
X
Xinghai Sun 已提交
26
    "--num_rnn_layers", default=5, type=int, help="RNN layer number.")
27
parser.add_argument(
28
    "--rnn_layer_size", default=512, type=int, help="RNN layer cell number.")
29 30
parser.add_argument(
    "--use_gpu", default=True, type=bool, help="Use gpu or not.")
31 32
parser.add_argument(
    "--use_sortagrad", default=False, type=bool, help="Use sortagrad or not.")
33 34
parser.add_argument(
    "--trainer_count", default=8, type=int, help="Trainer number.")
35 36 37 38
args = parser.parse_args()


def train():
X
Xinghai Sun 已提交
39 40 41
    """
    DeepSpeech2 training.
    """
42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73
    # create data readers
    data_generator = DataGenerator(
        vocab_filepath='eng_vocab.txt',
        normalizer_manifest_path='./libri.manifest.train',
        normalizer_num_samples=200,
        max_duration=20.0,
        min_duration=0.0,
        stride_ms=10,
        window_ms=20)
    train_batch_reader_sortagrad = data_generator.batch_reader_creator(
        manifest_path='./libri.manifest.dev.small',
        batch_size=args.batch_size // args.trainer,
        padding_to=2000,
        flatten=True,
        sort_by_duration=True,
        shuffle=False)
    train_batch_reader_nosortagrad = data_generator.batch_reader_creator(
        manifest_path='./libri.manifest.dev.small',
        batch_size=args.batch_size // args.trainer,
        padding_to=2000,
        flatten=True,
        sort_by_duration=False,
        shuffle=True)
    test_batch_reader = data_generator.batch_reader_creator(
        manifest_path='./libri.manifest.test',
        batch_size=args.batch_size // args.trainer,
        padding_to=2000,
        flatten=True,
        sort_by_duration=False,
        shuffle=False)
    feeding = data_generator.data_name_feeding()

74
    # create network config
75
    dict_size = data_generator.vocabulary_size()
76 77 78
    audio_data = paddle.layer.data(
        name="audio_spectrogram",
        height=161,
79 80
        width=2000,
        type=paddle.data_type.dense_vector(322000))
81 82 83
    text_data = paddle.layer.data(
        name="transcript_text",
        type=paddle.data_type.integer_value_sequence(dict_size))
84 85 86 87 88 89 90
    cost, _ = deep_speech2(
        audio_data=audio_data,
        text_data=text_data,
        dict_size=dict_size,
        num_conv_layers=args.num_conv_layers,
        num_rnn_layers=args.num_rnn_layers,
        rnn_size=args.rnn_layer_size)
91 92 93 94

    # create parameters and optimizer
    parameters = paddle.parameters.create(cost)
    optimizer = paddle.optimizer.Adam(
95
        learning_rate=5e-5, gradient_clipping_threshold=400)
96 97 98 99 100
    trainer = paddle.trainer.SGD(
        cost=cost, parameters=parameters, update_equation=optimizer)

    # create event handler
    def event_handler(event):
101
        global start_time
102 103
        if isinstance(event, paddle.event.EndIteration):
            if event.batch_id % 10 == 0:
104
                print "\nPass: %d, Batch: %d, TrainCost: %f" % (
105
                    event.pass_id, event.batch_id, event.cost)
106 107 108
            else:
                sys.stdout.write('.')
                sys.stdout.flush()
109 110
        if isinstance(event, paddle.event.BeginPass):
            start_time = time.time()
111 112
        if isinstance(event, paddle.event.EndPass):
            result = trainer.test(reader=test_batch_reader, feeding=feeding)
113 114
            print "\n------- Time: %d,  Pass: %d, TestCost: %s" % (
                time.time() - start_time, event.pass_id, result.cost)
115 116 117 118
            with gzip.open("params.tar.gz", 'w') as f:
                parameters.to_tar(f)

    # run train
119 120 121
    # first pass with sortagrad
    if args.use_sortagrad:
        trainer.train(
122
            reader=train_batch_reader_sortagrad,
123 124 125 126 127
            event_handler=event_handler,
            num_passes=1,
            feeding=feeding)
        args.num_passes -= 1
    # other passes without sortagrad
128
    trainer.train(
129
        reader=train_batch_reader_nosortagrad,
130
        event_handler=event_handler,
131
        num_passes=args.num_passes,
132 133 134 135
        feeding=feeding)


def main():
136
    paddle.init(use_gpu=args.use_gpu, trainer_count=args.trainer_count)
137 138 139 140 141
    train()


if __name__ == '__main__':
    main()