提交 33be7a97 编写于 作者: G guosheng

Make reader.py and infer.py support en-fr wordpiece data in Transformer

上级 6b81d938
import argparse import argparse
import ast
import numpy as np import numpy as np
import paddle import paddle
...@@ -11,6 +12,7 @@ from model import fast_decode as fast_decoder ...@@ -11,6 +12,7 @@ from model import fast_decode as fast_decoder
from config import * from config import *
from train import pad_batch_data from train import pad_batch_data
import reader import reader
import util
def parse_args(): def parse_args():
...@@ -46,6 +48,14 @@ def parse_args(): ...@@ -46,6 +48,14 @@ def parse_args():
default=["<s>", "<e>", "<unk>"], default=["<s>", "<e>", "<unk>"],
nargs=3, nargs=3,
help="The <bos>, <eos> and <unk> tokens in the dictionary.") help="The <bos>, <eos> and <unk> tokens in the dictionary.")
parser.add_argument(
"--use_wordpiece",
type=ast.literal_eval,
default=False,
help="The flag indicating if the data is wordpiece data. The EN-FR data we "
"provided is wordpiece data. For wordpiece data, converting ids to "
"original words is a little different and some special codes are "
"provided in util.py to do this.")
parser.add_argument( parser.add_argument(
'opts', 'opts',
help='See config.py for all options', help='See config.py for all options',
...@@ -320,7 +330,7 @@ def post_process_seq(seq, ...@@ -320,7 +330,7 @@ def post_process_seq(seq,
seq) seq)
def py_infer(test_data, trg_idx2word): def py_infer(test_data, trg_idx2word, use_wordpiece):
""" """
Inference by beam search implented by python, while the calculations from Inference by beam search implented by python, while the calculations from
symbols to probilities execute by Fluid operators. symbols to probilities execute by Fluid operators.
...@@ -399,7 +409,10 @@ def py_infer(test_data, trg_idx2word): ...@@ -399,7 +409,10 @@ def py_infer(test_data, trg_idx2word):
seqs = map(post_process_seq, batch_seqs[i]) seqs = map(post_process_seq, batch_seqs[i])
scores = batch_scores[i] scores = batch_scores[i]
for seq in seqs: for seq in seqs:
print(" ".join([trg_idx2word[idx] for idx in seq])) if use_wordpiece:
print(util.subword_ids_to_str(seq, trg_idx2word))
else:
print(" ".join([trg_idx2word[idx] for idx in seq]))
def prepare_batch_input(insts, data_input_names, util_input_names, src_pad_idx, def prepare_batch_input(insts, data_input_names, util_input_names, src_pad_idx,
...@@ -465,7 +478,7 @@ def prepare_batch_input(insts, data_input_names, util_input_names, src_pad_idx, ...@@ -465,7 +478,7 @@ def prepare_batch_input(insts, data_input_names, util_input_names, src_pad_idx,
return input_dict return input_dict
def fast_infer(test_data, trg_idx2word): def fast_infer(test_data, trg_idx2word, use_wordpiece):
""" """
Inference by beam search decoder based solely on Fluid operators. Inference by beam search decoder based solely on Fluid operators.
""" """
...@@ -520,7 +533,9 @@ def fast_infer(test_data, trg_idx2word): ...@@ -520,7 +533,9 @@ def fast_infer(test_data, trg_idx2word):
trg_idx2word[idx] trg_idx2word[idx]
for idx in post_process_seq( for idx in post_process_seq(
np.array(seq_ids)[sub_start:sub_end]) np.array(seq_ids)[sub_start:sub_end])
])) ]) if not use_wordpiece else util.subword_ids_to_str(
post_process_seq(np.array(seq_ids)[sub_start:sub_end]),
trg_idx2word))
scores[i].append(np.array(seq_scores)[sub_end - 1]) scores[i].append(np.array(seq_scores)[sub_end - 1])
print hyps[i][-1] print hyps[i][-1]
if len(hyps[i]) >= InferTaskConfig.n_best: if len(hyps[i]) >= InferTaskConfig.n_best:
...@@ -548,7 +563,7 @@ def infer(args, inferencer=fast_infer): ...@@ -548,7 +563,7 @@ def infer(args, inferencer=fast_infer):
clip_last_batch=False) clip_last_batch=False)
trg_idx2word = test_data.load_dict( trg_idx2word = test_data.load_dict(
dict_path=args.trg_vocab_fpath, reverse=True) dict_path=args.trg_vocab_fpath, reverse=True)
inferencer(test_data, trg_idx2word) inferencer(test_data, trg_idx2word, args.use_wordpiece)
if __name__ == "__main__": if __name__ == "__main__":
......
...@@ -116,9 +116,12 @@ class DataReader(object): ...@@ -116,9 +116,12 @@ class DataReader(object):
:param use_token_batch: Whether to produce batch data according to :param use_token_batch: Whether to produce batch data according to
token number. token number.
:type use_token_batch: bool :type use_token_batch: bool
:param delimiter: The delimiter used to split source and target in each :param field_delimiter: The delimiter used to split source and target in
line of data file. each line of data file.
:type delimiter: basestring :type field_delimiter: basestring
:param token_delimiter: The delimiter used to split tokens in source or
target sentences.
:type token_delimiter: basestring
:param start_mark: The token representing for the beginning of :param start_mark: The token representing for the beginning of
sentences in dictionary. sentences in dictionary.
:type start_mark: basestring :type start_mark: basestring
...@@ -145,7 +148,8 @@ class DataReader(object): ...@@ -145,7 +148,8 @@ class DataReader(object):
shuffle=True, shuffle=True,
shuffle_batch=False, shuffle_batch=False,
use_token_batch=False, use_token_batch=False,
delimiter="\t", field_delimiter="\t",
token_delimiter=" ",
start_mark="<s>", start_mark="<s>",
end_mark="<e>", end_mark="<e>",
unk_mark="<unk>", unk_mark="<unk>",
...@@ -164,7 +168,8 @@ class DataReader(object): ...@@ -164,7 +168,8 @@ class DataReader(object):
self._shuffle_batch = shuffle_batch self._shuffle_batch = shuffle_batch
self._min_length = min_length self._min_length = min_length
self._max_length = max_length self._max_length = max_length
self._delimiter = delimiter self._field_delimiter = field_delimiter
self._token_delimiter = token_delimiter
self._epoch_batches = [] self._epoch_batches = []
src_seq_words, trg_seq_words = self._load_data(fpattern, tar_fname) src_seq_words, trg_seq_words = self._load_data(fpattern, tar_fname)
...@@ -196,7 +201,7 @@ class DataReader(object): ...@@ -196,7 +201,7 @@ class DataReader(object):
trg_seq_words = [] trg_seq_words = []
for line in f_obj: for line in f_obj:
fields = line.strip().split(self._delimiter) fields = line.strip().split(self._field_delimiter)
if (not self._only_src and len(fields) != 2) or (self._only_src and if (not self._only_src and len(fields) != 2) or (self._only_src and
len(fields) != 1): len(fields) != 1):
...@@ -207,7 +212,7 @@ class DataReader(object): ...@@ -207,7 +212,7 @@ class DataReader(object):
max_len = -1 max_len = -1
for i, seq in enumerate(fields): for i, seq in enumerate(fields):
seq_words = seq.split() seq_words = seq.split(self._token_delimiter)
max_len = max(max_len, len(seq_words)) max_len = max(max_len, len(seq_words))
if len(seq_words) == 0 or \ if len(seq_words) == 0 or \
len(seq_words) < self._min_length or \ len(seq_words) < self._min_length or \
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册