提交 33be7a97 编写于 作者: G guosheng

Make reader.py and infer.py support en-fr wordpiece data in Transformer

上级 6b81d938
import argparse
import ast
import numpy as np
import paddle
......@@ -11,6 +12,7 @@ from model import fast_decode as fast_decoder
from config import *
from train import pad_batch_data
import reader
import util
def parse_args():
......@@ -46,6 +48,14 @@ def parse_args():
default=["<s>", "<e>", "<unk>"],
nargs=3,
help="The <bos>, <eos> and <unk> tokens in the dictionary.")
parser.add_argument(
"--use_wordpiece",
type=ast.literal_eval,
default=False,
help="The flag indicating if the data is wordpiece data. The EN-FR data we "
"provided is wordpiece data. For wordpiece data, converting ids to "
"original words is a little different and some special codes are "
"provided in util.py to do this.")
parser.add_argument(
'opts',
help='See config.py for all options',
......@@ -320,7 +330,7 @@ def post_process_seq(seq,
seq)
def py_infer(test_data, trg_idx2word):
def py_infer(test_data, trg_idx2word, use_wordpiece):
"""
Inference by beam search implented by python, while the calculations from
symbols to probilities execute by Fluid operators.
......@@ -399,7 +409,10 @@ def py_infer(test_data, trg_idx2word):
seqs = map(post_process_seq, batch_seqs[i])
scores = batch_scores[i]
for seq in seqs:
print(" ".join([trg_idx2word[idx] for idx in seq]))
if use_wordpiece:
print(util.subword_ids_to_str(seq, trg_idx2word))
else:
print(" ".join([trg_idx2word[idx] for idx in seq]))
def prepare_batch_input(insts, data_input_names, util_input_names, src_pad_idx,
......@@ -465,7 +478,7 @@ def prepare_batch_input(insts, data_input_names, util_input_names, src_pad_idx,
return input_dict
def fast_infer(test_data, trg_idx2word):
def fast_infer(test_data, trg_idx2word, use_wordpiece):
"""
Inference by beam search decoder based solely on Fluid operators.
"""
......@@ -520,7 +533,9 @@ def fast_infer(test_data, trg_idx2word):
trg_idx2word[idx]
for idx in post_process_seq(
np.array(seq_ids)[sub_start:sub_end])
]))
]) if not use_wordpiece else util.subword_ids_to_str(
post_process_seq(np.array(seq_ids)[sub_start:sub_end]),
trg_idx2word))
scores[i].append(np.array(seq_scores)[sub_end - 1])
print hyps[i][-1]
if len(hyps[i]) >= InferTaskConfig.n_best:
......@@ -548,7 +563,7 @@ def infer(args, inferencer=fast_infer):
clip_last_batch=False)
trg_idx2word = test_data.load_dict(
dict_path=args.trg_vocab_fpath, reverse=True)
inferencer(test_data, trg_idx2word)
inferencer(test_data, trg_idx2word, args.use_wordpiece)
if __name__ == "__main__":
......
......@@ -116,9 +116,12 @@ class DataReader(object):
:param use_token_batch: Whether to produce batch data according to
token number.
:type use_token_batch: bool
:param delimiter: The delimiter used to split source and target in each
line of data file.
:type delimiter: basestring
:param field_delimiter: The delimiter used to split source and target in
each line of data file.
:type field_delimiter: basestring
:param token_delimiter: The delimiter used to split tokens in source or
target sentences.
:type token_delimiter: basestring
:param start_mark: The token representing for the beginning of
sentences in dictionary.
:type start_mark: basestring
......@@ -145,7 +148,8 @@ class DataReader(object):
shuffle=True,
shuffle_batch=False,
use_token_batch=False,
delimiter="\t",
field_delimiter="\t",
token_delimiter=" ",
start_mark="<s>",
end_mark="<e>",
unk_mark="<unk>",
......@@ -164,7 +168,8 @@ class DataReader(object):
self._shuffle_batch = shuffle_batch
self._min_length = min_length
self._max_length = max_length
self._delimiter = delimiter
self._field_delimiter = field_delimiter
self._token_delimiter = token_delimiter
self._epoch_batches = []
src_seq_words, trg_seq_words = self._load_data(fpattern, tar_fname)
......@@ -196,7 +201,7 @@ class DataReader(object):
trg_seq_words = []
for line in f_obj:
fields = line.strip().split(self._delimiter)
fields = line.strip().split(self._field_delimiter)
if (not self._only_src and len(fields) != 2) or (self._only_src and
len(fields) != 1):
......@@ -207,7 +212,7 @@ class DataReader(object):
max_len = -1
for i, seq in enumerate(fields):
seq_words = seq.split()
seq_words = seq.split(self._token_delimiter)
max_len = max(max_len, len(seq_words))
if len(seq_words) == 0 or \
len(seq_words) < self._min_length or \
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册