提交 2d7b878e 编写于 作者: Y Yu Yang

Refine reader

上级 5cabce84
...@@ -264,16 +264,22 @@ class DataReader(object): ...@@ -264,16 +264,22 @@ class DataReader(object):
def batch_generator(self): def batch_generator(self):
# global sort or global shuffle # global sort or global shuffle
beg = time.time()
if self._sort_type == SortType.GLOBAL: if self._sort_type == SortType.GLOBAL:
infos = sorted( infos = sorted(self._sample_infos, key=lambda x: x.max_len)
self._sample_infos,
key=lambda x: max(x[1], x[2]) if not self._only_src else x[1])
elif self._shuffle:
infos = self._sample_infos
self._random.shuffle(infos)
else: else:
infos = self._sample_infos if self._shuffle:
infos = self._sample_infos
self._random.shuffle(infos)
else:
infos = self._sample_infos
if self._sort_type == SortType.POOL:
for i in range(0, len(infos), self._pool_size):
infos[i * self._pool_size:(i + 1) *
self._pool_size] = sorted(
infos[i * self._pool_size:(i + 1) *
self._pool_size],
key=lambda x: x.max_len)
# concat batch # concat batch
batches = [] batches = []
...@@ -295,14 +301,7 @@ class DataReader(object): ...@@ -295,14 +301,7 @@ class DataReader(object):
self._random.shuffle(batches) self._random.shuffle(batches)
for batch in batches: for batch in batches:
if self._sort_type == SortType.POOL: batch_ids = [info.i for info in batch]
batch_ids = [
info.i
for info in sorted(
batch, key=lambda info: info.max_len)
]
else:
batch_ids = [info.i for info in batch]
if self._only_src: if self._only_src:
yield [[self._src_seq_ids[idx]] for idx in batch_ids] yield [[self._src_seq_ids[idx]] for idx in batch_ids]
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册