"""Wrapper for various CTC decoders in SWIG.""" from __future__ import absolute_import from __future__ import division from __future__ import print_function import swig_decoders class Scorer(swig_decoders.Scorer): """Wrapper for Scorer. :param alpha: Parameter associated with language model. Don't use language model when alpha = 0. :type alpha: float :param beta: Parameter associated with word count. Don't use word count when beta = 0. :type beta: float :model_path: Path to load language model. :type model_path: basestring """ def __init__(self, alpha, beta, model_path): swig_decoders.Scorer.__init__(self, alpha, beta, model_path) def ctc_greedy_decoder(probs_seq, vocabulary): """Wrapper for ctc best path decoder in swig. :param probs_seq: 2-D list of probability distributions over each time step, with each element being a list of normalized probabilities over vocabulary and blank. :type probs_seq: 2-D list :param vocabulary: Vocabulary list. :type vocabulary: list :return: Decoding result string. :rtype: basestring """ return swig_decoders.ctc_greedy_decoder(probs_seq.tolist(), vocabulary) def ctc_beam_search_decoder(probs_seq, beam_size, vocabulary, blank_id, cutoff_prob=1.0, cutoff_top_n=40, ext_scoring_func=None): """Wrapper for the CTC Beam Search Decoder. :param probs_seq: 2-D list of probability distributions over each time step, with each element being a list of normalized probabilities over vocabulary and blank. :type probs_seq: 2-D list :param beam_size: Width for beam search. :type beam_size: int :param vocabulary: Vocabulary list. :type vocabulary: list :param blank_id: ID of blank. :type blank_id: int :param cutoff_prob: Cutoff probability in pruning, default 1.0, no pruning. :type cutoff_prob: float :param cutoff_top_n: Cutoff number in pruning, only top cutoff_top_n characters with highest probs in vocabulary will be used in beam search, default 40. :type cutoff_top_n: int :param ext_scoring_func: External scoring function for partially decoded sentence, e.g. word count or language model. :type external_scoring_func: callable :return: List of tuples of log probability and sentence as decoding results, in descending order of the probability. :rtype: list """ return swig_decoders.ctc_beam_search_decoder( probs_seq.tolist(), beam_size, vocabulary, blank_id, cutoff_prob, cutoff_top_n, ext_scoring_func) def ctc_beam_search_decoder_batch(probs_split, beam_size, vocabulary, blank_id, num_processes, cutoff_prob=1.0, cutoff_top_n=40, ext_scoring_func=None): """Wrapper for the batched CTC beam search decoder. :param probs_seq: 3-D list with each element as an instance of 2-D list of probabilities used by ctc_beam_search_decoder(). :type probs_seq: 3-D list :param beam_size: Width for beam search. :type beam_size: int :param vocabulary: Vocabulary list. :type vocabulary: list :param blank_id: ID of blank. :type blank_id: int :param num_processes: Number of parallel processes. :type num_processes: int :param cutoff_prob: Cutoff probability in vocabulary pruning, default 1.0, no pruning. :type cutoff_prob: float :param cutoff_top_n: Cutoff number in pruning, only top cutoff_top_n characters with highest probs in vocabulary will be used in beam search, default 40. :type cutoff_top_n: int :param num_processes: Number of parallel processes. :type num_processes: int :param ext_scoring_func: External scoring function for partially decoded sentence, e.g. word count or language model. :type external_scoring_function: callable :return: List of tuples of log probability and sentence as decoding results, in descending order of the probability. :rtype: list """ probs_split = [probs_seq.tolist() for probs_seq in probs_split] return swig_decoders.ctc_beam_search_decoder_batch( probs_split, beam_size, vocabulary, blank_id, num_processes, cutoff_prob, cutoff_top_n, ext_scoring_func)