"""Wrapper for various CTC decoders in SWIG.""" from __future__ import absolute_import from __future__ import division from __future__ import print_function import swig_decoders import multiprocessing class Scorer(swig_decoders.Scorer): """Wrapper for Scorer. :param alpha: Parameter associated with language model. Don't use language model when alpha = 0. :type alpha: float :param beta: Parameter associated with word count. Don't use word count when beta = 0. :type beta: float :model_path: Path to load language model. :type model_path: basestring """ def __init__(self, alpha, beta, model_path): swig_decoders.Scorer.__init__(self, alpha, beta, model_path) def ctc_best_path_decoder(probs_seq, vocabulary): """Wrapper for ctc best path decoder in swig. :param probs_seq: 2-D list of probability distributions over each time step, with each element being a list of normalized probabilities over vocabulary and blank. :type probs_seq: 2-D list :param vocabulary: Vocabulary list. :type vocabulary: list :return: Decoding result string. :rtype: basestring """ return swig_decoders.ctc_best_path_decoder(probs_seq.tolist(), vocabulary) def ctc_beam_search_decoder( probs_seq, beam_size, vocabulary, blank_id, cutoff_prob=1.0, ext_scoring_func=None, ): """Wrapper for CTC Beam Search Decoder. :param probs_seq: 2-D list of probability distributions over each time step, with each element being a list of normalized probabilities over vocabulary and blank. :type probs_seq: 2-D list :param beam_size: Width for beam search. :type beam_size: int :param vocabulary: Vocabulary list. :type vocabulary: list :param blank_id: ID of blank. :type blank_id: int :param cutoff_prob: Cutoff probability in pruning, default 1.0, no pruning. :type cutoff_prob: float :param ext_scoring_func: External scoring function for partially decoded sentence, e.g. word count or language model. :type external_scoring_func: callable :return: List of tuples of log probability and sentence as decoding results, in descending order of the probability. :rtype: list """ return swig_decoders.ctc_beam_search_decoder(probs_seq.tolist(), beam_size, vocabulary, blank_id, cutoff_prob, ext_scoring_func) def ctc_beam_search_decoder_batch(probs_split, beam_size, vocabulary, blank_id, num_processes, cutoff_prob=1.0, ext_scoring_func=None): """Wrapper for CTC beam search decoder in batch """ # TODO: to resolve PicklingError if not num_processes > 0: raise ValueError("Number of processes must be positive!") pool = Pool(processes=num_processes) results = [] args_list = [] for i, probs_list in enumerate(probs_split): args = (probs_list, beam_size, vocabulary, blank_id, cutoff_prob, ext_scoring_func) args_list.append(args) results.append(pool.apply_async(ctc_beam_search_decoder, args)) pool.close() pool.join() beam_search_results = [result.get() for result in results] return beam_search_results