未验证 提交 21ea2976 编写于 作者: C Chen Weihang 提交者: GitHub

fix static dataloader default pin memory set (#26390)

* fix static dataloader default pin memory set

* fix related unittests
上级 8c922931
...@@ -1039,7 +1039,7 @@ class GeneratorLoader(DataLoaderBase): ...@@ -1039,7 +1039,7 @@ class GeneratorLoader(DataLoaderBase):
self._reader = core.create_py_reader( self._reader = core.create_py_reader(
self.queue, self._var_names, self._shapes, self._dtypes, self.queue, self._var_names, self._shapes, self._dtypes,
self._need_check_feed, self._places, self._use_double_buffer, self._need_check_feed, self._places, self._use_double_buffer,
self._drop_last, True) self._drop_last, False)
def _init_non_iterable(self): def _init_non_iterable(self):
lod_levels = [] lod_levels = []
......
...@@ -122,14 +122,8 @@ class TestBase(unittest.TestCase): ...@@ -122,14 +122,8 @@ class TestBase(unittest.TestCase):
label = item['label'] label = item['label']
assert image.shape() == [BATCH_SIZE, 784] assert image.shape() == [BATCH_SIZE, 784]
assert label.shape() == [BATCH_SIZE, 1] assert label.shape() == [BATCH_SIZE, 1]
if ps[i]._equals(fluid.CPUPlace()): assert image._place()._equals(ps[i])
assert image._place()._equals(fluid.CPUPlace()) assert label._place()._equals(ps[i])
assert label._place()._equals(fluid.CPUPlace())
else:
assert image._place()._equals(
fluid.CUDAPinnedPlace())
assert label._place()._equals(
fluid.CUDAPinnedPlace())
L, = exe.run(program=prog, L, = exe.run(program=prog,
feed=d, feed=d,
fetch_list=[loss], fetch_list=[loss],
......
...@@ -124,14 +124,8 @@ class TestBase(unittest.TestCase): ...@@ -124,14 +124,8 @@ class TestBase(unittest.TestCase):
label = item['label'] label = item['label']
assert image.shape() == [BATCH_SIZE, 784] assert image.shape() == [BATCH_SIZE, 784]
assert label.shape() == [BATCH_SIZE, 1] assert label.shape() == [BATCH_SIZE, 1]
if ps[i]._equals(fluid.CPUPlace()): assert image._place()._equals(ps[i])
assert image._place()._equals(fluid.CPUPlace()) assert label._place()._equals(ps[i])
assert label._place()._equals(fluid.CPUPlace())
else:
assert image._place()._equals(
fluid.CUDAPinnedPlace())
assert label._place()._equals(
fluid.CUDAPinnedPlace())
L, = exe.run(program=prog, L, = exe.run(program=prog,
feed=d, feed=d,
fetch_list=[loss], fetch_list=[loss],
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册