From f8e03bf62c5e338b575b40a1c2a0ebc32e5d9589 Mon Sep 17 00:00:00 2001 From: guosheng Date: Mon, 2 Jul 2018 11:18:32 +0800 Subject: [PATCH] Fix the max length in Transformer to count start and end tokens in --- fluid/neural_machine_translation/transformer/config.py | 6 ++---- fluid/neural_machine_translation/transformer/infer.py | 3 ++- fluid/neural_machine_translation/transformer/train.py | 6 ++++-- 3 files changed, 8 insertions(+), 7 deletions(-) diff --git a/fluid/neural_machine_translation/transformer/config.py b/fluid/neural_machine_translation/transformer/config.py index a4e588c6..e68ab17e 100644 --- a/fluid/neural_machine_translation/transformer/config.py +++ b/fluid/neural_machine_translation/transformer/config.py @@ -62,10 +62,8 @@ class ModelHyperParams(object): eos_idx = 1 # index for token unk_idx = 2 - # max length of sequences. - # The size of position encoding table should at least plus 1, since the - # sinusoid position encoding starts from 1 and 0 can be used as the padding - # token for position encoding. + # max length of sequences deciding the size of position encoding table. + # Start from 1 and count start and end tokens in. max_length = 256 # the dimension for word embeddings, which is also the last dimension of # the input and output of multi-head attention, position-wise feed-forward diff --git a/fluid/neural_machine_translation/transformer/infer.py b/fluid/neural_machine_translation/transformer/infer.py index 87402808..505bf0b0 100644 --- a/fluid/neural_machine_translation/transformer/infer.py +++ b/fluid/neural_machine_translation/transformer/infer.py @@ -543,7 +543,8 @@ def infer(args, inferencer=fast_infer): start_mark=args.special_token[0], end_mark=args.special_token[1], unk_mark=args.special_token[2], - max_length=ModelHyperParams.max_length, + # count start and end tokens out + max_length=ModelHyperParams.max_length - 2, clip_last_batch=False) trg_idx2word = test_data.load_dict( dict_path=args.trg_vocab_fpath, reverse=True) diff --git a/fluid/neural_machine_translation/transformer/train.py b/fluid/neural_machine_translation/transformer/train.py index e3c9b62d..cdd7dfed 100644 --- a/fluid/neural_machine_translation/transformer/train.py +++ b/fluid/neural_machine_translation/transformer/train.py @@ -290,7 +290,8 @@ def train(args): start_mark=args.special_token[0], end_mark=args.special_token[1], unk_mark=args.special_token[2], - max_length=ModelHyperParams.max_length, + # count start and end tokens out + max_length=ModelHyperParams.max_length - 2, clip_last_batch=False) train_data = read_multiple( reader=train_data.batch_generator, @@ -326,7 +327,8 @@ def train(args): start_mark=args.special_token[0], end_mark=args.special_token[1], unk_mark=args.special_token[2], - max_length=ModelHyperParams.max_length, + # count start and end tokens out + max_length=ModelHyperParams.max_length - 2, clip_last_batch=False, shuffle=False, shuffle_batch=False) -- GitLab