提交 67fcb0c9 编写于 作者: Y Youwei Song 提交者: Jiabin Yang

fix api fluid.cuda_pinned_places (#20424)

* fix api fluid.cuda_pinned_places, test=develop

* fix api fluid.cuda_pinned_places, test=develop
上级 85e1f215
...@@ -373,8 +373,8 @@ def cuda_pinned_places(device_count=None): ...@@ -373,8 +373,8 @@ def cuda_pinned_places(device_count=None):
assert core.is_compiled_with_cuda(), \ assert core.is_compiled_with_cuda(), \
"Not compiled with CUDA" "Not compiled with CUDA"
if device_count is None: if device_count is None:
device_count = _cpu_num() device_count = len(_cuda_ids())
return [core.cuda_pinned_places()] * device_count return [core.CUDAPinnedPlace()] * device_count
class NameScope(object): class NameScope(object):
......
...@@ -275,6 +275,10 @@ class TestTensor(unittest.TestCase): ...@@ -275,6 +275,10 @@ class TestTensor(unittest.TestCase):
self.assertTrue( self.assertTrue(
isinstance( isinstance(
tensor._mutable_data(place, dtype), numbers.Integral)) tensor._mutable_data(place, dtype), numbers.Integral))
places = fluid.cuda_pinned_places()
self.assertTrue(
isinstance(
tensor._mutable_data(places[0], dtype), numbers.Integral))
if __name__ == '__main__': if __name__ == '__main__':
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册