From 26510f74a63307786f83db3f9faa2f579292e1f4 Mon Sep 17 00:00:00 2001 From: Yibing Liu Date: Tue, 27 Jun 2017 17:42:44 +0800 Subject: [PATCH] refine ctc_beam_search_decoder --- decoder.py | 128 +++++++++++++++++++---------------- evaluate.py | 89 +++++++++++------------- infer.py | 79 ++++++++------------- lm/__init__.py | 0 scorer.py => lm/lm_scorer.py | 21 +++--- lm/run.sh | 3 + requirements.txt | 1 + tests/test_decoders.py | 6 +- tune.py | 89 +++++++++--------------- 9 files changed, 187 insertions(+), 229 deletions(-) create mode 100644 lm/__init__.py rename scorer.py => lm/lm_scorer.py (73%) create mode 100644 lm/run.sh diff --git a/decoder.py b/decoder.py index 00659367..4676b02b 100644 --- a/decoder.py +++ b/decoder.py @@ -8,8 +8,8 @@ import numpy as np import multiprocessing -def ctc_best_path_decode(probs_seq, vocabulary): - """Best path decoding, also called argmax decoding or greedy decoding. +def ctc_best_path_decoder(probs_seq, vocabulary): + """Best path decoder, also called argmax decoder or greedy decoder. Path consisting of the most probable tokens are further post-processed to remove consecutive repetitions and all blanks. @@ -40,73 +40,84 @@ def ctc_best_path_decode(probs_seq, vocabulary): def ctc_beam_search_decoder(probs_seq, beam_size, vocabulary, - blank_id=0, + blank_id, cutoff_prob=1.0, ext_scoring_func=None, nproc=False): - '''Beam search decoder for CTC-trained network, using beam search with width - beam_size to find many paths to one label, return beam_size labels in - the descending order of probabilities. The implementation is based on Prefix - Beam Search(https://arxiv.org/abs/1408.2873), and the unclear part is - redesigned. - - :param probs_seq: 2-D list with length num_time_steps, each element - is a list of normalized probabilities over vocabulary - and blank for one time step. + """Beam search decoder for CTC-trained network. It utilizes beam search + to approximately select top best decoding labels and returning results + in the descending order. The implementation is based on Prefix + Beam Search (https://arxiv.org/abs/1408.2873), and the unclear part is + redesigned. Two important modifications: 1) in the iterative computation + of probabilities, the assignment operation is changed to accumulation for + one prefix may comes from different paths; 2) the if condition "if l^+ not + in A_prev then" after probabilities' computation is deprecated for it is + hard to understand and seems unnecessary. + + :param probs_seq: 2-D list of probability distributions over each time + step, with each element being a list of normalized + probabilities over vocabulary and blank. :type probs_seq: 2-D list :param beam_size: Width for beam search. :type beam_size: int :param vocabulary: Vocabulary list. :type vocabulary: list - :param blank_id: ID of blank, default 0. + :param blank_id: ID of blank. :type blank_id: int :param cutoff_prob: Cutoff probability in pruning, default 1.0, no pruning. :type cutoff_prob: float - :param ext_scoring_func: External defined scoring function for + :param ext_scoring_func: External scoring function for partially decoded sentence, e.g. word count - and language model. - :type external_scoring_function: function + or language model. + :type external_scoring_func: callable :param nproc: Whether the decoder used in multiprocesses. :type nproc: bool - :return: Decoding log probabilities and result sentences in descending order. + :return: List of tuples of log probability and sentence as decoding + results, in descending order of the probability. :rtype: list - ''' + """ # dimension check for prob_list in probs_seq: if not len(prob_list) == len(vocabulary) + 1: - raise ValueError("probs dimension mismatched with vocabulary") - num_time_steps = len(probs_seq) + raise ValueError("The shape of prob_seq does not match with the " + "shape of the vocabulary.") # blank_id check - probs_dim = len(probs_seq[0]) - if not blank_id < probs_dim: + if not blank_id < len(probs_seq[0]): raise ValueError("blank_id shouldn't be greater than probs dimension") # If the decoder called in the multiprocesses, then use the global scorer - # instantiated in ctc_beam_search_decoder_nproc(). + # instantiated in ctc_beam_search_decoder_batch(). if nproc is True: global ext_nproc_scorer ext_scoring_func = ext_nproc_scorer ## initialize - # the set containing selected prefixes - prefix_set_prev = {'\t': 1.0} - probs_b_prev, probs_nb_prev = {'\t': 1.0}, {'\t': 0.0} + # prefix_set_prev: the set containing selected prefixes + # probs_b_prev: prefixes' probability ending with blank in previous step + # probs_nb_prev: prefixes' probability ending with non-blank in previous step + prefix_set_prev, probs_b_prev, probs_nb_prev = { + '\t': 1.0 + }, { + '\t': 1.0 + }, { + '\t': 0.0 + } ## extend prefix in loop - for time_step in xrange(num_time_steps): - # the set containing candidate prefixes - prefix_set_next = {} - probs_b_cur, probs_nb_cur = {}, {} - prob = probs_seq[time_step] - prob_idx = [[i, prob[i]] for i in xrange(len(prob))] + for time_step in xrange(len(probs_seq)): + # prefix_set_next: the set containing candidate prefixes + # probs_b_cur: prefixes' probability ending with blank in current step + # probs_nb_cur: prefixes' probability ending with non-blank in current step + prefix_set_next, probs_b_cur, probs_nb_cur = {}, {}, {} + + prob_idx = list(enumerate(probs_seq[time_step])) cutoff_len = len(prob_idx) #If pruning is enabled - if (cutoff_prob < 1.0): + if cutoff_prob < 1.0: prob_idx = sorted(prob_idx, key=lambda asd: asd[1], reverse=True) - cutoff_len = 0 - cum_prob = 0.0 + cutoff_len, cum_prob = 0, 0.0 for i in xrange(len(prob_idx)): cum_prob += prob_idx[i][1] cutoff_len += 1 @@ -162,54 +173,53 @@ def ctc_beam_search_decoder(probs_seq, prefix_set_prev = dict(prefix_set_prev) beam_result = [] - for (seq, prob) in prefix_set_prev.items(): + for seq, prob in prefix_set_prev.items(): if prob > 0.0 and len(seq) > 1: result = seq[1:] # score last word by external scorer if (ext_scoring_func is not None) and (result[-1] != ' '): prob = prob * ext_scoring_func(result) log_prob = np.log(prob) - beam_result.append([log_prob, result]) + beam_result.append((log_prob, result)) ## output top beam_size decoding results beam_result = sorted(beam_result, key=lambda asd: asd[0], reverse=True) return beam_result -def ctc_beam_search_decoder_nproc(probs_split, +def ctc_beam_search_decoder_batch(probs_split, beam_size, vocabulary, - blank_id=0, + blank_id, + num_processes, cutoff_prob=1.0, - ext_scoring_func=None, - num_processes=None): - '''Beam search decoder using multiple processes. + ext_scoring_func=None): + """CTC beam search decoder using multiple processes. - :param probs_seq: 3-D list with length batch_size, each element - is a 2-D list of probabilities can be used by - ctc_beam_search_decoder. + :param probs_seq: 3-D list with each element as an instance of 2-D list + of probabilities used by ctc_beam_search_decoder(). :type probs_seq: 3-D list :param beam_size: Width for beam search. :type beam_size: int :param vocabulary: Vocabulary list. :type vocabulary: list - :param blank_id: ID of blank, default 0. + :param blank_id: ID of blank. :type blank_id: int + :param num_processes: Number of parallel processes. + :type num_processes: int :param cutoff_prob: Cutoff probability in pruning, - default 0, no pruning. + default 1.0, no pruning. + :param num_processes: Number of parallel processes. + :type num_processes: int :type cutoff_prob: float - :param ext_scoring_func: External defined scoring function for + :param ext_scoring_func: External scoring function for partially decoded sentence, e.g. word count - and language model. - :type external_scoring_function: function - :param num_processes: Number of processes, default None, equal to the - number of CPUs. - :type num_processes: int - :return: Decoding log probabilities and result sentences in descending order. + or language model. + :type external_scoring_function: callable + :return: List of tuples of log probability and sentence as decoding + results, in descending order of the probability. :rtype: list - ''' - if num_processes is None: - num_processes = multiprocessing.cpu_count() + """ if not num_processes > 0: raise ValueError("Number of processes must be positive!") @@ -227,7 +237,5 @@ def ctc_beam_search_decoder_nproc(probs_split, pool.close() pool.join() - beam_search_results = [] - for result in results: - beam_search_results.append(result.get()) + beam_search_results = [result.get() for result in results] return beam_search_results diff --git a/evaluate.py b/evaluate.py index a7b8e221..7ef32ad1 100644 --- a/evaluate.py +++ b/evaluate.py @@ -3,22 +3,22 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function -import paddle.v2 as paddle import distutils.util import argparse import gzip +import paddle.v2 as paddle from data_utils.data import DataGenerator from model import deep_speech2 from decoder import * -from scorer import Scorer +from lm.lm_scorer import LmScorer from error_rate import wer parser = argparse.ArgumentParser(description=__doc__) parser.add_argument( - "--num_samples", + "--batch_size", default=100, type=int, - help="Number of samples for evaluation. (default: %(default)s)") + help="Minibatch size for evaluation. (default: %(default)s)") parser.add_argument( "--num_conv_layers", default=2, @@ -39,6 +39,16 @@ parser.add_argument( default=True, type=distutils.util.strtobool, help="Use gpu or not. (default: %(default)s)") +parser.add_argument( + "--num_threads_data", + default=multiprocessing.cpu_count(), + type=int, + help="Number of cpu threads for preprocessing data. (default: %(default)s)") +parser.add_argument( + "--num_processes_beam_search", + default=multiprocessing.cpu_count(), + type=int, + help="Number of cpu processes for beam search. (default: %(default)s)") parser.add_argument( "--mean_std_filepath", default='mean_std.npz', @@ -46,10 +56,10 @@ parser.add_argument( help="Manifest path for normalizer. (default: %(default)s)") parser.add_argument( "--decode_method", - default='beam_search_nproc', + default='beam_search', type=str, - help="Method for ctc decoding, best_path, " - "beam_search or beam_search_nproc. (default: %(default)s)") + help="Method for ctc decoding, best_path or beam_search. (default: %(default)s)" +) parser.add_argument( "--language_model_path", default="data/en.00.UNKNOWN.klm", @@ -76,11 +86,6 @@ parser.add_argument( default=500, type=int, help="Width for beam search decoding. (default: %(default)d)") -parser.add_argument( - "--normalizer_manifest_path", - default='data/manifest.libri.train-clean-100', - type=str, - help="Manifest path for normalizer. (default: %(default)s)") parser.add_argument( "--decode_manifest_path", default='data/manifest.libri.test-clean', @@ -88,7 +93,7 @@ parser.add_argument( help="Manifest path for decoding. (default: %(default)s)") parser.add_argument( "--model_filepath", - default='./params.tar.gz', + default='checkpoints/params.latest.tar.gz', type=str, help="Model filepath. (default: %(default)s)") parser.add_argument( @@ -101,12 +106,12 @@ args = parser.parse_args() def evaluate(): """Evaluate on whole test data for DeepSpeech2.""" - # initialize data generator data_generator = DataGenerator( vocab_filepath=args.vocab_filepath, mean_std_filepath=args.mean_std_filepath, - augmentation_config='{}') + augmentation_config='{}', + num_threads=args.num_threads_data) # create network config # paddle.data_type.dense_array is used for variable batch input. @@ -133,7 +138,7 @@ def evaluate(): # prepare infer data batch_reader = data_generator.batch_reader_creator( manifest_path=args.decode_manifest_path, - batch_size=args.num_samples, + batch_size=args.batch_size, sortagrad=False, shuffle_method=None) @@ -142,9 +147,8 @@ def evaluate(): output_layer=output_probs, parameters=parameters) # initialize external scorer for beam search decoding - if args.decode_method == 'beam_search' or \ - args.decode_method == 'beam_search_nproc': - ext_scorer = Scorer(args.alpha, args.beta, args.language_model_path) + if args.decode_method == 'beam_search': + ext_scorer = LmScorer(args.alpha, args.beta, args.language_model_path) wer_counter, wer_sum = 0, 0.0 for infer_data in batch_reader(): @@ -155,56 +159,39 @@ def evaluate(): infer_results[i * num_steps:(i + 1) * num_steps] for i in xrange(0, len(infer_data)) ] - + # target transcription + target_transcription = [ + ''.join([ + data_generator.vocab_list[index] for index in infer_data[i][1] + ]) for i, probs in enumerate(probs_split) + ] # decode and print # best path decode if args.decode_method == "best_path": for i, probs in enumerate(probs_split): - output_transcription = ctc_best_path_decode( + output_transcription = ctc_best_path_decoder( probs_seq=probs, vocabulary=data_generator.vocab_list) - target_transcription = ''.join([ - data_generator.vocab_list[index] - for index in infer_data[i][1] - ]) - wer_sum += wer(target_transcription, output_transcription) + wer_sum += wer(target_transcription[i], output_transcription) wer_counter += 1 - # beam search decode in single process + # beam search decode elif args.decode_method == "beam_search": - for i, probs in enumerate(probs_split): - target_transcription = ''.join([ - data_generator.vocab_list[index] - for index in infer_data[i][1] - ]) - beam_search_result = ctc_beam_search_decoder( - probs_seq=probs, - vocabulary=data_generator.vocab_list, - beam_size=args.beam_size, - blank_id=len(data_generator.vocab_list), - ext_scoring_func=ext_scorer, - cutoff_prob=args.cutoff_prob, ) - wer_sum += wer(target_transcription, beam_search_result[0][1]) - wer_counter += 1 - # beam search using multiple processes - elif args.decode_method == "beam_search_nproc": - beam_search_nproc_results = ctc_beam_search_decoder_nproc( + # beam search using multiple processes + beam_search_results = ctc_beam_search_decoder_batch( probs_split=probs_split, vocabulary=data_generator.vocab_list, beam_size=args.beam_size, blank_id=len(data_generator.vocab_list), + num_processes=args.num_processes_beam_search, ext_scoring_func=ext_scorer, cutoff_prob=args.cutoff_prob, ) - for i, beam_search_result in enumerate(beam_search_nproc_results): - target_transcription = ''.join([ - data_generator.vocab_list[index] - for index in infer_data[i][1] - ]) - wer_sum += wer(target_transcription, beam_search_result[0][1]) + for i, beam_search_result in enumerate(beam_search_results): + wer_sum += wer(target_transcription[i], + beam_search_result[0][1]) wer_counter += 1 else: raise ValueError("Decoding method [%s] is not supported." % decode_method) - print("Cur WER = %f" % (wer_sum / wer_counter)) print("Final WER = %f" % (wer_sum / wer_counter)) diff --git a/infer.py b/infer.py index 069b9e3e..5f0f268a 100644 --- a/infer.py +++ b/infer.py @@ -11,14 +11,14 @@ import paddle.v2 as paddle from data_utils.data import DataGenerator from model import deep_speech2 from decoder import * -from scorer import Scorer +from lm.lm_scorer import LmScorer from error_rate import wer import utils parser = argparse.ArgumentParser(description=__doc__) parser.add_argument( "--num_samples", - default=100, + default=10, type=int, help="Number of samples for inference. (default: %(default)s)") parser.add_argument( @@ -46,6 +46,11 @@ parser.add_argument( default=multiprocessing.cpu_count(), type=int, help="Number of cpu threads for preprocessing data. (default: %(default)s)") +parser.add_argument( + "--num_processes_beam_search", + default=multiprocessing.cpu_count(), + type=int, + help="Number of cpu processes for beam search. (default: %(default)s)") parser.add_argument( "--mean_std_filepath", default='mean_std.npz', @@ -53,12 +58,12 @@ parser.add_argument( help="Manifest path for normalizer. (default: %(default)s)") parser.add_argument( "--decode_manifest_path", - default='data/manifest.libri.test-100sample', + default='datasets/manifest.test', type=str, help="Manifest path for decoding. (default: %(default)s)") parser.add_argument( "--model_filepath", - default='checkpoints/params.latest.tar.gz', + default='checkpoints/params.tar.gz.41', type=str, help="Model filepath. (default: %(default)s)") parser.add_argument( @@ -68,12 +73,10 @@ parser.add_argument( help="Vocabulary filepath. (default: %(default)s)") parser.add_argument( "--decode_method", - default='beam_search_nproc', + default='beam_search', type=str, - help="Method for ctc decoding:" - " best_path," - " beam_search, " - " or beam_search_nproc. (default: %(default)s)") + help="Method for ctc decoding: best_path or beam_search. (default: %(default)s)" +) parser.add_argument( "--beam_size", default=500, @@ -86,7 +89,7 @@ parser.add_argument( help="Number of output per sample in beam search. (default: %(default)d)") parser.add_argument( "--language_model_path", - default="data/en.00.UNKNOWN.klm", + default="lm/data/en.00.UNKNOWN.klm", type=str, help="Path for language model. (default: %(default)s)") parser.add_argument( @@ -143,6 +146,7 @@ def infer(): batch_reader = data_generator.batch_reader_creator( manifest_path=args.decode_manifest_path, batch_size=args.num_samples, + min_batch_size=1, sortagrad=False, shuffle_method=None) infer_data = batch_reader().next() @@ -156,68 +160,45 @@ def infer(): for i in xrange(len(infer_data)) ] + # targe transcription + target_transcription = [ + ''.join( + [data_generator.vocab_list[index] for index in infer_data[i][1]]) + for i, probs in enumerate(probs_split) + ] + ## decode and print # best path decode wer_sum, wer_counter = 0, 0 if args.decode_method == "best_path": for i, probs in enumerate(probs_split): - target_transcription = ''.join([ - data_generator.vocab_list[index] for index in infer_data[i][1] - ]) - best_path_transcription = ctc_best_path_decode( + best_path_transcription = ctc_best_path_decoder( probs_seq=probs, vocabulary=data_generator.vocab_list) print("\nTarget Transcription: %s\nOutput Transcription: %s" % - (target_transcription, best_path_transcription)) - wer_cur = wer(target_transcription, best_path_transcription) + (target_transcription[i], best_path_transcription)) + wer_cur = wer(target_transcription[i], best_path_transcription) wer_sum += wer_cur wer_counter += 1 print("cur wer = %f, average wer = %f" % (wer_cur, wer_sum / wer_counter)) # beam search decode elif args.decode_method == "beam_search": - ext_scorer = Scorer(args.alpha, args.beta, args.language_model_path) - for i, probs in enumerate(probs_split): - target_transcription = ''.join([ - data_generator.vocab_list[index] for index in infer_data[i][1] - ]) - beam_search_result = ctc_beam_search_decoder( - probs_seq=probs, - vocabulary=data_generator.vocab_list, - beam_size=args.beam_size, - blank_id=len(data_generator.vocab_list), - cutoff_prob=args.cutoff_prob, - ext_scoring_func=ext_scorer, ) - print("\nTarget Transcription:\t%s" % target_transcription) - - for index in xrange(args.num_results_per_sample): - result = beam_search_result[index] - #output: index, log prob, beam result - print("Beam %d: %f \t%s" % (index, result[0], result[1])) - wer_cur = wer(target_transcription, beam_search_result[0][1]) - wer_sum += wer_cur - wer_counter += 1 - print("cur wer = %f , average wer = %f" % - (wer_cur, wer_sum / wer_counter)) - elif args.decode_method == "beam_search_nproc": - ext_scorer = Scorer(args.alpha, args.beta, args.language_model_path) - beam_search_nproc_results = ctc_beam_search_decoder_nproc( + ext_scorer = LmScorer(args.alpha, args.beta, args.language_model_path) + beam_search_batch_results = ctc_beam_search_decoder_batch( probs_split=probs_split, vocabulary=data_generator.vocab_list, beam_size=args.beam_size, blank_id=len(data_generator.vocab_list), + num_processes=args.num_processes_beam_search, cutoff_prob=args.cutoff_prob, ext_scoring_func=ext_scorer, ) - for i, beam_search_result in enumerate(beam_search_nproc_results): - target_transcription = ''.join([ - data_generator.vocab_list[index] for index in infer_data[i][1] - ]) - print("\nTarget Transcription:\t%s" % target_transcription) - + for i, beam_search_result in enumerate(beam_search_batch_results): + print("\nTarget Transcription:\t%s" % target_transcription[i]) for index in xrange(args.num_results_per_sample): result = beam_search_result[index] #output: index, log prob, beam result print("Beam %d: %f \t%s" % (index, result[0], result[1])) - wer_cur = wer(target_transcription, beam_search_result[0][1]) + wer_cur = wer(target_transcription[i], beam_search_result[0][1]) wer_sum += wer_cur wer_counter += 1 print("cur wer = %f , average wer = %f" % diff --git a/lm/__init__.py b/lm/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/scorer.py b/lm/lm_scorer.py similarity index 73% rename from scorer.py rename to lm/lm_scorer.py index 4f468481..1c029e97 100644 --- a/scorer.py +++ b/lm/lm_scorer.py @@ -8,13 +8,16 @@ import kenlm import numpy as np -class Scorer(object): - """External defined scorer to evaluate a sentence in beam search - decoding, consisting of language model and word count. +class LmScorer(object): + """External scorer to evaluate a prefix or whole sentence in + beam search decoding, including the score from n-gram language + model and word count. - :param alpha: Parameter associated with language model. + :param alpha: Parameter associated with language model. Don't use + language model when alpha = 0. :type alpha: float - :param beta: Parameter associated with word count. + :param beta: Parameter associated with word count. Don't use word + count when beta = 0. :type beta: float :model_path: Path to load language model. :type model_path: basestring @@ -28,14 +31,14 @@ class Scorer(object): self._language_model = kenlm.LanguageModel(model_path) # n-gram language model scoring - def language_model_score(self, sentence): + def _language_model_score(self, sentence): #log10 prob of last word log_cond_prob = list( self._language_model.full_scores(sentence, eos=False))[-1][0] return np.power(10, log_cond_prob) # word insertion term - def word_count(self, sentence): + def _word_count(self, sentence): words = sentence.strip().split(' ') return len(words) @@ -51,8 +54,8 @@ class Scorer(object): :return: Evaluation score, in the decimal or log. :rtype: float """ - lm = self.language_model_score(sentence) - word_cnt = self.word_count(sentence) + lm = self._language_model_score(sentence) + word_cnt = self._word_count(sentence) if log == False: score = np.power(lm, self._alpha) \ * np.power(word_cnt, self._beta) diff --git a/lm/run.sh b/lm/run.sh new file mode 100644 index 00000000..bf523740 --- /dev/null +++ b/lm/run.sh @@ -0,0 +1,3 @@ +echo "Downloading language model." + +wget -c ftp://xxx/xxx/en.00.UNKNOWN.klm -P ./data diff --git a/requirements.txt b/requirements.txt index 0183ecf0..ce024591 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,3 +1,4 @@ SoundFile==0.9.0.post1 wget==3.2 scipy==0.13.1 +https://github.com/kpu/kenlm/archive/master.zip diff --git a/tests/test_decoders.py b/tests/test_decoders.py index 7fa89c5f..4435355c 100644 --- a/tests/test_decoders.py +++ b/tests/test_decoders.py @@ -53,11 +53,11 @@ class TestDecoders(unittest.TestCase): self.beam_search_result = ['acdc', "b'a"] def test_best_path_decoder_1(self): - bst_result = ctc_best_path_decode(self.probs_seq1, self.vocab_list) + bst_result = ctc_best_path_decoder(self.probs_seq1, self.vocab_list) self.assertEqual(bst_result, self.best_path_result[0]) def test_best_path_decoder_2(self): - bst_result = ctc_best_path_decode(self.probs_seq2, self.vocab_list) + bst_result = ctc_best_path_decoder(self.probs_seq2, self.vocab_list) self.assertEqual(bst_result, self.best_path_result[1]) def test_beam_search_decoder_1(self): @@ -77,7 +77,7 @@ class TestDecoders(unittest.TestCase): self.assertEqual(beam_result[0][1], self.beam_search_result[1]) def test_beam_search_nproc_decoder(self): - beam_results = ctc_beam_search_decoder_nproc( + beam_results = ctc_beam_search_decoder_batch( probs_split=[self.probs_seq1, self.probs_seq2], beam_size=self.beam_size, vocabulary=self.vocab_list, diff --git a/tune.py b/tune.py index 02076349..9cea66b9 100644 --- a/tune.py +++ b/tune.py @@ -3,14 +3,14 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function -import paddle.v2 as paddle import distutils.util import argparse import gzip +import paddle.v2 as paddle from data_utils.data import DataGenerator from model import deep_speech2 from decoder import * -from scorer import Scorer +from lm.lm_scorer import LmScorer from error_rate import wer parser = argparse.ArgumentParser(description=__doc__) @@ -39,24 +39,29 @@ parser.add_argument( default=True, type=distutils.util.strtobool, help="Use gpu or not. (default: %(default)s)") +parser.add_argument( + "--num_threads_data", + default=multiprocessing.cpu_count(), + type=int, + help="Number of cpu threads for preprocessing data. (default: %(default)s)") +parser.add_argument( + "--num_processes_beam_search", + default=multiprocessing.cpu_count(), + type=int, + help="Number of cpu processes for beam search. (default: %(default)s)") parser.add_argument( "--mean_std_filepath", default='mean_std.npz', type=str, help="Manifest path for normalizer. (default: %(default)s)") -parser.add_argument( - "--normalizer_manifest_path", - default='data/manifest.libri.train-clean-100', - type=str, - help="Manifest path for normalizer. (default: %(default)s)") parser.add_argument( "--decode_manifest_path", - default='data/manifest.libri.test-100sample', + default='datasets/manifest.test', type=str, help="Manifest path for decoding. (default: %(default)s)") parser.add_argument( "--model_filepath", - default='./params.tar.gz', + default='checkpoints/params.latest.tar.gz', type=str, help="Model filepath. (default: %(default)s)") parser.add_argument( @@ -64,25 +69,14 @@ parser.add_argument( default='datasets/vocab/eng_vocab.txt', type=str, help="Vocabulary filepath. (default: %(default)s)") -parser.add_argument( - "--decode_method", - default='beam_search_nproc', - type=str, - help="Method for decoding, beam_search or beam_search_nproc. (default: %(default)s)" -) parser.add_argument( "--beam_size", default=500, type=int, help="Width for beam search decoding. (default: %(default)d)") -parser.add_argument( - "--num_results_per_sample", - default=1, - type=int, - help="Number of outputs per sample in beam search. (default: %(default)d)") parser.add_argument( "--language_model_path", - default="data/en.00.UNKNOWN.klm", + default="lm/data/en.00.UNKNOWN.klm", type=str, help="Path for language model. (default: %(default)s)") parser.add_argument( @@ -137,7 +131,8 @@ def tune(): data_generator = DataGenerator( vocab_filepath=args.vocab_filepath, mean_std_filepath=args.mean_std_filepath, - augmentation_config='{}') + augmentation_config='{}', + num_threads=args.num_threads_data) # create network config # paddle.data_type.dense_array is used for variable batch input. @@ -188,42 +183,22 @@ def tune(): ## tune parameters in loop for (alpha, beta) in params_grid: wer_sum, wer_counter = 0, 0 - ext_scorer = Scorer(alpha, beta, args.language_model_path) - # beam search decode - if args.decode_method == "beam_search": - for i, probs in enumerate(probs_split): - target_transcription = ''.join([ - data_generator.vocab_list[index] - for index in infer_data[i][1] - ]) - beam_search_result = ctc_beam_search_decoder( - probs_seq=probs, - vocabulary=data_generator.vocab_list, - beam_size=args.beam_size, - blank_id=len(data_generator.vocab_list), - cutoff_prob=args.cutoff_prob, - ext_scoring_func=ext_scorer, ) - wer_sum += wer(target_transcription, beam_search_result[0][1]) - wer_counter += 1 + ext_scorer = LmScorer(alpha, beta, args.language_model_path) # beam search using multiple processes - elif args.decode_method == "beam_search_nproc": - beam_search_nproc_results = ctc_beam_search_decoder_nproc( - probs_split=probs_split, - vocabulary=data_generator.vocab_list, - beam_size=args.beam_size, - cutoff_prob=args.cutoff_prob, - blank_id=len(data_generator.vocab_list), - ext_scoring_func=ext_scorer, ) - for i, beam_search_result in enumerate(beam_search_nproc_results): - target_transcription = ''.join([ - data_generator.vocab_list[index] - for index in infer_data[i][1] - ]) - wer_sum += wer(target_transcription, beam_search_result[0][1]) - wer_counter += 1 - else: - raise ValueError("Decoding method [%s] is not supported." % - decode_method) + beam_search_results = ctc_beam_search_decoder_batch( + probs_split=probs_split, + vocabulary=data_generator.vocab_list, + beam_size=args.beam_size, + cutoff_prob=args.cutoff_prob, + blank_id=len(data_generator.vocab_list), + num_processes=args.num_processes_beam_search, + ext_scoring_func=ext_scorer, ) + for i, beam_search_result in enumerate(beam_search_results): + target_transcription = ''.join([ + data_generator.vocab_list[index] for index in infer_data[i][1] + ]) + wer_sum += wer(target_transcription, beam_search_result[0][1]) + wer_counter += 1 print("alpha = %f\tbeta = %f\tWER = %f" % (alpha, beta, wer_sum / wer_counter)) -- GitLab