From f3c247d33baef18570c7981e048d46b20d3a44a1 Mon Sep 17 00:00:00 2001 From: guosheng Date: Wed, 28 Mar 2018 03:39:49 +0800 Subject: [PATCH] Decouple the program desc with batch_size in Transformer. --- .../transformer/config.py | 4 +- .../transformer/infer.py | 87 +++++++----- .../transformer/model.py | 131 ++++++++++-------- .../transformer/train.py | 23 ++- 4 files changed, 144 insertions(+), 101 deletions(-) diff --git a/fluid/neural_machine_translation/transformer/config.py b/fluid/neural_machine_translation/transformer/config.py index 71e43149..1bf3f8d8 100644 --- a/fluid/neural_machine_translation/transformer/config.py +++ b/fluid/neural_machine_translation/transformer/config.py @@ -92,7 +92,8 @@ pos_enc_param_names = ( encoder_input_data_names = ( "src_word", "src_pos", - "src_slf_attn_bias", ) + "src_slf_attn_bias", + "src_data_shape", ) # Names of all data layers in decoder listed in order. decoder_input_data_names = ( @@ -100,6 +101,7 @@ decoder_input_data_names = ( "trg_pos", "trg_slf_attn_bias", "trg_src_attn_bias", + "trg_data_shape", "enc_output", ) # Names of label related data layers listed in order. diff --git a/fluid/neural_machine_translation/transformer/infer.py b/fluid/neural_machine_translation/transformer/infer.py index e4dee220..f24b7f6e 100644 --- a/fluid/neural_machine_translation/transformer/infer.py +++ b/fluid/neural_machine_translation/transformer/infer.py @@ -13,8 +13,8 @@ from train import pad_batch_data def translate_batch(exe, src_words, encoder, enc_in_names, enc_out_names, decoder, dec_in_names, dec_out_names, beam_size, max_length, - n_best, batch_size, n_head, src_pad_idx, trg_pad_idx, - bos_idx, eos_idx): + n_best, batch_size, n_head, d_model, src_pad_idx, + trg_pad_idx, bos_idx, eos_idx): """ Run the encoder program once and run the decoder program multiple times to implement beam search externally. @@ -28,6 +28,10 @@ def translate_batch(exe, src_words, encoder, enc_in_names, enc_out_names, return_pos=True, return_attn_bias=True, return_max_len=True) + enc_in_data = enc_in_data[:-1] + [ + np.array( + [batch_size, enc_in_data[-1], d_model], dtype="int32") + ] # Append the data shape input. enc_output = exe.run(encoder, feed=dict(zip(enc_in_names, enc_in_data)), fetch_list=enc_out_names)[0] @@ -35,11 +39,16 @@ def translate_batch(exe, src_words, encoder, enc_in_names, enc_out_names, # Beam Search. # To store the beam info. scores = np.zeros((batch_size, beam_size), dtype="float32") - prev_branchs = [[]] * batch_size - next_ids = [[]] * batch_size - # Use beam_map to map the instance idx in batch to beam idx, since the + prev_branchs = [[] for i in range(batch_size)] + next_ids = [[] for i in range(batch_size)] + # Use beam_inst_map to map beam idx to the instance idx in batch, since the # size of feeded batch is changing. - beam_map = range(batch_size) + beam_inst_map = { + beam_idx: inst_idx + for inst_idx, beam_idx in enumerate(range(batch_size)) + } + # Use active_beams to recode the alive. + active_beams = range(batch_size) def beam_backtrace(prev_branchs, next_ids, n_best=beam_size, add_bos=True): """ @@ -64,8 +73,8 @@ def translate_batch(exe, src_words, encoder, enc_in_names, enc_out_names, trg_words = np.array( [[bos_idx]] * batch_size * beam_size, dtype="int64") trg_pos = np.array([[1]] * batch_size * beam_size, dtype="int64") - src_max_length, src_slf_attn_bias, trg_max_len = enc_in_data[ - -1], enc_in_data[-2], 1 + src_max_length, src_slf_attn_bias, trg_max_len = enc_in_data[-1][ + 1], enc_in_data[-2], 1 # This is used to remove attention on subsequent words. trg_slf_attn_bias = np.ones((batch_size * beam_size, trg_max_len, trg_max_len)) @@ -77,16 +86,20 @@ def translate_batch(exe, src_words, encoder, enc_in_names, enc_out_names, trg_src_attn_bias = np.tile( src_slf_attn_bias[:, :, ::src_max_length, :], [beam_size, 1, trg_max_len, 1]) - enc_output = np.tile(enc_output, [beam_size, 1, 1]) - return trg_words, trg_pos, trg_slf_attn_bias, trg_src_attn_bias, enc_output + trg_data_shape = np.array( + [batch_size * beam_size, trg_max_len, d_model], dtype="int32") + enc_output = np.tile( + enc_output[:, np.newaxis], [1, beam_size, 1, 1]).reshape( + [-1, enc_output.shape[-2], enc_output.shape[-1]]) + return trg_words, trg_pos, trg_slf_attn_bias, trg_src_attn_bias, trg_data_shape, enc_output - def update_dec_in_data(dec_in_data, next_ids, active_beams): + def update_dec_in_data(dec_in_data, next_ids, active_beams, beam_inst_map): """ Update the input data of decoder mainly by slicing from the previous input data and dropping the finished instance beams. """ - trg_words, trg_pos, trg_slf_attn_bias, trg_src_attn_bias, enc_output = dec_in_data - trg_cur_len = len(next_ids[0]) + 1 # include the + trg_words, trg_pos, trg_slf_attn_bias, trg_src_attn_bias, trg_data_shape, enc_output = dec_in_data + trg_cur_len = trg_slf_attn_bias.shape[-1] + 1 trg_words = np.array( [ beam_backtrace( @@ -98,6 +111,7 @@ def translate_batch(exe, src_words, encoder, enc_in_names, enc_out_names, trg_pos = np.array( [range(1, trg_cur_len + 1)] * len(active_beams) * beam_size, dtype="int64").reshape([-1, 1]) + active_beams = [beam_inst_map[beam_idx] for beam_idx in active_beams] active_beams_indice = ( (np.array(active_beams) * beam_size)[:, np.newaxis] + np.array(range(beam_size))[np.newaxis, :]).flatten() @@ -112,8 +126,11 @@ def translate_batch(exe, src_words, encoder, enc_in_names, enc_out_names, trg_src_attn_bias = np.tile(trg_src_attn_bias[ active_beams_indice, :, ::trg_src_attn_bias.shape[2], :], [1, 1, trg_cur_len, 1]) + trg_data_shape = np.array( + [len(active_beams) * beam_size, trg_cur_len, d_model], + dtype="int32") enc_output = enc_output[active_beams_indice, :, :] - return trg_words, trg_pos, trg_slf_attn_bias, trg_src_attn_bias, enc_output + return trg_words, trg_pos, trg_slf_attn_bias, trg_src_attn_bias, trg_data_shape, enc_output dec_in_data = init_dec_in_data(batch_size, beam_size, enc_in_data, enc_output) @@ -122,13 +139,16 @@ def translate_batch(exe, src_words, encoder, enc_in_names, enc_out_names, feed=dict(zip(dec_in_names, dec_in_data)), fetch_list=dec_out_names)[0] predict_all = np.log( - predict_all.reshape([len(beam_map) * beam_size, i + 1, -1])[:, - -1, :]) - predict_all = (predict_all + scores[beam_map].reshape( - [len(beam_map) * beam_size, -1])).reshape( - [len(beam_map), beam_size, -1]) + predict_all.reshape([len(beam_inst_map) * beam_size, i + 1, -1]) + [:, -1, :]) + predict_all = (predict_all + scores[active_beams].reshape( + [len(beam_inst_map) * beam_size, -1])).reshape( + [len(beam_inst_map), beam_size, -1]) active_beams = [] - for inst_idx, beam_idx in enumerate(beam_map): + for beam_idx in range(batch_size): + if not beam_inst_map.has_key(beam_idx): + continue + inst_idx = beam_inst_map[beam_idx] predict = (predict_all[inst_idx, :, :] if i != 0 else predict_all[inst_idx, 0, :]).flatten() top_k_indice = np.argpartition(predict, -beam_size)[-beam_size:] @@ -141,13 +161,20 @@ def translate_batch(exe, src_words, encoder, enc_in_names, enc_out_names, next_ids[beam_idx].append(top_scores_ids % predict_all.shape[-1]) if next_ids[beam_idx][-1][0] != eos_idx: active_beams.append(beam_idx) - beam_map = active_beams - if len(beam_map) == 0: + if len(active_beams) == 0: break - dec_in_data = update_dec_in_data(dec_in_data, next_ids, active_beams) + dec_in_data = update_dec_in_data(dec_in_data, next_ids, active_beams, + beam_inst_map) + beam_inst_map = { + beam_idx: inst_idx + for inst_idx, beam_idx in enumerate(active_beams) + } # Decode beams and select n_best sequences for each instance by backtrace. - seqs = [beam_backtrace(prev_branchs[beam_idx], next_ids[beam_idx], n_best)] + seqs = [ + beam_backtrace(prev_branchs[beam_idx], next_ids[beam_idx], n_best) + for beam_idx in range(batch_size) + ] return seqs, scores[:, :n_best].tolist() @@ -155,10 +182,8 @@ def translate_batch(exe, src_words, encoder, enc_in_names, enc_out_names, def main(): place = fluid.CUDAPlace(0) if InferTaskConfig.use_gpu else fluid.CPUPlace() exe = fluid.Executor(place) - # The current program desc is coupled with batch_size and the only - # supported batch size is 1 currently. + encoder_program = fluid.Program() - model.batch_size = InferTaskConfig.batch_size with fluid.program_guard(main_program=encoder_program): enc_output = encoder( ModelHyperParams.src_vocab_size + 1, @@ -168,7 +193,6 @@ def main(): ModelHyperParams.d_inner_hid, ModelHyperParams.dropout, ModelHyperParams.src_pad_idx, ModelHyperParams.pos_pad_idx) - model.batch_size = InferTaskConfig.batch_size * InferTaskConfig.beam_size decoder_program = fluid.Program() with fluid.program_guard(main_program=decoder_program): predict = decoder( @@ -213,16 +237,15 @@ def main(): trg_idx2word = paddle.dataset.wmt16.get_dict( "de", dict_size=ModelHyperParams.trg_vocab_size, reverse=True) - for batch_id, data in enumerate(test_data()): batch_seqs, batch_scores = translate_batch( exe, [item[0] for item in data], encoder_program, encoder_input_data_names, [enc_output.name], decoder_program, decoder_input_data_names, [predict.name], InferTaskConfig.beam_size, InferTaskConfig.max_length, InferTaskConfig.n_best, - len(data), ModelHyperParams.n_head, ModelHyperParams.src_pad_idx, - ModelHyperParams.trg_pad_idx, ModelHyperParams.bos_idx, - ModelHyperParams.eos_idx) + len(data), ModelHyperParams.n_head, ModelHyperParams.d_model, + ModelHyperParams.src_pad_idx, ModelHyperParams.trg_pad_idx, + ModelHyperParams.bos_idx, ModelHyperParams.eos_idx) for i in range(len(batch_seqs)): seqs = batch_seqs[i] scores = batch_scores[i] diff --git a/fluid/neural_machine_translation/transformer/model.py b/fluid/neural_machine_translation/transformer/model.py index ba5ba447..ff339c31 100644 --- a/fluid/neural_machine_translation/transformer/model.py +++ b/fluid/neural_machine_translation/transformer/model.py @@ -7,9 +7,6 @@ import paddle.fluid.layers as layers from config import TrainTaskConfig, pos_enc_param_names, \ encoder_input_data_names, decoder_input_data_names, label_data_names -# FIXME(guosheng): Remove out the batch_size from the model. -batch_size = TrainTaskConfig.batch_size - def position_encoding_init(n_position, d_pos_vec): """ @@ -83,9 +80,10 @@ def multi_head_attention(queries, return x hidden_size = x.shape[-1] - # FIXME(guosheng): Decouple the program desc with batch_size. + # The value 0 in shape attr means copying the corresponding dimension + # size of the input as the output dimension size. reshaped = layers.reshape( - x=x, shape=[batch_size, -1, n_head, hidden_size // n_head]) + x=x, shape=[0, -1, n_head, hidden_size // n_head]) # permuate the dimensions into: # [batch_size, n_head, max_sequence_len, hidden_size_per_head] @@ -101,26 +99,20 @@ def multi_head_attention(queries, raise ValueError("Input(x) should be a 4-D Tensor.") trans_x = layers.transpose(x, perm=[0, 2, 1, 3]) - # FIXME(guosheng): Decouple the program desc with batch_size. + # The value 0 in shape attr means copying the corresponding dimension + # size of the input as the output dimension size. return layers.reshape( x=trans_x, - shape=map(int, - [batch_size, -1, trans_x.shape[2] * trans_x.shape[3]])) + shape=map(int, [0, -1, trans_x.shape[2] * trans_x.shape[3]])) def scaled_dot_product_attention(q, k, v, attn_bias, d_model, dropout_rate): """ Scaled Dot-Product Attention """ - # FIXME(guosheng): Optimize the shape in reshape_op or softmax_op. - - # The current implementation of softmax_op only supports 2D tensor, - # consequently it cannot be directly used here. - # If to use the reshape_op, Besides, the shape of product inferred in - # compile-time is not the actual shape in run-time. It cann't be used - # to set the attribute of reshape_op. - # So, here define the softmax for temporary solution. - + # FIXME(guosheng): Remove __softmax when softmax_op supporting high + # rank tensors. softmax_op only supports 2D tensor currently. + # Otherwise, add extra input data to reshape. def __softmax(x, eps=1e-9): exp_out = layers.exp(x=x) sum_out = layers.reduce_sum(exp_out, dim=-1, keep_dim=False) @@ -131,6 +123,7 @@ def multi_head_attention(queries, weights = __softmax( layers.elementwise_add( x=product, y=attn_bias) if attn_bias else product) + # weights = __softmax(product) if dropout_rate: weights = layers.dropout( weights, dropout_prob=dropout_rate, is_test=False) @@ -177,7 +170,7 @@ def positionwise_feed_forward(x, d_inner_hid, d_hid): return out -def pre_post_process_layer(prev_out, out, process_cmd, dropout=0.): +def pre_post_process_layer(prev_out, out, process_cmd, dropout_rate=0.): """ Add residual connection, layer normalization and droput to the out tensor optionally according to the value of process_cmd. @@ -195,8 +188,9 @@ def pre_post_process_layer(prev_out, out, process_cmd, dropout=0.): param_attr=fluid.initializer.Constant(1.), bias_attr=fluid.initializer.Constant(0.)) elif cmd == "d": # add dropout - if dropout: - out = layers.dropout(out, dropout_prob=dropout, is_test=False) + if dropout_rate: + out = layers.dropout( + out, dropout_prob=dropout_rate, is_test=False) return out @@ -210,8 +204,9 @@ def prepare_encoder(src_word, src_emb_dim, src_pad_idx, src_max_len, - dropout=0., + dropout_rate=0., pos_pad_idx=0, + src_data_shape=None, pos_enc_param_name=None): """Add word embeddings and position encodings. The output tensor has a shape of: @@ -231,12 +226,13 @@ def prepare_encoder(src_word, param_attr=fluid.ParamAttr( name=pos_enc_param_name, trainable=False)) enc_input = src_word_emb + src_pos_enc - - # FIXME(guosheng): Decouple the program desc with batch_size. - enc_input = layers.reshape(x=enc_input, shape=[batch_size, -1, src_emb_dim]) + enc_input = layers.reshape( + x=enc_input, + shape=[-1, src_max_len, src_emb_dim], + actual_shape=src_data_shape) return layers.dropout( - enc_input, dropout_prob=dropout, - is_test=False) if dropout else enc_input + enc_input, dropout_prob=dropout_rate, + is_test=False) if dropout_rate else enc_input prepare_encoder = partial( @@ -386,18 +382,21 @@ def decoder(dec_input, def make_inputs(input_data_names, n_head, d_model, - batch_size, max_length, - is_pos, - slf_attn_bias_flag, - src_attn_bias_flag, - enc_output_flag=False): + is_pos=True, + slf_attn_bias_flag=True, + src_attn_bias_flag=True, + enc_output_flag=False, + data_shape_flag=True): """ Define the input data layers for the transformer model. """ input_layers = [] - # The shapes here act as placeholder. - # The shapes set here is to pass the infer-shape in compile time. + batch_size = 1 # Only for the infer-shape in compile time. + # The shapes here act as placeholder and are set to pass the infer-shape in + # compile time. + # The actual data shape of word is: + # [batch_size * max_len_in_batch, 1] word = layers.data( name=input_data_names[len(input_layers)], shape=[batch_size * max_length, 1], @@ -405,6 +404,8 @@ def make_inputs(input_data_names, append_batch_size=False) input_layers += [word] # This is used for position data or label weight. + # The actual data shape of pos is: + # [batch_size * max_len_in_batch, 1] pos = layers.data( name=input_data_names[len(input_layers)], shape=[batch_size * max_length, 1], @@ -415,6 +416,8 @@ def make_inputs(input_data_names, # This input is used to remove attention weights on paddings for the # encoder and to remove attention weights on subsequent words for the # decoder. + # The actual data shape of slf_attn_bias_flag is: + # [batch_size, n_head, max_len_in_batch, max_len_in_batch] slf_attn_bias = layers.data( name=input_data_names[len(input_layers)], shape=[batch_size, n_head, max_length, max_length], @@ -423,13 +426,26 @@ def make_inputs(input_data_names, input_layers += [slf_attn_bias] if src_attn_bias_flag: # This input is used to remove attention weights on paddings. + # The actual data shape of slf_attn_bias_flag is: + # [batch_size, n_head, trg_max_len_in_batch, src_max_len_in_batch] src_attn_bias = layers.data( name=input_data_names[len(input_layers)], shape=[batch_size, n_head, max_length, max_length], dtype="float32", append_batch_size=False) input_layers += [src_attn_bias] + if data_shape_flag: + # This input is used to reshape. + data_shape = layers.data( + name=input_data_names[len(input_layers)], + shape=[3], + dtype="int32", + append_batch_size=False) + input_layers += [data_shape] if enc_output_flag: + # This input is used in independent decoder program for inference. + # The actual data shape of slf_attn_bias_flag is: + # [batch_size, max_len_in_batch, d_model] enc_output = layers.data( name=input_data_names[len(input_layers)], shape=[batch_size, max_length, d_model], @@ -453,8 +469,8 @@ def transformer( src_pad_idx, trg_pad_idx, pos_pad_idx, ): - enc_input_layers = make_inputs(encoder_input_data_names, n_head, d_model, - batch_size, max_length, True, True, False) + enc_inputs = make_inputs(encoder_input_data_names, n_head, d_model, + max_length, True, True, False) enc_output = wrap_encoder( src_vocab_size, @@ -468,10 +484,10 @@ def transformer( dropout_rate, src_pad_idx, pos_pad_idx, - enc_input_layers, ) + enc_inputs, ) - dec_input_layers = make_inputs(decoder_input_data_names, n_head, d_model, - batch_size, max_length, True, True, True) + dec_inputs = make_inputs(decoder_input_data_names, n_head, d_model, + max_length, True, True, True) predict = wrap_decoder( trg_vocab_size, @@ -485,13 +501,13 @@ def transformer( dropout_rate, trg_pad_idx, pos_pad_idx, - dec_input_layers, + dec_inputs, enc_output, ) # Padding index do not contribute to the total loss. The weights is used to # cancel padding index in calculating the loss. - gold, weights = make_inputs(label_data_names, n_head, d_model, batch_size, - max_length, False, False, False) + gold, weights = make_inputs(label_data_names, n_head, d_model, max_length, + False, False, False, False, False) cost = layers.cross_entropy(input=predict, label=gold) weighted_cost = cost * weights return layers.reduce_sum(weighted_cost), predict @@ -508,17 +524,18 @@ def wrap_encoder(src_vocab_size, dropout_rate, src_pad_idx, pos_pad_idx, - enc_input_layers=None): + enc_inputs=None): """ The wrapper assembles together all needed layers for the encoder. """ - if enc_input_layers is None: + if enc_inputs is None: # This is used to implement independent encoder program in inference. - src_word, src_pos, src_slf_attn_bias = make_inputs( - encoder_input_data_names, n_head, d_model, batch_size, max_length, - True, True, False) + src_word, src_pos, src_slf_attn_bias, src_data_shape = make_inputs( + encoder_input_data_names, n_head, d_model, max_length, True, True, + False) else: - src_word, src_pos, src_slf_attn_bias = enc_input_layers + src_word, src_pos, src_slf_attn_bias, src_data_shape = enc_inputs + enc_input = prepare_encoder( src_word, src_pos, @@ -526,7 +543,9 @@ def wrap_encoder(src_vocab_size, d_model, src_pad_idx, max_length, - dropout_rate, ) + dropout_rate, + pos_pad_idx, + src_data_shape, ) enc_output = encoder( enc_input, src_slf_attn_bias, @@ -551,18 +570,18 @@ def wrap_decoder(trg_vocab_size, dropout_rate, trg_pad_idx, pos_pad_idx, - dec_input_layers=None, + dec_inputs=None, enc_output=None): """ The wrapper assembles together all needed layers for the decoder. """ - if dec_input_layers is None: + if dec_inputs is None: # This is used to implement independent decoder program in inference. - trg_word, trg_pos, trg_slf_attn_bias, trg_src_attn_bias, enc_output = make_inputs( - decoder_input_data_names, n_head, d_model, batch_size, max_length, - True, True, True, True) + trg_word, trg_pos, trg_slf_attn_bias, trg_src_attn_bias, trg_data_shape, enc_output = make_inputs( + decoder_input_data_names, n_head, d_model, max_length, True, True, + True, True) else: - trg_word, trg_pos, trg_slf_attn_bias, trg_src_attn_bias = dec_input_layers + trg_word, trg_pos, trg_slf_attn_bias, trg_src_attn_bias, trg_data_shape = dec_inputs dec_input = prepare_decoder( trg_word, @@ -571,7 +590,9 @@ def wrap_decoder(trg_vocab_size, d_model, trg_pad_idx, max_length, - dropout_rate, ) + dropout_rate, + pos_pad_idx, + trg_data_shape, ) dec_output = decoder( dec_input, enc_output, diff --git a/fluid/neural_machine_translation/transformer/train.py b/fluid/neural_machine_translation/transformer/train.py index 65de8ef7..b0055233 100644 --- a/fluid/neural_machine_translation/transformer/train.py +++ b/fluid/neural_machine_translation/transformer/train.py @@ -56,7 +56,7 @@ def pad_batch_data(insts, def prepare_batch_input(insts, input_data_names, src_pad_idx, trg_pad_idx, - max_length, n_head): + n_head, d_model): """ Put all padded data needed by training into a dict. """ @@ -69,10 +69,13 @@ def prepare_batch_input(insts, input_data_names, src_pad_idx, trg_pad_idx, lbl_word = pad_batch_data([inst[2] for inst in insts], trg_pad_idx, n_head, False, False, False, False) lbl_weight = (lbl_word != trg_pad_idx).astype("float32").reshape([-1, 1]) + src_data_shape = np.array([len(insts), src_max_len, d_model], dtype="int32") + trg_data_shape = np.array([len(insts), trg_max_len, d_model], dtype="int32") input_dict = dict( zip(input_data_names, [ - src_word, src_pos, src_slf_attn_bias, trg_word, trg_pos, - trg_slf_attn_bias, trg_src_attn_bias, lbl_word, lbl_weight + src_word, src_pos, src_slf_attn_bias, src_data_shape, trg_word, + trg_pos, trg_slf_attn_bias, trg_src_attn_bias, trg_data_shape, + lbl_word, lbl_weight ])) return input_dict @@ -119,13 +122,11 @@ def main(): def test(exe): test_costs = [] for batch_id, data in enumerate(val_data()): - if len(data) != TrainTaskConfig.batch_size: - continue data_input = prepare_batch_input( data, encoder_input_data_names + decoder_input_data_names[:-1] + label_data_names, ModelHyperParams.src_pad_idx, - ModelHyperParams.trg_pad_idx, ModelHyperParams.max_length, - ModelHyperParams.n_head) + ModelHyperParams.trg_pad_idx, ModelHyperParams.n_head, + ModelHyperParams.d_model) test_cost = exe.run(test_program, feed=data_input, fetch_list=[cost])[0] @@ -143,15 +144,11 @@ def main(): for pass_id in xrange(TrainTaskConfig.pass_num): for batch_id, data in enumerate(train_data()): - # The current program desc is coupled with batch_size, thus all - # mini-batches must have the same number of instances currently. - if len(data) != TrainTaskConfig.batch_size: - continue data_input = prepare_batch_input( data, encoder_input_data_names + decoder_input_data_names[:-1] + label_data_names, ModelHyperParams.src_pad_idx, - ModelHyperParams.trg_pad_idx, ModelHyperParams.max_length, - ModelHyperParams.n_head) + ModelHyperParams.trg_pad_idx, ModelHyperParams.n_head, + ModelHyperParams.d_model) lr_scheduler.update_learning_rate(data_input) outs = exe.run(fluid.framework.default_main_program(), feed=data_input, -- GitLab