未验证 提交 531242be 编写于 作者: K Kaipeng Deng 提交者: GitHub

fix fields lost (#1914)

上级 3dae967f
......@@ -105,6 +105,9 @@ class BaseDataLoader(object):
self._batch_transforms = Compose(batch_transforms,
copy.deepcopy(self._fields),
transform, num_classes)
self.output_fields = self._batch_transforms.output_fields
else:
self.output_fields = self._fields
self.batch_size = batch_size
self.shuffle = shuffle
......@@ -161,7 +164,7 @@ class BaseDataLoader(object):
# data structure in paddle.io.DataLoader
try:
data = next(self.loader)
return {k: v for k, v in zip(self._fields, data)}
return {k: v for k, v in zip(self.output_fields, data)}
except StopIteration:
self.loader = iter(self.dataloader)
six.reraise(*sys.exc_info())
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册