diff --git a/fluid/neural_machine_translation/transformer/config.py b/fluid/neural_machine_translation/transformer/config.py index 091ea175291c56d63e1d8b42a874516d9733f1cf..a1e2fd903b1f0042be0f440c2f6f34da7af22909 100644 --- a/fluid/neural_machine_translation/transformer/config.py +++ b/fluid/neural_machine_translation/transformer/config.py @@ -15,6 +15,23 @@ class TrainTaskConfig(object): # the params for learning rate scheduling warmup_steps = 4000 + # the directory for saving inference models + model_dir = "transformer_model" + + +class InferTaskConfig(object): + use_gpu = False + # number of sequences contained in a mini-batch + batch_size = 1 + + # the params for beam search + beam_size = 5 + max_length = 30 + n_best = 1 + + # the directory for loading inference model + model_path = "transformer_model/pass_1.infer.model" + class ModelHyperParams(object): # Dictionary size for source and target language. This model directly uses @@ -33,6 +50,11 @@ class ModelHyperParams(object): # index for token in target language. trg_pad_idx = trg_vocab_size + # index for token + bos_idx = 0 + # index for token + eos_idx = 1 + # position value corresponding to the token. pos_pad_idx = 0 @@ -64,14 +86,21 @@ pos_enc_param_names = ( "src_pos_enc_table", "trg_pos_enc_table", ) -# Names of all data layers listed in order. -input_data_names = ( +# Names of all data layers in encoder listed in order. +encoder_input_data_names = ( "src_word", "src_pos", + "src_slf_attn_bias", ) + +# Names of all data layers in decoder listed in order. +decoder_input_data_names = ( "trg_word", "trg_pos", - "src_slf_attn_bias", "trg_slf_attn_bias", "trg_src_attn_bias", + "enc_output", ) + +# Names of label related data layers listed in order. +label_data_names = ( "lbl_word", "lbl_weight", ) diff --git a/fluid/neural_machine_translation/transformer/infer.py b/fluid/neural_machine_translation/transformer/infer.py new file mode 100644 index 0000000000000000000000000000000000000000..3ffa4e77171c91c6e06ec6a71725135acb4f4dc2 --- /dev/null +++ b/fluid/neural_machine_translation/transformer/infer.py @@ -0,0 +1,220 @@ +import numpy as np + +import paddle.v2 as paddle +import paddle.fluid as fluid + +import model +from model import wrap_encoder as encoder +from model import wrap_decoder as decoder +from config import InferTaskConfig, ModelHyperParams, \ + encoder_input_data_names, decoder_input_data_names +from train import pad_batch_data + + +def translate_batch(exe, src_words, encoder, enc_in_names, enc_out_names, + decoder, dec_in_names, dec_out_names, beam_size, max_length, + n_best, batch_size, n_head, src_pad_idx, trg_pad_idx, + bos_idx, eos_idx): + """ + Run the encoder program once and run the decoder program multiple times to + implement beam search externally. + """ + # Prepare data for encoder and run the encoder. + enc_in_data = pad_batch_data( + src_words, + src_pad_idx, + n_head, + is_target=False, + return_pos=True, + return_attn_bias=True, + return_max_len=True) + enc_output = exe.run(encoder, + feed=dict(zip(enc_in_names, enc_in_data)), + fetch_list=enc_out_names)[0] + + # Beam Search. + # To store the beam info. + scores = np.zeros((batch_size, beam_size), dtype="float32") + prev_branchs = [[]] * batch_size + next_ids = [[]] * batch_size + # Use beam_map to map the instance idx in batch to beam idx, since the + # size of feeded batch is changing. + beam_map = range(batch_size) + + def beam_backtrace(prev_branchs, next_ids, n_best=beam_size, add_bos=True): + """ + Decode and select n_best sequences for one instance by backtrace. + """ + seqs = [] + for i in range(n_best): + k = i + seq = [] + for j in range(len(prev_branchs) - 1, -1, -1): + seq.append(next_ids[j][k]) + k = prev_branchs[j][k] + seq = seq[::-1] + seq = [bos_idx] + seq if add_bos else seq + seqs.append(seq) + return seqs + + def init_dec_in_data(batch_size, beam_size, enc_in_data, enc_output): + """ + Initialize the input data for decoder. + """ + trg_words = np.array( + [[bos_idx]] * batch_size * beam_size, dtype="int64") + trg_pos = np.array([[1]] * batch_size * beam_size, dtype="int64") + src_max_length, src_slf_attn_bias, trg_max_len = enc_in_data[ + -1], enc_in_data[-2], 1 + trg_src_attn_bias = np.tile( + src_slf_attn_bias[:, :, ::src_max_length, :], + [beam_size, 1, trg_max_len, 1]) + enc_output = np.tile(enc_output, [beam_size, 1, 1]) + # No need for trg_slf_attn_bias because of no paddings. + return trg_words, trg_pos, None, trg_src_attn_bias, enc_output + + def update_dec_in_data(dec_in_data, next_ids, active_beams): + """ + Update the input data of decoder mainly by slicing from the previous + input data and dropping the finished instance beams. + """ + trg_words, trg_pos, trg_slf_attn_bias, trg_src_attn_bias, enc_output = dec_in_data + trg_words = np.array( + [ + beam_backtrace( + prev_branchs[beam_idx], next_ids[beam_idx], add_bos=True) + for beam_idx in active_beams + ], + dtype="int64") + trg_words = trg_words.reshape([-1, 1]) + trg_pos = np.array( + [range(1, len(next_ids[0]) + 2)] * len(active_beams) * beam_size, + dtype="int64").reshape([-1, 1]) + active_beams_indice = ( + (np.array(active_beams) * beam_size)[:, np.newaxis] + + np.array(range(beam_size))[np.newaxis, :]).flatten() + trg_src_attn_bias = np.tile(trg_src_attn_bias[ + active_beams_indice, :, ::trg_src_attn_bias.shape[2], :], + [1, 1, len(next_ids[0]) + 1, 1]) + enc_output = enc_output[active_beams_indice, :, :] + return trg_words, trg_pos, trg_slf_attn_bias, trg_src_attn_bias, enc_output + + dec_in_data = init_dec_in_data(batch_size, beam_size, enc_in_data, + enc_output) + for i in range(max_length): + predict_all = exe.run(decoder, + feed=dict( + filter(lambda item: item[1] is not None, + zip(dec_in_names, dec_in_data))), + fetch_list=dec_out_names)[0] + predict_all = np.log(predict_all) + predict_all = ( + predict_all.reshape( + [len(beam_map) * beam_size, i + 1, -1])[:, -1, :] + + scores[beam_map].reshape([len(beam_map) * beam_size, -1])).reshape( + [len(beam_map), beam_size, -1]) + active_beams = [] + for inst_idx, beam_idx in enumerate(beam_map): + predict = (predict_all[inst_idx, :, :] + if i != 0 else predict_all[inst_idx, 0, :]).flatten() + top_k_indice = np.argpartition(predict, -beam_size)[-beam_size:] + top_scores_ids = top_k_indice[np.argsort(predict[top_k_indice])[:: + -1]] + top_scores = predict[top_scores_ids] + scores[beam_idx] = top_scores + prev_branchs[beam_idx].append(top_scores_ids / + predict_all.shape[-1]) + next_ids[beam_idx].append(top_scores_ids % predict_all.shape[-1]) + if next_ids[beam_idx][-1][0] != eos_idx: + active_beams.append(beam_idx) + beam_map = active_beams + if len(beam_map) == 0: + break + dec_in_data = update_dec_in_data(dec_in_data, next_ids, active_beams) + + # Decode beams and select n_best sequences for each instance by backtrace. + seqs = [beam_backtrace(prev_branchs[beam_idx], next_ids[beam_idx], n_best)] + + return seqs, scores[:, :n_best].tolist() + + +def main(): + place = fluid.CUDAPlace(0) if InferTaskConfig.use_gpu else fluid.CPUPlace() + exe = fluid.Executor(place) + # The current program desc is coupled with batch_size and the only + # supported batch size is 1 currently. + encoder_program = fluid.Program() + model.batch_size = InferTaskConfig.batch_size + with fluid.program_guard(main_program=encoder_program): + enc_output = encoder( + ModelHyperParams.src_vocab_size + 1, + ModelHyperParams.max_length + 1, ModelHyperParams.n_layer, + ModelHyperParams.n_head, ModelHyperParams.d_key, + ModelHyperParams.d_value, ModelHyperParams.d_model, + ModelHyperParams.d_inner_hid, ModelHyperParams.dropout, + ModelHyperParams.src_pad_idx, ModelHyperParams.pos_pad_idx) + + model.batch_size = InferTaskConfig.batch_size * InferTaskConfig.beam_size + decoder_program = fluid.Program() + with fluid.program_guard(main_program=decoder_program): + predict = decoder( + ModelHyperParams.trg_vocab_size + 1, + ModelHyperParams.max_length + 1, ModelHyperParams.n_layer, + ModelHyperParams.n_head, ModelHyperParams.d_key, + ModelHyperParams.d_value, ModelHyperParams.d_model, + ModelHyperParams.d_inner_hid, ModelHyperParams.dropout, + ModelHyperParams.trg_pad_idx, ModelHyperParams.pos_pad_idx) + + # Load model parameters of encoder and decoder separately from the saved + # transformer model. + encoder_var_names = [] + for op in encoder_program.block(0).ops: + encoder_var_names += op.input_arg_names + encoder_param_names = filter( + lambda var_name: isinstance(encoder_program.block(0).var(var_name), + fluid.framework.Parameter), + encoder_var_names) + encoder_params = map(encoder_program.block(0).var, encoder_param_names) + decoder_var_names = [] + for op in decoder_program.block(0).ops: + decoder_var_names += op.input_arg_names + decoder_param_names = filter( + lambda var_name: isinstance(decoder_program.block(0).var(var_name), + fluid.framework.Parameter), + decoder_var_names) + decoder_params = map(decoder_program.block(0).var, decoder_param_names) + fluid.io.load_vars(exe, InferTaskConfig.model_path, vars=encoder_params) + fluid.io.load_vars(exe, InferTaskConfig.model_path, vars=decoder_params) + + # This is used here to set dropout to the test mode. + encoder_program = fluid.io.get_inference_program( + target_vars=[enc_output], main_program=encoder_program) + decoder_program = fluid.io.get_inference_program( + target_vars=[predict], main_program=decoder_program) + + test_data = paddle.batch( + paddle.dataset.wmt16.test(ModelHyperParams.src_vocab_size, + ModelHyperParams.trg_vocab_size), + batch_size=InferTaskConfig.batch_size) + + trg_idx2word = paddle.dataset.wmt16.get_dict( + "de", dict_size=ModelHyperParams.trg_vocab_size, reverse=True) + + for batch_id, data in enumerate(test_data()): + batch_seqs, batch_scores = translate_batch( + exe, [item[0] for item in data], encoder_program, + encoder_input_data_names, [enc_output.name], decoder_program, + decoder_input_data_names, [predict.name], InferTaskConfig.beam_size, + InferTaskConfig.max_length, InferTaskConfig.n_best, + InferTaskConfig.batch_size, ModelHyperParams.n_head, + ModelHyperParams.src_pad_idx, ModelHyperParams.trg_pad_idx, + ModelHyperParams.bos_idx, ModelHyperParams.eos_idx) + for i in range(len(batch_seqs)): + seqs = batch_seqs[i] + scores = batch_scores[i] + for seq in seqs: + print(" ".join([trg_idx2word[idx] for idx in seq])) + + +if __name__ == "__main__": + main() diff --git a/fluid/neural_machine_translation/transformer/model.py b/fluid/neural_machine_translation/transformer/model.py index 379a17221c3aaa4daf7f530f9553bcef89b42de6..b814419ea02ea5360079c9808569db601bfb30e8 100644 --- a/fluid/neural_machine_translation/transformer/model.py +++ b/fluid/neural_machine_translation/transformer/model.py @@ -4,7 +4,8 @@ import numpy as np import paddle.fluid as fluid import paddle.fluid.layers as layers -from config import TrainTaskConfig, input_data_names, pos_enc_param_names +from config import TrainTaskConfig, pos_enc_param_names, \ + encoder_input_data_names, decoder_input_data_names, label_data_names # FIXME(guosheng): Remove out the batch_size from the model. batch_size = TrainTaskConfig.batch_size @@ -127,7 +128,9 @@ def multi_head_attention(queries, scaled_q = layers.scale(x=q, scale=d_model**-0.5) product = layers.matmul(x=scaled_q, y=k, transpose_y=True) - weights = __softmax(layers.elementwise_add(x=product, y=attn_bias)) + weights = __softmax( + layers.elementwise_add( + x=product, y=attn_bias) if attn_bias else product) if dropout_rate: weights = layers.dropout( weights, dropout_prob=dropout_rate, is_test=False) @@ -373,6 +376,53 @@ def decoder(dec_input, return dec_output +def make_inputs(input_data_names, + n_head, + d_model, + batch_size, + max_length, + slf_attn_bias_flag, + src_attn_bias_flag, + pos_flag=1): + """ + Define the input data layers for the transformer model. + """ + input_layers = [] + # The shapes here act as placeholder. + # The shapes set here is to pass the infer-shape in compile time. + word = layers.data( + name=input_data_names[0], + shape=[batch_size * max_length, 1], + dtype="int64", + append_batch_size=False) + input_layers += [word] + # This is used for position data or label weight. + pos = layers.data( + name=input_data_names[1], + shape=[batch_size * max_length, 1], + dtype="int64" if pos_flag else "float32", + append_batch_size=False) + input_layers += [pos] + if slf_attn_bias_flag: + # This is used for attention bias or encoder output. + slf_attn_bias = layers.data( + name=input_data_names[2] + if slf_attn_bias_flag == 1 else input_data_names[-1], + shape=[batch_size, n_head, max_length, max_length] + if slf_attn_bias_flag == 1 else [batch_size, max_length, d_model], + dtype="float32", + append_batch_size=False) + input_layers += [slf_attn_bias] + if src_attn_bias_flag: + src_attn_bias = layers.data( + name=input_data_names[3], + shape=[batch_size, n_head, max_length, max_length], + dtype="float32", + append_batch_size=False) + input_layers += [src_attn_bias] + return input_layers + + def transformer( src_vocab_size, trg_vocab_size, @@ -387,61 +437,72 @@ def transformer( src_pad_idx, trg_pad_idx, pos_pad_idx, ): - # The shapes here act as placeholder. - # The shapes set here is to pass the infer-shape in compile time. The actual - # shape of src_word in run time is: - # [batch_size * max_src_length_in_a_batch, 1]. - src_word = layers.data( - name=input_data_names[0], - shape=[batch_size * max_length, 1], - dtype="int64", - append_batch_size=False) - # The actual shape of src_pos in runtime is: - # [batch_size * max_src_length_in_a_batch, 1]. - src_pos = layers.data( - name=input_data_names[1], - shape=[batch_size * max_length, 1], - dtype="int64", - append_batch_size=False) - # The actual shape of trg_word is in runtime is: - # [batch_size * max_trg_length_in_a_batch, 1]. - trg_word = layers.data( - name=input_data_names[2], - shape=[batch_size * max_length, 1], - dtype="int64", - append_batch_size=False) - # The actual shape of trg_pos in runtime is: - # [batch_size * max_trg_length_in_a_batch, 1]. - trg_pos = layers.data( - name=input_data_names[3], - shape=[batch_size * max_length, 1], - dtype="int64", - append_batch_size=False) - # The actual shape of src_slf_attn_bias in runtime is: - # [batch_size, n_head, max_src_length_in_a_batch, max_src_length_in_a_batch]. - # This input is used to remove attention weights on paddings. - src_slf_attn_bias = layers.data( - name=input_data_names[4], - shape=[batch_size, n_head, max_length, max_length], - dtype="float32", - append_batch_size=False) - # The actual shape of trg_slf_attn_bias in runtime is: - # [batch_size, n_head, max_trg_length_in_batch, max_trg_length_in_batch]. - # This is used to remove attention weights on paddings and subsequent words. - trg_slf_attn_bias = layers.data( - name=input_data_names[5], - shape=[batch_size, n_head, max_length, max_length], - dtype="float32", - append_batch_size=False) - # The actual shape of trg_src_attn_bias in runtime is: - # [batch_size, n_head, max_trg_length_in_batch, max_src_length_in_batch]. - # This is used to remove attention weights on paddings. - trg_src_attn_bias = layers.data( - name=input_data_names[6], - shape=[batch_size, n_head, max_length, max_length], - dtype="float32", - append_batch_size=False) + enc_input_layers = make_inputs(encoder_input_data_names, n_head, d_model, + batch_size, max_length, 1, 0) + enc_output = wrap_encoder( + src_vocab_size, + max_length, + n_layer, + n_head, + d_key, + d_value, + d_model, + d_inner_hid, + dropout_rate, + src_pad_idx, + pos_pad_idx, + enc_input_layers, ) + + dec_input_layers = make_inputs(decoder_input_data_names, n_head, d_model, + batch_size, max_length, 1, 1) + + predict = wrap_decoder( + trg_vocab_size, + max_length, + n_layer, + n_head, + d_key, + d_value, + d_model, + d_inner_hid, + dropout_rate, + trg_pad_idx, + pos_pad_idx, + dec_input_layers, + enc_output, ) + + # Padding index do not contribute to the total loss. The weights is used to + # cancel padding index in calculating the loss. + gold, weights = make_inputs(label_data_names, n_head, d_model, batch_size, + max_length, 0, 0, 0) + cost = layers.cross_entropy(input=predict, label=gold) + weighted_cost = cost * weights + return layers.reduce_sum(weighted_cost), predict + + +def wrap_encoder(src_vocab_size, + max_length, + n_layer, + n_head, + d_key, + d_value, + d_model, + d_inner_hid, + dropout_rate, + src_pad_idx, + pos_pad_idx, + enc_input_layers=None): + """ + The wrapper assembles together all needed layers for the encoder. + """ + if enc_input_layers is None: + # This is used to implement independent encoder program in inference. + src_word, src_pos, src_slf_attn_bias = make_inputs( + encoder_input_data_names, n_head, d_model, batch_size, max_length, + True, False) + else: + src_word, src_pos, src_slf_attn_bias = enc_input_layers enc_input = prepare_encoder( src_word, src_pos, @@ -460,6 +521,34 @@ def transformer( d_model, d_inner_hid, dropout_rate, ) + return enc_output + + +def wrap_decoder(trg_vocab_size, + max_length, + n_layer, + n_head, + d_key, + d_value, + d_model, + d_inner_hid, + dropout_rate, + trg_pad_idx, + pos_pad_idx, + dec_input_layers=None, + enc_output=None): + """ + The wrapper assembles together all needed layers for the decoder. + """ + if dec_input_layers is None: + # This is used to implement independent decoder program in inference. + # No need for trg_slf_attn_bias because of no paddings in inference. + trg_word, trg_pos, enc_output, trg_src_attn_bias = make_inputs( + decoder_input_data_names, n_head, d_model, batch_size, max_length, + 2, 1) + trg_slf_attn_bias = None + else: + trg_word, trg_pos, trg_slf_attn_bias, trg_src_attn_bias = dec_input_layers dec_input = prepare_decoder( trg_word, @@ -482,32 +571,11 @@ def transformer( d_inner_hid, dropout_rate, ) - # TODO(guosheng): Share the weight matrix between the embedding layers and - # the pre-softmax linear transformation. predict = layers.reshape( x=layers.fc(input=dec_output, size=trg_vocab_size, - param_attr=fluid.initializer.Xavier(uniform=False), bias_attr=False, num_flatten_dims=2), shape=[-1, trg_vocab_size], act="softmax") - # The actual shape of gold in runtime is: - # [batch_size * max_trg_length_in_a_batch, 1]. - gold = layers.data( - name=input_data_names[7], - shape=[batch_size * max_length, 1], - dtype="int64", - append_batch_size=False) - cost = layers.cross_entropy(input=predict, label=gold) - # The actual shape of weights in runtime is: - # [batch_size * max_trg_length_in_a_batch, 1]. - # Padding index do not contribute to the total loss. This Weight is used to - # cancel padding index in calculating the loss. - weights = layers.data( - name=input_data_names[8], - shape=[batch_size * max_length, 1], - dtype="float32", - append_batch_size=False) - weighted_cost = cost * weights - return layers.reduce_sum(weighted_cost) + return predict diff --git a/fluid/neural_machine_translation/transformer/train.py b/fluid/neural_machine_translation/transformer/train.py index 19835c486e06e1a3a955bdebfc552fc397ff7d9b..bb1b506bfa163c0f0a02b548c7de9c225afd39a4 100644 --- a/fluid/neural_machine_translation/transformer/train.py +++ b/fluid/neural_machine_translation/transformer/train.py @@ -1,3 +1,4 @@ +import os import numpy as np import paddle.v2 as paddle @@ -5,86 +6,74 @@ import paddle.fluid as fluid from model import transformer, position_encoding_init from optim import LearningRateScheduler -from config import TrainTaskConfig, ModelHyperParams, \ - pos_enc_param_names, input_data_names +from config import TrainTaskConfig, ModelHyperParams, pos_enc_param_names, \ + encoder_input_data_names, decoder_input_data_names, label_data_names -def prepare_batch_input(insts, input_data_names, src_pad_idx, trg_pad_idx, - max_length, n_head, place): +def pad_batch_data(insts, + pad_idx, + n_head, + is_target=False, + return_pos=True, + return_attn_bias=True, + return_max_len=True): """ Pad the instances to the max sequence length in batch, and generate the - corresponding position data and attention bias. Then, convert the numpy - data to tensors and return a dict mapping names to tensors. + corresponding position data and attention bias. + """ + return_list = [] + max_len = max(len(inst) for inst in insts) + inst_data = np.array( + [inst + [pad_idx] * (max_len - len(inst)) for inst in insts]) + return_list += [inst_data.astype("int64").reshape([-1, 1])] + if return_pos: + inst_pos = np.array([[ + pos_i + 1 if w_i != pad_idx else 0 for pos_i, w_i in enumerate(inst) + ] for inst in inst_data]) + + return_list += [inst_pos.astype("int64").reshape([-1, 1])] + if return_attn_bias: + if is_target: + # This is used to avoid attention on paddings and subsequent + # words. + slf_attn_bias_data = np.ones((inst_data.shape[0], max_len, max_len)) + slf_attn_bias_data = np.triu(slf_attn_bias_data, 1).reshape( + [-1, 1, max_len, max_len]) + slf_attn_bias_data = np.tile(slf_attn_bias_data, + [1, n_head, 1, 1]) * [-1e9] + else: + # This is used to avoid attention on paddings. + slf_attn_bias_data = np.array([[0] * len(inst) + [-1e9] * + (max_len - len(inst)) + for inst in insts]) + slf_attn_bias_data = np.tile( + slf_attn_bias_data.reshape([-1, 1, 1, max_len]), + [1, n_head, max_len, 1]) + return_list += [slf_attn_bias_data.astype("float32")] + if return_max_len: + return_list += [max_len] + return return_list if len(return_list) > 1 else return_list[0] + + +def prepare_batch_input(insts, input_data_names, src_pad_idx, trg_pad_idx, + max_length, n_head): + """ + Put all padded data needed by training into a dict. """ - input_dict = {} - - def __pad_batch_data(insts, - pad_idx, - is_target=False, - return_pos=True, - return_attn_bias=True, - return_max_len=True): - """ - Pad the instances to the max sequence length in batch, and generate the - corresponding position data and attention bias. - """ - return_list = [] - max_len = max(len(inst) for inst in insts) - inst_data = np.array( - [inst + [pad_idx] * (max_len - len(inst)) for inst in insts]) - return_list += [inst_data.astype("int64").reshape([-1, 1])] - if return_pos: - inst_pos = np.array([[ - pos_i + 1 if w_i != pad_idx else 0 - for pos_i, w_i in enumerate(inst) - ] for inst in inst_data]) - - return_list += [inst_pos.astype("int64").reshape([-1, 1])] - if return_attn_bias: - if is_target: - # This is used to avoid attention on paddings and subsequent - # words. - slf_attn_bias_data = np.ones((inst_data.shape[0], max_len, - max_len)) - slf_attn_bias_data = np.triu(slf_attn_bias_data, 1).reshape( - [-1, 1, max_len, max_len]) - slf_attn_bias_data = np.tile(slf_attn_bias_data, - [1, n_head, 1, 1]) * [-1e9] - else: - # This is used to avoid attention on paddings. - slf_attn_bias_data = np.array([[0] * len(inst) + [-1e9] * - (max_len - len(inst)) - for inst in insts]) - slf_attn_bias_data = np.tile( - slf_attn_bias_data.reshape([-1, 1, 1, max_len]), - [1, n_head, max_len, 1]) - return_list += [slf_attn_bias_data.astype("float32")] - if return_max_len: - return_list += [max_len] - return return_list if len(return_list) > 1 else return_list[0] - - def data_to_tensor(data_list, name_list, input_dict, place): - assert len(data_list) == len(name_list) - for i in range(len(name_list)): - tensor = fluid.LoDTensor() - tensor.set(data_list[i], place) - input_dict[name_list[i]] = tensor - - src_word, src_pos, src_slf_attn_bias, src_max_len = __pad_batch_data( - [inst[0] for inst in insts], src_pad_idx, is_target=False) - trg_word, trg_pos, trg_slf_attn_bias, trg_max_len = __pad_batch_data( - [inst[1] for inst in insts], trg_pad_idx, is_target=True) + src_word, src_pos, src_slf_attn_bias, src_max_len = pad_batch_data( + [inst[0] for inst in insts], src_pad_idx, n_head, is_target=False) + trg_word, trg_pos, trg_slf_attn_bias, trg_max_len = pad_batch_data( + [inst[1] for inst in insts], trg_pad_idx, n_head, is_target=True) trg_src_attn_bias = np.tile(src_slf_attn_bias[:, :, ::src_max_len, :], [1, 1, trg_max_len, 1]).astype("float32") - lbl_word = __pad_batch_data([inst[2] for inst in insts], trg_pad_idx, False, - False, False, False) + lbl_word = pad_batch_data([inst[2] for inst in insts], trg_pad_idx, n_head, + False, False, False, False) lbl_weight = (lbl_word != trg_pad_idx).astype("float32").reshape([-1, 1]) - - data_to_tensor([ - src_word, src_pos, trg_word, trg_pos, src_slf_attn_bias, - trg_slf_attn_bias, trg_src_attn_bias, lbl_word, lbl_weight - ], input_data_names, input_dict, place) - + input_dict = dict( + zip(input_data_names, [ + src_word, src_pos, src_slf_attn_bias, trg_word, trg_pos, + trg_slf_attn_bias, trg_src_attn_bias, lbl_word, lbl_weight + ])) return input_dict @@ -92,7 +81,7 @@ def main(): place = fluid.CUDAPlace(0) if TrainTaskConfig.use_gpu else fluid.CPUPlace() exe = fluid.Executor(place) - cost = transformer( + cost, predict = transformer( ModelHyperParams.src_vocab_size + 1, ModelHyperParams.trg_vocab_size + 1, ModelHyperParams.max_length + 1, ModelHyperParams.n_layer, ModelHyperParams.n_head, @@ -118,6 +107,31 @@ def main(): buf_size=100000), batch_size=TrainTaskConfig.batch_size) + # Program to do validation. + test_program = fluid.default_main_program().clone() + with fluid.program_guard(test_program): + test_program = fluid.io.get_inference_program([cost]) + val_data = paddle.batch( + paddle.dataset.wmt16.validation(ModelHyperParams.src_vocab_size, + ModelHyperParams.trg_vocab_size), + batch_size=TrainTaskConfig.batch_size) + + def test(exe): + test_costs = [] + for batch_id, data in enumerate(val_data()): + if len(data) != TrainTaskConfig.batch_size: + continue + data_input = prepare_batch_input( + data, encoder_input_data_names + decoder_input_data_names[:-1] + + label_data_names, ModelHyperParams.src_pad_idx, + ModelHyperParams.trg_pad_idx, ModelHyperParams.max_length, + ModelHyperParams.n_head) + test_cost = exe.run(test_program, + feed=data_input, + fetch_list=[cost])[0] + test_costs.append(test_cost) + return np.mean(test_costs) + # Initialize the parameters. exe.run(fluid.framework.default_startup_program()) for pos_enc_param_name in pos_enc_param_names: @@ -134,9 +148,10 @@ def main(): if len(data) != TrainTaskConfig.batch_size: continue data_input = prepare_batch_input( - data, input_data_names, ModelHyperParams.src_pad_idx, + data, encoder_input_data_names + decoder_input_data_names[:-1] + + label_data_names, ModelHyperParams.src_pad_idx, ModelHyperParams.trg_pad_idx, ModelHyperParams.max_length, - ModelHyperParams.n_head, place) + ModelHyperParams.n_head) lr_scheduler.update_learning_rate(data_input) outs = exe.run(fluid.framework.default_main_program(), feed=data_input, @@ -144,6 +159,14 @@ def main(): cost_val = np.array(outs[0]) print("pass_id = " + str(pass_id) + " batch = " + str(batch_id) + " cost = " + str(cost_val)) + # Validate and save the model for inference. + val_cost = test(exe) + print("pass_id = " + str(pass_id) + " val_cost = " + str(val_cost)) + fluid.io.save_inference_model( + os.path.join(TrainTaskConfig.model_dir, + "pass_" + str(pass_id) + ".infer.model"), + encoder_input_data_names + decoder_input_data_names[:-1], + [predict], exe) if __name__ == "__main__":