提交 7c6fa642 编写于 作者: Y Yibing Liu

Rename infer_probs_batch to infer_batch_probs

上级 dd2588c9
...@@ -166,7 +166,7 @@ def start_server(): ...@@ -166,7 +166,7 @@ def start_server():
# prepare ASR inference handler # prepare ASR inference handler
def file_to_transcript(filename): def file_to_transcript(filename):
feature = data_generator.process_utterance(filename, "") feature = data_generator.process_utterance(filename, "")
probs_split = ds2_model.infer_probs_batch( probs_split = ds2_model.infer_batch_probs(
infer_data=[feature], infer_data=[feature],
feeding_dict=data_generator.feeding) feeding_dict=data_generator.feeding)
......
...@@ -92,7 +92,7 @@ def infer(): ...@@ -92,7 +92,7 @@ def infer():
if args.decoding_method == "ctc_greedy": if args.decoding_method == "ctc_greedy":
ds2_model.logger.info("start inference ...") ds2_model.logger.info("start inference ...")
probs_split = ds2_model.infer_probs_batch(infer_data=infer_data, probs_split = ds2_model.infer_batch_probs(infer_data=infer_data,
feeding_dict=data_generator.feeding) feeding_dict=data_generator.feeding)
result_transcripts = ds2_model.decode_batch_greedy( result_transcripts = ds2_model.decode_batch_greedy(
probs_split=probs_split, probs_split=probs_split,
...@@ -101,7 +101,7 @@ def infer(): ...@@ -101,7 +101,7 @@ def infer():
ds2_model.init_ext_scorer(args.alpha, args.beta, args.lang_model_path, ds2_model.init_ext_scorer(args.alpha, args.beta, args.lang_model_path,
vocab_list) vocab_list)
ds2_model.logger.info("start inference ...") ds2_model.logger.info("start inference ...")
probs_split = ds2_model.infer_probs_batch(infer_data=infer_data, probs_split = ds2_model.infer_batch_probs(infer_data=infer_data,
feeding_dict=data_generator.feeding) feeding_dict=data_generator.feeding)
result_transcripts = ds2_model.decode_batch_beam_search( result_transcripts = ds2_model.decode_batch_beam_search(
probs_split=probs_split, probs_split=probs_split,
......
...@@ -173,7 +173,7 @@ class DeepSpeech2Model(object): ...@@ -173,7 +173,7 @@ class DeepSpeech2Model(object):
# run inference # run inference
return self._loss_inferer.infer(input=infer_data) return self._loss_inferer.infer(input=infer_data)
def infer_probs_batch(self, infer_data, feeding_dict): def infer_batch_probs(self, infer_data, feeding_dict):
"""Infer the prob matrices for a batch of speech utterances. """Infer the prob matrices for a batch of speech utterances.
:param infer_data: List of utterances to infer, with each utterance :param infer_data: List of utterances to infer, with each utterance
......
...@@ -97,7 +97,7 @@ def evaluate(): ...@@ -97,7 +97,7 @@ def evaluate():
errors_sum, len_refs, num_ins = 0.0, 0, 0 errors_sum, len_refs, num_ins = 0.0, 0, 0
ds2_model.logger.info("start evaluation ...") ds2_model.logger.info("start evaluation ...")
for infer_data in batch_reader(): for infer_data in batch_reader():
probs_split = ds2_model.infer_probs_batch( probs_split = ds2_model.infer_batch_probs(
infer_data=infer_data, infer_data=infer_data,
feeding_dict=data_generator.feeding) feeding_dict=data_generator.feeding)
......
...@@ -120,7 +120,7 @@ def tune(): ...@@ -120,7 +120,7 @@ def tune():
for infer_data in batch_reader(): for infer_data in batch_reader():
if (args.num_batches >= 0) and (cur_batch >= args.num_batches): if (args.num_batches >= 0) and (cur_batch >= args.num_batches):
break break
probs_split = ds2_model.infer_probs_batch( probs_split = ds2_model.infer_batch_probs(
infer_data=infer_data, infer_data=infer_data,
feeding_dict=data_generator.feeding) feeding_dict=data_generator.feeding)
target_transcripts = [ data[1] for data in infer_data ] target_transcripts = [ data[1] for data in infer_data ]
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册