diff --git a/fluid/neural_machine_translation/transformer/infer.py b/fluid/neural_machine_translation/transformer/infer.py index 505bf0b0062bda27a0299ed7d844e2f05abd95b8..40b3a459d63f95ac36b81f9b00731fcde9a4c36b 100644 --- a/fluid/neural_machine_translation/transformer/infer.py +++ b/fluid/neural_machine_translation/transformer/infer.py @@ -1,4 +1,5 @@ import argparse +import ast import numpy as np import paddle @@ -11,6 +12,7 @@ from model import fast_decode as fast_decoder from config import * from train import pad_batch_data import reader +import util def parse_args(): @@ -46,6 +48,14 @@ def parse_args(): default=["", "", ""], nargs=3, help="The , and tokens in the dictionary.") + parser.add_argument( + "--use_wordpiece", + type=ast.literal_eval, + default=False, + help="The flag indicating if the data is wordpiece data. The EN-FR data we " + "provided is wordpiece data. For wordpiece data, converting ids to " + "original words is a little different and some special codes are " + "provided in util.py to do this.") parser.add_argument( 'opts', help='See config.py for all options', @@ -320,7 +330,7 @@ def post_process_seq(seq, seq) -def py_infer(test_data, trg_idx2word): +def py_infer(test_data, trg_idx2word, use_wordpiece): """ Inference by beam search implented by python, while the calculations from symbols to probilities execute by Fluid operators. @@ -399,7 +409,10 @@ def py_infer(test_data, trg_idx2word): seqs = map(post_process_seq, batch_seqs[i]) scores = batch_scores[i] for seq in seqs: - print(" ".join([trg_idx2word[idx] for idx in seq])) + if use_wordpiece: + print(util.subword_ids_to_str(seq, trg_idx2word)) + else: + print(" ".join([trg_idx2word[idx] for idx in seq])) def prepare_batch_input(insts, data_input_names, util_input_names, src_pad_idx, @@ -465,7 +478,7 @@ def prepare_batch_input(insts, data_input_names, util_input_names, src_pad_idx, return input_dict -def fast_infer(test_data, trg_idx2word): +def fast_infer(test_data, trg_idx2word, use_wordpiece): """ Inference by beam search decoder based solely on Fluid operators. """ @@ -520,7 +533,9 @@ def fast_infer(test_data, trg_idx2word): trg_idx2word[idx] for idx in post_process_seq( np.array(seq_ids)[sub_start:sub_end]) - ])) + ]) if not use_wordpiece else util.subword_ids_to_str( + post_process_seq(np.array(seq_ids)[sub_start:sub_end]), + trg_idx2word)) scores[i].append(np.array(seq_scores)[sub_end - 1]) print hyps[i][-1] if len(hyps[i]) >= InferTaskConfig.n_best: @@ -548,7 +563,7 @@ def infer(args, inferencer=fast_infer): clip_last_batch=False) trg_idx2word = test_data.load_dict( dict_path=args.trg_vocab_fpath, reverse=True) - inferencer(test_data, trg_idx2word) + inferencer(test_data, trg_idx2word, args.use_wordpiece) if __name__ == "__main__": diff --git a/fluid/neural_machine_translation/transformer/reader.py b/fluid/neural_machine_translation/transformer/reader.py index 27bd82b13a0480e80bdfcdc72eaa670854f4cd3a..d7ec0e36a4ac0fea5bac73524125c448601f5a0f 100644 --- a/fluid/neural_machine_translation/transformer/reader.py +++ b/fluid/neural_machine_translation/transformer/reader.py @@ -116,9 +116,12 @@ class DataReader(object): :param use_token_batch: Whether to produce batch data according to token number. :type use_token_batch: bool - :param delimiter: The delimiter used to split source and target in each - line of data file. - :type delimiter: basestring + :param field_delimiter: The delimiter used to split source and target in + each line of data file. + :type field_delimiter: basestring + :param token_delimiter: The delimiter used to split tokens in source or + target sentences. + :type token_delimiter: basestring :param start_mark: The token representing for the beginning of sentences in dictionary. :type start_mark: basestring @@ -145,7 +148,8 @@ class DataReader(object): shuffle=True, shuffle_batch=False, use_token_batch=False, - delimiter="\t", + field_delimiter="\t", + token_delimiter=" ", start_mark="", end_mark="", unk_mark="", @@ -164,7 +168,8 @@ class DataReader(object): self._shuffle_batch = shuffle_batch self._min_length = min_length self._max_length = max_length - self._delimiter = delimiter + self._field_delimiter = field_delimiter + self._token_delimiter = token_delimiter self._epoch_batches = [] src_seq_words, trg_seq_words = self._load_data(fpattern, tar_fname) @@ -196,7 +201,7 @@ class DataReader(object): trg_seq_words = [] for line in f_obj: - fields = line.strip().split(self._delimiter) + fields = line.strip().split(self._field_delimiter) if (not self._only_src and len(fields) != 2) or (self._only_src and len(fields) != 1): @@ -207,7 +212,7 @@ class DataReader(object): max_len = -1 for i, seq in enumerate(fields): - seq_words = seq.split() + seq_words = seq.split(self._token_delimiter) max_len = max(max_len, len(seq_words)) if len(seq_words) == 0 or \ len(seq_words) < self._min_length or \