提交 67fcb0c9 编写于 作者: Y Youwei Song 提交者: Jiabin Yang

fix api fluid.cuda_pinned_places (#20424)

* fix api fluid.cuda_pinned_places, test=develop

* fix api fluid.cuda_pinned_places, test=develop
上级 85e1f215
......@@ -373,8 +373,8 @@ def cuda_pinned_places(device_count=None):
assert core.is_compiled_with_cuda(), \
"Not compiled with CUDA"
if device_count is None:
device_count = _cpu_num()
return [core.cuda_pinned_places()] * device_count
device_count = len(_cuda_ids())
return [core.CUDAPinnedPlace()] * device_count
class NameScope(object):
......
......@@ -275,6 +275,10 @@ class TestTensor(unittest.TestCase):
self.assertTrue(
isinstance(
tensor._mutable_data(place, dtype), numbers.Integral))
places = fluid.cuda_pinned_places()
self.assertTrue(
isinstance(
tensor._mutable_data(places[0], dtype), numbers.Integral))
if __name__ == '__main__':
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册