infer_by_ckpt.py 9.2 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

import sys
import os
import numpy as np
import argparse
import time

import paddle.fluid as fluid
import data_utils.augmentor.trans_mean_variance_norm as trans_mean_variance_norm
import data_utils.augmentor.trans_add_delta as trans_add_delta
import data_utils.augmentor.trans_splice as trans_splice
Z
zhxfl 已提交
15
import data_utils.augmentor.trans_delay as trans_delay
Y
Yibing Liu 已提交
16
import data_utils.async_data_reader as reader
17
from data_utils.util import lodtensor_to_ndarray, split_infer_result
18
from model_utils.model import stacked_lstmp_model
19
from post_latgen_faster_mapped import Decoder
Y
Yibing Liu 已提交
20
from tools.error_rate import char_errors
21 22 23 24 25 26 27 28 29


def parse_args():
    parser = argparse.ArgumentParser("Run inference by using checkpoint.")
    parser.add_argument(
        '--batch_size',
        type=int,
        default=32,
        help='The sequence number of a batch data. (default: %(default)d)')
Y
Yibing Liu 已提交
30 31 32 33 34
    parser.add_argument(
        '--beam_size',
        type=int,
        default=11,
        help='The beam size for decoding. (default: %(default)d)')
35 36 37 38 39 40
    parser.add_argument(
        '--minimum_batch_size',
        type=int,
        default=1,
        help='The minimum sequence number of a batch data. '
        '(default: %(default)d)')
41 42 43
    parser.add_argument(
        '--frame_dim',
        type=int,
Z
zhxfl 已提交
44
        default=80,
45
        help='Frame dimension of feature data. (default: %(default)d)')
46 47 48 49 50 51 52 53 54 55 56 57 58 59 60
    parser.add_argument(
        '--stacked_num',
        type=int,
        default=5,
        help='Number of lstmp layers to stack. (default: %(default)d)')
    parser.add_argument(
        '--proj_dim',
        type=int,
        default=512,
        help='Project size of lstmp unit. (default: %(default)d)')
    parser.add_argument(
        '--hidden_dim',
        type=int,
        default=1024,
        help='Hidden size of lstmp unit. (default: %(default)d)')
61 62 63 64 65
    parser.add_argument(
        '--class_num',
        type=int,
        default=1749,
        help='Number of classes in label. (default: %(default)d)')
66 67 68 69 70
    parser.add_argument(
        '--num_threads',
        type=int,
        default=10,
        help='The number of threads for decoding. (default: %(default)d)')
71 72 73 74 75 76 77 78 79 80 81
    parser.add_argument(
        '--device',
        type=str,
        default='GPU',
        choices=['CPU', 'GPU'],
        help='The device type. (default: %(default)s)')
    parser.add_argument(
        '--parallel', action='store_true', help='If set, run in parallel.')
    parser.add_argument(
        '--mean_var',
        type=str,
82
        default='data/global_mean_var',
83 84 85 86 87 88 89 90 91 92 93
        help="The path for feature's global mean and variance. "
        "(default: %(default)s)")
    parser.add_argument(
        '--infer_feature_lst',
        type=str,
        default='data/infer_feature.lst',
        help='The feature list path for inference. (default: %(default)s)')
    parser.add_argument(
        '--checkpoint',
        type=str,
        default='./checkpoint',
Y
Yibing Liu 已提交
94
        help="The checkpoint path to init model. (default: %(default)s)")
Y
Yibing Liu 已提交
95 96 97 98 99
    parser.add_argument(
        '--trans_model',
        type=str,
        default='./graph/trans_model',
        help="The path to vocabulary. (default: %(default)s)")
Y
Yibing Liu 已提交
100 101 102
    parser.add_argument(
        '--vocabulary',
        type=str,
Y
Yibing Liu 已提交
103
        default='./graph/words.txt',
Y
Yibing Liu 已提交
104 105 106 107
        help="The path to vocabulary. (default: %(default)s)")
    parser.add_argument(
        '--graphs',
        type=str,
Y
Yibing Liu 已提交
108
        default='./graph/TLG.fst',
Y
Yibing Liu 已提交
109 110 111 112
        help="The path to TLG graphs for decoding. (default: %(default)s)")
    parser.add_argument(
        '--log_prior',
        type=str,
Y
Yibing Liu 已提交
113
        default="./logprior",
Y
Yibing Liu 已提交
114
        help="The log prior probs for training data. (default: %(default)s)")
115 116 117 118 119
    parser.add_argument(
        '--acoustic_scale',
        type=float,
        default=0.2,
        help="Scaling factor for acoustic likelihoods. (default: %(default)f)")
Y
Yibing Liu 已提交
120 121 122 123 124
    parser.add_argument(
        '--post_matrix_path',
        type=str,
        default=None,
        help="The path to output post prob matrix. (default: %(default)s)")
125 126 127 128 129 130
    parser.add_argument(
        '--decode_to_path',
        type=str,
        default='./decoding_result.txt',
        required=True,
        help="The path to output the decoding result. (default: %(default)s)")
131 132 133 134 135 136 137 138 139 140 141
    args = parser.parse_args()
    return args


def print_arguments(args):
    print('-----------  Configuration Arguments -----------')
    for arg, value in sorted(vars(args).iteritems()):
        print('%s: %s' % (arg, value))
    print('------------------------------------------------')


