提交 6e896b08 编写于 作者: G guosheng

Refine the log calculation in Transformer beam search

上级 6ef54e8e
...@@ -121,11 +121,11 @@ def translate_batch(exe, src_words, encoder, enc_in_names, enc_out_names, ...@@ -121,11 +121,11 @@ def translate_batch(exe, src_words, encoder, enc_in_names, enc_out_names,
predict_all = exe.run(decoder, predict_all = exe.run(decoder,
feed=dict(zip(dec_in_names, dec_in_data)), feed=dict(zip(dec_in_names, dec_in_data)),
fetch_list=dec_out_names)[0] fetch_list=dec_out_names)[0]
predict_all = np.log(predict_all) predict_all = np.log(
predict_all = ( predict_all.reshape([len(beam_map) * beam_size, i + 1, -1])[:,
predict_all.reshape( -1, :])
[len(beam_map) * beam_size, i + 1, -1])[:, -1, :] + predict_all = (predict_all + scores[beam_map].reshape(
scores[beam_map].reshape([len(beam_map) * beam_size, -1])).reshape( [len(beam_map) * beam_size, -1])).reshape(
[len(beam_map), beam_size, -1]) [len(beam_map), beam_size, -1])
active_beams = [] active_beams = []
for inst_idx, beam_idx in enumerate(beam_map): for inst_idx, beam_idx in enumerate(beam_map):
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册