提交 f3c92f11 编写于 作者: G guoshengCS

Fix text decoding in Transformer under python3.

上级 03dbae4c
...@@ -48,14 +48,14 @@ def parse_args(): ...@@ -48,14 +48,14 @@ def parse_args():
help="The buffer size to pool data.") help="The buffer size to pool data.")
parser.add_argument( parser.add_argument(
"--special_token", "--special_token",
type=str, type=lambda x: x.encode(),
default=["<s>", "<e>", "<unk>"], default=[b"<s>", b"<e>", b"<unk>"],
nargs=3, nargs=3,
help="The <bos>, <eos> and <unk> tokens in the dictionary.") help="The <bos>, <eos> and <unk> tokens in the dictionary.")
parser.add_argument( parser.add_argument(
"--token_delimiter", "--token_delimiter",
type=lambda x: str(x.encode().decode("unicode-escape")), type=lambda x: x.encode(),
default=" ", default=b" ",
help="The delimiter used to split tokens in source or target sentences. " help="The delimiter used to split tokens in source or target sentences. "
"For EN-DE BPE data we provided, use spaces as token delimiter. ") "For EN-DE BPE data we provided, use spaces as token delimiter. ")
parser.add_argument( parser.add_argument(
......
...@@ -183,11 +183,11 @@ class DataReader(object): ...@@ -183,11 +183,11 @@ class DataReader(object):
shuffle_seed=None, shuffle_seed=None,
shuffle_batch=False, shuffle_batch=False,
use_token_batch=False, use_token_batch=False,
field_delimiter="\t", field_delimiter=b"\t",
token_delimiter=" ", token_delimiter=b" ",
start_mark="<s>", start_mark=b"<s>",
end_mark="<e>", end_mark=b"<e>",
unk_mark="<unk>", unk_mark=b"<unk>",
seed=0): seed=0):
self._src_vocab = self.load_dict(src_vocab_fpath) self._src_vocab = self.load_dict(src_vocab_fpath)
self._only_src = True self._only_src = True
...@@ -254,9 +254,9 @@ class DataReader(object): ...@@ -254,9 +254,9 @@ class DataReader(object):
if tar_fname is None: if tar_fname is None:
raise Exception("If tar file provided, please set tar_fname.") raise Exception("If tar file provided, please set tar_fname.")
f = tarfile.open(fpaths[0], "r") f = tarfile.open(fpaths[0], "rb")
for line in f.extractfile(tar_fname): for line in f.extractfile(tar_fname):
fields = line.strip("\n").split(self._field_delimiter) fields = line.strip(b"\n").split(self._field_delimiter)
if (not self._only_src and len(fields) == 2) or ( if (not self._only_src and len(fields) == 2) or (
self._only_src and len(fields) == 1): self._only_src and len(fields) == 1):
yield fields yield fields
...@@ -267,9 +267,7 @@ class DataReader(object): ...@@ -267,9 +267,7 @@ class DataReader(object):
with open(fpath, "rb") as f: with open(fpath, "rb") as f:
for line in f: for line in f:
if six.PY3: fields = line.strip(b"\n").split(self._field_delimiter)
line = line.decode("utf8", errors="ignore")
fields = line.strip("\n").split(self._field_delimiter)
if (not self._only_src and len(fields) == 2) or ( if (not self._only_src and len(fields) == 2) or (
self._only_src and len(fields) == 1): self._only_src and len(fields) == 1):
yield fields yield fields
...@@ -279,12 +277,10 @@ class DataReader(object): ...@@ -279,12 +277,10 @@ class DataReader(object):
word_dict = {} word_dict = {}
with open(dict_path, "rb") as fdict: with open(dict_path, "rb") as fdict:
for idx, line in enumerate(fdict): for idx, line in enumerate(fdict):
if six.PY3:
line = line.decode("utf8", errors="ignore")
if reverse: if reverse:
word_dict[idx] = line.strip("\n") word_dict[idx] = line.strip(b"\n")
else: else:
word_dict[line.strip("\n")] = idx word_dict[line.strip(b"\n")] = idx
return word_dict return word_dict
def batch_generator(self): def batch_generator(self):
......
...@@ -86,14 +86,14 @@ def parse_args(): ...@@ -86,14 +86,14 @@ def parse_args():
help="The flag indicating whether to shuffle the data batches.") help="The flag indicating whether to shuffle the data batches.")
parser.add_argument( parser.add_argument(
"--special_token", "--special_token",
type=str, type=lambda x: x.encode(),
default=["<s>", "<e>", "<unk>"], default=[b"<s>", b"<e>", b"<unk>"],
nargs=3, nargs=3,
help="The <bos>, <eos> and <unk> tokens in the dictionary.") help="The <bos>, <eos> and <unk> tokens in the dictionary.")
parser.add_argument( parser.add_argument(
"--token_delimiter", "--token_delimiter",
type=lambda x: str(x.encode().decode("unicode-escape")), type=lambda x: x.encode(),
default=" ", default=b" ",
help="The delimiter used to split tokens in source or target sentences. " help="The delimiter used to split tokens in source or target sentences. "
"For EN-DE BPE data we provided, use spaces as token delimiter. ") "For EN-DE BPE data we provided, use spaces as token delimiter. ")
parser.add_argument( parser.add_argument(
......
...@@ -46,14 +46,14 @@ def parse_args(): ...@@ -46,14 +46,14 @@ def parse_args():
help="The buffer size to pool data.") help="The buffer size to pool data.")
parser.add_argument( parser.add_argument(
"--special_token", "--special_token",
type=str, type=lambda x: x.encode(),
default=["<s>", "<e>", "<unk>"], default=[b"<s>", b"<e>", b"<unk>"],
nargs=3, nargs=3,
help="The <bos>, <eos> and <unk> tokens in the dictionary.") help="The <bos>, <eos> and <unk> tokens in the dictionary.")
parser.add_argument( parser.add_argument(
"--token_delimiter", "--token_delimiter",
type=lambda x: str(x.encode().decode("unicode-escape")), type=lambda x: x.encode(),
default=" ", default=b" ",
help="The delimiter used to split tokens in source or target sentences. " help="The delimiter used to split tokens in source or target sentences. "
"For EN-DE BPE data we provided, use spaces as token delimiter. ") "For EN-DE BPE data we provided, use spaces as token delimiter. ")
parser.add_argument( parser.add_argument(
......
...@@ -182,11 +182,11 @@ class DataReader(object): ...@@ -182,11 +182,11 @@ class DataReader(object):
shuffle=True, shuffle=True,
shuffle_batch=False, shuffle_batch=False,
use_token_batch=False, use_token_batch=False,
field_delimiter="\t", field_delimiter=b"\t",
token_delimiter=" ", token_delimiter=b" ",
start_mark="<s>", start_mark=b"<s>",
end_mark="<e>", end_mark=b"<e>",
unk_mark="<unk>", unk_mark=b"<unk>",
seed=0): seed=0):
self._src_vocab = self.load_dict(src_vocab_fpath) self._src_vocab = self.load_dict(src_vocab_fpath)
self._only_src = True self._only_src = True
...@@ -252,9 +252,9 @@ class DataReader(object): ...@@ -252,9 +252,9 @@ class DataReader(object):
if tar_fname is None: if tar_fname is None:
raise Exception("If tar file provided, please set tar_fname.") raise Exception("If tar file provided, please set tar_fname.")
f = tarfile.open(fpaths[0], "r") f = tarfile.open(fpaths[0], "rb")
for line in f.extractfile(tar_fname): for line in f.extractfile(tar_fname):
fields = line.strip("\n").split(self._field_delimiter) fields = line.strip(b"\n").split(self._field_delimiter)
if (not self._only_src and len(fields) == 2) or ( if (not self._only_src and len(fields) == 2) or (
self._only_src and len(fields) == 1): self._only_src and len(fields) == 1):
yield fields yield fields
...@@ -265,9 +265,7 @@ class DataReader(object): ...@@ -265,9 +265,7 @@ class DataReader(object):
with open(fpath, "rb") as f: with open(fpath, "rb") as f:
for line in f: for line in f:
if six.PY3: fields = line.strip(b"\n").split(self._field_delimiter)
line = line.decode("utf8", errors="ignore")
fields = line.strip("\n").split(self._field_delimiter)
if (not self._only_src and len(fields) == 2) or ( if (not self._only_src and len(fields) == 2) or (
self._only_src and len(fields) == 1): self._only_src and len(fields) == 1):
yield fields yield fields
...@@ -277,12 +275,10 @@ class DataReader(object): ...@@ -277,12 +275,10 @@ class DataReader(object):
word_dict = {} word_dict = {}
with open(dict_path, "rb") as fdict: with open(dict_path, "rb") as fdict:
for idx, line in enumerate(fdict): for idx, line in enumerate(fdict):
if six.PY3:
line = line.decode("utf8", errors="ignore")
if reverse: if reverse:
word_dict[idx] = line.strip("\n") word_dict[idx] = line.strip(b"\n")
else: else:
word_dict[line.strip("\n")] = idx word_dict[line.strip(b"\n")] = idx
return word_dict return word_dict
def batch_generator(self): def batch_generator(self):
......
...@@ -74,14 +74,14 @@ def parse_args(): ...@@ -74,14 +74,14 @@ def parse_args():
help="The flag indicating whether to shuffle the data batches.") help="The flag indicating whether to shuffle the data batches.")
parser.add_argument( parser.add_argument(
"--special_token", "--special_token",
type=str, type=lambda x: x.encode(),
default=["<s>", "<e>", "<unk>"], default=[b"<s>", b"<e>", b"<unk>"],
nargs=3, nargs=3,
help="The <bos>, <eos> and <unk> tokens in the dictionary.") help="The <bos>, <eos> and <unk> tokens in the dictionary.")
parser.add_argument( parser.add_argument(
"--token_delimiter", "--token_delimiter",
type=lambda x: str(x.encode().decode("unicode-escape")), type=lambda x: x.encode(),
default=" ", default=b" ",
help="The delimiter used to split tokens in source or target sentences. " help="The delimiter used to split tokens in source or target sentences. "
"For EN-DE BPE data we provided, use spaces as token delimiter. ") "For EN-DE BPE data we provided, use spaces as token delimiter. ")
parser.add_argument( parser.add_argument(
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册