未验证 提交 e700ffdc 编写于 作者: K Kaipeng Deng 提交者: GitHub

[cherry pick] fix structure infos conflict in static return_list mode (#43691)

* fix structure infos conflict in static return_list mode. test=develop

* fix format. test=develop

* fix format. test=develop
上级 096eb801
......@@ -276,13 +276,12 @@ class _DataLoaderIterSingleProcess(_DataLoaderIterBase):
data = self._reader.read_next_list()
for i in range(len(data)):
data[i] = data[i]._move_to_list()
data = [
_restore_batch(d, s)
for d, s in zip(data, self._structure_infos[:len(
self._places)])
structs = [
self._structure_infos.pop(0)
for _ in range(len(self._places))
]
self._structure_infos = self._structure_infos[len(
self._places):]
data = [_restore_batch(d, s) \
for d, s in zip(data, structs)]
# static graph organized data on multi-device with list, if
# place number is 1, there is only 1 device, extra the data
# from list for devices to be compatible with dygraph mode
......@@ -751,13 +750,12 @@ class _DataLoaderIterMultiProcess(_DataLoaderIterBase):
data = self._reader.read_next_list()
for i in range(len(data)):
data[i] = data[i]._move_to_list()
data = [
_restore_batch(d, s)
for d, s in zip(data, self._structure_infos[:len(
self._places)])
structs = [
self._structure_infos.pop(0)
for _ in range(len(self._places))
]
self._structure_infos = self._structure_infos[len(
self._places):]
data = [_restore_batch(d, s) \
for d, s in zip(data, structs)]
# static graph organized data on multi-device with list, if
# place number is 1, there is only 1 device, extra the data
# from list for devices to be compatible with dygraph mode
......
......@@ -22,6 +22,7 @@ import unittest
import multiprocessing
import numpy as np
import paddle
import paddle.fluid as fluid
from paddle.io import Dataset, BatchSampler, DataLoader
......@@ -181,7 +182,7 @@ class TestStaticDataLoader(unittest.TestCase):
class TestStaticDataLoaderReturnList(unittest.TestCase):
def test_single_place(self):
def run_single_place(self, num_workers):
scope = fluid.Scope()
image = fluid.data(
name='image', shape=[None, IMAGE_SIZE], dtype='float32')
......@@ -191,7 +192,7 @@ class TestStaticDataLoaderReturnList(unittest.TestCase):
dataloader = DataLoader(
dataset,
feed_list=[image, label],
num_workers=0,
num_workers=num_workers,
batch_size=BATCH_SIZE,
drop_last=True,
return_list=True)
......@@ -202,7 +203,7 @@ class TestStaticDataLoaderReturnList(unittest.TestCase):
assert not isinstance(d[0], list)
assert not isinstance(d[1], list)
def test_multi_place(self):
def run_multi_place(self, num_workers):
scope = fluid.Scope()
image = fluid.data(
name='image', shape=[None, IMAGE_SIZE], dtype='float32')
......@@ -212,7 +213,7 @@ class TestStaticDataLoaderReturnList(unittest.TestCase):
dataloader = DataLoader(
dataset,
feed_list=[image, label],
num_workers=0,
num_workers=num_workers,
batch_size=BATCH_SIZE,
places=[fluid.CPUPlace()] * 2,
drop_last=True,
......@@ -224,6 +225,12 @@ class TestStaticDataLoaderReturnList(unittest.TestCase):
assert isinstance(d[0], list)
assert isinstance(d[1], list)
def test_main(self):
paddle.enable_static()
for num_workers in [0, 2]:
self.run_single_place(num_workers)
self.run_multi_place(num_workers)
class RandomBatchedDataset(Dataset):
def __init__(self, sample_num, class_num):
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册