diff --git a/data_utils/data.py b/data_utils/data.py index 2a6e99b75a3a09d54500de921d5149c4798d3905..46298bf75101d51cec0bc5257ed6178736fc4e42 100644 --- a/data_utils/data.py +++ b/data_utils/data.py @@ -11,7 +11,6 @@ import multiprocessing import numpy as np import paddle.v2 as paddle from threading import local -import atexit from data_utils.utility import read_manifest from data_utils.utility import xmap_readers_mp from data_utils.augmentor.augmentation import AugmentationPipeline @@ -194,15 +193,18 @@ class DataGenerator(object): raise ValueError("Unknown shuffle method %s." % shuffle_method) # prepare batches - instance_reader = self._instance_reader_creator(manifest) + instance_reader, cleanup = self._instance_reader_creator(manifest) batch = [] - for instance in instance_reader(): - batch.append(instance) - if len(batch) == batch_size: + try: + for instance in instance_reader(): + batch.append(instance) + if len(batch) == batch_size: + yield self._padding_batch(batch, padding_to, flatten) + batch = [] + if len(batch) >= min_batch_size: yield self._padding_batch(batch, padding_to, flatten) - batch = [] - if len(batch) >= min_batch_size: - yield self._padding_batch(batch, padding_to, flatten) + finally: + cleanup() self._epoch += 1 return batch_reader @@ -280,10 +282,7 @@ class DataGenerator(object): lambda instance: self.process_utterance(instance["audio_filepath"], instance["text"]), reader, self._num_threads, 4096) - # register callback to main process - atexit.register(cleanup_callback) - - return reader + return reader, cleanup_callback def _padding_batch(self, batch, padding_to=-1, flatten=False): """