diff --git a/paddle/framework/init.cc b/paddle/framework/init.cc index 4ef82a541efaa35bcf831d5122570154f2fa2423..3f6ea121b3994979d89a7d5a8c20c59240a0c111 100644 --- a/paddle/framework/init.cc +++ b/paddle/framework/init.cc @@ -13,6 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. */ #include // for strdup #include +#include #include #include "paddle/framework/init.h" @@ -46,17 +47,23 @@ void InitDevices() { std::vector places; places.emplace_back(platform::CPUPlace()); + int count = 0; #ifdef PADDLE_WITH_CUDA - int count = platform::GetCUDADeviceCount(); - for (int i = 0; i < count; ++i) { - places.emplace_back(platform::CUDAPlace(i)); + try { + count = platform::GetCUDADeviceCount(); + } catch (const std::exception &exp) { + LOG(WARNING) << "Compiled with WITH_GPU, but no GPU found in runtime."; } #else LOG(WARNING) - << "'GPU' is not supported, Please re-compile with WITH_GPU option"; + << "'CUDA' is not supported, Please re-compile with WITH_GPU option"; #endif + for (int i = 0; i < count; ++i) { + places.emplace_back(platform::CUDAPlace(i)); + } + platform::DeviceContextPool::Init(places); } diff --git a/paddle/framework/init_test.cc b/paddle/framework/init_test.cc index f837a965d3be7d40c20803ae4462b3bfd91bffd0..01e076dd8ea24831e3ed7c8a7f8fae6818a89335 100644 --- a/paddle/framework/init_test.cc +++ b/paddle/framework/init_test.cc @@ -20,7 +20,21 @@ TEST(InitDevices, CPU) { using paddle::framework::InitDevices; using paddle::platform::DeviceContextPool; +#ifndef PADDLE_WITH_CUDA InitDevices(); DeviceContextPool& pool = DeviceContextPool::Instance(); - ASSERT_GE(pool.size(), 1U); + ASSERT_EQ(pool.size(), 1U); +#endif +} + +TEST(InitDevices, CUDA) { + using paddle::framework::InitDevices; + using paddle::platform::DeviceContextPool; + +#ifdef PADDLE_WITH_CUDA + int count = paddle::platform::GetCUDADeviceCount(); + InitDevices(); + DeviceContextPool& pool = DeviceContextPool::Instance(); + ASSERT_EQ(pool.size(), 1U + static_cast(count)); +#endif }