提交 f8e03bf6 编写于 作者: G guosheng

Fix the max length in Transformer to count start and end tokens in

上级 8222284e
...@@ -62,10 +62,8 @@ class ModelHyperParams(object): ...@@ -62,10 +62,8 @@ class ModelHyperParams(object):
eos_idx = 1 eos_idx = 1
# index for <unk> token # index for <unk> token
unk_idx = 2 unk_idx = 2
# max length of sequences. # max length of sequences deciding the size of position encoding table.
# The size of position encoding table should at least plus 1, since the # Start from 1 and count start and end tokens in.
# sinusoid position encoding starts from 1 and 0 can be used as the padding
# token for position encoding.
max_length = 256 max_length = 256
# the dimension for word embeddings, which is also the last dimension of # the dimension for word embeddings, which is also the last dimension of
# the input and output of multi-head attention, position-wise feed-forward # the input and output of multi-head attention, position-wise feed-forward
......
...@@ -543,7 +543,8 @@ def infer(args, inferencer=fast_infer): ...@@ -543,7 +543,8 @@ def infer(args, inferencer=fast_infer):
start_mark=args.special_token[0], start_mark=args.special_token[0],
end_mark=args.special_token[1], end_mark=args.special_token[1],
unk_mark=args.special_token[2], unk_mark=args.special_token[2],
max_length=ModelHyperParams.max_length, # count start and end tokens out
max_length=ModelHyperParams.max_length - 2,
clip_last_batch=False) clip_last_batch=False)
trg_idx2word = test_data.load_dict( trg_idx2word = test_data.load_dict(
dict_path=args.trg_vocab_fpath, reverse=True) dict_path=args.trg_vocab_fpath, reverse=True)
......
...@@ -290,7 +290,8 @@ def train(args): ...@@ -290,7 +290,8 @@ def train(args):
start_mark=args.special_token[0], start_mark=args.special_token[0],
end_mark=args.special_token[1], end_mark=args.special_token[1],
unk_mark=args.special_token[2], unk_mark=args.special_token[2],
max_length=ModelHyperParams.max_length, # count start and end tokens out
max_length=ModelHyperParams.max_length - 2,
clip_last_batch=False) clip_last_batch=False)
train_data = read_multiple( train_data = read_multiple(
reader=train_data.batch_generator, reader=train_data.batch_generator,
...@@ -326,7 +327,8 @@ def train(args): ...@@ -326,7 +327,8 @@ def train(args):
start_mark=args.special_token[0], start_mark=args.special_token[0],
end_mark=args.special_token[1], end_mark=args.special_token[1],
unk_mark=args.special_token[2], unk_mark=args.special_token[2],
max_length=ModelHyperParams.max_length, # count start and end tokens out
max_length=ModelHyperParams.max_length - 2,
clip_last_batch=False, clip_last_batch=False,
shuffle=False, shuffle=False,
shuffle_batch=False) shuffle_batch=False)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册