train.py 6.2 KB
Newer Older
1
"""Trainer for DeepSpeech2 model."""
2 3 4 5
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

6
import argparse
7
import distutils.util
8
import multiprocessing
9
import paddle.v2 as paddle
10
from model import DeepSpeech2Model
11
from data_utils.data import DataGenerator
12
import utils
X
Xinghai Sun 已提交
13

14
parser = argparse.ArgumentParser(description=__doc__)
15
parser.add_argument(
16
    "--batch_size", default=256, type=int, help="Minibatch size.")
17
parser.add_argument(
18
    "--num_passes",
19
    default=200,
20 21
    type=int,
    help="Training pass number. (default: %(default)s)")
22 23 24 25 26 27
parser.add_argument(
    "--num_iterations_print",
    default=100,
    type=int,
    help="Number of iterations for every train cost printing. "
    "(default: %(default)s)")
28
parser.add_argument(
29 30 31 32
    "--num_conv_layers",
    default=2,
    type=int,
    help="Convolution layer number. (default: %(default)s)")
33
parser.add_argument(
34 35 36 37
    "--num_rnn_layers",
    default=3,
    type=int,
    help="RNN layer number. (default: %(default)s)")
38
parser.add_argument(
39
    "--rnn_layer_size",
40
    default=1024,
41 42
    type=int,
    help="RNN layer cell number. (default: %(default)s)")
X
Xinghai Sun 已提交
43 44
parser.add_argument(
    "--use_gru",
45
    default=False,
46
    type=distutils.util.strtobool,
X
Xinghai Sun 已提交
47
    help="Use GRU or simple RNN. (default: %(default)s)")
48
parser.add_argument(
49 50 51 52
    "--adam_learning_rate",
    default=5e-4,
    type=float,
    help="Learning rate for ADAM Optimizer. (default: %(default)s)")
53
parser.add_argument(
54 55 56 57
    "--use_gpu",
    default=True,
    type=distutils.util.strtobool,
    help="Use gpu or not. (default: %(default)s)")
58
parser.add_argument(
59
    "--use_sortagrad",
60
    default=True,
61 62
    type=distutils.util.strtobool,
    help="Use sortagrad or not. (default: %(default)s)")
Y
Yibing Liu 已提交
63 64 65 66 67 68
parser.add_argument(
    "--specgram_type",
    default='linear',
    type=str,
    help="Feature type of audio data: 'linear' (power spectrum)"
    " or 'mfcc'. (default: %(default)s)")
69 70
parser.add_argument(
    "--max_duration",
71
    default=27.0,
72 73 74 75 76 77 78 79 80
    type=float,
    help="Audios with duration larger than this will be discarded. "
    "(default: %(default)s)")
parser.add_argument(
    "--min_duration",
    default=0.0,
    type=float,
    help="Audios with duration smaller than this will be discarded. "
    "(default: %(default)s)")
81 82
parser.add_argument(
    "--shuffle_method",
83
    default='batch_shuffle_clipped',
84 85 86
    type=str,
    help="Shuffle method: 'instance_shuffle', 'batch_shuffle', "
    "'batch_shuffle_batch'. (default: %(default)s)")
87 88
parser.add_argument(
    "--trainer_count",
89
    default=8,
90 91
    type=int,
    help="Trainer number. (default: %(default)s)")
92 93
parser.add_argument(
    "--num_threads_data",
94
    default=multiprocessing.cpu_count() // 2,
95 96
    type=int,
    help="Number of cpu threads for preprocessing data. (default: %(default)s)")
97
parser.add_argument(
98 99
    "--mean_std_filepath",
    default='mean_std.npz',
100 101 102 103
    type=str,
    help="Manifest path for normalizer. (default: %(default)s)")
parser.add_argument(
    "--train_manifest_path",
104
    default='datasets/manifest.train',
105 106 107 108
    type=str,
    help="Manifest path for training. (default: %(default)s)")
parser.add_argument(
    "--dev_manifest_path",
109
    default='datasets/manifest.dev',
110 111
    type=str,
    help="Manifest path for validation. (default: %(default)s)")
112 113
parser.add_argument(
    "--vocab_filepath",
114
    default='datasets/vocab/eng_vocab.txt',
115 116
    type=str,
    help="Vocabulary filepath. (default: %(default)s)")
117 118
parser.add_argument(
    "--init_model_path",
Y
yangyaming 已提交
119
    default=None,
120
    type=str,
Y
yangyaming 已提交
121 122 123
    help="If set None, the training will start from scratch. "
    "Otherwise, the training will resume from "
    "the existing model of this path. (default: %(default)s)")
124 125 126 127 128
parser.add_argument(
    "--output_model_dir",
    default="./checkpoints",
    type=str,
    help="Directory for saving models. (default: %(default)s)")
129 130
parser.add_argument(
    "--augmentation_config",
131
    default=open('conf/augmentation.config', 'r').read(),
132 133 134
    type=str,
    help="Augmentation configuration in json-format. "
    "(default: %(default)s)")
135 136 137 138 139 140
parser.add_argument(
    "--is_local",
    default=True,
    type=distutils.util.strtobool,
    help="Set to false if running with pserver in paddlecloud. "
    "(default: %(default)s)")
141 142 143 144
args = parser.parse_args()


def train():
145
    """DeepSpeech2 training."""
146 147 148 149 150 151 152 153 154 155 156 157 158 159
    train_generator = DataGenerator(
        vocab_filepath=args.vocab_filepath,
        mean_std_filepath=args.mean_std_filepath,
        augmentation_config=args.augmentation_config,
        max_duration=args.max_duration,
        min_duration=args.min_duration,
        specgram_type=args.specgram_type,
        num_threads=args.num_threads_data)
    dev_generator = DataGenerator(
        vocab_filepath=args.vocab_filepath,
        mean_std_filepath=args.mean_std_filepath,
        augmentation_config="{}",
        specgram_type=args.specgram_type,
        num_threads=args.num_threads_data)
160
    train_batch_reader = train_generator.batch_reader_creator(
161
        manifest_path=args.train_manifest_path,
162
        batch_size=args.batch_size,
163
        min_batch_size=args.trainer_count,
164
        sortagrad=args.use_sortagrad if args.init_model_path is None else False,
165
        shuffle_method=args.shuffle_method)
166
    dev_batch_reader = dev_generator.batch_reader_creator(
167
        manifest_path=args.dev_manifest_path,
168
        batch_size=args.batch_size,
169
        min_batch_size=1,  # must be 1, but will have errors.
170
        sortagrad=False,
171
        shuffle_method=None)
172

173 174 175 176 177
    ds2_model = DeepSpeech2Model(
        vocab_size=train_generator.vocab_size,
        num_conv_layers=args.num_conv_layers,
        num_rnn_layers=args.num_rnn_layers,
        rnn_layer_size=args.rnn_layer_size,
X
Xinghai Sun 已提交
178
        use_gru=args.use_gru,
179 180 181 182 183 184 185
        pretrained_model_path=args.init_model_path)
    ds2_model.train(
        train_batch_reader=train_batch_reader,
        dev_batch_reader=dev_batch_reader,
        feeding_dict=train_generator.feeding,
        learning_rate=args.adam_learning_rate,
        gradient_clipping=400,
186
        num_passes=args.num_passes,
187
        num_iterations_print=args.num_iterations_print,
188 189
        output_model_dir=args.output_model_dir,
        is_local=args.is_local)
190 191 192


def main():
193
    utils.print_arguments(args)
194
    paddle.init(use_gpu=args.use_gpu, trainer_count=args.trainer_count)
195 196 197 198 199
    train()


if __name__ == '__main__':
    main()