diff --git a/fluid/neural_machine_translation/transformer/reader.py b/fluid/neural_machine_translation/transformer/reader.py index 6170c57f4f212671d4949915a594db1d5ca52b07..9a8e7e365c631bff881040c86aa26fcfb5c9724c 100644 --- a/fluid/neural_machine_translation/transformer/reader.py +++ b/fluid/neural_machine_translation/transformer/reader.py @@ -286,15 +286,24 @@ class DataReader(object): for info in infos: batch = batch_creator.append(info) if batch is not None: - batches.append([info.i for info in batch]) + batches.append(batch) if not self._clip_last_batch and len(batch_creator.batch) != 0: - batches.append([info.i for info in batch_creator.batch]) + batches.append(batch_creator.batch) if self._shuffle: self._random.shuffle(batches) - for batch_ids in batches: + for batch in batches: + if self._sort_type == SortType.POOL: + batch_ids = [ + info.i + for info in sorted( + batch, key=lambda info: info.max_len) + ] + else: + batch_ids = [info.i for info in batch] + if self._only_src: yield [[self._src_seq_ids[idx]] for idx in batch_ids] else: