diff --git a/ppdet/data/reader.py b/ppdet/data/reader.py index 79133abe9a7480391e19f0aa64070d8797f837af..9be283e510c3f7ea1e78c165a1b499f5c44a928a 100644 --- a/ppdet/data/reader.py +++ b/ppdet/data/reader.py @@ -105,6 +105,9 @@ class BaseDataLoader(object): self._batch_transforms = Compose(batch_transforms, copy.deepcopy(self._fields), transform, num_classes) + self.output_fields = self._batch_transforms.output_fields + else: + self.output_fields = self._fields self.batch_size = batch_size self.shuffle = shuffle @@ -161,7 +164,7 @@ class BaseDataLoader(object): # data structure in paddle.io.DataLoader try: data = next(self.loader) - return {k: v for k, v in zip(self._fields, data)} + return {k: v for k, v in zip(self.output_fields, data)} except StopIteration: self.loader = iter(self.dataloader) six.reraise(*sys.exc_info())