train.py 9.6 KB
Newer Older
Y
Yibing Liu 已提交
1 2 3 4 5
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

import sys
6
import os
Y
Yibing Liu 已提交
7 8 9 10
import numpy as np
import argparse
import time

L
Luo Tao 已提交
11
import paddle.fluid as fluid
12 13 14
import data_utils.augmentor.trans_mean_variance_norm as trans_mean_variance_norm
import data_utils.augmentor.trans_add_delta as trans_add_delta
import data_utils.augmentor.trans_splice as trans_splice
Y
Yibing Liu 已提交
15
import data_utils.data_reader as reader
16
from data_utils.util import lodtensor_to_ndarray
17
from model_utils.model import stacked_lstmp_model
Y
Yibing Liu 已提交
18 19 20


def parse_args():
21
    parser = argparse.ArgumentParser("Training for stacked LSTMP model.")
Y
Yibing Liu 已提交
22 23 24 25 26
    parser.add_argument(
        '--batch_size',
        type=int,
        default=32,
        help='The sequence number of a batch data. (default: %(default)d)')
27 28 29 30 31 32
    parser.add_argument(
        '--minimum_batch_size',
        type=int,
        default=1,
        help='The minimum sequence number of a batch data. '
        '(default: %(default)d)')
Y
Yibing Liu 已提交
33 34 35 36
    parser.add_argument(
        '--stacked_num',
        type=int,
        default=5,
37
        help='Number of lstmp layers to stack. (default: %(default)d)')
Y
Yibing Liu 已提交
38 39 40 41
    parser.add_argument(
        '--proj_dim',
        type=int,
        default=512,
42
        help='Project size of lstmp unit. (default: %(default)d)')
Y
Yibing Liu 已提交
43 44 45 46
    parser.add_argument(
        '--hidden_dim',
        type=int,
        default=1024,
47
        help='Hidden size of lstmp unit. (default: %(default)d)')
Y
Yibing Liu 已提交
48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68
    parser.add_argument(
        '--pass_num',
        type=int,
        default=100,
        help='Epoch number to train. (default: %(default)d)')
    parser.add_argument(
        '--print_per_batches',
        type=int,
        default=100,
        help='Interval to print training accuracy. (default: %(default)d)')
    parser.add_argument(
        '--learning_rate',
        type=float,
        default=0.002,
        help='Learning rate used to train. (default: %(default)f)')
    parser.add_argument(
        '--device',
        type=str,
        default='GPU',
        choices=['CPU', 'GPU'],
        help='The device type. (default: %(default)s)')
69 70
    parser.add_argument(
        '--parallel', action='store_true', help='If set, run in parallel.')
Y
Yibing Liu 已提交
71 72 73 74
    parser.add_argument(
        '--mean_var',
        type=str,
        default='data/global_mean_var_search26kHr',
Y
Yibing Liu 已提交
75 76
        help="The path for feature's global mean and variance. "
        "(default: %(default)s)")
Y
Yibing Liu 已提交
77
    parser.add_argument(
78
        '--train_feature_lst',
Y
Yibing Liu 已提交
79 80
        type=str,
        default='data/feature.lst',
Y
Yibing Liu 已提交
81
        help='The feature list path for training. (default: %(default)s)')
Y
Yibing Liu 已提交
82
    parser.add_argument(
83
        '--train_label_lst',
Y
Yibing Liu 已提交
84 85
        type=str,
        default='data/label.lst',
Y
Yibing Liu 已提交
86
        help='The label list path for training. (default: %(default)s)')
87 88 89 90
    parser.add_argument(
        '--val_feature_lst',
        type=str,
        default='data/val_feature.lst',
Y
Yibing Liu 已提交
91
        help='The feature list path for validation. (default: %(default)s)')
92 93 94 95
    parser.add_argument(
        '--val_label_lst',
        type=str,
        default='data/val_label.lst',
Y
Yibing Liu 已提交
96
        help='The label list path for validation. (default: %(default)s)')
