decoder.py 7.5 KB
Newer Older
1 2 3 4 5 6
"""
    CTC-like decoder utilitis.
"""

from itertools import groupby
import numpy as np
Y
Yibing Liu 已提交
7 8
import copy
import kenlm
9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40


def ctc_best_path_decode(probs_seq, vocabulary):
    """
    Best path decoding, also called argmax decoding or greedy decoding.
    Path consisting of the most probable tokens are further post-processed to
    remove consecutive repetitions and all blanks.

    :param probs_seq: 2-D list of probabilities over the vocabulary for each
                      character. Each element is a list of float probabilities
                      for one character.
    :type probs_seq: list
    :param vocabulary: Vocabulary list.
    :type vocabulary: list
    :return: Decoding result string.
    :rtype: baseline
    """
    # dimension verification
    for probs in probs_seq:
        if not len(probs) == len(vocabulary) + 1:
            raise ValueError("probs_seq dimension mismatchedd with vocabulary")
    # argmax to get the best index for each time step
    max_index_list = list(np.array(probs_seq).argmax(axis=1))
    # remove consecutive duplicate indexes
    index_list = [index_group[0] for index_group in groupby(max_index_list)]
    # remove blank indexes
    blank_index = len(vocabulary)
    index_list = [index for index in index_list if index != blank_index]
    # convert index list to string
    return ''.join([vocabulary[index] for index in index_list])


Y
Yibing Liu 已提交
41
class Scorer(object):
42
    """
Y
Yibing Liu 已提交
43 44
    External defined scorer to evaluate a sentence in beam search
               decoding, consisting of language model and word count.
45

Y
Yibing Liu 已提交
46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70
    :param alpha: Parameter associated with language model.
    :type alpha: float
    :param beta: Parameter associated with word count.
    :type beta: float
    :model_path: Path to load language model.
    :type model_path: basestring
    """

    def __init__(self, alpha, beta, model_path):

        self._alpha = alpha
        self._beta = beta
        self._language_model = kenlm.LanguageModel(model_path)

    def language_model_score(self, sentence, bos=True, eos=False):
        log_prob = self._language_model.score(sentence, bos, eos)
        return np.power(10, log_prob)

    def word_count(self, sentence):
        words = sentence.strip().split(' ')
        return len(words)

    # execute evaluation
    def evaluate(self, sentence, bos=True, eos=False):
        lm = self.language_model_score(sentence, bos, eos)
Y
Yibing Liu 已提交
71
        word_cnt = self.word_count(sentence)
Y
Yibing Liu 已提交
72
        score = np.power(lm, self._alpha) \
Y
Yibing Liu 已提交
73
                * np.power(word_cnt, self._beta)
Y
Yibing Liu 已提交
74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94
        return score


def ctc_beam_search_decoder(probs_seq,
                            beam_size,
                            vocabulary,
                            ext_scoring_func=None,
                            blank_id=0):
    '''
    Beam search decoder for CTC-trained network, using beam search with width
    beam_size to find many paths to one label, return  beam_size labels in
    the order of probabilities. The implementation is based on Prefix Beam
    Search(https://arxiv.org/abs/1408.2873), and the unclear part is
    redesigned, need to be verified.

    :param probs_seq: 2-D list with length max_time_steps, each element
                      is a list of normalized probabilities over vocabulary
                      and blank for one time step.
    :type probs_seq: 2-D list
    :param beam_size: Width for beam search.
    :type beam_size: int
95 96
    :param vocabulary: Vocabulary list.
    :type vocabulary: list
Y
Yibing Liu 已提交
97 98 99 100 101 102 103 104 105 106
    :param ext_scoring_func: External defined scoring function for
                            partially decoded sentence, e.g. word count
                            and language model.
    :type external_scoring_function: function
    :param blank_id: id of blank, default 0.
    :type blank_id: int
    :return: Decoding log probability and result string.
    :rtype: list

    '''
Y
Yibing Liu 已提交
107
    # dimension check
