提交 177d4910 编写于 作者: G guoshengCS

Fix txt encoding by using byte data.

上级 1214fdfd
......@@ -55,7 +55,7 @@
3. 模型预测
使用以上提供的数据和模型,可以按照以下代码进行预测,翻译结果将打印到标准输出:
使用以上提供的数据和模型,可以按照以下代码进行预测,翻译结果将打印到 `output_file` 指定的文件中:
```sh
# base model
python -u infer.py \
......@@ -63,6 +63,7 @@
--trg_vocab_fpath gen_data/wmt16_ende_data_bpe/vocab_all.bpe.32000 \
--special_token '<s>' '<e>' '<unk>' \
--test_file_pattern gen_data/wmt16_ende_data_bpe/newstest2014.tok.bpe.32000.en-de \
--output_file predict.txt \
--token_delimiter ' ' \
--batch_size 32 \
model_path trained_models/iter_100000.infer.model \
......@@ -76,6 +77,7 @@
--trg_vocab_fpath gen_data/wmt16_ende_data_bpe/vocab_all.bpe.32000 \
--special_token '<s>' '<e>' '<unk>' \
--test_file_pattern gen_data/wmt16_ende_data_bpe/newstest2014.tok.bpe.32000.en-de \
--output_file predict.txt \
--token_delimiter ' ' \
--batch_size 32 \
model_path trained_models/iter_100000.infer.model \
......
......@@ -4,9 +4,6 @@ import multiprocessing
import numpy as np
import os
import sys
if sys.version[0] == '2':
reload(sys)
sys.setdefaultencoding("utf-8")
sys.path.append("../../")
sys.path.append("../../models/neural_machine_translation/transformer/")
from functools import partial
......@@ -23,7 +20,7 @@ from train import pad_batch_data, prepare_data_generator
def parse_args():
parser = argparse.ArgumentParser("Training for Transformer.")
parser = argparse.ArgumentParser("Inference for Transformer.")
parser.add_argument(
"--src_vocab_fpath",
type=str,
......@@ -39,6 +36,11 @@ def parse_args():
type=str,
required=True,
help="The pattern to match test data files.")
parser.add_argument(
"--output_file",
type=str,
default="predict.txt",
help="The file to output the translation results of to.")
parser.add_argument(
"--batch_size",
type=int,
......@@ -51,14 +53,14 @@ def parse_args():
help="The buffer size to pool data.")
parser.add_argument(
"--special_token",
type=lambda x: x.encode(),
default=[b"<s>", b"<e>", b"<unk>"],
type=lambda x: x.encode("utf8"),
default=["<s>", "<e>", "<unk>"],
nargs=3,
help="The <bos>, <eos> and <unk> tokens in the dictionary.")
parser.add_argument(
"--token_delimiter",
type=lambda x: x.encode(),
default=b" ",
type=lambda x: x.encode("utf8"),
default=" ",
help="The delimiter used to split tokens in source or target sentences. "
"For EN-DE BPE data we provided, use spaces as token delimiter. ")
parser.add_argument(
......@@ -271,6 +273,7 @@ def fast_infer(args):
trg_idx2word = reader.DataReader.load_dict(
dict_path=args.trg_vocab_fpath, reverse=True)
f = open(args.output_file, "wb")
while True:
try:
feed_dict_list = prepare_feed_dict_list(data_generator, dev_count,
......@@ -316,7 +319,7 @@ def fast_infer(args):
np.array(seq_ids)[sub_start:sub_end])
]))
scores[i].append(np.array(seq_scores)[sub_end - 1])
print(hyps[i][-1].decode("utf8"))
f.write(hyps[i][-1] + b"\n")
if len(hyps[i]) >= InferTaskConfig.n_best:
break
except (StopIteration, fluid.core.EOFException):
......@@ -324,6 +327,7 @@ def fast_infer(args):
if args.use_py_reader:
pyreader.reset()
break
f.close()
if __name__ == "__main__":
......
......@@ -183,12 +183,23 @@ class DataReader(object):
shuffle_seed=None,
shuffle_batch=False,
use_token_batch=False,
field_delimiter=b"\t",
token_delimiter=b" ",
start_mark=b"<s>",
end_mark=b"<e>",
unk_mark=b"<unk>",
field_delimiter="\t",
token_delimiter=" ",
start_mark="<s>",
end_mark="<e>",
unk_mark="<unk>",
seed=0):
# convert str to bytes, and use byte data
field_delimiter = field_delimiter if isinstance(
field_delimiter, bytes) else field_delimiter.encode("utf8")
token_delimiter = token_delimiter if isinstance(
token_delimiter, bytes) else token_delimiter.encode("utf8")
start_mark = start_mark if isinstance(
start_mark, bytes) else start_mark.encode("utf8")
end_mark = end_mark if isinstance(end_mark,
bytes) else end_mark.encode("utf8")
unk_mark = unk_mark if isinstance(unk_mark,
bytes) else unk_mark.encode("utf8")
self._src_vocab = self.load_dict(src_vocab_fpath)
self._only_src = True
if trg_vocab_fpath is not None:
......
......@@ -11,9 +11,6 @@ if os.environ.get('FLAGS_eager_delete_tensor_gb', None) is None:
import six
import sys
if sys.version[0] == '2':
reload(sys)
sys.setdefaultencoding("utf-8")
sys.path.append("../../")
sys.path.append("../../models/neural_machine_translation/transformer/")
import time
......@@ -89,14 +86,14 @@ def parse_args():
help="The flag indicating whether to shuffle the data batches.")
parser.add_argument(
"--special_token",
type=lambda x: x.encode(),
default=[b"<s>", b"<e>", b"<unk>"],
type=lambda x: x.encode("utf8"),
default=["<s>", "<e>", "<unk>"],
nargs=3,
help="The <bos>, <eos> and <unk> tokens in the dictionary.")
parser.add_argument(
"--token_delimiter",
type=lambda x: x.encode(),
default=b" ",
type=lambda x: x.encode("utf8"),
default=" ",
help="The delimiter used to split tokens in source or target sentences. "
"For EN-DE BPE data we provided, use spaces as token delimiter. ")
parser.add_argument(
......
......@@ -139,6 +139,7 @@ python -u infer.py \
--trg_vocab_fpath gen_data/wmt16_ende_data_bpe/vocab_all.bpe.32000 \
--special_token '<s>' '<e>' '<unk>' \
--test_file_pattern gen_data/wmt16_ende_data_bpe/newstest2016.tok.bpe.32000.en-de \
--output_file predict.txt \
--token_delimiter ' ' \
--batch_size 32 \
model_path trained_models/iter_100000.infer.model \
......@@ -152,6 +153,7 @@ python -u infer.py \
--trg_vocab_fpath gen_data/wmt16_ende_data_bpe/vocab_all.bpe.32000 \
--special_token '<s>' '<e>' '<unk>' \
--test_file_pattern gen_data/wmt16_ende_data_bpe/newstest2016.tok.bpe.32000.en-de \
--output_file predict.txt \
--token_delimiter ' ' \
--batch_size 32 \
model_path trained_models/iter_100000.infer.model \
......@@ -164,7 +166,7 @@ python -u infer.py \
```
此外相比于模型训练,预测时还有一些额外的参数,如需要设置 `model_path` 来给出模型所在目录,可以设置 `beam_size``max_out_len` 来指定 Beam Search 算法的搜索宽度和最大深度(翻译长度),这些参数也可以在 `config.py` 中的 `InferTaskConfig` 内查阅注释说明并进行更改设置。
执行以上预测命令会打印翻译结果到标准输出,每行输出是对应行输入的得分最高的翻译。对于使用 BPE 的英德数据,预测出的翻译结果也将是 BPE 表示的数据,要还原成原始的数据(这里指 tokenize 后的数据)才能进行正确的评估,可以使用以下命令来恢复 `predict.txt` 内的翻译结果到 `predict.tok.txt` 中(无需再次 tokenize 处理):
执行以上预测命令会打印翻译结果到 `output_file` 指定的文件中,每行输出是对应行输入的得分最高的翻译。对于使用 BPE 的英德数据,预测出的翻译结果也将是 BPE 表示的数据,要还原成原始的数据(这里指 tokenize 后的数据)才能进行正确的评估,可以使用以下命令来恢复 `predict.txt` 内的翻译结果到 `predict.tok.txt` 中(无需再次 tokenize 处理):
```sh
sed -r 's/(@@ )|(@@ ?$)//g' predict.txt > predict.tok.txt
```
......
......@@ -4,9 +4,6 @@ import multiprocessing
import numpy as np
import os
import sys
if sys.version[0] == '2':
reload(sys)
sys.setdefaultencoding("utf-8")
from functools import partial
import paddle
......@@ -22,7 +19,7 @@ from train import pad_batch_data, prepare_data_generator
def parse_args():
parser = argparse.ArgumentParser("Training for Transformer.")
parser = argparse.ArgumentParser("Inference for Transformer.")
parser.add_argument(
"--src_vocab_fpath",
type=str,
......@@ -38,6 +35,11 @@ def parse_args():
type=str,
required=True,
help="The pattern to match test data files.")
parser.add_argument(
"--output_file",
type=str,
default="predict.txt",
help="The file to output the translation results of to.")
parser.add_argument(
"--batch_size",
type=int,
......@@ -50,14 +52,14 @@ def parse_args():
help="The buffer size to pool data.")
parser.add_argument(
"--special_token",
type=lambda x: x.encode(),
default=[b"<s>", b"<e>", b"<unk>"],
type=lambda x: x.encode("utf8"),
default=["<s>", "<e>", "<unk>"],
nargs=3,
help="The <bos>, <eos> and <unk> tokens in the dictionary.")
parser.add_argument(
"--token_delimiter",
type=lambda x: x.encode(),
default=b" ",
type=lambda x: x.encode("utf8"),
default=" ",
help="The delimiter used to split tokens in source or target sentences. "
"For EN-DE BPE data we provided, use spaces as token delimiter. ")
parser.add_argument(
......@@ -268,6 +270,7 @@ def fast_infer(args):
trg_idx2word = reader.DataReader.load_dict(
dict_path=args.trg_vocab_fpath, reverse=True)
f = open(args.output_file, "wb")
while True:
try:
feed_dict_list = prepare_feed_dict_list(data_generator, dev_count,
......@@ -313,7 +316,7 @@ def fast_infer(args):
np.array(seq_ids)[sub_start:sub_end])
]))
scores[i].append(np.array(seq_scores)[sub_end - 1])
print(hyps[i][-1].decode("utf8"))
f.write(hyps[i][-1] + b"\n")
if len(hyps[i]) >= InferTaskConfig.n_best:
break
except (StopIteration, fluid.core.EOFException):
......@@ -321,6 +324,7 @@ def fast_infer(args):
if args.use_py_reader:
pyreader.reset()
break
f.close()
if __name__ == "__main__":
......
......@@ -182,12 +182,23 @@ class DataReader(object):
shuffle=True,
shuffle_batch=False,
use_token_batch=False,
field_delimiter=b"\t",
token_delimiter=b" ",
start_mark=b"<s>",
end_mark=b"<e>",
unk_mark=b"<unk>",
field_delimiter="\t",
token_delimiter=" ",
start_mark="<s>",
end_mark="<e>",
unk_mark="<unk>",
seed=0):
# convert str to bytes, and use byte data
field_delimiter = field_delimiter if isinstance(
field_delimiter, bytes) else field_delimiter.encode("utf8")
token_delimiter = token_delimiter if isinstance(
token_delimiter, bytes) else token_delimiter.encode("utf8")
start_mark = start_mark if isinstance(
start_mark, bytes) else start_mark.encode("utf8")
end_mark = end_mark if isinstance(end_mark,
bytes) else end_mark.encode("utf8")
unk_mark = unk_mark if isinstance(unk_mark,
bytes) else unk_mark.encode("utf8")
self._src_vocab = self.load_dict(src_vocab_fpath)
self._only_src = True
if trg_vocab_fpath is not None:
......
......@@ -6,9 +6,6 @@ import multiprocessing
import os
import six
import sys
if sys.version[0] == '2':
reload(sys)
sys.setdefaultencoding("utf-8")
import time
import numpy as np
......@@ -77,14 +74,14 @@ def parse_args():
help="The flag indicating whether to shuffle the data batches.")
parser.add_argument(
"--special_token",
type=lambda x: x.encode(),
default=[b"<s>", b"<e>", b"<unk>"],
type=lambda x: x.encode("utf8"),
default=["<s>", "<e>", "<unk>"],
nargs=3,
help="The <bos>, <eos> and <unk> tokens in the dictionary.")
parser.add_argument(
"--token_delimiter",
type=lambda x: x.encode(),
default=b" ",
type=lambda x: x.encode("utf8"),
default=" ",
help="The delimiter used to split tokens in source or target sentences. "
"For EN-DE BPE data we provided, use spaces as token delimiter. ")
parser.add_argument(
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册