提交 6e38cc33 编写于 作者: G guosheng

Fix the beam_search in test_machine_translation.py

上级 35e32a8e
...@@ -127,9 +127,19 @@ def decode(context, is_sparse): ...@@ -127,9 +127,19 @@ def decode(context, is_sparse):
current_score = pd.fc(input=current_state_with_lod, current_score = pd.fc(input=current_state_with_lod,
size=target_dict_dim, size=target_dict_dim,
act='softmax') act='softmax')
topk_scores, topk_indices = pd.topk(current_score, k=topk_size) topk_scores, topk_indices = pd.topk(current_score, k=beam_size)
# calculate accumulated scores after topk to reduce computation cost
accu_scores = pd.elementwise_add(
x=pd.log(topk_scores), y=pd.reshape(
pre_score, shape=[-1]), axis=0)
selected_ids, selected_scores = pd.beam_search( selected_ids, selected_scores = pd.beam_search(
pre_ids, topk_indices, topk_scores, beam_size, end_id=10, level=0) pre_ids,
pre_score,
topk_indices,
accu_scores,
beam_size,
end_id=10,
level=0)
pd.increment(x=counter, value=1, in_place=True) pd.increment(x=counter, value=1, in_place=True)
...@@ -141,7 +151,7 @@ def decode(context, is_sparse): ...@@ -141,7 +151,7 @@ def decode(context, is_sparse):
pd.less_than(x=counter, y=array_len, cond=cond) pd.less_than(x=counter, y=array_len, cond=cond)
translation_ids, translation_scores = pd.beam_search_decode( translation_ids, translation_scores = pd.beam_search_decode(
ids=ids_array, scores=scores_array) ids=ids_array, scores=scores_array, beam_size=beam_size, end_id=10)
# return init_ids, init_scores # return init_ids, init_scores
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册