提交 6fc33ac5 编写于 作者: Y Yu Yang

Merge branch 'speed_up_transformer_python_reader' of...

Merge branch 'speed_up_transformer_python_reader' of https://github.com/reyoung/models into speed_up_transformer_python_reader
......@@ -202,7 +202,6 @@ class DataReader(object):
self._max_length = max_length
self._field_delimiter = field_delimiter
self._token_delimiter = token_delimiter
self._epoch_batches = []
self.load_src_trg_ids(end_mark, fpattern, start_mark, tar_fname,
unk_mark)
self._random = random.Random(x=seed)
......@@ -248,17 +247,23 @@ class DataReader(object):
if tar_fname is None:
raise Exception("If tar file provided, please set tar_fname.")
f = tarfile.open(fpaths[0], 'r')
f = tarfile.open(fpaths[0], "r")
for line in f.extractfile(tar_fname):
yield line.split(self._field_delimiter)
fields = line.strip("\n").split(self._field_delimiter)
if (not self._only_src and len(fields) == 2) or (
self._only_src and len(fields) == 1):
yield fields
else:
for fpath in fpaths:
if not os.path.isfile(fpath):
raise IOError("Invalid file: %s" % fpath)
with open(fpath, 'r') as f:
with open(fpath, "r") as f:
for line in f:
yield line.split(self._field_delimiter)
fields = line.strip("\n").split(self._field_delimiter)
if (not self._only_src and len(fields) == 2) or (
self._only_src and len(fields) == 1):
yield fields
@staticmethod
def load_dict(dict_path, reverse=False):
......@@ -266,9 +271,9 @@ class DataReader(object):
with open(dict_path, "r") as fdict:
for idx, line in enumerate(fdict):
if reverse:
word_dict[idx] = line.strip('\n')
word_dict[idx] = line.strip("\n")
else:
word_dict[line.strip('\n')] = idx
word_dict[line.strip("\n")] = idx
return word_dict
def batch_generator(self):
......@@ -285,12 +290,8 @@ class DataReader(object):
if self._sort_type == SortType.POOL:
for i in range(0, len(infos), self._pool_size):
infos[i * self._pool_size:(i + 1) *
self._pool_size] = sorted(
infos[i * self._pool_size:(i + 1) *
self._pool_size],
key=lambda x: x.max_len,
reverse=True)
infos[i:i + self._pool_size] = sorted(
infos[i:i + self._pool_size], key=lambda x: x.max_len)
# concat batch
batches = []
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册