From 21ea2976ddbad26d345d3171425de942f336a093 Mon Sep 17 00:00:00 2001 From: Chen Weihang Date: Fri, 21 Aug 2020 09:57:22 +0800 Subject: [PATCH] fix static dataloader default pin memory set (#26390) * fix static dataloader default pin memory set * fix related unittests --- python/paddle/fluid/reader.py | 2 +- .../fluid/tests/unittests/test_decoupled_py_reader.py | 10 ++-------- .../fluid/tests/unittests/test_generator_dataloader.py | 10 ++-------- 3 files changed, 5 insertions(+), 17 deletions(-) diff --git a/python/paddle/fluid/reader.py b/python/paddle/fluid/reader.py index 7e633756fc..76c95be75d 100644 --- a/python/paddle/fluid/reader.py +++ b/python/paddle/fluid/reader.py @@ -1039,7 +1039,7 @@ class GeneratorLoader(DataLoaderBase): self._reader = core.create_py_reader( self.queue, self._var_names, self._shapes, self._dtypes, self._need_check_feed, self._places, self._use_double_buffer, - self._drop_last, True) + self._drop_last, False) def _init_non_iterable(self): lod_levels = [] diff --git a/python/paddle/fluid/tests/unittests/test_decoupled_py_reader.py b/python/paddle/fluid/tests/unittests/test_decoupled_py_reader.py index f8cb6170be..a16f21c0f9 100644 --- a/python/paddle/fluid/tests/unittests/test_decoupled_py_reader.py +++ b/python/paddle/fluid/tests/unittests/test_decoupled_py_reader.py @@ -122,14 +122,8 @@ class TestBase(unittest.TestCase): label = item['label'] assert image.shape() == [BATCH_SIZE, 784] assert label.shape() == [BATCH_SIZE, 1] - if ps[i]._equals(fluid.CPUPlace()): - assert image._place()._equals(fluid.CPUPlace()) - assert label._place()._equals(fluid.CPUPlace()) - else: - assert image._place()._equals( - fluid.CUDAPinnedPlace()) - assert label._place()._equals( - fluid.CUDAPinnedPlace()) + assert image._place()._equals(ps[i]) + assert label._place()._equals(ps[i]) L, = exe.run(program=prog, feed=d, fetch_list=[loss], diff --git a/python/paddle/fluid/tests/unittests/test_generator_dataloader.py b/python/paddle/fluid/tests/unittests/test_generator_dataloader.py index 6660bfb0c7..4f0beb8c0d 100644 --- a/python/paddle/fluid/tests/unittests/test_generator_dataloader.py +++ b/python/paddle/fluid/tests/unittests/test_generator_dataloader.py @@ -124,14 +124,8 @@ class TestBase(unittest.TestCase): label = item['label'] assert image.shape() == [BATCH_SIZE, 784] assert label.shape() == [BATCH_SIZE, 1] - if ps[i]._equals(fluid.CPUPlace()): - assert image._place()._equals(fluid.CPUPlace()) - assert label._place()._equals(fluid.CPUPlace()) - else: - assert image._place()._equals( - fluid.CUDAPinnedPlace()) - assert label._place()._equals( - fluid.CUDAPinnedPlace()) + assert image._place()._equals(ps[i]) + assert label._place()._equals(ps[i]) L, = exe.run(program=prog, feed=d, fetch_list=[loss], -- GitLab