diff --git a/python/paddle/fluid/tests/book/high-level-api/machine_translation/test_machine_translation.py b/python/paddle/fluid/tests/book/high-level-api/machine_translation/test_machine_translation.py index c4b37df3a09f93fe965ae28ce783f06f5018020d..ccb7a4f9ab20803c9201edbe4a5aa471b23c1083 100644 --- a/python/paddle/fluid/tests/book/high-level-api/machine_translation/test_machine_translation.py +++ b/python/paddle/fluid/tests/book/high-level-api/machine_translation/test_machine_translation.py @@ -127,9 +127,19 @@ def decode(context, is_sparse): current_score = pd.fc(input=current_state_with_lod, size=target_dict_dim, act='softmax') - topk_scores, topk_indices = pd.topk(current_score, k=topk_size) + topk_scores, topk_indices = pd.topk(current_score, k=beam_size) + # calculate accumulated scores after topk to reduce computation cost + accu_scores = pd.elementwise_add( + x=pd.log(topk_scores), y=pd.reshape( + pre_score, shape=[-1]), axis=0) selected_ids, selected_scores = pd.beam_search( - pre_ids, topk_indices, topk_scores, beam_size, end_id=10, level=0) + pre_ids, + pre_score, + topk_indices, + accu_scores, + beam_size, + end_id=10, + level=0) pd.increment(x=counter, value=1, in_place=True) @@ -141,7 +151,7 @@ def decode(context, is_sparse): pd.less_than(x=counter, y=array_len, cond=cond) translation_ids, translation_scores = pd.beam_search_decode( - ids=ids_array, scores=scores_array) + ids=ids_array, scores=scores_array, beam_size=beam_size, end_id=10) # return init_ids, init_scores