提交 9c360cc7 编写于 作者: S sneaxiy

test=develop

......@@ -15,6 +15,8 @@ limitations under the License. */
#include "paddle/fluid/platform/gpu_info.h"
#include <algorithm>
#include <cstdlib>
#include <string>
#include "gflags/gflags.h"
#include "paddle/fluid/platform/enforce.h"
......@@ -58,7 +60,18 @@ DEFINE_string(selected_gpus, "",
namespace paddle {
namespace platform {
int GetCUDADeviceCount() {
static int GetCUDADeviceCountImpl() {
const auto *cuda_visible_devices = std::getenv("CUDA_VISIBLE_DEVICES");
if (cuda_visible_devices != nullptr) {
std::string cuda_visible_devices_str(cuda_visible_devices);
if (std::all_of(cuda_visible_devices_str.begin(),
cuda_visible_devices_str.end(),
[](char ch) { return ch == ' '; })) {
VLOG(2) << "CUDA_VISIBLE_DEVICES is set to be empty. No GPU detected.";
return 0;
}
}
int count;
PADDLE_ENFORCE(
cudaGetDeviceCount(&count),
......@@ -66,6 +79,11 @@ int GetCUDADeviceCount() {
return count;
}
int GetCUDADeviceCount() {
static auto dev_cnt = GetCUDADeviceCountImpl();
return dev_cnt;
}
int GetCUDAComputeCapability(int id) {
PADDLE_ENFORCE_LT(id, GetCUDADeviceCount(), "id must less than GPU count");
cudaDeviceProp device_prop;
......
......@@ -626,7 +626,18 @@ All parameter, weight, gradient are variables in Paddle.
py::class_<platform::Communicator>(m, "Communicator").def(py::init<>());
#endif
py::class_<platform::CUDAPlace>(m, "CUDAPlace")
.def(py::init<int>())
.def("__init__",
[](platform::CUDAPlace &self, int dev_id) {
#ifdef PADDLE_WITH_CUDA
PADDLE_ENFORCE(
dev_id >= 0 && dev_id < platform::GetCUDADeviceCount(),
"Invalid CUDAPlace(%d), must inside [0, %d)", dev_id,
platform::GetCUDADeviceCount());
new (&self) platform::CUDAPlace(dev_id);
#else
PADDLE_THROW("Cannot use CUDAPlace in CPU only version");
#endif
})
.def("__str__", string::to_string<const platform::CUDAPlace &>);
py::class_<paddle::platform::CPUPlace>(m, "CPUPlace")
......@@ -634,7 +645,12 @@ All parameter, weight, gradient are variables in Paddle.
.def("__str__", string::to_string<const platform::CPUPlace &>);
py::class_<paddle::platform::CUDAPinnedPlace>(m, "CUDAPinnedPlace")
.def(py::init<>())
.def("__init__",
[](platform::CUDAPinnedPlace &) {
#ifndef PADDLE_WITH_CUDA
PADDLE_THROW("Cannot use CUDAPinnedPlace in CPU only version");
#endif
})
.def("__str__", string::to_string<const platform::CUDAPinnedPlace &>);
py::class_<platform::Place>(m, "Place")
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册