train.py 9.8 KB
Newer Older
Y
Yibing Liu 已提交
1 2 3 4 5
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

import sys
6
import os
Y
Yibing Liu 已提交
7 8 9 10
import numpy as np
import argparse
import time

L
Luo Tao 已提交
11
import paddle.fluid as fluid
12 13 14
import data_utils.augmentor.trans_mean_variance_norm as trans_mean_variance_norm
import data_utils.augmentor.trans_add_delta as trans_add_delta
import data_utils.augmentor.trans_splice as trans_splice
15
import data_utils.async_data_reader as reader
16
from data_utils.util import lodtensor_to_ndarray
17
from model_utils.model import stacked_lstmp_model
Y
Yibing Liu 已提交
18 19 20


def parse_args():
21
    parser = argparse.ArgumentParser("Training for stacked LSTMP model.")
Y
Yibing Liu 已提交
22 23 24 25 26
    parser.add_argument(
        '--batch_size',
        type=int,
        default=32,
        help='The sequence number of a batch data. (default: %(default)d)')
27 28 29 30 31 32
    parser.add_argument(
        '--minimum_batch_size',
        type=int,
        default=1,
        help='The minimum sequence number of a batch data. '
        '(default: %(default)d)')
33 34 35 36 37
    parser.add_argument(
        '--frame_dim',
        type=int,
        default=120 * 11,
        help='Frame dimension of feature data. (default: %(default)d)')
Y
Yibing Liu 已提交
38 39 40 41
    parser.add_argument(
        '--stacked_num',
        type=int,
        default=5,
42
        help='Number of lstmp layers to stack. (default: %(default)d)')
Y
Yibing Liu 已提交
43 44 45 46
    parser.add_argument(
        '--proj_dim',
        type=int,
        default=512,
47
        help='Project size of lstmp unit. (default: %(default)d)')
Y
Yibing Liu 已提交
48 49 50 51
    parser.add_argument(
        '--hidden_dim',
        type=int,
        default=1024,
52
        help='Hidden size of lstmp unit. (default: %(default)d)')
53 54 55 56 57
    parser.add_argument(
        '--class_num',
        type=int,
        default=1749,
        help='Number of classes in label. (default: %(default)d)')
Y
Yibing Liu 已提交
58 59 60 61 62 63 64 65 66 67 68 69 70
    parser.add_argument(
        '--pass_num',
        type=int,
        default=100,
        help='Epoch number to train. (default: %(default)d)')
    parser.add_argument(
        '--print_per_batches',
        type=int,
        default=100,
        help='Interval to print training accuracy. (default: %(default)d)')
    parser.add_argument(
        '--learning_rate',
        type=float,
71
        default=0.00016,
Y
Yibing Liu 已提交
72 73 74 75 76 77 78
        help='Learning rate used to train. (default: %(default)f)')
    parser.add_argument(
        '--device',
        type=str,
        default='GPU',
        choices=['CPU', 'GPU'],
        help='The device type. (default: %(default)s)')
79 80
    parser.add_argument(
        '--parallel', action='store_true', help='If set, run in parallel.')
Y
Yibing Liu 已提交
81 82 83 84
    parser.add_argument(
        '--mean_var',
        type=str,
        default='data/global_mean_var_search26kHr',
Y
Yibing Liu 已提交
85 86
        help="The path for feature's global mean and variance. "
        "(default: %(default)s)")
Y
Yibing Liu 已提交
87
    parser.add_argument(
88
        '--train_feature_lst',
Y
Yibing Liu 已提交
89 90
        type=str,
        default='data/feature.lst',
Y
Yibing Liu 已提交
91
        help='The feature list path for training. (default: %(default)s)')
Y
Yibing Liu 已提交
92
    parser.add_argument(
93
        '--train_label_lst',
Y
Yibing Liu 已提交
94 95
        type=str,
        default='data/label.lst',
Y
Yibing Liu 已提交
96
        help='The label list path for training. (default: %(default)s)')
97 98 99 100
    parser.add_argument(
        '--val_feature_lst',
        type=str,
        default='data/val_feature.lst',
Y
Yibing Liu 已提交
101
        help='The feature list path for validation. (default: %(default)s)')
