From 6e896b0879265c0afa6788e3320bba586d365b2e Mon Sep 17 00:00:00 2001 From: guosheng Date: Wed, 21 Mar 2018 10:38:36 +0800 Subject: [PATCH] Refine the log calculation in Transformer beam search --- fluid/neural_machine_translation/transformer/infer.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/fluid/neural_machine_translation/transformer/infer.py b/fluid/neural_machine_translation/transformer/infer.py index f5fdfb33..e4dee220 100644 --- a/fluid/neural_machine_translation/transformer/infer.py +++ b/fluid/neural_machine_translation/transformer/infer.py @@ -121,11 +121,11 @@ def translate_batch(exe, src_words, encoder, enc_in_names, enc_out_names, predict_all = exe.run(decoder, feed=dict(zip(dec_in_names, dec_in_data)), fetch_list=dec_out_names)[0] - predict_all = np.log(predict_all) - predict_all = ( - predict_all.reshape( - [len(beam_map) * beam_size, i + 1, -1])[:, -1, :] + - scores[beam_map].reshape([len(beam_map) * beam_size, -1])).reshape( + predict_all = np.log( + predict_all.reshape([len(beam_map) * beam_size, i + 1, -1])[:, + -1, :]) + predict_all = (predict_all + scores[beam_map].reshape( + [len(beam_map) * beam_size, -1])).reshape( [len(beam_map), beam_size, -1]) active_beams = [] for inst_idx, beam_idx in enumerate(beam_map): -- GitLab