未验证 提交 41bb70fb 编写于 作者: K Kaipeng Deng 提交者: GitHub

fix structure infos conflict in static return_list mode (#43291)

* fix structure infos conflict in static return_list mode. test=develop
上级 c43b54ad
...@@ -276,12 +276,12 @@ class _DataLoaderIterSingleProcess(_DataLoaderIterBase): ...@@ -276,12 +276,12 @@ class _DataLoaderIterSingleProcess(_DataLoaderIterBase):
data = self._reader.read_next_list() data = self._reader.read_next_list()
for i in range(len(data)): for i in range(len(data)):
data[i] = data[i]._move_to_list() data[i] = data[i]._move_to_list()
data = [ structs = [
_restore_batch(d, s) for d, s in zip( self._structure_infos.pop(0)
data, self._structure_infos[:len(self._places)]) for _ in range(len(self._places))
] ]
self._structure_infos = self._structure_infos[ data = [_restore_batch(d, s) \
len(self._places):] for d, s in zip(data, structs)]
# static graph organized data on multi-device with list, if # static graph organized data on multi-device with list, if
# place number is 1, there is only 1 device, extra the data # place number is 1, there is only 1 device, extra the data
# from list for devices to be compatible with dygraph mode # from list for devices to be compatible with dygraph mode
...@@ -750,12 +750,12 @@ class _DataLoaderIterMultiProcess(_DataLoaderIterBase): ...@@ -750,12 +750,12 @@ class _DataLoaderIterMultiProcess(_DataLoaderIterBase):
data = self._reader.read_next_list() data = self._reader.read_next_list()
for i in range(len(data)): for i in range(len(data)):
data[i] = data[i]._move_to_list() data[i] = data[i]._move_to_list()
data = [ structs = [
_restore_batch(d, s) for d, s in zip( self._structure_infos.pop(0)
data, self._structure_infos[:len(self._places)]) for _ in range(len(self._places))
] ]
self._structure_infos = self._structure_infos[ data = [_restore_batch(d, s) \
len(self._places):] for d, s in zip(data, structs)]
# static graph organized data on multi-device with list, if # static graph organized data on multi-device with list, if
# place number is 1, there is only 1 device, extra the data # place number is 1, there is only 1 device, extra the data
# from list for devices to be compatible with dygraph mode # from list for devices to be compatible with dygraph mode
......
...@@ -22,6 +22,7 @@ import unittest ...@@ -22,6 +22,7 @@ import unittest
import multiprocessing import multiprocessing
import numpy as np import numpy as np
import paddle
import paddle.fluid as fluid import paddle.fluid as fluid
from paddle.io import Dataset, BatchSampler, DataLoader from paddle.io import Dataset, BatchSampler, DataLoader
...@@ -182,7 +183,7 @@ class TestStaticDataLoader(unittest.TestCase): ...@@ -182,7 +183,7 @@ class TestStaticDataLoader(unittest.TestCase):
class TestStaticDataLoaderReturnList(unittest.TestCase): class TestStaticDataLoaderReturnList(unittest.TestCase):
def test_single_place(self): def run_single_place(self, num_workers):
scope = fluid.Scope() scope = fluid.Scope()
image = fluid.data(name='image', image = fluid.data(name='image',
shape=[None, IMAGE_SIZE], shape=[None, IMAGE_SIZE],
...@@ -192,7 +193,7 @@ class TestStaticDataLoaderReturnList(unittest.TestCase): ...@@ -192,7 +193,7 @@ class TestStaticDataLoaderReturnList(unittest.TestCase):
dataset = RandomDataset(SAMPLE_NUM, CLASS_NUM) dataset = RandomDataset(SAMPLE_NUM, CLASS_NUM)
dataloader = DataLoader(dataset, dataloader = DataLoader(dataset,
feed_list=[image, label], feed_list=[image, label],
num_workers=0, num_workers=num_workers,
batch_size=BATCH_SIZE, batch_size=BATCH_SIZE,
drop_last=True, drop_last=True,
return_list=True) return_list=True)
...@@ -203,7 +204,7 @@ class TestStaticDataLoaderReturnList(unittest.TestCase): ...@@ -203,7 +204,7 @@ class TestStaticDataLoaderReturnList(unittest.TestCase):
assert not isinstance(d[0], list) assert not isinstance(d[0], list)
assert not isinstance(d[1], list) assert not isinstance(d[1], list)
def test_multi_place(self): def run_multi_place(self, num_workers):
scope = fluid.Scope() scope = fluid.Scope()
image = fluid.data(name='image', image = fluid.data(name='image',
shape=[None, IMAGE_SIZE], shape=[None, IMAGE_SIZE],
...@@ -213,7 +214,7 @@ class TestStaticDataLoaderReturnList(unittest.TestCase): ...@@ -213,7 +214,7 @@ class TestStaticDataLoaderReturnList(unittest.TestCase):
dataset = RandomDataset(SAMPLE_NUM, CLASS_NUM) dataset = RandomDataset(SAMPLE_NUM, CLASS_NUM)
dataloader = DataLoader(dataset, dataloader = DataLoader(dataset,
feed_list=[image, label], feed_list=[image, label],
num_workers=0, num_workers=num_workers,
batch_size=BATCH_SIZE, batch_size=BATCH_SIZE,
places=[fluid.CPUPlace()] * 2, places=[fluid.CPUPlace()] * 2,
drop_last=True, drop_last=True,
...@@ -225,6 +226,12 @@ class TestStaticDataLoaderReturnList(unittest.TestCase): ...@@ -225,6 +226,12 @@ class TestStaticDataLoaderReturnList(unittest.TestCase):
assert isinstance(d[0], list) assert isinstance(d[0], list)
assert isinstance(d[1], list) assert isinstance(d[1], list)
def test_main(self):
paddle.enable_static()
for num_workers in [0, 2]:
self.run_single_place(num_workers)
self.run_multi_place(num_workers)
class RandomBatchedDataset(Dataset): class RandomBatchedDataset(Dataset):
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册