swig_wrapper.py 4.8 KB
Newer Older
1 2 3 4 5
"""Wrapper for various CTC decoders in SWIG."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

6 7 8 9 10 11 12 13 14 15
import swig_decoders


class Scorer(swig_decoders.Scorer):
    """Wrapper for Scorer.

    :param alpha: Parameter associated with language model. Don't use
                  language model when alpha = 0.
    :type alpha: float
    :param beta: Parameter associated with word count. Don't use word
16
                 count when beta = 0.
17 18 19 20 21
    :type beta: float
    :model_path: Path to load language model.
    :type model_path: basestring
    """

22 23
    def __init__(self, alpha, beta, model_path, vocabulary):
        swig_decoders.Scorer.__init__(self, alpha, beta, model_path, vocabulary)
24 25


26
def ctc_greedy_decoder(probs_seq, vocabulary):
27 28 29 30 31 32 33 34 35 36 37
    """Wrapper for ctc best path decoder in swig.

    :param probs_seq: 2-D list of probability distributions over each time
                      step, with each element being a list of normalized
                      probabilities over vocabulary and blank.
    :type probs_seq: 2-D list
    :param vocabulary: Vocabulary list.
    :type vocabulary: list
    :return: Decoding result string.
    :rtype: basestring
    """
38
    return swig_decoders.ctc_greedy_decoder(probs_seq.tolist(), vocabulary)
39 40


41 42 43 44
def ctc_beam_search_decoder(probs_seq,
                            beam_size,
                            vocabulary,
                            cutoff_prob=1.0,
Y
Yibing Liu 已提交
45
                            cutoff_top_n=40,
46 47
                            ext_scoring_func=None):
    """Wrapper for the CTC Beam Search Decoder.
48 49 50 51 52 53 54 55 56 57 58 59

    :param probs_seq: 2-D list of probability distributions over each time
                      step, with each element being a list of normalized
                      probabilities over vocabulary and blank.
    :type probs_seq: 2-D list
    :param beam_size: Width for beam search.
    :type beam_size: int
    :param vocabulary: Vocabulary list.
    :type vocabulary: list
    :param cutoff_prob: Cutoff probability in pruning,
                        default 1.0, no pruning.
    :type cutoff_prob: float
Y
Yibing Liu 已提交
60
    :param cutoff_top_n: Cutoff number in pruning, only top cutoff_top_n
61 62
                         characters with highest probs in vocabulary will be
                         used in beam search, default 40.
Y
Yibing Liu 已提交
63
    :type cutoff_top_n: int
64
    :param ext_scoring_func: External scoring function for
65 66
                             partially decoded sentence, e.g. word count
                             or language model.
67 68 69 70 71
    :type external_scoring_func: callable
    :return: List of tuples of log probability and sentence as decoding
             results, in descending order of the probability.
    :rtype: list
    """
Y
Yibing Liu 已提交
72 73 74
    return swig_decoders.ctc_beam_search_decoder(probs_seq.tolist(), beam_size,
                                                 vocabulary, cutoff_prob,
                                                 cutoff_top_n, ext_scoring_func)
75 76 77 78 79 80 81


def ctc_beam_search_decoder_batch(probs_split,
                                  beam_size,
                                  vocabulary,
                                  num_processes,
                                  cutoff_prob=1.0,
Y
Yibing Liu 已提交
82
                                  cutoff_top_n=40,
83
                                  ext_scoring_func=None):
84
    """Wrapper for the batched CTC beam search decoder.
85

86 87 88 89 90 91 92 93 94
    :param probs_seq: 3-D list with each element as an instance of 2-D list
                      of probabilities used by ctc_beam_search_decoder().
    :type probs_seq: 3-D list
    :param beam_size: Width for beam search.
    :type beam_size: int
    :param vocabulary: Vocabulary list.
    :type vocabulary: list
    :param num_processes: Number of parallel processes.
    :type num_processes: int
Y
Yibing Liu 已提交
95
    :param cutoff_prob: Cutoff probability in vocabulary pruning,
96
                        default 1.0, no pruning.
Y
Yibing Liu 已提交
97 98
    :type cutoff_prob: float
    :param cutoff_top_n: Cutoff number in pruning, only top cutoff_top_n
99 100
                         characters with highest probs in vocabulary will be
                         used in beam search, default 40.
Y
Yibing Liu 已提交
101
    :type cutoff_top_n: int
102 103 104
    :param num_processes: Number of parallel processes.
    :type num_processes: int
    :param ext_scoring_func: External scoring function for
105 106
                             partially decoded sentence, e.g. word count
                             or language model.
107 108 109 110 111 112
    :type external_scoring_function: callable
    :return: List of tuples of log probability and sentence as decoding
             results, in descending order of the probability.
    :rtype: list
    """
    probs_split = [probs_seq.tolist() for probs_seq in probs_split]
113

114
    return swig_decoders.ctc_beam_search_decoder_batch(
Y
Yibing Liu 已提交
115 116
        probs_split, beam_size, vocabulary, num_processes, cutoff_prob,
        cutoff_top_n, ext_scoring_func)