decoder.py 7.9 KB
Newer Older
1 2 3 4 5 6
"""
    CTC-like decoder utilitis.
"""

from itertools import groupby
import numpy as np
Y
Yibing Liu 已提交
7 8
import copy
import kenlm
9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40


def ctc_best_path_decode(probs_seq, vocabulary):
    """
    Best path decoding, also called argmax decoding or greedy decoding.
    Path consisting of the most probable tokens are further post-processed to
    remove consecutive repetitions and all blanks.

    :param probs_seq: 2-D list of probabilities over the vocabulary for each
                      character. Each element is a list of float probabilities
                      for one character.
    :type probs_seq: list
    :param vocabulary: Vocabulary list.
    :type vocabulary: list
    :return: Decoding result string.
    :rtype: baseline
    """
    # dimension verification
    for probs in probs_seq:
        if not len(probs) == len(vocabulary) + 1:
            raise ValueError("probs_seq dimension mismatchedd with vocabulary")
    # argmax to get the best index for each time step
    max_index_list = list(np.array(probs_seq).argmax(axis=1))
    # remove consecutive duplicate indexes
    index_list = [index_group[0] for index_group in groupby(max_index_list)]
    # remove blank indexes
    blank_index = len(vocabulary)
    index_list = [index for index in index_list if index != blank_index]
    # convert index list to string
    return ''.join([vocabulary[index] for index in index_list])


Y
Yibing Liu 已提交
41
class Scorer(object):
42
    """
Y
Yibing Liu 已提交
43 44
    External defined scorer to evaluate a sentence in beam search
               decoding, consisting of language model and word count.
45

Y
Yibing Liu 已提交
46 47 48 49 50 51 52 53 54 55 56 57 58
    :param alpha: Parameter associated with language model.
    :type alpha: float
    :param beta: Parameter associated with word count.
    :type beta: float
    :model_path: Path to load language model.
    :type model_path: basestring
    """

    def __init__(self, alpha, beta, model_path):
        self._alpha = alpha
        self._beta = beta
        self._language_model = kenlm.LanguageModel(model_path)

Y
Yibing Liu 已提交
59
    # language model scoring
Y
Yibing Liu 已提交
60
    def language_model_score(self, sentence, bos=True, eos=False):
Y
Yibing Liu 已提交
61 62 63 64 65 66 67 68
        words = sentence.strip().split(' ')
        length = len(words)
        if length == 1:
            log_prob = self._language_model.score(sentence, bos, eos)
        else:
            prefix_sent = ' '.join(words[0:length - 1])
            log_prob = self._language_model.score(sentence, bos, eos) \
                       - self._language_model.score(prefix_sent, bos, eos)
Y
Yibing Liu 已提交
69 70
        return np.power(10, log_prob)

Y
Yibing Liu 已提交
71
    # word insertion term
Y
Yibing Liu 已提交
72 73 74 75 76 77 78
    def word_count(self, sentence):
        words = sentence.strip().split(' ')
        return len(words)

    # execute evaluation
    def evaluate(self, sentence, bos=True, eos=False):
        lm = self.language_model_score(sentence, bos, eos)
Y
Yibing Liu 已提交
79
        word_cnt = self.word_count(sentence)
Y
Yibing Liu 已提交
80
        score = np.power(lm, self._alpha) \
Y
Yibing Liu 已提交
81
                * np.power(word_cnt, self._beta)
Y
Yibing Liu 已提交
82 83 84 85 86 87 88 89 90 91 92 93 94 95 96
        return score


def ctc_beam_search_decoder(probs_seq,
                            beam_size,
                            vocabulary,
                            ext_scoring_func=None,
                            blank_id=0):
    '''
    Beam search decoder for CTC-trained network, using beam search with width
    beam_size to find many paths to one label, return  beam_size labels in
    the order of probabilities. The implementation is based on Prefix Beam
    Search(https://arxiv.org/abs/1408.2873), and the unclear part is
    redesigned, need to be verified.

Y
Yibing Liu 已提交
97
    :param probs_seq: 2-D list with length num_time_steps, each element
Y
Yibing Liu 已提交
98 99 100 101 102
                      is a list of normalized probabilities over vocabulary
                      and blank for one time step.
    :type probs_seq: 2-D list
    :param beam_size: Width for beam search.
    :type beam_size: int
103 104
    :param vocabulary: Vocabulary list.
    :type vocabulary: list
Y
Yibing Liu 已提交
105 106 107 108 109 110 111 112 113 114
    :param ext_scoring_func: External defined scoring function for
                            partially decoded sentence, e.g. word count
                            and language model.
    :type external_scoring_function: function
    :param blank_id: id of blank, default 0.
    :type blank_id: int
    :return: Decoding log probability and result string.
    :rtype: list

    '''
Y
Yibing Liu 已提交
115
    # dimension check
116 117 118
    for prob_list in probs_seq:
        if not len(prob_list) == len(vocabulary) + 1:
            raise ValueError("probs dimension mismatchedd with vocabulary")
