diff --git a/fluid/neural_machine_translation/transformer/config.py b/fluid/neural_machine_translation/transformer/config.py index 5024fc5f996df0c74cfe5e319d2ab9c332a1025f..6fdfaf07ddca6b06eea6bdcc1379a3b6937f05f1 100644 --- a/fluid/neural_machine_translation/transformer/config.py +++ b/fluid/neural_machine_translation/transformer/config.py @@ -25,8 +25,7 @@ class TrainTaskConfig(object): class InferTaskConfig(object): use_gpu = False # the number of examples in one run for sequence generation. - # currently the batch size can only be set to 1. - batch_size = 1 + batch_size = 10 # the parameters for beam search. beam_size = 5 @@ -103,6 +102,7 @@ encoder_input_data_names = ( "src_word", "src_pos", "src_slf_attn_bias", + "src_data_shape", "src_slf_attn_pre_softmax_shape", "src_slf_attn_post_softmax_shape", ) @@ -112,6 +112,7 @@ decoder_input_data_names = ( "trg_pos", "trg_slf_attn_bias", "trg_src_attn_bias", + "trg_data_shape", "trg_slf_attn_pre_softmax_shape", "trg_slf_attn_post_softmax_shape", "trg_src_attn_pre_softmax_shape", diff --git a/fluid/neural_machine_translation/transformer/infer.py b/fluid/neural_machine_translation/transformer/infer.py index 14d476105df63ee6b054ca1b5405f2f7f789bc9b..5634f160a99f83dd1b35e83251c303604f7e78a1 100644 --- a/fluid/neural_machine_translation/transformer/infer.py +++ b/fluid/neural_machine_translation/transformer/infer.py @@ -24,6 +24,7 @@ def translate_batch(exe, n_best, batch_size, n_head, + d_model, src_pad_idx, trg_pad_idx, bos_idx, @@ -43,6 +44,11 @@ def translate_batch(exe, return_pos=True, return_attn_bias=True, return_max_len=False) + # Append the data shape input to reshape the output of embedding layer. + enc_in_data = enc_in_data + [ + np.array( + [-1, enc_in_data[2].shape[-1], d_model], dtype="int32") + ] # Append the shape inputs to reshape before and after softmax in encoder # self attention. enc_in_data = enc_in_data + [ @@ -59,9 +65,14 @@ def translate_batch(exe, scores = np.zeros((batch_size, beam_size), dtype="float32") prev_branchs = [[] for i in range(batch_size)] next_ids = [[] for i in range(batch_size)] - # Use beam_map to map the instance idx in batch to beam idx, since the + # Use beam_inst_map to map beam idx to the instance idx in batch, since the # size of feeded batch is changing. - beam_map = range(batch_size) + beam_inst_map = { + beam_idx: inst_idx + for inst_idx, beam_idx in enumerate(range(batch_size)) + } + # Use active_beams to recode the alive. + active_beams = range(batch_size) def beam_backtrace(prev_branchs, next_ids, n_best=beam_size): """ @@ -98,8 +109,14 @@ def translate_batch(exe, [-1e9]).astype("float32") # This is used to remove attention on the paddings of source sequences. trg_src_attn_bias = np.tile( - src_slf_attn_bias[:, :, ::src_max_length, :], - [beam_size, 1, trg_max_len, 1]) + src_slf_attn_bias[:, :, ::src_max_length, :][:, np.newaxis], + [1, beam_size, 1, trg_max_len, 1]).reshape([ + -1, src_slf_attn_bias.shape[1], trg_max_len, + src_slf_attn_bias.shape[-1] + ]) + # Append the shape input to reshape the output of embedding layer. + trg_data_shape = np.array( + [batch_size * beam_size, trg_max_len, d_model], dtype="int32") # Append the shape inputs to reshape before and after softmax in # decoder self attention. trg_slf_attn_pre_softmax_shape = np.array( @@ -112,22 +129,24 @@ def translate_batch(exe, [-1, trg_src_attn_bias.shape[-1]], dtype="int32") trg_src_attn_post_softmax_shape = np.array( trg_src_attn_bias.shape, dtype="int32") - enc_output = np.tile(enc_output, [beam_size, 1, 1]) + enc_output = np.tile( + enc_output[:, np.newaxis], [1, beam_size, 1, 1]).reshape( + [-1, enc_output.shape[-2], enc_output.shape[-1]]) return trg_words, trg_pos, trg_slf_attn_bias, trg_src_attn_bias, \ - trg_slf_attn_pre_softmax_shape, trg_slf_attn_post_softmax_shape, \ - trg_src_attn_pre_softmax_shape, trg_src_attn_post_softmax_shape, \ - enc_output + trg_data_shape, trg_slf_attn_pre_softmax_shape, \ + trg_slf_attn_post_softmax_shape, trg_src_attn_pre_softmax_shape, \ + trg_src_attn_post_softmax_shape, enc_output - def update_dec_in_data(dec_in_data, next_ids, active_beams): + def update_dec_in_data(dec_in_data, next_ids, active_beams, beam_inst_map): """ Update the input data of decoder mainly by slicing from the previous input data and dropping the finished instance beams. """ trg_words, trg_pos, trg_slf_attn_bias, trg_src_attn_bias, \ - trg_slf_attn_pre_softmax_shape, trg_slf_attn_post_softmax_shape, \ - trg_src_attn_pre_softmax_shape, trg_src_attn_post_softmax_shape, \ - enc_output = dec_in_data - trg_cur_len = len(next_ids[0]) + 1 # include the + trg_data_shape, trg_slf_attn_pre_softmax_shape, \ + trg_slf_attn_post_softmax_shape, trg_src_attn_pre_softmax_shape, \ + trg_src_attn_post_softmax_shape, enc_output = dec_in_data + trg_cur_len = trg_slf_attn_bias.shape[-1] + 1 trg_words = np.array( [ beam_backtrace(prev_branchs[beam_idx], next_ids[beam_idx]) @@ -138,6 +157,7 @@ def translate_batch(exe, trg_pos = np.array( [range(1, trg_cur_len + 1)] * len(active_beams) * beam_size, dtype="int64").reshape([-1, 1]) + active_beams = [beam_inst_map[beam_idx] for beam_idx in active_beams] active_beams_indice = ( (np.array(active_beams) * beam_size)[:, np.newaxis] + np.array(range(beam_size))[np.newaxis, :]).flatten() @@ -152,6 +172,10 @@ def translate_batch(exe, trg_src_attn_bias = np.tile(trg_src_attn_bias[ active_beams_indice, :, ::trg_src_attn_bias.shape[2], :], [1, 1, trg_cur_len, 1]) + # Append the shape input to reshape the output of embedding layer. + trg_data_shape = np.array( + [len(active_beams) * beam_size, trg_cur_len, d_model], + dtype="int32") # Append the shape inputs to reshape before and after softmax in # decoder self attention. trg_slf_attn_pre_softmax_shape = np.array( @@ -166,9 +190,9 @@ def translate_batch(exe, trg_src_attn_bias.shape, dtype="int32") enc_output = enc_output[active_beams_indice, :, :] return trg_words, trg_pos, trg_slf_attn_bias, trg_src_attn_bias, \ - trg_slf_attn_pre_softmax_shape, trg_slf_attn_post_softmax_shape, \ - trg_src_attn_pre_softmax_shape, trg_src_attn_post_softmax_shape, \ - enc_output + trg_data_shape, trg_slf_attn_pre_softmax_shape, \ + trg_slf_attn_post_softmax_shape, trg_src_attn_pre_softmax_shape, \ + trg_src_attn_post_softmax_shape, enc_output dec_in_data = init_dec_in_data(batch_size, beam_size, enc_in_data, enc_output) @@ -177,15 +201,18 @@ def translate_batch(exe, feed=dict(zip(dec_in_names, dec_in_data)), fetch_list=dec_out_names)[0] predict_all = np.log( - predict_all.reshape([len(beam_map) * beam_size, i + 1, -1])[:, - -1, :]) - predict_all = (predict_all + scores[beam_map].reshape( - [len(beam_map) * beam_size, -1])).reshape( - [len(beam_map), beam_size, -1]) + predict_all.reshape([len(beam_inst_map) * beam_size, i + 1, -1]) + [:, -1, :]) + predict_all = (predict_all + scores[active_beams].reshape( + [len(beam_inst_map) * beam_size, -1])).reshape( + [len(beam_inst_map), beam_size, -1]) if not output_unk: # To exclude the token. predict_all[:, :, unk_idx] = -1e9 active_beams = [] - for inst_idx, beam_idx in enumerate(beam_map): + for beam_idx in range(batch_size): + if not beam_inst_map.has_key(beam_idx): + continue + inst_idx = beam_inst_map[beam_idx] predict = (predict_all[inst_idx, :, :] if i != 0 else predict_all[inst_idx, 0, :]).flatten() top_k_indice = np.argpartition(predict, -beam_size)[-beam_size:] @@ -198,10 +225,14 @@ def translate_batch(exe, next_ids[beam_idx].append(top_scores_ids % predict_all.shape[-1]) if next_ids[beam_idx][-1][0] != eos_idx: active_beams.append(beam_idx) - beam_map = active_beams - if len(beam_map) == 0: + if len(active_beams) == 0: break - dec_in_data = update_dec_in_data(dec_in_data, next_ids, active_beams) + dec_in_data = update_dec_in_data(dec_in_data, next_ids, active_beams, + beam_inst_map) + beam_inst_map = { + beam_idx: inst_idx + for inst_idx, beam_idx in enumerate(active_beams) + } # Decode beams and select n_best sequences for each instance by backtrace. seqs = [ @@ -215,10 +246,8 @@ def translate_batch(exe, def main(): place = fluid.CUDAPlace(0) if InferTaskConfig.use_gpu else fluid.CPUPlace() exe = fluid.Executor(place) - # The current program desc is coupled with batch_size and the only - # supported batch size is 1 currently. + encoder_program = fluid.Program() - model.batch_size = InferTaskConfig.batch_size with fluid.program_guard(main_program=encoder_program): enc_output = encoder( ModelHyperParams.src_vocab_size + 1, @@ -228,7 +257,6 @@ def main(): ModelHyperParams.d_inner_hid, ModelHyperParams.dropout, ModelHyperParams.src_pad_idx, ModelHyperParams.pos_pad_idx) - model.batch_size = InferTaskConfig.batch_size * InferTaskConfig.beam_size decoder_program = fluid.Program() with fluid.program_guard(main_program=decoder_program): predict = decoder( @@ -273,6 +301,9 @@ def main(): trg_idx2word = paddle.dataset.wmt16.get_dict( "de", dict_size=ModelHyperParams.trg_vocab_size, reverse=True) + # Append the token since the dict provided by dataset.wmt16 does + # not include it. + trg_idx2word[ModelHyperParams.trg_pad_idx] = "" def post_process_seq(seq, bos_idx=ModelHyperParams.bos_idx, @@ -306,6 +337,7 @@ def main(): InferTaskConfig.n_best, len(data), ModelHyperParams.n_head, + ModelHyperParams.d_model, ModelHyperParams.src_pad_idx, ModelHyperParams.trg_pad_idx, ModelHyperParams.bos_idx, diff --git a/fluid/neural_machine_translation/transformer/model.py b/fluid/neural_machine_translation/transformer/model.py index 73420a234edf74fbe913bb719719ee9f3ade39b0..eabd17a703747951fcf0a027f47688aaaa204703 100644 --- a/fluid/neural_machine_translation/transformer/model.py +++ b/fluid/neural_machine_translation/transformer/model.py @@ -7,9 +7,6 @@ import paddle.fluid.layers as layers from config import TrainTaskConfig, pos_enc_param_names, \ encoder_input_data_names, decoder_input_data_names, label_data_names -# FIXME(guosheng): Remove out the batch_size from the model. -batch_size = TrainTaskConfig.batch_size - def position_encoding_init(n_position, d_pos_vec): """ @@ -85,9 +82,10 @@ def multi_head_attention(queries, return x hidden_size = x.shape[-1] - # FIXME(guosheng): Decouple the program desc with batch_size. + # The value 0 in shape attr means copying the corresponding dimension + # size of the input as the output dimension size. reshaped = layers.reshape( - x=x, shape=[batch_size, -1, n_head, hidden_size // n_head]) + x=x, shape=[0, -1, n_head, hidden_size // n_head]) # permuate the dimensions into: # [batch_size, n_head, max_sequence_len, hidden_size_per_head] @@ -103,11 +101,11 @@ def multi_head_attention(queries, raise ValueError("Input(x) should be a 4-D Tensor.") trans_x = layers.transpose(x, perm=[0, 2, 1, 3]) - # FIXME(guosheng): Decouple the program desc with batch_size. + # The value 0 in shape attr means copying the corresponding dimension + # size of the input as the output dimension size. return layers.reshape( x=trans_x, - shape=map(int, - [batch_size, -1, trans_x.shape[2] * trans_x.shape[3]])) + shape=map(int, [0, -1, trans_x.shape[2] * trans_x.shape[3]])) def scaled_dot_product_attention(q, k, v, attn_bias, d_model, dropout_rate): """ @@ -205,6 +203,7 @@ def prepare_encoder(src_word, src_max_len, dropout_rate=0., pos_pad_idx=0, + src_data_shape=None, pos_enc_param_name=None): """Add word embeddings and position encodings. The output tensor has a shape of: @@ -224,9 +223,10 @@ def prepare_encoder(src_word, param_attr=fluid.ParamAttr( name=pos_enc_param_name, trainable=False)) enc_input = src_word_emb + src_pos_enc - - # FIXME(guosheng): Decouple the program desc with batch_size. - enc_input = layers.reshape(x=enc_input, shape=[batch_size, -1, src_emb_dim]) + enc_input = layers.reshape( + x=enc_input, + shape=[-1, src_max_len, src_emb_dim], + actual_shape=src_data_shape) return layers.dropout( enc_input, dropout_prob=dropout_rate, is_test=False) if dropout_rate else enc_input @@ -401,20 +401,23 @@ def decoder(dec_input, def make_inputs(input_data_names, n_head, d_model, - batch_size, max_length, is_pos, slf_attn_bias_flag, src_attn_bias_flag, enc_output_flag=False, + data_shape_flag=True, slf_attn_shape_flag=True, src_attn_shape_flag=True): """ Define the input data layers for the transformer model. """ input_layers = [] - # The shapes here act as placeholder. - # The shapes set here is to pass the infer-shape in compile time. + batch_size = 1 # Only for the infer-shape in compile time. + # The shapes here act as placeholder and are set to pass the infer-shape in + # compile time. + # The actual data shape of word is: + # [batch_size * max_len_in_batch, 1] word = layers.data( name=input_data_names[len(input_layers)], shape=[batch_size * max_length, 1], @@ -422,6 +425,8 @@ def make_inputs(input_data_names, append_batch_size=False) input_layers += [word] # This is used for position data or label weight. + # The actual data shape of pos is: + # [batch_size * max_len_in_batch, 1] pos = layers.data( name=input_data_names[len(input_layers)], shape=[batch_size * max_length, 1], @@ -432,6 +437,8 @@ def make_inputs(input_data_names, # This input is used to remove attention weights on paddings for the # encoder and to remove attention weights on subsequent words for the # decoder. + # The actual data shape of slf_attn_bias_flag is: + # [batch_size, n_head, max_len_in_batch, max_len_in_batch] slf_attn_bias = layers.data( name=input_data_names[len(input_layers)], shape=[batch_size, n_head, max_length, max_length], @@ -439,40 +446,56 @@ def make_inputs(input_data_names, append_batch_size=False) input_layers += [slf_attn_bias] if src_attn_bias_flag: - # This input is used to remove attention weights on paddings. + # This input is used to remove attention weights on paddings. It's used + # in encoder-decoder attention. + # The actual data shape of slf_attn_bias_flag is: + # [batch_size, n_head, trg_max_len_in_batch, src_max_len_in_batch] src_attn_bias = layers.data( name=input_data_names[len(input_layers)], shape=[batch_size, n_head, max_length, max_length], dtype="float32", append_batch_size=False) input_layers += [src_attn_bias] + if data_shape_flag: + # This input is used to reshape the output of embedding layer. + data_shape = layers.data( + name=input_data_names[len(input_layers)], + shape=[3], + dtype="int32", + append_batch_size=False) + input_layers += [data_shape] if slf_attn_shape_flag: + # This shape input is used to reshape before softmax in self attention. slf_attn_pre_softmax_shape = layers.data( name=input_data_names[len(input_layers)], - shape=[3], + shape=[2], dtype="int32", append_batch_size=False) input_layers += [slf_attn_pre_softmax_shape] + # This shape input is used to reshape after softmax in self attention. slf_attn_post_softmax_shape = layers.data( name=input_data_names[len(input_layers)], - shape=[3], + shape=[4], dtype="int32", append_batch_size=False) input_layers += [slf_attn_post_softmax_shape] if src_attn_shape_flag: src_attn_pre_softmax_shape = layers.data( name=input_data_names[len(input_layers)], - shape=[3], + shape=[2], dtype="int32", append_batch_size=False) input_layers += [src_attn_pre_softmax_shape] src_attn_post_softmax_shape = layers.data( name=input_data_names[len(input_layers)], - shape=[3], + shape=[4], dtype="int32", append_batch_size=False) input_layers += [src_attn_post_softmax_shape] if enc_output_flag: + # This input is used in independent decoder program for inference. + # The actual data shape of slf_attn_bias_flag is: + # [batch_size, max_len_in_batch, d_model] enc_output = layers.data( name=input_data_names[len(input_layers)], shape=[batch_size, max_length, d_model], @@ -497,16 +520,16 @@ def transformer( src_pad_idx, trg_pad_idx, pos_pad_idx, ): - enc_input_layers = make_inputs( + enc_inputs = make_inputs( encoder_input_data_names, n_head, d_model, - batch_size, max_length, is_pos=True, slf_attn_bias_flag=True, src_attn_bias_flag=False, enc_output_flag=False, + data_shape_flag=True, slf_attn_shape_flag=True, src_attn_shape_flag=False) @@ -522,18 +545,18 @@ def transformer( dropout_rate, src_pad_idx, pos_pad_idx, - enc_input_layers, ) + enc_inputs, ) - dec_input_layers = make_inputs( + dec_inputs = make_inputs( decoder_input_data_names, n_head, d_model, - batch_size, max_length, is_pos=True, slf_attn_bias_flag=True, src_attn_bias_flag=True, enc_output_flag=False, + data_shape_flag=True, slf_attn_shape_flag=True, src_attn_shape_flag=True) @@ -549,7 +572,7 @@ def transformer( dropout_rate, trg_pad_idx, pos_pad_idx, - dec_input_layers, + dec_inputs, enc_output, ) # Padding index do not contribute to the total loss. The weights is used to @@ -558,12 +581,12 @@ def transformer( label_data_names, n_head, d_model, - batch_size, max_length, is_pos=False, slf_attn_bias_flag=False, src_attn_bias_flag=False, enc_output_flag=False, + data_shape_flag=False, slf_attn_shape_flag=False, src_attn_shape_flag=False) cost = layers.softmax_with_cross_entropy(logits=predict, label=gold) @@ -571,7 +594,7 @@ def transformer( sum_cost = layers.reduce_sum(weighted_cost) token_num = layers.reduce_sum(weights) avg_cost = sum_cost / token_num - return sum_cost, avg_cost, predict + return sum_cost, avg_cost, predict, token_num def wrap_encoder(src_vocab_size, @@ -585,28 +608,30 @@ def wrap_encoder(src_vocab_size, dropout_rate, src_pad_idx, pos_pad_idx, - enc_input_layers=None): + enc_inputs=None): """ The wrapper assembles together all needed layers for the encoder. """ - if enc_input_layers is None: + if enc_inputs is None: # This is used to implement independent encoder program in inference. - src_word, src_pos, src_slf_attn_bias, slf_attn_pre_softmax_shape, \ - slf_attn_post_softmax_shape = make_inputs( + src_word, src_pos, src_slf_attn_bias, src_data_shape, \ + slf_attn_pre_softmax_shape, slf_attn_post_softmax_shape = \ + make_inputs( encoder_input_data_names, n_head, d_model, - batch_size, max_length, is_pos=True, slf_attn_bias_flag=True, src_attn_bias_flag=False, enc_output_flag=False, + data_shape_flag=True, slf_attn_shape_flag=True, src_attn_shape_flag=False) else: - src_word, src_pos, src_slf_attn_bias, slf_attn_pre_softmax_shape, \ - slf_attn_post_softmax_shape = enc_input_layers + src_word, src_pos, src_slf_attn_bias, src_data_shape, \ + slf_attn_pre_softmax_shape, slf_attn_post_softmax_shape = \ + enc_inputs enc_input = prepare_encoder( src_word, src_pos, @@ -614,7 +639,9 @@ def wrap_encoder(src_vocab_size, d_model, src_pad_idx, max_length, - dropout_rate, ) + dropout_rate, + pos_pad_idx, + src_data_shape, ) enc_output = encoder( enc_input, src_slf_attn_bias, @@ -641,33 +668,33 @@ def wrap_decoder(trg_vocab_size, dropout_rate, trg_pad_idx, pos_pad_idx, - dec_input_layers=None, + dec_inputs=None, enc_output=None): """ The wrapper assembles together all needed layers for the decoder. """ - if dec_input_layers is None: + if dec_inputs is None: # This is used to implement independent decoder program in inference. trg_word, trg_pos, trg_slf_attn_bias, trg_src_attn_bias, \ - slf_attn_pre_softmax_shape, slf_attn_post_softmax_shape, \ - src_attn_pre_softmax_shape, src_attn_post_softmax_shape, \ - enc_output = make_inputs( + trg_data_shape, slf_attn_pre_softmax_shape, \ + slf_attn_post_softmax_shape, src_attn_pre_softmax_shape, \ + src_attn_post_softmax_shape, enc_output = make_inputs( decoder_input_data_names, n_head, d_model, - batch_size, max_length, is_pos=True, slf_attn_bias_flag=True, src_attn_bias_flag=True, enc_output_flag=True, + data_shape_flag=True, slf_attn_shape_flag=True, src_attn_shape_flag=True) else: trg_word, trg_pos, trg_slf_attn_bias, trg_src_attn_bias, \ - slf_attn_pre_softmax_shape, slf_attn_post_softmax_shape, \ - src_attn_pre_softmax_shape, src_attn_post_softmax_shape = \ - dec_input_layers + trg_data_shape, slf_attn_pre_softmax_shape, \ + slf_attn_post_softmax_shape, src_attn_pre_softmax_shape, \ + src_attn_post_softmax_shape = dec_inputs dec_input = prepare_decoder( trg_word, @@ -676,7 +703,9 @@ def wrap_decoder(trg_vocab_size, d_model, trg_pad_idx, max_length, - dropout_rate, ) + dropout_rate, + pos_pad_idx, + trg_data_shape, ) dec_output = decoder( dec_input, enc_output, @@ -700,5 +729,5 @@ def wrap_decoder(trg_vocab_size, bias_attr=False, num_flatten_dims=2), shape=[-1, trg_vocab_size], - act="softmax" if dec_input_layers is None else None) + act="softmax" if dec_inputs is None else None) return predict diff --git a/fluid/neural_machine_translation/transformer/train.py b/fluid/neural_machine_translation/transformer/train.py index e267f209edf03962c41b67b456a3308d2a657145..0887a954342039518f9c973852f6695c9ea5c2d4 100644 --- a/fluid/neural_machine_translation/transformer/train.py +++ b/fluid/neural_machine_translation/transformer/train.py @@ -57,7 +57,7 @@ def pad_batch_data(insts, def prepare_batch_input(insts, input_data_names, src_pad_idx, trg_pad_idx, - max_length, n_head): + n_head, d_model): """ Put all padded data needed by training into a dict. """ @@ -67,6 +67,10 @@ def prepare_batch_input(insts, input_data_names, src_pad_idx, trg_pad_idx, [inst[1] for inst in insts], trg_pad_idx, n_head, is_target=True) trg_src_attn_bias = np.tile(src_slf_attn_bias[:, :, ::src_max_len, :], [1, 1, trg_max_len, 1]).astype("float32") + + # These shape tensors are used in reshape_op. + src_data_shape = np.array([len(insts), src_max_len, d_model], dtype="int32") + trg_data_shape = np.array([len(insts), trg_max_len, d_model], dtype="int32") src_slf_attn_pre_softmax_shape = np.array( [-1, src_slf_attn_bias.shape[-1]], dtype="int32") src_slf_attn_post_softmax_shape = np.array( @@ -79,17 +83,19 @@ def prepare_batch_input(insts, input_data_names, src_pad_idx, trg_pad_idx, [-1, trg_src_attn_bias.shape[-1]], dtype="int32") trg_src_attn_post_softmax_shape = np.array( trg_src_attn_bias.shape, dtype="int32") + lbl_word = pad_batch_data([inst[2] for inst in insts], trg_pad_idx, n_head, False, False, False, False) lbl_weight = (lbl_word != trg_pad_idx).astype("float32").reshape([-1, 1]) + input_dict = dict( zip(input_data_names, [ - src_word, src_pos, src_slf_attn_bias, + src_word, src_pos, src_slf_attn_bias, src_data_shape, src_slf_attn_pre_softmax_shape, src_slf_attn_post_softmax_shape, trg_word, trg_pos, trg_slf_attn_bias, trg_src_attn_bias, - trg_slf_attn_pre_softmax_shape, trg_slf_attn_post_softmax_shape, - trg_src_attn_pre_softmax_shape, trg_src_attn_post_softmax_shape, - lbl_word, lbl_weight + trg_data_shape, trg_slf_attn_pre_softmax_shape, + trg_slf_attn_post_softmax_shape, trg_src_attn_pre_softmax_shape, + trg_src_attn_post_softmax_shape, lbl_word, lbl_weight ])) return input_dict @@ -98,7 +104,7 @@ def main(): place = fluid.CUDAPlace(0) if TrainTaskConfig.use_gpu else fluid.CPUPlace() exe = fluid.Executor(place) - sum_cost, avg_cost, predict = transformer( + sum_cost, avg_cost, predict, token_num = transformer( ModelHyperParams.src_vocab_size + 1, ModelHyperParams.trg_vocab_size + 1, ModelHyperParams.max_length + 1, ModelHyperParams.n_layer, ModelHyperParams.n_head, @@ -134,21 +140,24 @@ def main(): batch_size=TrainTaskConfig.batch_size) def test(exe): - test_sum_costs = [] - test_avg_costs = [] + test_total_cost = 0 + test_total_token = 0 for batch_id, data in enumerate(val_data()): - if len(data) != TrainTaskConfig.batch_size: - continue data_input = prepare_batch_input( data, encoder_input_data_names + decoder_input_data_names[:-1] + label_data_names, ModelHyperParams.src_pad_idx, - ModelHyperParams.trg_pad_idx, ModelHyperParams.max_length, - ModelHyperParams.n_head) - test_sum_cost, test_avg_cost = exe.run( - test_program, feed=data_input, fetch_list=[sum_cost, avg_cost]) - test_sum_costs.append(test_sum_cost) - test_avg_costs.append(test_avg_cost) - return np.mean(test_sum_costs), np.mean(test_avg_costs) + ModelHyperParams.trg_pad_idx, ModelHyperParams.n_head, + ModelHyperParams.d_model) + test_sum_cost, test_token_num = exe.run( + test_program, + feed=data_input, + fetch_list=[sum_cost, token_num], + use_program_cache=True) + test_total_cost += test_sum_cost + test_total_token += test_token_num + test_avg_cost = test_total_cost / test_total_token + test_ppl = np.exp([min(test_avg_cost, 100)]) + return test_avg_cost, test_ppl # Initialize the parameters. exe.run(fluid.framework.default_startup_program()) @@ -162,15 +171,11 @@ def main(): for pass_id in xrange(TrainTaskConfig.pass_num): pass_start_time = time.time() for batch_id, data in enumerate(train_data()): - # The current program desc is coupled with batch_size, thus all - # mini-batches must have the same number of instances currently. - if len(data) != TrainTaskConfig.batch_size: - continue data_input = prepare_batch_input( data, encoder_input_data_names + decoder_input_data_names[:-1] + label_data_names, ModelHyperParams.src_pad_idx, - ModelHyperParams.trg_pad_idx, ModelHyperParams.max_length, - ModelHyperParams.n_head) + ModelHyperParams.trg_pad_idx, ModelHyperParams.n_head, + ModelHyperParams.d_model) lr_scheduler.update_learning_rate(data_input) outs = exe.run(fluid.framework.default_main_program(), feed=data_input, @@ -181,13 +186,11 @@ def main(): (pass_id, batch_id, sum_cost_val, avg_cost_val, np.exp([min(avg_cost_val[0], 100)]))) # Validate and save the model for inference. - val_sum_cost, val_avg_cost = test(exe) + val_avg_cost, val_ppl = test(exe) pass_end_time = time.time() time_consumed = pass_end_time - pass_start_time - print("epoch: %d, val sum loss: %f, val avg loss: %f, val ppl: %f, " - "consumed %fs" % - (pass_id, val_sum_cost, val_avg_cost, - np.exp([min(val_avg_cost, 100)]), time_consumed)) + print("epoch: %d, val avg loss: %f, val ppl: %f, " + "consumed %fs" % (pass_id, val_avg_cost, val_ppl, time_consumed)) fluid.io.save_inference_model( os.path.join(TrainTaskConfig.model_dir, "pass_" + str(pass_id) + ".infer.model"),