提交 9e39359b 编写于 作者: G guosheng

Make train.py support en-fr wordpiece data in Transformer

上级 88142779
...@@ -52,10 +52,17 @@ def parse_args(): ...@@ -52,10 +52,17 @@ def parse_args():
"--use_wordpiece", "--use_wordpiece",
type=ast.literal_eval, type=ast.literal_eval,
default=False, default=False,
help="The flag indicating if the data is wordpiece data. The EN-FR data we " help="The flag indicating if the data is wordpiece data. The EN-FR data "
"provided is wordpiece data. For wordpiece data, converting ids to " "we provided is wordpiece data. For wordpiece data, converting ids to "
"original words is a little different and some special codes are " "original words is a little different and some special codes are "
"provided in util.py to do this.") "provided in util.py to do this.")
parser.add_argument(
"--token_delimiter",
type=str,
default=" ",
help="The delimiter used to split tokens in source or target sentences. "
"For EN-DE BPE data we provided, use spaces as token delimiter.; "
"For EN-FR wordpiece data we provided, use '\x01' as token delimiter.")
parser.add_argument( parser.add_argument(
'opts', 'opts',
help='See config.py for all options', help='See config.py for all options',
...@@ -549,8 +556,9 @@ def infer(args, inferencer=fast_infer): ...@@ -549,8 +556,9 @@ def infer(args, inferencer=fast_infer):
src_vocab_fpath=args.src_vocab_fpath, src_vocab_fpath=args.src_vocab_fpath,
trg_vocab_fpath=args.trg_vocab_fpath, trg_vocab_fpath=args.trg_vocab_fpath,
fpattern=args.test_file_pattern, fpattern=args.test_file_pattern,
batch_size=args.batch_size, token_delimiter=args.token_delimiter,
use_token_batch=False, use_token_batch=False,
batch_size=args.batch_size,
pool_size=args.pool_size, pool_size=args.pool_size,
sort_type=reader.SortType.NONE, sort_type=reader.SortType.NONE,
shuffle=False, shuffle=False,
......
...@@ -76,6 +76,19 @@ def parse_args(): ...@@ -76,6 +76,19 @@ def parse_args():
default=["<s>", "<e>", "<unk>"], default=["<s>", "<e>", "<unk>"],
nargs=3, nargs=3,
help="The <bos>, <eos> and <unk> tokens in the dictionary.") help="The <bos>, <eos> and <unk> tokens in the dictionary.")
parser.add_argument(
"--use_wordpiece",
type=ast.literal_eval,
default=False,
help="The flag indicating if the data is wordpiece data. The EN-FR "
"data we provided is wordpiece data.")
parser.add_argument(
"--token_delimiter",
type=str,
default=" ",
help="The delimiter used to split tokens in source or target sentences. "
"For EN-DE BPE data we provided, use spaces as token delimiter.; "
"For EN-FR wordpiece data we provided, use '\x01' as token delimiter.")
parser.add_argument( parser.add_argument(
'opts', 'opts',
help='See config.py for all options', help='See config.py for all options',
...@@ -273,6 +286,7 @@ def test_context(train_progm, avg_cost, train_exe, dev_count, data_input_names, ...@@ -273,6 +286,7 @@ def test_context(train_progm, avg_cost, train_exe, dev_count, data_input_names,
src_vocab_fpath=args.src_vocab_fpath, src_vocab_fpath=args.src_vocab_fpath,
trg_vocab_fpath=args.trg_vocab_fpath, trg_vocab_fpath=args.trg_vocab_fpath,
fpattern=args.val_file_pattern, fpattern=args.val_file_pattern,
token_delimiter=args.token_delimiter,
use_token_batch=args.use_token_batch, use_token_batch=args.use_token_batch,
batch_size=args.batch_size * (1 if args.use_token_batch else dev_count), batch_size=args.batch_size * (1 if args.use_token_batch else dev_count),
pool_size=args.pool_size, pool_size=args.pool_size,
...@@ -335,6 +349,7 @@ def train_loop(exe, train_progm, dev_count, sum_cost, avg_cost, lr_scheduler, ...@@ -335,6 +349,7 @@ def train_loop(exe, train_progm, dev_count, sum_cost, avg_cost, lr_scheduler,
src_vocab_fpath=args.src_vocab_fpath, src_vocab_fpath=args.src_vocab_fpath,
trg_vocab_fpath=args.trg_vocab_fpath, trg_vocab_fpath=args.trg_vocab_fpath,
fpattern=args.train_file_pattern, fpattern=args.train_file_pattern,
token_delimiter=args.token_delimiter,
use_token_batch=args.use_token_batch, use_token_batch=args.use_token_batch,
batch_size=args.batch_size * (1 if args.use_token_batch else dev_count), batch_size=args.batch_size * (1 if args.use_token_batch else dev_count),
pool_size=args.pool_size, pool_size=args.pool_size,
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册