diff --git a/fluid/neural_machine_translation/transformer/reader.py b/fluid/neural_machine_translation/transformer/reader.py index 604d1ddc9f4ef347e320b3fe9ac9672bcc4c9447..caede1f77e82796a4750723d139034baae6b66ca 100644 --- a/fluid/neural_machine_translation/transformer/reader.py +++ b/fluid/neural_machine_translation/transformer/reader.py @@ -264,7 +264,8 @@ class DataReader(object): def batch_generator(self): # global sort or global shuffle if self._sort_type == SortType.GLOBAL: - infos = sorted(self._sample_infos, key=lambda x: x.max_len) + infos = sorted( + self._sample_infos, key=lambda x: x.max_len, reverse=True) else: if self._shuffle: infos = self._sample_infos @@ -278,7 +279,8 @@ class DataReader(object): self._pool_size] = sorted( infos[i * self._pool_size:(i + 1) * self._pool_size], - key=lambda x: x.max_len) + key=lambda x: x.max_len, + reverse=True) # concat batch batches = []