train.py 8.1 KB
Newer Older
Y
Yibing Liu 已提交
1 2 3 4 5
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

import sys
6
import os
Y
Yibing Liu 已提交
7 8 9 10 11
import numpy as np
import argparse
import time

import paddle.v2.fluid as fluid
12 13 14
import data_utils.augmentor.trans_mean_variance_norm as trans_mean_variance_norm
import data_utils.augmentor.trans_add_delta as trans_add_delta
import data_utils.augmentor.trans_splice as trans_splice
Y
Yibing Liu 已提交
15
import data_utils.data_reader as reader
16
from data_utils.util import lodtensor_to_ndarray
17
from model_utils.model import stacked_lstmp_model
Y
Yibing Liu 已提交
18 19 20


def parse_args():
21
    parser = argparse.ArgumentParser("Training for stacked LSTMP model.")
Y
Yibing Liu 已提交
22 23 24 25 26
    parser.add_argument(
        '--batch_size',
        type=int,
        default=32,
        help='The sequence number of a batch data. (default: %(default)d)')
27 28 29 30 31 32
    parser.add_argument(
        '--minimum_batch_size',
        type=int,
        default=1,
        help='The minimum sequence number of a batch data. '
        '(default: %(default)d)')
Y
Yibing Liu 已提交
33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68
    parser.add_argument(
        '--stacked_num',
        type=int,
        default=5,
        help='Number of lstm layers to stack. (default: %(default)d)')
    parser.add_argument(
        '--proj_dim',
        type=int,
        default=512,
        help='Project size of lstm unit. (default: %(default)d)')
    parser.add_argument(
        '--hidden_dim',
        type=int,
        default=1024,
        help='Hidden size of lstm unit. (default: %(default)d)')
    parser.add_argument(
        '--pass_num',
        type=int,
        default=100,
        help='Epoch number to train. (default: %(default)d)')
    parser.add_argument(
        '--print_per_batches',
        type=int,
        default=100,
        help='Interval to print training accuracy. (default: %(default)d)')
    parser.add_argument(
        '--learning_rate',
        type=float,
        default=0.002,
        help='Learning rate used to train. (default: %(default)f)')
    parser.add_argument(
        '--device',
        type=str,
        default='GPU',
        choices=['CPU', 'GPU'],
        help='The device type. (default: %(default)s)')
69 70
    parser.add_argument(
        '--parallel', action='store_true', help='If set, run in parallel.')
Y
Yibing Liu 已提交
71 72 73 74
    parser.add_argument(
        '--mean_var',
        type=str,
        default='data/global_mean_var_search26kHr',
Y
Yibing Liu 已提交
75 76
        help="The path for feature's global mean and variance. "
        "(default: %(default)s)")
Y
Yibing Liu 已提交
77
    parser.add_argument(
78
        '--train_feature_lst',
Y
Yibing Liu 已提交
79 80
        type=str,
        default='data/feature.lst',
Y
Yibing Liu 已提交
81
        help='The feature list path for training. (default: %(default)s)')
Y
Yibing Liu 已提交
82
    parser.add_argument(
83
        '--train_label_lst',
Y
Yibing Liu 已提交
84 85
        type=str,
        default='data/label.lst',
Y
Yibing Liu 已提交
86
        help='The label list path for training. (default: %(default)s)')
87 88 89 90
    parser.add_argument(
        '--val_feature_lst',
        type=str,
        default='data/val_feature.lst',
Y
Yibing Liu 已提交
91
        help='The feature list path for validation. (default: %(default)s)')
92 93 94 95
    parser.add_argument(
        '--val_label_lst',
        type=str,
        default='data/val_label.lst',
Y
Yibing Liu 已提交
96
        help='The label list path for validation. (default: %(default)s)')
Y
Yibing Liu 已提交
97 98 99 100
    parser.add_argument(
        '--model_save_dir',
        type=str,
        default='./checkpoints',
Y
Yibing Liu 已提交
101 102
        help="The directory for saving model. Do not save model if set to "
        "''. (default: %(default)s)")
Y
Yibing Liu 已提交
103 104 105 106
    args = parser.parse_args()
    return args


107 108 109 110 111 112 113
def print_arguments(args):
    print('-----------  Configuration Arguments -----------')
    for arg, value in sorted(vars(args).iteritems()):
        print('%s: %s' % (arg, value))
    print('------------------------------------------------')


Y
Yibing Liu 已提交
114
def train(args):
Y
Yibing Liu 已提交
115 116
    """train in loop.
    """
Y
Yibing Liu 已提交
117

Y
Yibing Liu 已提交
118
    prediction, avg_cost, accuracy = stacked_lstmp_model(
119 120 121 122 123
        hidden_dim=args.hidden_dim,
        proj_dim=args.proj_dim,
        stacked_num=args.stacked_num,
        class_num=1749,
        parallel=args.parallel)
Y
Yibing Liu 已提交
124 125 126 127

    adam_optimizer = fluid.optimizer.Adam(learning_rate=args.learning_rate)
    adam_optimizer.minimize(avg_cost)

