未验证 提交 41bb70fb 编写于 作者: K Kaipeng Deng 提交者: GitHub

fix structure infos conflict in static return_list mode (#43291)

* fix structure infos conflict in static return_list mode. test=develop
上级 c43b54ad
......@@ -276,12 +276,12 @@ class _DataLoaderIterSingleProcess(_DataLoaderIterBase):
data = self._reader.read_next_list()
for i in range(len(data)):
data[i] = data[i]._move_to_list()
data = [
_restore_batch(d, s) for d, s in zip(
data, self._structure_infos[:len(self._places)])
structs = [
self._structure_infos.pop(0)
for _ in range(len(self._places))
]
self._structure_infos = self._structure_infos[
len(self._places):]
data = [_restore_batch(d, s) \
for d, s in zip(data, structs)]
# static graph organized data on multi-device with list, if
# place number is 1, there is only 1 device, extra the data
# from list for devices to be compatible with dygraph mode
......@@ -750,12 +750,12 @@ class _DataLoaderIterMultiProcess(_DataLoaderIterBase):
data = self._reader.read_next_list()
for i in range(len(data)):
data[i] = data[i]._move_to_list()
data = [
_restore_batch(d, s) for d, s in zip(
data, self._structure_infos[:len(self._places)])
structs = [
self._structure_infos.pop(0)
for _ in range(len(self._places))
]
self._structure_infos = self._structure_infos[
len(self._places):]
data = [_restore_batch(d, s) \
for d, s in zip(data, structs)]
# static graph organized data on multi-device with list, if
# place number is 1, there is only 1 device, extra the data
# from list for devices to be compatible with dygraph mode
......
......@@ -22,6 +22,7 @@ import unittest
import multiprocessing
import numpy as np
import paddle
import paddle.fluid as fluid
from paddle.io import Dataset, BatchSampler, DataLoader
......@@ -182,7 +183,7 @@ class TestStaticDataLoader(unittest.TestCase):
class TestStaticDataLoaderReturnList(unittest.TestCase):
def test_single_place(self):
def run_single_place(self, num_workers):
scope = fluid.Scope()
image = fluid.data(name='image',
shape=[None, IMAGE_SIZE],
......@@ -192,7 +193,7 @@ class TestStaticDataLoaderReturnList(unittest.TestCase):
dataset = RandomDataset(SAMPLE_NUM, CLASS_NUM)
dataloader = DataLoader(dataset,
feed_list=[image, label],
num_workers=0,
num_workers=num_workers,
batch_size=BATCH_SIZE,
drop_last=True,
return_list=True)
......@@ -203,7 +204,7 @@ class TestStaticDataLoaderReturnList(unittest.TestCase):
assert not isinstance(d[0], list)
assert not isinstance(d[1], list)
def test_multi_place(self):
def run_multi_place(self, num_workers):
scope = fluid.Scope()
image = fluid.data(name='image',
shape=[None, IMAGE_SIZE],
......@@ -213,7 +214,7 @@ class TestStaticDataLoaderReturnList(unittest.TestCase):
dataset = RandomDataset(SAMPLE_NUM, CLASS_NUM)
dataloader = DataLoader(dataset,
feed_list=[image, label],
num_workers=0,
num_workers=num_workers,
batch_size=BATCH_SIZE,
places=[fluid.CPUPlace()] * 2,
drop_last=True,
......@@ -225,6 +226,12 @@ class TestStaticDataLoaderReturnList(unittest.TestCase):
assert isinstance(d[0], list)
assert isinstance(d[1], list)
def test_main(self):
paddle.enable_static()
for num_workers in [0, 2]:
self.run_single_place(num_workers)
self.run_multi_place(num_workers)
class RandomBatchedDataset(Dataset):
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册