未验证 提交 21ea2976 编写于 作者: C Chen Weihang 提交者: GitHub

fix static dataloader default pin memory set (#26390)

* fix static dataloader default pin memory set

* fix related unittests
上级 8c922931
......@@ -1039,7 +1039,7 @@ class GeneratorLoader(DataLoaderBase):
self._reader = core.create_py_reader(
self.queue, self._var_names, self._shapes, self._dtypes,
self._need_check_feed, self._places, self._use_double_buffer,
self._drop_last, True)
self._drop_last, False)
def _init_non_iterable(self):
lod_levels = []
......
......@@ -122,14 +122,8 @@ class TestBase(unittest.TestCase):
label = item['label']
assert image.shape() == [BATCH_SIZE, 784]
assert label.shape() == [BATCH_SIZE, 1]
if ps[i]._equals(fluid.CPUPlace()):
assert image._place()._equals(fluid.CPUPlace())
assert label._place()._equals(fluid.CPUPlace())
else:
assert image._place()._equals(
fluid.CUDAPinnedPlace())
assert label._place()._equals(
fluid.CUDAPinnedPlace())
assert image._place()._equals(ps[i])
assert label._place()._equals(ps[i])
L, = exe.run(program=prog,
feed=d,
fetch_list=[loss],
......
......@@ -124,14 +124,8 @@ class TestBase(unittest.TestCase):
label = item['label']
assert image.shape() == [BATCH_SIZE, 784]
assert label.shape() == [BATCH_SIZE, 1]
if ps[i]._equals(fluid.CPUPlace()):
assert image._place()._equals(fluid.CPUPlace())
assert label._place()._equals(fluid.CPUPlace())
else:
assert image._place()._equals(
fluid.CUDAPinnedPlace())
assert label._place()._equals(
fluid.CUDAPinnedPlace())
assert image._place()._equals(ps[i])
assert label._place()._equals(ps[i])
L, = exe.run(program=prog,
feed=d,
fetch_list=[loss],
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册