From 44efbed798966f1d57276e5fde3d8541e8fddc48 Mon Sep 17 00:00:00 2001 From: Yibing Liu Date: Wed, 7 Jun 2017 16:59:11 +0800 Subject: [PATCH] rename variables in decoder --- decoder.py | 25 +++++++++++++------------ 1 file changed, 13 insertions(+), 12 deletions(-) diff --git a/decoder.py b/decoder.py index 458cd9ad..d5bd72f6 100755 --- a/decoder.py +++ b/decoder.py @@ -92,7 +92,7 @@ def ctc_beam_search_decoder(probs_seq, Search(https://arxiv.org/abs/1408.2873), and the unclear part is redesigned, need to be verified. - :param probs_seq: 2-D list with length max_time_steps, each element + :param probs_seq: 2-D list with length num_time_steps, each element is a list of normalized probabilities over vocabulary and blank for one time step. :type probs_seq: 2-D list @@ -114,7 +114,7 @@ def ctc_beam_search_decoder(probs_seq, for prob_list in probs_seq: if not len(prob_list) == len(vocabulary) + 1: raise ValueError("probs dimension mismatchedd with vocabulary") - max_time_steps = len(probs_seq) + num_time_steps = len(probs_seq) # blank_id check probs_dim = len(probs_seq[0]) @@ -139,10 +139,10 @@ def ctc_beam_search_decoder(probs_seq, ## initialize # the set containing selected prefixes prefix_set_prev = {'-1': 1.0} - probs_b, probs_nb = {'-1': 1.0}, {'-1': 0.0} + probs_b_prev, probs_nb_prev = {'-1': 1.0}, {'-1': 0.0} ## extend prefix in loop - for time_step in range(max_time_steps): + for time_step in range(num_time_steps): # the set containing candidate prefixes prefix_set_next = {} probs_b_cur, probs_nb_cur = {}, {} @@ -158,33 +158,34 @@ def ctc_beam_search_decoder(probs_seq, # extend prefix by travering vocabulary for c in range(0, probs_dim): if c == blank_id: - probs_b_cur[l] += prob[c] * (probs_b[l] + probs_nb[l]) + probs_b_cur[l] += prob[c] * ( + probs_b_prev[l] + probs_nb_prev[l]) else: l_plus = l + ' ' + str(c) if not prefix_set_next.has_key(l_plus): probs_b_cur[l_plus], probs_nb_cur[l_plus] = 0.0, 0.0 if c == end_id: - probs_nb_cur[l_plus] += prob[c] * probs_b[l] - probs_nb_cur[l] += prob[c] * probs_nb[l] + probs_nb_cur[l_plus] += prob[c] * probs_b_prev[l] + probs_nb_cur[l] += prob[c] * probs_nb_prev[l] elif c == space_id: if ext_scoring_func is None: score = 1.0 else: - prefix_sent = ids2sentence(ids_list, vocabulary) - score = ext_scoring_func(prefix_sent) + prefix = ids2sentence(ids_list, vocabulary) + score = ext_scoring_func(prefix) probs_nb_cur[l_plus] += score * prob[c] * ( - probs_b[l] + probs_nb[l]) + probs_b_prev[l] + probs_nb_prev[l]) else: probs_nb_cur[l_plus] += prob[c] * ( - probs_b[l] + probs_nb[l]) + probs_b_prev[l] + probs_nb_prev[l]) # add l_plus into prefix_set_next prefix_set_next[l_plus] = probs_nb_cur[ l_plus] + probs_b_cur[l_plus] # add l into prefix_set_next prefix_set_next[l] = probs_b_cur[l] + probs_nb_cur[l] # update probs - probs_b, probs_nb = copy.deepcopy(probs_b_cur), copy.deepcopy( + probs_b_prev, probs_nb_prev = copy.deepcopy(probs_b_cur), copy.deepcopy( probs_nb_cur) ## store top beam_size prefixes -- GitLab