提交 1bdea0a8 编写于 作者: W wanghaoshuang

Add init interface for customize devices.

上级 2486d563
...@@ -15,19 +15,40 @@ limitations under the License. */ ...@@ -15,19 +15,40 @@ limitations under the License. */
#include <algorithm> #include <algorithm>
#include <stdexcept> #include <stdexcept>
#include <string> #include <string>
#include <vector>
#include "paddle/fluid/framework/init.h" #include "paddle/fluid/framework/init.h"
#include "paddle/fluid/framework/operator.h" #include "paddle/fluid/framework/operator.h"
#include "paddle/fluid/platform/device_context.h" #include "paddle/fluid/platform/device_context.h"
#include "paddle/fluid/platform/device_context.h"
#include "paddle/fluid/platform/place.h" #include "paddle/fluid/platform/place.h"
#include "paddle/fluid/string/piece.h" #include "paddle/fluid/string/piece.h"
namespace paddle { namespace paddle {
namespace framework { namespace framework {
DEFINE_string(devices, "", "The devices to be used.");
DEFINE_bool(init_p2p, true, "Whether to init p2p.");
std::once_flag gflags_init_flag; std::once_flag gflags_init_flag;
std::once_flag p2p_init_flag; std::once_flag p2p_init_flag;
using paddle::platform::DeviceContextPool;
void Init(int argc, char **argv) {
std::call_once(gflags_init_flag,
[&]() { google::ParseCommandLineFlags(&argc, &argv, true); });
// init devices
std::vector<int> devices;
std::string token;
std::istringstream tokenStream(FLAGS_devices);
while (std::getline(tokenStream, token, ',')) {
devices.push_back(std::stoi(token));
}
InitDevices(FLAGS_init_p2p, devices);
}
void InitGflags(std::vector<std::string> &argv) { void InitGflags(std::vector<std::string> &argv) {
std::call_once(gflags_init_flag, [&]() { std::call_once(gflags_init_flag, [&]() {
int argc = argv.size(); int argc = argv.size();
...@@ -64,6 +85,30 @@ void InitP2P(int count) { ...@@ -64,6 +85,30 @@ void InitP2P(int count) {
#endif #endif
} }
void InitP2P(std::vector<int> devices) {
#ifdef PADDLE_WITH_CUDA
std::call_once(p2p_init_flag, [&]() {
int count = devices.size();
for (int i = 0; i < count; ++i) {
for (int j = 0; j < count; ++j) {
if (devices[i] == devices[j]) continue;
int can_acess = -1;
PADDLE_ENFORCE(
cudaDeviceCanAccessPeer(&can_acess, devices[i], devices[j]),
"Failed to test P2P access.");
if (can_acess != 1) {
LOG(WARNING) << "Cannot enable P2P access from " << devices[i]
<< " to " << devices[j];
} else {
cudaSetDevice(devices[i]);
cudaDeviceEnablePeerAccess(devices[j], 0);
}
}
}
});
#endif
}
void InitDevices(bool init_p2p) { void InitDevices(bool init_p2p) {
/*Init all avaiable devices by default */ /*Init all avaiable devices by default */
...@@ -91,6 +136,34 @@ void InitDevices(bool init_p2p) { ...@@ -91,6 +136,34 @@ void InitDevices(bool init_p2p) {
platform::DeviceContextPool::Init(places); platform::DeviceContextPool::Init(places);
} }
void InitDevices(bool init_p2p, const std::vector<int> devices) {
std::vector<platform::Place> places;
int count = 0;
#ifdef PADDLE_WITH_CUDA
try {
count = platform::GetCUDADeviceCount();
} catch (const std::exception &exp) {
LOG(WARNING) << "Compiled with WITH_GPU, but no GPU found in runtime.";
}
#else
LOG(WARNING)
<< "'CUDA' is not supported, Please re-compile with WITH_GPU option";
#endif
for (size_t i = 0; i < devices.size(); ++i) {
if (devices[i] >= count) {
LOG(WARNING) << "Invalid devices id.";
continue;
}
places.emplace_back(platform::CUDAPlace(devices[i]));
}
if (init_p2p) {
InitP2P(devices);
}
places.emplace_back(platform::CPUPlace());
platform::DeviceContextPool::Init(places);
}
void InitGLOG(const std::string &prog_name) { void InitGLOG(const std::string &prog_name) {
// glog will not hold the ARGV[0] inside. // glog will not hold the ARGV[0] inside.
// Use strdup to alloc a new string. // Use strdup to alloc a new string.
......
...@@ -20,11 +20,15 @@ limitations under the License. */ ...@@ -20,11 +20,15 @@ limitations under the License. */
namespace paddle { namespace paddle {
namespace framework { namespace framework {
void Init(int argc, char **argv);
void InitGflags(std::vector<std::string> &argv); void InitGflags(std::vector<std::string> &argv);
void InitGLOG(const std::string &prog_name); void InitGLOG(const std::string &prog_name);
void InitDevices(bool init_p2p); void InitDevices(bool init_p2p);
void InitDevices(bool init_p2p, const std::vector<int> devices);
} // namespace framework } // namespace framework
} // namespace paddle } // namespace paddle
...@@ -28,6 +28,8 @@ namespace inference { ...@@ -28,6 +28,8 @@ namespace inference {
// linking the inference shared library. // linking the inference shared library.
void Init(bool init_p2p) { framework::InitDevices(init_p2p); } void Init(bool init_p2p) { framework::InitDevices(init_p2p); }
void Init(int argc, char** argv) { framework::Init(argc, argv); }
void ReadBinaryFile(const std::string& filename, std::string* contents) { void ReadBinaryFile(const std::string& filename, std::string* contents) {
std::ifstream fin(filename, std::ios::in | std::ios::binary); std::ifstream fin(filename, std::ios::in | std::ios::binary);
PADDLE_ENFORCE(static_cast<bool>(fin), "Cannot open file %s", filename); PADDLE_ENFORCE(static_cast<bool>(fin), "Cannot open file %s", filename);
......
...@@ -27,6 +27,8 @@ namespace inference { ...@@ -27,6 +27,8 @@ namespace inference {
void Init(bool init_p2p); void Init(bool init_p2p);
void Init(int argc, char** argv);
void LoadPersistables(framework::Executor* executor, framework::Scope* scope, void LoadPersistables(framework::Executor* executor, framework::Scope* scope,
const framework::ProgramDesc& main_program, const framework::ProgramDesc& main_program,
const std::string& dirname, const std::string& dirname,
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册