提交 94c7e710 编写于 作者: Y Yu Yang

Support SortType.Pool

上级 dcdb6a00
...@@ -286,15 +286,24 @@ class DataReader(object): ...@@ -286,15 +286,24 @@ class DataReader(object):
for info in infos: for info in infos:
batch = batch_creator.append(info) batch = batch_creator.append(info)
if batch is not None: if batch is not None:
batches.append([info.i for info in batch]) batches.append(batch)
if not self._clip_last_batch and len(batch_creator.batch) != 0: if not self._clip_last_batch and len(batch_creator.batch) != 0:
batches.append([info.i for info in batch_creator.batch]) batches.append(batch_creator.batch)
if self._shuffle: if self._shuffle:
self._random.shuffle(batches) self._random.shuffle(batches)
for batch_ids in batches: for batch in batches:
if self._sort_type == SortType.POOL:
batch_ids = [
info.i
for info in sorted(
batch, key=lambda info: info.max_len)
]
else:
batch_ids = [info.i for info in batch]
if self._only_src: if self._only_src:
yield [[self._src_seq_ids[idx]] for idx in batch_ids] yield [[self._src_seq_ids[idx]] for idx in batch_ids]
else: else:
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册