未验证 提交 9ab61a38 编写于 作者: G Guo Sheng 提交者: GitHub

Merge pull request #1023 from guoshengCS/fix-transformer-max-len

Fix the max length in Transformer to count start and end tokens in
......@@ -62,10 +62,8 @@ class ModelHyperParams(object):
eos_idx = 1
# index for <unk> token
unk_idx = 2
# max length of sequences.
# The size of position encoding table should at least plus 1, since the
# sinusoid position encoding starts from 1 and 0 can be used as the padding
# token for position encoding.
# max length of sequences deciding the size of position encoding table.
# Start from 1 and count start and end tokens in.
max_length = 256
# the dimension for word embeddings, which is also the last dimension of
# the input and output of multi-head attention, position-wise feed-forward
......
......@@ -543,7 +543,8 @@ def infer(args, inferencer=fast_infer):
start_mark=args.special_token[0],
end_mark=args.special_token[1],
unk_mark=args.special_token[2],
max_length=ModelHyperParams.max_length,
# count start and end tokens out
max_length=ModelHyperParams.max_length - 2,
clip_last_batch=False)
trg_idx2word = test_data.load_dict(
dict_path=args.trg_vocab_fpath, reverse=True)
......
......@@ -290,7 +290,8 @@ def train(args):
start_mark=args.special_token[0],
end_mark=args.special_token[1],
unk_mark=args.special_token[2],
max_length=ModelHyperParams.max_length,
# count start and end tokens out
max_length=ModelHyperParams.max_length - 2,
clip_last_batch=False)
train_data = read_multiple(
reader=train_data.batch_generator,
......@@ -326,7 +327,8 @@ def train(args):
start_mark=args.special_token[0],
end_mark=args.special_token[1],
unk_mark=args.special_token[2],
max_length=ModelHyperParams.max_length,
# count start and end tokens out
max_length=ModelHyperParams.max_length - 2,
clip_last_batch=False,
shuffle=False,
shuffle_batch=False)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册