diff --git a/deep_speech_2/decoder.py b/deep_speech_2/decoder.py index 824ac970162dde6e85b82b3a5a3910e82f31d075..2ee89cbd028553421d6b5aa02017fb7d20f03e16 100755 --- a/deep_speech_2/decoder.py +++ b/deep_speech_2/decoder.py @@ -5,7 +5,6 @@ import os from itertools import groupby import numpy as np -import copy import kenlm import multiprocessing @@ -73,11 +72,25 @@ class Scorer(object): return len(words) # execute evaluation - def __call__(self, sentence): + def __call__(self, sentence, log=False): + """ + Evaluation function + + :param sentence: The input sentence for evalutation + :type sentence: basestring + :param log: Whether return the score in log representation. + :type log: bool + :return: Evaluation score, in the decimal or log. + :rtype: float + """ lm = self.language_model_score(sentence) word_cnt = self.word_count(sentence) - score = np.power(lm, self._alpha) \ - * np.power(word_cnt, self._beta) + if log == False: + score = np.power(lm, self._alpha) \ + * np.power(word_cnt, self._beta) + else: + score = self._alpha * np.log(lm) \ + + self._beta * np.log(word_cnt) return score @@ -85,13 +98,14 @@ def ctc_beam_search_decoder(probs_seq, beam_size, vocabulary, blank_id=0, + cutoff_prob=1.0, ext_scoring_func=None, nproc=False): ''' Beam search decoder for CTC-trained network, using beam search with width beam_size to find many paths to one label, return beam_size labels in - the order of probabilities. The implementation is based on Prefix Beam - Search(https://arxiv.org/abs/1408.2873), and the unclear part is + the descending order of probabilities. The implementation is based on Prefix + Beam Search(https://arxiv.org/abs/1408.2873), and the unclear part is redesigned, need to be verified. :param probs_seq: 2-D list with length num_time_steps, each element @@ -102,22 +116,25 @@ def ctc_beam_search_decoder(probs_seq, :type beam_size: int :param vocabulary: Vocabulary list. :type vocabulary: list + :param blank_id: ID of blank, default 0. + :type blank_id: int + :param cutoff_prob: Cutoff probability in pruning, + default 1.0, no pruning. + :type cutoff_prob: float :param ext_scoring_func: External defined scoring function for partially decoded sentence, e.g. word count and language model. :type external_scoring_function: function - :param blank_id: id of blank, default 0. - :type blank_id: int :param nproc: Whether the decoder used in multiprocesses. :type nproc: bool - :return: Decoding log probability and result string. + :return: Decoding log probabilities and result sentences in descending order. :rtype: list ''' # dimension check for prob_list in probs_seq: if not len(prob_list) == len(vocabulary) + 1: - raise ValueError("probs dimension mismatchedd with vocabulary") + raise ValueError("probs dimension mismatched with vocabulary") num_time_steps = len(probs_seq) # blank_id check @@ -137,19 +154,35 @@ def ctc_beam_search_decoder(probs_seq, probs_b_prev, probs_nb_prev = {'\t': 1.0}, {'\t': 0.0} ## extend prefix in loop - for time_step in range(num_time_steps): + for time_step in xrange(num_time_steps): # the set containing candidate prefixes prefix_set_next = {} probs_b_cur, probs_nb_cur = {}, {} + prob = probs_seq[time_step] + prob_idx = [[i, prob[i]] for i in xrange(len(prob))] + cutoff_len = len(prob_idx) + #If pruning is enabled + if (cutoff_prob < 1.0): + prob_idx = sorted(prob_idx, key=lambda asd: asd[1], reverse=True) + cutoff_len = 0 + cum_prob = 0.0 + for i in xrange(len(prob_idx)): + cum_prob += prob_idx[i][1] + cutoff_len += 1 + if cum_prob >= cutoff_prob: + break + prob_idx = prob_idx[0:cutoff_len] + for l in prefix_set_prev: - prob = probs_seq[time_step] if not prefix_set_next.has_key(l): probs_b_cur[l], probs_nb_cur[l] = 0.0, 0.0 - # extend prefix by travering vocabulary - for c in range(0, probs_dim): + # extend prefix by travering prob_idx + for index in xrange(cutoff_len): + c, prob_c = prob_idx[index][0], prob_idx[index][1] + if c == blank_id: - probs_b_cur[l] += prob[c] * ( + probs_b_cur[l] += prob_c * ( probs_b_prev[l] + probs_nb_prev[l]) else: last_char = l[-1] @@ -159,18 +192,18 @@ def ctc_beam_search_decoder(probs_seq, probs_b_cur[l_plus], probs_nb_cur[l_plus] = 0.0, 0.0 if new_char == last_char: - probs_nb_cur[l_plus] += prob[c] * probs_b_prev[l] - probs_nb_cur[l] += prob[c] * probs_nb_prev[l] + probs_nb_cur[l_plus] += prob_c * probs_b_prev[l] + probs_nb_cur[l] += prob_c * probs_nb_prev[l] elif new_char == ' ': if (ext_scoring_func is None) or (len(l) == 1): score = 1.0 else: prefix = l[1:] score = ext_scoring_func(prefix) - probs_nb_cur[l_plus] += score * prob[c] * ( + probs_nb_cur[l_plus] += score * prob_c * ( probs_b_prev[l] + probs_nb_prev[l]) else: - probs_nb_cur[l_plus] += prob[c] * ( + probs_nb_cur[l_plus] += prob_c * ( probs_b_prev[l] + probs_nb_prev[l]) # add l_plus into prefix_set_next prefix_set_next[l_plus] = probs_nb_cur[ @@ -203,6 +236,7 @@ def ctc_beam_search_decoder_nproc(probs_split, beam_size, vocabulary, blank_id=0, + cutoff_prob=1.0, ext_scoring_func=None, num_processes=None): ''' @@ -216,16 +250,19 @@ def ctc_beam_search_decoder_nproc(probs_split, :type beam_size: int :param vocabulary: Vocabulary list. :type vocabulary: list + :param blank_id: ID of blank, default 0. + :type blank_id: int + :param cutoff_prob: Cutoff probability in pruning, + default 0, no pruning. + :type cutoff_prob: float :param ext_scoring_func: External defined scoring function for partially decoded sentence, e.g. word count and language model. :type external_scoring_function: function - :param blank_id: id of blank, default 0. - :type blank_id: int :param num_processes: Number of processes, default None, equal to the number of CPUs. :type num_processes: int - :return: Decoding log probability and result string. + :return: Decoding log probabilities and result sentences in descending order. :rtype: list ''' @@ -243,7 +280,8 @@ def ctc_beam_search_decoder_nproc(probs_split, pool = multiprocessing.Pool(processes=num_processes) results = [] for i, probs_list in enumerate(probs_split): - args = (probs_list, beam_size, vocabulary, blank_id, None, nproc) + args = (probs_list, beam_size, vocabulary, blank_id, cutoff_prob, None, + nproc) results.append(pool.apply_async(ctc_beam_search_decoder, args)) pool.close() diff --git a/deep_speech_2/evaluate.py b/deep_speech_2/evaluate.py new file mode 100644 index 0000000000000000000000000000000000000000..7c05a309daf1a23da8597ddc5d9801ecce989993 --- /dev/null +++ b/deep_speech_2/evaluate.py @@ -0,0 +1,214 @@ +""" + Evaluation for a simplifed version of Baidu DeepSpeech2 model. +""" + +import paddle.v2 as paddle +import distutils.util +import argparse +import gzip +from audio_data_utils import DataGenerator +from model import deep_speech2 +from decoder import * +from error_rate import wer + +parser = argparse.ArgumentParser( + description='Simplified version of DeepSpeech2 evaluation.') +parser.add_argument( + "--num_samples", + default=100, + type=int, + help="Number of samples for evaluation. (default: %(default)s)") +parser.add_argument( + "--num_conv_layers", + default=2, + type=int, + help="Convolution layer number. (default: %(default)s)") +parser.add_argument( + "--num_rnn_layers", + default=3, + type=int, + help="RNN layer number. (default: %(default)s)") +parser.add_argument( + "--rnn_layer_size", + default=512, + type=int, + help="RNN layer cell number. (default: %(default)s)") +parser.add_argument( + "--use_gpu", + default=True, + type=distutils.util.strtobool, + help="Use gpu or not. (default: %(default)s)") +parser.add_argument( + "--decode_method", + default='beam_search_nproc', + type=str, + help="Method for ctc decoding, best_path, " + "beam_search or beam_search_nproc. (default: %(default)s)") +parser.add_argument( + "--language_model_path", + default="./data/1Billion.klm", + type=str, + help="Path for language model. (default: %(default)s)") +parser.add_argument( + "--alpha", + default=0.26, + type=float, + help="Parameter associated with language model. (default: %(default)f)") +parser.add_argument( + "--beta", + default=0.1, + type=float, + help="Parameter associated with word count. (default: %(default)f)") +parser.add_argument( + "--cutoff_prob", + default=0.99, + type=float, + help="The cutoff probability of pruning" + "in beam search. (default: %(default)f)") +parser.add_argument( + "--beam_size", + default=500, + type=int, + help="Width for beam search decoding. (default: %(default)d)") +parser.add_argument( + "--normalizer_manifest_path", + default='data/manifest.libri.train-clean-100', + type=str, + help="Manifest path for normalizer. (default: %(default)s)") +parser.add_argument( + "--decode_manifest_path", + default='data/manifest.libri.test-clean', + type=str, + help="Manifest path for decoding. (default: %(default)s)") +parser.add_argument( + "--model_filepath", + default='./params.tar.gz', + type=str, + help="Model filepath. (default: %(default)s)") +parser.add_argument( + "--vocab_filepath", + default='data/eng_vocab.txt', + type=str, + help="Vocabulary filepath. (default: %(default)s)") +args = parser.parse_args() + + +def evaluate(): + """ + Evaluate on whole test data for DeepSpeech2. + """ + # initialize data generator + data_generator = DataGenerator( + vocab_filepath=args.vocab_filepath, + normalizer_manifest_path=args.normalizer_manifest_path, + normalizer_num_samples=200, + max_duration=20.0, + min_duration=0.0, + stride_ms=10, + window_ms=20) + + # create network config + dict_size = data_generator.vocabulary_size() + vocab_list = data_generator.vocabulary_list() + audio_data = paddle.layer.data( + name="audio_spectrogram", + height=161, + width=2000, + type=paddle.data_type.dense_vector(322000)) + text_data = paddle.layer.data( + name="transcript_text", + type=paddle.data_type.integer_value_sequence(dict_size)) + output_probs = deep_speech2( + audio_data=audio_data, + text_data=text_data, + dict_size=dict_size, + num_conv_layers=args.num_conv_layers, + num_rnn_layers=args.num_rnn_layers, + rnn_size=args.rnn_layer_size, + is_inference=True) + + # load parameters + parameters = paddle.parameters.Parameters.from_tar( + gzip.open(args.model_filepath)) + + # prepare infer data + feeding = data_generator.data_name_feeding() + test_batch_reader = data_generator.batch_reader_creator( + manifest_path=args.decode_manifest_path, + batch_size=args.num_samples, + padding_to=2000, + flatten=True, + sort_by_duration=False, + shuffle=False) + + # define inferer + inferer = paddle.inference.Inference( + output_layer=output_probs, parameters=parameters) + + # initialize external scorer for beam search decoding + if args.decode_method == 'beam_search' or \ + args.decode_method == 'beam_search_nproc': + ext_scorer = Scorer(args.alpha, args.beta, args.language_model_path) + + wer_counter, wer_sum = 0, 0.0 + for infer_data in test_batch_reader(): + # run inference + infer_results = inferer.infer(input=infer_data) + num_steps = len(infer_results) / len(infer_data) + probs_split = [ + infer_results[i * num_steps:(i + 1) * num_steps] + for i in xrange(0, len(infer_data)) + ] + + # decode and print + # best path decode + if args.decode_method == "best_path": + for i, probs in enumerate(probs_split): + output_transcription = ctc_decode( + probs_seq=probs, vocabulary=vocab_list, method="best_path") + target_transcription = ''.join( + [vocab_list[index] for index in infer_data[i][1]]) + wer_sum += wer(target_transcription, output_transcription) + wer_counter += 1 + # beam search decode in single process + elif args.decode_method == "beam_search": + for i, probs in enumerate(probs_split): + target_transcription = ''.join( + [vocab_list[index] for index in infer_data[i][1]]) + beam_search_result = ctc_beam_search_decoder( + probs_seq=probs, + vocabulary=vocab_list, + beam_size=args.beam_size, + blank_id=len(vocab_list), + ext_scoring_func=ext_scorer, + cutoff_prob=args.cutoff_prob, ) + wer_sum += wer(target_transcription, beam_search_result[0][1]) + wer_counter += 1 + # beam search using multiple processes + elif args.decode_method == "beam_search_nproc": + beam_search_nproc_results = ctc_beam_search_decoder_nproc( + probs_split=probs_split, + vocabulary=vocab_list, + beam_size=args.beam_size, + blank_id=len(vocab_list), + ext_scoring_func=ext_scorer, + cutoff_prob=args.cutoff_prob, ) + for i, beam_search_result in enumerate(beam_search_nproc_results): + target_transcription = ''.join( + [vocab_list[index] for index in infer_data[i][1]]) + wer_sum += wer(target_transcription, beam_search_result[0][1]) + wer_counter += 1 + else: + raise ValueError("Decoding method [%s] is not supported." % method) + + print("Cur WER = %f" % (wer_sum / wer_counter)) + print("Final WER = %f" % (wer_sum / wer_counter)) + + +def main(): + paddle.init(use_gpu=args.use_gpu, trainer_count=1) + evaluate() + + +if __name__ == '__main__': + main() diff --git a/deep_speech_2/infer.py b/deep_speech_2/infer.py index bb9dfa0a66ae55a499edef10b7e9e60ddd92f8d2..64fe1524e5bc1d5bb6646962319460c433a3a37f 100644 --- a/deep_speech_2/infer.py +++ b/deep_speech_2/infer.py @@ -9,14 +9,14 @@ import gzip from audio_data_utils import DataGenerator from model import deep_speech2 from decoder import * -import kenlm from error_rate import wer +import time parser = argparse.ArgumentParser( description='Simplified version of DeepSpeech2 inference.') parser.add_argument( "--num_samples", - default=10, + default=100, type=int, help="Number of samples for inference. (default: %(default)s)") parser.add_argument( @@ -46,7 +46,7 @@ parser.add_argument( help="Manifest path for normalizer. (default: %(default)s)") parser.add_argument( "--decode_manifest_path", - default='data/manifest.libri.test-clean', + default='data/manifest.libri.test-100sample', type=str, help="Manifest path for decoding. (default: %(default)s)") parser.add_argument( @@ -63,11 +63,13 @@ parser.add_argument( "--decode_method", default='beam_search_nproc', type=str, - help="Method for ctc decoding, best_path, beam_search or beam_search_nproc. (default: %(default)s)" -) + help="Method for ctc decoding:" + " best_path," + " beam_search, " + " or beam_search_nproc. (default: %(default)s)") parser.add_argument( "--beam_size", - default=50, + default=500, type=int, help="Width for beam search decoding. (default: %(default)d)") parser.add_argument( @@ -82,14 +84,20 @@ parser.add_argument( help="Path for language model. (default: %(default)s)") parser.add_argument( "--alpha", - default=0.0, + default=0.26, type=float, help="Parameter associated with language model. (default: %(default)f)") parser.add_argument( "--beta", - default=0.0, + default=0.1, type=float, help="Parameter associated with word count. (default: %(default)f)") +parser.add_argument( + "--cutoff_prob", + default=0.99, + type=float, + help="The cutoff probability of pruning" + "in beam search. (default: %(default)f)") args = parser.parse_args() @@ -154,6 +162,7 @@ def infer(): ## decode and print # best path decode wer_sum, wer_counter = 0, 0 + total_time = 0.0 if args.decode_method == "best_path": for i, probs in enumerate(probs_split): target_transcription = ''.join( @@ -177,11 +186,12 @@ def infer(): probs_seq=probs, vocabulary=vocab_list, beam_size=args.beam_size, - ext_scoring_func=ext_scorer, - blank_id=len(vocab_list)) + blank_id=len(vocab_list), + cutoff_prob=args.cutoff_prob, + ext_scoring_func=ext_scorer, ) print("\nTarget Transcription:\t%s" % target_transcription) - for index in range(args.num_results_per_sample): + for index in xrange(args.num_results_per_sample): result = beam_search_result[index] #output: index, log prob, beam result print("Beam %d: %f \t%s" % (index, result[0], result[1])) @@ -190,21 +200,21 @@ def infer(): wer_counter += 1 print("cur wer = %f , average wer = %f" % (wer_cur, wer_sum / wer_counter)) - # beam search using multiple processes elif args.decode_method == "beam_search_nproc": ext_scorer = Scorer(args.alpha, args.beta, args.language_model_path) beam_search_nproc_results = ctc_beam_search_decoder_nproc( probs_split=probs_split, vocabulary=vocab_list, beam_size=args.beam_size, - ext_scoring_func=ext_scorer, - blank_id=len(vocab_list)) + blank_id=len(vocab_list), + cutoff_prob=args.cutoff_prob, + ext_scoring_func=ext_scorer, ) for i, beam_search_result in enumerate(beam_search_nproc_results): target_transcription = ''.join( [vocab_list[index] for index in infer_data[i][1]]) print("\nTarget Transcription:\t%s" % target_transcription) - for index in range(args.num_results_per_sample): + for index in xrange(args.num_results_per_sample): result = beam_search_result[index] #output: index, log prob, beam result print("Beam %d: %f \t%s" % (index, result[0], result[1])) diff --git a/deep_speech_2/tune.py b/deep_speech_2/tune.py index 3eb82648993a1eefeb694399a426fbc4200b98c4..58a8a0d1b48795ea7ae0db193d7fe3ac9016de68 100644 --- a/deep_speech_2/tune.py +++ b/deep_speech_2/tune.py @@ -1,5 +1,5 @@ """ - Tune parameters for beam search decoder in Deep Speech 2. + Parameters tuning for beam search decoder in Deep Speech 2. """ import paddle.v2 as paddle @@ -12,7 +12,7 @@ from decoder import * from error_rate import wer parser = argparse.ArgumentParser( - description='Parameters tuning script for ctc beam search decoder in Deep Speech 2.' + description='Parameters tuning for ctc beam search decoder in Deep Speech 2.' ) parser.add_argument( "--num_samples", @@ -82,34 +82,40 @@ parser.add_argument( help="Path for language model. (default: %(default)s)") parser.add_argument( "--alpha_from", - default=0.0, + default=0.1, type=float, - help="Where alpha starts from, <= alpha_to. (default: %(default)f)") + help="Where alpha starts from. (default: %(default)f)") parser.add_argument( - "--alpha_stride", - default=0.001, - type=float, - help="Step length for varying alpha. (default: %(default)f)") + "--num_alphas", + default=14, + type=int, + help="Number of candidate alphas. (default: %(default)d)") parser.add_argument( "--alpha_to", - default=0.01, + default=0.36, type=float, - help="Where alpha ends with, >= alpha_from. (default: %(default)f)") + help="Where alpha ends with. (default: %(default)f)") parser.add_argument( "--beta_from", - default=0.0, + default=0.05, type=float, - help="Where beta starts from, <= beta_to. (default: %(default)f)") + help="Where beta starts from. (default: %(default)f)") parser.add_argument( - "--beta_stride", - default=0.01, + "--num_betas", + default=20, type=float, - help="Step length for varying beta. (default: %(default)f)") + help="Number of candidate betas. (default: %(default)d)") parser.add_argument( "--beta_to", - default=0.0, + default=1.0, type=float, - help="Where beta ends with, >= beta_from. (default: %(default)f)") + help="Where beta ends with. (default: %(default)f)") +parser.add_argument( + "--cutoff_prob", + default=0.99, + type=float, + help="The cutoff probability of pruning" + "in beam search. (default: %(default)f)") args = parser.parse_args() @@ -118,15 +124,11 @@ def tune(): Tune parameters alpha and beta on one minibatch. """ - if not args.alpha_from <= args.alpha_to: - raise ValueError("alpha_from <= alpha_to doesn't satisfy!") - if not args.alpha_stride > 0: - raise ValueError("alpha_stride shouldn't be negative!") + if not args.num_alphas >= 0: + raise ValueError("num_alphas must be non-negative!") - if not args.beta_from <= args.beta_to: - raise ValueError("beta_from <= beta_to doesn't satisfy!") - if not args.beta_stride > 0: - raise ValueError("beta_stride shouldn't be negative!") + if not args.num_betas >= 0: + raise ValueError("num_betas must be non-negative!") # initialize data generator data_generator = DataGenerator( @@ -171,6 +173,7 @@ def tune(): flatten=True, sort_by_duration=False, shuffle=False) + # get one batch data for tuning infer_data = test_batch_reader().next() # run inference @@ -182,11 +185,12 @@ def tune(): for i in xrange(0, len(infer_data)) ] - cand_alpha = np.arange(args.alpha_from, args.alpha_to + args.alpha_stride, - args.alpha_stride) - cand_beta = np.arange(args.beta_from, args.beta_to + args.beta_stride, - args.beta_stride) - params_grid = [(alpha, beta) for alpha in cand_alpha for beta in cand_beta] + # create grid for search + cand_alphas = np.linspace(args.alpha_from, args.alpha_to, args.num_alphas) + cand_betas = np.linspace(args.beta_from, args.beta_to, args.num_betas) + params_grid = [(alpha, beta) for alpha in cand_alphas + for beta in cand_betas] + ## tune parameters in loop for (alpha, beta) in params_grid: wer_sum, wer_counter = 0, 0 @@ -200,8 +204,9 @@ def tune(): probs_seq=probs, vocabulary=vocab_list, beam_size=args.beam_size, - ext_scoring_func=ext_scorer, - blank_id=len(vocab_list)) + blank_id=len(vocab_list), + cutoff_prob=args.cutoff_prob, + ext_scoring_func=ext_scorer, ) wer_sum += wer(target_transcription, beam_search_result[0][1]) wer_counter += 1 # beam search using multiple processes @@ -210,9 +215,9 @@ def tune(): probs_split=probs_split, vocabulary=vocab_list, beam_size=args.beam_size, - ext_scoring_func=ext_scorer, + cutoff_prob=args.cutoff_prob, blank_id=len(vocab_list), - num_processes=1) + ext_scoring_func=ext_scorer, ) for i, beam_search_result in enumerate(beam_search_nproc_results): target_transcription = ''.join( [vocab_list[index] for index in infer_data[i][1]])