diff --git a/ppdet/data/loader.py b/ppdet/data/loader.py index a2c6fdd0e284769abe8ec038f4972f4789e0f722..29bf8a5b090dd57b62519dc8ae1d97f97902505f 100644 --- a/ppdet/data/loader.py +++ b/ppdet/data/loader.py @@ -66,55 +66,6 @@ class Compose(object): return data -class Prefetcher(threading.Thread): - def __init__(self, iterator, prefetch_num=1): - threading.Thread.__init__(self) - self.queue = Queue.Queue(prefetch_num) - self.iterator = iterator - self.daemon = True - self.start() - - def run(self): - for item in self.iterator: - self.queue.put(item) - self.queue.put(None) - - def next(self): - next_item = self.queue.get() - if next_item is None: - raise StopIteration - return next_item - - # Python 3 compatibility - def __next__(self): - return self.next() - - def __iter__(self): - return self - - -class DataLoaderPrefetch(DataLoader): - def __init__(self, - dataset, - batch_sampler, - collate_fn, - num_workers, - places, - return_list, - prefetch_num=1): - super(DataLoaderPrefetch, self).__init__( - dataset=dataset, - batch_sampler=batch_sampler, - collate_fn=collate_fn, - num_workers=num_workers, - places=places, - return_list=return_list) - self.prefetch_num = prefetch_num - - def __iter__(self): - return Prefetcher(super().__iter__(), self.prefetch_num) - - class BaseDataLoader(object): __share__ = ['num_classes'] __inject__ = ['dataset'] @@ -159,26 +110,16 @@ class BaseDataLoader(object): worker_num, device, return_list=False, - use_prefetch=False, - prefetch_num=None): - if use_prefetch: - loader = DataLoaderPrefetch( - dataset=self._dataset, - batch_sampler=self._batch_sampler, - collate_fn=self._batch_transforms, - num_workers=worker_num, - places=device, - return_list=return_list, - prefetch_num=prefetch_num - if prefetch_num is not None else self.batch_size) - else: - loader = DataLoader( - dataset=self._dataset, - batch_sampler=self._batch_sampler, - collate_fn=self._batch_transforms, - num_workers=worker_num, - places=device, - return_list=return_list) + use_prefetch=False): + + loader = DataLoader( + dataset=self._dataset, + batch_sampler=self._batch_sampler, + collate_fn=self._batch_transforms, + num_workers=worker_num, + use_buffer_reader=use_prefetch, + places=device, + return_list=return_list) return loader, len(self._batch_sampler)