diff --git a/decoder.py b/decoder.py index 458cd9ad3e46c8bb0196e6c09063b69d40d2ba8e..d5bd72f6faeadf9620b1082d130c7c439989e9db 100755 --- a/decoder.py +++ b/decoder.py @@ -92,7 +92,7 @@ def ctc_beam_search_decoder(probs_seq, Search(https://arxiv.org/abs/1408.2873), and the unclear part is redesigned, need to be verified. - :param probs_seq: 2-D list with length max_time_steps, each element + :param probs_seq: 2-D list with length num_time_steps, each element is a list of normalized probabilities over vocabulary and blank for one time step. :type probs_seq: 2-D list @@ -114,7 +114,7 @@ def ctc_beam_search_decoder(probs_seq, for prob_list in probs_seq: if not len(prob_list) == len(vocabulary) + 1: raise ValueError("probs dimension mismatchedd with vocabulary") - max_time_steps = len(probs_seq) + num_time_steps = len(probs_seq) # blank_id check probs_dim = len(probs_seq[0]) @@ -139,10 +139,10 @@ def ctc_beam_search_decoder(probs_seq, ## initialize # the set containing selected prefixes prefix_set_prev = {'-1': 1.0} - probs_b, probs_nb = {'-1': 1.0}, {'-1': 0.0} + probs_b_prev, probs_nb_prev = {'-1': 1.0}, {'-1': 0.0} ## extend prefix in loop - for time_step in range(max_time_steps): + for time_step in range(num_time_steps): # the set containing candidate prefixes prefix_set_next = {} probs_b_cur, probs_nb_cur = {}, {} @@ -158,33 +158,34 @@ def ctc_beam_search_decoder(probs_seq, # extend prefix by travering vocabulary for c in range(0, probs_dim): if c == blank_id: - probs_b_cur[l] += prob[c] * (probs_b[l] + probs_nb[l]) + probs_b_cur[l] += prob[c] * ( + probs_b_prev[l] + probs_nb_prev[l]) else: l_plus = l + ' ' + str(c) if not prefix_set_next.has_key(l_plus): probs_b_cur[l_plus], probs_nb_cur[l_plus] = 0.0, 0.0 if c == end_id: - probs_nb_cur[l_plus] += prob[c] * probs_b[l] - probs_nb_cur[l] += prob[c] * probs_nb[l] + probs_nb_cur[l_plus] += prob[c] * probs_b_prev[l] + probs_nb_cur[l] += prob[c] * probs_nb_prev[l] elif c == space_id: if ext_scoring_func is None: score = 1.0 else: - prefix_sent = ids2sentence(ids_list, vocabulary) - score = ext_scoring_func(prefix_sent) + prefix = ids2sentence(ids_list, vocabulary) + score = ext_scoring_func(prefix) probs_nb_cur[l_plus] += score * prob[c] * ( - probs_b[l] + probs_nb[l]) + probs_b_prev[l] + probs_nb_prev[l]) else: probs_nb_cur[l_plus] += prob[c] * ( - probs_b[l] + probs_nb[l]) + probs_b_prev[l] + probs_nb_prev[l]) # add l_plus into prefix_set_next prefix_set_next[l_plus] = probs_nb_cur[ l_plus] + probs_b_cur[l_plus] # add l into prefix_set_next prefix_set_next[l] = probs_b_cur[l] + probs_nb_cur[l] # update probs - probs_b, probs_nb = copy.deepcopy(probs_b_cur), copy.deepcopy( + probs_b_prev, probs_nb_prev = copy.deepcopy(probs_b_cur), copy.deepcopy( probs_nb_cur) ## store top beam_size prefixes