From e700ffdc8fd59188b1ece905f0a26393dbfa13e2 Mon Sep 17 00:00:00 2001 From: Kaipeng Deng Date: Fri, 24 Jun 2022 08:52:48 +0800 Subject: [PATCH] [cherry pick] fix structure infos conflict in static return_list mode (#43691) * fix structure infos conflict in static return_list mode. test=develop * fix format. test=develop * fix format. test=develop --- .../fluid/dataloader/dataloader_iter.py | 22 +++++++++---------- .../test_multiprocess_dataloader_static.py | 15 +++++++++---- 2 files changed, 21 insertions(+), 16 deletions(-) diff --git a/python/paddle/fluid/dataloader/dataloader_iter.py b/python/paddle/fluid/dataloader/dataloader_iter.py index 3deff6e2d40..b57f43aaeb8 100644 --- a/python/paddle/fluid/dataloader/dataloader_iter.py +++ b/python/paddle/fluid/dataloader/dataloader_iter.py @@ -276,13 +276,12 @@ class _DataLoaderIterSingleProcess(_DataLoaderIterBase): data = self._reader.read_next_list() for i in range(len(data)): data[i] = data[i]._move_to_list() - data = [ - _restore_batch(d, s) - for d, s in zip(data, self._structure_infos[:len( - self._places)]) + structs = [ + self._structure_infos.pop(0) + for _ in range(len(self._places)) ] - self._structure_infos = self._structure_infos[len( - self._places):] + data = [_restore_batch(d, s) \ + for d, s in zip(data, structs)] # static graph organized data on multi-device with list, if # place number is 1, there is only 1 device, extra the data # from list for devices to be compatible with dygraph mode @@ -751,13 +750,12 @@ class _DataLoaderIterMultiProcess(_DataLoaderIterBase): data = self._reader.read_next_list() for i in range(len(data)): data[i] = data[i]._move_to_list() - data = [ - _restore_batch(d, s) - for d, s in zip(data, self._structure_infos[:len( - self._places)]) + structs = [ + self._structure_infos.pop(0) + for _ in range(len(self._places)) ] - self._structure_infos = self._structure_infos[len( - self._places):] + data = [_restore_batch(d, s) \ + for d, s in zip(data, structs)] # static graph organized data on multi-device with list, if # place number is 1, there is only 1 device, extra the data # from list for devices to be compatible with dygraph mode diff --git a/python/paddle/fluid/tests/unittests/test_multiprocess_dataloader_static.py b/python/paddle/fluid/tests/unittests/test_multiprocess_dataloader_static.py index 9f73ee041e0..ff280375499 100644 --- a/python/paddle/fluid/tests/unittests/test_multiprocess_dataloader_static.py +++ b/python/paddle/fluid/tests/unittests/test_multiprocess_dataloader_static.py @@ -22,6 +22,7 @@ import unittest import multiprocessing import numpy as np +import paddle import paddle.fluid as fluid from paddle.io import Dataset, BatchSampler, DataLoader @@ -181,7 +182,7 @@ class TestStaticDataLoader(unittest.TestCase): class TestStaticDataLoaderReturnList(unittest.TestCase): - def test_single_place(self): + def run_single_place(self, num_workers): scope = fluid.Scope() image = fluid.data( name='image', shape=[None, IMAGE_SIZE], dtype='float32') @@ -191,7 +192,7 @@ class TestStaticDataLoaderReturnList(unittest.TestCase): dataloader = DataLoader( dataset, feed_list=[image, label], - num_workers=0, + num_workers=num_workers, batch_size=BATCH_SIZE, drop_last=True, return_list=True) @@ -202,7 +203,7 @@ class TestStaticDataLoaderReturnList(unittest.TestCase): assert not isinstance(d[0], list) assert not isinstance(d[1], list) - def test_multi_place(self): + def run_multi_place(self, num_workers): scope = fluid.Scope() image = fluid.data( name='image', shape=[None, IMAGE_SIZE], dtype='float32') @@ -212,7 +213,7 @@ class TestStaticDataLoaderReturnList(unittest.TestCase): dataloader = DataLoader( dataset, feed_list=[image, label], - num_workers=0, + num_workers=num_workers, batch_size=BATCH_SIZE, places=[fluid.CPUPlace()] * 2, drop_last=True, @@ -224,6 +225,12 @@ class TestStaticDataLoaderReturnList(unittest.TestCase): assert isinstance(d[0], list) assert isinstance(d[1], list) + def test_main(self): + paddle.enable_static() + for num_workers in [0, 2]: + self.run_single_place(num_workers) + self.run_multi_place(num_workers) + class RandomBatchedDataset(Dataset): def __init__(self, sample_num, class_num): -- GitLab