diff --git a/fluid/neural_machine_translation/transformer/infer.py b/fluid/neural_machine_translation/transformer/infer.py index 0866efd8bfc97f4695af01b6e96ce66bcb82b41d..2338975f980e472224e599d83891f2b27f814ece 100644 --- a/fluid/neural_machine_translation/transformer/infer.py +++ b/fluid/neural_machine_translation/transformer/infer.py @@ -203,7 +203,7 @@ def translate_batch(exe, predict_all = np.log( predict_all.reshape([len(beam_inst_map) * beam_size, i + 1, -1]) [:, -1, :]) - predict_all = (predict_all + scores[beam_inst_map].reshape( + predict_all = (predict_all + scores[active_beams].reshape( [len(beam_inst_map) * beam_size, -1])).reshape( [len(beam_inst_map), beam_size, -1]) if not output_unk: # To exclude the token.