提交 3a7a52e2 编写于 作者: Y Yibing Liu

rename variables in decoder

上级 6698fe64
......@@ -92,7 +92,7 @@ def ctc_beam_search_decoder(probs_seq,
Search(https://arxiv.org/abs/1408.2873), and the unclear part is
redesigned, need to be verified.
:param probs_seq: 2-D list with length max_time_steps, each element
:param probs_seq: 2-D list with length num_time_steps, each element
is a list of normalized probabilities over vocabulary
and blank for one time step.
:type probs_seq: 2-D list
......@@ -114,7 +114,7 @@ def ctc_beam_search_decoder(probs_seq,
for prob_list in probs_seq:
if not len(prob_list) == len(vocabulary) + 1:
raise ValueError("probs dimension mismatchedd with vocabulary")
max_time_steps = len(probs_seq)
num_time_steps = len(probs_seq)
# blank_id check
probs_dim = len(probs_seq[0])
......@@ -139,10 +139,10 @@ def ctc_beam_search_decoder(probs_seq,
## initialize
# the set containing selected prefixes
prefix_set_prev = {'-1': 1.0}
probs_b, probs_nb = {'-1': 1.0}, {'-1': 0.0}
probs_b_prev, probs_nb_prev = {'-1': 1.0}, {'-1': 0.0}
## extend prefix in loop
for time_step in range(max_time_steps):
for time_step in range(num_time_steps):
# the set containing candidate prefixes
prefix_set_next = {}
probs_b_cur, probs_nb_cur = {}, {}
......@@ -158,33 +158,34 @@ def ctc_beam_search_decoder(probs_seq,
# extend prefix by travering vocabulary
for c in range(0, probs_dim):
if c == blank_id:
probs_b_cur[l] += prob[c] * (probs_b[l] + probs_nb[l])
probs_b_cur[l] += prob[c] * (
probs_b_prev[l] + probs_nb_prev[l])
else:
l_plus = l + ' ' + str(c)
if not prefix_set_next.has_key(l_plus):
probs_b_cur[l_plus], probs_nb_cur[l_plus] = 0.0, 0.0
if c == end_id:
probs_nb_cur[l_plus] += prob[c] * probs_b[l]
probs_nb_cur[l] += prob[c] * probs_nb[l]
probs_nb_cur[l_plus] += prob[c] * probs_b_prev[l]
probs_nb_cur[l] += prob[c] * probs_nb_prev[l]
elif c == space_id:
if ext_scoring_func is None:
score = 1.0
else:
prefix_sent = ids2sentence(ids_list, vocabulary)
score = ext_scoring_func(prefix_sent)
prefix = ids2sentence(ids_list, vocabulary)
score = ext_scoring_func(prefix)
probs_nb_cur[l_plus] += score * prob[c] * (
probs_b[l] + probs_nb[l])
probs_b_prev[l] + probs_nb_prev[l])
else:
probs_nb_cur[l_plus] += prob[c] * (
probs_b[l] + probs_nb[l])
probs_b_prev[l] + probs_nb_prev[l])
# add l_plus into prefix_set_next
prefix_set_next[l_plus] = probs_nb_cur[
l_plus] + probs_b_cur[l_plus]
# add l into prefix_set_next
prefix_set_next[l] = probs_b_cur[l] + probs_nb_cur[l]
# update probs
probs_b, probs_nb = copy.deepcopy(probs_b_cur), copy.deepcopy(
probs_b_prev, probs_nb_prev = copy.deepcopy(probs_b_cur), copy.deepcopy(
probs_nb_cur)
## store top beam_size prefixes
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册