diff --git a/fluid/neural_machine_translation/transformer/config.py b/fluid/neural_machine_translation/transformer/config.py index 8ab9efce1a275ea9539b05c0b959dee42d83c759..dca7bca1122b338bf113fc8ee19435ec6a756fcb 100644 --- a/fluid/neural_machine_translation/transformer/config.py +++ b/fluid/neural_machine_translation/transformer/config.py @@ -103,22 +103,23 @@ def merge_cfg_from_list(cfg_list, g_cfgs): break +batch_size = -1 # Here list the data shapes and data types of all inputs. # The shapes here act as placeholder and are set to pass the infer-shape in # compile time. input_descs = { # The actual data shape of src_word is: # [batch_size * max_src_len_in_batch, 1] - "src_word": [(1 * (ModelHyperParams.max_length + 1), 1L), "int64"], + "src_word": [(batch_size * (ModelHyperParams.max_length + 1), 1L), "int64"], # The actual data shape of src_pos is: # [batch_size * max_src_len_in_batch, 1] - "src_pos": [(1 * (ModelHyperParams.max_length + 1), 1L), "int64"], + "src_pos": [(batch_size * (ModelHyperParams.max_length + 1), 1L), "int64"], # This input is used to remove attention weights on paddings in the # encoder. # The actual data shape of src_slf_attn_bias is: # [batch_size, n_head, max_src_len_in_batch, max_src_len_in_batch] "src_slf_attn_bias": - [(1, ModelHyperParams.n_head, (ModelHyperParams.max_length + 1), + [(batch_size, ModelHyperParams.n_head, (ModelHyperParams.max_length + 1), (ModelHyperParams.max_length + 1)), "float32"], # This shape input is used to reshape the output of embedding layer. "src_data_shape": [(3L, ), "int32"], @@ -128,22 +129,22 @@ input_descs = { "src_slf_attn_post_softmax_shape": [(4L, ), "int32"], # The actual data shape of trg_word is: # [batch_size * max_trg_len_in_batch, 1] - "trg_word": [(1 * (ModelHyperParams.max_length + 1), 1L), "int64"], + "trg_word": [(batch_size * (ModelHyperParams.max_length + 1), 1L), "int64"], # The actual data shape of trg_pos is: # [batch_size * max_trg_len_in_batch, 1] - "trg_pos": [(1 * (ModelHyperParams.max_length + 1), 1L), "int64"], + "trg_pos": [(batch_size * (ModelHyperParams.max_length + 1), 1L), "int64"], # This input is used to remove attention weights on paddings and # subsequent words in the decoder. # The actual data shape of trg_slf_attn_bias is: # [batch_size, n_head, max_trg_len_in_batch, max_trg_len_in_batch] - "trg_slf_attn_bias": [(1, ModelHyperParams.n_head, + "trg_slf_attn_bias": [(batch_size, ModelHyperParams.n_head, (ModelHyperParams.max_length + 1), (ModelHyperParams.max_length + 1)), "float32"], # This input is used to remove attention weights on paddings of the source # input in the encoder-decoder attention. # The actual data shape of trg_src_attn_bias is: # [batch_size, n_head, max_trg_len_in_batch, max_src_len_in_batch] - "trg_src_attn_bias": [(1, ModelHyperParams.n_head, + "trg_src_attn_bias": [(batch_size, ModelHyperParams.n_head, (ModelHyperParams.max_length + 1), (ModelHyperParams.max_length + 1)), "float32"], # This shape input is used to reshape the output of embedding layer. @@ -170,6 +171,8 @@ input_descs = { # The actual data shape of label_weight is: # [batch_size * max_trg_len_in_batch, 1] "lbl_weight": [(1 * (ModelHyperParams.max_length + 1), 1L), "float32"], + # These two inputs are used for beam search decoder. + # "start_token": [(1 * 1, 1L), "int64"], } # Names of position encoding table which will be initialized externally. @@ -200,3 +203,7 @@ decoder_util_input_fields = ( label_data_input_fields = ( "lbl_word", "lbl_weight", ) +fast_decoder_data_fields = ( + "trg_word", + # "start_token", + "trg_src_attn_bias", ) diff --git a/fluid/neural_machine_translation/transformer/infer.py b/fluid/neural_machine_translation/transformer/infer.py index e8f7f47dd5c0dc4937b73bd1693b2fd14fb8d55c..b1668b1c4c88ac4bb8525e06574e205cab3f0296 100644 --- a/fluid/neural_machine_translation/transformer/infer.py +++ b/fluid/neural_machine_translation/transformer/infer.py @@ -7,6 +7,7 @@ import paddle.fluid as fluid import model from model import wrap_encoder as encoder from model import wrap_decoder as decoder +from model import fast_decode as fast_decoder from config import * from train import pad_batch_data import reader @@ -416,5 +417,15 @@ def infer(args): if __name__ == "__main__": + fast_decoder(ModelHyperParams.src_vocab_size, + ModelHyperParams.trg_vocab_size, + ModelHyperParams.max_length + 1, ModelHyperParams.n_layer, + ModelHyperParams.n_head, ModelHyperParams.d_key, + ModelHyperParams.d_value, ModelHyperParams.d_model, + ModelHyperParams.d_inner_hid, ModelHyperParams.dropout, + InferTaskConfig.beam_size, InferTaskConfig.max_length, + ModelHyperParams.eos_idx) + print(fluid.default_main_program()) + exit(0) args = parse_args() infer(args) diff --git a/fluid/neural_machine_translation/transformer/model.py b/fluid/neural_machine_translation/transformer/model.py index 9c5d8adc312d48eb7c232789e590755e1b349d3a..23c0e0507eceb4d44746a4077274792a3f4c8b96 100644 --- a/fluid/neural_machine_translation/transformer/model.py +++ b/fluid/neural_machine_translation/transformer/model.py @@ -30,7 +30,8 @@ def multi_head_attention(queries, n_head=1, dropout_rate=0., pre_softmax_shape=None, - post_softmax_shape=None): + post_softmax_shape=None, + cache=None): """ Multi-Head Attention. Note that attn_bias is added to the logit before computing softmax activiation to mask certain selected positions so that @@ -128,6 +129,12 @@ def multi_head_attention(queries, q, k, v = __compute_qkv(queries, keys, values, n_head, d_key, d_value) + if cache is not None: # use cache and concat time steps + print cache["k"].shape + print k.shape + k = cache["k"] = layers.concat([cache["k"], k], axis=1) + v = cache["v"] = layers.concat([cache["v"], v], axis=1) + q = __split_heads(q, n_head) k = __split_heads(k, n_head) v = __split_heads(v, n_head) @@ -300,7 +307,8 @@ def decoder_layer(dec_input, slf_attn_pre_softmax_shape=None, slf_attn_post_softmax_shape=None, src_attn_pre_softmax_shape=None, - src_attn_post_softmax_shape=None): + src_attn_post_softmax_shape=None, + cache=None): """ The layer to be stacked in decoder part. The structure of this module is similar to that in the encoder part except a multi-head attention is added to implement encoder-decoder attention. @@ -316,7 +324,8 @@ def decoder_layer(dec_input, n_head, dropout_rate, slf_attn_pre_softmax_shape, - slf_attn_post_softmax_shape, ) + slf_attn_post_softmax_shape, + cache, ) slf_attn_output = post_process_layer( dec_input, slf_attn_output, @@ -365,26 +374,18 @@ def decoder(dec_input, slf_attn_pre_softmax_shape=None, slf_attn_post_softmax_shape=None, src_attn_pre_softmax_shape=None, - src_attn_post_softmax_shape=None): + src_attn_post_softmax_shape=None, + caches=None): """ The decoder is composed of a stack of identical decoder_layer layers. """ for i in range(n_layer): dec_output = decoder_layer( - dec_input, - enc_output, - dec_slf_attn_bias, - dec_enc_attn_bias, - n_head, - d_key, - d_value, - d_model, - d_inner_hid, - dropout_rate, - slf_attn_pre_softmax_shape, - slf_attn_post_softmax_shape, - src_attn_pre_softmax_shape, - src_attn_post_softmax_shape, ) + dec_input, enc_output, dec_slf_attn_bias, dec_enc_attn_bias, n_head, + d_key, d_value, d_model, d_inner_hid, dropout_rate, + slf_attn_pre_softmax_shape, slf_attn_post_softmax_shape, + src_attn_pre_softmax_shape, src_attn_post_softmax_shape, None + if caches is None else caches[i]) dec_input = dec_output return dec_output @@ -523,7 +524,8 @@ def wrap_decoder(trg_vocab_size, d_inner_hid, dropout_rate, dec_inputs=None, - enc_output=None): + enc_output=None, + caches=None): """ The wrapper assembles together all needed layers for the decoder. """ @@ -563,7 +565,8 @@ def wrap_decoder(trg_vocab_size, slf_attn_pre_softmax_shape, slf_attn_post_softmax_shape, src_attn_pre_softmax_shape, - src_attn_post_softmax_shape, ) + src_attn_post_softmax_shape, + caches, ) # Return logits for training and probs for inference. predict = layers.reshape( x=layers.fc(input=dec_output, @@ -573,3 +576,112 @@ def wrap_decoder(trg_vocab_size, shape=[-1, trg_vocab_size], act="softmax" if dec_inputs is None else None) return predict + + +def fast_decode( + src_vocab_size, + trg_vocab_size, + max_in_len, + n_layer, + n_head, + d_key, + d_value, + d_model, + d_inner_hid, + dropout_rate, + beam_size, + max_out_len, + eos_idx, ): + enc_output = wrap_encoder(src_vocab_size, max_in_len, n_layer, n_head, + d_key, d_value, d_model, d_inner_hid, + dropout_rate) + start_tokens, trg_src_attn_bias, trg_data_shape, \ + slf_attn_pre_softmax_shape, slf_attn_post_softmax_shape, \ + src_attn_pre_softmax_shape, src_attn_post_softmax_shape = \ + make_all_inputs(fast_decoder_data_fields + decoder_util_input_fields) + + def beam_search(): + cond = layers.create_tensor(dtype='bool') + while_op = layers.While(cond) + max_len = layers.fill_constant( + shape=[1], dtype='int32', value=max_out_len) + step_idx = layers.fill_constant(shape=[1], dtype='int32', value=0) + init_scores = layers.fill_constant_batch_size_like( + input=start_tokens, shape=[-1, 1], dtype="float32", value=0) + # array states + ids = layers.array_write(start_tokens, step_idx) + scores = layers.array_write(init_scores, step_idx) + # cell states (can be overwrited) + caches = [{ + "k": layers.fill_constant_batch_size_like( + input=start_tokens, + shape=[-1, 0, d_model], + dtype="float32", + value=0), + "v": layers.fill_constant_batch_size_like( + input=start_tokens, + shape=[-1, 0, d_model], + dtype="float32", + value=0) + } for i in range(n_layer)] + + with while_op.block(): + pre_ids = layers.array_read(array=ids, i=step_idx) + pre_scores = layers.array_read(array=scores, i=step_idx) + pre_pos = layers.elementwise_mul( + x=layers.fill_constant_batch_size_like( + input=pre_ids, value=1, shape=[-1, 1], dtype='int32'), + y=layers.increment( + x=step_idx, value=1.0, in_place=False)) + pre_src_attn_bias = layers.sequence_expand( + x=trg_src_attn_bias, y=pre_ids) + pre_enc_output = layers.sequence_expand(x=enc_output, y=pre_ids) + print caches[0]["k"].shape + pre_caches = [{ + "k": layers.sequence_expand( + x=cache["k"], y=pre_ids), + "v": layers.sequence_expand( + x=cache["v"], y=pre_ids), + } for cache in caches] + print pre_caches[0]["k"].shape + logits = wrap_decoder( + trg_vocab_size, + max_in_len, + n_layer, + n_head, + d_key, + d_value, + d_model, + d_inner_hid, + dropout_rate, + dec_inputs=( + pre_ids, pre_pos, None, pre_src_attn_bias, trg_data_shape, + slf_attn_pre_softmax_shape, slf_attn_post_softmax_shape, + src_attn_pre_softmax_shape, src_attn_post_softmax_shape), + enc_output=pre_enc_output, + caches=pre_caches) + topk_scores, topk_indices = layers.topk(logits, k=beam_size) + accu_scores = layers.elementwise_add( + x=pre_scores, y=layers.log(x=layers.softmax(topk_scores))) + selected_ids, selected_scores = layers.beam_search( + pre_ids=pre_ids, + ids=topk_indices, + scores=accu_scores, + beam_size=beam_size, + end_id=eos_idx) + + layers.increment(x=step_idx, value=1.0, in_place=True) + # update states + layers.array_write(selected_ids, i=step_idx) + layers.array_write(selected_scores, i=step_idx) + layers.assign(pre_src_attn_bias, trg_src_attn_bias) + layers.assign(pre_enc_output, enc_output) + for i in range(n_layer): + layers.assign(pre_caches[i]["k"], caches[i]["k"]) + layers.assign(pre_caches[i]["v"], caches[i]["v"]) + + max_len_cond = layers.less_than(x=step_idx, y=max_len) + all_finish_cond = layers.less_than(x=step_idx, y=max_len) + layers.logical_or(x=max_len_cond, y=all_finish_cond, out=cond) + + beam_search()