diff --git a/python/paddle/fluid/reader.py b/python/paddle/fluid/reader.py index 969cac4a6631ed627fd384f46a3f8610c11ff15c..6b716ff82f10a4d7a47489a15568b6dfbf0f87fc 100644 --- a/python/paddle/fluid/reader.py +++ b/python/paddle/fluid/reader.py @@ -18,7 +18,7 @@ import six import numpy as np import threading import paddle -from .framework import Program, Variable, program_guard, default_main_program, default_startup_program, in_dygraph_mode, cpu_places +from .framework import Program, Variable, program_guard, default_main_program, default_startup_program, in_dygraph_mode, cpu_places, _current_expected_place from .executor import global_scope from .data_feeder import DataFeeder, BatchedTensorProvider from .multiprocess_utils import multiprocess_queue_set, CleanupFuncRegistrar, _cleanup_mmap, _cleanup, _set_SIGCHLD_handler @@ -671,12 +671,12 @@ class DygraphGeneratorLoader(DataLoaderBase): if not iterable: logging.warning( - "Please NOTE: dygraph can support iterable mode only. Change to iterable mode." + "Please NOTE: imperative mode can support iterable mode only. Change to iterable mode." ) self._iterable = True if not return_list: logging.warning( - "Please NOTE: dygraph can support return as list only. Change to return as list." + "Please NOTE: imperative mode can support return as list only. Change to return as list." ) self._return_list = True @@ -941,10 +941,11 @@ class DygraphGeneratorLoader(DataLoaderBase): def set_batch_generator(self, reader, places=None): self._batch_reader = reader - assert places is not None, "Places cannot be None when DataLoader is iterable" + if places is None: + places = _current_expected_place() self._places = _convert_places(places) assert len(self._places) == 1, \ - "Number of places must be 1 in dygraph mode" + "Number of places must be 1 in imperative mode" return self diff --git a/python/paddle/fluid/tests/unittests/test_imperative_data_loader_base.py b/python/paddle/fluid/tests/unittests/test_imperative_data_loader_base.py index a32f57c79399834eab4399f644f77b49fb506a8b..71b208e2cdd114ba527746d085cb066204c23777 100644 --- a/python/paddle/fluid/tests/unittests/test_imperative_data_loader_base.py +++ b/python/paddle/fluid/tests/unittests/test_imperative_data_loader_base.py @@ -41,6 +41,14 @@ class TestDygraphDataLoader(unittest.TestCase): self.epoch_num = 1 self.capacity = 5 + def iter_loader_data(self, loader): + for _ in range(self.epoch_num): + for image, label in loader(): + relu = fluid.layers.relu(image) + self.assertEqual(image.shape, [self.batch_size, 784]) + self.assertEqual(label.shape, [self.batch_size, 1]) + self.assertEqual(relu.shape, [self.batch_size, 784]) + def test_single_process_loader(self): with fluid.dygraph.guard(): loader = fluid.io.DataLoader.from_generator( @@ -49,12 +57,7 @@ class TestDygraphDataLoader(unittest.TestCase): sample_generator_creator(self.batch_size, self.batch_num), batch_size=self.batch_size, places=fluid.CPUPlace()) - for _ in range(self.epoch_num): - for image, label in loader(): - relu = fluid.layers.relu(image) - self.assertEqual(image.shape, [self.batch_size, 784]) - self.assertEqual(label.shape, [self.batch_size, 1]) - self.assertEqual(relu.shape, [self.batch_size, 784]) + self.iter_loader_data(loader) def test_multi_process_loader(self): with fluid.dygraph.guard(): @@ -64,12 +67,15 @@ class TestDygraphDataLoader(unittest.TestCase): sample_generator_creator(self.batch_size, self.batch_num), batch_size=self.batch_size, places=fluid.CPUPlace()) - for _ in range(self.epoch_num): - for image, label in loader(): - relu = fluid.layers.relu(image) - self.assertEqual(image.shape, [self.batch_size, 784]) - self.assertEqual(label.shape, [self.batch_size, 1]) - self.assertEqual(relu.shape, [self.batch_size, 784]) + self.iter_loader_data(loader) + + def test_generator_no_places(self): + with fluid.dygraph.guard(): + loader = fluid.io.DataLoader.from_generator(capacity=self.capacity) + loader.set_sample_generator( + sample_generator_creator(self.batch_size, self.batch_num), + batch_size=self.batch_size) + self.iter_loader_data(loader) if __name__ == '__main__':