未验证 提交 4cbed9e5 编写于 作者: Y Yanxing Shi 提交者: GitHub

Add paddle.device.cuda.get_device_properties (#35661)

* Initial Commit

* add unittest and add error information

* modify doc

* fix some error

* fix some word

* fix bug cudaDeviceProp* and modify error explanation

* fix cudaDeviceProp* error and unnitest samples

* fix hip error and PADDLE_WITH_HIP

* update style

* fix error is_compiled_with_cuda

* fix paddle.device.cuda.get_device_properties

* fix error for multi thread safe

* update style

* merge conflict

* modify after mentor review

* update style

* delete word

* fix unittest error for windows

* support string input and modify some code

* modify doc to support string input

* fix error for express information

* fix error for express information

* fix unnitest for windows

* fix device.startswith('gpu:')

* format error and doc

* fix after review

* format code

* fix error for doc compile

* fix error for doc compile

* fix error for doc compile

* fix error for doc compile

* fix error for doc compile

* fix py2 error

* fix wrong words and doc

* fix _gpuDeviceProperties
上级 6f18b041
...@@ -14,6 +14,8 @@ limitations under the License. */ ...@@ -14,6 +14,8 @@ limitations under the License. */
#include "paddle/fluid/platform/gpu_info.h" #include "paddle/fluid/platform/gpu_info.h"
#include <cstdlib> #include <cstdlib>
#include <mutex>
#include <vector>
#include "gflags/gflags.h" #include "gflags/gflags.h"
#include "paddle/fluid/platform/cuda_device_guard.h" #include "paddle/fluid/platform/cuda_device_guard.h"
...@@ -39,6 +41,10 @@ DECLARE_uint64(gpu_memory_limit_mb); ...@@ -39,6 +41,10 @@ DECLARE_uint64(gpu_memory_limit_mb);
constexpr static float fraction_reserve_gpu_memory = 0.05f; constexpr static float fraction_reserve_gpu_memory = 0.05f;
static std::once_flag g_device_props_size_init_flag;
static std::vector<std::unique_ptr<std::once_flag>> g_device_props_init_flags;
static std::vector<paddle::gpuDeviceProp> g_device_props;
USE_GPU_MEM_STAT; USE_GPU_MEM_STAT;
namespace paddle { namespace paddle {
namespace platform { namespace platform {
...@@ -297,6 +303,44 @@ std::vector<int> GetSelectedDevices() { ...@@ -297,6 +303,44 @@ std::vector<int> GetSelectedDevices() {
return devices; return devices;
} }
const gpuDeviceProp &GetDeviceProperties(int id) {
std::call_once(g_device_props_size_init_flag, [&] {
int gpu_num = 0;
gpu_num = platform::GetCUDADeviceCount();
g_device_props_init_flags.resize(gpu_num);
g_device_props.resize(gpu_num);
for (int i = 0; i < gpu_num; ++i) {
g_device_props_init_flags[i] = std::make_unique<std::once_flag>();
}
});
if (id == -1) {
id = platform::GetCurrentDeviceId();
}
if (id < 0 || id >= static_cast<int>(g_device_props.size())) {
PADDLE_THROW(platform::errors::OutOfRange(
"The device id %d is out of range [0, %d), where %d is the number of "
"devices on this machine. Because the device id should be greater than "
"or equal to zero and smaller than the number of gpus. Please input "
"appropriate device again!",
id, static_cast<int>(g_device_props.size()),
static_cast<int>(g_device_props.size())));
}
std::call_once(*(g_device_props_init_flags[id]), [&] {
#ifdef PADDLE_WITH_CUDA
PADDLE_ENFORCE_CUDA_SUCCESS(
cudaGetDeviceProperties(&g_device_props[id], id));
#else
PADDLE_ENFORCE_CUDA_SUCCESS(
hipGetDeviceProperties(&g_device_props[id], id));
#endif
});
return g_device_props[id];
}
void SetDeviceId(int id) { void SetDeviceId(int id) {
// TODO(qijun): find a better way to cache the cuda device count // TODO(qijun): find a better way to cache the cuda device count
PADDLE_ENFORCE_LT(id, GetCUDADeviceCount(), PADDLE_ENFORCE_LT(id, GetCUDADeviceCount(),
......
...@@ -67,6 +67,9 @@ dim3 GetGpuMaxGridDimSize(int); ...@@ -67,6 +67,9 @@ dim3 GetGpuMaxGridDimSize(int);
//! Get a list of device ids from environment variable or use all. //! Get a list of device ids from environment variable or use all.
std::vector<int> GetSelectedDevices(); std::vector<int> GetSelectedDevices();
//! Get the properties of the ith GPU device.
const gpuDeviceProp &GetDeviceProperties(int id);
//! Set the GPU device id for next execution. //! Set the GPU device id for next execution.
void SetDeviceId(int device_id); void SetDeviceId(int device_id);
......
...@@ -27,11 +27,13 @@ namespace paddle { ...@@ -27,11 +27,13 @@ namespace paddle {
using gpuStream_t = hipStream_t; using gpuStream_t = hipStream_t;
using gpuError_t = hipError_t; using gpuError_t = hipError_t;
using gpuEvent_t = hipEvent_t; using gpuEvent_t = hipEvent_t;
using gpuDeviceProp = hipDeviceProp_t;
#else #else
#define gpuSuccess cudaSuccess #define gpuSuccess cudaSuccess
using gpuStream_t = cudaStream_t; using gpuStream_t = cudaStream_t;
using gpuError_t = cudaError_t; using gpuError_t = cudaError_t;
using gpuEvent_t = cudaEvent_t; using gpuEvent_t = cudaEvent_t;
using gpuDeviceProp = cudaDeviceProp;
#endif #endif
} // namespace paddle } // namespace paddle
...@@ -2285,6 +2285,31 @@ All parameter, weight, gradient are variables in Paddle. ...@@ -2285,6 +2285,31 @@ All parameter, weight, gradient are variables in Paddle.
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP) #if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
m.def("get_cuda_device_count", platform::GetCUDADeviceCount); m.def("get_cuda_device_count", platform::GetCUDADeviceCount);
m.def("cuda_empty_cache", platform::EmptyCache); m.def("cuda_empty_cache", platform::EmptyCache);
m.def("get_device_properties",
[](int id) -> const gpuDeviceProp & {
return platform::GetDeviceProperties(id);
},
py::return_value_policy::copy);
py::class_<gpuDeviceProp>(m, "_gpuDeviceProperties")
.def_readonly("name", &gpuDeviceProp::name)
.def_readonly("major", &gpuDeviceProp::major)
.def_readonly("minor", &gpuDeviceProp::minor)
.def_readonly("is_multi_gpu_board", &gpuDeviceProp::isMultiGpuBoard)
.def_readonly("is_integrated", &gpuDeviceProp::integrated)
.def_readonly("multi_processor_count",
&gpuDeviceProp::multiProcessorCount)
.def_readonly("total_memory", &gpuDeviceProp::totalGlobalMem)
.def("__repr__", [](const gpuDeviceProp &gpu_device_prop) {
std::ostringstream stream;
stream << "_gpuDeviceProperties(name='" << gpu_device_prop.name
<< "', major=" << gpu_device_prop.major
<< ", minor=" << gpu_device_prop.minor << ", total_memory="
<< gpu_device_prop.totalGlobalMem / (1024 * 1024)
<< "MB, multi_processor_count="
<< gpu_device_prop.multiProcessorCount << ")";
return stream.str();
});
#if !defined(PADDLE_WITH_HIP) && !defined(_WIN32) #if !defined(PADDLE_WITH_HIP) && !defined(_WIN32)
m.def("nvprof_init", platform::CudaProfilerInit); m.def("nvprof_init", platform::CudaProfilerInit);
......
...@@ -27,6 +27,7 @@ __all__ = [ ...@@ -27,6 +27,7 @@ __all__ = [
'device_count', 'device_count',
'empty_cache', 'empty_cache',
'stream_guard', 'stream_guard',
'get_device_properties',
] ]
...@@ -204,3 +205,69 @@ def stream_guard(stream): ...@@ -204,3 +205,69 @@ def stream_guard(stream):
yield yield
finally: finally:
stream = _set_current_stream(pre_stream) stream = _set_current_stream(pre_stream)
def get_device_properties(device=None):
'''
Return the properties of given device.
Args:
device(paddle.CUDAPlace or int or str): The device, the id of the device
or the string name of device like 'gpu:x' which to get the properties of
the device from. If device is None, the device is the current device.
Default: None.
Returns:
_gpuDeviceProperties: the properties of the device which include ASCII string
identifying device, major compute capability, minor compute capability, global
memory available on device and the number of multiprocessors on the device.
Examples:
.. code-block:: python
# required: gpu
import paddle
paddle.device.cuda.get_device_properties()
# _gpuDeviceProperties(name='A100-SXM4-40GB', major=8, minor=0, total_memory=40536MB, multi_processor_count=108)
paddle.device.cuda.get_device_properties(0)
# _gpuDeviceProperties(name='A100-SXM4-40GB', major=8, minor=0, total_memory=40536MB, multi_processor_count=108)
paddle.device.cuda.get_device_properties('gpu:0')
# _gpuDeviceProperties(name='A100-SXM4-40GB', major=8, minor=0, total_memory=40536MB, multi_processor_count=108)
paddle.device.cuda.get_device_properties(paddle.CUDAPlace(0))
# _gpuDeviceProperties(name='A100-SXM4-40GB', major=8, minor=0, total_memory=40536MB, multi_processor_count=108)
'''
if not core.is_compiled_with_cuda():
raise ValueError(
"The API paddle.device.cuda.get_device_properties is not supported in "
"CPU-only PaddlePaddle. Please reinstall PaddlePaddle with GPU support "
"to call this API.")
if device is not None:
if isinstance(device, int):
device_id = device
elif isinstance(device, core.CUDAPlace):
device_id = device.get_device_id()
elif isinstance(device, str):
if device.startswith('gpu:'):
device_id = int(device[4:])
else:
raise ValueError(
"The current string {} is not expected. Because paddle.device."
"cuda.get_device_properties only support string which is like 'gpu:x'. "
"Please input appropriate string again!".format(device))
else:
raise ValueError(
"The device type {} is not expected. Because paddle.device.cuda."
"get_device_properties only support int, str or paddle.CUDAPlace. "
"Please input appropriate device again!".format(device))
else:
device_id = -1
return core.get_device_properties(device_id)
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import paddle
import unittest
from paddle.fluid import core
from paddle.device.cuda import device_count, get_device_properties
class TestGetDeviceProperties(unittest.TestCase):
def test_get_device_properties_default(self):
if core.is_compiled_with_cuda():
props = get_device_properties()
self.assertIsNotNone(props)
def test_get_device_properties_str(self):
if core.is_compiled_with_cuda():
props = get_device_properties('gpu:0')
self.assertIsNotNone(props)
def test_get_device_properties_int(self):
if core.is_compiled_with_cuda():
gpu_num = device_count()
for i in range(gpu_num):
props = get_device_properties(i)
self.assertIsNotNone(props)
def test_get_device_properties_CUDAPlace(self):
if core.is_compiled_with_cuda():
device = core.CUDAPlace(0)
props = get_device_properties(device)
self.assertIsNotNone(props)
class TestGetDevicePropertiesError(unittest.TestCase):
def test_error_api(self):
if core.is_compiled_with_cuda():
def test_device_indexError_error():
device_error = device_count() + 1
props = get_device_properties(device_error)
self.assertRaises(IndexError, test_device_indexError_error)
def test_device_value_error1():
device_error = 'gpu1'
props = get_device_properties(device_error)
self.assertRaises(ValueError, test_device_value_error1)
def test_device_value_error2():
device_error = float(device_count())
props = get_device_properties(device_error)
self.assertRaises(ValueError, test_device_value_error2)
if __name__ == "__main__":
unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册