From ec70ff4558ff20d0e1b9563e0b4bdf5a5e4552f0 Mon Sep 17 00:00:00 2001 From: guosheng Date: Thu, 29 Mar 2018 00:35:14 +0800 Subject: [PATCH] Fix the inference batch_size in Transformer. --- fluid/neural_machine_translation/transformer/config.py | 3 +-- fluid/neural_machine_translation/transformer/infer.py | 7 +++++-- 2 files changed, 6 insertions(+), 4 deletions(-) diff --git a/fluid/neural_machine_translation/transformer/config.py b/fluid/neural_machine_translation/transformer/config.py index 1bf3f8d8..b91a8672 100644 --- a/fluid/neural_machine_translation/transformer/config.py +++ b/fluid/neural_machine_translation/transformer/config.py @@ -22,8 +22,7 @@ class TrainTaskConfig(object): class InferTaskConfig(object): use_gpu = False # the number of examples in one run for sequence generation. - # currently the batch size can only be set to 1. - batch_size = 1 + batch_size = 10 # the parameters for beam search. beam_size = 5 diff --git a/fluid/neural_machine_translation/transformer/infer.py b/fluid/neural_machine_translation/transformer/infer.py index f24b7f6e..5d572c56 100644 --- a/fluid/neural_machine_translation/transformer/infer.py +++ b/fluid/neural_machine_translation/transformer/infer.py @@ -84,8 +84,11 @@ def translate_batch(exe, src_words, encoder, enc_in_names, enc_out_names, [-1e9]).astype("float32") # This is used to remove attention on the paddings of source sequences. trg_src_attn_bias = np.tile( - src_slf_attn_bias[:, :, ::src_max_length, :], - [beam_size, 1, trg_max_len, 1]) + src_slf_attn_bias[:, :, ::src_max_length, :][:, np.newaxis], + [1, beam_size, 1, trg_max_len, 1]).reshape([ + -1, src_slf_attn_bias.shape[1], trg_max_len, + src_slf_attn_bias.shape[-1] + ]) trg_data_shape = np.array( [batch_size * beam_size, trg_max_len, d_model], dtype="int32") enc_output = np.tile( -- GitLab