diff --git a/ppdet/data/reader.py b/ppdet/data/reader.py index 94819a4f229ebfabafdb6dd8158d0dc53467097e..79133abe9a7480391e19f0aa64070d8797f837af 100644 --- a/ppdet/data/reader.py +++ b/ppdet/data/reader.py @@ -136,7 +136,7 @@ class BaseDataLoader(object): else: self._batch_sampler = batch_sampler - self.loader = DataLoader( + self.dataloader = DataLoader( dataset=self.dataset, batch_sampler=self._batch_sampler, collate_fn=self._batch_transforms, @@ -145,7 +145,7 @@ class BaseDataLoader(object): return_list=return_list, use_buffer_reader=use_prefetch, use_shared_memory=False) - self.loader = iter(self.loader) + self.loader = iter(self.dataloader) return self @@ -163,6 +163,7 @@ class BaseDataLoader(object): data = next(self.loader) return {k: v for k, v in zip(self._fields, data)} except StopIteration: + self.loader = iter(self.dataloader) six.reraise(*sys.exc_info()) def next(self):