提交 9e39359b 编写于 作者: G guosheng

Make train.py support en-fr wordpiece data in Transformer

上级 88142779
......@@ -52,10 +52,17 @@ def parse_args():
"--use_wordpiece",
type=ast.literal_eval,
default=False,
help="The flag indicating if the data is wordpiece data. The EN-FR data we "
"provided is wordpiece data. For wordpiece data, converting ids to "
help="The flag indicating if the data is wordpiece data. The EN-FR data "
"we provided is wordpiece data. For wordpiece data, converting ids to "
"original words is a little different and some special codes are "
"provided in util.py to do this.")
parser.add_argument(
"--token_delimiter",
type=str,
default=" ",
help="The delimiter used to split tokens in source or target sentences. "
"For EN-DE BPE data we provided, use spaces as token delimiter.; "
"For EN-FR wordpiece data we provided, use '\x01' as token delimiter.")
parser.add_argument(
'opts',
help='See config.py for all options',
......@@ -549,8 +556,9 @@ def infer(args, inferencer=fast_infer):
src_vocab_fpath=args.src_vocab_fpath,
trg_vocab_fpath=args.trg_vocab_fpath,
fpattern=args.test_file_pattern,
batch_size=args.batch_size,
token_delimiter=args.token_delimiter,
use_token_batch=False,
batch_size=args.batch_size,
pool_size=args.pool_size,
sort_type=reader.SortType.NONE,
shuffle=False,
......
......@@ -76,6 +76,19 @@ def parse_args():
default=["<s>", "<e>", "<unk>"],
nargs=3,
help="The <bos>, <eos> and <unk> tokens in the dictionary.")
parser.add_argument(
"--use_wordpiece",
type=ast.literal_eval,
default=False,
help="The flag indicating if the data is wordpiece data. The EN-FR "
"data we provided is wordpiece data.")
parser.add_argument(
"--token_delimiter",
type=str,
default=" ",
help="The delimiter used to split tokens in source or target sentences. "
"For EN-DE BPE data we provided, use spaces as token delimiter.; "
"For EN-FR wordpiece data we provided, use '\x01' as token delimiter.")
parser.add_argument(
'opts',
help='See config.py for all options',
......@@ -273,6 +286,7 @@ def test_context(train_progm, avg_cost, train_exe, dev_count, data_input_names,
src_vocab_fpath=args.src_vocab_fpath,
trg_vocab_fpath=args.trg_vocab_fpath,
fpattern=args.val_file_pattern,
token_delimiter=args.token_delimiter,
use_token_batch=args.use_token_batch,
batch_size=args.batch_size * (1 if args.use_token_batch else dev_count),
pool_size=args.pool_size,
......@@ -335,6 +349,7 @@ def train_loop(exe, train_progm, dev_count, sum_cost, avg_cost, lr_scheduler,
src_vocab_fpath=args.src_vocab_fpath,
trg_vocab_fpath=args.trg_vocab_fpath,
fpattern=args.train_file_pattern,
token_delimiter=args.token_delimiter,
use_token_batch=args.use_token_batch,
batch_size=args.batch_size * (1 if args.use_token_batch else dev_count),
pool_size=args.pool_size,
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册