提交 0deb2e6a 编写于 作者: Y Yibing Liu

add beam search decoder using multiprocesses

上级 bcd01f79
...@@ -2,11 +2,12 @@ ...@@ -2,11 +2,12 @@
CTC-like decoder utilitis. CTC-like decoder utilitis.
""" """
import os
from itertools import groupby from itertools import groupby
import numpy as np import numpy as np
import copy import copy
import kenlm import kenlm
import os import multiprocessing
def ctc_best_path_decode(probs_seq, vocabulary): def ctc_best_path_decode(probs_seq, vocabulary):
...@@ -187,3 +188,54 @@ def ctc_beam_search_decoder(probs_seq, ...@@ -187,3 +188,54 @@ def ctc_beam_search_decoder(probs_seq,
## output top beam_size decoding results ## output top beam_size decoding results
beam_result = sorted(beam_result, key=lambda asd: asd[0], reverse=True) beam_result = sorted(beam_result, key=lambda asd: asd[0], reverse=True)
return beam_result return beam_result
def ctc_beam_search_decoder_nproc(probs_split,
beam_size,
vocabulary,
ext_scoring_func=None,
blank_id=0,
num_processes=None):
'''
Beam search decoder using multiple processes.
:param probs_seq: 3-D list with length num_time_steps, each element
is a 2-D list of probabilities can be used by
ctc_beam_search_decoder.
:type probs_seq: 3-D list
:param beam_size: Width for beam search.
:type beam_size: int
:param vocabulary: Vocabulary list.
:type vocabulary: list
:param ext_scoring_func: External defined scoring function for
partially decoded sentence, e.g. word count
and language model.
:type external_scoring_function: function
:param blank_id: id of blank, default 0.
:type blank_id: int
:param num_processes: Number of processes, default None, equal to the
number of CPUs.
:type num_processes: int
:return: Decoding log probability and result string.
:rtype: list
'''
if num_processes is None:
num_processes = multiprocessing.cpu_count()
if not num_processes > 0:
raise ValueError("Number of processes must be positive!")
pool = multiprocessing.Pool(processes=num_processes)
results = []
for i, probs_list in enumerate(probs_split):
args = (probs_list, beam_size, vocabulary, ext_scoring_func, blank_id)
results.append(pool.apply_async(ctc_beam_search_decoder, args))
pool.close()
pool.join()
beam_search_results = []
for result in results:
beam_search_results.append(result.get())
return beam_search_results
...@@ -9,6 +9,7 @@ import gzip ...@@ -9,6 +9,7 @@ import gzip
from audio_data_utils import DataGenerator from audio_data_utils import DataGenerator
from model import deep_speech2 from model import deep_speech2
from decoder import * from decoder import *
from error_rate import wer
parser = argparse.ArgumentParser( parser = argparse.ArgumentParser(
description='Simplified version of DeepSpeech2 inference.') description='Simplified version of DeepSpeech2 inference.')
...@@ -59,9 +60,9 @@ parser.add_argument( ...@@ -59,9 +60,9 @@ parser.add_argument(
help="Vocabulary filepath. (default: %(default)s)") help="Vocabulary filepath. (default: %(default)s)")
parser.add_argument( parser.add_argument(
"--decode_method", "--decode_method",
default='beam_search', default='beam_search_nproc',
type=str, type=str,
help="Method for ctc decoding, best_path or beam_search. (default: %(default)s)" help="Method for ctc decoding, best_path, beam_search or beam_search_nproc. (default: %(default)s)"
) )
parser.add_argument( parser.add_argument(
"--beam_size", "--beam_size",
...@@ -151,6 +152,7 @@ def infer(): ...@@ -151,6 +152,7 @@ def infer():
## decode and print ## decode and print
# best path decode # best path decode
wer_sum, wer_counter = 0, 0
if args.decode_method == "best_path": if args.decode_method == "best_path":
for i, probs in enumerate(probs_split): for i, probs in enumerate(probs_split):
target_transcription = ''.join( target_transcription = ''.join(
...@@ -159,12 +161,17 @@ def infer(): ...@@ -159,12 +161,17 @@ def infer():
probs_seq=probs, vocabulary=vocab_list) probs_seq=probs, vocabulary=vocab_list)
print("\nTarget Transcription: %s\nOutput Transcription: %s" % print("\nTarget Transcription: %s\nOutput Transcription: %s" %
(target_transcription, best_path_transcription)) (target_transcription, best_path_transcription))
wer_cur = wer(target_transcription, best_path_transcription)
wer_sum += wer_cur
wer_counter += 1
print("cur wer = %f, average wer = %f" %
(wer_cur, wer_sum / wer_counter))
# beam search decode # beam search decode
elif args.decode_method == "beam_search": elif args.decode_method == "beam_search":
ext_scorer = Scorer(args.alpha, args.beta, args.language_model_path)
for i, probs in enumerate(probs_split): for i, probs in enumerate(probs_split):
target_transcription = ''.join( target_transcription = ''.join(
[vocab_list[index] for index in infer_data[i][1]]) [vocab_list[index] for index in infer_data[i][1]])
ext_scorer = Scorer(args.alpha, args.beta, args.language_model_path)
beam_search_result = ctc_beam_search_decoder( beam_search_result = ctc_beam_search_decoder(
probs_seq=probs, probs_seq=probs,
vocabulary=vocab_list, vocabulary=vocab_list,
...@@ -172,10 +179,40 @@ def infer(): ...@@ -172,10 +179,40 @@ def infer():
ext_scoring_func=ext_scorer.evaluate, ext_scoring_func=ext_scorer.evaluate,
blank_id=len(vocab_list)) blank_id=len(vocab_list))
print("\nTarget Transcription:\t%s" % target_transcription) print("\nTarget Transcription:\t%s" % target_transcription)
for index in range(args.num_results_per_sample):
result = beam_search_result[index]
#output: index, log prob, beam result
print("Beam %d: %f \t%s" % (index, result[0], result[1]))
wer_cur = wer(target_transcription, beam_search_result[0][1])
wer_sum += wer_cur
wer_counter += 1
print("cur wer = %f , average wer = %f" %
(wer_cur, wer_sum / wer_counter))
# beam search in multiple processes
elif args.decode_method == "beam_search_nproc":
ext_scorer = Scorer(args.alpha, args.beta, args.language_model_path)
beam_search_nproc_results = ctc_beam_search_decoder_nproc(
probs_split=probs_split,
vocabulary=vocab_list,
beam_size=args.beam_size,
#ext_scoring_func=ext_scorer.evaluate,
ext_scoring_func=None,
blank_id=len(vocab_list))
for i, beam_search_result in enumerate(beam_search_nproc_results):
target_transcription = ''.join(
[vocab_list[index] for index in infer_data[i][1]])
print("\nTarget Transcription:\t%s" % target_transcription)
for index in range(args.num_results_per_sample): for index in range(args.num_results_per_sample):
result = beam_search_result[index] result = beam_search_result[index]
#output: index, log prob, beam result #output: index, log prob, beam result
print("Beam %d: %f \t%s" % (index, result[0], result[1])) print("Beam %d: %f \t%s" % (index, result[0], result[1]))
wer_cur = wer(target_transcription, beam_search_result[0][1])
wer_sum += wer_cur
wer_counter += 1
print("cur wer = %f , average wer = %f" %
(wer_cur, wer_sum / wer_counter))
else: else:
raise ValueError("Decoding method [%s] is not supported." % method) raise ValueError("Decoding method [%s] is not supported." % method)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册