提交 177d4910 编写于 作者: G guoshengCS

Fix txt encoding by using byte data.

上级 1214fdfd
...@@ -55,7 +55,7 @@ ...@@ -55,7 +55,7 @@
3. 模型预测 3. 模型预测
使用以上提供的数据和模型,可以按照以下代码进行预测,翻译结果将打印到标准输出: 使用以上提供的数据和模型,可以按照以下代码进行预测,翻译结果将打印到 `output_file` 指定的文件中:
```sh ```sh
# base model # base model
python -u infer.py \ python -u infer.py \
...@@ -63,6 +63,7 @@ ...@@ -63,6 +63,7 @@
--trg_vocab_fpath gen_data/wmt16_ende_data_bpe/vocab_all.bpe.32000 \ --trg_vocab_fpath gen_data/wmt16_ende_data_bpe/vocab_all.bpe.32000 \
--special_token '<s>' '<e>' '<unk>' \ --special_token '<s>' '<e>' '<unk>' \
--test_file_pattern gen_data/wmt16_ende_data_bpe/newstest2014.tok.bpe.32000.en-de \ --test_file_pattern gen_data/wmt16_ende_data_bpe/newstest2014.tok.bpe.32000.en-de \
--output_file predict.txt \
--token_delimiter ' ' \ --token_delimiter ' ' \
--batch_size 32 \ --batch_size 32 \
model_path trained_models/iter_100000.infer.model \ model_path trained_models/iter_100000.infer.model \
...@@ -76,6 +77,7 @@ ...@@ -76,6 +77,7 @@
--trg_vocab_fpath gen_data/wmt16_ende_data_bpe/vocab_all.bpe.32000 \ --trg_vocab_fpath gen_data/wmt16_ende_data_bpe/vocab_all.bpe.32000 \
--special_token '<s>' '<e>' '<unk>' \ --special_token '<s>' '<e>' '<unk>' \
--test_file_pattern gen_data/wmt16_ende_data_bpe/newstest2014.tok.bpe.32000.en-de \ --test_file_pattern gen_data/wmt16_ende_data_bpe/newstest2014.tok.bpe.32000.en-de \
--output_file predict.txt \
--token_delimiter ' ' \ --token_delimiter ' ' \
--batch_size 32 \ --batch_size 32 \
model_path trained_models/iter_100000.infer.model \ model_path trained_models/iter_100000.infer.model \
......
...@@ -4,9 +4,6 @@ import multiprocessing ...@@ -4,9 +4,6 @@ import multiprocessing
import numpy as np import numpy as np
import os import os
import sys import sys
if sys.version[0] == '2':
reload(sys)
sys.setdefaultencoding("utf-8")
sys.path.append("../../") sys.path.append("../../")
sys.path.append("../../models/neural_machine_translation/transformer/") sys.path.append("../../models/neural_machine_translation/transformer/")
from functools import partial from functools import partial
...@@ -23,7 +20,7 @@ from train import pad_batch_data, prepare_data_generator ...@@ -23,7 +20,7 @@ from train import pad_batch_data, prepare_data_generator
def parse_args(): def parse_args():
parser = argparse.ArgumentParser("Training for Transformer.") parser = argparse.ArgumentParser("Inference for Transformer.")
parser.add_argument( parser.add_argument(
"--src_vocab_fpath", "--src_vocab_fpath",
type=str, type=str,
...@@ -39,6 +36,11 @@ def parse_args(): ...@@ -39,6 +36,11 @@ def parse_args():
type=str, type=str,
required=True, required=True,
help="The pattern to match test data files.") help="The pattern to match test data files.")
parser.add_argument(
"--output_file",
type=str,
default="predict.txt",
help="The file to output the translation results of to.")
parser.add_argument( parser.add_argument(
"--batch_size", "--batch_size",
type=int, type=int,
...@@ -51,14 +53,14 @@ def parse_args(): ...@@ -51,14 +53,14 @@ def parse_args():
help="The buffer size to pool data.") help="The buffer size to pool data.")
parser.add_argument( parser.add_argument(
"--special_token", "--special_token",
type=lambda x: x.encode(), type=lambda x: x.encode("utf8"),
default=[b"<s>", b"<e>", b"<unk>"], default=["<s>", "<e>", "<unk>"],
nargs=3, nargs=3,
help="The <bos>, <eos> and <unk> tokens in the dictionary.") help="The <bos>, <eos> and <unk> tokens in the dictionary.")
parser.add_argument( parser.add_argument(
"--token_delimiter", "--token_delimiter",
type=lambda x: x.encode(), type=lambda x: x.encode("utf8"),
default=b" ", default=" ",
help="The delimiter used to split tokens in source or target sentences. " help="The delimiter used to split tokens in source or target sentences. "
"For EN-DE BPE data we provided, use spaces as token delimiter. ") "For EN-DE BPE data we provided, use spaces as token delimiter. ")
parser.add_argument( parser.add_argument(
...@@ -271,6 +273,7 @@ def fast_infer(args): ...@@ -271,6 +273,7 @@ def fast_infer(args):
trg_idx2word = reader.DataReader.load_dict( trg_idx2word = reader.DataReader.load_dict(
dict_path=args.trg_vocab_fpath, reverse=True) dict_path=args.trg_vocab_fpath, reverse=True)
f = open(args.output_file, "wb")
while True: while True:
try: try:
feed_dict_list = prepare_feed_dict_list(data_generator, dev_count, feed_dict_list = prepare_feed_dict_list(data_generator, dev_count,
...@@ -316,7 +319,7 @@ def fast_infer(args): ...@@ -316,7 +319,7 @@ def fast_infer(args):
np.array(seq_ids)[sub_start:sub_end]) np.array(seq_ids)[sub_start:sub_end])
])) ]))
scores[i].append(np.array(seq_scores)[sub_end - 1]) scores[i].append(np.array(seq_scores)[sub_end - 1])
print(hyps[i][-1].decode("utf8")) f.write(hyps[i][-1] + b"\n")
if len(hyps[i]) >= InferTaskConfig.n_best: if len(hyps[i]) >= InferTaskConfig.n_best:
break break
except (StopIteration, fluid.core.EOFException): except (StopIteration, fluid.core.EOFException):
...@@ -324,6 +327,7 @@ def fast_infer(args): ...@@ -324,6 +327,7 @@ def fast_infer(args):
if args.use_py_reader: if args.use_py_reader:
pyreader.reset() pyreader.reset()
break break
f.close()
if __name__ == "__main__": if __name__ == "__main__":
......
...@@ -183,12 +183,23 @@ class DataReader(object): ...@@ -183,12 +183,23 @@ class DataReader(object):
shuffle_seed=None, shuffle_seed=None,
shuffle_batch=False, shuffle_batch=False,
use_token_batch=False, use_token_batch=False,
field_delimiter=b"\t", field_delimiter="\t",
token_delimiter=b" ", token_delimiter=" ",
start_mark=b"<s>", start_mark="<s>",
end_mark=b"<e>", end_mark="<e>",
unk_mark=b"<unk>", unk_mark="<unk>",
seed=0): seed=0):
# convert str to bytes, and use byte data
field_delimiter = field_delimiter if isinstance(
field_delimiter, bytes) else field_delimiter.encode("utf8")
token_delimiter = token_delimiter if isinstance(
token_delimiter, bytes) else token_delimiter.encode("utf8")
start_mark = start_mark if isinstance(
start_mark, bytes) else start_mark.encode("utf8")
end_mark = end_mark if isinstance(end_mark,
bytes) else end_mark.encode("utf8")
unk_mark = unk_mark if isinstance(unk_mark,
bytes) else unk_mark.encode("utf8")
self._src_vocab = self.load_dict(src_vocab_fpath) self._src_vocab = self.load_dict(src_vocab_fpath)
self._only_src = True self._only_src = True
if trg_vocab_fpath is not None: if trg_vocab_fpath is not None:
......
...@@ -11,9 +11,6 @@ if os.environ.get('FLAGS_eager_delete_tensor_gb', None) is None: ...@@ -11,9 +11,6 @@ if os.environ.get('FLAGS_eager_delete_tensor_gb', None) is None:
import six import six
import sys import sys
if sys.version[0] == '2':
reload(sys)
sys.setdefaultencoding("utf-8")
sys.path.append("../../") sys.path.append("../../")
sys.path.append("../../models/neural_machine_translation/transformer/") sys.path.append("../../models/neural_machine_translation/transformer/")
import time import time
...@@ -89,14 +86,14 @@ def parse_args(): ...@@ -89,14 +86,14 @@ def parse_args():
help="The flag indicating whether to shuffle the data batches.") help="The flag indicating whether to shuffle the data batches.")
parser.add_argument( parser.add_argument(
"--special_token", "--special_token",
type=lambda x: x.encode(), type=lambda x: x.encode("utf8"),
default=[b"<s>", b"<e>", b"<unk>"], default=["<s>", "<e>", "<unk>"],
nargs=3, nargs=3,
help="The <bos>, <eos> and <unk> tokens in the dictionary.") help="The <bos>, <eos> and <unk> tokens in the dictionary.")
parser.add_argument( parser.add_argument(
"--token_delimiter", "--token_delimiter",
type=lambda x: x.encode(), type=lambda x: x.encode("utf8"),
default=b" ", default=" ",
help="The delimiter used to split tokens in source or target sentences. " help="The delimiter used to split tokens in source or target sentences. "
"For EN-DE BPE data we provided, use spaces as token delimiter. ") "For EN-DE BPE data we provided, use spaces as token delimiter. ")
parser.add_argument( parser.add_argument(
......
...@@ -139,6 +139,7 @@ python -u infer.py \ ...@@ -139,6 +139,7 @@ python -u infer.py \
--trg_vocab_fpath gen_data/wmt16_ende_data_bpe/vocab_all.bpe.32000 \ --trg_vocab_fpath gen_data/wmt16_ende_data_bpe/vocab_all.bpe.32000 \
--special_token '<s>' '<e>' '<unk>' \ --special_token '<s>' '<e>' '<unk>' \
--test_file_pattern gen_data/wmt16_ende_data_bpe/newstest2016.tok.bpe.32000.en-de \ --test_file_pattern gen_data/wmt16_ende_data_bpe/newstest2016.tok.bpe.32000.en-de \
--output_file predict.txt \
--token_delimiter ' ' \ --token_delimiter ' ' \
--batch_size 32 \ --batch_size 32 \
model_path trained_models/iter_100000.infer.model \ model_path trained_models/iter_100000.infer.model \
...@@ -152,6 +153,7 @@ python -u infer.py \ ...@@ -152,6 +153,7 @@ python -u infer.py \
--trg_vocab_fpath gen_data/wmt16_ende_data_bpe/vocab_all.bpe.32000 \ --trg_vocab_fpath gen_data/wmt16_ende_data_bpe/vocab_all.bpe.32000 \
--special_token '<s>' '<e>' '<unk>' \ --special_token '<s>' '<e>' '<unk>' \
--test_file_pattern gen_data/wmt16_ende_data_bpe/newstest2016.tok.bpe.32000.en-de \ --test_file_pattern gen_data/wmt16_ende_data_bpe/newstest2016.tok.bpe.32000.en-de \
--output_file predict.txt \
--token_delimiter ' ' \ --token_delimiter ' ' \
--batch_size 32 \ --batch_size 32 \
model_path trained_models/iter_100000.infer.model \ model_path trained_models/iter_100000.infer.model \
...@@ -164,7 +166,7 @@ python -u infer.py \ ...@@ -164,7 +166,7 @@ python -u infer.py \
``` ```
此外相比于模型训练,预测时还有一些额外的参数,如需要设置 `model_path` 来给出模型所在目录,可以设置 `beam_size``max_out_len` 来指定 Beam Search 算法的搜索宽度和最大深度(翻译长度),这些参数也可以在 `config.py` 中的 `InferTaskConfig` 内查阅注释说明并进行更改设置。 此外相比于模型训练,预测时还有一些额外的参数,如需要设置 `model_path` 来给出模型所在目录,可以设置 `beam_size``max_out_len` 来指定 Beam Search 算法的搜索宽度和最大深度(翻译长度),这些参数也可以在 `config.py` 中的 `InferTaskConfig` 内查阅注释说明并进行更改设置。
执行以上预测命令会打印翻译结果到标准输出,每行输出是对应行输入的得分最高的翻译。对于使用 BPE 的英德数据,预测出的翻译结果也将是 BPE 表示的数据,要还原成原始的数据(这里指 tokenize 后的数据)才能进行正确的评估,可以使用以下命令来恢复 `predict.txt` 内的翻译结果到 `predict.tok.txt` 中(无需再次 tokenize 处理): 执行以上预测命令会打印翻译结果到 `output_file` 指定的文件中,每行输出是对应行输入的得分最高的翻译。对于使用 BPE 的英德数据,预测出的翻译结果也将是 BPE 表示的数据,要还原成原始的数据(这里指 tokenize 后的数据)才能进行正确的评估,可以使用以下命令来恢复 `predict.txt` 内的翻译结果到 `predict.tok.txt` 中(无需再次 tokenize 处理):
```sh ```sh
sed -r 's/(@@ )|(@@ ?$)//g' predict.txt > predict.tok.txt sed -r 's/(@@ )|(@@ ?$)//g' predict.txt > predict.tok.txt
``` ```
......
...@@ -4,9 +4,6 @@ import multiprocessing ...@@ -4,9 +4,6 @@ import multiprocessing
import numpy as np import numpy as np
import os import os
import sys import sys
if sys.version[0] == '2':
reload(sys)
sys.setdefaultencoding("utf-8")
from functools import partial from functools import partial
import paddle import paddle
...@@ -22,7 +19,7 @@ from train import pad_batch_data, prepare_data_generator ...@@ -22,7 +19,7 @@ from train import pad_batch_data, prepare_data_generator
def parse_args(): def parse_args():
parser = argparse.ArgumentParser("Training for Transformer.") parser = argparse.ArgumentParser("Inference for Transformer.")
parser.add_argument( parser.add_argument(
"--src_vocab_fpath", "--src_vocab_fpath",
type=str, type=str,
...@@ -38,6 +35,11 @@ def parse_args(): ...@@ -38,6 +35,11 @@ def parse_args():
type=str, type=str,
required=True, required=True,
help="The pattern to match test data files.") help="The pattern to match test data files.")
parser.add_argument(
"--output_file",
type=str,
default="predict.txt",
help="The file to output the translation results of to.")
parser.add_argument( parser.add_argument(
"--batch_size", "--batch_size",
type=int, type=int,
...@@ -50,14 +52,14 @@ def parse_args(): ...@@ -50,14 +52,14 @@ def parse_args():
help="The buffer size to pool data.") help="The buffer size to pool data.")
parser.add_argument( parser.add_argument(
"--special_token", "--special_token",
type=lambda x: x.encode(), type=lambda x: x.encode("utf8"),
default=[b"<s>", b"<e>", b"<unk>"], default=["<s>", "<e>", "<unk>"],
nargs=3, nargs=3,
help="The <bos>, <eos> and <unk> tokens in the dictionary.") help="The <bos>, <eos> and <unk> tokens in the dictionary.")
parser.add_argument( parser.add_argument(
"--token_delimiter", "--token_delimiter",
type=lambda x: x.encode(), type=lambda x: x.encode("utf8"),
default=b" ", default=" ",
help="The delimiter used to split tokens in source or target sentences. " help="The delimiter used to split tokens in source or target sentences. "
"For EN-DE BPE data we provided, use spaces as token delimiter. ") "For EN-DE BPE data we provided, use spaces as token delimiter. ")
parser.add_argument( parser.add_argument(
...@@ -268,6 +270,7 @@ def fast_infer(args): ...@@ -268,6 +270,7 @@ def fast_infer(args):
trg_idx2word = reader.DataReader.load_dict( trg_idx2word = reader.DataReader.load_dict(
dict_path=args.trg_vocab_fpath, reverse=True) dict_path=args.trg_vocab_fpath, reverse=True)
f = open(args.output_file, "wb")
while True: while True:
try: try:
feed_dict_list = prepare_feed_dict_list(data_generator, dev_count, feed_dict_list = prepare_feed_dict_list(data_generator, dev_count,
...@@ -313,7 +316,7 @@ def fast_infer(args): ...@@ -313,7 +316,7 @@ def fast_infer(args):
np.array(seq_ids)[sub_start:sub_end]) np.array(seq_ids)[sub_start:sub_end])
])) ]))
scores[i].append(np.array(seq_scores)[sub_end - 1]) scores[i].append(np.array(seq_scores)[sub_end - 1])
print(hyps[i][-1].decode("utf8")) f.write(hyps[i][-1] + b"\n")
if len(hyps[i]) >= InferTaskConfig.n_best: if len(hyps[i]) >= InferTaskConfig.n_best:
break break
except (StopIteration, fluid.core.EOFException): except (StopIteration, fluid.core.EOFException):
...@@ -321,6 +324,7 @@ def fast_infer(args): ...@@ -321,6 +324,7 @@ def fast_infer(args):
if args.use_py_reader: if args.use_py_reader:
pyreader.reset() pyreader.reset()
break break
f.close()
if __name__ == "__main__": if __name__ == "__main__":
......
...@@ -182,12 +182,23 @@ class DataReader(object): ...@@ -182,12 +182,23 @@ class DataReader(object):
shuffle=True, shuffle=True,
shuffle_batch=False, shuffle_batch=False,
use_token_batch=False, use_token_batch=False,
field_delimiter=b"\t", field_delimiter="\t",
token_delimiter=b" ", token_delimiter=" ",
start_mark=b"<s>", start_mark="<s>",
end_mark=b"<e>", end_mark="<e>",
unk_mark=b"<unk>", unk_mark="<unk>",
seed=0): seed=0):
# convert str to bytes, and use byte data
field_delimiter = field_delimiter if isinstance(
field_delimiter, bytes) else field_delimiter.encode("utf8")
token_delimiter = token_delimiter if isinstance(
token_delimiter, bytes) else token_delimiter.encode("utf8")
start_mark = start_mark if isinstance(
start_mark, bytes) else start_mark.encode("utf8")
end_mark = end_mark if isinstance(end_mark,
bytes) else end_mark.encode("utf8")
unk_mark = unk_mark if isinstance(unk_mark,
bytes) else unk_mark.encode("utf8")
self._src_vocab = self.load_dict(src_vocab_fpath) self._src_vocab = self.load_dict(src_vocab_fpath)
self._only_src = True self._only_src = True
if trg_vocab_fpath is not None: if trg_vocab_fpath is not None:
......
...@@ -6,9 +6,6 @@ import multiprocessing ...@@ -6,9 +6,6 @@ import multiprocessing
import os import os
import six import six
import sys import sys
if sys.version[0] == '2':
reload(sys)
sys.setdefaultencoding("utf-8")
import time import time
import numpy as np import numpy as np
...@@ -77,14 +74,14 @@ def parse_args(): ...@@ -77,14 +74,14 @@ def parse_args():
help="The flag indicating whether to shuffle the data batches.") help="The flag indicating whether to shuffle the data batches.")
parser.add_argument( parser.add_argument(
"--special_token", "--special_token",
type=lambda x: x.encode(), type=lambda x: x.encode("utf8"),
default=[b"<s>", b"<e>", b"<unk>"], default=["<s>", "<e>", "<unk>"],
nargs=3, nargs=3,
help="The <bos>, <eos> and <unk> tokens in the dictionary.") help="The <bos>, <eos> and <unk> tokens in the dictionary.")
parser.add_argument( parser.add_argument(
"--token_delimiter", "--token_delimiter",
type=lambda x: x.encode(), type=lambda x: x.encode("utf8"),
default=b" ", default=" ",
help="The delimiter used to split tokens in source or target sentences. " help="The delimiter used to split tokens in source or target sentences. "
"For EN-DE BPE data we provided, use spaces as token delimiter. ") "For EN-DE BPE data we provided, use spaces as token delimiter. ")
parser.add_argument( parser.add_argument(
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册