diff --git a/README.md b/README.md index 2912ff3143516ee21f21732f25992fadcd33c270..3010c0e536da732f1c4f042c82badaae21179f87 100644 --- a/README.md +++ b/README.md @@ -38,7 +38,13 @@ python datasets/librispeech/librispeech.py --help python compute_mean_std.py ``` -`python compute_mean_std.py` computes mean and stdandard deviation for audio features, and save them to a file with a default name `./mean_std.npz`. This file will be used in both training and inferencing. +It will compute mean and stdandard deviation for audio features, and save them to a file with a default name `./mean_std.npz`. This file will be used in both training and inferencing. The default feature of audio data is power spectrum, and the mfcc feature is also supported. To train and infer based on mfcc feature, please generate this file by + +``` +python compute_mean_std.py --specgram_type mfcc +``` + +and specify ```--specgram_type mfcc``` when running train.py, infer.py, evaluator.py or tune.py. More help for arguments: @@ -66,14 +72,69 @@ More help for arguments: python train.py --help ``` -### Inferencing +### Preparing language model + +The following steps, inference, parameters tuning and evaluating, will require a language model during decoding. +A compressed language model is provided and can be accessed by + +``` +cd ./lm +sh run.sh +cd .. +``` + +### Inference + +For GPU inference ``` CUDA_VISIBLE_DEVICES=0 python infer.py ``` +For CPU inference + +``` +python infer.py --use_gpu=False +``` + More help for arguments: ``` python infer.py --help ``` + +### Evaluating + +``` +CUDA_VISIBLE_DEVICES=0 python evaluate.py +``` + +More help for arguments: + +``` +python evaluate.py --help +``` + +### Parameters tuning + +Usually, the parameters $\alpha$ and $\beta$ for the CTC [prefix beam search](https://arxiv.org/abs/1408.2873) decoder need to be tuned after retraining the acoustic model. + +For GPU tuning + +``` +CUDA_VISIBLE_DEVICES=0 python tune.py +``` + +For CPU tuning + +``` +python tune.py --use_gpu=False +``` + +More help for arguments: + +``` +python tune.py --help +``` + +Then reset parameters with the tuning result before inference or evaluating. diff --git a/compute_mean_std.py b/compute_mean_std.py index 9c301c93f6d2ce3ae099caa96830912f76ce6c58..0cc84e73022ecb1333b805457cace39adcc68ce4 100644 --- a/compute_mean_std.py +++ b/compute_mean_std.py @@ -10,6 +10,12 @@ from data_utils.featurizer.audio_featurizer import AudioFeaturizer parser = argparse.ArgumentParser( description='Computing mean and stddev for feature normalizer.') +parser.add_argument( + "--specgram_type", + default='linear', + type=str, + help="Feature type of audio data: 'linear' (power spectrum)" + " or 'mfcc'. (default: %(default)s)") parser.add_argument( "--manifest_path", default='datasets/manifest.train', @@ -39,7 +45,7 @@ args = parser.parse_args() def main(): augmentation_pipeline = AugmentationPipeline(args.augmentation_config) - audio_featurizer = AudioFeaturizer() + audio_featurizer = AudioFeaturizer(specgram_type=args.specgram_type) def augment_and_featurize(audio_segment): augmentation_pipeline.transform_audio(audio_segment) diff --git a/data_utils/featurizer/audio_featurizer.py b/data_utils/featurizer/audio_featurizer.py index 4b4d02c60f4193d753badae1aaa3b17ab3b7ea43..271e535b6a9f1cded27caf4f63adcc51abf3e835 100644 --- a/data_utils/featurizer/audio_featurizer.py +++ b/data_utils/featurizer/audio_featurizer.py @@ -6,13 +6,15 @@ from __future__ import print_function import numpy as np from data_utils import utils from data_utils.audio import AudioSegment +from python_speech_features import mfcc +from python_speech_features import delta class AudioFeaturizer(object): """Audio featurizer, for extracting features from audio contents of AudioSegment or SpeechSegment. - Currently, it only supports feature type of linear spectrogram. + Currently, it supports feature types of linear spectrogram and mfcc. :param specgram_type: Specgram feature type. Options: 'linear'. :type specgram_type: str @@ -20,9 +22,10 @@ class AudioFeaturizer(object): :type stride_ms: float :param window_ms: Window size (in milliseconds) for generating frames. :type window_ms: float - :param max_freq: Used when specgram_type is 'linear', only FFT bins + :param max_freq: When specgram_type is 'linear', only FFT bins corresponding to frequencies between [0, max_freq] are - returned. + returned; when specgram_type is 'mfcc', max_feq is the + highest band edge of mel filters. :types max_freq: None|float :param target_sample_rate: Audio are resampled (if upsampling or downsampling is allowed) to this before @@ -91,6 +94,9 @@ class AudioFeaturizer(object): return self._compute_linear_specgram( samples, sample_rate, self._stride_ms, self._window_ms, self._max_freq) + elif self._specgram_type == 'mfcc': + return self._compute_mfcc(samples, sample_rate, self._stride_ms, + self._window_ms, self._max_freq) else: raise ValueError("Unknown specgram_type %s. " "Supported values: linear." % self._specgram_type) @@ -142,3 +148,39 @@ class AudioFeaturizer(object): # prepare fft frequency list freqs = float(sample_rate) / window_size * np.arange(fft.shape[0]) return fft, freqs + + def _compute_mfcc(self, + samples, + sample_rate, + stride_ms=10.0, + window_ms=20.0, + max_freq=None): + """Compute mfcc from samples.""" + if max_freq is None: + max_freq = sample_rate / 2 + if max_freq > sample_rate / 2: + raise ValueError("max_freq must be greater than half of " + "sample rate.") + if stride_ms > window_ms: + raise ValueError("Stride size must not be greater than " + "window size.") + # compute 13 cepstral coefficients, and the first one is replaced + # by log(frame energy) + mfcc_feat = mfcc( + signal=samples, + samplerate=sample_rate, + winlen=0.001 * window_ms, + winstep=0.001 * stride_ms, + highfreq=max_freq) + # Deltas + d_mfcc_feat = delta(mfcc_feat, 2) + # Deltas-Deltas + dd_mfcc_feat = delta(d_mfcc_feat, 2) + # concat above three features + concat_mfcc_feat = [ + np.concatenate((mfcc_feat[i], d_mfcc_feat[i], dd_mfcc_feat[i])) + for i in xrange(len(mfcc_feat)) + ] + # transpose to be consistent with the linear specgram situation + concat_mfcc_feat = np.transpose(concat_mfcc_feat) + return concat_mfcc_feat diff --git a/data_utils/featurizer/speech_featurizer.py b/data_utils/featurizer/speech_featurizer.py index 26283892e85beb8b41351fb2d1b876c6284da887..a947588db4a29d7d49b9650c2da28731259cc0e0 100644 --- a/data_utils/featurizer/speech_featurizer.py +++ b/data_utils/featurizer/speech_featurizer.py @@ -11,23 +11,24 @@ class SpeechFeaturizer(object): """Speech featurizer, for extracting features from both audio and transcript contents of SpeechSegment. - Currently, for audio parts, it only supports feature type of linear - spectrogram; for transcript parts, it only supports char-level tokenizing - and conversion into a list of token indices. Note that the token indexing - order follows the given vocabulary file. + Currently, for audio parts, it supports feature types of linear + spectrogram and mfcc; for transcript parts, it only supports char-level + tokenizing and conversion into a list of token indices. Note that the + token indexing order follows the given vocabulary file. :param vocab_filepath: Filepath to load vocabulary for token indices conversion. :type specgram_type: basestring - :param specgram_type: Specgram feature type. Options: 'linear'. + :param specgram_type: Specgram feature type. Options: 'linear', 'mfcc'. :type specgram_type: str :param stride_ms: Striding size (in milliseconds) for generating frames. :type stride_ms: float :param window_ms: Window size (in milliseconds) for generating frames. :type window_ms: float - :param max_freq: Used when specgram_type is 'linear', only FFT bins + :param max_freq: When specgram_type is 'linear', only FFT bins corresponding to frequencies between [0, max_freq] are - returned. + returned; when specgram_type is 'mfcc', max_freq is the + highest band edge of mel filters. :types max_freq: None|float :param target_sample_rate: Speech are resampled (if upsampling or downsampling is allowed) to this before diff --git a/data_utils/normalizer.py b/data_utils/normalizer.py index c123d25d20600140b47da1e93655b15c0053dfea..1f4aae9a0913f323480c46c2d449f9515a65bb7e 100644 --- a/data_utils/normalizer.py +++ b/data_utils/normalizer.py @@ -16,7 +16,7 @@ class FeatureNormalizer(object): if mean_std_filepath is provided (not None), the normalizer will directly initilize from the file. Otherwise, both manifest_path and featurize_func should be given for on-the-fly mean and stddev computing. - + :param mean_std_filepath: File containing the pre-computed mean and stddev. :type mean_std_filepath: None|basestring :param manifest_path: Manifest of instances for computing mean and stddev. diff --git a/decoder.py b/decoder.py index 77d950b8db072d539788fd1b2bc7ac0525ffa0f9..a1fadc2c81ac5036f5082e1a60b018106ab90277 100644 --- a/decoder.py +++ b/decoder.py @@ -1,14 +1,16 @@ -"""Contains various CTC decoder.""" +"""Contains various CTC decoders.""" from __future__ import absolute_import from __future__ import division from __future__ import print_function -import numpy as np from itertools import groupby +import numpy as np +from math import log +import multiprocessing -def ctc_best_path_decode(probs_seq, vocabulary): - """Best path decoding, also called argmax decoding or greedy decoding. +def ctc_best_path_decoder(probs_seq, vocabulary): + """Best path decoder, also called argmax decoder or greedy decoder. Path consisting of the most probable tokens are further post-processed to remove consecutive repetitions and all blanks. @@ -36,24 +38,200 @@ def ctc_best_path_decode(probs_seq, vocabulary): return ''.join([vocabulary[index] for index in index_list]) -def ctc_decode(probs_seq, vocabulary, method): - """CTC-like sequence decoding from a sequence of likelihood probablilites. +def ctc_beam_search_decoder(probs_seq, + beam_size, + vocabulary, + blank_id, + cutoff_prob=1.0, + ext_scoring_func=None, + nproc=False): + """Beam search decoder for CTC-trained network. It utilizes beam search + to approximately select top best decoding labels and returning results + in the descending order. The implementation is based on Prefix + Beam Search (https://arxiv.org/abs/1408.2873), and the unclear part is + redesigned. Two important modifications: 1) in the iterative computation + of probabilities, the assignment operation is changed to accumulation for + one prefix may comes from different paths; 2) the if condition "if l^+ not + in A_prev then" after probabilities' computation is deprecated for it is + hard to understand and seems unnecessary. - :param probs_seq: 2-D list of probabilities over the vocabulary for each - character. Each element is a list of float probabilities - for one character. - :type probs_seq: list + :param probs_seq: 2-D list of probability distributions over each time + step, with each element being a list of normalized + probabilities over vocabulary and blank. + :type probs_seq: 2-D list + :param beam_size: Width for beam search. + :type beam_size: int :param vocabulary: Vocabulary list. :type vocabulary: list - :param method: Decoding method name, with options: "best_path". - :type method: basestring - :return: Decoding result string. - :rtype: baseline + :param blank_id: ID of blank. + :type blank_id: int + :param cutoff_prob: Cutoff probability in pruning, + default 1.0, no pruning. + :type cutoff_prob: float + :param ext_scoring_func: External scoring function for + partially decoded sentence, e.g. word count + or language model. + :type external_scoring_func: callable + :param nproc: Whether the decoder used in multiprocesses. + :type nproc: bool + :return: List of tuples of log probability and sentence as decoding + results, in descending order of the probability. + :rtype: list """ + # dimension check for prob_list in probs_seq: if not len(prob_list) == len(vocabulary) + 1: - raise ValueError("probs dimension mismatchedd with vocabulary") - if method == "best_path": - return ctc_best_path_decode(probs_seq, vocabulary) - else: - raise ValueError("Decoding method [%s] is not supported.") + raise ValueError("The shape of prob_seq does not match with the " + "shape of the vocabulary.") + + # blank_id check + if not blank_id < len(probs_seq[0]): + raise ValueError("blank_id shouldn't be greater than probs dimension") + + # If the decoder called in the multiprocesses, then use the global scorer + # instantiated in ctc_beam_search_decoder_batch(). + if nproc is True: + global ext_nproc_scorer + ext_scoring_func = ext_nproc_scorer + + ## initialize + # prefix_set_prev: the set containing selected prefixes + # probs_b_prev: prefixes' probability ending with blank in previous step + # probs_nb_prev: prefixes' probability ending with non-blank in previous step + prefix_set_prev = {'\t': 1.0} + probs_b_prev, probs_nb_prev = {'\t': 1.0}, {'\t': 0.0} + + ## extend prefix in loop + for time_step in xrange(len(probs_seq)): + # prefix_set_next: the set containing candidate prefixes + # probs_b_cur: prefixes' probability ending with blank in current step + # probs_nb_cur: prefixes' probability ending with non-blank in current step + prefix_set_next, probs_b_cur, probs_nb_cur = {}, {}, {} + + prob_idx = list(enumerate(probs_seq[time_step])) + cutoff_len = len(prob_idx) + #If pruning is enabled + if cutoff_prob < 1.0: + prob_idx = sorted(prob_idx, key=lambda asd: asd[1], reverse=True) + cutoff_len, cum_prob = 0, 0.0 + for i in xrange(len(prob_idx)): + cum_prob += prob_idx[i][1] + cutoff_len += 1 + if cum_prob >= cutoff_prob: + break + prob_idx = prob_idx[0:cutoff_len] + + for l in prefix_set_prev: + if not prefix_set_next.has_key(l): + probs_b_cur[l], probs_nb_cur[l] = 0.0, 0.0 + + # extend prefix by travering prob_idx + for index in xrange(cutoff_len): + c, prob_c = prob_idx[index][0], prob_idx[index][1] + + if c == blank_id: + probs_b_cur[l] += prob_c * ( + probs_b_prev[l] + probs_nb_prev[l]) + else: + last_char = l[-1] + new_char = vocabulary[c] + l_plus = l + new_char + if not prefix_set_next.has_key(l_plus): + probs_b_cur[l_plus], probs_nb_cur[l_plus] = 0.0, 0.0 + + if new_char == last_char: + probs_nb_cur[l_plus] += prob_c * probs_b_prev[l] + probs_nb_cur[l] += prob_c * probs_nb_prev[l] + elif new_char == ' ': + if (ext_scoring_func is None) or (len(l) == 1): + score = 1.0 + else: + prefix = l[1:] + score = ext_scoring_func(prefix) + probs_nb_cur[l_plus] += score * prob_c * ( + probs_b_prev[l] + probs_nb_prev[l]) + else: + probs_nb_cur[l_plus] += prob_c * ( + probs_b_prev[l] + probs_nb_prev[l]) + # add l_plus into prefix_set_next + prefix_set_next[l_plus] = probs_nb_cur[ + l_plus] + probs_b_cur[l_plus] + # add l into prefix_set_next + prefix_set_next[l] = probs_b_cur[l] + probs_nb_cur[l] + # update probs + probs_b_prev, probs_nb_prev = probs_b_cur, probs_nb_cur + + ## store top beam_size prefixes + prefix_set_prev = sorted( + prefix_set_next.iteritems(), key=lambda asd: asd[1], reverse=True) + if beam_size < len(prefix_set_prev): + prefix_set_prev = prefix_set_prev[:beam_size] + prefix_set_prev = dict(prefix_set_prev) + + beam_result = [] + for seq, prob in prefix_set_prev.items(): + if prob > 0.0 and len(seq) > 1: + result = seq[1:] + # score last word by external scorer + if (ext_scoring_func is not None) and (result[-1] != ' '): + prob = prob * ext_scoring_func(result) + log_prob = log(prob) + beam_result.append((log_prob, result)) + + ## output top beam_size decoding results + beam_result = sorted(beam_result, key=lambda asd: asd[0], reverse=True) + return beam_result + + +def ctc_beam_search_decoder_batch(probs_split, + beam_size, + vocabulary, + blank_id, + num_processes, + cutoff_prob=1.0, + ext_scoring_func=None): + """CTC beam search decoder using multiple processes. + + :param probs_seq: 3-D list with each element as an instance of 2-D list + of probabilities used by ctc_beam_search_decoder(). + :type probs_seq: 3-D list + :param beam_size: Width for beam search. + :type beam_size: int + :param vocabulary: Vocabulary list. + :type vocabulary: list + :param blank_id: ID of blank. + :type blank_id: int + :param num_processes: Number of parallel processes. + :type num_processes: int + :param cutoff_prob: Cutoff probability in pruning, + default 1.0, no pruning. + :param num_processes: Number of parallel processes. + :type num_processes: int + :type cutoff_prob: float + :param ext_scoring_func: External scoring function for + partially decoded sentence, e.g. word count + or language model. + :type external_scoring_function: callable + :return: List of tuples of log probability and sentence as decoding + results, in descending order of the probability. + :rtype: list + """ + if not num_processes > 0: + raise ValueError("Number of processes must be positive!") + + # use global variable to pass the externnal scorer to beam search decoder + global ext_nproc_scorer + ext_nproc_scorer = ext_scoring_func + nproc = True + + pool = multiprocessing.Pool(processes=num_processes) + results = [] + for i, probs_list in enumerate(probs_split): + args = (probs_list, beam_size, vocabulary, blank_id, cutoff_prob, None, + nproc) + results.append(pool.apply_async(ctc_beam_search_decoder, args)) + + pool.close() + pool.join() + beam_search_results = [result.get() for result in results] + return beam_search_results diff --git a/evaluate.py b/evaluate.py new file mode 100644 index 0000000000000000000000000000000000000000..19eabf4e5aff090ed2f529e3ea3cd7f10ae57cb7 --- /dev/null +++ b/evaluate.py @@ -0,0 +1,212 @@ +"""Evaluation for DeepSpeech2 model.""" +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import distutils.util +import argparse +import gzip +import paddle.v2 as paddle +from data_utils.data import DataGenerator +from model import deep_speech2 +from decoder import * +from lm.lm_scorer import LmScorer +from error_rate import wer + +parser = argparse.ArgumentParser(description=__doc__) +parser.add_argument( + "--batch_size", + default=100, + type=int, + help="Minibatch size for evaluation. (default: %(default)s)") +parser.add_argument( + "--num_conv_layers", + default=2, + type=int, + help="Convolution layer number. (default: %(default)s)") +parser.add_argument( + "--num_rnn_layers", + default=3, + type=int, + help="RNN layer number. (default: %(default)s)") +parser.add_argument( + "--rnn_layer_size", + default=512, + type=int, + help="RNN layer cell number. (default: %(default)s)") +parser.add_argument( + "--use_gpu", + default=True, + type=distutils.util.strtobool, + help="Use gpu or not. (default: %(default)s)") +parser.add_argument( + "--num_threads_data", + default=multiprocessing.cpu_count(), + type=int, + help="Number of cpu threads for preprocessing data. (default: %(default)s)") +parser.add_argument( + "--num_processes_beam_search", + default=multiprocessing.cpu_count(), + type=int, + help="Number of cpu processes for beam search. (default: %(default)s)") +parser.add_argument( + "--mean_std_filepath", + default='mean_std.npz', + type=str, + help="Manifest path for normalizer. (default: %(default)s)") +parser.add_argument( + "--decode_method", + default='beam_search', + type=str, + help="Method for ctc decoding, best_path or beam_search. (default: %(default)s)" +) +parser.add_argument( + "--language_model_path", + default="lm/data/common_crawl_00.prune01111.trie.klm", + type=str, + help="Path for language model. (default: %(default)s)") +parser.add_argument( + "--alpha", + default=0.26, + type=float, + help="Parameter associated with language model. (default: %(default)f)") +parser.add_argument( + "--beta", + default=0.1, + type=float, + help="Parameter associated with word count. (default: %(default)f)") +parser.add_argument( + "--cutoff_prob", + default=0.99, + type=float, + help="The cutoff probability of pruning" + "in beam search. (default: %(default)f)") +parser.add_argument( + "--beam_size", + default=500, + type=int, + help="Width for beam search decoding. (default: %(default)d)") +parser.add_argument( + "--specgram_type", + default='linear', + type=str, + help="Feature type of audio data: 'linear' (power spectrum)" + " or 'mfcc'. (default: %(default)s)") +parser.add_argument( + "--decode_manifest_path", + default='datasets/manifest.test', + type=str, + help="Manifest path for decoding. (default: %(default)s)") +parser.add_argument( + "--model_filepath", + default='checkpoints/params.latest.tar.gz', + type=str, + help="Model filepath. (default: %(default)s)") +parser.add_argument( + "--vocab_filepath", + default='datasets/vocab/eng_vocab.txt', + type=str, + help="Vocabulary filepath. (default: %(default)s)") +args = parser.parse_args() + + +def evaluate(): + """Evaluate on whole test data for DeepSpeech2.""" + # initialize data generator + data_generator = DataGenerator( + vocab_filepath=args.vocab_filepath, + mean_std_filepath=args.mean_std_filepath, + augmentation_config='{}', + specgram_type=args.specgram_type, + num_threads=args.num_threads_data) + + # create network config + # paddle.data_type.dense_array is used for variable batch input. + # The size 161 * 161 is only an placeholder value and the real shape + # of input batch data will be induced during training. + audio_data = paddle.layer.data( + name="audio_spectrogram", type=paddle.data_type.dense_array(161 * 161)) + text_data = paddle.layer.data( + name="transcript_text", + type=paddle.data_type.integer_value_sequence(data_generator.vocab_size)) + output_probs = deep_speech2( + audio_data=audio_data, + text_data=text_data, + dict_size=data_generator.vocab_size, + num_conv_layers=args.num_conv_layers, + num_rnn_layers=args.num_rnn_layers, + rnn_size=args.rnn_layer_size, + is_inference=True) + + # load parameters + parameters = paddle.parameters.Parameters.from_tar( + gzip.open(args.model_filepath)) + + # prepare infer data + batch_reader = data_generator.batch_reader_creator( + manifest_path=args.decode_manifest_path, + batch_size=args.batch_size, + min_batch_size=1, + sortagrad=False, + shuffle_method=None) + + # define inferer + inferer = paddle.inference.Inference( + output_layer=output_probs, parameters=parameters) + + # initialize external scorer for beam search decoding + if args.decode_method == 'beam_search': + ext_scorer = LmScorer(args.alpha, args.beta, args.language_model_path) + + wer_counter, wer_sum = 0, 0.0 + for infer_data in batch_reader(): + # run inference + infer_results = inferer.infer(input=infer_data) + num_steps = len(infer_results) // len(infer_data) + probs_split = [ + infer_results[i * num_steps:(i + 1) * num_steps] + for i in xrange(0, len(infer_data)) + ] + # target transcription + target_transcription = [ + ''.join([ + data_generator.vocab_list[index] for index in infer_data[i][1] + ]) for i, probs in enumerate(probs_split) + ] + # decode and print + # best path decode + if args.decode_method == "best_path": + for i, probs in enumerate(probs_split): + output_transcription = ctc_best_path_decoder( + probs_seq=probs, vocabulary=data_generator.vocab_list) + wer_sum += wer(target_transcription[i], output_transcription) + wer_counter += 1 + # beam search decode + elif args.decode_method == "beam_search": + # beam search using multiple processes + beam_search_results = ctc_beam_search_decoder_batch( + probs_split=probs_split, + vocabulary=data_generator.vocab_list, + beam_size=args.beam_size, + blank_id=len(data_generator.vocab_list), + num_processes=args.num_processes_beam_search, + ext_scoring_func=ext_scorer, + cutoff_prob=args.cutoff_prob, ) + for i, beam_search_result in enumerate(beam_search_results): + wer_sum += wer(target_transcription[i], + beam_search_result[0][1]) + wer_counter += 1 + else: + raise ValueError("Decoding method [%s] is not supported." % + decode_method) + + print("Final WER = %f" % (wer_sum / wer_counter)) + + +def main(): + paddle.init(use_gpu=args.use_gpu, trainer_count=1) + evaluate() + + +if __name__ == '__main__': + main() diff --git a/infer.py b/infer.py index 9037a108e2c5cbf8f5d8544b6fa07057067c9340..817526302764b3d6044688da97ad0cc072c14144 100644 --- a/infer.py +++ b/infer.py @@ -10,7 +10,9 @@ import multiprocessing import paddle.v2 as paddle from data_utils.data import DataGenerator from model import deep_speech2 -from decoder import ctc_decode +from decoder import * +from lm.lm_scorer import LmScorer +from error_rate import wer import utils parser = argparse.ArgumentParser(description=__doc__) @@ -44,6 +46,17 @@ parser.add_argument( default=multiprocessing.cpu_count(), type=int, help="Number of cpu threads for preprocessing data. (default: %(default)s)") +parser.add_argument( + "--num_processes_beam_search", + default=multiprocessing.cpu_count(), + type=int, + help="Number of cpu processes for beam search. (default: %(default)s)") +parser.add_argument( + "--specgram_type", + default='linear', + type=str, + help="Feature type of audio data: 'linear' (power spectrum)" + " or 'mfcc'. (default: %(default)s)") parser.add_argument( "--mean_std_filepath", default='mean_std.npz', @@ -64,16 +77,54 @@ parser.add_argument( default='datasets/vocab/eng_vocab.txt', type=str, help="Vocabulary filepath. (default: %(default)s)") +parser.add_argument( + "--decode_method", + default='beam_search', + type=str, + help="Method for ctc decoding: best_path or beam_search. (default: %(default)s)" +) +parser.add_argument( + "--beam_size", + default=500, + type=int, + help="Width for beam search decoding. (default: %(default)d)") +parser.add_argument( + "--num_results_per_sample", + default=1, + type=int, + help="Number of output per sample in beam search. (default: %(default)d)") +parser.add_argument( + "--language_model_path", + default="lm/data/common_crawl_00.prune01111.trie.klm", + type=str, + help="Path for language model. (default: %(default)s)") +parser.add_argument( + "--alpha", + default=0.26, + type=float, + help="Parameter associated with language model. (default: %(default)f)") +parser.add_argument( + "--beta", + default=0.1, + type=float, + help="Parameter associated with word count. (default: %(default)f)") +parser.add_argument( + "--cutoff_prob", + default=0.99, + type=float, + help="The cutoff probability of pruning" + "in beam search. (default: %(default)f)") args = parser.parse_args() def infer(): - """Max-ctc-decoding for DeepSpeech2.""" + """Inference for DeepSpeech2.""" # initialize data generator data_generator = DataGenerator( vocab_filepath=args.vocab_filepath, mean_std_filepath=args.mean_std_filepath, augmentation_config='{}', + specgram_type=args.specgram_type, num_threads=args.num_threads_data) # create network config @@ -102,6 +153,7 @@ def infer(): batch_reader = data_generator.batch_reader_creator( manifest_path=args.decode_manifest_path, batch_size=args.num_samples, + min_batch_size=1, sortagrad=False, shuffle_method=None) infer_data = batch_reader().next() @@ -115,16 +167,52 @@ def infer(): for i in xrange(len(infer_data)) ] - # decode and print - for i, probs in enumerate(probs_split): - output_transcription = ctc_decode( - probs_seq=probs, - vocabulary=data_generator.vocab_list, - method="best_path") - target_transcription = ''.join( + # targe transcription + target_transcription = [ + ''.join( [data_generator.vocab_list[index] for index in infer_data[i][1]]) - print("Target Transcription: %s \nOutput Transcription: %s \n" % - (target_transcription, output_transcription)) + for i, probs in enumerate(probs_split) + ] + + ## decode and print + # best path decode + wer_sum, wer_counter = 0, 0 + if args.decode_method == "best_path": + for i, probs in enumerate(probs_split): + best_path_transcription = ctc_best_path_decoder( + probs_seq=probs, vocabulary=data_generator.vocab_list) + print("\nTarget Transcription: %s\nOutput Transcription: %s" % + (target_transcription[i], best_path_transcription)) + wer_cur = wer(target_transcription[i], best_path_transcription) + wer_sum += wer_cur + wer_counter += 1 + print("cur wer = %f, average wer = %f" % + (wer_cur, wer_sum / wer_counter)) + # beam search decode + elif args.decode_method == "beam_search": + ext_scorer = LmScorer(args.alpha, args.beta, args.language_model_path) + beam_search_batch_results = ctc_beam_search_decoder_batch( + probs_split=probs_split, + vocabulary=data_generator.vocab_list, + beam_size=args.beam_size, + blank_id=len(data_generator.vocab_list), + num_processes=args.num_processes_beam_search, + cutoff_prob=args.cutoff_prob, + ext_scoring_func=ext_scorer, ) + for i, beam_search_result in enumerate(beam_search_batch_results): + print("\nTarget Transcription:\t%s" % target_transcription[i]) + for index in xrange(args.num_results_per_sample): + result = beam_search_result[index] + #output: index, log prob, beam result + print("Beam %d: %f \t%s" % (index, result[0], result[1])) + wer_cur = wer(target_transcription[i], beam_search_result[0][1]) + wer_sum += wer_cur + wer_counter += 1 + print("cur wer = %f , average wer = %f" % + (wer_cur, wer_sum / wer_counter)) + else: + raise ValueError("Decoding method [%s] is not supported." % + decode_method) def main(): diff --git a/lm/__init__.py b/lm/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/lm/lm_scorer.py b/lm/lm_scorer.py new file mode 100644 index 0000000000000000000000000000000000000000..463e96d6653b29207fb6105527a1f79c41c7fb84 --- /dev/null +++ b/lm/lm_scorer.py @@ -0,0 +1,68 @@ +"""External Scorer for Beam Search Decoder.""" +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import os +import kenlm +import numpy as np + + +class LmScorer(object): + """External scorer to evaluate a prefix or whole sentence in + beam search decoding, including the score from n-gram language + model and word count. + + :param alpha: Parameter associated with language model. Don't use + language model when alpha = 0. + :type alpha: float + :param beta: Parameter associated with word count. Don't use word + count when beta = 0. + :type beta: float + :model_path: Path to load language model. + :type model_path: basestring + """ + + def __init__(self, alpha, beta, model_path): + self._alpha = alpha + self._beta = beta + if not os.path.isfile(model_path): + raise IOError("Invaid language model path: %s" % model_path) + self._language_model = kenlm.LanguageModel(model_path) + + # n-gram language model scoring + def _language_model_score(self, sentence): + #log10 prob of last word + log_cond_prob = list( + self._language_model.full_scores(sentence, eos=False))[-1][0] + return np.power(10, log_cond_prob) + + # word insertion term + def _word_count(self, sentence): + words = sentence.strip().split(' ') + return len(words) + + # reset alpha and beta + def reset_params(self, alpha, beta): + self._alpha = alpha + self._beta = beta + + # execute evaluation + def __call__(self, sentence, log=False): + """Evaluation function, gathering all the different scores + and return the final one. + + :param sentence: The input sentence for evalutation + :type sentence: basestring + :param log: Whether return the score in log representation. + :type log: bool + :return: Evaluation score, in the decimal or log. + :rtype: float + """ + lm = self._language_model_score(sentence) + word_cnt = self._word_count(sentence) + if log == False: + score = np.power(lm, self._alpha) * np.power(word_cnt, self._beta) + else: + score = self._alpha * np.log(lm) + self._beta * np.log(word_cnt) + return score diff --git a/lm/run.sh b/lm/run.sh new file mode 100644 index 0000000000000000000000000000000000000000..2108ea55f1205f4c4c32b8994602544ca4e63edd --- /dev/null +++ b/lm/run.sh @@ -0,0 +1,19 @@ +echo "Downloading language model ..." + +mkdir data + +LM=common_crawl_00.prune01111.trie.klm +MD5="099a601759d467cd0a8523ff939819c5" + +wget -c http://paddlepaddle.bj.bcebos.com/model_zoo/speech/$LM -P ./data + +echo "Checking md5sum ..." +md5_tmp=`md5sum ./data/$LM | awk -F[' '] '{print $1}'` + +if [ $MD5 != $md5_tmp ]; then + echo "Fail to download the language model!" + exit 1 +fi + + + diff --git a/requirements.txt b/requirements.txt index 967b4f8c3148c62cd5b7a511567848af6c5c8f93..721fa2811081e530a9cec3b2e403ad2372b59269 100755 --- a/requirements.txt +++ b/requirements.txt @@ -1,3 +1,5 @@ wget==3.2 scipy==0.13.1 -resampy==0.1.5 \ No newline at end of file +resampy==0.1.5 +https://github.com/kpu/kenlm/archive/master.zip +python_speech_features diff --git a/tests/test_decoders.py b/tests/test_decoders.py new file mode 100644 index 0000000000000000000000000000000000000000..99d8a8289d93574c58ced50923716c39cfb96558 --- /dev/null +++ b/tests/test_decoders.py @@ -0,0 +1,91 @@ +"""Test decoders.""" +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import unittest +from decoder import * + + +class TestDecoders(unittest.TestCase): + def setUp(self): + self.vocab_list = ["\'", ' ', 'a', 'b', 'c', 'd'] + self.beam_size = 20 + self.probs_seq1 = [[ + 0.06390443, 0.21124858, 0.27323887, 0.06870235, 0.0361254, + 0.18184413, 0.16493624 + ], [ + 0.03309247, 0.22866108, 0.24390638, 0.09699597, 0.31895462, + 0.0094893, 0.06890021 + ], [ + 0.218104, 0.19992557, 0.18245131, 0.08503348, 0.14903535, + 0.08424043, 0.08120984 + ], [ + 0.12094152, 0.19162472, 0.01473646, 0.28045061, 0.24246305, + 0.05206269, 0.09772094 + ], [ + 0.1333387, 0.00550838, 0.00301669, 0.21745861, 0.20803985, + 0.41317442, 0.01946335 + ], [ + 0.16468227, 0.1980699, 0.1906545, 0.18963251, 0.19860937, + 0.04377724, 0.01457421 + ]] + self.probs_seq2 = [[ + 0.08034842, 0.22671944, 0.05799633, 0.36814645, 0.11307441, + 0.04468023, 0.10903471 + ], [ + 0.09742457, 0.12959763, 0.09435383, 0.21889204, 0.15113123, + 0.10219457, 0.20640612 + ], [ + 0.45033529, 0.09091417, 0.15333208, 0.07939558, 0.08649316, + 0.12298585, 0.01654384 + ], [ + 0.02512238, 0.22079203, 0.19664364, 0.11906379, 0.07816055, + 0.22538587, 0.13483174 + ], [ + 0.17928453, 0.06065261, 0.41153005, 0.1172041, 0.11880313, + 0.07113197, 0.04139363 + ], [ + 0.15882358, 0.1235788, 0.23376776, 0.20510435, 0.00279306, + 0.05294827, 0.22298418 + ]] + self.best_path_result = ["ac'bdc", "b'da"] + self.beam_search_result = ['acdc', "b'a"] + + def test_best_path_decoder_1(self): + bst_result = ctc_best_path_decoder(self.probs_seq1, self.vocab_list) + self.assertEqual(bst_result, self.best_path_result[0]) + + def test_best_path_decoder_2(self): + bst_result = ctc_best_path_decoder(self.probs_seq2, self.vocab_list) + self.assertEqual(bst_result, self.best_path_result[1]) + + def test_beam_search_decoder_1(self): + beam_result = ctc_beam_search_decoder( + probs_seq=self.probs_seq1, + beam_size=self.beam_size, + vocabulary=self.vocab_list, + blank_id=len(self.vocab_list)) + self.assertEqual(beam_result[0][1], self.beam_search_result[0]) + + def test_beam_search_decoder_2(self): + beam_result = ctc_beam_search_decoder( + probs_seq=self.probs_seq2, + beam_size=self.beam_size, + vocabulary=self.vocab_list, + blank_id=len(self.vocab_list)) + self.assertEqual(beam_result[0][1], self.beam_search_result[1]) + + def test_beam_search_decoder_batch(self): + beam_results = ctc_beam_search_decoder_batch( + probs_split=[self.probs_seq1, self.probs_seq2], + beam_size=self.beam_size, + vocabulary=self.vocab_list, + blank_id=len(self.vocab_list), + num_processes=24) + self.assertEqual(beam_results[0][0][1], self.beam_search_result[0]) + self.assertEqual(beam_results[1][0][1], self.beam_search_result[1]) + + +if __name__ == '__main__': + unittest.main() diff --git a/train.py b/train.py index 3a2d0cad9ec9635c7e44e0149e426842a5e892b6..6481074c6e58f98f57f81c6e42480fa00a261bbe 100644 --- a/train.py +++ b/train.py @@ -53,6 +53,12 @@ parser.add_argument( default=True, type=distutils.util.strtobool, help="Use sortagrad or not. (default: %(default)s)") +parser.add_argument( + "--specgram_type", + default='linear', + type=str, + help="Feature type of audio data: 'linear' (power spectrum)" + " or 'mfcc'. (default: %(default)s)") parser.add_argument( "--max_duration", default=27.0, @@ -130,6 +136,7 @@ def train(): augmentation_config=args.augmentation_config, max_duration=args.max_duration, min_duration=args.min_duration, + specgram_type=args.specgram_type, num_threads=args.num_threads_data) train_generator = data_generator() diff --git a/tune.py b/tune.py new file mode 100644 index 0000000000000000000000000000000000000000..2fcca48628aa0aba7fd2e09a1d9ba90582492f89 --- /dev/null +++ b/tune.py @@ -0,0 +1,224 @@ +"""Parameters tuning for DeepSpeech2 model.""" +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import distutils.util +import argparse +import gzip +import paddle.v2 as paddle +from data_utils.data import DataGenerator +from model import deep_speech2 +from decoder import * +from lm.lm_scorer import LmScorer +from error_rate import wer +import utils + +parser = argparse.ArgumentParser(description=__doc__) +parser.add_argument( + "--num_samples", + default=100, + type=int, + help="Number of samples for parameters tuning. (default: %(default)s)") +parser.add_argument( + "--num_conv_layers", + default=2, + type=int, + help="Convolution layer number. (default: %(default)s)") +parser.add_argument( + "--num_rnn_layers", + default=3, + type=int, + help="RNN layer number. (default: %(default)s)") +parser.add_argument( + "--rnn_layer_size", + default=512, + type=int, + help="RNN layer cell number. (default: %(default)s)") +parser.add_argument( + "--use_gpu", + default=True, + type=distutils.util.strtobool, + help="Use gpu or not. (default: %(default)s)") +parser.add_argument( + "--num_threads_data", + default=multiprocessing.cpu_count(), + type=int, + help="Number of cpu threads for preprocessing data. (default: %(default)s)") +parser.add_argument( + "--num_processes_beam_search", + default=multiprocessing.cpu_count(), + type=int, + help="Number of cpu processes for beam search. (default: %(default)s)") +parser.add_argument( + "--specgram_type", + default='linear', + type=str, + help="Feature type of audio data: 'linear' (power spectrum)" + " or 'mfcc'. (default: %(default)s)") +parser.add_argument( + "--mean_std_filepath", + default='mean_std.npz', + type=str, + help="Manifest path for normalizer. (default: %(default)s)") +parser.add_argument( + "--decode_manifest_path", + default='datasets/manifest.test', + type=str, + help="Manifest path for decoding. (default: %(default)s)") +parser.add_argument( + "--model_filepath", + default='checkpoints/params.latest.tar.gz', + type=str, + help="Model filepath. (default: %(default)s)") +parser.add_argument( + "--vocab_filepath", + default='datasets/vocab/eng_vocab.txt', + type=str, + help="Vocabulary filepath. (default: %(default)s)") +parser.add_argument( + "--beam_size", + default=500, + type=int, + help="Width for beam search decoding. (default: %(default)d)") +parser.add_argument( + "--language_model_path", + default="lm/data/common_crawl_00.prune01111.trie.klm", + type=str, + help="Path for language model. (default: %(default)s)") +parser.add_argument( + "--alpha_from", + default=0.1, + type=float, + help="Where alpha starts from. (default: %(default)f)") +parser.add_argument( + "--num_alphas", + default=14, + type=int, + help="Number of candidate alphas. (default: %(default)d)") +parser.add_argument( + "--alpha_to", + default=0.36, + type=float, + help="Where alpha ends with. (default: %(default)f)") +parser.add_argument( + "--beta_from", + default=0.05, + type=float, + help="Where beta starts from. (default: %(default)f)") +parser.add_argument( + "--num_betas", + default=20, + type=float, + help="Number of candidate betas. (default: %(default)d)") +parser.add_argument( + "--beta_to", + default=1.0, + type=float, + help="Where beta ends with. (default: %(default)f)") +parser.add_argument( + "--cutoff_prob", + default=0.99, + type=float, + help="The cutoff probability of pruning" + "in beam search. (default: %(default)f)") +args = parser.parse_args() + + +def tune(): + """Tune parameters alpha and beta on one minibatch.""" + + if not args.num_alphas >= 0: + raise ValueError("num_alphas must be non-negative!") + + if not args.num_betas >= 0: + raise ValueError("num_betas must be non-negative!") + + # initialize data generator + data_generator = DataGenerator( + vocab_filepath=args.vocab_filepath, + mean_std_filepath=args.mean_std_filepath, + augmentation_config='{}', + specgram_type=args.specgram_type, + num_threads=args.num_threads_data) + + # create network config + # paddle.data_type.dense_array is used for variable batch input. + # The size 161 * 161 is only an placeholder value and the real shape + # of input batch data will be induced during training. + audio_data = paddle.layer.data( + name="audio_spectrogram", type=paddle.data_type.dense_array(161 * 161)) + text_data = paddle.layer.data( + name="transcript_text", + type=paddle.data_type.integer_value_sequence(data_generator.vocab_size)) + output_probs = deep_speech2( + audio_data=audio_data, + text_data=text_data, + dict_size=data_generator.vocab_size, + num_conv_layers=args.num_conv_layers, + num_rnn_layers=args.num_rnn_layers, + rnn_size=args.rnn_layer_size, + is_inference=True) + + # load parameters + parameters = paddle.parameters.Parameters.from_tar( + gzip.open(args.model_filepath)) + + # prepare infer data + batch_reader = data_generator.batch_reader_creator( + manifest_path=args.decode_manifest_path, + batch_size=args.num_samples, + sortagrad=False, + shuffle_method=None) + # get one batch data for tuning + infer_data = batch_reader().next() + + # run inference + infer_results = paddle.infer( + output_layer=output_probs, parameters=parameters, input=infer_data) + num_steps = len(infer_results) // len(infer_data) + probs_split = [ + infer_results[i * num_steps:(i + 1) * num_steps] + for i in xrange(0, len(infer_data)) + ] + + # create grid for search + cand_alphas = np.linspace(args.alpha_from, args.alpha_to, args.num_alphas) + cand_betas = np.linspace(args.beta_from, args.beta_to, args.num_betas) + params_grid = [(alpha, beta) for alpha in cand_alphas + for beta in cand_betas] + + ext_scorer = LmScorer(args.alpha_from, args.beta_from, + args.language_model_path) + ## tune parameters in loop + for alpha, beta in params_grid: + wer_sum, wer_counter = 0, 0 + # reset scorer + ext_scorer.reset_params(alpha, beta) + # beam search using multiple processes + beam_search_results = ctc_beam_search_decoder_batch( + probs_split=probs_split, + vocabulary=data_generator.vocab_list, + beam_size=args.beam_size, + cutoff_prob=args.cutoff_prob, + blank_id=len(data_generator.vocab_list), + num_processes=args.num_processes_beam_search, + ext_scoring_func=ext_scorer, ) + for i, beam_search_result in enumerate(beam_search_results): + target_transcription = ''.join([ + data_generator.vocab_list[index] for index in infer_data[i][1] + ]) + wer_sum += wer(target_transcription, beam_search_result[0][1]) + wer_counter += 1 + + print("alpha = %f\tbeta = %f\tWER = %f" % + (alpha, beta, wer_sum / wer_counter)) + + +def main(): + paddle.init(use_gpu=args.use_gpu, trainer_count=1) + tune() + + +if __name__ == '__main__': + main()