142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164
class PostMatrixWriter:
    """ The writer for outputing the post probability matrix
    """

    def __init__(self, to_path):
        self._to_path = to_path
        with open(self._to_path, "w") as post_matrix:
            post_matrix.seek(0)
            post_matrix.truncate()

    def write(self, keys, probs):
        with open(self._to_path, "a") as post_matrix:
            if isinstance(keys, str):
                keys, probs = [keys], [probs]

            for key, prob in zip(keys, probs):
                post_matrix.write(key + " [\n")
                for i in range(prob.shape[0]):
                    for j in range(prob.shape[1]):
                        post_matrix.write(str(prob[i][j]) + " ")
                    post_matrix.write("\n")
                post_matrix.write("]\n")

Y
Yibing Liu 已提交
165

166 167 168
class DecodingResultWriter:
    """ The writer for writing out decoding results
    """
Y
Yibing Liu 已提交
169

170 171 172 173 174 175 176 177 178 179 180 181 182
    def __init__(self, to_path):
        self._to_path = to_path
        with open(self._to_path, "w") as decoding_result:
            decoding_result.seek(0)
            decoding_result.truncate()

    def write(self, results):
        with open(self._to_path, "a") as decoding_result:
            if isinstance(results, str):
                decoding_result.write(results.encode("utf8") + "\n")
            else:
                for result in results:
                    decoding_result.write(result.encode("utf8") + "\n")
Y
Yibing Liu 已提交
183 184


185
def infer_from_ckpt(args):
Y
Yibing Liu 已提交
186
    """Inference by using checkpoint."""
187 188 189 190

    if not os.path.exists(args.checkpoint):
        raise IOError("Invalid checkpoint!")

191 192 193 194 195 196 197 198
    feature = fluid.data(
        name='feature',
        shape=[None, 3, 11, args.frame_dim],
        dtype='float32',
        lod_level=1)
    label = fluid.data(
        name='label', shape=[None, 1], dtype='int64', lod_level=1)

199
    prediction, avg_cost, accuracy = stacked_lstmp_model(
200 201
        feature=feature,
        label=label,
202 203 204
        hidden_dim=args.hidden_dim,
        proj_dim=args.proj_dim,
        stacked_num=args.stacked_num,
205
        class_num=args.class_num,
206 207 208 209
        parallel=args.parallel)

    infer_program = fluid.default_main_program().clone()

210
    # optimizer, placeholder
211 212
    optimizer = fluid.optimizer.Adam(
        learning_rate=fluid.layers.exponential_decay(
213
            learning_rate=0.0001,
214 215 216
            decay_steps=1879,
            decay_rate=1 / 1.2,
            staircase=True))
217 218 219 220 221 222 223 224 225
    optimizer.minimize(avg_cost)

    place = fluid.CPUPlace() if args.device == 'CPU' else fluid.CUDAPlace(0)
    exe = fluid.Executor(place)
    exe.run(fluid.default_startup_program())

    # load checkpoint.
    fluid.io.load_persistables(exe, args.checkpoint)

Y
Yibing Liu 已提交
226
    # init decoder
Y
Yibing Liu 已提交
227
    decoder = Decoder(args.trans_model, args.vocabulary, args.graphs,
Y
Yibing Liu 已提交
228
                      args.log_prior, args.beam_size, args.acoustic_scale)
Y
Yibing Liu 已提交
229

230 231 232
    ltrans = [
        trans_add_delta.TransAddDelta(2, 2),
        trans_mean_variance_norm.TransMeanVarianceNorm(args.mean_var),
Y
Yibing Liu 已提交
233
        trans_splice.TransSplice(5, 5), trans_delay.TransDelay(5)
234 235 236 237 238 239
    ]

    feature_t = fluid.LoDTensor()
    label_t = fluid.LoDTensor()

    # infer data reader
240
    infer_data_reader = reader.AsyncDataReader(
241
        args.infer_feature_lst, drop_frame_len=-1, split_sentence_threshold=-1)
242
    infer_data_reader.set_transformers(ltrans)
243 244 245 246 247

    decoding_result_writer = DecodingResultWriter(args.decode_to_path)
    post_matrix_writer = None if args.post_matrix_path is None \
                         else PostMatrixWriter(args.post_matrix_path)

248 249 250 251
    for batch_id, batch_data in enumerate(
            infer_data_reader.batch_iterator(args.batch_size,
                                             args.minimum_batch_size)):
        # load_data
Y
Yibing Liu 已提交
252
        (features, labels, lod, name_lst) = batch_data
Y
Yibing Liu 已提交
253 254
        features = np.reshape(features, (-1, 11, 3, args.frame_dim))
        features = np.transpose(features, (0, 2, 1, 3))
255 256 257 258
        feature_t.set(features, place)
        feature_t.set_lod([lod])
        label_t.set(labels, place)
        label_t.set_lod([lod])
259

Y
Yibing Liu 已提交
260 261 262 263 264 265 266 267 268
        results = exe.run(infer_program,
                          feed={"feature": feature_t,
                                "label": label_t},
                          fetch_list=[prediction, avg_cost, accuracy],
                          return_numpy=False)

        probs, lod = lodtensor_to_ndarray(results[0])
        infer_batch = split_infer_result(probs, lod)

269
        print("Decoding batch %d ..." % batch_id)
270
        decoded = decoder.decode_batch(name_lst, infer_batch, args.num_threads)
271 272 273

        decoding_result_writer.write(decoded)

274
        if args.post_matrix_path is not None:
275
            post_matrix_writer.write(name_lst, infer_batch)
276 277 278 279 280 281 282


if __name__ == '__main__':
    args = parse_args()
    print_arguments(args)

    infer_from_ckpt(args)