提交 ec70ff45 编写于 作者: G guosheng

Fix the inference batch_size in Transformer.

上级 f3c247d3
......@@ -22,8 +22,7 @@ class TrainTaskConfig(object):
class InferTaskConfig(object):
use_gpu = False
# the number of examples in one run for sequence generation.
# currently the batch size can only be set to 1.
batch_size = 1
batch_size = 10
# the parameters for beam search.
beam_size = 5
......
......@@ -84,8 +84,11 @@ def translate_batch(exe, src_words, encoder, enc_in_names, enc_out_names,
[-1e9]).astype("float32")
# This is used to remove attention on the paddings of source sequences.
trg_src_attn_bias = np.tile(
src_slf_attn_bias[:, :, ::src_max_length, :],
[beam_size, 1, trg_max_len, 1])
src_slf_attn_bias[:, :, ::src_max_length, :][:, np.newaxis],
[1, beam_size, 1, trg_max_len, 1]).reshape([
-1, src_slf_attn_bias.shape[1], trg_max_len,
src_slf_attn_bias.shape[-1]
])
trg_data_shape = np.array(
[batch_size * beam_size, trg_max_len, d_model], dtype="int32")
enc_output = np.tile(
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册