未验证 提交 cf99c0d5 编写于 作者: L Linjie Chen 提交者: GitHub

Add cuda.device_count api (#34811)

* Add cuda device count api

* update coda format

* fix unittest error

* update code format

* update comment
上级 56c5e210
...@@ -22,6 +22,7 @@ __all__ = [ ...@@ -22,6 +22,7 @@ __all__ = [
'Event', 'Event',
'current_stream', 'current_stream',
'synchronize', 'synchronize',
'device_count',
] ]
...@@ -94,3 +95,25 @@ def synchronize(device=None): ...@@ -94,3 +95,25 @@ def synchronize(device=None):
raise ValueError("device type must be int or paddle.CUDAPlace") raise ValueError("device type must be int or paddle.CUDAPlace")
return core._device_synchronize(device_id) return core._device_synchronize(device_id)
def device_count():
'''
Return the number of GPUs available.
Returns:
int: the number of GPUs available.
Examples:
.. code-block:: python
import paddle
paddle.device.cuda.device_count()
'''
num_gpus = core.get_cuda_device_count() if hasattr(
core, 'get_cuda_device_count') else 0
return num_gpus
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import paddle
import unittest
class TestDeviceCount(unittest.TestCase):
def test_device_count(self):
s = paddle.device.cuda.device_count()
self.assertIsNotNone(s)
if __name__ == "__main__":
unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册