diff --git a/deploy/demo_server.py b/deploy/demo_server.py index 53be16f77a7272564820af403d0be0cfee4e92dc..eca13dcea8d933bba8e29d20cce071cac692432d 100644 --- a/deploy/demo_server.py +++ b/deploy/demo_server.py @@ -171,11 +171,11 @@ def start_server(): feeding_dict=data_generator.feeding) if args.decoding_method == "ctc_greedy": - result_transcript = ds2_model.infer_batch_greedy( + result_transcript = ds2_model.decode_batch_greedy( probs_split=probs_split, vocab_list=vocab_list) else: - result_transcript = ds2_model.infer_batch_beam_search( + result_transcript = ds2_model.decode_batch_beam_search( probs_split=probs_split, beam_alpha=args.alpha, beam_beta=args.beta, diff --git a/infer.py b/infer.py index 5dd9b406d1fc74d1a64fedb13248ed7bc8b9be50..ff45a5dc864cb87f3b4b88ce71467d8566c275a7 100644 --- a/infer.py +++ b/infer.py @@ -98,11 +98,11 @@ def infer(): probs_split = ds2_model.infer_probs_batch(infer_data=infer_data, feeding_dict=data_generator.feeding) if args.decoding_method == "ctc_greedy": - result_transcripts = ds2_model.infer_batch_greedy( + result_transcripts = ds2_model.decode_batch_greedy( probs_split=probs_split, vocab_list=vocab_list) else: - result_transcripts = ds2_model.infer_batch_beam_search( + result_transcripts = ds2_model.decode_batch_beam_search( probs_split=probs_split, beam_alpha=args.alpha, beam_beta=args.beta, diff --git a/model_utils/model.py b/model_utils/model.py index 70ba7bb93c49a9bd363d23933cc477bf69e27257..a8283fae45119f1433a114f7fe084d513a855950 100644 --- a/model_utils/model.py +++ b/model_utils/model.py @@ -205,8 +205,9 @@ class DeepSpeech2Model(object): ] return probs_split - def infer_batch_greedy(self, probs_split, vocab_list): - """ + def decode_batch_greedy(self, probs_split, vocab_list): + """Decode by best path for a batch of probs matrix input. + :param probs_split: List of 2-D probability matrix, and each consists of prob vectors for one speech utterancce. :param probs_split: List of matrix @@ -256,11 +257,10 @@ class DeepSpeech2Model(object): self.logger.info("no language model provided, " "decoding by pure beam search without scorer.") - def infer_batch_beam_search(self, probs_split, beam_alpha, beam_beta, - beam_size, cutoff_prob, cutoff_top_n, - vocab_list, num_processes): - """Model inference. Infer the transcription for a batch of speech - utterances. + def decode_batch_beam_search(self, probs_split, beam_alpha, beam_beta, + beam_size, cutoff_prob, cutoff_top_n, + vocab_list, num_processes): + """Decode by beam search for a batch of probs matrix input. :param probs_split: List of 2-D probability matrix, and each consists of prob vectors for one speech utterancce. diff --git a/test.py b/test.py index 24ce54a2be8c17c2a72d6bd1393c67304613be8f..a82893c03bb16610329f0d2838f12bb0c7bc4113 100644 --- a/test.py +++ b/test.py @@ -102,11 +102,11 @@ def evaluate(): feeding_dict=data_generator.feeding) if args.decoding_method == "ctc_greedy": - result_transcripts = ds2_model.infer_batch_greedy( + result_transcripts = ds2_model.decode_batch_greedy( probs_split=probs_split, vocab_list=vocab_list) else: - result_transcripts = ds2_model.infer_batch_beam_search( + result_transcripts = ds2_model.decode_batch_beam_search( probs_split=probs_split, beam_alpha=args.alpha, beam_beta=args.beta, diff --git a/tools/tune.py b/tools/tune.py index 923e6c3c32a2909bb0fb105325965d1a3637ae87..d8e28c58a5e9045d420061eb276f814b8e4f39af 100644 --- a/tools/tune.py +++ b/tools/tune.py @@ -128,7 +128,7 @@ def tune(): num_ins += len(target_transcripts) # grid search for index, (alpha, beta) in enumerate(params_grid): - result_transcripts = ds2_model.infer_batch_beam_search( + result_transcripts = ds2_model.decode_batch_beam_search( probs_split=probs_split, beam_alpha=alpha, beam_beta=beta,