提交 d48788a7 编写于 作者: S sunxl1988

test=dygraph del prefetch

上级 e9672b57
......@@ -66,55 +66,6 @@ class Compose(object):
return data
class Prefetcher(threading.Thread):
def __init__(self, iterator, prefetch_num=1):
threading.Thread.__init__(self)
self.queue = Queue.Queue(prefetch_num)
self.iterator = iterator
self.daemon = True
self.start()
def run(self):
for item in self.iterator:
self.queue.put(item)
self.queue.put(None)
def next(self):
next_item = self.queue.get()
if next_item is None:
raise StopIteration
return next_item
# Python 3 compatibility
def __next__(self):
return self.next()
def __iter__(self):
return self
class DataLoaderPrefetch(DataLoader):
def __init__(self,
dataset,
batch_sampler,
collate_fn,
num_workers,
places,
return_list,
prefetch_num=1):
super(DataLoaderPrefetch, self).__init__(
dataset=dataset,
batch_sampler=batch_sampler,
collate_fn=collate_fn,
num_workers=num_workers,
places=places,
return_list=return_list)
self.prefetch_num = prefetch_num
def __iter__(self):
return Prefetcher(super().__iter__(), self.prefetch_num)
class BaseDataLoader(object):
__share__ = ['num_classes']
__inject__ = ['dataset']
......@@ -159,24 +110,14 @@ class BaseDataLoader(object):
worker_num,
device,
return_list=False,
use_prefetch=False,
prefetch_num=None):
if use_prefetch:
loader = DataLoaderPrefetch(
dataset=self._dataset,
batch_sampler=self._batch_sampler,
collate_fn=self._batch_transforms,
num_workers=worker_num,
places=device,
return_list=return_list,
prefetch_num=prefetch_num
if prefetch_num is not None else self.batch_size)
else:
use_prefetch=False):
loader = DataLoader(
dataset=self._dataset,
batch_sampler=self._batch_sampler,
collate_fn=self._batch_transforms,
num_workers=worker_num,
use_buffer_reader=use_prefetch,
places=device,
return_list=return_list)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册