diff --git a/decoder.py b/decoder.py index 05400d1b4113e78bd79de5765e4640727086fa85..0eab3651968d16e06225c94ecf67592e16240e5b 100755 --- a/decoder.py +++ b/decoder.py @@ -121,25 +121,10 @@ def ctc_beam_search_decoder(probs_seq, if not blank_id < probs_dim: raise ValueError("blank_id shouldn't be greater than probs dimension") - # assign space_id - if ' ' not in vocabulary: - raise ValueError("space doesn't exist in vocabulary") - space_id = vocabulary.index(' ') - - # function to convert ids in string to list - def ids_str2list(ids_str): - ids_str = ids_str.split(' ') - ids_list = [int(elem) for elem in ids_str] - return ids_list - - # function to convert ids list to sentence - def ids2sentence(ids_list, vocab): - return ''.join([vocab[ids] for ids in ids_list]) - ## initialize # the set containing selected prefixes - prefix_set_prev = {'-1': 1.0} - probs_b_prev, probs_nb_prev = {'-1': 1.0}, {'-1': 0.0} + prefix_set_prev = {'\t': 1.0} + probs_b_prev, probs_nb_prev = {'\t': 1.0}, {'\t': 0.0} ## extend prefix in loop for time_step in range(num_time_steps): @@ -148,10 +133,6 @@ def ctc_beam_search_decoder(probs_seq, probs_b_cur, probs_nb_cur = {}, {} for l in prefix_set_prev: prob = probs_seq[time_step] - - # convert ids in string to list - ids_list = ids_str2list(l) - end_id = ids_list[-1] if not prefix_set_next.has_key(l): probs_b_cur[l], probs_nb_cur[l] = 0.0, 0.0 @@ -161,18 +142,20 @@ def ctc_beam_search_decoder(probs_seq, probs_b_cur[l] += prob[c] * ( probs_b_prev[l] + probs_nb_prev[l]) else: - l_plus = l + ' ' + str(c) + last_char = l[-1] + new_char = vocabulary[c] + l_plus = l + new_char if not prefix_set_next.has_key(l_plus): probs_b_cur[l_plus], probs_nb_cur[l_plus] = 0.0, 0.0 - if c == end_id: + if new_char == last_char: probs_nb_cur[l_plus] += prob[c] * probs_b_prev[l] probs_nb_cur[l] += prob[c] * probs_nb_prev[l] - elif c == space_id: - if ext_scoring_func is None: + elif new_char == ' ': + if (ext_scoring_func is None) or (len(l) == 1): score = 1.0 else: - prefix = ids2sentence(ids_list, vocabulary) + prefix = l[1:] score = ext_scoring_func(prefix) probs_nb_cur[l_plus] += score * prob[c] * ( probs_b_prev[l] + probs_nb_prev[l]) @@ -185,8 +168,7 @@ def ctc_beam_search_decoder(probs_seq, # add l into prefix_set_next prefix_set_next[l] = probs_b_cur[l] + probs_nb_cur[l] # update probs - probs_b_prev, probs_nb_prev = copy.deepcopy(probs_b_cur), copy.deepcopy( - probs_nb_cur) + probs_b_prev, probs_nb_prev = probs_b_cur, probs_nb_cur ## store top beam_size prefixes prefix_set_prev = sorted( @@ -198,8 +180,7 @@ def ctc_beam_search_decoder(probs_seq, beam_result = [] for (seq, prob) in prefix_set_prev.items(): if prob > 0.0: - ids_list = ids_str2list(seq)[1:] - result = ids2sentence(ids_list, vocabulary) + result = seq[1:] log_prob = np.log(prob) beam_result.append([log_prob, result])