128 129 130 131 132
    # program for test
    test_program = fluid.default_main_program().clone()
    with fluid.program_guard(test_program):
        test_program = fluid.io.get_inference_program([avg_cost, accuracy])

Y
Yibing Liu 已提交
133 134 135 136 137 138 139 140 141 142
    place = fluid.CPUPlace() if args.device == 'CPU' else fluid.CUDAPlace(0)
    exe = fluid.Executor(place)
    exe.run(fluid.default_startup_program())

    ltrans = [
        trans_add_delta.TransAddDelta(2, 2),
        trans_mean_variance_norm.TransMeanVarianceNorm(args.mean_var),
        trans_splice.TransSplice()
    ]

Y
Yibing Liu 已提交
143 144
    feature_t = fluid.LoDTensor()
    label_t = fluid.LoDTensor()
145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160

    # validation
    def test(exe):
        # If test data not found, return invalid cost and accuracy
        if not (os.path.exists(args.val_feature_lst) and
                os.path.exists(args.val_label_lst)):
            return -1.0, -1.0
        # test data reader
        test_data_reader = reader.DataReader(args.val_feature_lst,
                                             args.val_label_lst)
        test_data_reader.set_transformers(ltrans)
        test_costs, test_accs = [], []
        for batch_id, batch_data in enumerate(
                test_data_reader.batch_iterator(args.batch_size,
                                                args.minimum_batch_size)):
            # load_data
Y
Yibing Liu 已提交
161 162 163 164 165 166 167 168 169 170 171
            (features, labels, lod) = batch_data
            feature_t.set(features, place)
            feature_t.set_lod([lod])
            label_t.set(labels, place)
            label_t.set_lod([lod])

            cost, acc = exe.run(test_program,
                                feed={"feature": feature_t,
                                      "label": label_t},
                                fetch_list=[avg_cost, accuracy],
                                return_numpy=False)
172 173 174 175
            test_costs.append(lodtensor_to_ndarray(cost)[0])
            test_accs.append(lodtensor_to_ndarray(acc)[0])
        return np.mean(test_costs), np.mean(test_accs)

Y
Yibing Liu 已提交
176
    # train data reader
177
    train_data_reader = reader.DataReader(args.train_feature_lst,
Z
zhxfl 已提交
178
                                          args.train_label_lst, -1)
179 180
    train_data_reader.set_transformers(ltrans)
    # train
Y
Yibing Liu 已提交
181 182
    for pass_id in xrange(args.pass_num):
        pass_start_time = time.time()
183
        for batch_id, batch_data in enumerate(
184 185
                train_data_reader.batch_iterator(args.batch_size,
                                                 args.minimum_batch_size)):
Y
Yibing Liu 已提交
186
            # load_data
Y
Yibing Liu 已提交
187 188 189 190 191 192 193 194 195 196 197
            (features, labels, lod) = batch_data
            feature_t.set(features, place)
            feature_t.set_lod([lod])
            label_t.set(labels, place)
            label_t.set_lod([lod])

            cost, acc = exe.run(fluid.default_main_program(),
                                feed={"feature": feature_t,
                                      "label": label_t},
                                fetch_list=[avg_cost, accuracy],
                                return_numpy=False)
Y
Yibing Liu 已提交
198 199

            if batch_id > 0 and (batch_id % args.print_per_batches == 0):
Y
Yibing Liu 已提交
200 201 202
                print("\nBatch %d, train cost: %f, train acc: %f" %
                      (batch_id, lodtensor_to_ndarray(cost)[0],
                       lodtensor_to_ndarray(acc)[0]))
Y
Yibing Liu 已提交
203 204 205
            else:
                sys.stdout.write('.')
                sys.stdout.flush()
Y
Yibing Liu 已提交
206
        # run test
207
        val_cost, val_acc = test(exe)
Y
Yibing Liu 已提交
208
        # save model
Y
Yibing Liu 已提交
209
        if args.model_save_dir != '':
Y
Yibing Liu 已提交
210 211 212 213
            model_path = os.path.join(
                args.model_save_dir, "deep_asr.pass_" + str(pass_id) + ".model")
            fluid.io.save_inference_model(model_path, ["feature"],
                                          [prediction], exe)
Y
Yibing Liu 已提交
214 215 216 217
        # cal pass time
        pass_end_time = time.time()
        time_consumed = pass_end_time - pass_start_time
        # print info at pass end
218 219
        print("\nPass %d, time consumed: %f s, val cost: %f, val acc: %f\n" %
              (pass_id, time_consumed, val_cost, val_acc))
Y
Yibing Liu 已提交
220 221 222 223 224 225


if __name__ == '__main__':
    args = parse_args()
    print_arguments(args)

Y
Yibing Liu 已提交
226
    if args.model_save_dir != '' and not os.path.exists(args.model_save_dir):
Y
Yibing Liu 已提交
227 228
        os.mkdir(args.model_save_dir)

Y
Yibing Liu 已提交
229
    train(args)