提交 6d290cec 编写于 作者: Y Yibing Liu

Add the demo script for inference

上级 2738ca10
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import os
import argparse
import paddle.v2.fluid as fluid
import data_utils.augmentor.trans_mean_variance_norm as trans_mean_variance_norm
import data_utils.augmentor.trans_add_delta as trans_add_delta
import data_utils.augmentor.trans_splice as trans_splice
import data_utils.data_reader as reader
from data_utils.util import lodtensor_to_ndarray
def parse_args():
parser = argparse.ArgumentParser("Inference for stacked LSTMP model.")
parser.add_argument(
'--batch_size',
type=int,
default=32,
help='The sequence number of a batch data. (default: %(default)d)')
parser.add_argument(
'--device',
type=str,
default='GPU',
choices=['CPU', 'GPU'],
help='The device type. (default: %(default)s)')
parser.add_argument(
'--mean_var',
type=str,
default='data/global_mean_var_search26kHr',
help='mean var path')
parser.add_argument(
'--infer_feature_lst',
type=str,
default='data/infer_feature.lst',
help='feature list path for inference.')
parser.add_argument(
'--infer_label_lst',
type=str,
default='data/infer_label.lst',
help='label list path for inference.')
parser.add_argument(
'--model_save_path',
type=str,
default='./checkpoints/deep_asr.pass_0.model/',
help='directory to save model.')
args = parser.parse_args()
return args
def print_arguments(args):
print('----------- Configuration Arguments -----------')
for arg, value in sorted(vars(args).iteritems()):
print('%s: %s' % (arg, value))
print('------------------------------------------------')
def split_infer_result(infer_seq, lod):
infer_batch = []
for i in xrange(0, len(lod[0]) - 1):
infer_batch.append(infer_seq[lod[0][i]:lod[0][i + 1]])
return infer_batch
def infer(args):
""" Get one batch of feature data and predicts labels for each sample.
"""
if args.model_save_path is None or \
not os.path.exists(args.model_save_path):
raise IOError("Invalid model path!")
place = fluid.CUDAPlace(0) if args.device == 'GPU' else fluid.CPUPlace()
exe = fluid.Executor(place)
[infer_program, feed_dicts,
fetch_targets] = fluid.io.load_inference_model(args.model_save_path, exe)
ltrans = [
trans_add_delta.TransAddDelta(2, 2),
trans_mean_variance_norm.TransMeanVarianceNorm(args.mean_var),
trans_splice.TransSplice()
]
infer_data_reader = reader.DataReader(args.infer_feature_lst,
args.infer_label_lst)
infer_data_reader.set_transformers(ltrans)
feature_t = fluid.LoDTensor()
one_batch = infer_data_reader.batch_iterator(args.batch_size, 1).next()
(features, labels, lod) = one_batch
feature_t.set(features, place)
feature_t.set_lod([lod])
results = exe.run(infer_program,
feed={feed_dicts[0]: feature_t},
fetch_list=fetch_targets,
return_numpy=False)
probs, lod = lodtensor_to_ndarray(results[0])
preds = probs.argmax(axis=1)
infer_batch = split_infer_result(preds, lod)
for index, sample in enumerate(infer_batch):
print("result %d: " % index, sample, '\n')
if __name__ == '__main__':
args = parse_args()
print_arguments(args)
infer(args)
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册