diff --git a/fluid/neural_machine_translation/transformer/config.py b/fluid/neural_machine_translation/transformer/config.py index 71e4314953383b8f89b40fdfd8cc4274f954fed1..8bfdf6461bdbfae92afe36520b3b056dddb4836c 100644 --- a/fluid/neural_machine_translation/transformer/config.py +++ b/fluid/neural_machine_translation/transformer/config.py @@ -92,7 +92,9 @@ pos_enc_param_names = ( encoder_input_data_names = ( "src_word", "src_pos", - "src_slf_attn_bias", ) + "src_slf_attn_bias", + "src_slf_attn_pre_softmax_shape", + "src_slf_attn_post_softmax_shape", ) # Names of all data layers in decoder listed in order. decoder_input_data_names = ( @@ -100,6 +102,10 @@ decoder_input_data_names = ( "trg_pos", "trg_slf_attn_bias", "trg_src_attn_bias", + "trg_slf_attn_pre_softmax_shape", + "trg_slf_attn_post_softmax_shape", + "trg_src_attn_pre_softmax_shape", + "trg_src_attn_post_softmax_shape", "enc_output", ) # Names of label related data layers listed in order. diff --git a/fluid/neural_machine_translation/transformer/infer.py b/fluid/neural_machine_translation/transformer/infer.py index e4dee220cedf856633ee626b762804e49a10cfe8..b8b002dc0757481137d452400f276af4342a8af9 100644 --- a/fluid/neural_machine_translation/transformer/infer.py +++ b/fluid/neural_machine_translation/transformer/infer.py @@ -27,7 +27,14 @@ def translate_batch(exe, src_words, encoder, enc_in_names, enc_out_names, is_target=False, return_pos=True, return_attn_bias=True, - return_max_len=True) + return_max_len=False) + # Append the shape inputs to reshape before and after softmax in encoder + # self attention. + enc_in_data = enc_in_data + [ + np.array( + [-1, enc_in_data[2].shape[-1]], dtype="int32"), np.array( + enc_in_data[2].shape, dtype="int32") + ] enc_output = exe.run(encoder, feed=dict(zip(enc_in_names, enc_in_data)), fetch_list=enc_out_names)[0] @@ -35,8 +42,8 @@ def translate_batch(exe, src_words, encoder, enc_in_names, enc_out_names, # Beam Search. # To store the beam info. scores = np.zeros((batch_size, beam_size), dtype="float32") - prev_branchs = [[]] * batch_size - next_ids = [[]] * batch_size + prev_branchs = [[] for i in range(batch_size)] + next_ids = [[] for i in range(batch_size)] # Use beam_map to map the instance idx in batch to beam idx, since the # size of feeded batch is changing. beam_map = range(batch_size) @@ -64,8 +71,8 @@ def translate_batch(exe, src_words, encoder, enc_in_names, enc_out_names, trg_words = np.array( [[bos_idx]] * batch_size * beam_size, dtype="int64") trg_pos = np.array([[1]] * batch_size * beam_size, dtype="int64") - src_max_length, src_slf_attn_bias, trg_max_len = enc_in_data[ - -1], enc_in_data[-2], 1 + src_max_length, src_slf_attn_bias, trg_max_len = enc_in_data[2].shape[ + -1], enc_in_data[2], 1 # This is used to remove attention on subsequent words. trg_slf_attn_bias = np.ones((batch_size * beam_size, trg_max_len, trg_max_len)) @@ -77,15 +84,33 @@ def translate_batch(exe, src_words, encoder, enc_in_names, enc_out_names, trg_src_attn_bias = np.tile( src_slf_attn_bias[:, :, ::src_max_length, :], [beam_size, 1, trg_max_len, 1]) + # Append the shape inputs to reshape before and after softmax in + # decoder self attention. + trg_slf_attn_pre_softmax_shape = np.array( + [-1, trg_slf_attn_bias.shape[-1]], dtype="int32") + trg_slf_attn_post_softmax_shape = np.array( + trg_slf_attn_bias.shape, dtype="int32") + # Append the shape inputs to reshape before and after softmax in + # encoder-decoder attention. + trg_src_attn_pre_softmax_shape = np.array( + [-1, trg_src_attn_bias.shape[-1]], dtype="int32") + trg_src_attn_post_softmax_shape = np.array( + trg_src_attn_bias.shape, dtype="int32") enc_output = np.tile(enc_output, [beam_size, 1, 1]) - return trg_words, trg_pos, trg_slf_attn_bias, trg_src_attn_bias, enc_output + return trg_words, trg_pos, trg_slf_attn_bias, trg_src_attn_bias, \ + trg_slf_attn_pre_softmax_shape, trg_slf_attn_post_softmax_shape, \ + trg_src_attn_pre_softmax_shape, trg_src_attn_post_softmax_shape, \ + enc_output def update_dec_in_data(dec_in_data, next_ids, active_beams): """ Update the input data of decoder mainly by slicing from the previous input data and dropping the finished instance beams. """ - trg_words, trg_pos, trg_slf_attn_bias, trg_src_attn_bias, enc_output = dec_in_data + trg_words, trg_pos, trg_slf_attn_bias, trg_src_attn_bias, \ + trg_slf_attn_pre_softmax_shape, trg_slf_attn_post_softmax_shape, \ + trg_src_attn_pre_softmax_shape, trg_src_attn_post_softmax_shape, \ + enc_output = dec_in_data trg_cur_len = len(next_ids[0]) + 1 # include the trg_words = np.array( [ @@ -112,8 +137,23 @@ def translate_batch(exe, src_words, encoder, enc_in_names, enc_out_names, trg_src_attn_bias = np.tile(trg_src_attn_bias[ active_beams_indice, :, ::trg_src_attn_bias.shape[2], :], [1, 1, trg_cur_len, 1]) + # Append the shape inputs to reshape before and after softmax in + # decoder self attention. + trg_slf_attn_pre_softmax_shape = np.array( + [-1, trg_slf_attn_bias.shape[-1]], dtype="int32") + trg_slf_attn_post_softmax_shape = np.array( + trg_slf_attn_bias.shape, dtype="int32") + # Append the shape inputs to reshape before and after softmax in + # encoder-decoder attention. + trg_src_attn_pre_softmax_shape = np.array( + [-1, trg_src_attn_bias.shape[-1]], dtype="int32") + trg_src_attn_post_softmax_shape = np.array( + trg_src_attn_bias.shape, dtype="int32") enc_output = enc_output[active_beams_indice, :, :] - return trg_words, trg_pos, trg_slf_attn_bias, trg_src_attn_bias, enc_output + return trg_words, trg_pos, trg_slf_attn_bias, trg_src_attn_bias, \ + trg_slf_attn_pre_softmax_shape, trg_slf_attn_post_softmax_shape, \ + trg_src_attn_pre_softmax_shape, trg_src_attn_post_softmax_shape, \ + enc_output dec_in_data = init_dec_in_data(batch_size, beam_size, enc_in_data, enc_output) diff --git a/fluid/neural_machine_translation/transformer/model.py b/fluid/neural_machine_translation/transformer/model.py index ba5ba4470759da5fd2c6dd3b3d61b88c3468bd27..57575103b34bc565db5c219b85b184178dfe9700 100644 --- a/fluid/neural_machine_translation/transformer/model.py +++ b/fluid/neural_machine_translation/transformer/model.py @@ -32,7 +32,9 @@ def multi_head_attention(queries, d_value, d_model, n_head=1, - dropout_rate=0.): + dropout_rate=0., + pre_softmax_shape=None, + post_softmax_shape=None): """ Multi-Head Attention. Note that attn_bias is added to the logit before computing softmax activiation to mask certain selected positions so that @@ -111,26 +113,16 @@ def multi_head_attention(queries, """ Scaled Dot-Product Attention """ - - # FIXME(guosheng): Optimize the shape in reshape_op or softmax_op. - - # The current implementation of softmax_op only supports 2D tensor, - # consequently it cannot be directly used here. - # If to use the reshape_op, Besides, the shape of product inferred in - # compile-time is not the actual shape in run-time. It cann't be used - # to set the attribute of reshape_op. - # So, here define the softmax for temporary solution. - - def __softmax(x, eps=1e-9): - exp_out = layers.exp(x=x) - sum_out = layers.reduce_sum(exp_out, dim=-1, keep_dim=False) - return layers.elementwise_div(x=exp_out, y=sum_out, axis=0) - scaled_q = layers.scale(x=q, scale=d_model**-0.5) product = layers.matmul(x=scaled_q, y=k, transpose_y=True) - weights = __softmax( - layers.elementwise_add( - x=product, y=attn_bias) if attn_bias else product) + weights = layers.reshape( + x=layers.elementwise_add( + x=product, y=attn_bias) if attn_bias else product, + shape=[-1, product.shape[-1]], + actual_shape=pre_softmax_shape, + act="softmax") + weights = layers.reshape( + x=weights, shape=product.shape, actual_shape=post_softmax_shape) if dropout_rate: weights = layers.dropout( weights, dropout_prob=dropout_rate, is_test=False) @@ -252,7 +244,9 @@ def encoder_layer(enc_input, d_value, d_model, d_inner_hid, - dropout_rate=0.): + dropout_rate=0., + pre_softmax_shape=None, + post_softmax_shape=None): """The encoder layers that can be stacked to form a deep encoder. This module consits of a multi-head (self) attention followed by @@ -260,9 +254,9 @@ def encoder_layer(enc_input, with the post_process_layer to add residual connection, layer normalization and droput. """ - attn_output = multi_head_attention(enc_input, enc_input, enc_input, - attn_bias, d_key, d_value, d_model, - n_head, dropout_rate) + attn_output = multi_head_attention( + enc_input, enc_input, enc_input, attn_bias, d_key, d_value, d_model, + n_head, dropout_rate, pre_softmax_shape, post_softmax_shape) attn_output = post_process_layer(enc_input, attn_output, "dan", dropout_rate) ffd_output = positionwise_feed_forward(attn_output, d_inner_hid, d_model) @@ -277,7 +271,9 @@ def encoder(enc_input, d_value, d_model, d_inner_hid, - dropout_rate=0.): + dropout_rate=0., + pre_softmax_shape=None, + post_softmax_shape=None): """ The encoder is composed of a stack of identical layers returned by calling encoder_layer. @@ -291,7 +287,9 @@ def encoder(enc_input, d_value, d_model, d_inner_hid, - dropout_rate, ) + dropout_rate, + pre_softmax_shape, + post_softmax_shape, ) enc_input = enc_output return enc_output @@ -305,7 +303,11 @@ def decoder_layer(dec_input, d_value, d_model, d_inner_hid, - dropout_rate=0.): + dropout_rate=0., + slf_attn_pre_softmax_shape=None, + slf_attn_post_softmax_shape=None, + src_attn_pre_softmax_shape=None, + src_attn_post_softmax_shape=None): """ The layer to be stacked in decoder part. The structure of this module is similar to that in the encoder part except @@ -320,7 +322,9 @@ def decoder_layer(dec_input, d_value, d_model, n_head, - dropout_rate, ) + dropout_rate, + slf_attn_pre_softmax_shape, + slf_attn_post_softmax_shape, ) slf_attn_output = post_process_layer( dec_input, slf_attn_output, @@ -335,7 +339,9 @@ def decoder_layer(dec_input, d_value, d_model, n_head, - dropout_rate, ) + dropout_rate, + src_attn_pre_softmax_shape, + src_attn_post_softmax_shape, ) enc_attn_output = post_process_layer( slf_attn_output, enc_attn_output, @@ -363,7 +369,11 @@ def decoder(dec_input, d_value, d_model, d_inner_hid, - dropout_rate=0.): + dropout_rate=0., + slf_attn_pre_softmax_shape=None, + slf_attn_post_softmax_shape=None, + src_attn_pre_softmax_shape=None, + src_attn_post_softmax_shape=None): """ The decoder is composed of a stack of identical decoder_layer layers. """ @@ -378,7 +388,11 @@ def decoder(dec_input, d_value, d_model, d_inner_hid, - dropout_rate, ) + dropout_rate, + slf_attn_pre_softmax_shape, + slf_attn_post_softmax_shape, + src_attn_pre_softmax_shape, + src_attn_post_softmax_shape, ) dec_input = dec_output return dec_output @@ -391,7 +405,9 @@ def make_inputs(input_data_names, is_pos, slf_attn_bias_flag, src_attn_bias_flag, - enc_output_flag=False): + enc_output_flag=False, + slf_attn_shape_flag=True, + src_attn_shape_flag=True): """ Define the input data layers for the transformer model. """ @@ -429,6 +445,32 @@ def make_inputs(input_data_names, dtype="float32", append_batch_size=False) input_layers += [src_attn_bias] + if slf_attn_shape_flag: + slf_attn_pre_softmax_shape = layers.data( + name=input_data_names[len(input_layers)], + shape=[3], + dtype="int32", + append_batch_size=False) + input_layers += [slf_attn_pre_softmax_shape] + slf_attn_post_softmax_shape = layers.data( + name=input_data_names[len(input_layers)], + shape=[3], + dtype="int32", + append_batch_size=False) + input_layers += [slf_attn_post_softmax_shape] + if src_attn_shape_flag: + src_attn_pre_softmax_shape = layers.data( + name=input_data_names[len(input_layers)], + shape=[3], + dtype="int32", + append_batch_size=False) + input_layers += [src_attn_pre_softmax_shape] + src_attn_post_softmax_shape = layers.data( + name=input_data_names[len(input_layers)], + shape=[3], + dtype="int32", + append_batch_size=False) + input_layers += [src_attn_post_softmax_shape] if enc_output_flag: enc_output = layers.data( name=input_data_names[len(input_layers)], @@ -436,6 +478,7 @@ def make_inputs(input_data_names, dtype="float32", append_batch_size=False) input_layers += [enc_output] + return input_layers @@ -453,8 +496,18 @@ def transformer( src_pad_idx, trg_pad_idx, pos_pad_idx, ): - enc_input_layers = make_inputs(encoder_input_data_names, n_head, d_model, - batch_size, max_length, True, True, False) + enc_input_layers = make_inputs( + encoder_input_data_names, + n_head, + d_model, + batch_size, + max_length, + is_pos=True, + slf_attn_bias_flag=True, + src_attn_bias_flag=False, + enc_output_flag=False, + slf_attn_shape_flag=True, + src_attn_shape_flag=False) enc_output = wrap_encoder( src_vocab_size, @@ -470,8 +523,18 @@ def transformer( pos_pad_idx, enc_input_layers, ) - dec_input_layers = make_inputs(decoder_input_data_names, n_head, d_model, - batch_size, max_length, True, True, True) + dec_input_layers = make_inputs( + decoder_input_data_names, + n_head, + d_model, + batch_size, + max_length, + is_pos=True, + slf_attn_bias_flag=True, + src_attn_bias_flag=True, + enc_output_flag=False, + slf_attn_shape_flag=True, + src_attn_shape_flag=True) predict = wrap_decoder( trg_vocab_size, @@ -490,8 +553,18 @@ def transformer( # Padding index do not contribute to the total loss. The weights is used to # cancel padding index in calculating the loss. - gold, weights = make_inputs(label_data_names, n_head, d_model, batch_size, - max_length, False, False, False) + gold, weights = make_inputs( + label_data_names, + n_head, + d_model, + batch_size, + max_length, + is_pos=False, + slf_attn_bias_flag=False, + src_attn_bias_flag=False, + enc_output_flag=False, + slf_attn_shape_flag=False, + src_attn_shape_flag=False) cost = layers.cross_entropy(input=predict, label=gold) weighted_cost = cost * weights return layers.reduce_sum(weighted_cost), predict @@ -514,11 +587,22 @@ def wrap_encoder(src_vocab_size, """ if enc_input_layers is None: # This is used to implement independent encoder program in inference. - src_word, src_pos, src_slf_attn_bias = make_inputs( - encoder_input_data_names, n_head, d_model, batch_size, max_length, - True, True, False) + src_word, src_pos, src_slf_attn_bias, slf_attn_pre_softmax_shape, \ + slf_attn_post_softmax_shape = make_inputs( + encoder_input_data_names, + n_head, + d_model, + batch_size, + max_length, + is_pos=True, + slf_attn_bias_flag=True, + src_attn_bias_flag=False, + enc_output_flag=False, + slf_attn_shape_flag=True, + src_attn_shape_flag=False) else: - src_word, src_pos, src_slf_attn_bias = enc_input_layers + src_word, src_pos, src_slf_attn_bias, slf_attn_pre_softmax_shape, \ + slf_attn_post_softmax_shape = enc_input_layers enc_input = prepare_encoder( src_word, src_pos, @@ -536,7 +620,9 @@ def wrap_encoder(src_vocab_size, d_value, d_model, d_inner_hid, - dropout_rate, ) + dropout_rate, + slf_attn_pre_softmax_shape, + slf_attn_post_softmax_shape, ) return enc_output @@ -558,11 +644,26 @@ def wrap_decoder(trg_vocab_size, """ if dec_input_layers is None: # This is used to implement independent decoder program in inference. - trg_word, trg_pos, trg_slf_attn_bias, trg_src_attn_bias, enc_output = make_inputs( - decoder_input_data_names, n_head, d_model, batch_size, max_length, - True, True, True, True) + trg_word, trg_pos, trg_slf_attn_bias, trg_src_attn_bias, \ + slf_attn_pre_softmax_shape, slf_attn_post_softmax_shape, \ + src_attn_pre_softmax_shape, src_attn_post_softmax_shape, \ + enc_output = make_inputs( + decoder_input_data_names, + n_head, + d_model, + batch_size, + max_length, + is_pos=True, + slf_attn_bias_flag=True, + src_attn_bias_flag=True, + enc_output_flag=True, + slf_attn_shape_flag=True, + src_attn_shape_flag=True) else: - trg_word, trg_pos, trg_slf_attn_bias, trg_src_attn_bias = dec_input_layers + trg_word, trg_pos, trg_slf_attn_bias, trg_src_attn_bias, \ + slf_attn_pre_softmax_shape, slf_attn_post_softmax_shape, \ + src_attn_pre_softmax_shape, src_attn_post_softmax_shape = \ + dec_input_layers dec_input = prepare_decoder( trg_word, @@ -583,7 +684,11 @@ def wrap_decoder(trg_vocab_size, d_value, d_model, d_inner_hid, - dropout_rate, ) + dropout_rate, + slf_attn_pre_softmax_shape, + slf_attn_post_softmax_shape, + src_attn_pre_softmax_shape, + src_attn_post_softmax_shape, ) predict = layers.reshape( x=layers.fc(input=dec_output, diff --git a/fluid/neural_machine_translation/transformer/train.py b/fluid/neural_machine_translation/transformer/train.py index 65de8ef7fa8421bd72175175f1cf421a4237ddd5..13e4fe7a4aa787f6e59ceb15d40dbd1f1477c86c 100644 --- a/fluid/neural_machine_translation/transformer/train.py +++ b/fluid/neural_machine_translation/transformer/train.py @@ -66,13 +66,29 @@ def prepare_batch_input(insts, input_data_names, src_pad_idx, trg_pad_idx, [inst[1] for inst in insts], trg_pad_idx, n_head, is_target=True) trg_src_attn_bias = np.tile(src_slf_attn_bias[:, :, ::src_max_len, :], [1, 1, trg_max_len, 1]).astype("float32") + src_slf_attn_pre_softmax_shape = np.array( + [-1, src_slf_attn_bias.shape[-1]], dtype="int32") + src_slf_attn_post_softmax_shape = np.array( + src_slf_attn_bias.shape, dtype="int32") + trg_slf_attn_pre_softmax_shape = np.array( + [-1, trg_slf_attn_bias.shape[-1]], dtype="int32") + trg_slf_attn_post_softmax_shape = np.array( + trg_slf_attn_bias.shape, dtype="int32") + trg_src_attn_pre_softmax_shape = np.array( + [-1, trg_src_attn_bias.shape[-1]], dtype="int32") + trg_src_attn_post_softmax_shape = np.array( + trg_src_attn_bias.shape, dtype="int32") lbl_word = pad_batch_data([inst[2] for inst in insts], trg_pad_idx, n_head, False, False, False, False) lbl_weight = (lbl_word != trg_pad_idx).astype("float32").reshape([-1, 1]) input_dict = dict( zip(input_data_names, [ - src_word, src_pos, src_slf_attn_bias, trg_word, trg_pos, - trg_slf_attn_bias, trg_src_attn_bias, lbl_word, lbl_weight + src_word, src_pos, src_slf_attn_bias, + src_slf_attn_pre_softmax_shape, src_slf_attn_post_softmax_shape, + trg_word, trg_pos, trg_slf_attn_bias, trg_src_attn_bias, + trg_slf_attn_pre_softmax_shape, trg_slf_attn_post_softmax_shape, + trg_src_attn_pre_softmax_shape, trg_src_attn_post_softmax_shape, + lbl_word, lbl_weight ])) return input_dict