提交 91a0b7c6 编写于 作者: G guosheng

Refine the pad_idx in Transformer

上级 d2d973d2
......@@ -60,11 +60,10 @@ class ModelHyperParams(object):
# index for <unk> token
unk_idx = 2
# position value corresponding to the <pad> token.
pos_pad_idx = 0
# max length of sequences. It should plus 1 to include position
# padding token for position encoding.
# max length of sequences.
# The size of position encoding table should plus 1, since the sinusoid
# position encoding start from 1 and 0 can be used as the padding token
# for position encoding.
max_length = 50
# the dimension for word embeddings, which is also the last dimension of
......
......@@ -39,7 +39,6 @@ def translate_batch(exe,
enc_in_data = pad_batch_data(
src_words,
src_pad_idx,
eos_idx,
n_head,
is_target=False,
is_label=False,
......@@ -251,7 +250,7 @@ def main():
encoder_program = fluid.Program()
with fluid.program_guard(main_program=encoder_program):
enc_output = encoder(
ModelHyperParams.src_vocab_size, ModelHyperParams.max_length,
ModelHyperParams.src_vocab_size, ModelHyperParams.max_length + 1,
ModelHyperParams.n_layer, ModelHyperParams.n_head,
ModelHyperParams.d_key, ModelHyperParams.d_value,
ModelHyperParams.d_model, ModelHyperParams.d_inner_hid,
......@@ -259,11 +258,11 @@ def main():
decoder_program = fluid.Program()
with fluid.program_guard(main_program=decoder_program):
predict = decoder(ModelHyperParams.trg_vocab_size,
ModelHyperParams.max_length, ModelHyperParams.n_layer,
ModelHyperParams.n_head, ModelHyperParams.d_key,
ModelHyperParams.d_value, ModelHyperParams.d_model,
ModelHyperParams.d_inner_hid,
predict = decoder(
ModelHyperParams.trg_vocab_size, ModelHyperParams.max_length + 1,
ModelHyperParams.n_layer, ModelHyperParams.n_head,
ModelHyperParams.d_key, ModelHyperParams.d_value,
ModelHyperParams.d_model, ModelHyperParams.d_inner_hid,
ModelHyperParams.dropout)
# Load model parameters of encoder and decoder separately from the saved
......@@ -300,9 +299,6 @@ def main():
trg_idx2word = paddle.dataset.wmt16.get_dict(
"de", dict_size=ModelHyperParams.trg_vocab_size, reverse=True)
# Append the <pad> token since the dict provided by dataset.wmt16 does
# not include it.
trg_idx2word[ModelHyperParams.trg_pad_idx] = "<pad>"
def post_process_seq(seq,
bos_idx=ModelHyperParams.bos_idx,
......@@ -326,19 +322,22 @@ def main():
for batch_id, data in enumerate(test_data()):
batch_seqs, batch_scores = translate_batch(
exe, [item[0] for item in data],
exe,
[item[0] for item in data],
encoder_program,
encoder_input_data_names, [enc_output.name],
encoder_input_data_names,
[enc_output.name],
decoder_program,
decoder_input_data_names, [predict.name],
decoder_input_data_names,
[predict.name],
InferTaskConfig.beam_size,
InferTaskConfig.max_length,
InferTaskConfig.n_best,
len(data),
ModelHyperParams.n_head,
ModelHyperParams.d_model,
ModelHyperParams.src_pad_idx,
ModelHyperParams.trg_pad_idx,
ModelHyperParams.eos_idx, # Use eos_idx to pad.
ModelHyperParams.eos_idx, # Use eos_idx to pad.
ModelHyperParams.bos_idx,
ModelHyperParams.eos_idx,
ModelHyperParams.unk_idx,
......
......@@ -117,7 +117,7 @@ def main():
sum_cost, avg_cost, predict, token_num = transformer(
ModelHyperParams.src_vocab_size, ModelHyperParams.trg_vocab_size,
ModelHyperParams.max_length, ModelHyperParams.n_layer,
ModelHyperParams.max_length + 1, ModelHyperParams.n_layer,
ModelHyperParams.n_head, ModelHyperParams.d_key,
ModelHyperParams.d_value, ModelHyperParams.d_model,
ModelHyperParams.d_inner_hid, ModelHyperParams.dropout)
......@@ -174,7 +174,7 @@ def main():
pos_enc_param = fluid.global_scope().find_var(
pos_enc_param_name).get_tensor()
pos_enc_param.set(
position_encoding_init(ModelHyperParams.max_length,
position_encoding_init(ModelHyperParams.max_length + 1,
ModelHyperParams.d_model), place)
for pass_id in xrange(TrainTaskConfig.pass_num):
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册