Y
Yibing Liu 已提交
119
    num_time_steps = len(probs_seq)
Y
Yibing Liu 已提交
120

Y
Yibing Liu 已提交
121
    # blank_id check
Y
Yibing Liu 已提交
122 123 124 125
    probs_dim = len(probs_seq[0])
    if not blank_id < probs_dim:
        raise ValueError("blank_id shouldn't be greater than probs dimension")

Y
Yibing Liu 已提交
126
    # assign space_id
Y
Yibing Liu 已提交
127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143
    if ' ' not in vocabulary:
        raise ValueError("space doesn't exist in vocabulary")
    space_id = vocabulary.index(' ')

    # function to convert ids in string to list
    def ids_str2list(ids_str):
        ids_str = ids_str.split(' ')
        ids_list = [int(elem) for elem in ids_str]
        return ids_list

    # function to convert ids list to sentence
    def ids2sentence(ids_list, vocab):
        return ''.join([vocab[ids] for ids in ids_list])

    ## initialize
    # the set containing selected prefixes
    prefix_set_prev = {'-1': 1.0}
Y
Yibing Liu 已提交
144
    probs_b_prev, probs_nb_prev = {'-1': 1.0}, {'-1': 0.0}
Y
Yibing Liu 已提交
145 146

    ## extend prefix in loop
Y
Yibing Liu 已提交
147
    for time_step in range(num_time_steps):
Y
Yibing Liu 已提交
148 149 150 151 152 153 154 155 156 157 158 159 160 161 162
        # the set containing candidate prefixes
        prefix_set_next = {}
        probs_b_cur, probs_nb_cur = {}, {}
        for l in prefix_set_prev:
            prob = probs_seq[time_step]

            # convert ids in string to list
            ids_list = ids_str2list(l)
            end_id = ids_list[-1]
            if not prefix_set_next.has_key(l):
                probs_b_cur[l], probs_nb_cur[l] = 0.0, 0.0

            # extend prefix by travering vocabulary
            for c in range(0, probs_dim):
                if c == blank_id:
Y
Yibing Liu 已提交
163 164
                    probs_b_cur[l] += prob[c] * (
                        probs_b_prev[l] + probs_nb_prev[l])
Y
Yibing Liu 已提交
165 166 167 168 169 170
                else:
                    l_plus = l + ' ' + str(c)
                    if not prefix_set_next.has_key(l_plus):
                        probs_b_cur[l_plus], probs_nb_cur[l_plus] = 0.0, 0.0

                    if c == end_id:
Y
Yibing Liu 已提交
171 172
                        probs_nb_cur[l_plus] += prob[c] * probs_b_prev[l]
                        probs_nb_cur[l] += prob[c] * probs_nb_prev[l]
Y
Yibing Liu 已提交
173 174 175 176
                    elif c == space_id:
                        if ext_scoring_func is None:
                            score = 1.0
                        else:
Y
Yibing Liu 已提交
177 178
                            prefix = ids2sentence(ids_list, vocabulary)
                            score = ext_scoring_func(prefix)
Y
Yibing Liu 已提交
179
                        probs_nb_cur[l_plus] += score * prob[c] * (
Y
Yibing Liu 已提交
180
                            probs_b_prev[l] + probs_nb_prev[l])
Y
Yibing Liu 已提交
181 182
                    else:
                        probs_nb_cur[l_plus] += prob[c] * (
Y
Yibing Liu 已提交
183
                            probs_b_prev[l] + probs_nb_prev[l])
Y
Yibing Liu 已提交
184 185 186 187 188 189
                    # add l_plus into prefix_set_next
                    prefix_set_next[l_plus] = probs_nb_cur[
                        l_plus] + probs_b_cur[l_plus]
            # add l into prefix_set_next
            prefix_set_next[l] = probs_b_cur[l] + probs_nb_cur[l]
        # update probs
Y
Yibing Liu 已提交
190
        probs_b_prev, probs_nb_prev = copy.deepcopy(probs_b_cur), copy.deepcopy(
Y
Yibing Liu 已提交
191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210
            probs_nb_cur)

        ## store top beam_size prefixes
        prefix_set_prev = sorted(
            prefix_set_next.iteritems(), key=lambda asd: asd[1], reverse=True)
        if beam_size < len(prefix_set_prev):
            prefix_set_prev = prefix_set_prev[:beam_size]
        prefix_set_prev = dict(prefix_set_prev)

    beam_result = []
    for (seq, prob) in prefix_set_prev.items():
        if prob > 0.0:
            ids_list = ids_str2list(seq)[1:]
            result = ids2sentence(ids_list, vocabulary)
            log_prob = np.log(prob)
            beam_result.append([log_prob, result])

    ## output top beam_size decoding results
    beam_result = sorted(beam_result, key=lambda asd: asd[0], reverse=True)
    return beam_result