From 1bdea0a8d2fffe282c741712ade39d3604472fb9 Mon Sep 17 00:00:00 2001 From: wanghaoshuang Date: Tue, 24 Apr 2018 13:41:39 +0800 Subject: [PATCH] Add init interface for customize devices. --- paddle/fluid/framework/init.cc | 73 ++++++++++++++++++++++++++++++++++ paddle/fluid/framework/init.h | 4 ++ paddle/fluid/inference/io.cc | 2 + paddle/fluid/inference/io.h | 2 + 4 files changed, 81 insertions(+) diff --git a/paddle/fluid/framework/init.cc b/paddle/fluid/framework/init.cc index 75c557fa4..3ce37041c 100644 --- a/paddle/fluid/framework/init.cc +++ b/paddle/fluid/framework/init.cc @@ -15,19 +15,40 @@ limitations under the License. */ #include #include #include +#include #include "paddle/fluid/framework/init.h" #include "paddle/fluid/framework/operator.h" #include "paddle/fluid/platform/device_context.h" +#include "paddle/fluid/platform/device_context.h" #include "paddle/fluid/platform/place.h" #include "paddle/fluid/string/piece.h" namespace paddle { namespace framework { +DEFINE_string(devices, "", "The devices to be used."); +DEFINE_bool(init_p2p, true, "Whether to init p2p."); + std::once_flag gflags_init_flag; std::once_flag p2p_init_flag; +using paddle::platform::DeviceContextPool; + +void Init(int argc, char **argv) { + std::call_once(gflags_init_flag, + [&]() { google::ParseCommandLineFlags(&argc, &argv, true); }); + + // init devices + std::vector devices; + std::string token; + std::istringstream tokenStream(FLAGS_devices); + while (std::getline(tokenStream, token, ',')) { + devices.push_back(std::stoi(token)); + } + InitDevices(FLAGS_init_p2p, devices); +} + void InitGflags(std::vector &argv) { std::call_once(gflags_init_flag, [&]() { int argc = argv.size(); @@ -64,6 +85,30 @@ void InitP2P(int count) { #endif } +void InitP2P(std::vector devices) { +#ifdef PADDLE_WITH_CUDA + std::call_once(p2p_init_flag, [&]() { + int count = devices.size(); + for (int i = 0; i < count; ++i) { + for (int j = 0; j < count; ++j) { + if (devices[i] == devices[j]) continue; + int can_acess = -1; + PADDLE_ENFORCE( + cudaDeviceCanAccessPeer(&can_acess, devices[i], devices[j]), + "Failed to test P2P access."); + if (can_acess != 1) { + LOG(WARNING) << "Cannot enable P2P access from " << devices[i] + << " to " << devices[j]; + } else { + cudaSetDevice(devices[i]); + cudaDeviceEnablePeerAccess(devices[j], 0); + } + } + } + }); +#endif +} + void InitDevices(bool init_p2p) { /*Init all avaiable devices by default */ @@ -91,6 +136,34 @@ void InitDevices(bool init_p2p) { platform::DeviceContextPool::Init(places); } +void InitDevices(bool init_p2p, const std::vector devices) { + std::vector places; + int count = 0; +#ifdef PADDLE_WITH_CUDA + try { + count = platform::GetCUDADeviceCount(); + } catch (const std::exception &exp) { + LOG(WARNING) << "Compiled with WITH_GPU, but no GPU found in runtime."; + } +#else + LOG(WARNING) + << "'CUDA' is not supported, Please re-compile with WITH_GPU option"; +#endif + + for (size_t i = 0; i < devices.size(); ++i) { + if (devices[i] >= count) { + LOG(WARNING) << "Invalid devices id."; + continue; + } + places.emplace_back(platform::CUDAPlace(devices[i])); + } + if (init_p2p) { + InitP2P(devices); + } + places.emplace_back(platform::CPUPlace()); + platform::DeviceContextPool::Init(places); +} + void InitGLOG(const std::string &prog_name) { // glog will not hold the ARGV[0] inside. // Use strdup to alloc a new string. diff --git a/paddle/fluid/framework/init.h b/paddle/fluid/framework/init.h index fae98a60b..38604d232 100644 --- a/paddle/fluid/framework/init.h +++ b/paddle/fluid/framework/init.h @@ -20,11 +20,15 @@ limitations under the License. */ namespace paddle { namespace framework { +void Init(int argc, char **argv); + void InitGflags(std::vector &argv); void InitGLOG(const std::string &prog_name); void InitDevices(bool init_p2p); +void InitDevices(bool init_p2p, const std::vector devices); + } // namespace framework } // namespace paddle diff --git a/paddle/fluid/inference/io.cc b/paddle/fluid/inference/io.cc index 78d2f1674..74068d9db 100644 --- a/paddle/fluid/inference/io.cc +++ b/paddle/fluid/inference/io.cc @@ -28,6 +28,8 @@ namespace inference { // linking the inference shared library. void Init(bool init_p2p) { framework::InitDevices(init_p2p); } +void Init(int argc, char** argv) { framework::Init(argc, argv); } + void ReadBinaryFile(const std::string& filename, std::string* contents) { std::ifstream fin(filename, std::ios::in | std::ios::binary); PADDLE_ENFORCE(static_cast(fin), "Cannot open file %s", filename); diff --git a/paddle/fluid/inference/io.h b/paddle/fluid/inference/io.h index ba3e45099..988b8aebb 100644 --- a/paddle/fluid/inference/io.h +++ b/paddle/fluid/inference/io.h @@ -27,6 +27,8 @@ namespace inference { void Init(bool init_p2p); +void Init(int argc, char** argv); + void LoadPersistables(framework::Executor* executor, framework::Scope* scope, const framework::ProgramDesc& main_program, const std::string& dirname, -- GitLab