102 103 104 105
    parser.add_argument(
        '--val_label_lst',
        type=str,
        default='data/val_label.lst',
Y
Yibing Liu 已提交
106
        help='The label list path for validation. (default: %(default)s)')
Y
Yibing Liu 已提交
107
    parser.add_argument(
108 109 110
        '--init_model_path',
        type=str,
        default=None,
111 112
        help="The model (checkpoint) path which the training resumes from. "
        "If None, train the model from scratch. (default: %(default)s)")
113 114
    parser.add_argument(
        '--checkpoints',
Y
Yibing Liu 已提交
115 116
        type=str,
        default='./checkpoints',
117 118 119 120 121 122 123 124
        help="The directory for saving checkpoints. Do not save checkpoints "
        "if set to ''. (default: %(default)s)")
    parser.add_argument(
        '--infer_models',
        type=str,
        default='./infer_models',
        help="The directory for saving inference models. Do not save inference "
        "models if set to ''. (default: %(default)s)")
Y
Yibing Liu 已提交
125 126 127 128
    args = parser.parse_args()
    return args


129 130 131 132 133 134 135
def print_arguments(args):
    print('-----------  Configuration Arguments -----------')
    for arg, value in sorted(vars(args).iteritems()):
        print('%s: %s' % (arg, value))
    print('------------------------------------------------')


Y
Yibing Liu 已提交
136
def train(args):
Y
Yibing Liu 已提交
137 138
    """train in loop.
    """
Y
Yibing Liu 已提交
139

140 141 142 143 144 145 146 147 148
    # paths check
    if args.init_model_path is not None and \
            not os.path.exists(args.init_model_path):
        raise IOError("Invalid initial model path!")
    if args.checkpoints != '' and not os.path.exists(args.checkpoints):
        os.mkdir(args.checkpoints)
    if args.infer_models != '' and not os.path.exists(args.infer_models):
        os.mkdir(args.infer_models)

Y
Yibing Liu 已提交
149
    prediction, avg_cost, accuracy = stacked_lstmp_model(
150
        frame_dim=args.frame_dim,
151 152 153
        hidden_dim=args.hidden_dim,
        proj_dim=args.proj_dim,
        stacked_num=args.stacked_num,
154
        class_num=args.class_num,
155
        parallel=args.parallel)
Y
Yibing Liu 已提交
156

157 158
    # program for test
    test_program = fluid.default_main_program().clone()
159 160 161

    optimizer = fluid.optimizer.Adam(learning_rate=args.learning_rate)
    optimizer.minimize(avg_cost)
162

Y
Yibing Liu 已提交
163 164 165 166
    place = fluid.CPUPlace() if args.device == 'CPU' else fluid.CUDAPlace(0)
    exe = fluid.Executor(place)
    exe.run(fluid.default_startup_program())

167 168 169 170
    # resume training if initial model provided.
    if args.init_model_path is not None:
        fluid.io.load_persistables(exe, args.init_model_path)

Y
Yibing Liu 已提交
171 172 173 174 175 176
    ltrans = [
        trans_add_delta.TransAddDelta(2, 2),
        trans_mean_variance_norm.TransMeanVarianceNorm(args.mean_var),
        trans_splice.TransSplice()
    ]

Y
Yibing Liu 已提交
177 178
    feature_t = fluid.LoDTensor()
    label_t = fluid.LoDTensor()
179 180 181 182 183 184 185 186

    # validation
    def test(exe):
        # If test data not found, return invalid cost and accuracy
        if not (os.path.exists(args.val_feature_lst) and
                os.path.exists(args.val_label_lst)):
            return -1.0, -1.0
        # test data reader
187 188
        test_data_reader = reader.AsyncDataReader(args.val_feature_lst,
                                                  args.val_label_lst)
189 190 191 192 193 194
        test_data_reader.set_transformers(ltrans)
        test_costs, test_accs = [], []
        for batch_id, batch_data in enumerate(
                test_data_reader.batch_iterator(args.batch_size,
                                                args.minimum_batch_size)):
            # load_data
