提交 6fc33ac5 编写于 作者: Y Yu Yang

Merge branch 'speed_up_transformer_python_reader' of...

Merge branch 'speed_up_transformer_python_reader' of https://github.com/reyoung/models into speed_up_transformer_python_reader
...@@ -202,7 +202,6 @@ class DataReader(object): ...@@ -202,7 +202,6 @@ class DataReader(object):
self._max_length = max_length self._max_length = max_length
self._field_delimiter = field_delimiter self._field_delimiter = field_delimiter
self._token_delimiter = token_delimiter self._token_delimiter = token_delimiter
self._epoch_batches = []
self.load_src_trg_ids(end_mark, fpattern, start_mark, tar_fname, self.load_src_trg_ids(end_mark, fpattern, start_mark, tar_fname,
unk_mark) unk_mark)
self._random = random.Random(x=seed) self._random = random.Random(x=seed)
...@@ -248,17 +247,23 @@ class DataReader(object): ...@@ -248,17 +247,23 @@ class DataReader(object):
if tar_fname is None: if tar_fname is None:
raise Exception("If tar file provided, please set tar_fname.") raise Exception("If tar file provided, please set tar_fname.")
f = tarfile.open(fpaths[0], 'r') f = tarfile.open(fpaths[0], "r")
for line in f.extractfile(tar_fname): for line in f.extractfile(tar_fname):
yield line.split(self._field_delimiter) fields = line.strip("\n").split(self._field_delimiter)
if (not self._only_src and len(fields) == 2) or (
self._only_src and len(fields) == 1):
yield fields
else: else:
for fpath in fpaths: for fpath in fpaths:
if not os.path.isfile(fpath): if not os.path.isfile(fpath):
raise IOError("Invalid file: %s" % fpath) raise IOError("Invalid file: %s" % fpath)
with open(fpath, 'r') as f: with open(fpath, "r") as f:
for line in f: for line in f:
yield line.split(self._field_delimiter) fields = line.strip("\n").split(self._field_delimiter)
if (not self._only_src and len(fields) == 2) or (
self._only_src and len(fields) == 1):
yield fields
@staticmethod @staticmethod
def load_dict(dict_path, reverse=False): def load_dict(dict_path, reverse=False):
...@@ -266,9 +271,9 @@ class DataReader(object): ...@@ -266,9 +271,9 @@ class DataReader(object):
with open(dict_path, "r") as fdict: with open(dict_path, "r") as fdict:
for idx, line in enumerate(fdict): for idx, line in enumerate(fdict):
if reverse: if reverse:
word_dict[idx] = line.strip('\n') word_dict[idx] = line.strip("\n")
else: else:
word_dict[line.strip('\n')] = idx word_dict[line.strip("\n")] = idx
return word_dict return word_dict
def batch_generator(self): def batch_generator(self):
...@@ -285,12 +290,8 @@ class DataReader(object): ...@@ -285,12 +290,8 @@ class DataReader(object):
if self._sort_type == SortType.POOL: if self._sort_type == SortType.POOL:
for i in range(0, len(infos), self._pool_size): for i in range(0, len(infos), self._pool_size):
infos[i * self._pool_size:(i + 1) * infos[i:i + self._pool_size] = sorted(
self._pool_size] = sorted( infos[i:i + self._pool_size], key=lambda x: x.max_len)
infos[i * self._pool_size:(i + 1) *
self._pool_size],
key=lambda x: x.max_len,
reverse=True)
# concat batch # concat batch
batches = [] batches = []
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册