提交 91a0b7c6 编写于 作者: G guosheng

Refine the pad_idx in Transformer

上级 d2d973d2
...@@ -60,11 +60,10 @@ class ModelHyperParams(object): ...@@ -60,11 +60,10 @@ class ModelHyperParams(object):
# index for <unk> token # index for <unk> token
unk_idx = 2 unk_idx = 2
# position value corresponding to the <pad> token. # max length of sequences.
pos_pad_idx = 0 # The size of position encoding table should plus 1, since the sinusoid
# position encoding start from 1 and 0 can be used as the padding token
# max length of sequences. It should plus 1 to include position # for position encoding.
# padding token for position encoding.
max_length = 50 max_length = 50
# the dimension for word embeddings, which is also the last dimension of # the dimension for word embeddings, which is also the last dimension of
......
...@@ -39,7 +39,6 @@ def translate_batch(exe, ...@@ -39,7 +39,6 @@ def translate_batch(exe,
enc_in_data = pad_batch_data( enc_in_data = pad_batch_data(
src_words, src_words,
src_pad_idx, src_pad_idx,
eos_idx,
n_head, n_head,
is_target=False, is_target=False,
is_label=False, is_label=False,
...@@ -251,7 +250,7 @@ def main(): ...@@ -251,7 +250,7 @@ def main():
encoder_program = fluid.Program() encoder_program = fluid.Program()
with fluid.program_guard(main_program=encoder_program): with fluid.program_guard(main_program=encoder_program):
enc_output = encoder( enc_output = encoder(
ModelHyperParams.src_vocab_size, ModelHyperParams.max_length, ModelHyperParams.src_vocab_size, ModelHyperParams.max_length + 1,
ModelHyperParams.n_layer, ModelHyperParams.n_head, ModelHyperParams.n_layer, ModelHyperParams.n_head,
ModelHyperParams.d_key, ModelHyperParams.d_value, ModelHyperParams.d_key, ModelHyperParams.d_value,
ModelHyperParams.d_model, ModelHyperParams.d_inner_hid, ModelHyperParams.d_model, ModelHyperParams.d_inner_hid,
...@@ -259,12 +258,12 @@ def main(): ...@@ -259,12 +258,12 @@ def main():
decoder_program = fluid.Program() decoder_program = fluid.Program()
with fluid.program_guard(main_program=decoder_program): with fluid.program_guard(main_program=decoder_program):
predict = decoder(ModelHyperParams.trg_vocab_size, predict = decoder(
ModelHyperParams.max_length, ModelHyperParams.n_layer, ModelHyperParams.trg_vocab_size, ModelHyperParams.max_length + 1,
ModelHyperParams.n_head, ModelHyperParams.d_key, ModelHyperParams.n_layer, ModelHyperParams.n_head,
ModelHyperParams.d_value, ModelHyperParams.d_model, ModelHyperParams.d_key, ModelHyperParams.d_value,
ModelHyperParams.d_inner_hid, ModelHyperParams.d_model, ModelHyperParams.d_inner_hid,
ModelHyperParams.dropout) ModelHyperParams.dropout)
# Load model parameters of encoder and decoder separately from the saved # Load model parameters of encoder and decoder separately from the saved
# transformer model. # transformer model.
...@@ -300,9 +299,6 @@ def main(): ...@@ -300,9 +299,6 @@ def main():
trg_idx2word = paddle.dataset.wmt16.get_dict( trg_idx2word = paddle.dataset.wmt16.get_dict(
"de", dict_size=ModelHyperParams.trg_vocab_size, reverse=True) "de", dict_size=ModelHyperParams.trg_vocab_size, reverse=True)
# Append the <pad> token since the dict provided by dataset.wmt16 does
# not include it.
trg_idx2word[ModelHyperParams.trg_pad_idx] = "<pad>"
def post_process_seq(seq, def post_process_seq(seq,
bos_idx=ModelHyperParams.bos_idx, bos_idx=ModelHyperParams.bos_idx,
...@@ -326,19 +322,22 @@ def main(): ...@@ -326,19 +322,22 @@ def main():
for batch_id, data in enumerate(test_data()): for batch_id, data in enumerate(test_data()):
batch_seqs, batch_scores = translate_batch( batch_seqs, batch_scores = translate_batch(
exe, [item[0] for item in data], exe,
[item[0] for item in data],
encoder_program, encoder_program,
encoder_input_data_names, [enc_output.name], encoder_input_data_names,
[enc_output.name],
decoder_program, decoder_program,
decoder_input_data_names, [predict.name], decoder_input_data_names,
[predict.name],
InferTaskConfig.beam_size, InferTaskConfig.beam_size,
InferTaskConfig.max_length, InferTaskConfig.max_length,
InferTaskConfig.n_best, InferTaskConfig.n_best,
len(data), len(data),
ModelHyperParams.n_head, ModelHyperParams.n_head,
ModelHyperParams.d_model, ModelHyperParams.d_model,
ModelHyperParams.src_pad_idx, ModelHyperParams.eos_idx, # Use eos_idx to pad.
ModelHyperParams.trg_pad_idx, ModelHyperParams.eos_idx, # Use eos_idx to pad.
ModelHyperParams.bos_idx, ModelHyperParams.bos_idx,
ModelHyperParams.eos_idx, ModelHyperParams.eos_idx,
ModelHyperParams.unk_idx, ModelHyperParams.unk_idx,
......
...@@ -117,7 +117,7 @@ def main(): ...@@ -117,7 +117,7 @@ def main():
sum_cost, avg_cost, predict, token_num = transformer( sum_cost, avg_cost, predict, token_num = transformer(
ModelHyperParams.src_vocab_size, ModelHyperParams.trg_vocab_size, ModelHyperParams.src_vocab_size, ModelHyperParams.trg_vocab_size,
ModelHyperParams.max_length, ModelHyperParams.n_layer, ModelHyperParams.max_length + 1, ModelHyperParams.n_layer,
ModelHyperParams.n_head, ModelHyperParams.d_key, ModelHyperParams.n_head, ModelHyperParams.d_key,
ModelHyperParams.d_value, ModelHyperParams.d_model, ModelHyperParams.d_value, ModelHyperParams.d_model,
ModelHyperParams.d_inner_hid, ModelHyperParams.dropout) ModelHyperParams.d_inner_hid, ModelHyperParams.dropout)
...@@ -174,7 +174,7 @@ def main(): ...@@ -174,7 +174,7 @@ def main():
pos_enc_param = fluid.global_scope().find_var( pos_enc_param = fluid.global_scope().find_var(
pos_enc_param_name).get_tensor() pos_enc_param_name).get_tensor()
pos_enc_param.set( pos_enc_param.set(
position_encoding_init(ModelHyperParams.max_length, position_encoding_init(ModelHyperParams.max_length + 1,
ModelHyperParams.d_model), place) ModelHyperParams.d_model), place)
for pass_id in xrange(TrainTaskConfig.pass_num): for pass_id in xrange(TrainTaskConfig.pass_num):
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册