Y
Yibing Liu 已提交
195
            (features, labels, lod, _) = batch_data
196 197 198 199
            feature_t.set(features, place)
            feature_t.set_lod([lod])
            label_t.set(labels, place)
            label_t.set_lod([lod])
Y
Yibing Liu 已提交
200 201 202 203 204 205

            cost, acc = exe.run(test_program,
                                feed={"feature": feature_t,
                                      "label": label_t},
                                fetch_list=[avg_cost, accuracy],
                                return_numpy=False)
206 207 208 209
            test_costs.append(lodtensor_to_ndarray(cost)[0])
            test_accs.append(lodtensor_to_ndarray(acc)[0])
        return np.mean(test_costs), np.mean(test_accs)

Y
Yibing Liu 已提交
210
    # train data reader
211 212
    train_data_reader = reader.AsyncDataReader(args.train_feature_lst,
                                               args.train_label_lst, -1)
Z
zhxfl 已提交
213

214 215
    train_data_reader.set_transformers(ltrans)
    # train
Y
Yibing Liu 已提交
216 217
    for pass_id in xrange(args.pass_num):
        pass_start_time = time.time()
218
        for batch_id, batch_data in enumerate(
219 220
                train_data_reader.batch_iterator(args.batch_size,
                                                 args.minimum_batch_size)):
Y
Yibing Liu 已提交
221
            # load_data
Z
zhxfl 已提交
222
            (features, labels, lod, name_lst) = batch_data
223 224 225 226
            feature_t.set(features, place)
            feature_t.set_lod([lod])
            label_t.set(labels, place)
            label_t.set_lod([lod])
Y
Yibing Liu 已提交
227

Y
Yibing Liu 已提交
228 229 230 231 232 233 234 235
            to_print = batch_id > 0 and (batch_id % args.print_per_batches == 0)
            outs = exe.run(fluid.default_main_program(),
                           feed={"feature": feature_t,
                                 "label": label_t},
                           fetch_list=[avg_cost, accuracy] if to_print else [],
                           return_numpy=False)

            if to_print:
Y
Yibing Liu 已提交
236
                print("\nBatch %d, train cost: %f, train acc: %f" %
Y
Yibing Liu 已提交
237 238
                      (batch_id, lodtensor_to_ndarray(outs[0])[0],
                       lodtensor_to_ndarray(outs[1])[0]))
239
                # save the latest checkpoint
240 241 242 243
                if args.checkpoints != '':
                    model_path = os.path.join(args.checkpoints,
                                              "deep_asr.latest.checkpoint")
                    fluid.io.save_persistables(exe, model_path)
Y
Yibing Liu 已提交
244 245 246
            else:
                sys.stdout.write('.')
                sys.stdout.flush()
Y
Yibing Liu 已提交
247
        # run test
248
        val_cost, val_acc = test(exe)
249 250 251

        # save checkpoint per pass
        if args.checkpoints != '':
Y
Yibing Liu 已提交
252
            model_path = os.path.join(
253 254 255 256 257 258 259 260
                args.checkpoints,
                "deep_asr.pass_" + str(pass_id) + ".checkpoint")
            fluid.io.save_persistables(exe, model_path)
        # save inference model
        if args.infer_models != '':
            model_path = os.path.join(
                args.infer_models,
                "deep_asr.pass_" + str(pass_id) + ".infer.model")
Y
Yibing Liu 已提交
261 262
            fluid.io.save_inference_model(model_path, ["feature"],
                                          [prediction], exe)
Y
Yibing Liu 已提交
263 264 265 266
        # cal pass time
        pass_end_time = time.time()
        time_consumed = pass_end_time - pass_start_time
        # print info at pass end
267 268
        print("\nPass %d, time consumed: %f s, val cost: %f, val acc: %f\n" %
              (pass_id, time_consumed, val_cost, val_acc))
Y
Yibing Liu 已提交
269 270 271 272 273 274 275


if __name__ == '__main__':
    args = parse_args()
    print_arguments(args)

    train(args)