未验证 提交 2e7e69d0 编写于 作者: C Chen Weihang 提交者: GitHub

remove imperative data loader place limit, test=develop (#24641)

上级 fa846da5
......@@ -18,7 +18,7 @@ import six
import numpy as np
import threading
import paddle
from .framework import Program, Variable, program_guard, default_main_program, default_startup_program, in_dygraph_mode, cpu_places
from .framework import Program, Variable, program_guard, default_main_program, default_startup_program, in_dygraph_mode, cpu_places, _current_expected_place
from .executor import global_scope
from .data_feeder import DataFeeder, BatchedTensorProvider
from .multiprocess_utils import multiprocess_queue_set, CleanupFuncRegistrar, _cleanup_mmap, _cleanup, _set_SIGCHLD_handler
......@@ -671,12 +671,12 @@ class DygraphGeneratorLoader(DataLoaderBase):
if not iterable:
logging.warning(
"Please NOTE: dygraph can support iterable mode only. Change to iterable mode."
"Please NOTE: imperative mode can support iterable mode only. Change to iterable mode."
)
self._iterable = True
if not return_list:
logging.warning(
"Please NOTE: dygraph can support return as list only. Change to return as list."
"Please NOTE: imperative mode can support return as list only. Change to return as list."
)
self._return_list = True
......@@ -941,10 +941,11 @@ class DygraphGeneratorLoader(DataLoaderBase):
def set_batch_generator(self, reader, places=None):
self._batch_reader = reader
assert places is not None, "Places cannot be None when DataLoader is iterable"
if places is None:
places = _current_expected_place()
self._places = _convert_places(places)
assert len(self._places) == 1, \
"Number of places must be 1 in dygraph mode"
"Number of places must be 1 in imperative mode"
return self
......
......@@ -41,6 +41,14 @@ class TestDygraphDataLoader(unittest.TestCase):
self.epoch_num = 1
self.capacity = 5
def iter_loader_data(self, loader):
for _ in range(self.epoch_num):
for image, label in loader():
relu = fluid.layers.relu(image)
self.assertEqual(image.shape, [self.batch_size, 784])
self.assertEqual(label.shape, [self.batch_size, 1])
self.assertEqual(relu.shape, [self.batch_size, 784])
def test_single_process_loader(self):
with fluid.dygraph.guard():
loader = fluid.io.DataLoader.from_generator(
......@@ -49,12 +57,7 @@ class TestDygraphDataLoader(unittest.TestCase):
sample_generator_creator(self.batch_size, self.batch_num),
batch_size=self.batch_size,
places=fluid.CPUPlace())
for _ in range(self.epoch_num):
for image, label in loader():
relu = fluid.layers.relu(image)
self.assertEqual(image.shape, [self.batch_size, 784])
self.assertEqual(label.shape, [self.batch_size, 1])
self.assertEqual(relu.shape, [self.batch_size, 784])
self.iter_loader_data(loader)
def test_multi_process_loader(self):
with fluid.dygraph.guard():
......@@ -64,12 +67,15 @@ class TestDygraphDataLoader(unittest.TestCase):
sample_generator_creator(self.batch_size, self.batch_num),
batch_size=self.batch_size,
places=fluid.CPUPlace())
for _ in range(self.epoch_num):
for image, label in loader():
relu = fluid.layers.relu(image)
self.assertEqual(image.shape, [self.batch_size, 784])
self.assertEqual(label.shape, [self.batch_size, 1])
self.assertEqual(relu.shape, [self.batch_size, 784])
self.iter_loader_data(loader)
def test_generator_no_places(self):
with fluid.dygraph.guard():
loader = fluid.io.DataLoader.from_generator(capacity=self.capacity)
loader.set_sample_generator(
sample_generator_creator(self.batch_size, self.batch_num),
batch_size=self.batch_size)
self.iter_loader_data(loader)
if __name__ == '__main__':
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册