diff --git a/python/paddle/device/cuda/__init__.py b/python/paddle/device/cuda/__init__.py index 4a65f53fe58d02fae49fc006ada72eb76def55d9..970fb35bfaeb1a8155d56ec6d23ed779b64f21d8 100644 --- a/python/paddle/device/cuda/__init__.py +++ b/python/paddle/device/cuda/__init__.py @@ -28,6 +28,8 @@ __all__ = [ 'empty_cache', 'stream_guard', 'get_device_properties', + 'get_device_name', + 'get_device_capability', ] @@ -271,3 +273,61 @@ def get_device_properties(device=None): device_id = -1 return core.get_device_properties(device_id) + + +def get_device_name(device=None): + ''' + Return the name of the device which is got from CUDA function `cudaDeviceProp `_. + + Parameters: + device(paddle.CUDAPlace|int, optional): The device or the ID of the device. If device is None (default), the device is the current device. + + Returns: + str: The name of the device. + + Examples: + + .. code-block:: python + + # required: gpu + + import paddle + + paddle.device.cuda.get_device_name() + + paddle.device.cuda.get_device_name(0) + + paddle.device.cuda.get_device_name(paddle.CUDAPlace(0)) + + ''' + + return get_device_properties(device).name + + +def get_device_capability(device=None): + ''' + Return the major and minor revision numbers defining the device's compute capability which are got from CUDA function `cudaDeviceProp `_. + + Parameters: + device(paddle.CUDAPlace|int, optional): The device or the ID of the device. If device is None (default), the device is the current device. + + Returns: + tuple(int,int): the major and minor revision numbers defining the device's compute capability. + + Examples: + + .. code-block:: python + + # required: gpu + + import paddle + + paddle.device.cuda.get_device_capability() + + paddle.device.cuda.get_device_capability(0) + + paddle.device.cuda.get_device_capability(paddle.CUDAPlace(0)) + + ''' + prop = get_device_properties(device) + return prop.major, prop.minor diff --git a/python/paddle/fluid/tests/unittests/test_cuda_device_name_capability.py b/python/paddle/fluid/tests/unittests/test_cuda_device_name_capability.py new file mode 100644 index 0000000000000000000000000000000000000000..88f71f28412e34a14c3f9ae627db5711e771d86f --- /dev/null +++ b/python/paddle/fluid/tests/unittests/test_cuda_device_name_capability.py @@ -0,0 +1,55 @@ +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import paddle +import unittest + + +class TestDeviceName(unittest.TestCase): + def test_device_name_default(self): + if paddle.is_compiled_with_cuda(): + name = paddle.device.cuda.get_device_name() + self.assertIsNotNone(name) + + def test_device_name_int(self): + if paddle.is_compiled_with_cuda(): + name = paddle.device.cuda.get_device_name(0) + self.assertIsNotNone(name) + + def test_device_name_CUDAPlace(self): + if paddle.is_compiled_with_cuda(): + name = paddle.device.cuda.get_device_name(paddle.CUDAPlace(0)) + self.assertIsNotNone(name) + + +class TestDeviceCapability(unittest.TestCase): + def test_device_capability_default(self): + if paddle.is_compiled_with_cuda(): + capability = paddle.device.cuda.get_device_capability() + self.assertIsNotNone(capability) + + def test_device_capability_int(self): + if paddle.is_compiled_with_cuda(): + capability = paddle.device.cuda.get_device_capability(0) + self.assertIsNotNone(capability) + + def test_device_capability_CUDAPlace(self): + if paddle.is_compiled_with_cuda(): + capability = paddle.device.cuda.get_device_capability( + paddle.CUDAPlace(0)) + self.assertIsNotNone(capability) + + +if __name__ == "__main__": + unittest.main()