提交 d2d973d2 编写于 作者: G guosheng

Remove pad_idx in Transformer

上级 10de2bf3
...@@ -45,19 +45,13 @@ class InferTaskConfig(object): ...@@ -45,19 +45,13 @@ class InferTaskConfig(object):
class ModelHyperParams(object): class ModelHyperParams(object):
# Dictionary size for source and target language. This model directly uses # Dictionary size for source and target language. This model directly uses
# paddle.dataset.wmt16 in which <bos>, <eos> and <unk> token has # paddle.dataset.wmt16 in which <bos>, <eos> and <unk> token has
# alreay been added, but the <pad> token is not added. Transformer requires # alreay been added.
# sequences in a mini-batch are padded to have the same length. A <pad> token is
# added into the original dictionary in paddle.dateset.wmt16.
# size of source word dictionary. # size of source word dictionary.
src_vocab_size = 10000 src_vocab_size = 10000
# index for <pad> token in source language.
src_pad_idx = src_vocab_size
# size of target word dictionay # size of target word dictionay
trg_vocab_size = 10000 trg_vocab_size = 10000
# index for <pad> token in target language.
trg_pad_idx = trg_vocab_size
# index for <bos> token # index for <bos> token
bos_idx = 0 bos_idx = 0
......
...@@ -251,22 +251,20 @@ def main(): ...@@ -251,22 +251,20 @@ def main():
encoder_program = fluid.Program() encoder_program = fluid.Program()
with fluid.program_guard(main_program=encoder_program): with fluid.program_guard(main_program=encoder_program):
enc_output = encoder( enc_output = encoder(
ModelHyperParams.src_vocab_size + 1, ModelHyperParams.src_vocab_size, ModelHyperParams.max_length,
ModelHyperParams.max_length + 1, ModelHyperParams.n_layer, ModelHyperParams.n_layer, ModelHyperParams.n_head,
ModelHyperParams.n_head, ModelHyperParams.d_key, ModelHyperParams.d_key, ModelHyperParams.d_value,
ModelHyperParams.d_value, ModelHyperParams.d_model, ModelHyperParams.d_model, ModelHyperParams.d_inner_hid,
ModelHyperParams.d_inner_hid, ModelHyperParams.dropout, ModelHyperParams.dropout)
ModelHyperParams.src_pad_idx, ModelHyperParams.pos_pad_idx)
decoder_program = fluid.Program() decoder_program = fluid.Program()
with fluid.program_guard(main_program=decoder_program): with fluid.program_guard(main_program=decoder_program):
predict = decoder( predict = decoder(ModelHyperParams.trg_vocab_size,
ModelHyperParams.trg_vocab_size + 1, ModelHyperParams.max_length, ModelHyperParams.n_layer,
ModelHyperParams.max_length + 1, ModelHyperParams.n_layer,
ModelHyperParams.n_head, ModelHyperParams.d_key, ModelHyperParams.n_head, ModelHyperParams.d_key,
ModelHyperParams.d_value, ModelHyperParams.d_model, ModelHyperParams.d_value, ModelHyperParams.d_model,
ModelHyperParams.d_inner_hid, ModelHyperParams.dropout, ModelHyperParams.d_inner_hid,
ModelHyperParams.trg_pad_idx, ModelHyperParams.pos_pad_idx) ModelHyperParams.dropout)
# Load model parameters of encoder and decoder separately from the saved # Load model parameters of encoder and decoder separately from the saved
# transformer model. # transformer model.
......
...@@ -199,10 +199,8 @@ def prepare_encoder(src_word, ...@@ -199,10 +199,8 @@ def prepare_encoder(src_word,
src_pos, src_pos,
src_vocab_size, src_vocab_size,
src_emb_dim, src_emb_dim,
src_pad_idx,
src_max_len, src_max_len,
dropout_rate=0., dropout_rate=0.,
pos_pad_idx=0,
src_data_shape=None, src_data_shape=None,
pos_enc_param_name=None): pos_enc_param_name=None):
"""Add word embeddings and position encodings. """Add word embeddings and position encodings.
...@@ -214,12 +212,10 @@ def prepare_encoder(src_word, ...@@ -214,12 +212,10 @@ def prepare_encoder(src_word,
src_word_emb = layers.embedding( src_word_emb = layers.embedding(
src_word, src_word,
size=[src_vocab_size, src_emb_dim], size=[src_vocab_size, src_emb_dim],
padding_idx=src_pad_idx,
param_attr=fluid.initializer.Normal(0., 1.)) param_attr=fluid.initializer.Normal(0., 1.))
src_pos_enc = layers.embedding( src_pos_enc = layers.embedding(
src_pos, src_pos,
size=[src_max_len, src_emb_dim], size=[src_max_len, src_emb_dim],
padding_idx=pos_pad_idx,
param_attr=fluid.ParamAttr( param_attr=fluid.ParamAttr(
name=pos_enc_param_name, trainable=False)) name=pos_enc_param_name, trainable=False))
enc_input = src_word_emb + src_pos_enc enc_input = src_word_emb + src_pos_enc
...@@ -516,10 +512,7 @@ def transformer( ...@@ -516,10 +512,7 @@ def transformer(
d_value, d_value,
d_model, d_model,
d_inner_hid, d_inner_hid,
dropout_rate, dropout_rate, ):
src_pad_idx,
trg_pad_idx,
pos_pad_idx, ):
enc_inputs = make_inputs( enc_inputs = make_inputs(
encoder_input_data_names, encoder_input_data_names,
n_head, n_head,
...@@ -543,8 +536,6 @@ def transformer( ...@@ -543,8 +536,6 @@ def transformer(
d_model, d_model,
d_inner_hid, d_inner_hid,
dropout_rate, dropout_rate,
src_pad_idx,
pos_pad_idx,
enc_inputs, ) enc_inputs, )
dec_inputs = make_inputs( dec_inputs = make_inputs(
...@@ -570,8 +561,6 @@ def transformer( ...@@ -570,8 +561,6 @@ def transformer(
d_model, d_model,
d_inner_hid, d_inner_hid,
dropout_rate, dropout_rate,
trg_pad_idx,
pos_pad_idx,
dec_inputs, dec_inputs,
enc_output, ) enc_output, )
...@@ -606,8 +595,6 @@ def wrap_encoder(src_vocab_size, ...@@ -606,8 +595,6 @@ def wrap_encoder(src_vocab_size,
d_model, d_model,
d_inner_hid, d_inner_hid,
dropout_rate, dropout_rate,
src_pad_idx,
pos_pad_idx,
enc_inputs=None): enc_inputs=None):
""" """
The wrapper assembles together all needed layers for the encoder. The wrapper assembles together all needed layers for the encoder.
...@@ -637,10 +624,8 @@ def wrap_encoder(src_vocab_size, ...@@ -637,10 +624,8 @@ def wrap_encoder(src_vocab_size,
src_pos, src_pos,
src_vocab_size, src_vocab_size,
d_model, d_model,
src_pad_idx,
max_length, max_length,
dropout_rate, dropout_rate,
pos_pad_idx,
src_data_shape, ) src_data_shape, )
enc_output = encoder( enc_output = encoder(
enc_input, enc_input,
...@@ -666,8 +651,6 @@ def wrap_decoder(trg_vocab_size, ...@@ -666,8 +651,6 @@ def wrap_decoder(trg_vocab_size,
d_model, d_model,
d_inner_hid, d_inner_hid,
dropout_rate, dropout_rate,
trg_pad_idx,
pos_pad_idx,
dec_inputs=None, dec_inputs=None,
enc_output=None): enc_output=None):
""" """
...@@ -701,10 +684,8 @@ def wrap_decoder(trg_vocab_size, ...@@ -701,10 +684,8 @@ def wrap_decoder(trg_vocab_size,
trg_pos, trg_pos,
trg_vocab_size, trg_vocab_size,
d_model, d_model,
trg_pad_idx,
max_length, max_length,
dropout_rate, dropout_rate,
pos_pad_idx,
trg_data_shape, ) trg_data_shape, )
dec_output = decoder( dec_output = decoder(
dec_input, dec_input,
...@@ -724,11 +705,10 @@ def wrap_decoder(trg_vocab_size, ...@@ -724,11 +705,10 @@ def wrap_decoder(trg_vocab_size,
src_attn_post_softmax_shape, ) src_attn_post_softmax_shape, )
# Return logits for training and probs for inference. # Return logits for training and probs for inference.
predict = layers.reshape( predict = layers.reshape(
x=layers.fc( x=layers.fc(input=dec_output,
input=dec_output, size=trg_vocab_size,
size=trg_vocab_size - 1, # To exclude <pad>.
bias_attr=False, bias_attr=False,
num_flatten_dims=2), num_flatten_dims=2),
shape=[-1, trg_vocab_size - 1], shape=[-1, trg_vocab_size],
act="softmax" if dec_inputs is None else None) act="softmax" if dec_inputs is None else None)
return predict return predict
...@@ -13,7 +13,6 @@ from config import TrainTaskConfig, ModelHyperParams, pos_enc_param_names, \ ...@@ -13,7 +13,6 @@ from config import TrainTaskConfig, ModelHyperParams, pos_enc_param_names, \
def pad_batch_data(insts, def pad_batch_data(insts,
pad_idx, pad_idx,
eos_idx,
n_head, n_head,
is_target=False, is_target=False,
is_label=False, is_label=False,
...@@ -25,12 +24,10 @@ def pad_batch_data(insts, ...@@ -25,12 +24,10 @@ def pad_batch_data(insts,
""" """
return_list = [] return_list = []
max_len = max(len(inst) for inst in insts) max_len = max(len(inst) for inst in insts)
# Since we restrict the predicted probs excluding the <pad> to avoid # Any token included in dict can be used to pad, since the paddings' loss
# generating the <pad>, also replace the <pad> with others in labels. # will be masked out by weights and make no effect on parameter gradients.
inst_data = np.array([ inst_data = np.array(
inst + [eos_idx if is_label else pad_idx] * (max_len - len(inst)) [inst + [pad_idx] * (max_len - len(inst)) for inst in insts])
for inst in insts
])
return_list += [inst_data.astype("int64").reshape([-1, 1])] return_list += [inst_data.astype("int64").reshape([-1, 1])]
if is_label: # label weight if is_label: # label weight
inst_weight = np.array( inst_weight = np.array(
...@@ -66,22 +63,14 @@ def pad_batch_data(insts, ...@@ -66,22 +63,14 @@ def pad_batch_data(insts,
def prepare_batch_input(insts, input_data_names, src_pad_idx, trg_pad_idx, def prepare_batch_input(insts, input_data_names, src_pad_idx, trg_pad_idx,
eos_idx, n_head, d_model): n_head, d_model):
""" """
Put all padded data needed by training into a dict. Put all padded data needed by training into a dict.
""" """
src_word, src_pos, src_slf_attn_bias, src_max_len = pad_batch_data( src_word, src_pos, src_slf_attn_bias, src_max_len = pad_batch_data(
[inst[0] for inst in insts], [inst[0] for inst in insts], src_pad_idx, n_head, is_target=False)
src_pad_idx,
eos_idx,
n_head,
is_target=False)
trg_word, trg_pos, trg_slf_attn_bias, trg_max_len = pad_batch_data( trg_word, trg_pos, trg_slf_attn_bias, trg_max_len = pad_batch_data(
[inst[1] for inst in insts], [inst[1] for inst in insts], trg_pad_idx, n_head, is_target=True)
trg_pad_idx,
eos_idx,
n_head,
is_target=True)
trg_src_attn_bias = np.tile(src_slf_attn_bias[:, :, ::src_max_len, :], trg_src_attn_bias = np.tile(src_slf_attn_bias[:, :, ::src_max_len, :],
[1, 1, trg_max_len, 1]).astype("float32") [1, 1, trg_max_len, 1]).astype("float32")
...@@ -104,7 +93,6 @@ def prepare_batch_input(insts, input_data_names, src_pad_idx, trg_pad_idx, ...@@ -104,7 +93,6 @@ def prepare_batch_input(insts, input_data_names, src_pad_idx, trg_pad_idx,
lbl_word, lbl_weight = pad_batch_data( lbl_word, lbl_weight = pad_batch_data(
[inst[2] for inst in insts], [inst[2] for inst in insts],
trg_pad_idx, trg_pad_idx,
eos_idx,
n_head, n_head,
is_target=False, is_target=False,
is_label=True, is_label=True,
...@@ -128,13 +116,11 @@ def main(): ...@@ -128,13 +116,11 @@ def main():
exe = fluid.Executor(place) exe = fluid.Executor(place)
sum_cost, avg_cost, predict, token_num = transformer( sum_cost, avg_cost, predict, token_num = transformer(
ModelHyperParams.src_vocab_size + 1, ModelHyperParams.src_vocab_size, ModelHyperParams.trg_vocab_size,
ModelHyperParams.trg_vocab_size + 1, ModelHyperParams.max_length + 1, ModelHyperParams.max_length, ModelHyperParams.n_layer,
ModelHyperParams.n_layer, ModelHyperParams.n_head, ModelHyperParams.n_head, ModelHyperParams.d_key,
ModelHyperParams.d_key, ModelHyperParams.d_value, ModelHyperParams.d_value, ModelHyperParams.d_model,
ModelHyperParams.d_model, ModelHyperParams.d_inner_hid, ModelHyperParams.d_inner_hid, ModelHyperParams.dropout)
ModelHyperParams.dropout, ModelHyperParams.src_pad_idx,
ModelHyperParams.trg_pad_idx, ModelHyperParams.pos_pad_idx)
lr_scheduler = LearningRateScheduler(ModelHyperParams.d_model, lr_scheduler = LearningRateScheduler(ModelHyperParams.d_model,
TrainTaskConfig.warmup_steps, place, TrainTaskConfig.warmup_steps, place,
...@@ -168,9 +154,9 @@ def main(): ...@@ -168,9 +154,9 @@ def main():
for batch_id, data in enumerate(val_data()): for batch_id, data in enumerate(val_data()):
data_input = prepare_batch_input( data_input = prepare_batch_input(
data, encoder_input_data_names + decoder_input_data_names[:-1] + data, encoder_input_data_names + decoder_input_data_names[:-1] +
label_data_names, ModelHyperParams.src_pad_idx, label_data_names, ModelHyperParams.eos_idx,
ModelHyperParams.trg_pad_idx, ModelHyperParams.eos_idx, ModelHyperParams.eos_idx, ModelHyperParams.n_head,
ModelHyperParams.n_head, ModelHyperParams.d_model) ModelHyperParams.d_model)
test_sum_cost, test_token_num = exe.run( test_sum_cost, test_token_num = exe.run(
test_program, test_program,
feed=data_input, feed=data_input,
...@@ -188,7 +174,7 @@ def main(): ...@@ -188,7 +174,7 @@ def main():
pos_enc_param = fluid.global_scope().find_var( pos_enc_param = fluid.global_scope().find_var(
pos_enc_param_name).get_tensor() pos_enc_param_name).get_tensor()
pos_enc_param.set( pos_enc_param.set(
position_encoding_init(ModelHyperParams.max_length + 1, position_encoding_init(ModelHyperParams.max_length,
ModelHyperParams.d_model), place) ModelHyperParams.d_model), place)
for pass_id in xrange(TrainTaskConfig.pass_num): for pass_id in xrange(TrainTaskConfig.pass_num):
...@@ -196,9 +182,9 @@ def main(): ...@@ -196,9 +182,9 @@ def main():
for batch_id, data in enumerate(train_data()): for batch_id, data in enumerate(train_data()):
data_input = prepare_batch_input( data_input = prepare_batch_input(
data, encoder_input_data_names + decoder_input_data_names[:-1] + data, encoder_input_data_names + decoder_input_data_names[:-1] +
label_data_names, ModelHyperParams.src_pad_idx, label_data_names, ModelHyperParams.eos_idx,
ModelHyperParams.trg_pad_idx, ModelHyperParams.eos_idx, ModelHyperParams.eos_idx, ModelHyperParams.n_head,
ModelHyperParams.n_head, ModelHyperParams.d_model) ModelHyperParams.d_model)
lr_scheduler.update_learning_rate(data_input) lr_scheduler.update_learning_rate(data_input)
outs = exe.run(fluid.framework.default_main_program(), outs = exe.run(fluid.framework.default_main_program(),
feed=data_input, feed=data_input,
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册