未验证 提交 4671d85a 编写于 作者: K Kaipeng Deng 提交者: GitHub

fix DataLoader return same format between static & dynamic in single mode (#28176)

* fix DataLoader return same format between static & dynamic in single mode. test=develop
上级 7db747d9
...@@ -341,6 +341,12 @@ class _DataLoaderIterSingleProcess(_DataLoaderIterBase): ...@@ -341,6 +341,12 @@ class _DataLoaderIterSingleProcess(_DataLoaderIterBase):
return self._reader.read_next_var_list() return self._reader.read_next_var_list()
else: else:
if self._return_list: if self._return_list:
# static graph organized data on multi-device with list, if
# place number is 1, there is only 1 device, extra the data
# from list for devices to be compatible with dygraph mode
if len(self._places) == 1:
return self._reader.read_next_list()[0]
else:
return self._reader.read_next_list() return self._reader.read_next_list()
else: else:
return self._reader.read_next() return self._reader.read_next()
......
...@@ -170,5 +170,50 @@ class TestStaticDataLoader(unittest.TestCase): ...@@ -170,5 +170,50 @@ class TestStaticDataLoader(unittest.TestCase):
self.assertLess(diff, 1e-2) self.assertLess(diff, 1e-2)
class TestStaticDataLoaderReturnList(unittest.TestCase):
def test_single_place(self):
scope = fluid.Scope()
image = fluid.data(
name='image', shape=[None, IMAGE_SIZE], dtype='float32')
label = fluid.data(name='label', shape=[None, 1], dtype='int64')
with fluid.scope_guard(scope):
dataset = RandomDataset(SAMPLE_NUM, CLASS_NUM)
dataloader = DataLoader(
dataset,
feed_list=[image, label],
num_workers=0,
batch_size=BATCH_SIZE,
drop_last=True,
return_list=True)
for d in dataloader:
assert isinstance(d, list)
assert len(d) == 2
assert not isinstance(d[0], list)
assert not isinstance(d[1], list)
def test_multi_place(self):
scope = fluid.Scope()
image = fluid.data(
name='image', shape=[None, IMAGE_SIZE], dtype='float32')
label = fluid.data(name='label', shape=[None, 1], dtype='int64')
with fluid.scope_guard(scope):
dataset = RandomDataset(SAMPLE_NUM, CLASS_NUM)
dataloader = DataLoader(
dataset,
feed_list=[image, label],
num_workers=0,
batch_size=BATCH_SIZE,
places=[fluid.CPUPlace()] * 2,
drop_last=True,
return_list=True)
for d in dataloader:
assert isinstance(d, list)
assert len(d) == 2
assert isinstance(d[0], list)
assert isinstance(d[1], list)
if __name__ == '__main__': if __name__ == '__main__':
unittest.main() unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册