Y
Yibing Liu 已提交
97
    parser.add_argument(
98 99 100
        '--init_model_path',
        type=str,
        default=None,
101 102
        help="The model (checkpoint) path which the training resumes from. "
        "If None, train the model from scratch. (default: %(default)s)")
103 104
    parser.add_argument(
        '--checkpoints',
Y
Yibing Liu 已提交
105 106
        type=str,
        default='./checkpoints',
107 108 109 110 111 112 113 114
        help="The directory for saving checkpoints. Do not save checkpoints "
        "if set to ''. (default: %(default)s)")
    parser.add_argument(
        '--infer_models',
        type=str,
        default='./infer_models',
        help="The directory for saving inference models. Do not save inference "
        "models if set to ''. (default: %(default)s)")
Y
Yibing Liu 已提交
115 116 117 118
    args = parser.parse_args()
    return args


119 120 121 122 123 124 125
def print_arguments(args):
    print('-----------  Configuration Arguments -----------')
    for arg, value in sorted(vars(args).iteritems()):
        print('%s: %s' % (arg, value))
    print('------------------------------------------------')


Y
Yibing Liu 已提交
126
def train(args):
Y
Yibing Liu 已提交
127 128
    """train in loop.
    """
Y
Yibing Liu 已提交
129

130 131 132 133 134 135 136 137 138
    # paths check
    if args.init_model_path is not None and \
            not os.path.exists(args.init_model_path):
        raise IOError("Invalid initial model path!")
    if args.checkpoints != '' and not os.path.exists(args.checkpoints):
        os.mkdir(args.checkpoints)
    if args.infer_models != '' and not os.path.exists(args.infer_models):
        os.mkdir(args.infer_models)

Y
Yibing Liu 已提交
139
    prediction, avg_cost, accuracy = stacked_lstmp_model(
140 141 142 143 144
        hidden_dim=args.hidden_dim,
        proj_dim=args.proj_dim,
        stacked_num=args.stacked_num,
        class_num=1749,
        parallel=args.parallel)
Y
Yibing Liu 已提交
145

Y
Yibing Liu 已提交
146
    optimizer = fluid.optimizer.Momentum(
147
        learning_rate=args.learning_rate, momentum=0.9)
Y
Yibing Liu 已提交
148
    optimizer.minimize(avg_cost)
Y
Yibing Liu 已提交
149

150 151 152 153 154
    # program for test
    test_program = fluid.default_main_program().clone()
    with fluid.program_guard(test_program):
        test_program = fluid.io.get_inference_program([avg_cost, accuracy])

Y
Yibing Liu 已提交
155 156 157 158
    place = fluid.CPUPlace() if args.device == 'CPU' else fluid.CUDAPlace(0)
    exe = fluid.Executor(place)
    exe.run(fluid.default_startup_program())

159 160 161 162
    # resume training if initial model provided.
    if args.init_model_path is not None:
        fluid.io.load_persistables(exe, args.init_model_path)

Y
Yibing Liu 已提交
163 164 165 166 167 168
    ltrans = [
        trans_add_delta.TransAddDelta(2, 2),
        trans_mean_variance_norm.TransMeanVarianceNorm(args.mean_var),
        trans_splice.TransSplice()
    ]

Y
Yibing Liu 已提交
169 170
    feature_t = fluid.LoDTensor()
    label_t = fluid.LoDTensor()
171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186

    # validation
    def test(exe):
        # If test data not found, return invalid cost and accuracy
        if not (os.path.exists(args.val_feature_lst) and
                os.path.exists(args.val_label_lst)):
            return -1.0, -1.0
        # test data reader
        test_data_reader = reader.DataReader(args.val_feature_lst,
                                             args.val_label_lst)
        test_data_reader.set_transformers(ltrans)
        test_costs, test_accs = [], []
        for batch_id, batch_data in enumerate(
                test_data_reader.batch_iterator(args.batch_size,
                                                args.minimum_batch_size)):
            # load_data
