From d7222c0453f76c7cbec4677e03fba00659c1cfc9 Mon Sep 17 00:00:00 2001 From: Jackwaterveg <87408988+Jackwaterveg@users.noreply.github.com> Date: Mon, 24 Jan 2022 17:27:51 +0800 Subject: [PATCH] [ASR] Support CTC decoder online (#821) * fix the destructer problem for prefixes * unified offline and online in ctcdecoders, test=asr * rename swig_decoders to paddlespeech_ctcdecoders, test=asr * add reset_stage for ctcdecoder * fix some problems * fix ctconline * fix a bug * fix the format * fix 1xt2x --- .../models/ds2/deepspeech2.py | 36 +- .../1xt2x/src_deepspeech2x/test_model.py | 34 +- .../s2t/decoders/ctcdecoder/__init__.py | 5 + .../s2t/decoders/ctcdecoder/swig_wrapper.py | 77 ++-- paddlespeech/s2t/exps/deepspeech2/model.py | 100 +++-- paddlespeech/s2t/models/ds2/__init__.py | 2 +- paddlespeech/s2t/models/ds2/deepspeech2.py | 22 +- .../s2t/models/ds2_online/__init__.py | 2 +- .../s2t/models/ds2_online/deepspeech2.py | 22 +- paddlespeech/s2t/models/u2/u2.py | 6 +- paddlespeech/s2t/models/u2_st/u2_st.py | 9 +- paddlespeech/s2t/modules/ctc.py | 245 ++++++++++-- .../ctc_decoders/ctc_beam_search_decoder.cpp | 370 +++++++++++++++++- .../ctc_decoders/ctc_beam_search_decoder.h | 102 ++++- .../ctc_decoders/ctc_greedy_decoder.cpp | 2 +- third_party/ctc_decoders/ctc_greedy_decoder.h | 2 +- third_party/ctc_decoders/decoders.i | 2 +- third_party/ctc_decoders/path_trie.cpp | 27 +- third_party/ctc_decoders/scorer.cpp | 3 +- third_party/ctc_decoders/scorer.h | 3 +- third_party/ctc_decoders/setup.py | 4 +- 21 files changed, 865 insertions(+), 210 deletions(-) diff --git a/examples/other/1xt2x/src_deepspeech2x/models/ds2/deepspeech2.py b/examples/other/1xt2x/src_deepspeech2x/models/ds2/deepspeech2.py index fb8b321c..59be4222 100644 --- a/examples/other/1xt2x/src_deepspeech2x/models/ds2/deepspeech2.py +++ b/examples/other/1xt2x/src_deepspeech2x/models/ds2/deepspeech2.py @@ -162,39 +162,17 @@ class DeepSpeech2Model(nn.Layer): return loss @paddle.no_grad() - def decode(self, audio, audio_len, vocab_list, decoding_method, - lang_model_path, beam_alpha, beam_beta, beam_size, cutoff_prob, - cutoff_top_n, num_processes): - # init once + def decode(self, audio, audio_len): # decoders only accept string encoded in utf-8 - self.decoder.init_decode( - beam_alpha=beam_alpha, - beam_beta=beam_beta, - lang_model_path=lang_model_path, - vocab_list=vocab_list, - decoding_method=decoding_method) + # Make sure the decoder has been initialized eouts, eouts_len = self.encoder(audio, audio_len) probs = self.decoder.softmax(eouts) - print("probs.shape", probs.shape) - return self.decoder.decode_probs( - probs.numpy(), eouts_len, vocab_list, decoding_method, - lang_model_path, beam_alpha, beam_beta, beam_size, cutoff_prob, - cutoff_top_n, num_processes) - - def decode_probs_split(self, probs_split, vocab_list, decoding_method, - lang_model_path, beam_alpha, beam_beta, beam_size, - cutoff_prob, cutoff_top_n, num_processes): - self.decoder.init_decode( - beam_alpha=beam_alpha, - beam_beta=beam_beta, - lang_model_path=lang_model_path, - vocab_list=vocab_list, - decoding_method=decoding_method) - return self.decoder.decode_probs_split( - probs_split, vocab_list, decoding_method, lang_model_path, - beam_alpha, beam_beta, beam_size, cutoff_prob, cutoff_top_n, - num_processes) + batch_size = probs.shape[0] + self.decoder.reset_decoder(batch_size = batch_size) + self.decoder.next(probs, eouts_len) + trans_best, trans_beam = self.decoder.decode() + return trans_best @classmethod def from_pretrained(cls, dataloader, config, checkpoint_path): diff --git a/examples/other/1xt2x/src_deepspeech2x/test_model.py b/examples/other/1xt2x/src_deepspeech2x/test_model.py index 2a38fb5c..11b85442 100644 --- a/examples/other/1xt2x/src_deepspeech2x/test_model.py +++ b/examples/other/1xt2x/src_deepspeech2x/test_model.py @@ -254,12 +254,10 @@ class DeepSpeech2Tester(DeepSpeech2Trainer): errors_func = error_rate.char_errors if cfg.error_rate_type == 'cer' else error_rate.word_errors error_rate_func = error_rate.cer if cfg.error_rate_type == 'cer' else error_rate.wer - vocab_list = self.test_loader.collate_fn.vocab_list - target_transcripts = self.ordid2token(texts, texts_len) - result_transcripts = self.compute_result_transcripts(audio, audio_len, - vocab_list, cfg) + result_transcripts = self.compute_result_transcripts(audio, audio_len) + for utt, target, result in zip(utts, target_transcripts, result_transcripts): errors, len_ref = errors_func(target, result) @@ -280,19 +278,9 @@ class DeepSpeech2Tester(DeepSpeech2Trainer): error_rate=errors_sum / len_refs, error_rate_type=cfg.error_rate_type) - def compute_result_transcripts(self, audio, audio_len, vocab_list, cfg): - result_transcripts = self.model.decode( - audio, - audio_len, - vocab_list, - decoding_method=cfg.decoding_method, - lang_model_path=cfg.lang_model_path, - beam_alpha=cfg.alpha, - beam_beta=cfg.beta, - beam_size=cfg.beam_size, - cutoff_prob=cfg.cutoff_prob, - cutoff_top_n=cfg.cutoff_top_n, - num_processes=cfg.num_proc_bsearch) + def compute_result_transcripts(self, audio, audio_len): + result_transcripts = self.model.decode(audio, audio_len) + result_transcripts = [ self._text_featurizer.detokenize(item) for item in result_transcripts @@ -307,6 +295,17 @@ class DeepSpeech2Tester(DeepSpeech2Trainer): cfg = self.config error_rate_type = None errors_sum, len_refs, num_ins = 0.0, 0, 0 + + # Initialized the decoder in model + decode_cfg = self.config.decode + vocab_list = self.test_loader.collate_fn.vocab_list + decode_batch_size = self.test_loader.batch_size + self.model.decoder.init_decoder( + decode_batch_size, vocab_list, decode_cfg.decoding_method, + decode_cfg.lang_model_path, decode_cfg.alpha, decode_cfg.beta, + decode_cfg.beam_size, decode_cfg.cutoff_prob, + decode_cfg.cutoff_top_n, decode_cfg.num_proc_bsearch) + with open(self.args.result_file, 'w') as fout: for i, batch in enumerate(self.test_loader): utts, audio, audio_len, texts, texts_len = batch @@ -326,6 +325,7 @@ class DeepSpeech2Tester(DeepSpeech2Trainer): msg += "Final error rate [%s] (%d/%d) = %f" % ( error_rate_type, num_ins, num_ins, errors_sum / len_refs) logger.info(msg) + self.model.decoder.del_decoder() def run_test(self): self.resume_or_scratch() diff --git a/paddlespeech/s2t/decoders/ctcdecoder/__init__.py b/paddlespeech/s2t/decoders/ctcdecoder/__init__.py index 185a92b8..37ceae6e 100644 --- a/paddlespeech/s2t/decoders/ctcdecoder/__init__.py +++ b/paddlespeech/s2t/decoders/ctcdecoder/__init__.py @@ -11,3 +11,8 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +from .swig_wrapper import ctc_beam_search_decoding +from .swig_wrapper import ctc_beam_search_decoding_batch +from .swig_wrapper import ctc_greedy_decoding +from .swig_wrapper import CTCBeamSearchDecoder +from .swig_wrapper import Scorer diff --git a/paddlespeech/s2t/decoders/ctcdecoder/swig_wrapper.py b/paddlespeech/s2t/decoders/ctcdecoder/swig_wrapper.py index d883d430..9e2a8506 100644 --- a/paddlespeech/s2t/decoders/ctcdecoder/swig_wrapper.py +++ b/paddlespeech/s2t/decoders/ctcdecoder/swig_wrapper.py @@ -12,10 +12,10 @@ # See the License for the specific language governing permissions and # limitations under the License. """Wrapper for various CTC decoders in SWIG.""" -import swig_decoders +import paddlespeech_ctcdecoders -class Scorer(swig_decoders.Scorer): +class Scorer(paddlespeech_ctcdecoders.Scorer): """Wrapper for Scorer. :param alpha: Parameter associated with language model. Don't use @@ -26,14 +26,17 @@ class Scorer(swig_decoders.Scorer): :type beta: float :model_path: Path to load language model. :type model_path: str + :param vocabulary: Vocabulary list. + :type vocabulary: list """ def __init__(self, alpha, beta, model_path, vocabulary): - swig_decoders.Scorer.__init__(self, alpha, beta, model_path, vocabulary) + paddlespeech_ctcdecoders.Scorer.__init__(self, alpha, beta, model_path, + vocabulary) -def ctc_greedy_decoder(probs_seq, vocabulary, blank_id): - """Wrapper for ctc best path decoder in swig. +def ctc_greedy_decoding(probs_seq, vocabulary, blank_id): + """Wrapper for ctc best path decodeing function in swig. :param probs_seq: 2-D list of probability distributions over each time step, with each element being a list of normalized @@ -44,19 +47,19 @@ def ctc_greedy_decoder(probs_seq, vocabulary, blank_id): :return: Decoding result string. :rtype: str """ - result = swig_decoders.ctc_greedy_decoder(probs_seq.tolist(), vocabulary, - blank_id) + result = paddlespeech_ctcdecoders.ctc_greedy_decoding(probs_seq.tolist(), + vocabulary, blank_id) return result -def ctc_beam_search_decoder(probs_seq, - vocabulary, - beam_size, - cutoff_prob=1.0, - cutoff_top_n=40, - ext_scoring_func=None, - blank_id=0): - """Wrapper for the CTC Beam Search Decoder. +def ctc_beam_search_decoding(probs_seq, + vocabulary, + beam_size, + cutoff_prob=1.0, + cutoff_top_n=40, + ext_scoring_func=None, + blank_id=0): + """Wrapper for the CTC Beam Search Decoding function. :param probs_seq: 2-D list of probability distributions over each time step, with each element being a list of normalized @@ -81,22 +84,22 @@ def ctc_beam_search_decoder(probs_seq, results, in descending order of the probability. :rtype: list """ - beam_results = swig_decoders.ctc_beam_search_decoder( + beam_results = paddlespeech_ctcdecoders.ctc_beam_search_decoding( probs_seq.tolist(), vocabulary, beam_size, cutoff_prob, cutoff_top_n, ext_scoring_func, blank_id) beam_results = [(res[0], res[1].decode('utf-8')) for res in beam_results] return beam_results -def ctc_beam_search_decoder_batch(probs_split, - vocabulary, - beam_size, - num_processes, - cutoff_prob=1.0, - cutoff_top_n=40, - ext_scoring_func=None, - blank_id=0): - """Wrapper for the batched CTC beam search decoder. +def ctc_beam_search_decoding_batch(probs_split, + vocabulary, + beam_size, + num_processes, + cutoff_prob=1.0, + cutoff_top_n=40, + ext_scoring_func=None, + blank_id=0): + """Wrapper for the batched CTC beam search decodeing batch function. :param probs_seq: 3-D list with each element as an instance of 2-D list of probabilities used by ctc_beam_search_decoder(). @@ -126,9 +129,31 @@ def ctc_beam_search_decoder_batch(probs_split, """ probs_split = [probs_seq.tolist() for probs_seq in probs_split] - batch_beam_results = swig_decoders.ctc_beam_search_decoder_batch( + batch_beam_results = paddlespeech_ctcdecoders.ctc_beam_search_decoding_batch( probs_split, vocabulary, beam_size, num_processes, cutoff_prob, cutoff_top_n, ext_scoring_func, blank_id) batch_beam_results = [[(res[0], res[1]) for res in beam_results] for beam_results in batch_beam_results] return batch_beam_results + + +class CTCBeamSearchDecoder(paddlespeech_ctcdecoders.CtcBeamSearchDecoderBatch): + """Wrapper for CtcBeamSearchDecoderBatch. + Args: + vocab_list (list): Vocabulary list. + beam_size (int): Width for beam search. + num_processes (int): Number of parallel processes. + param cutoff_prob (float): Cutoff probability in vocabulary pruning, + default 1.0, no pruning. + cutoff_top_n (int): Cutoff number in pruning, only top cutoff_top_n + characters with highest probs in vocabulary will be + used in beam search, default 40. + param ext_scorer (Scorer): External scorer for partially decoded sentence, e.g. word count + or language model. + """ + + def __init__(self, vocab_list, batch_size, beam_size, num_processes, + cutoff_prob, cutoff_top_n, _ext_scorer, blank_id): + paddlespeech_ctcdecoders.CtcBeamSearchDecoderBatch.__init__( + self, vocab_list, batch_size, beam_size, num_processes, cutoff_prob, + cutoff_top_n, _ext_scorer, blank_id) diff --git a/paddlespeech/s2t/exps/deepspeech2/model.py b/paddlespeech/s2t/exps/deepspeech2/model.py index 049311c7..3e9ede76 100644 --- a/paddlespeech/s2t/exps/deepspeech2/model.py +++ b/paddlespeech/s2t/exps/deepspeech2/model.py @@ -267,12 +267,9 @@ class DeepSpeech2Tester(DeepSpeech2Trainer): errors_func = error_rate.char_errors if decode_cfg.error_rate_type == 'cer' else error_rate.word_errors error_rate_func = error_rate.cer if decode_cfg.error_rate_type == 'cer' else error_rate.wer - vocab_list = self.test_loader.collate_fn.vocab_list - target_transcripts = self.ordid2token(texts, texts_len) - result_transcripts = self.compute_result_transcripts( - audio, audio_len, vocab_list, decode_cfg) + result_transcripts = self.compute_result_transcripts(audio, audio_len) for utt, target, result in zip(utts, target_transcripts, result_transcripts): @@ -296,21 +293,8 @@ class DeepSpeech2Tester(DeepSpeech2Trainer): error_rate=errors_sum / len_refs, error_rate_type=decode_cfg.error_rate_type) - def compute_result_transcripts(self, audio, audio_len, vocab_list, - decode_cfg): - result_transcripts = self.model.decode( - audio, - audio_len, - vocab_list, - decoding_method=decode_cfg.decoding_method, - lang_model_path=decode_cfg.lang_model_path, - beam_alpha=decode_cfg.alpha, - beam_beta=decode_cfg.beta, - beam_size=decode_cfg.beam_size, - cutoff_prob=decode_cfg.cutoff_prob, - cutoff_top_n=decode_cfg.cutoff_top_n, - num_processes=decode_cfg.num_proc_bsearch) - + def compute_result_transcripts(self, audio, audio_len): + result_transcripts = self.model.decode(audio, audio_len) return result_transcripts @mp_tools.rank_zero_only @@ -320,6 +304,17 @@ class DeepSpeech2Tester(DeepSpeech2Trainer): self.model.eval() error_rate_type = None errors_sum, len_refs, num_ins = 0.0, 0, 0 + + # Initialized the decoder in model + decode_cfg = self.config.decode + vocab_list = self.test_loader.collate_fn.vocab_list + decode_batch_size = self.test_loader.batch_size + self.model.decoder.init_decoder( + decode_batch_size, vocab_list, decode_cfg.decoding_method, + decode_cfg.lang_model_path, decode_cfg.alpha, decode_cfg.beta, + decode_cfg.beam_size, decode_cfg.cutoff_prob, + decode_cfg.cutoff_top_n, decode_cfg.num_proc_bsearch) + with jsonlines.open(self.args.result_file, 'w') as fout: for i, batch in enumerate(self.test_loader): utts, audio, audio_len, texts, texts_len = batch @@ -339,6 +334,7 @@ class DeepSpeech2Tester(DeepSpeech2Trainer): msg += "Final error rate [%s] (%d/%d) = %f" % ( error_rate_type, num_ins, num_ins, errors_sum / len_refs) logger.info(msg) + self.model.decoder.del_decoder() @paddle.no_grad() def export(self): @@ -377,6 +373,22 @@ class DeepSpeech2ExportTester(DeepSpeech2Tester): self.model.eval() error_rate_type = None errors_sum, len_refs, num_ins = 0.0, 0, 0 + + # Initialized the decoder in model + decode_cfg = self.config.decode + vocab_list = self.test_loader.collate_fn.vocab_list + if self.args.model_type == "online": + decode_batch_size = 1 + elif self.args.model_type == "offline": + decode_batch_size = self.test_loader.batch_size + else: + raise Exception("wrong model type") + self.model.decoder.init_decoder( + decode_batch_size, vocab_list, decode_cfg.decoding_method, + decode_cfg.lang_model_path, decode_cfg.alpha, decode_cfg.beta, + decode_cfg.beam_size, decode_cfg.cutoff_prob, + decode_cfg.cutoff_top_n, decode_cfg.num_proc_bsearch) + with jsonlines.open(self.args.result_file, 'w') as fout: for i, batch in enumerate(self.test_loader): utts, audio, audio_len, texts, texts_len = batch @@ -388,7 +400,6 @@ class DeepSpeech2ExportTester(DeepSpeech2Tester): error_rate_type = metrics['error_rate_type'] logger.info("Error rate [%s] (%d/?) = %f" % (error_rate_type, num_ins, errors_sum / len_refs)) - # logging msg = "Test: " msg += "epoch: {}, ".format(self.epoch) @@ -398,30 +409,31 @@ class DeepSpeech2ExportTester(DeepSpeech2Tester): logger.info(msg) if self.args.enable_auto_log is True: self.autolog.report() + self.model.decoder.del_decoder() - def compute_result_transcripts(self, audio, audio_len, vocab_list, - decode_cfg): + def compute_result_transcripts(self, audio, audio_len): if self.args.model_type == "online": - output_probs, output_lens = self.static_forward_online(audio, - audio_len) + output_probs, output_lens, trans_batch = self.static_forward_online( + audio, audio_len, decoder_chunk_size=1) + result_transcripts = [trans[-1] for trans in trans_batch] elif self.args.model_type == "offline": output_probs, output_lens = self.static_forward_offline(audio, audio_len) + batch_size = output_probs.shape[0] + self.model.decoder.reset_decoder(batch_size=batch_size) + + self.model.decoder.next(output_probs, output_lens) + + trans_best, trans_beam = self.model.decoder.decode() + + result_transcripts = trans_best + else: raise Exception("wrong model type") self.predictor.clear_intermediate_tensor() self.predictor.try_shrink_memory() - self.model.decoder.init_decode(decode_cfg.alpha, decode_cfg.beta, - decode_cfg.lang_model_path, vocab_list, - decode_cfg.decoding_method) - - result_transcripts = self.model.decoder.decode_probs( - output_probs, output_lens, vocab_list, decode_cfg.decoding_method, - decode_cfg.lang_model_path, decode_cfg.alpha, decode_cfg.beta, - decode_cfg.beam_size, decode_cfg.cutoff_prob, - decode_cfg.cutoff_top_n, decode_cfg.num_proc_bsearch) #replace the with ' ' result_transcripts = [ self._text_featurizer.detokenize(sentence) @@ -451,6 +463,7 @@ class DeepSpeech2ExportTester(DeepSpeech2Tester): ------- output_probs(numpy.array): shape[B, T, vocab_size] output_lens(numpy.array): shape[B] + trans(list(list(str))): shape[B, T] """ output_probs_list = [] output_lens_list = [] @@ -464,14 +477,15 @@ class DeepSpeech2ExportTester(DeepSpeech2Tester): batch_size, Tmax, x_dim = x_batch.shape x_len_batch = audio_len.numpy().astype(np.int64) if (Tmax - chunk_size) % chunk_stride != 0: - padding_len_batch = chunk_stride - ( - Tmax - chunk_size - ) % chunk_stride # The length of padding for the batch + # The length of padding for the batch + padding_len_batch = chunk_stride - (Tmax - chunk_size + ) % chunk_stride else: padding_len_batch = 0 x_list = np.split(x_batch, batch_size, axis=0) x_len_list = np.split(x_len_batch, batch_size, axis=0) + trans_batch = [] for x, x_len in zip(x_list, x_len_list): if self.args.enable_auto_log is True: self.autolog.times.start() @@ -504,12 +518,14 @@ class DeepSpeech2ExportTester(DeepSpeech2Tester): h_box_handle = self.predictor.get_input_handle(input_names[2]) c_box_handle = self.predictor.get_input_handle(input_names[3]) + trans = [] probs_chunk_list = [] probs_chunk_lens_list = [] if self.args.enable_auto_log is True: # record the model preprocessing time self.autolog.times.stamp() + self.model.decoder.reset_decoder(batch_size=1) for i in range(0, num_chunk): start = i * chunk_stride end = start + chunk_size @@ -518,9 +534,8 @@ class DeepSpeech2ExportTester(DeepSpeech2Tester): x_chunk_lens = 0 else: x_chunk_lens = min(x_len - i * chunk_stride, chunk_size) - - if (x_chunk_lens < - receptive_field_length): #means the number of input frames in the chunk is not enough for predicting one prob + #means the number of input frames in the chunk is not enough for predicting one prob + if (x_chunk_lens < receptive_field_length): break x_chunk_lens = np.array([x_chunk_lens]) audio_handle.reshape(x_chunk.shape) @@ -549,9 +564,12 @@ class DeepSpeech2ExportTester(DeepSpeech2Tester): output_chunk_lens = output_lens_handle.copy_to_cpu() chunk_state_h_box = output_state_h_handle.copy_to_cpu() chunk_state_c_box = output_state_c_handle.copy_to_cpu() - + self.model.decoder.next(output_chunk_probs, output_chunk_lens) probs_chunk_list.append(output_chunk_probs) probs_chunk_lens_list.append(output_chunk_lens) + trans_best, trans_beam = self.model.decoder.decode() + trans.append(trans_best[0]) + trans_batch.append(trans) output_probs = np.concatenate(probs_chunk_list, axis=1) output_lens = np.sum(probs_chunk_lens_list, axis=0) vocab_size = output_probs.shape[2] @@ -573,7 +591,7 @@ class DeepSpeech2ExportTester(DeepSpeech2Tester): self.autolog.times.end() output_probs = np.concatenate(output_probs_list, axis=0) output_lens = np.concatenate(output_lens_list, axis=0) - return output_probs, output_lens + return output_probs, output_lens, trans_batch def static_forward_offline(self, audio, audio_len): """ diff --git a/paddlespeech/s2t/models/ds2/__init__.py b/paddlespeech/s2t/models/ds2/__init__.py index 8d5959c8..b3222067 100644 --- a/paddlespeech/s2t/models/ds2/__init__.py +++ b/paddlespeech/s2t/models/ds2/__init__.py @@ -16,7 +16,7 @@ from .deepspeech2 import DeepSpeech2Model from paddlespeech.s2t.utils import dynamic_pip_install try: - import swig_decoders + import paddlespeech_ctcdecoders except ImportError: try: package_name = 'paddlespeech_ctcdecoders' diff --git a/paddlespeech/s2t/models/ds2/deepspeech2.py b/paddlespeech/s2t/models/ds2/deepspeech2.py index 4a4d67ce..9c6b66c2 100644 --- a/paddlespeech/s2t/models/ds2/deepspeech2.py +++ b/paddlespeech/s2t/models/ds2/deepspeech2.py @@ -164,24 +164,18 @@ class DeepSpeech2Model(nn.Layer): return loss @paddle.no_grad() - def decode(self, audio, audio_len, vocab_list, decoding_method, - lang_model_path, beam_alpha, beam_beta, beam_size, cutoff_prob, - cutoff_top_n, num_processes): - # init once + def decode(self, audio, audio_len): # decoders only accept string encoded in utf-8 - self.decoder.init_decode( - beam_alpha=beam_alpha, - beam_beta=beam_beta, - lang_model_path=lang_model_path, - vocab_list=vocab_list, - decoding_method=decoding_method) + # Make sure the decoder has been initialized eouts, eouts_len = self.encoder(audio, audio_len) probs = self.decoder.softmax(eouts) - return self.decoder.decode_probs( - probs.numpy(), eouts_len, vocab_list, decoding_method, - lang_model_path, beam_alpha, beam_beta, beam_size, cutoff_prob, - cutoff_top_n, num_processes) + batch_size = probs.shape[0] + self.decoder.reset_decoder(batch_size=batch_size) + self.decoder.next(probs, eouts_len) + trans_best, trans_beam = self.decoder.decode() + + return trans_best @classmethod def from_pretrained(cls, dataloader, config, checkpoint_path): diff --git a/paddlespeech/s2t/models/ds2_online/__init__.py b/paddlespeech/s2t/models/ds2_online/__init__.py index 2d304237..c5fdab1b 100644 --- a/paddlespeech/s2t/models/ds2_online/__init__.py +++ b/paddlespeech/s2t/models/ds2_online/__init__.py @@ -16,7 +16,7 @@ from .deepspeech2 import DeepSpeech2ModelOnline from paddlespeech.s2t.utils import dynamic_pip_install try: - import swig_decoders + import paddlespeech_ctcdecoders except ImportError: try: package_name = 'paddlespeech_ctcdecoders' diff --git a/paddlespeech/s2t/models/ds2_online/deepspeech2.py b/paddlespeech/s2t/models/ds2_online/deepspeech2.py index 5e4981c0..9574a62b 100644 --- a/paddlespeech/s2t/models/ds2_online/deepspeech2.py +++ b/paddlespeech/s2t/models/ds2_online/deepspeech2.py @@ -293,25 +293,17 @@ class DeepSpeech2ModelOnline(nn.Layer): return loss @paddle.no_grad() - def decode(self, audio, audio_len, vocab_list, decoding_method, - lang_model_path, beam_alpha, beam_beta, beam_size, cutoff_prob, - cutoff_top_n, num_processes): - # init once + def decode(self, audio, audio_len): # decoders only accept string encoded in utf-8 - self.decoder.init_decode( - beam_alpha=beam_alpha, - beam_beta=beam_beta, - lang_model_path=lang_model_path, - vocab_list=vocab_list, - decoding_method=decoding_method) - + # Make sure the decoder has been initialized eouts, eouts_len, final_state_h_box, final_state_c_box = self.encoder( audio, audio_len, None, None) probs = self.decoder.softmax(eouts) - return self.decoder.decode_probs( - probs.numpy(), eouts_len, vocab_list, decoding_method, - lang_model_path, beam_alpha, beam_beta, beam_size, cutoff_prob, - cutoff_top_n, num_processes) + batch_size = probs.shape[0] + self.decoder.reset_decoder(batch_size=batch_size) + self.decoder.next(probs, eouts_len) + trans_best, trans_beam = self.decoder.decode() + return trans_best @classmethod def from_pretrained(cls, dataloader, config, checkpoint_path): diff --git a/paddlespeech/s2t/models/u2/u2.py b/paddlespeech/s2t/models/u2/u2.py index ff4012e8..b6ec5f90 100644 --- a/paddlespeech/s2t/models/u2/u2.py +++ b/paddlespeech/s2t/models/u2/u2.py @@ -32,7 +32,7 @@ from paddlespeech.s2t.frontend.utility import IGNORE_ID from paddlespeech.s2t.frontend.utility import load_cmvn from paddlespeech.s2t.models.asr_interface import ASRInterface from paddlespeech.s2t.modules.cmvn import GlobalCMVN -from paddlespeech.s2t.modules.ctc import CTCDecoder +from paddlespeech.s2t.modules.ctc import CTCDecoderBase from paddlespeech.s2t.modules.decoder import TransformerDecoder from paddlespeech.s2t.modules.encoder import ConformerEncoder from paddlespeech.s2t.modules.encoder import TransformerEncoder @@ -63,7 +63,7 @@ class U2BaseModel(ASRInterface, nn.Layer): vocab_size: int, encoder: TransformerEncoder, decoder: TransformerDecoder, - ctc: CTCDecoder, + ctc: CTCDecoderBase, ctc_weight: float=0.5, ignore_id: int=IGNORE_ID, lsm_weight: float=0.0, @@ -840,7 +840,7 @@ class U2Model(U2DecodeModel): model_conf = configs.get('model_conf', dict()) dropout_rate = model_conf.get('ctc_dropout_rate', 0.0) grad_norm_type = model_conf.get('ctc_grad_norm_type', None) - ctc = CTCDecoder( + ctc = CTCDecoderBase( odim=vocab_size, enc_n_units=encoder.output_size(), blank_id=0, diff --git a/paddlespeech/s2t/models/u2_st/u2_st.py b/paddlespeech/s2t/models/u2_st/u2_st.py index bc76de7a..f7b05714 100644 --- a/paddlespeech/s2t/models/u2_st/u2_st.py +++ b/paddlespeech/s2t/models/u2_st/u2_st.py @@ -28,7 +28,7 @@ from paddle import nn from paddlespeech.s2t.frontend.utility import IGNORE_ID from paddlespeech.s2t.frontend.utility import load_cmvn from paddlespeech.s2t.modules.cmvn import GlobalCMVN -from paddlespeech.s2t.modules.ctc import CTCDecoder +from paddlespeech.s2t.modules.ctc import CTCDecoderBase from paddlespeech.s2t.modules.decoder import TransformerDecoder from paddlespeech.s2t.modules.encoder import ConformerEncoder from paddlespeech.s2t.modules.encoder import TransformerEncoder @@ -56,7 +56,7 @@ class U2STBaseModel(nn.Layer): encoder: TransformerEncoder, st_decoder: TransformerDecoder, decoder: TransformerDecoder=None, - ctc: CTCDecoder=None, + ctc: CTCDecoderBase=None, ctc_weight: float=0.0, asr_weight: float=0.0, ignore_id: int=IGNORE_ID, @@ -313,8 +313,7 @@ class U2STBaseModel(nn.Layer): cache = [ paddle.ones( (len(hyps), i - 1, hyp_cache.shape[-1]), - dtype=paddle.float32) - for hyp_cache in hyps[0]["cache"] + dtype=paddle.float32) for hyp_cache in hyps[0]["cache"] ] for j, hyp in enumerate(hyps): ys[j, :] = paddle.to_tensor(hyp["yseq"]) @@ -596,7 +595,7 @@ class U2STModel(U2STBaseModel): model_conf = configs['model_conf'] dropout_rate = model_conf.get('ctc_dropout_rate', 0.0) grad_norm_type = model_conf.get('ctc_grad_norm_type', None) - ctc = CTCDecoder( + ctc = CTCDecoderBase( odim=vocab_size, enc_n_units=encoder.output_size(), blank_id=0, diff --git a/paddlespeech/s2t/modules/ctc.py b/paddlespeech/s2t/modules/ctc.py index 1f983807..2094182a 100644 --- a/paddlespeech/s2t/modules/ctc.py +++ b/paddlespeech/s2t/modules/ctc.py @@ -25,17 +25,19 @@ from paddlespeech.s2t.utils.log import Log logger = Log(__name__).getlog() try: - from paddlespeech.s2t.decoders.ctcdecoder.swig_wrapper import ctc_beam_search_decoder_batch # noqa: F401 - from paddlespeech.s2t.decoders.ctcdecoder.swig_wrapper import ctc_greedy_decoder # noqa: F401 - from paddlespeech.s2t.decoders.ctcdecoder.swig_wrapper import Scorer # noqa: F401 + from paddlespeech.s2t.decoders.ctcdecoder import ctc_beam_search_decoding_batch # noqa: F401 + from paddlespeech.s2t.decoders.ctcdecoder import ctc_greedy_decoding # noqa: F401 + from paddlespeech.s2t.decoders.ctcdecoder import Scorer # noqa: F401 + from paddlespeech.s2t.decoders.ctcdecoder import CTCBeamSearchDecoder # noqa: F401 except ImportError: try: from paddlespeech.s2t.utils import dynamic_pip_install package_name = 'paddlespeech_ctcdecoders' dynamic_pip_install.install(package_name) - from paddlespeech.s2t.decoders.ctcdecoder.swig_wrapper import ctc_beam_search_decoder_batch # noqa: F401 - from paddlespeech.s2t.decoders.ctcdecoder.swig_wrapper import ctc_greedy_decoder # noqa: F401 - from paddlespeech.s2t.decoders.ctcdecoder.swig_wrapper import Scorer # noqa: F401 + from paddlespeech.s2t.decoders.ctcdecoder import ctc_beam_search_decoding_batch # noqa: F401 + from paddlespeech.s2t.decoders.ctcdecoder import ctc_greedy_decoding # noqa: F401 + from paddlespeech.s2t.decoders.ctcdecoder import Scorer # noqa: F401 + from paddlespeech.s2t.decoders.ctcdecoder import CTCBeamSearchDecoder # noqa: F401 except Exception as e: logger.info("paddlespeech_ctcdecoders not installed!") @@ -139,9 +141,11 @@ class CTCDecoder(CTCDecoderBase): super().__init__(*args, **kwargs) # CTCDecoder LM Score handle self._ext_scorer = None + self.beam_search_decoder = None - def _decode_batch_greedy(self, probs_split, vocab_list): - """Decode by best path for a batch of probs matrix input. + def _decode_batch_greedy_offline(self, probs_split, vocab_list): + """This function will be deprecated in future. + Decode by best path for a batch of probs matrix input. :param probs_split: List of 2-D probability matrix, and each consists of prob vectors for one speech utterancce. :param probs_split: List of matrix @@ -152,7 +156,7 @@ class CTCDecoder(CTCDecoderBase): """ results = [] for i, probs in enumerate(probs_split): - output_transcription = ctc_greedy_decoder( + output_transcription = ctc_greedy_decoding( probs_seq=probs, vocabulary=vocab_list, blank_id=self.blank_id) results.append(output_transcription) return results @@ -194,10 +198,12 @@ class CTCDecoder(CTCDecoderBase): logger.info("no language model provided, " "decoding by pure beam search without scorer.") - def _decode_batch_beam_search(self, probs_split, beam_alpha, beam_beta, - beam_size, cutoff_prob, cutoff_top_n, - vocab_list, num_processes): - """Decode by beam search for a batch of probs matrix input. + def _decode_batch_beam_search_offline( + self, probs_split, beam_alpha, beam_beta, beam_size, cutoff_prob, + cutoff_top_n, vocab_list, num_processes): + """ + This function will be deprecated in future. + Decode by beam search for a batch of probs matrix input. :param probs_split: List of 2-D probability matrix, and each consists of prob vectors for one speech utterancce. :param probs_split: List of matrix @@ -226,7 +232,7 @@ class CTCDecoder(CTCDecoderBase): # beam search decode num_processes = min(num_processes, len(probs_split)) - beam_search_results = ctc_beam_search_decoder_batch( + beam_search_results = ctc_beam_search_decoding_batch( probs_split=probs_split, vocabulary=vocab_list, beam_size=beam_size, @@ -239,30 +245,69 @@ class CTCDecoder(CTCDecoderBase): results = [result[0][1] for result in beam_search_results] return results - def init_decode(self, beam_alpha, beam_beta, lang_model_path, vocab_list, - decoding_method): + def init_decoder(self, batch_size, vocab_list, decoding_method, + lang_model_path, beam_alpha, beam_beta, beam_size, + cutoff_prob, cutoff_top_n, num_processes): + """ + init ctc decoders + Args: + batch_size(int): Batch size for input data + vocab_list (list): List of tokens in the vocabulary, for decoding + decoding_method (str): ctc_beam_search + lang_model_path (str): language model path + beam_alpha (float): beam_alpha + beam_beta (float): beam_beta + beam_size (int): beam_size + cutoff_prob (float): cutoff probability in beam search + cutoff_top_n (int): cutoff_top_n + num_processes (int): num_processes + + Raises: + ValueError: when decoding_method not support. + Returns: + CTCBeamSearchDecoder + """ + self.batch_size = batch_size + self.vocab_list = vocab_list + self.decoding_method = decoding_method + self.beam_size = beam_size + self.cutoff_prob = cutoff_prob + self.cutoff_top_n = cutoff_top_n + self.num_processes = num_processes if decoding_method == "ctc_beam_search": self._init_ext_scorer(beam_alpha, beam_beta, lang_model_path, vocab_list) + if self.beam_search_decoder is None: + self.beam_search_decoder = self.get_decoder( + vocab_list, batch_size, beam_alpha, beam_beta, beam_size, + num_processes, cutoff_prob, cutoff_top_n) + return self.beam_search_decoder + elif decoding_method == "ctc_greedy": + self._init_ext_scorer(beam_alpha, beam_beta, lang_model_path, + vocab_list) + else: + raise ValueError(f"Not support: {decoding_method}") - def decode_probs(self, probs, logits_lens, vocab_list, decoding_method, - lang_model_path, beam_alpha, beam_beta, beam_size, - cutoff_prob, cutoff_top_n, num_processes): - """ctc decoding with probs. - + def decode_probs_offline(self, probs, logits_lens, vocab_list, + decoding_method, lang_model_path, beam_alpha, + beam_beta, beam_size, cutoff_prob, cutoff_top_n, + num_processes): + """ + This function will be deprecated in future. + ctc decoding with probs. Args: probs (Tensor): activation after softmax logits_lens (Tensor): audio output lens - vocab_list ([type]): [description] - decoding_method ([type]): [description] - lang_model_path ([type]): [description] - beam_alpha ([type]): [description] - beam_beta ([type]): [description] - beam_size ([type]): [description] - cutoff_prob ([type]): [description] - cutoff_top_n ([type]): [description] - num_processes ([type]): [description] + vocab_list (list): List of tokens in the vocabulary, for decoding + decoding_method (str): ctc_beam_search + lang_model_path (str): language model path + beam_alpha (float): beam_alpha + beam_beta (float): beam_beta + beam_size (int): beam_size + cutoff_prob (float): cutoff probability in beam search + cutoff_top_n (int): cutoff_top_n + num_processes (int): num_processes Raises: ValueError: when decoding_method not support. @@ -270,13 +315,14 @@ class CTCDecoder(CTCDecoderBase): Returns: List[str]: transcripts. """ - + logger.warn( + "This function will be deprecated in future: decode_probs_offline") probs_split = [probs[i, :l, :] for i, l in enumerate(logits_lens)] if decoding_method == "ctc_greedy": - result_transcripts = self._decode_batch_greedy( + result_transcripts = self._decode_batch_greedy_offline( probs_split=probs_split, vocab_list=vocab_list) elif decoding_method == "ctc_beam_search": - result_transcripts = self._decode_batch_beam_search( + result_transcripts = self._decode_batch_beam_search_offline( probs_split=probs_split, beam_alpha=beam_alpha, beam_beta=beam_beta, @@ -288,3 +334,136 @@ class CTCDecoder(CTCDecoderBase): else: raise ValueError(f"Not support: {decoding_method}") return result_transcripts + + def get_decoder(self, vocab_list, batch_size, beam_alpha, beam_beta, + beam_size, num_processes, cutoff_prob, cutoff_top_n): + """ + init get ctc decoder + Args: + vocab_list (list): List of tokens in the vocabulary, for decoding. + batch_size(int): Batch size for input data + beam_alpha (float): beam_alpha + beam_beta (float): beam_beta + beam_size (int): beam_size + num_processes (int): num_processes + cutoff_prob (float): cutoff probability in beam search + cutoff_top_n (int): cutoff_top_n + + Raises: + ValueError: when decoding_method not support. + + Returns: + CTCBeamSearchDecoder + """ + num_processes = min(num_processes, batch_size) + if self._ext_scorer is not None: + self._ext_scorer.reset_params(beam_alpha, beam_beta) + if self.decoding_method == "ctc_beam_search": + beam_search_decoder = CTCBeamSearchDecoder( + vocab_list, batch_size, beam_size, num_processes, cutoff_prob, + cutoff_top_n, self._ext_scorer, self.blank_id) + else: + raise ValueError(f"Not support: {decoding_method}") + return beam_search_decoder + + def next(self, probs, logits_lens): + """ + Input probs into ctc decoder + Args: + probs (list(list(float))): probs for a batch of data + logits_lens (list(int)): logits lens for a batch of data + Raises: + Exception: when the ctc decoder is not initialized + ValueError: when decoding_method not support. + """ + + if self.beam_search_decoder is None: + raise Exception( + "You need to initialize the beam_search_decoder firstly") + beam_search_decoder = self.beam_search_decoder + + has_value = (logits_lens > 0).tolist() + has_value = [ + "true" if has_value[i] is True else "false" + for i in range(len(has_value)) + ] + probs_split = [ + probs[i, :l, :].tolist() if has_value[i] else probs[i].tolist() + for i, l in enumerate(logits_lens) + ] + if self.decoding_method == "ctc_beam_search": + beam_search_decoder.next(probs_split, has_value) + else: + raise ValueError(f"Not support: {decoding_method}") + + return + + def decode(self): + """ + Get the decoding result + Raises: + Exception: when the ctc decoder is not initialized + ValueError: when decoding_method not support. + Returns: + results_best (list(str)): The best result for a batch of data + results_beam (list(list(str))): The beam search result for a batch of data + """ + if self.beam_search_decoder is None: + raise Exception( + "You need to initialize the beam_search_decoder firstly") + + beam_search_decoder = self.beam_search_decoder + if self.decoding_method == "ctc_beam_search": + batch_beam_results = beam_search_decoder.decode() + batch_beam_results = [[(res[0], res[1]) for res in beam_results] + for beam_results in batch_beam_results] + results_best = [result[0][1] for result in batch_beam_results] + results_beam = [[trans[1] for trans in result] + for result in batch_beam_results] + + else: + raise ValueError(f"Not support: {decoding_method}") + + return results_best, results_beam + + def reset_decoder(self, + batch_size=-1, + beam_size=-1, + num_processes=-1, + cutoff_prob=-1.0, + cutoff_top_n=-1): + if batch_size > 0: + self.batch_size = batch_size + if beam_size > 0: + self.beam_size = beam_size + if num_processes > 0: + self.num_processes = num_processes + if cutoff_prob > 0: + self.cutoff_prob = cutoff_prob + if cutoff_top_n > 0: + self.cutoff_top_n = cutoff_top_n + """ + Reset the decoder state + Args: + batch_size(int): Batch size for input data + beam_size (int): beam_size + num_processes (int): num_processes + cutoff_prob (float): cutoff probability in beam search + cutoff_top_n (int): cutoff_top_n + Raises: + Exception: when the ctc decoder is not initialized + """ + if self.beam_search_decoder is None: + raise Exception( + "You need to initialize the beam_search_decoder firstly") + self.beam_search_decoder.reset_state( + self.batch_size, self.beam_size, self.num_processes, + self.cutoff_prob, self.cutoff_top_n) + + def del_decoder(self): + """ + Delete the decoder + """ + if self.beam_search_decoder is not None: + del self.beam_search_decoder + self.beam_search_decoder = None diff --git a/third_party/ctc_decoders/ctc_beam_search_decoder.cpp b/third_party/ctc_decoders/ctc_beam_search_decoder.cpp index db742fbb..ebea5c22 100644 --- a/third_party/ctc_decoders/ctc_beam_search_decoder.cpp +++ b/third_party/ctc_decoders/ctc_beam_search_decoder.cpp @@ -29,7 +29,8 @@ using FSTMATCH = fst::SortedMatcher; -std::vector> ctc_beam_search_decoder( + +std::vector> ctc_beam_search_decoding( const std::vector> &probs_seq, const std::vector &vocabulary, size_t beam_size, @@ -46,6 +47,8 @@ std::vector> ctc_beam_search_decoder( "The shape of probs_seq does not match with " "the shape of the vocabulary"); } + + // assign space id auto it = std::find(vocabulary.begin(), vocabulary.end(), kSPACE); int space_id = it - vocabulary.begin(); @@ -206,7 +209,7 @@ std::vector> ctc_beam_search_decoder( std::vector>> -ctc_beam_search_decoder_batch( +ctc_beam_search_decoding_batch( const std::vector>> &probs_split, const std::vector &vocabulary, size_t beam_size, @@ -224,7 +227,7 @@ ctc_beam_search_decoder_batch( // enqueue the tasks of decoding std::vector>>> res; for (size_t i = 0; i < batch_size; ++i) { - res.emplace_back(pool.enqueue(ctc_beam_search_decoder, + res.emplace_back(pool.enqueue(ctc_beam_search_decoding, probs_split[i], vocabulary, beam_size, @@ -241,3 +244,364 @@ ctc_beam_search_decoder_batch( } return batch_results; } + +void ctc_beam_search_decode_chunk_begin(PathTrie *root, Scorer *ext_scorer) { + if (ext_scorer != nullptr && !ext_scorer->is_character_based()) { + auto fst_dict = + static_cast(ext_scorer->dictionary); + fst::StdVectorFst *dict_ptr = fst_dict->Copy(true); + root->set_dictionary(dict_ptr); + auto matcher = std::make_shared(*dict_ptr, fst::MATCH_INPUT); + root->set_matcher(matcher); + } +} + +void ctc_beam_search_decode_chunk( + PathTrie *root, + std::vector &prefixes, + const std::vector> &probs_seq, + const std::vector &vocabulary, + size_t beam_size, + double cutoff_prob, + size_t cutoff_top_n, + Scorer *ext_scorer, + size_t blank_id) { + // dimension check + size_t num_time_steps = probs_seq.size(); + for (size_t i = 0; i < num_time_steps; ++i) { + VALID_CHECK_EQ(probs_seq[i].size(), + // vocabulary.size() + 1, + vocabulary.size(), + "The shape of probs_seq does not match with " + "the shape of the vocabulary"); + } + + // assign space id + auto it = std::find(vocabulary.begin(), vocabulary.end(), kSPACE); + int space_id = it - vocabulary.begin(); + // if no space in vocabulary + if ((size_t)space_id >= vocabulary.size()) { + space_id = -2; + } + // init prefixes' root + // + // prefix search over time + for (size_t time_step = 0; time_step < num_time_steps; ++time_step) { + auto &prob = probs_seq[time_step]; + + float min_cutoff = -NUM_FLT_INF; + bool full_beam = false; + if (ext_scorer != nullptr) { + size_t num_prefixes = std::min(prefixes.size(), beam_size); + std::sort(prefixes.begin(), + prefixes.begin() + num_prefixes, + prefix_compare); + min_cutoff = prefixes[num_prefixes - 1]->score + + std::log(prob[blank_id]) - + std::max(0.0, ext_scorer->beta); + full_beam = (num_prefixes == beam_size); + } + + std::vector> log_prob_idx = + get_pruned_log_probs(prob, cutoff_prob, cutoff_top_n); + // loop over chars + for (size_t index = 0; index < log_prob_idx.size(); index++) { + auto c = log_prob_idx[index].first; + auto log_prob_c = log_prob_idx[index].second; + + for (size_t i = 0; i < prefixes.size() && i < beam_size; ++i) { + auto prefix = prefixes[i]; + if (full_beam && log_prob_c + prefix->score < min_cutoff) { + break; + } + // blank + if (c == blank_id) { + prefix->log_prob_b_cur = log_sum_exp( + prefix->log_prob_b_cur, log_prob_c + prefix->score); + continue; + } + // repeated character + if (c == prefix->character) { + prefix->log_prob_nb_cur = + log_sum_exp(prefix->log_prob_nb_cur, + log_prob_c + prefix->log_prob_nb_prev); + } + // get new prefix + auto prefix_new = prefix->get_path_trie(c); + + if (prefix_new != nullptr) { + float log_p = -NUM_FLT_INF; + + if (c == prefix->character && + prefix->log_prob_b_prev > -NUM_FLT_INF) { + log_p = log_prob_c + prefix->log_prob_b_prev; + } else if (c != prefix->character) { + log_p = log_prob_c + prefix->score; + } + + // language model scoring + if (ext_scorer != nullptr && + (c == space_id || ext_scorer->is_character_based())) { + PathTrie *prefix_to_score = nullptr; + // skip scoring the space + if (ext_scorer->is_character_based()) { + prefix_to_score = prefix_new; + } else { + prefix_to_score = prefix; + } + + float score = 0.0; + std::vector ngram; + ngram = ext_scorer->make_ngram(prefix_to_score); + score = ext_scorer->get_log_cond_prob(ngram) * + ext_scorer->alpha; + log_p += score; + log_p += ext_scorer->beta; + } + prefix_new->log_prob_nb_cur = + log_sum_exp(prefix_new->log_prob_nb_cur, log_p); + } + } // end of loop over prefix + } // end of loop over vocabulary + + prefixes.clear(); + // update log probs + + root->iterate_to_vec(prefixes); + + // only preserve top beam_size prefixes + if (prefixes.size() >= beam_size) { + std::nth_element(prefixes.begin(), + prefixes.begin() + beam_size, + prefixes.end(), + prefix_compare); + for (size_t i = beam_size; i < prefixes.size(); ++i) { + prefixes[i]->remove(); + } + } + } // end of loop over time + + return; +} + + +std::vector> get_decode_result( + std::vector &prefixes, + const std::vector &vocabulary, + size_t beam_size, + Scorer *ext_scorer) { + auto it = std::find(vocabulary.begin(), vocabulary.end(), kSPACE); + int space_id = it - vocabulary.begin(); + // if no space in vocabulary + if ((size_t)space_id >= vocabulary.size()) { + space_id = -2; + } + // score the last word of each prefix that doesn't end with space + if (ext_scorer != nullptr && !ext_scorer->is_character_based()) { + for (size_t i = 0; i < beam_size && i < prefixes.size(); ++i) { + auto prefix = prefixes[i]; + if (!prefix->is_empty() && prefix->character != space_id) { + float score = 0.0; + std::vector ngram = ext_scorer->make_ngram(prefix); + score = + ext_scorer->get_log_cond_prob(ngram) * ext_scorer->alpha; + score += ext_scorer->beta; + prefix->score += score; + } + } + } + + size_t num_prefixes = std::min(prefixes.size(), beam_size); + std::sort( + prefixes.begin(), prefixes.begin() + num_prefixes, prefix_compare); + + // compute aproximate ctc score as the return score, without affecting the + // return order of decoding result. To delete when decoder gets stable. + for (size_t i = 0; i < beam_size && i < prefixes.size(); ++i) { + double approx_ctc = prefixes[i]->score; + if (ext_scorer != nullptr) { + std::vector output; + prefixes[i]->get_path_vec(output); + auto prefix_length = output.size(); + auto words = ext_scorer->split_labels(output); + // remove word insert + approx_ctc = approx_ctc - prefix_length * ext_scorer->beta; + // remove language model weight: + approx_ctc -= + (ext_scorer->get_sent_log_prob(words)) * ext_scorer->alpha; + } + prefixes[i]->approx_ctc = approx_ctc; + } + + std::vector> res = + get_beam_search_result(prefixes, vocabulary, beam_size); + + // pay back the last word of each prefix that doesn't end with space (for + // decoding by chunk) + if (ext_scorer != nullptr && !ext_scorer->is_character_based()) { + for (size_t i = 0; i < beam_size && i < prefixes.size(); ++i) { + auto prefix = prefixes[i]; + if (!prefix->is_empty() && prefix->character != space_id) { + float score = 0.0; + std::vector ngram = ext_scorer->make_ngram(prefix); + score = + ext_scorer->get_log_cond_prob(ngram) * ext_scorer->alpha; + score += ext_scorer->beta; + prefix->score -= score; + } + } + } + return res; +} + + +void free_storage(std::unique_ptr &storage) { + storage = nullptr; +} + + +CtcBeamSearchDecoderBatch::~CtcBeamSearchDecoderBatch() {} + +CtcBeamSearchDecoderBatch::CtcBeamSearchDecoderBatch( + const std::vector &vocabulary, + size_t batch_size, + size_t beam_size, + size_t num_processes, + double cutoff_prob, + size_t cutoff_top_n, + Scorer *ext_scorer, + size_t blank_id) + : batch_size(batch_size), + beam_size(beam_size), + num_processes(num_processes), + cutoff_prob(cutoff_prob), + cutoff_top_n(cutoff_top_n), + ext_scorer(ext_scorer), + blank_id(blank_id) { + VALID_CHECK_GT(this->beam_size, 0, "beam_size must be greater than 0!"); + VALID_CHECK_GT( + this->num_processes, 0, "num_processes must be nonnegative!"); + this->vocabulary = vocabulary; + for (size_t i = 0; i < batch_size; i++) { + this->decoder_storage_vector.push_back( + std::unique_ptr( + new CtcBeamSearchDecoderStorage())); + ctc_beam_search_decode_chunk_begin( + this->decoder_storage_vector[i]->root, ext_scorer); + } +}; + +/** + * Input + * probs_split: shape [B, T, D] + */ +void CtcBeamSearchDecoderBatch::next( + const std::vector>> &probs_split, + const std::vector &has_value) { + VALID_CHECK_GT(num_processes, 0, "num_processes must be nonnegative!"); + // thread pool + size_t num_has_value = 0; + for (int i = 0; i < has_value.size(); i++) + if (has_value[i] == "true") num_has_value += 1; + ThreadPool pool(std::min(num_processes, num_has_value)); + // number of samples + size_t probs_num = probs_split.size(); + VALID_CHECK_EQ(this->batch_size, + probs_num, + "The batch size of the current input data should be same " + "with the input data before"); + + // enqueue the tasks of decoding + std::vector> res; + for (size_t i = 0; i < batch_size; ++i) { + if (has_value[i] == "true") { + res.emplace_back(pool.enqueue( + ctc_beam_search_decode_chunk, + std::ref(this->decoder_storage_vector[i]->root), + std::ref(this->decoder_storage_vector[i]->prefixes), + probs_split[i], + this->vocabulary, + this->beam_size, + this->cutoff_prob, + this->cutoff_top_n, + this->ext_scorer, + this->blank_id)); + } + } + + for (size_t i = 0; i < batch_size; ++i) { + res[i].get(); + } + return; +}; + +/** + * Return + * batch_result: shape[B, beam_size,(-approx_ctc score, string)] + */ +std::vector>> +CtcBeamSearchDecoderBatch::decode() { + VALID_CHECK_GT( + this->num_processes, 0, "num_processes must be nonnegative!"); + // thread pool + ThreadPool pool(this->num_processes); + // number of samples + // enqueue the tasks of decoding + std::vector>>> res; + for (size_t i = 0; i < this->batch_size; ++i) { + res.emplace_back( + pool.enqueue(get_decode_result, + std::ref(this->decoder_storage_vector[i]->prefixes), + this->vocabulary, + this->beam_size, + this->ext_scorer)); + } + // get decoding results + std::vector>> batch_results; + for (size_t i = 0; i < this->batch_size; ++i) { + batch_results.emplace_back(res[i].get()); + } + return batch_results; +} + + +/** + * reset the state of ctcBeamSearchDecoderBatch + */ +void CtcBeamSearchDecoderBatch::reset_state(size_t batch_size, + size_t beam_size, + size_t num_processes, + double cutoff_prob, + size_t cutoff_top_n) { + this->batch_size = batch_size; + this->beam_size = beam_size; + this->num_processes = num_processes; + this->cutoff_prob = cutoff_prob; + this->cutoff_top_n = cutoff_top_n; + + VALID_CHECK_GT(this->beam_size, 0, "beam_size must be greater than 0!"); + VALID_CHECK_GT( + this->num_processes, 0, "num_processes must be nonnegative!"); + // thread pool + ThreadPool pool(this->num_processes); + // number of samples + // enqueue the tasks of decoding + std::vector> res; + size_t storage_size = decoder_storage_vector.size(); + for (size_t i = 0; i < storage_size; i++) { + res.emplace_back(pool.enqueue( + free_storage, std::ref(this->decoder_storage_vector[i]))); + } + for (size_t i = 0; i < storage_size; ++i) { + res[i].get(); + } + std::vector>().swap( + decoder_storage_vector); + for (size_t i = 0; i < this->batch_size; i++) { + this->decoder_storage_vector.push_back( + std::unique_ptr( + new CtcBeamSearchDecoderStorage())); + ctc_beam_search_decode_chunk_begin( + this->decoder_storage_vector[i]->root, this->ext_scorer); + } +} \ No newline at end of file diff --git a/third_party/ctc_decoders/ctc_beam_search_decoder.h b/third_party/ctc_decoders/ctc_beam_search_decoder.h index 58422657..92d2b855 100644 --- a/third_party/ctc_decoders/ctc_beam_search_decoder.h +++ b/third_party/ctc_decoders/ctc_beam_search_decoder.h @@ -37,7 +37,7 @@ * A vector that each element is a pair of score and decoding result, * in desending order. */ -std::vector> ctc_beam_search_decoder( +std::vector> ctc_beam_search_decoding( const std::vector> &probs_seq, const std::vector &vocabulary, size_t beam_size, @@ -46,6 +46,7 @@ std::vector> ctc_beam_search_decoder( Scorer *ext_scorer = nullptr, size_t blank_id = 0); + /* CTC Beam Search Decoder for batch data * Parameters: @@ -64,7 +65,7 @@ std::vector> ctc_beam_search_decoder( * result for one audio sample. */ std::vector>> -ctc_beam_search_decoder_batch( +ctc_beam_search_decoding_batch( const std::vector>> &probs_split, const std::vector &vocabulary, size_t beam_size, @@ -74,4 +75,101 @@ ctc_beam_search_decoder_batch( Scorer *ext_scorer = nullptr, size_t blank_id = 0); +/** + * Store the root and prefixes for decoder + */ + +class CtcBeamSearchDecoderStorage { + public: + PathTrie *root = nullptr; + std::vector prefixes; + + CtcBeamSearchDecoderStorage() { + // init prefixes' root + this->root = new PathTrie(); + this->root->log_prob_b_prev = 0.0; + // The score of root is in log scale.Since the prob=1.0, the prob score + // in log scale is 0.0 + this->root->score = root->log_prob_b_prev; + // std::vector prefixes; + this->prefixes.push_back(root); + }; + + ~CtcBeamSearchDecoderStorage() { + if (root != nullptr) { + delete root; + root = nullptr; + } + }; +}; + +/** + * The ctc beam search decoder, support batchsize >= 1 + */ +class CtcBeamSearchDecoderBatch { + public: + CtcBeamSearchDecoderBatch(const std::vector &vocabulary, + size_t batch_size, + size_t beam_size, + size_t num_processes, + double cutoff_prob, + size_t cutoff_top_n, + Scorer *ext_scorer, + size_t blank_id); + + ~CtcBeamSearchDecoderBatch(); + void next(const std::vector>> &probs_split, + const std::vector &has_value); + + std::vector>> decode(); + + void reset_state(size_t batch_size, + size_t beam_size, + size_t num_processes, + double cutoff_prob, + size_t cutoff_top_n); + + private: + std::vector vocabulary; + size_t batch_size; + size_t beam_size; + size_t num_processes; + double cutoff_prob; + size_t cutoff_top_n; + Scorer *ext_scorer; + size_t blank_id; + std::vector> + decoder_storage_vector; +}; + +/** + * function for chunk decoding + */ +void ctc_beam_search_decode_chunk( + PathTrie *root, + std::vector &prefixes, + const std::vector> &probs_seq, + const std::vector &vocabulary, + size_t beam_size, + double cutoff_prob, + size_t cutoff_top_n, + Scorer *ext_scorer, + size_t blank_id); + +std::vector> get_decode_result( + std::vector &prefixes, + const std::vector &vocabulary, + size_t beam_size, + Scorer *ext_scorer); + +/** + * free the CtcBeamSearchDecoderStorage + */ +void free_storage(std::unique_ptr &storage); + +/** + * initialize the root + */ +void ctc_beam_search_decode_chunk_begin(PathTrie *root, Scorer *ext_scorer); + #endif // CTC_BEAM_SEARCH_DECODER_H_ diff --git a/third_party/ctc_decoders/ctc_greedy_decoder.cpp b/third_party/ctc_decoders/ctc_greedy_decoder.cpp index a178c673..6aa3c996 100644 --- a/third_party/ctc_decoders/ctc_greedy_decoder.cpp +++ b/third_party/ctc_decoders/ctc_greedy_decoder.cpp @@ -15,7 +15,7 @@ #include "ctc_greedy_decoder.h" #include "decoder_utils.h" -std::string ctc_greedy_decoder( +std::string ctc_greedy_decoding( const std::vector> &probs_seq, const std::vector &vocabulary, size_t blank_id) { diff --git a/third_party/ctc_decoders/ctc_greedy_decoder.h b/third_party/ctc_decoders/ctc_greedy_decoder.h index 4d60beaf..4451600d 100644 --- a/third_party/ctc_decoders/ctc_greedy_decoder.h +++ b/third_party/ctc_decoders/ctc_greedy_decoder.h @@ -27,7 +27,7 @@ * Return: * The decoding result in string */ -std::string ctc_greedy_decoder( +std::string ctc_greedy_decoding( const std::vector>& probs_seq, const std::vector& vocabulary, size_t blank_id); diff --git a/third_party/ctc_decoders/decoders.i b/third_party/ctc_decoders/decoders.i index 4227d4a3..8fe3b279 100644 --- a/third_party/ctc_decoders/decoders.i +++ b/third_party/ctc_decoders/decoders.i @@ -1,4 +1,4 @@ -%module swig_decoders +%module paddlespeech_ctcdecoders %{ #include "scorer.h" #include "ctc_greedy_decoder.h" diff --git a/third_party/ctc_decoders/path_trie.cpp b/third_party/ctc_decoders/path_trie.cpp index a5e7dd3d..777ca052 100644 --- a/third_party/ctc_decoders/path_trie.cpp +++ b/third_party/ctc_decoders/path_trie.cpp @@ -44,6 +44,7 @@ PathTrie::PathTrie() { PathTrie::~PathTrie() { for (auto child : children_) { delete child.second; + child.second = nullptr; } } @@ -131,26 +132,26 @@ void PathTrie::iterate_to_vec(std::vector& output) { void PathTrie::remove() { exists_ = false; - if (children_.size() == 0) { - auto child = parent->children_.begin(); - for (child = parent->children_.begin(); - child != parent->children_.end(); - ++child) { - if (child->first == character) { - parent->children_.erase(child); - break; + if (parent != nullptr) { + auto child = parent->children_.begin(); + for (child = parent->children_.begin(); + child != parent->children_.end(); + ++child) { + if (child->first == character) { + parent->children_.erase(child); + break; + } + } + if (parent->children_.size() == 0 && !parent->exists_) { + parent->remove(); } } - - if (parent->children_.size() == 0 && !parent->exists_) { - parent->remove(); - } - delete this; } } + void PathTrie::set_dictionary(fst::StdVectorFst* dictionary) { dictionary_ = dictionary; dictionary_state_ = dictionary->Start(); diff --git a/third_party/ctc_decoders/scorer.cpp b/third_party/ctc_decoders/scorer.cpp index 977112d1..6c1d96be 100644 --- a/third_party/ctc_decoders/scorer.cpp +++ b/third_party/ctc_decoders/scorer.cpp @@ -1,4 +1,5 @@ -// Licensed under GNU Lesser General Public License v3 (LGPLv3) (LGPL-3) (the "COPYING.LESSER.3"); +// Licensed under GNU Lesser General Public License v3 (LGPLv3) (LGPL-3) (the +// "COPYING.LESSER.3"); #include "scorer.h" diff --git a/third_party/ctc_decoders/scorer.h b/third_party/ctc_decoders/scorer.h index 5739339d..08e109b7 100644 --- a/third_party/ctc_decoders/scorer.h +++ b/third_party/ctc_decoders/scorer.h @@ -1,4 +1,5 @@ -// Licensed under GNU Lesser General Public License v3 (LGPLv3) (LGPL-3) (the "COPYING.LESSER.3"); +// Licensed under GNU Lesser General Public License v3 (LGPLv3) (LGPL-3) (the +// "COPYING.LESSER.3"); #ifndef SCORER_H_ #define SCORER_H_ diff --git a/third_party/ctc_decoders/setup.py b/third_party/ctc_decoders/setup.py index 6484b87c..4a11b890 100644 --- a/third_party/ctc_decoders/setup.py +++ b/third_party/ctc_decoders/setup.py @@ -112,7 +112,7 @@ os.system('swig -python -c++ ./decoders.i') decoders_module = [ Extension( - name='_swig_decoders', + name='_paddlespeech_ctcdecoders', sources=FILES + glob.glob('*.cxx') + glob.glob('*.cpp'), language='c++', include_dirs=[ @@ -134,4 +134,4 @@ setup( url="https://github.com/PaddlePaddle/PaddleSpeech", license='Apache 2.0, GNU Lesser General Public License v3 (LGPLv3) (LGPL-3)', ext_modules=decoders_module, - py_modules=['swig_decoders']) + py_modules=['paddlespeech_ctcdecoders']) -- GitLab