未验证 提交 6f7eb0d5 编写于 作者: D dzhwinter 提交者: GitHub

"fix gpu init" (#7528)

* "fix gpu init"

* "set env variable default value for share gpu"

* "fix ci"

* "removed CUDA_VISIBLE_DEVICES default"

* "removed"
上级 455639b2
...@@ -13,6 +13,7 @@ See the License for the specific language governing permissions and ...@@ -13,6 +13,7 @@ See the License for the specific language governing permissions and
limitations under the License. */ limitations under the License. */
#include <string.h> // for strdup #include <string.h> // for strdup
#include <algorithm> #include <algorithm>
#include <stdexcept>
#include <string> #include <string>
#include "paddle/framework/init.h" #include "paddle/framework/init.h"
...@@ -46,17 +47,23 @@ void InitDevices() { ...@@ -46,17 +47,23 @@ void InitDevices() {
std::vector<platform::Place> places; std::vector<platform::Place> places;
places.emplace_back(platform::CPUPlace()); places.emplace_back(platform::CPUPlace());
int count = 0;
#ifdef PADDLE_WITH_CUDA #ifdef PADDLE_WITH_CUDA
int count = platform::GetCUDADeviceCount(); try {
for (int i = 0; i < count; ++i) { count = platform::GetCUDADeviceCount();
places.emplace_back(platform::CUDAPlace(i)); } catch (const std::exception &exp) {
LOG(WARNING) << "Compiled with WITH_GPU, but no GPU found in runtime.";
} }
#else #else
LOG(WARNING) LOG(WARNING)
<< "'GPU' is not supported, Please re-compile with WITH_GPU option"; << "'CUDA' is not supported, Please re-compile with WITH_GPU option";
#endif #endif
for (int i = 0; i < count; ++i) {
places.emplace_back(platform::CUDAPlace(i));
}
platform::DeviceContextPool::Init(places); platform::DeviceContextPool::Init(places);
} }
......
...@@ -20,7 +20,21 @@ TEST(InitDevices, CPU) { ...@@ -20,7 +20,21 @@ TEST(InitDevices, CPU) {
using paddle::framework::InitDevices; using paddle::framework::InitDevices;
using paddle::platform::DeviceContextPool; using paddle::platform::DeviceContextPool;
#ifndef PADDLE_WITH_CUDA
InitDevices(); InitDevices();
DeviceContextPool& pool = DeviceContextPool::Instance(); DeviceContextPool& pool = DeviceContextPool::Instance();
ASSERT_GE(pool.size(), 1U); ASSERT_EQ(pool.size(), 1U);
#endif
}
TEST(InitDevices, CUDA) {
using paddle::framework::InitDevices;
using paddle::platform::DeviceContextPool;
#ifdef PADDLE_WITH_CUDA
int count = paddle::platform::GetCUDADeviceCount();
InitDevices();
DeviceContextPool& pool = DeviceContextPool::Instance();
ASSERT_EQ(pool.size(), 1U + static_cast<unsigned>(count));
#endif
} }
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册