提交 0deb2e6a 编写于 作者: Y Yibing Liu

add beam search decoder using multiprocesses

上级 bcd01f79
......@@ -2,11 +2,12 @@
CTC-like decoder utilitis.
"""
import os
from itertools import groupby
import numpy as np
import copy
import kenlm
import os
import multiprocessing
def ctc_best_path_decode(probs_seq, vocabulary):
......@@ -187,3 +188,54 @@ def ctc_beam_search_decoder(probs_seq,
## output top beam_size decoding results
beam_result = sorted(beam_result, key=lambda asd: asd[0], reverse=True)
return beam_result
def ctc_beam_search_decoder_nproc(probs_split,
beam_size,
vocabulary,
ext_scoring_func=None,
blank_id=0,
num_processes=None):
'''
Beam search decoder using multiple processes.
:param probs_seq: 3-D list with length num_time_steps, each element
is a 2-D list of probabilities can be used by
ctc_beam_search_decoder.
:type probs_seq: 3-D list
:param beam_size: Width for beam search.
:type beam_size: int
:param vocabulary: Vocabulary list.
:type vocabulary: list
:param ext_scoring_func: External defined scoring function for
partially decoded sentence, e.g. word count
and language model.
:type external_scoring_function: function
:param blank_id: id of blank, default 0.
:type blank_id: int
:param num_processes: Number of processes, default None, equal to the
number of CPUs.
:type num_processes: int
:return: Decoding log probability and result string.
:rtype: list
'''
if num_processes is None:
num_processes = multiprocessing.cpu_count()
if not num_processes > 0:
raise ValueError("Number of processes must be positive!")
pool = multiprocessing.Pool(processes=num_processes)
results = []
for i, probs_list in enumerate(probs_split):
args = (probs_list, beam_size, vocabulary, ext_scoring_func, blank_id)
results.append(pool.apply_async(ctc_beam_search_decoder, args))
pool.close()
pool.join()
beam_search_results = []
for result in results:
beam_search_results.append(result.get())
return beam_search_results
......@@ -9,6 +9,7 @@ import gzip
from audio_data_utils import DataGenerator
from model import deep_speech2
from decoder import *
from error_rate import wer
parser = argparse.ArgumentParser(
description='Simplified version of DeepSpeech2 inference.')
......@@ -59,9 +60,9 @@ parser.add_argument(
help="Vocabulary filepath. (default: %(default)s)")
parser.add_argument(
"--decode_method",
default='beam_search',
default='beam_search_nproc',
type=str,
help="Method for ctc decoding, best_path or beam_search. (default: %(default)s)"
help="Method for ctc decoding, best_path, beam_search or beam_search_nproc. (default: %(default)s)"
)
parser.add_argument(
"--beam_size",
......@@ -151,6 +152,7 @@ def infer():
## decode and print
# best path decode
wer_sum, wer_counter = 0, 0
if args.decode_method == "best_path":
for i, probs in enumerate(probs_split):
target_transcription = ''.join(
......@@ -159,12 +161,17 @@ def infer():
probs_seq=probs, vocabulary=vocab_list)
print("\nTarget Transcription: %s\nOutput Transcription: %s" %
(target_transcription, best_path_transcription))
wer_cur = wer(target_transcription, best_path_transcription)
wer_sum += wer_cur
wer_counter += 1
print("cur wer = %f, average wer = %f" %
(wer_cur, wer_sum / wer_counter))
# beam search decode
elif args.decode_method == "beam_search":
ext_scorer = Scorer(args.alpha, args.beta, args.language_model_path)
for i, probs in enumerate(probs_split):
target_transcription = ''.join(
[vocab_list[index] for index in infer_data[i][1]])
ext_scorer = Scorer(args.alpha, args.beta, args.language_model_path)
beam_search_result = ctc_beam_search_decoder(
probs_seq=probs,
vocabulary=vocab_list,
......@@ -172,10 +179,40 @@ def infer():
ext_scoring_func=ext_scorer.evaluate,
blank_id=len(vocab_list))
print("\nTarget Transcription:\t%s" % target_transcription)
for index in range(args.num_results_per_sample):
result = beam_search_result[index]
#output: index, log prob, beam result
print("Beam %d: %f \t%s" % (index, result[0], result[1]))
wer_cur = wer(target_transcription, beam_search_result[0][1])
wer_sum += wer_cur
wer_counter += 1
print("cur wer = %f , average wer = %f" %
(wer_cur, wer_sum / wer_counter))
# beam search in multiple processes
elif args.decode_method == "beam_search_nproc":
ext_scorer = Scorer(args.alpha, args.beta, args.language_model_path)
beam_search_nproc_results = ctc_beam_search_decoder_nproc(
probs_split=probs_split,
vocabulary=vocab_list,
beam_size=args.beam_size,
#ext_scoring_func=ext_scorer.evaluate,
ext_scoring_func=None,
blank_id=len(vocab_list))
for i, beam_search_result in enumerate(beam_search_nproc_results):
target_transcription = ''.join(
[vocab_list[index] for index in infer_data[i][1]])
print("\nTarget Transcription:\t%s" % target_transcription)
for index in range(args.num_results_per_sample):
result = beam_search_result[index]
#output: index, log prob, beam result
print("Beam %d: %f \t%s" % (index, result[0], result[1]))
wer_cur = wer(target_transcription, beam_search_result[0][1])
wer_sum += wer_cur
wer_counter += 1
print("cur wer = %f , average wer = %f" %
(wer_cur, wer_sum / wer_counter))
else:
raise ValueError("Decoding method [%s] is not supported." % method)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册