提交 3dae967f 编写于 作者: D dengkaipeng

fix dataloader not restart

上级 a04b2a74
......@@ -136,7 +136,7 @@ class BaseDataLoader(object):
else:
self._batch_sampler = batch_sampler
self.loader = DataLoader(
self.dataloader = DataLoader(
dataset=self.dataset,
batch_sampler=self._batch_sampler,
collate_fn=self._batch_transforms,
......@@ -145,7 +145,7 @@ class BaseDataLoader(object):
return_list=return_list,
use_buffer_reader=use_prefetch,
use_shared_memory=False)
self.loader = iter(self.loader)
self.loader = iter(self.dataloader)
return self
......@@ -163,6 +163,7 @@ class BaseDataLoader(object):
data = next(self.loader)
return {k: v for k, v in zip(self._fields, data)}
except StopIteration:
self.loader = iter(self.dataloader)
six.reraise(*sys.exc_info())
def next(self):
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册