infer_by_ckpt.py 6.6 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

import sys
import os
import numpy as np
import argparse
import time

import paddle.fluid as fluid
import data_utils.augmentor.trans_mean_variance_norm as trans_mean_variance_norm
import data_utils.augmentor.trans_add_delta as trans_add_delta
import data_utils.augmentor.trans_splice as trans_splice
Y
Yibing Liu 已提交
15
import data_utils.async_data_reader as reader
Y
Yibing Liu 已提交
16
from decoder.post_decode_faster import Decoder
17 18
from data_utils.util import lodtensor_to_ndarray
from model_utils.model import stacked_lstmp_model
Y
Yibing Liu 已提交
19
from data_utils.util import split_infer_result
20 21 22 23 24 25 26 27 28 29 30 31 32 33 34


def parse_args():
    parser = argparse.ArgumentParser("Run inference by using checkpoint.")
    parser.add_argument(
        '--batch_size',
        type=int,
        default=32,
        help='The sequence number of a batch data. (default: %(default)d)')
    parser.add_argument(
        '--minimum_batch_size',
        type=int,
        default=1,
        help='The minimum sequence number of a batch data. '
        '(default: %(default)d)')
35 36 37 38 39
    parser.add_argument(
        '--frame_dim',
        type=int,
        default=120 * 11,
        help='Frame dimension of feature data. (default: %(default)d)')
40 41 42 43 44 45 46 47 48 49 50 51 52 53 54
    parser.add_argument(
        '--stacked_num',
        type=int,
        default=5,
        help='Number of lstmp layers to stack. (default: %(default)d)')
    parser.add_argument(
        '--proj_dim',
        type=int,
        default=512,
        help='Project size of lstmp unit. (default: %(default)d)')
    parser.add_argument(
        '--hidden_dim',
        type=int,
        default=1024,
        help='Hidden size of lstmp unit. (default: %(default)d)')
55 56 57 58 59
    parser.add_argument(
        '--class_num',
        type=int,
        default=1749,
        help='Number of classes in label. (default: %(default)d)')
60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92
    parser.add_argument(
        '--learning_rate',
        type=float,
        default=0.00016,
        help='Learning rate used to train. (default: %(default)f)')
    parser.add_argument(
        '--device',
        type=str,
        default='GPU',
        choices=['CPU', 'GPU'],
        help='The device type. (default: %(default)s)')
    parser.add_argument(
        '--parallel', action='store_true', help='If set, run in parallel.')
    parser.add_argument(
        '--mean_var',
        type=str,
        default='data/global_mean_var_search26kHr',
        help="The path for feature's global mean and variance. "
        "(default: %(default)s)")
    parser.add_argument(
        '--infer_feature_lst',
        type=str,
        default='data/infer_feature.lst',
        help='The feature list path for inference. (default: %(default)s)')
    parser.add_argument(
        '--infer_label_lst',
        type=str,
        default='data/infer_label.lst',
        help='The label list path for inference. (default: %(default)s)')
    parser.add_argument(
        '--checkpoint',
        type=str,
        default='./checkpoint',
Y
Yibing Liu 已提交
93
        help="The checkpoint path to init model. (default: %(default)s)")
Y
Yibing Liu 已提交
94 95 96 97 98 99 100 101 102 103 104 105 106 107 108
    parser.add_argument(
        '--vocabulary',
        type=str,
        default='./decoder/graph/words.txt',
        help="The path to vocabulary. (default: %(default)s)")
    parser.add_argument(
        '--graphs',
        type=str,
        default='./decoder/graph/TLG.fst',
        help="The path to TLG graphs for decoding. (default: %(default)s)")
    parser.add_argument(
        '--log_prior',
        type=str,
        default="./decoder/logprior",
        help="The log prior probs for training data. (default: %(default)s)")
109 110 111 112 113
    parser.add_argument(
        '--acoustic_scale',
        type=float,
        default=0.2,
        help="Scaling factor for acoustic likelihoods. (default: %(default)f)")
114 115 116 117 118 119 120 121 122 123 124 125
    args = parser.parse_args()
    return args


def print_arguments(args):
    print('-----------  Configuration Arguments -----------')
    for arg, value in sorted(vars(args).iteritems()):
        print('%s: %s' % (arg, value))
    print('------------------------------------------------')


def infer_from_ckpt(args):
Y
Yibing Liu 已提交
126
    """Inference by using checkpoint."""
127 128 129 130 131

    if not os.path.exists(args.checkpoint):
        raise IOError("Invalid checkpoint!")

    prediction, avg_cost, accuracy = stacked_lstmp_model(
132
        frame_dim=args.frame_dim,
133 134 135
        hidden_dim=args.hidden_dim,
        proj_dim=args.proj_dim,
        stacked_num=args.stacked_num,
136
        class_num=args.class_num,
137 138 139 140 141 142 143 144 145 146 147 148 149 150
        parallel=args.parallel)

    infer_program = fluid.default_main_program().clone()

    optimizer = fluid.optimizer.Adam(learning_rate=args.learning_rate)
    optimizer.minimize(avg_cost)

    place = fluid.CPUPlace() if args.device == 'CPU' else fluid.CUDAPlace(0)
    exe = fluid.Executor(place)
    exe.run(fluid.default_startup_program())

    # load checkpoint.
    fluid.io.load_persistables(exe, args.checkpoint)

Y
Yibing Liu 已提交
151 152 153
    # init decoder
    decoder = Decoder(args.vocabulary, args.graphs, args.log_prior)

154 155 156 157 158 159 160 161 162 163
    ltrans = [
        trans_add_delta.TransAddDelta(2, 2),
        trans_mean_variance_norm.TransMeanVarianceNorm(args.mean_var),
        trans_splice.TransSplice()
    ]

    feature_t = fluid.LoDTensor()
    label_t = fluid.LoDTensor()

    # infer data reader
Y
Yibing Liu 已提交
164 165
    infer_data_reader = reader.AsyncDataReader(args.infer_feature_lst,
                                               args.infer_label_lst)
166 167 168 169 170 171 172
    infer_data_reader.set_transformers(ltrans)
    infer_costs, infer_accs = [], []
    for batch_id, batch_data in enumerate(
            infer_data_reader.batch_iterator(args.batch_size,
                                             args.minimum_batch_size)):
        # load_data
        (features, labels, lod) = batch_data
173 174 175 176
        feature_t.set(features, place)
        feature_t.set_lod([lod])
        label_t.set(labels, place)
        label_t.set_lod([lod])
177

Y
Yibing Liu 已提交
178 179 180 181 182 183 184 185 186 187 188
        results = exe.run(infer_program,
                          feed={"feature": feature_t,
                                "label": label_t},
                          fetch_list=[prediction, avg_cost, accuracy],
                          return_numpy=False)
        infer_costs.append(lodtensor_to_ndarray(results[1])[0])
        infer_accs.append(lodtensor_to_ndarray(results[2])[0])

        probs, lod = lodtensor_to_ndarray(results[0])
        infer_batch = split_infer_result(probs, lod)
        for index, sample in enumerate(infer_batch):
Y
Yibing Liu 已提交
189
            key = "utter#%d" % (batch_id * args.batch_size + index)
Y
Yibing Liu 已提交
190
            print(key, ": ", decoder.decode(key, sample).encode("utf8"), "\n")
Y
Yibing Liu 已提交
191

192 193 194 195 196 197 198 199
    print(np.mean(infer_costs), np.mean(infer_accs))


if __name__ == '__main__':
    args = parse_args()
    print_arguments(args)

    infer_from_ckpt(args)