提交 7c6fa642 编写于 作者: Y Yibing Liu

Rename infer_probs_batch to infer_batch_probs

上级 dd2588c9
......@@ -166,7 +166,7 @@ def start_server():
# prepare ASR inference handler
def file_to_transcript(filename):
feature = data_generator.process_utterance(filename, "")
probs_split = ds2_model.infer_probs_batch(
probs_split = ds2_model.infer_batch_probs(
infer_data=[feature],
feeding_dict=data_generator.feeding)
......
......@@ -92,7 +92,7 @@ def infer():
if args.decoding_method == "ctc_greedy":
ds2_model.logger.info("start inference ...")
probs_split = ds2_model.infer_probs_batch(infer_data=infer_data,
probs_split = ds2_model.infer_batch_probs(infer_data=infer_data,
feeding_dict=data_generator.feeding)
result_transcripts = ds2_model.decode_batch_greedy(
probs_split=probs_split,
......@@ -101,7 +101,7 @@ def infer():
ds2_model.init_ext_scorer(args.alpha, args.beta, args.lang_model_path,
vocab_list)
ds2_model.logger.info("start inference ...")
probs_split = ds2_model.infer_probs_batch(infer_data=infer_data,
probs_split = ds2_model.infer_batch_probs(infer_data=infer_data,
feeding_dict=data_generator.feeding)
result_transcripts = ds2_model.decode_batch_beam_search(
probs_split=probs_split,
......
......@@ -173,7 +173,7 @@ class DeepSpeech2Model(object):
# run inference
return self._loss_inferer.infer(input=infer_data)
def infer_probs_batch(self, infer_data, feeding_dict):
def infer_batch_probs(self, infer_data, feeding_dict):
"""Infer the prob matrices for a batch of speech utterances.
:param infer_data: List of utterances to infer, with each utterance
......
......@@ -97,7 +97,7 @@ def evaluate():
errors_sum, len_refs, num_ins = 0.0, 0, 0
ds2_model.logger.info("start evaluation ...")
for infer_data in batch_reader():
probs_split = ds2_model.infer_probs_batch(
probs_split = ds2_model.infer_batch_probs(
infer_data=infer_data,
feeding_dict=data_generator.feeding)
......
......@@ -120,7 +120,7 @@ def tune():
for infer_data in batch_reader():
if (args.num_batches >= 0) and (cur_batch >= args.num_batches):
break
probs_split = ds2_model.infer_probs_batch(
probs_split = ds2_model.infer_batch_probs(
infer_data=infer_data,
feeding_dict=data_generator.feeding)
target_transcripts = [ data[1] for data in infer_data ]
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册