提交 d48788a7 编写于 作者: S sunxl1988

test=dygraph del prefetch

上级 e9672b57
...@@ -66,55 +66,6 @@ class Compose(object): ...@@ -66,55 +66,6 @@ class Compose(object):
return data return data
class Prefetcher(threading.Thread):
def __init__(self, iterator, prefetch_num=1):
threading.Thread.__init__(self)
self.queue = Queue.Queue(prefetch_num)
self.iterator = iterator
self.daemon = True
self.start()
def run(self):
for item in self.iterator:
self.queue.put(item)
self.queue.put(None)
def next(self):
next_item = self.queue.get()
if next_item is None:
raise StopIteration
return next_item
# Python 3 compatibility
def __next__(self):
return self.next()
def __iter__(self):
return self
class DataLoaderPrefetch(DataLoader):
def __init__(self,
dataset,
batch_sampler,
collate_fn,
num_workers,
places,
return_list,
prefetch_num=1):
super(DataLoaderPrefetch, self).__init__(
dataset=dataset,
batch_sampler=batch_sampler,
collate_fn=collate_fn,
num_workers=num_workers,
places=places,
return_list=return_list)
self.prefetch_num = prefetch_num
def __iter__(self):
return Prefetcher(super().__iter__(), self.prefetch_num)
class BaseDataLoader(object): class BaseDataLoader(object):
__share__ = ['num_classes'] __share__ = ['num_classes']
__inject__ = ['dataset'] __inject__ = ['dataset']
...@@ -159,26 +110,16 @@ class BaseDataLoader(object): ...@@ -159,26 +110,16 @@ class BaseDataLoader(object):
worker_num, worker_num,
device, device,
return_list=False, return_list=False,
use_prefetch=False, use_prefetch=False):
prefetch_num=None):
if use_prefetch: loader = DataLoader(
loader = DataLoaderPrefetch( dataset=self._dataset,
dataset=self._dataset, batch_sampler=self._batch_sampler,
batch_sampler=self._batch_sampler, collate_fn=self._batch_transforms,
collate_fn=self._batch_transforms, num_workers=worker_num,
num_workers=worker_num, use_buffer_reader=use_prefetch,
places=device, places=device,
return_list=return_list, return_list=return_list)
prefetch_num=prefetch_num
if prefetch_num is not None else self.batch_size)
else:
loader = DataLoader(
dataset=self._dataset,
batch_sampler=self._batch_sampler,
collate_fn=self._batch_transforms,
num_workers=worker_num,
places=device,
return_list=return_list)
return loader, len(self._batch_sampler) return loader, len(self._batch_sampler)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册