提交 6e896b08 编写于 作者: G guosheng

Refine the log calculation in Transformer beam search

上级 6ef54e8e
......@@ -121,11 +121,11 @@ def translate_batch(exe, src_words, encoder, enc_in_names, enc_out_names,
predict_all = exe.run(decoder,
feed=dict(zip(dec_in_names, dec_in_data)),
fetch_list=dec_out_names)[0]
predict_all = np.log(predict_all)
predict_all = (
predict_all.reshape(
[len(beam_map) * beam_size, i + 1, -1])[:, -1, :] +
scores[beam_map].reshape([len(beam_map) * beam_size, -1])).reshape(
predict_all = np.log(
predict_all.reshape([len(beam_map) * beam_size, i + 1, -1])[:,
-1, :])
predict_all = (predict_all + scores[beam_map].reshape(
[len(beam_map) * beam_size, -1])).reshape(
[len(beam_map), beam_size, -1])
active_beams = []
for inst_idx, beam_idx in enumerate(beam_map):
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册