提交 f3c92f11 编写于 作者: G guoshengCS

Fix text decoding in Transformer under python3.

上级 03dbae4c
......@@ -48,14 +48,14 @@ def parse_args():
help="The buffer size to pool data.")
parser.add_argument(
"--special_token",
type=str,
default=["<s>", "<e>", "<unk>"],
type=lambda x: x.encode(),
default=[b"<s>", b"<e>", b"<unk>"],
nargs=3,
help="The <bos>, <eos> and <unk> tokens in the dictionary.")
parser.add_argument(
"--token_delimiter",
type=lambda x: str(x.encode().decode("unicode-escape")),
default=" ",
type=lambda x: x.encode(),
default=b" ",
help="The delimiter used to split tokens in source or target sentences. "
"For EN-DE BPE data we provided, use spaces as token delimiter. ")
parser.add_argument(
......
......@@ -183,11 +183,11 @@ class DataReader(object):
shuffle_seed=None,
shuffle_batch=False,
use_token_batch=False,
field_delimiter="\t",
token_delimiter=" ",
start_mark="<s>",
end_mark="<e>",
unk_mark="<unk>",
field_delimiter=b"\t",
token_delimiter=b" ",
start_mark=b"<s>",
end_mark=b"<e>",
unk_mark=b"<unk>",
seed=0):
self._src_vocab = self.load_dict(src_vocab_fpath)
self._only_src = True
......@@ -254,9 +254,9 @@ class DataReader(object):
if tar_fname is None:
raise Exception("If tar file provided, please set tar_fname.")
f = tarfile.open(fpaths[0], "r")
f = tarfile.open(fpaths[0], "rb")
for line in f.extractfile(tar_fname):
fields = line.strip("\n").split(self._field_delimiter)
fields = line.strip(b"\n").split(self._field_delimiter)
if (not self._only_src and len(fields) == 2) or (
self._only_src and len(fields) == 1):
yield fields
......@@ -267,9 +267,7 @@ class DataReader(object):
with open(fpath, "rb") as f:
for line in f:
if six.PY3:
line = line.decode("utf8", errors="ignore")
fields = line.strip("\n").split(self._field_delimiter)
fields = line.strip(b"\n").split(self._field_delimiter)
if (not self._only_src and len(fields) == 2) or (
self._only_src and len(fields) == 1):
yield fields
......@@ -279,12 +277,10 @@ class DataReader(object):
word_dict = {}
with open(dict_path, "rb") as fdict:
for idx, line in enumerate(fdict):
if six.PY3:
line = line.decode("utf8", errors="ignore")
if reverse:
word_dict[idx] = line.strip("\n")
word_dict[idx] = line.strip(b"\n")
else:
word_dict[line.strip("\n")] = idx
word_dict[line.strip(b"\n")] = idx
return word_dict
def batch_generator(self):
......
......@@ -86,14 +86,14 @@ def parse_args():
help="The flag indicating whether to shuffle the data batches.")
parser.add_argument(
"--special_token",
type=str,
default=["<s>", "<e>", "<unk>"],
type=lambda x: x.encode(),
default=[b"<s>", b"<e>", b"<unk>"],
nargs=3,
help="The <bos>, <eos> and <unk> tokens in the dictionary.")
parser.add_argument(
"--token_delimiter",
type=lambda x: str(x.encode().decode("unicode-escape")),
default=" ",
type=lambda x: x.encode(),
default=b" ",
help="The delimiter used to split tokens in source or target sentences. "
"For EN-DE BPE data we provided, use spaces as token delimiter. ")
parser.add_argument(
......
......@@ -46,14 +46,14 @@ def parse_args():
help="The buffer size to pool data.")
parser.add_argument(
"--special_token",
type=str,
default=["<s>", "<e>", "<unk>"],
type=lambda x: x.encode(),
default=[b"<s>", b"<e>", b"<unk>"],
nargs=3,
help="The <bos>, <eos> and <unk> tokens in the dictionary.")
parser.add_argument(
"--token_delimiter",
type=lambda x: str(x.encode().decode("unicode-escape")),
default=" ",
type=lambda x: x.encode(),
default=b" ",
help="The delimiter used to split tokens in source or target sentences. "
"For EN-DE BPE data we provided, use spaces as token delimiter. ")
parser.add_argument(
......
......@@ -182,11 +182,11 @@ class DataReader(object):
shuffle=True,
shuffle_batch=False,
use_token_batch=False,
field_delimiter="\t",
token_delimiter=" ",
start_mark="<s>",
end_mark="<e>",
unk_mark="<unk>",
field_delimiter=b"\t",
token_delimiter=b" ",
start_mark=b"<s>",
end_mark=b"<e>",
unk_mark=b"<unk>",
seed=0):
self._src_vocab = self.load_dict(src_vocab_fpath)
self._only_src = True
......@@ -252,9 +252,9 @@ class DataReader(object):
if tar_fname is None:
raise Exception("If tar file provided, please set tar_fname.")
f = tarfile.open(fpaths[0], "r")
f = tarfile.open(fpaths[0], "rb")
for line in f.extractfile(tar_fname):
fields = line.strip("\n").split(self._field_delimiter)
fields = line.strip(b"\n").split(self._field_delimiter)
if (not self._only_src and len(fields) == 2) or (
self._only_src and len(fields) == 1):
yield fields
......@@ -265,9 +265,7 @@ class DataReader(object):
with open(fpath, "rb") as f:
for line in f:
if six.PY3:
line = line.decode("utf8", errors="ignore")
fields = line.strip("\n").split(self._field_delimiter)
fields = line.strip(b"\n").split(self._field_delimiter)
if (not self._only_src and len(fields) == 2) or (
self._only_src and len(fields) == 1):
yield fields
......@@ -277,12 +275,10 @@ class DataReader(object):
word_dict = {}
with open(dict_path, "rb") as fdict:
for idx, line in enumerate(fdict):
if six.PY3:
line = line.decode("utf8", errors="ignore")
if reverse:
word_dict[idx] = line.strip("\n")
word_dict[idx] = line.strip(b"\n")
else:
word_dict[line.strip("\n")] = idx
word_dict[line.strip(b"\n")] = idx
return word_dict
def batch_generator(self):
......
......@@ -74,14 +74,14 @@ def parse_args():
help="The flag indicating whether to shuffle the data batches.")
parser.add_argument(
"--special_token",
type=str,
default=["<s>", "<e>", "<unk>"],
type=lambda x: x.encode(),
default=[b"<s>", b"<e>", b"<unk>"],
nargs=3,
help="The <bos>, <eos> and <unk> tokens in the dictionary.")
parser.add_argument(
"--token_delimiter",
type=lambda x: str(x.encode().decode("unicode-escape")),
default=" ",
type=lambda x: x.encode(),
default=b" ",
help="The delimiter used to split tokens in source or target sentences. "
"For EN-DE BPE data we provided, use spaces as token delimiter. ")
parser.add_argument(
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册