提交 6c2cf40c 编写于 作者: Y Yibing Liu

Rename prefix 'infer_batch' to 'decode_batch'

上级 66a39088
...@@ -171,11 +171,11 @@ def start_server(): ...@@ -171,11 +171,11 @@ def start_server():
feeding_dict=data_generator.feeding) feeding_dict=data_generator.feeding)
if args.decoding_method == "ctc_greedy": if args.decoding_method == "ctc_greedy":
result_transcript = ds2_model.infer_batch_greedy( result_transcript = ds2_model.decode_batch_greedy(
probs_split=probs_split, probs_split=probs_split,
vocab_list=vocab_list) vocab_list=vocab_list)
else: else:
result_transcript = ds2_model.infer_batch_beam_search( result_transcript = ds2_model.decode_batch_beam_search(
probs_split=probs_split, probs_split=probs_split,
beam_alpha=args.alpha, beam_alpha=args.alpha,
beam_beta=args.beta, beam_beta=args.beta,
......
...@@ -98,11 +98,11 @@ def infer(): ...@@ -98,11 +98,11 @@ def infer():
probs_split = ds2_model.infer_probs_batch(infer_data=infer_data, probs_split = ds2_model.infer_probs_batch(infer_data=infer_data,
feeding_dict=data_generator.feeding) feeding_dict=data_generator.feeding)
if args.decoding_method == "ctc_greedy": if args.decoding_method == "ctc_greedy":
result_transcripts = ds2_model.infer_batch_greedy( result_transcripts = ds2_model.decode_batch_greedy(
probs_split=probs_split, probs_split=probs_split,
vocab_list=vocab_list) vocab_list=vocab_list)
else: else:
result_transcripts = ds2_model.infer_batch_beam_search( result_transcripts = ds2_model.decode_batch_beam_search(
probs_split=probs_split, probs_split=probs_split,
beam_alpha=args.alpha, beam_alpha=args.alpha,
beam_beta=args.beta, beam_beta=args.beta,
......
...@@ -205,8 +205,9 @@ class DeepSpeech2Model(object): ...@@ -205,8 +205,9 @@ class DeepSpeech2Model(object):
] ]
return probs_split return probs_split
def infer_batch_greedy(self, probs_split, vocab_list): def decode_batch_greedy(self, probs_split, vocab_list):
""" """Decode by best path for a batch of probs matrix input.
:param probs_split: List of 2-D probability matrix, and each consists :param probs_split: List of 2-D probability matrix, and each consists
of prob vectors for one speech utterancce. of prob vectors for one speech utterancce.
:param probs_split: List of matrix :param probs_split: List of matrix
...@@ -256,11 +257,10 @@ class DeepSpeech2Model(object): ...@@ -256,11 +257,10 @@ class DeepSpeech2Model(object):
self.logger.info("no language model provided, " self.logger.info("no language model provided, "
"decoding by pure beam search without scorer.") "decoding by pure beam search without scorer.")
def infer_batch_beam_search(self, probs_split, beam_alpha, beam_beta, def decode_batch_beam_search(self, probs_split, beam_alpha, beam_beta,
beam_size, cutoff_prob, cutoff_top_n, beam_size, cutoff_prob, cutoff_top_n,
vocab_list, num_processes): vocab_list, num_processes):
"""Model inference. Infer the transcription for a batch of speech """Decode by beam search for a batch of probs matrix input.
utterances.
:param probs_split: List of 2-D probability matrix, and each consists :param probs_split: List of 2-D probability matrix, and each consists
of prob vectors for one speech utterancce. of prob vectors for one speech utterancce.
......
...@@ -102,11 +102,11 @@ def evaluate(): ...@@ -102,11 +102,11 @@ def evaluate():
feeding_dict=data_generator.feeding) feeding_dict=data_generator.feeding)
if args.decoding_method == "ctc_greedy": if args.decoding_method == "ctc_greedy":
result_transcripts = ds2_model.infer_batch_greedy( result_transcripts = ds2_model.decode_batch_greedy(
probs_split=probs_split, probs_split=probs_split,
vocab_list=vocab_list) vocab_list=vocab_list)
else: else:
result_transcripts = ds2_model.infer_batch_beam_search( result_transcripts = ds2_model.decode_batch_beam_search(
probs_split=probs_split, probs_split=probs_split,
beam_alpha=args.alpha, beam_alpha=args.alpha,
beam_beta=args.beta, beam_beta=args.beta,
......
...@@ -128,7 +128,7 @@ def tune(): ...@@ -128,7 +128,7 @@ def tune():
num_ins += len(target_transcripts) num_ins += len(target_transcripts)
# grid search # grid search
for index, (alpha, beta) in enumerate(params_grid): for index, (alpha, beta) in enumerate(params_grid):
result_transcripts = ds2_model.infer_batch_beam_search( result_transcripts = ds2_model.decode_batch_beam_search(
probs_split=probs_split, probs_split=probs_split,
beam_alpha=alpha, beam_alpha=alpha,
beam_beta=beta, beam_beta=beta,
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册