Y
Yibing Liu 已提交
187 188 189 190 191 192 193 194 195 196 197
            (features, labels, lod) = batch_data
            feature_t.set(features, place)
            feature_t.set_lod([lod])
            label_t.set(labels, place)
            label_t.set_lod([lod])

            cost, acc = exe.run(test_program,
                                feed={"feature": feature_t,
                                      "label": label_t},
                                fetch_list=[avg_cost, accuracy],
                                return_numpy=False)
198 199 200 201
            test_costs.append(lodtensor_to_ndarray(cost)[0])
            test_accs.append(lodtensor_to_ndarray(acc)[0])
        return np.mean(test_costs), np.mean(test_accs)

Y
Yibing Liu 已提交
202
    # train data reader
203
    train_data_reader = reader.DataReader(args.train_feature_lst,
Z
zhxfl 已提交
204
                                          args.train_label_lst, -1)
205 206
    train_data_reader.set_transformers(ltrans)
    # train
Y
Yibing Liu 已提交
207 208
    for pass_id in xrange(args.pass_num):
        pass_start_time = time.time()
209
        for batch_id, batch_data in enumerate(
210 211
                train_data_reader.batch_iterator(args.batch_size,
                                                 args.minimum_batch_size)):
Y
Yibing Liu 已提交
212
            # load_data
Y
Yibing Liu 已提交
213 214 215 216 217 218
            (features, labels, lod) = batch_data
            feature_t.set(features, place)
            feature_t.set_lod([lod])
            label_t.set(labels, place)
            label_t.set_lod([lod])

Y
Yibing Liu 已提交
219 220 221 222 223 224 225 226
            to_print = batch_id > 0 and (batch_id % args.print_per_batches == 0)
            outs = exe.run(fluid.default_main_program(),
                           feed={"feature": feature_t,
                                 "label": label_t},
                           fetch_list=[avg_cost, accuracy] if to_print else [],
                           return_numpy=False)

            if to_print:
Y
Yibing Liu 已提交
227
                print("\nBatch %d, train cost: %f, train acc: %f" %
Y
Yibing Liu 已提交
228 229
                      (batch_id, lodtensor_to_ndarray(outs[0])[0],
                       lodtensor_to_ndarray(outs[1])[0]))
230
                # save the latest checkpoint
231 232 233 234
                if args.checkpoints != '':
                    model_path = os.path.join(args.checkpoints,
                                              "deep_asr.latest.checkpoint")
                    fluid.io.save_persistables(exe, model_path)
Y
Yibing Liu 已提交
235 236 237
            else:
                sys.stdout.write('.')
                sys.stdout.flush()
Y
Yibing Liu 已提交
238
        # run test
239
        val_cost, val_acc = test(exe)
240 241 242

        # save checkpoint per pass
        if args.checkpoints != '':
Y
Yibing Liu 已提交
243
            model_path = os.path.join(
244 245 246 247 248 249 250 251
                args.checkpoints,
                "deep_asr.pass_" + str(pass_id) + ".checkpoint")
            fluid.io.save_persistables(exe, model_path)
        # save inference model
        if args.infer_models != '':
            model_path = os.path.join(
                args.infer_models,
                "deep_asr.pass_" + str(pass_id) + ".infer.model")
Y
Yibing Liu 已提交
252 253
            fluid.io.save_inference_model(model_path, ["feature"],
                                          [prediction], exe)
Y
Yibing Liu 已提交
254 255 256 257
        # cal pass time
        pass_end_time = time.time()
        time_consumed = pass_end_time - pass_start_time
        # print info at pass end
258 259
        print("\nPass %d, time consumed: %f s, val cost: %f, val acc: %f\n" %
              (pass_id, time_consumed, val_cost, val_acc))
Y
Yibing Liu 已提交
260 261 262 263 264 265 266


if __name__ == '__main__':
    args = parse_args()
    print_arguments(args)

    train(args)