提交 3dae967f 编写于 作者: D dengkaipeng

fix dataloader not restart

上级 a04b2a74
...@@ -136,7 +136,7 @@ class BaseDataLoader(object): ...@@ -136,7 +136,7 @@ class BaseDataLoader(object):
else: else:
self._batch_sampler = batch_sampler self._batch_sampler = batch_sampler
self.loader = DataLoader( self.dataloader = DataLoader(
dataset=self.dataset, dataset=self.dataset,
batch_sampler=self._batch_sampler, batch_sampler=self._batch_sampler,
collate_fn=self._batch_transforms, collate_fn=self._batch_transforms,
...@@ -145,7 +145,7 @@ class BaseDataLoader(object): ...@@ -145,7 +145,7 @@ class BaseDataLoader(object):
return_list=return_list, return_list=return_list,
use_buffer_reader=use_prefetch, use_buffer_reader=use_prefetch,
use_shared_memory=False) use_shared_memory=False)
self.loader = iter(self.loader) self.loader = iter(self.dataloader)
return self return self
...@@ -163,6 +163,7 @@ class BaseDataLoader(object): ...@@ -163,6 +163,7 @@ class BaseDataLoader(object):
data = next(self.loader) data = next(self.loader)
return {k: v for k, v in zip(self._fields, data)} return {k: v for k, v in zip(self._fields, data)}
except StopIteration: except StopIteration:
self.loader = iter(self.dataloader)
six.reraise(*sys.exc_info()) six.reraise(*sys.exc_info())
def next(self): def next(self):
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册