提交 44efbed7 编写于 作者: Y Yibing Liu

rename variables in decoder

上级 21ff590e
...@@ -92,7 +92,7 @@ def ctc_beam_search_decoder(probs_seq, ...@@ -92,7 +92,7 @@ def ctc_beam_search_decoder(probs_seq,
Search(https://arxiv.org/abs/1408.2873), and the unclear part is Search(https://arxiv.org/abs/1408.2873), and the unclear part is
redesigned, need to be verified. redesigned, need to be verified.
:param probs_seq: 2-D list with length max_time_steps, each element :param probs_seq: 2-D list with length num_time_steps, each element
is a list of normalized probabilities over vocabulary is a list of normalized probabilities over vocabulary
and blank for one time step. and blank for one time step.
:type probs_seq: 2-D list :type probs_seq: 2-D list
...@@ -114,7 +114,7 @@ def ctc_beam_search_decoder(probs_seq, ...@@ -114,7 +114,7 @@ def ctc_beam_search_decoder(probs_seq,
for prob_list in probs_seq: for prob_list in probs_seq:
if not len(prob_list) == len(vocabulary) + 1: if not len(prob_list) == len(vocabulary) + 1:
raise ValueError("probs dimension mismatchedd with vocabulary") raise ValueError("probs dimension mismatchedd with vocabulary")
max_time_steps = len(probs_seq) num_time_steps = len(probs_seq)
# blank_id check # blank_id check
probs_dim = len(probs_seq[0]) probs_dim = len(probs_seq[0])
...@@ -139,10 +139,10 @@ def ctc_beam_search_decoder(probs_seq, ...@@ -139,10 +139,10 @@ def ctc_beam_search_decoder(probs_seq,
## initialize ## initialize
# the set containing selected prefixes # the set containing selected prefixes
prefix_set_prev = {'-1': 1.0} prefix_set_prev = {'-1': 1.0}
probs_b, probs_nb = {'-1': 1.0}, {'-1': 0.0} probs_b_prev, probs_nb_prev = {'-1': 1.0}, {'-1': 0.0}
## extend prefix in loop ## extend prefix in loop
for time_step in range(max_time_steps): for time_step in range(num_time_steps):
# the set containing candidate prefixes # the set containing candidate prefixes
prefix_set_next = {} prefix_set_next = {}
probs_b_cur, probs_nb_cur = {}, {} probs_b_cur, probs_nb_cur = {}, {}
...@@ -158,33 +158,34 @@ def ctc_beam_search_decoder(probs_seq, ...@@ -158,33 +158,34 @@ def ctc_beam_search_decoder(probs_seq,
# extend prefix by travering vocabulary # extend prefix by travering vocabulary
for c in range(0, probs_dim): for c in range(0, probs_dim):
if c == blank_id: if c == blank_id:
probs_b_cur[l] += prob[c] * (probs_b[l] + probs_nb[l]) probs_b_cur[l] += prob[c] * (
probs_b_prev[l] + probs_nb_prev[l])
else: else:
l_plus = l + ' ' + str(c) l_plus = l + ' ' + str(c)
if not prefix_set_next.has_key(l_plus): if not prefix_set_next.has_key(l_plus):
probs_b_cur[l_plus], probs_nb_cur[l_plus] = 0.0, 0.0 probs_b_cur[l_plus], probs_nb_cur[l_plus] = 0.0, 0.0
if c == end_id: if c == end_id:
probs_nb_cur[l_plus] += prob[c] * probs_b[l] probs_nb_cur[l_plus] += prob[c] * probs_b_prev[l]
probs_nb_cur[l] += prob[c] * probs_nb[l] probs_nb_cur[l] += prob[c] * probs_nb_prev[l]
elif c == space_id: elif c == space_id:
if ext_scoring_func is None: if ext_scoring_func is None:
score = 1.0 score = 1.0
else: else:
prefix_sent = ids2sentence(ids_list, vocabulary) prefix = ids2sentence(ids_list, vocabulary)
score = ext_scoring_func(prefix_sent) score = ext_scoring_func(prefix)
probs_nb_cur[l_plus] += score * prob[c] * ( probs_nb_cur[l_plus] += score * prob[c] * (
probs_b[l] + probs_nb[l]) probs_b_prev[l] + probs_nb_prev[l])
else: else:
probs_nb_cur[l_plus] += prob[c] * ( probs_nb_cur[l_plus] += prob[c] * (
probs_b[l] + probs_nb[l]) probs_b_prev[l] + probs_nb_prev[l])
# add l_plus into prefix_set_next # add l_plus into prefix_set_next
prefix_set_next[l_plus] = probs_nb_cur[ prefix_set_next[l_plus] = probs_nb_cur[
l_plus] + probs_b_cur[l_plus] l_plus] + probs_b_cur[l_plus]
# add l into prefix_set_next # add l into prefix_set_next
prefix_set_next[l] = probs_b_cur[l] + probs_nb_cur[l] prefix_set_next[l] = probs_b_cur[l] + probs_nb_cur[l]
# update probs # update probs
probs_b, probs_nb = copy.deepcopy(probs_b_cur), copy.deepcopy( probs_b_prev, probs_nb_prev = copy.deepcopy(probs_b_cur), copy.deepcopy(
probs_nb_cur) probs_nb_cur)
## store top beam_size prefixes ## store top beam_size prefixes
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册