108 109 110
    for prob_list in probs_seq:
        if not len(prob_list) == len(vocabulary) + 1:
            raise ValueError("probs dimension mismatchedd with vocabulary")
Y
Yibing Liu 已提交
111 112
    max_time_steps = len(probs_seq)

Y
Yibing Liu 已提交
113
    # blank_id check
Y
Yibing Liu 已提交
114 115 116 117
    probs_dim = len(probs_seq[0])
    if not blank_id < probs_dim:
        raise ValueError("blank_id shouldn't be greater than probs dimension")

Y
Yibing Liu 已提交
118
    # assign space_id
Y
Yibing Liu 已提交
119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201
    if ' ' not in vocabulary:
        raise ValueError("space doesn't exist in vocabulary")
    space_id = vocabulary.index(' ')

    # function to convert ids in string to list
    def ids_str2list(ids_str):
        ids_str = ids_str.split(' ')
        ids_list = [int(elem) for elem in ids_str]
        return ids_list

    # function to convert ids list to sentence
    def ids2sentence(ids_list, vocab):
        return ''.join([vocab[ids] for ids in ids_list])

    ## initialize
    # the set containing selected prefixes
    prefix_set_prev = {'-1': 1.0}
    probs_b, probs_nb = {'-1': 1.0}, {'-1': 0.0}

    ## extend prefix in loop
    for time_step in range(max_time_steps):
        # the set containing candidate prefixes
        prefix_set_next = {}
        probs_b_cur, probs_nb_cur = {}, {}
        for l in prefix_set_prev:
            prob = probs_seq[time_step]

            # convert ids in string to list
            ids_list = ids_str2list(l)
            end_id = ids_list[-1]
            if not prefix_set_next.has_key(l):
                probs_b_cur[l], probs_nb_cur[l] = 0.0, 0.0

            # extend prefix by travering vocabulary
            for c in range(0, probs_dim):
                if c == blank_id:
                    probs_b_cur[l] += prob[c] * (probs_b[l] + probs_nb[l])
                else:
                    l_plus = l + ' ' + str(c)
                    if not prefix_set_next.has_key(l_plus):
                        probs_b_cur[l_plus], probs_nb_cur[l_plus] = 0.0, 0.0

                    if c == end_id:
                        probs_nb_cur[l_plus] += prob[c] * probs_b[l]
                        probs_nb_cur[l] += prob[c] * probs_nb[l]
                    elif c == space_id:
                        if ext_scoring_func is None:
                            score = 1.0
                        else:
                            prefix_sent = ids2sentence(ids_list, vocabulary)
                            score = ext_scoring_func(prefix_sent)
                        probs_nb_cur[l_plus] += score * prob[c] * (
                            probs_b[l] + probs_nb[l])
                    else:
                        probs_nb_cur[l_plus] += prob[c] * (
                            probs_b[l] + probs_nb[l])
                    # add l_plus into prefix_set_next
                    prefix_set_next[l_plus] = probs_nb_cur[
                        l_plus] + probs_b_cur[l_plus]
            # add l into prefix_set_next
            prefix_set_next[l] = probs_b_cur[l] + probs_nb_cur[l]
        # update probs
        probs_b, probs_nb = copy.deepcopy(probs_b_cur), copy.deepcopy(
            probs_nb_cur)

        ## store top beam_size prefixes
        prefix_set_prev = sorted(
            prefix_set_next.iteritems(), key=lambda asd: asd[1], reverse=True)
        if beam_size < len(prefix_set_prev):
            prefix_set_prev = prefix_set_prev[:beam_size]
        prefix_set_prev = dict(prefix_set_prev)

    beam_result = []
    for (seq, prob) in prefix_set_prev.items():
        if prob > 0.0:
            ids_list = ids_str2list(seq)[1:]
            result = ids2sentence(ids_list, vocabulary)
            log_prob = np.log(prob)
            beam_result.append([log_prob, result])

    ## output top beam_size decoding results
    beam_result = sorted(beam_result, key=lambda asd: asd[0], reverse=True)
    return beam_result