swig_decoders.py 3.0 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86
"""Wrapper for various CTC decoders in SWIG."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

import swig_ctc_decoders
import multiprocessing


def ctc_best_path_decoder(probs_seq, vocabulary):
    """Wrapper for ctc best path decoder in swig.

    :param probs_seq: 2-D list of probability distributions over each time
                      step, with each element being a list of normalized
                      probabilities over vocabulary and blank.
    :type probs_seq: 2-D list
    :param vocabulary: Vocabulary list.
    :type vocabulary: list
    :return: Decoding result string.
    :rtype: basestring
    """
    return swig_ctc_decoders.ctc_best_path_decoder(probs_seq.tolist(),
                                                   vocabulary)


def ctc_beam_search_decoder(
        probs_seq,
        beam_size,
        vocabulary,
        blank_id,
        cutoff_prob=1.0,
        ext_scoring_func=None, ):
    """Wrapper for CTC Beam Search Decoder.

    :param probs_seq: 2-D list of probability distributions over each time
                      step, with each element being a list of normalized
                      probabilities over vocabulary and blank.
    :type probs_seq: 2-D list
    :param beam_size: Width for beam search.
    :type beam_size: int
    :param vocabulary: Vocabulary list.
    :type vocabulary: list
    :param blank_id: ID of blank.
    :type blank_id: int
    :param cutoff_prob: Cutoff probability in pruning,
                        default 1.0, no pruning.
    :type cutoff_prob: float
    :param ext_scoring_func: External scoring function for
                            partially decoded sentence, e.g. word count
                            or language model.
    :type external_scoring_func: callable
    :return: List of tuples of log probability and sentence as decoding
             results, in descending order of the probability.
    :rtype: list
    """
    return swig_ctc_decoders.ctc_beam_search_decoder(
        probs_seq.tolist(), beam_size, vocabulary, blank_id, cutoff_prob,
        ext_scoring_func)


def ctc_beam_search_decoder_batch(probs_split,
                                  beam_size,
                                  vocabulary,
                                  blank_id,
                                  num_processes,
                                  cutoff_prob=1.0,
                                  ext_scoring_func=None):
    """Wrapper for CTC beam search decoder in batch
    """

    # TODO: to resolve PicklingError

    if not num_processes > 0:
        raise ValueError("Number of processes must be positive!")

    pool = multiprocessing.Pool(processes=num_processes)
    results = []
    for i, probs_list in enumerate(probs_split):
        args = (probs_list, beam_size, vocabulary, blank_id, cutoff_prob,
                ext_scoring_func)
        results.append(pool.apply_async(ctc_beam_search_decoder, args))

    pool.close()
    pool.join()
    beam_search_results = [result.get() for result in results]
    return beam_search_results