提交 63a72c1e 编写于 作者: Y Yibing Liu

refine ctc_beam_search_decoder

上级 accaf924
...@@ -8,8 +8,8 @@ import numpy as np ...@@ -8,8 +8,8 @@ import numpy as np
import multiprocessing import multiprocessing
def ctc_best_path_decode(probs_seq, vocabulary): def ctc_best_path_decoder(probs_seq, vocabulary):
"""Best path decoding, also called argmax decoding or greedy decoding. """Best path decoder, also called argmax decoder or greedy decoder.
Path consisting of the most probable tokens are further post-processed to Path consisting of the most probable tokens are further post-processed to
remove consecutive repetitions and all blanks. remove consecutive repetitions and all blanks.
...@@ -40,73 +40,84 @@ def ctc_best_path_decode(probs_seq, vocabulary): ...@@ -40,73 +40,84 @@ def ctc_best_path_decode(probs_seq, vocabulary):
def ctc_beam_search_decoder(probs_seq, def ctc_beam_search_decoder(probs_seq,
beam_size, beam_size,
vocabulary, vocabulary,
blank_id=0, blank_id,
cutoff_prob=1.0, cutoff_prob=1.0,
ext_scoring_func=None, ext_scoring_func=None,
nproc=False): nproc=False):
'''Beam search decoder for CTC-trained network, using beam search with width """Beam search decoder for CTC-trained network. It utilizes beam search
beam_size to find many paths to one label, return beam_size labels in to approximately select top best decoding labels and returning results
the descending order of probabilities. The implementation is based on Prefix in the descending order. The implementation is based on Prefix
Beam Search(https://arxiv.org/abs/1408.2873), and the unclear part is Beam Search (https://arxiv.org/abs/1408.2873), and the unclear part is
redesigned. redesigned. Two important modifications: 1) in the iterative computation
of probabilities, the assignment operation is changed to accumulation for
:param probs_seq: 2-D list with length num_time_steps, each element one prefix may comes from different paths; 2) the if condition "if l^+ not
is a list of normalized probabilities over vocabulary in A_prev then" after probabilities' computation is deprecated for it is
and blank for one time step. hard to understand and seems unnecessary.
:param probs_seq: 2-D list of probability distributions over each time
step, with each element being a list of normalized
probabilities over vocabulary and blank.
:type probs_seq: 2-D list :type probs_seq: 2-D list
:param beam_size: Width for beam search. :param beam_size: Width for beam search.
:type beam_size: int :type beam_size: int
:param vocabulary: Vocabulary list. :param vocabulary: Vocabulary list.
:type vocabulary: list :type vocabulary: list
:param blank_id: ID of blank, default 0. :param blank_id: ID of blank.
:type blank_id: int :type blank_id: int
:param cutoff_prob: Cutoff probability in pruning, :param cutoff_prob: Cutoff probability in pruning,
default 1.0, no pruning. default 1.0, no pruning.
:type cutoff_prob: float :type cutoff_prob: float
:param ext_scoring_func: External defined scoring function for :param ext_scoring_func: External scoring function for
partially decoded sentence, e.g. word count partially decoded sentence, e.g. word count
and language model. or language model.
:type external_scoring_function: function :type external_scoring_func: callable
:param nproc: Whether the decoder used in multiprocesses. :param nproc: Whether the decoder used in multiprocesses.
:type nproc: bool :type nproc: bool
:return: Decoding log probabilities and result sentences in descending order. :return: List of tuples of log probability and sentence as decoding
results, in descending order of the probability.
:rtype: list :rtype: list
''' """
# dimension check # dimension check
for prob_list in probs_seq: for prob_list in probs_seq:
if not len(prob_list) == len(vocabulary) + 1: if not len(prob_list) == len(vocabulary) + 1:
raise ValueError("probs dimension mismatched with vocabulary") raise ValueError("The shape of prob_seq does not match with the "
num_time_steps = len(probs_seq) "shape of the vocabulary.")
# blank_id check # blank_id check
probs_dim = len(probs_seq[0]) if not blank_id < len(probs_seq[0]):
if not blank_id < probs_dim:
raise ValueError("blank_id shouldn't be greater than probs dimension") raise ValueError("blank_id shouldn't be greater than probs dimension")
# If the decoder called in the multiprocesses, then use the global scorer # If the decoder called in the multiprocesses, then use the global scorer
# instantiated in ctc_beam_search_decoder_nproc(). # instantiated in ctc_beam_search_decoder_batch().
if nproc is True: if nproc is True:
global ext_nproc_scorer global ext_nproc_scorer
ext_scoring_func = ext_nproc_scorer ext_scoring_func = ext_nproc_scorer
## initialize ## initialize
# the set containing selected prefixes # prefix_set_prev: the set containing selected prefixes
prefix_set_prev = {'\t': 1.0} # probs_b_prev: prefixes' probability ending with blank in previous step
probs_b_prev, probs_nb_prev = {'\t': 1.0}, {'\t': 0.0} # probs_nb_prev: prefixes' probability ending with non-blank in previous step
prefix_set_prev, probs_b_prev, probs_nb_prev = {
'\t': 1.0
}, {
'\t': 1.0
}, {
'\t': 0.0
}
## extend prefix in loop ## extend prefix in loop
for time_step in xrange(num_time_steps): for time_step in xrange(len(probs_seq)):
# the set containing candidate prefixes # prefix_set_next: the set containing candidate prefixes
prefix_set_next = {} # probs_b_cur: prefixes' probability ending with blank in current step
probs_b_cur, probs_nb_cur = {}, {} # probs_nb_cur: prefixes' probability ending with non-blank in current step
prob = probs_seq[time_step] prefix_set_next, probs_b_cur, probs_nb_cur = {}, {}, {}
prob_idx = [[i, prob[i]] for i in xrange(len(prob))]
prob_idx = list(enumerate(probs_seq[time_step]))
cutoff_len = len(prob_idx) cutoff_len = len(prob_idx)
#If pruning is enabled #If pruning is enabled
if (cutoff_prob < 1.0): if cutoff_prob < 1.0:
prob_idx = sorted(prob_idx, key=lambda asd: asd[1], reverse=True) prob_idx = sorted(prob_idx, key=lambda asd: asd[1], reverse=True)
cutoff_len = 0 cutoff_len, cum_prob = 0, 0.0
cum_prob = 0.0
for i in xrange(len(prob_idx)): for i in xrange(len(prob_idx)):
cum_prob += prob_idx[i][1] cum_prob += prob_idx[i][1]
cutoff_len += 1 cutoff_len += 1
...@@ -162,54 +173,53 @@ def ctc_beam_search_decoder(probs_seq, ...@@ -162,54 +173,53 @@ def ctc_beam_search_decoder(probs_seq,
prefix_set_prev = dict(prefix_set_prev) prefix_set_prev = dict(prefix_set_prev)
beam_result = [] beam_result = []
for (seq, prob) in prefix_set_prev.items(): for seq, prob in prefix_set_prev.items():
if prob > 0.0 and len(seq) > 1: if prob > 0.0 and len(seq) > 1:
result = seq[1:] result = seq[1:]
# score last word by external scorer # score last word by external scorer
if (ext_scoring_func is not None) and (result[-1] != ' '): if (ext_scoring_func is not None) and (result[-1] != ' '):
prob = prob * ext_scoring_func(result) prob = prob * ext_scoring_func(result)
log_prob = np.log(prob) log_prob = np.log(prob)
beam_result.append([log_prob, result]) beam_result.append((log_prob, result))
## output top beam_size decoding results ## output top beam_size decoding results
beam_result = sorted(beam_result, key=lambda asd: asd[0], reverse=True) beam_result = sorted(beam_result, key=lambda asd: asd[0], reverse=True)
return beam_result return beam_result
def ctc_beam_search_decoder_nproc(probs_split, def ctc_beam_search_decoder_batch(probs_split,
beam_size, beam_size,
vocabulary, vocabulary,
blank_id=0, blank_id,
num_processes,
cutoff_prob=1.0, cutoff_prob=1.0,
ext_scoring_func=None, ext_scoring_func=None):
num_processes=None): """CTC beam search decoder using multiple processes.
'''Beam search decoder using multiple processes.
:param probs_seq: 3-D list with length batch_size, each element :param probs_seq: 3-D list with each element as an instance of 2-D list
is a 2-D list of probabilities can be used by of probabilities used by ctc_beam_search_decoder().
ctc_beam_search_decoder.
:type probs_seq: 3-D list :type probs_seq: 3-D list
:param beam_size: Width for beam search. :param beam_size: Width for beam search.
:type beam_size: int :type beam_size: int
:param vocabulary: Vocabulary list. :param vocabulary: Vocabulary list.
:type vocabulary: list :type vocabulary: list
:param blank_id: ID of blank, default 0. :param blank_id: ID of blank.
:type blank_id: int :type blank_id: int
:param num_processes: Number of parallel processes.
:type num_processes: int
:param cutoff_prob: Cutoff probability in pruning, :param cutoff_prob: Cutoff probability in pruning,
default 0, no pruning. default 1.0, no pruning.
:param num_processes: Number of parallel processes.
:type num_processes: int
:type cutoff_prob: float :type cutoff_prob: float
:param ext_scoring_func: External defined scoring function for :param ext_scoring_func: External scoring function for
partially decoded sentence, e.g. word count partially decoded sentence, e.g. word count
and language model. or language model.
:type external_scoring_function: function :type external_scoring_function: callable
:param num_processes: Number of processes, default None, equal to the :return: List of tuples of log probability and sentence as decoding
number of CPUs. results, in descending order of the probability.
:type num_processes: int
:return: Decoding log probabilities and result sentences in descending order.
:rtype: list :rtype: list
''' """
if num_processes is None:
num_processes = multiprocessing.cpu_count()
if not num_processes > 0: if not num_processes > 0:
raise ValueError("Number of processes must be positive!") raise ValueError("Number of processes must be positive!")
...@@ -227,7 +237,5 @@ def ctc_beam_search_decoder_nproc(probs_split, ...@@ -227,7 +237,5 @@ def ctc_beam_search_decoder_nproc(probs_split,
pool.close() pool.close()
pool.join() pool.join()
beam_search_results = [] beam_search_results = [result.get() for result in results]
for result in results:
beam_search_results.append(result.get())
return beam_search_results return beam_search_results
...@@ -3,22 +3,22 @@ from __future__ import absolute_import ...@@ -3,22 +3,22 @@ from __future__ import absolute_import
from __future__ import division from __future__ import division
from __future__ import print_function from __future__ import print_function
import paddle.v2 as paddle
import distutils.util import distutils.util
import argparse import argparse
import gzip import gzip
import paddle.v2 as paddle
from data_utils.data import DataGenerator from data_utils.data import DataGenerator
from model import deep_speech2 from model import deep_speech2
from decoder import * from decoder import *
from scorer import Scorer from lm.lm_scorer import LmScorer
from error_rate import wer from error_rate import wer
parser = argparse.ArgumentParser(description=__doc__) parser = argparse.ArgumentParser(description=__doc__)
parser.add_argument( parser.add_argument(
"--num_samples", "--batch_size",
default=100, default=100,
type=int, type=int,
help="Number of samples for evaluation. (default: %(default)s)") help="Minibatch size for evaluation. (default: %(default)s)")
parser.add_argument( parser.add_argument(
"--num_conv_layers", "--num_conv_layers",
default=2, default=2,
...@@ -39,6 +39,16 @@ parser.add_argument( ...@@ -39,6 +39,16 @@ parser.add_argument(
default=True, default=True,
type=distutils.util.strtobool, type=distutils.util.strtobool,
help="Use gpu or not. (default: %(default)s)") help="Use gpu or not. (default: %(default)s)")
parser.add_argument(
"--num_threads_data",
default=multiprocessing.cpu_count(),
type=int,
help="Number of cpu threads for preprocessing data. (default: %(default)s)")
parser.add_argument(
"--num_processes_beam_search",
default=multiprocessing.cpu_count(),
type=int,
help="Number of cpu processes for beam search. (default: %(default)s)")
parser.add_argument( parser.add_argument(
"--mean_std_filepath", "--mean_std_filepath",
default='mean_std.npz', default='mean_std.npz',
...@@ -46,10 +56,10 @@ parser.add_argument( ...@@ -46,10 +56,10 @@ parser.add_argument(
help="Manifest path for normalizer. (default: %(default)s)") help="Manifest path for normalizer. (default: %(default)s)")
parser.add_argument( parser.add_argument(
"--decode_method", "--decode_method",
default='beam_search_nproc', default='beam_search',
type=str, type=str,
help="Method for ctc decoding, best_path, " help="Method for ctc decoding, best_path or beam_search. (default: %(default)s)"
"beam_search or beam_search_nproc. (default: %(default)s)") )
parser.add_argument( parser.add_argument(
"--language_model_path", "--language_model_path",
default="data/en.00.UNKNOWN.klm", default="data/en.00.UNKNOWN.klm",
...@@ -76,11 +86,6 @@ parser.add_argument( ...@@ -76,11 +86,6 @@ parser.add_argument(
default=500, default=500,
type=int, type=int,
help="Width for beam search decoding. (default: %(default)d)") help="Width for beam search decoding. (default: %(default)d)")
parser.add_argument(
"--normalizer_manifest_path",
default='data/manifest.libri.train-clean-100',
type=str,
help="Manifest path for normalizer. (default: %(default)s)")
parser.add_argument( parser.add_argument(
"--decode_manifest_path", "--decode_manifest_path",
default='data/manifest.libri.test-clean', default='data/manifest.libri.test-clean',
...@@ -88,7 +93,7 @@ parser.add_argument( ...@@ -88,7 +93,7 @@ parser.add_argument(
help="Manifest path for decoding. (default: %(default)s)") help="Manifest path for decoding. (default: %(default)s)")
parser.add_argument( parser.add_argument(
"--model_filepath", "--model_filepath",
default='./params.tar.gz', default='checkpoints/params.latest.tar.gz',
type=str, type=str,
help="Model filepath. (default: %(default)s)") help="Model filepath. (default: %(default)s)")
parser.add_argument( parser.add_argument(
...@@ -101,12 +106,12 @@ args = parser.parse_args() ...@@ -101,12 +106,12 @@ args = parser.parse_args()
def evaluate(): def evaluate():
"""Evaluate on whole test data for DeepSpeech2.""" """Evaluate on whole test data for DeepSpeech2."""
# initialize data generator # initialize data generator
data_generator = DataGenerator( data_generator = DataGenerator(
vocab_filepath=args.vocab_filepath, vocab_filepath=args.vocab_filepath,
mean_std_filepath=args.mean_std_filepath, mean_std_filepath=args.mean_std_filepath,
augmentation_config='{}') augmentation_config='{}',
num_threads=args.num_threads_data)
# create network config # create network config
# paddle.data_type.dense_array is used for variable batch input. # paddle.data_type.dense_array is used for variable batch input.
...@@ -133,7 +138,7 @@ def evaluate(): ...@@ -133,7 +138,7 @@ def evaluate():
# prepare infer data # prepare infer data
batch_reader = data_generator.batch_reader_creator( batch_reader = data_generator.batch_reader_creator(
manifest_path=args.decode_manifest_path, manifest_path=args.decode_manifest_path,
batch_size=args.num_samples, batch_size=args.batch_size,
sortagrad=False, sortagrad=False,
shuffle_method=None) shuffle_method=None)
...@@ -142,9 +147,8 @@ def evaluate(): ...@@ -142,9 +147,8 @@ def evaluate():
output_layer=output_probs, parameters=parameters) output_layer=output_probs, parameters=parameters)
# initialize external scorer for beam search decoding # initialize external scorer for beam search decoding
if args.decode_method == 'beam_search' or \ if args.decode_method == 'beam_search':
args.decode_method == 'beam_search_nproc': ext_scorer = LmScorer(args.alpha, args.beta, args.language_model_path)
ext_scorer = Scorer(args.alpha, args.beta, args.language_model_path)
wer_counter, wer_sum = 0, 0.0 wer_counter, wer_sum = 0, 0.0
for infer_data in batch_reader(): for infer_data in batch_reader():
...@@ -155,56 +159,39 @@ def evaluate(): ...@@ -155,56 +159,39 @@ def evaluate():
infer_results[i * num_steps:(i + 1) * num_steps] infer_results[i * num_steps:(i + 1) * num_steps]
for i in xrange(0, len(infer_data)) for i in xrange(0, len(infer_data))
] ]
# target transcription
target_transcription = [
''.join([
data_generator.vocab_list[index] for index in infer_data[i][1]
]) for i, probs in enumerate(probs_split)
]
# decode and print # decode and print
# best path decode # best path decode
if args.decode_method == "best_path": if args.decode_method == "best_path":
for i, probs in enumerate(probs_split): for i, probs in enumerate(probs_split):
output_transcription = ctc_best_path_decode( output_transcription = ctc_best_path_decoder(
probs_seq=probs, vocabulary=data_generator.vocab_list) probs_seq=probs, vocabulary=data_generator.vocab_list)
target_transcription = ''.join([ wer_sum += wer(target_transcription[i], output_transcription)
data_generator.vocab_list[index]
for index in infer_data[i][1]
])
wer_sum += wer(target_transcription, output_transcription)
wer_counter += 1 wer_counter += 1
# beam search decode in single process # beam search decode
elif args.decode_method == "beam_search": elif args.decode_method == "beam_search":
for i, probs in enumerate(probs_split):
target_transcription = ''.join([
data_generator.vocab_list[index]
for index in infer_data[i][1]
])
beam_search_result = ctc_beam_search_decoder(
probs_seq=probs,
vocabulary=data_generator.vocab_list,
beam_size=args.beam_size,
blank_id=len(data_generator.vocab_list),
ext_scoring_func=ext_scorer,
cutoff_prob=args.cutoff_prob, )
wer_sum += wer(target_transcription, beam_search_result[0][1])
wer_counter += 1
# beam search using multiple processes # beam search using multiple processes
elif args.decode_method == "beam_search_nproc": beam_search_results = ctc_beam_search_decoder_batch(
beam_search_nproc_results = ctc_beam_search_decoder_nproc(
probs_split=probs_split, probs_split=probs_split,
vocabulary=data_generator.vocab_list, vocabulary=data_generator.vocab_list,
beam_size=args.beam_size, beam_size=args.beam_size,
blank_id=len(data_generator.vocab_list), blank_id=len(data_generator.vocab_list),
num_processes=args.num_processes_beam_search,
ext_scoring_func=ext_scorer, ext_scoring_func=ext_scorer,
cutoff_prob=args.cutoff_prob, ) cutoff_prob=args.cutoff_prob, )
for i, beam_search_result in enumerate(beam_search_nproc_results): for i, beam_search_result in enumerate(beam_search_results):
target_transcription = ''.join([ wer_sum += wer(target_transcription[i],
data_generator.vocab_list[index] beam_search_result[0][1])
for index in infer_data[i][1]
])
wer_sum += wer(target_transcription, beam_search_result[0][1])
wer_counter += 1 wer_counter += 1
else: else:
raise ValueError("Decoding method [%s] is not supported." % raise ValueError("Decoding method [%s] is not supported." %
decode_method) decode_method)
print("Cur WER = %f" % (wer_sum / wer_counter))
print("Final WER = %f" % (wer_sum / wer_counter)) print("Final WER = %f" % (wer_sum / wer_counter))
......
...@@ -11,14 +11,14 @@ import paddle.v2 as paddle ...@@ -11,14 +11,14 @@ import paddle.v2 as paddle
from data_utils.data import DataGenerator from data_utils.data import DataGenerator
from model import deep_speech2 from model import deep_speech2
from decoder import * from decoder import *
from scorer import Scorer from lm.lm_scorer import LmScorer
from error_rate import wer from error_rate import wer
import utils import utils
parser = argparse.ArgumentParser(description=__doc__) parser = argparse.ArgumentParser(description=__doc__)
parser.add_argument( parser.add_argument(
"--num_samples", "--num_samples",
default=100, default=10,
type=int, type=int,
help="Number of samples for inference. (default: %(default)s)") help="Number of samples for inference. (default: %(default)s)")
parser.add_argument( parser.add_argument(
...@@ -46,6 +46,11 @@ parser.add_argument( ...@@ -46,6 +46,11 @@ parser.add_argument(
default=multiprocessing.cpu_count(), default=multiprocessing.cpu_count(),
type=int, type=int,
help="Number of cpu threads for preprocessing data. (default: %(default)s)") help="Number of cpu threads for preprocessing data. (default: %(default)s)")
parser.add_argument(
"--num_processes_beam_search",
default=multiprocessing.cpu_count(),
type=int,
help="Number of cpu processes for beam search. (default: %(default)s)")
parser.add_argument( parser.add_argument(
"--mean_std_filepath", "--mean_std_filepath",
default='mean_std.npz', default='mean_std.npz',
...@@ -53,12 +58,12 @@ parser.add_argument( ...@@ -53,12 +58,12 @@ parser.add_argument(
help="Manifest path for normalizer. (default: %(default)s)") help="Manifest path for normalizer. (default: %(default)s)")
parser.add_argument( parser.add_argument(
"--decode_manifest_path", "--decode_manifest_path",
default='data/manifest.libri.test-100sample', default='datasets/manifest.test',
type=str, type=str,
help="Manifest path for decoding. (default: %(default)s)") help="Manifest path for decoding. (default: %(default)s)")
parser.add_argument( parser.add_argument(
"--model_filepath", "--model_filepath",
default='checkpoints/params.latest.tar.gz', default='checkpoints/params.tar.gz.41',
type=str, type=str,
help="Model filepath. (default: %(default)s)") help="Model filepath. (default: %(default)s)")
parser.add_argument( parser.add_argument(
...@@ -68,12 +73,10 @@ parser.add_argument( ...@@ -68,12 +73,10 @@ parser.add_argument(
help="Vocabulary filepath. (default: %(default)s)") help="Vocabulary filepath. (default: %(default)s)")
parser.add_argument( parser.add_argument(
"--decode_method", "--decode_method",
default='beam_search_nproc', default='beam_search',
type=str, type=str,
help="Method for ctc decoding:" help="Method for ctc decoding: best_path or beam_search. (default: %(default)s)"
" best_path," )
" beam_search, "
" or beam_search_nproc. (default: %(default)s)")
parser.add_argument( parser.add_argument(
"--beam_size", "--beam_size",
default=500, default=500,
...@@ -86,7 +89,7 @@ parser.add_argument( ...@@ -86,7 +89,7 @@ parser.add_argument(
help="Number of output per sample in beam search. (default: %(default)d)") help="Number of output per sample in beam search. (default: %(default)d)")
parser.add_argument( parser.add_argument(
"--language_model_path", "--language_model_path",
default="data/en.00.UNKNOWN.klm", default="lm/data/en.00.UNKNOWN.klm",
type=str, type=str,
help="Path for language model. (default: %(default)s)") help="Path for language model. (default: %(default)s)")
parser.add_argument( parser.add_argument(
...@@ -143,6 +146,7 @@ def infer(): ...@@ -143,6 +146,7 @@ def infer():
batch_reader = data_generator.batch_reader_creator( batch_reader = data_generator.batch_reader_creator(
manifest_path=args.decode_manifest_path, manifest_path=args.decode_manifest_path,
batch_size=args.num_samples, batch_size=args.num_samples,
min_batch_size=1,
sortagrad=False, sortagrad=False,
shuffle_method=None) shuffle_method=None)
infer_data = batch_reader().next() infer_data = batch_reader().next()
...@@ -156,68 +160,45 @@ def infer(): ...@@ -156,68 +160,45 @@ def infer():
for i in xrange(len(infer_data)) for i in xrange(len(infer_data))
] ]
# targe transcription
target_transcription = [
''.join(
[data_generator.vocab_list[index] for index in infer_data[i][1]])
for i, probs in enumerate(probs_split)
]
## decode and print ## decode and print
# best path decode # best path decode
wer_sum, wer_counter = 0, 0 wer_sum, wer_counter = 0, 0
if args.decode_method == "best_path": if args.decode_method == "best_path":
for i, probs in enumerate(probs_split): for i, probs in enumerate(probs_split):
target_transcription = ''.join([ best_path_transcription = ctc_best_path_decoder(
data_generator.vocab_list[index] for index in infer_data[i][1]
])
best_path_transcription = ctc_best_path_decode(
probs_seq=probs, vocabulary=data_generator.vocab_list) probs_seq=probs, vocabulary=data_generator.vocab_list)
print("\nTarget Transcription: %s\nOutput Transcription: %s" % print("\nTarget Transcription: %s\nOutput Transcription: %s" %
(target_transcription, best_path_transcription)) (target_transcription[i], best_path_transcription))
wer_cur = wer(target_transcription, best_path_transcription) wer_cur = wer(target_transcription[i], best_path_transcription)
wer_sum += wer_cur wer_sum += wer_cur
wer_counter += 1 wer_counter += 1
print("cur wer = %f, average wer = %f" % print("cur wer = %f, average wer = %f" %
(wer_cur, wer_sum / wer_counter)) (wer_cur, wer_sum / wer_counter))
# beam search decode # beam search decode
elif args.decode_method == "beam_search": elif args.decode_method == "beam_search":
ext_scorer = Scorer(args.alpha, args.beta, args.language_model_path) ext_scorer = LmScorer(args.alpha, args.beta, args.language_model_path)
for i, probs in enumerate(probs_split): beam_search_batch_results = ctc_beam_search_decoder_batch(
target_transcription = ''.join([
data_generator.vocab_list[index] for index in infer_data[i][1]
])
beam_search_result = ctc_beam_search_decoder(
probs_seq=probs,
vocabulary=data_generator.vocab_list,
beam_size=args.beam_size,
blank_id=len(data_generator.vocab_list),
cutoff_prob=args.cutoff_prob,
ext_scoring_func=ext_scorer, )
print("\nTarget Transcription:\t%s" % target_transcription)
for index in xrange(args.num_results_per_sample):
result = beam_search_result[index]
#output: index, log prob, beam result
print("Beam %d: %f \t%s" % (index, result[0], result[1]))
wer_cur = wer(target_transcription, beam_search_result[0][1])
wer_sum += wer_cur
wer_counter += 1
print("cur wer = %f , average wer = %f" %
(wer_cur, wer_sum / wer_counter))
elif args.decode_method == "beam_search_nproc":
ext_scorer = Scorer(args.alpha, args.beta, args.language_model_path)
beam_search_nproc_results = ctc_beam_search_decoder_nproc(
probs_split=probs_split, probs_split=probs_split,
vocabulary=data_generator.vocab_list, vocabulary=data_generator.vocab_list,
beam_size=args.beam_size, beam_size=args.beam_size,
blank_id=len(data_generator.vocab_list), blank_id=len(data_generator.vocab_list),
num_processes=args.num_processes_beam_search,
cutoff_prob=args.cutoff_prob, cutoff_prob=args.cutoff_prob,
ext_scoring_func=ext_scorer, ) ext_scoring_func=ext_scorer, )
for i, beam_search_result in enumerate(beam_search_nproc_results): for i, beam_search_result in enumerate(beam_search_batch_results):
target_transcription = ''.join([ print("\nTarget Transcription:\t%s" % target_transcription[i])
data_generator.vocab_list[index] for index in infer_data[i][1]
])
print("\nTarget Transcription:\t%s" % target_transcription)
for index in xrange(args.num_results_per_sample): for index in xrange(args.num_results_per_sample):
result = beam_search_result[index] result = beam_search_result[index]
#output: index, log prob, beam result #output: index, log prob, beam result
print("Beam %d: %f \t%s" % (index, result[0], result[1])) print("Beam %d: %f \t%s" % (index, result[0], result[1]))
wer_cur = wer(target_transcription, beam_search_result[0][1]) wer_cur = wer(target_transcription[i], beam_search_result[0][1])
wer_sum += wer_cur wer_sum += wer_cur
wer_counter += 1 wer_counter += 1
print("cur wer = %f , average wer = %f" % print("cur wer = %f , average wer = %f" %
......
...@@ -8,13 +8,16 @@ import kenlm ...@@ -8,13 +8,16 @@ import kenlm
import numpy as np import numpy as np
class Scorer(object): class LmScorer(object):
"""External defined scorer to evaluate a sentence in beam search """External scorer to evaluate a prefix or whole sentence in
decoding, consisting of language model and word count. beam search decoding, including the score from n-gram language
model and word count.
:param alpha: Parameter associated with language model. :param alpha: Parameter associated with language model. Don't use
language model when alpha = 0.
:type alpha: float :type alpha: float
:param beta: Parameter associated with word count. :param beta: Parameter associated with word count. Don't use word
count when beta = 0.
:type beta: float :type beta: float
:model_path: Path to load language model. :model_path: Path to load language model.
:type model_path: basestring :type model_path: basestring
...@@ -28,14 +31,14 @@ class Scorer(object): ...@@ -28,14 +31,14 @@ class Scorer(object):
self._language_model = kenlm.LanguageModel(model_path) self._language_model = kenlm.LanguageModel(model_path)
# n-gram language model scoring # n-gram language model scoring
def language_model_score(self, sentence): def _language_model_score(self, sentence):
#log10 prob of last word #log10 prob of last word
log_cond_prob = list( log_cond_prob = list(
self._language_model.full_scores(sentence, eos=False))[-1][0] self._language_model.full_scores(sentence, eos=False))[-1][0]
return np.power(10, log_cond_prob) return np.power(10, log_cond_prob)
# word insertion term # word insertion term
def word_count(self, sentence): def _word_count(self, sentence):
words = sentence.strip().split(' ') words = sentence.strip().split(' ')
return len(words) return len(words)
...@@ -51,8 +54,8 @@ class Scorer(object): ...@@ -51,8 +54,8 @@ class Scorer(object):
:return: Evaluation score, in the decimal or log. :return: Evaluation score, in the decimal or log.
:rtype: float :rtype: float
""" """
lm = self.language_model_score(sentence) lm = self._language_model_score(sentence)
word_cnt = self.word_count(sentence) word_cnt = self._word_count(sentence)
if log == False: if log == False:
score = np.power(lm, self._alpha) \ score = np.power(lm, self._alpha) \
* np.power(word_cnt, self._beta) * np.power(word_cnt, self._beta)
......
echo "Downloading language model."
wget -c ftp://xxx/xxx/en.00.UNKNOWN.klm -P ./data
SoundFile==0.9.0.post1 SoundFile==0.9.0.post1
wget==3.2 wget==3.2
scipy==0.13.1 scipy==0.13.1
https://github.com/kpu/kenlm/archive/master.zip
...@@ -53,11 +53,11 @@ class TestDecoders(unittest.TestCase): ...@@ -53,11 +53,11 @@ class TestDecoders(unittest.TestCase):
self.beam_search_result = ['acdc', "b'a"] self.beam_search_result = ['acdc', "b'a"]
def test_best_path_decoder_1(self): def test_best_path_decoder_1(self):
bst_result = ctc_best_path_decode(self.probs_seq1, self.vocab_list) bst_result = ctc_best_path_decoder(self.probs_seq1, self.vocab_list)
self.assertEqual(bst_result, self.best_path_result[0]) self.assertEqual(bst_result, self.best_path_result[0])
def test_best_path_decoder_2(self): def test_best_path_decoder_2(self):
bst_result = ctc_best_path_decode(self.probs_seq2, self.vocab_list) bst_result = ctc_best_path_decoder(self.probs_seq2, self.vocab_list)
self.assertEqual(bst_result, self.best_path_result[1]) self.assertEqual(bst_result, self.best_path_result[1])
def test_beam_search_decoder_1(self): def test_beam_search_decoder_1(self):
...@@ -77,7 +77,7 @@ class TestDecoders(unittest.TestCase): ...@@ -77,7 +77,7 @@ class TestDecoders(unittest.TestCase):
self.assertEqual(beam_result[0][1], self.beam_search_result[1]) self.assertEqual(beam_result[0][1], self.beam_search_result[1])
def test_beam_search_nproc_decoder(self): def test_beam_search_nproc_decoder(self):
beam_results = ctc_beam_search_decoder_nproc( beam_results = ctc_beam_search_decoder_batch(
probs_split=[self.probs_seq1, self.probs_seq2], probs_split=[self.probs_seq1, self.probs_seq2],
beam_size=self.beam_size, beam_size=self.beam_size,
vocabulary=self.vocab_list, vocabulary=self.vocab_list,
......
...@@ -3,14 +3,14 @@ from __future__ import absolute_import ...@@ -3,14 +3,14 @@ from __future__ import absolute_import
from __future__ import division from __future__ import division
from __future__ import print_function from __future__ import print_function
import paddle.v2 as paddle
import distutils.util import distutils.util
import argparse import argparse
import gzip import gzip
import paddle.v2 as paddle
from data_utils.data import DataGenerator from data_utils.data import DataGenerator
from model import deep_speech2 from model import deep_speech2
from decoder import * from decoder import *
from scorer import Scorer from lm.lm_scorer import LmScorer
from error_rate import wer from error_rate import wer
parser = argparse.ArgumentParser(description=__doc__) parser = argparse.ArgumentParser(description=__doc__)
...@@ -39,24 +39,29 @@ parser.add_argument( ...@@ -39,24 +39,29 @@ parser.add_argument(
default=True, default=True,
type=distutils.util.strtobool, type=distutils.util.strtobool,
help="Use gpu or not. (default: %(default)s)") help="Use gpu or not. (default: %(default)s)")
parser.add_argument(
"--num_threads_data",
default=multiprocessing.cpu_count(),
type=int,
help="Number of cpu threads for preprocessing data. (default: %(default)s)")
parser.add_argument(
"--num_processes_beam_search",
default=multiprocessing.cpu_count(),
type=int,
help="Number of cpu processes for beam search. (default: %(default)s)")
parser.add_argument( parser.add_argument(
"--mean_std_filepath", "--mean_std_filepath",
default='mean_std.npz', default='mean_std.npz',
type=str, type=str,
help="Manifest path for normalizer. (default: %(default)s)") help="Manifest path for normalizer. (default: %(default)s)")
parser.add_argument(
"--normalizer_manifest_path",
default='data/manifest.libri.train-clean-100',
type=str,
help="Manifest path for normalizer. (default: %(default)s)")
parser.add_argument( parser.add_argument(
"--decode_manifest_path", "--decode_manifest_path",
default='data/manifest.libri.test-100sample', default='datasets/manifest.test',
type=str, type=str,
help="Manifest path for decoding. (default: %(default)s)") help="Manifest path for decoding. (default: %(default)s)")
parser.add_argument( parser.add_argument(
"--model_filepath", "--model_filepath",
default='./params.tar.gz', default='checkpoints/params.latest.tar.gz',
type=str, type=str,
help="Model filepath. (default: %(default)s)") help="Model filepath. (default: %(default)s)")
parser.add_argument( parser.add_argument(
...@@ -64,25 +69,14 @@ parser.add_argument( ...@@ -64,25 +69,14 @@ parser.add_argument(
default='datasets/vocab/eng_vocab.txt', default='datasets/vocab/eng_vocab.txt',
type=str, type=str,
help="Vocabulary filepath. (default: %(default)s)") help="Vocabulary filepath. (default: %(default)s)")
parser.add_argument(
"--decode_method",
default='beam_search_nproc',
type=str,
help="Method for decoding, beam_search or beam_search_nproc. (default: %(default)s)"
)
parser.add_argument( parser.add_argument(
"--beam_size", "--beam_size",
default=500, default=500,
type=int, type=int,
help="Width for beam search decoding. (default: %(default)d)") help="Width for beam search decoding. (default: %(default)d)")
parser.add_argument(
"--num_results_per_sample",
default=1,
type=int,
help="Number of outputs per sample in beam search. (default: %(default)d)")
parser.add_argument( parser.add_argument(
"--language_model_path", "--language_model_path",
default="data/en.00.UNKNOWN.klm", default="lm/data/en.00.UNKNOWN.klm",
type=str, type=str,
help="Path for language model. (default: %(default)s)") help="Path for language model. (default: %(default)s)")
parser.add_argument( parser.add_argument(
...@@ -137,7 +131,8 @@ def tune(): ...@@ -137,7 +131,8 @@ def tune():
data_generator = DataGenerator( data_generator = DataGenerator(
vocab_filepath=args.vocab_filepath, vocab_filepath=args.vocab_filepath,
mean_std_filepath=args.mean_std_filepath, mean_std_filepath=args.mean_std_filepath,
augmentation_config='{}') augmentation_config='{}',
num_threads=args.num_threads_data)
# create network config # create network config
# paddle.data_type.dense_array is used for variable batch input. # paddle.data_type.dense_array is used for variable batch input.
...@@ -188,42 +183,22 @@ def tune(): ...@@ -188,42 +183,22 @@ def tune():
## tune parameters in loop ## tune parameters in loop
for (alpha, beta) in params_grid: for (alpha, beta) in params_grid:
wer_sum, wer_counter = 0, 0 wer_sum, wer_counter = 0, 0
ext_scorer = Scorer(alpha, beta, args.language_model_path) ext_scorer = LmScorer(alpha, beta, args.language_model_path)
# beam search decode
if args.decode_method == "beam_search":
for i, probs in enumerate(probs_split):
target_transcription = ''.join([
data_generator.vocab_list[index]
for index in infer_data[i][1]
])
beam_search_result = ctc_beam_search_decoder(
probs_seq=probs,
vocabulary=data_generator.vocab_list,
beam_size=args.beam_size,
blank_id=len(data_generator.vocab_list),
cutoff_prob=args.cutoff_prob,
ext_scoring_func=ext_scorer, )
wer_sum += wer(target_transcription, beam_search_result[0][1])
wer_counter += 1
# beam search using multiple processes # beam search using multiple processes
elif args.decode_method == "beam_search_nproc": beam_search_results = ctc_beam_search_decoder_batch(
beam_search_nproc_results = ctc_beam_search_decoder_nproc(
probs_split=probs_split, probs_split=probs_split,
vocabulary=data_generator.vocab_list, vocabulary=data_generator.vocab_list,
beam_size=args.beam_size, beam_size=args.beam_size,
cutoff_prob=args.cutoff_prob, cutoff_prob=args.cutoff_prob,
blank_id=len(data_generator.vocab_list), blank_id=len(data_generator.vocab_list),
num_processes=args.num_processes_beam_search,
ext_scoring_func=ext_scorer, ) ext_scoring_func=ext_scorer, )
for i, beam_search_result in enumerate(beam_search_nproc_results): for i, beam_search_result in enumerate(beam_search_results):
target_transcription = ''.join([ target_transcription = ''.join([
data_generator.vocab_list[index] data_generator.vocab_list[index] for index in infer_data[i][1]
for index in infer_data[i][1]
]) ])
wer_sum += wer(target_transcription, beam_search_result[0][1]) wer_sum += wer(target_transcription, beam_search_result[0][1])
wer_counter += 1 wer_counter += 1
else:
raise ValueError("Decoding method [%s] is not supported." %
decode_method)
print("alpha = %f\tbeta = %f\tWER = %f" % print("alpha = %f\tbeta = %f\tWER = %f" %
(alpha, beta, wer_sum / wer_counter)) (alpha, beta, wer_sum / wer_counter))
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册