未验证 提交 2f9fa9b7 编写于 作者: W whs 提交者: GitHub

Merge pull request #10167 from wanghaoshuang/fluid_init

Add init interface for customize devices.
...@@ -15,7 +15,6 @@ limitations under the License. */ ...@@ -15,7 +15,6 @@ limitations under the License. */
#include <algorithm> #include <algorithm>
#include <stdexcept> #include <stdexcept>
#include <string> #include <string>
#include <vector>
#include "paddle/fluid/framework/init.h" #include "paddle/fluid/framework/init.h"
#include "paddle/fluid/framework/operator.h" #include "paddle/fluid/framework/operator.h"
...@@ -31,6 +30,7 @@ std::once_flag p2p_init_flag; ...@@ -31,6 +30,7 @@ std::once_flag p2p_init_flag;
void InitGflags(std::vector<std::string> argv) { void InitGflags(std::vector<std::string> argv) {
std::call_once(gflags_init_flag, [&]() { std::call_once(gflags_init_flag, [&]() {
argv.insert(argv.begin(), "dummy");
int argc = argv.size(); int argc = argv.size();
char **arr = new char *[argv.size()]; char **arr = new char *[argv.size()];
std::string line; std::string line;
...@@ -44,20 +44,23 @@ void InitGflags(std::vector<std::string> argv) { ...@@ -44,20 +44,23 @@ void InitGflags(std::vector<std::string> argv) {
}); });
} }
void InitP2P(int count) { void InitP2P(std::vector<int> devices) {
#ifdef PADDLE_WITH_CUDA #ifdef PADDLE_WITH_CUDA
std::call_once(p2p_init_flag, [&]() { std::call_once(p2p_init_flag, [&]() {
int count = devices.size();
for (int i = 0; i < count; ++i) { for (int i = 0; i < count; ++i) {
for (int j = 0; j < count; ++j) { for (int j = 0; j < count; ++j) {
if (i == j) continue; if (devices[i] == devices[j]) continue;
int can_acess = -1; int can_acess = -1;
PADDLE_ENFORCE(cudaDeviceCanAccessPeer(&can_acess, i, j), PADDLE_ENFORCE(
cudaDeviceCanAccessPeer(&can_acess, devices[i], devices[j]),
"Failed to test P2P access."); "Failed to test P2P access.");
if (can_acess != 1) { if (can_acess != 1) {
LOG(WARNING) << "Cannot enable P2P access from " << i << " to " << j; LOG(WARNING) << "Cannot enable P2P access from " << devices[i]
<< " to " << devices[j];
} else { } else {
cudaSetDevice(i); cudaSetDevice(devices[i]);
cudaDeviceEnablePeerAccess(j, 0); cudaDeviceEnablePeerAccess(devices[j], 0);
} }
} }
} }
...@@ -67,11 +70,26 @@ void InitP2P(int count) { ...@@ -67,11 +70,26 @@ void InitP2P(int count) {
void InitDevices(bool init_p2p) { void InitDevices(bool init_p2p) {
/*Init all available devices by default */ /*Init all available devices by default */
std::vector<int> devices;
#ifdef PADDLE_WITH_CUDA
try {
int count = platform::GetCUDADeviceCount();
for (int i = 0; i < count; ++i) {
devices.push_back(i);
}
} catch (const std::exception &exp) {
LOG(WARNING) << "Compiled with WITH_GPU, but no GPU found in runtime.";
}
#else
LOG(WARNING)
<< "'CUDA' is not supported, Please re-compile with WITH_GPU option";
#endif
InitDevices(init_p2p, devices);
}
void InitDevices(bool init_p2p, const std::vector<int> devices) {
std::vector<platform::Place> places; std::vector<platform::Place> places;
places.emplace_back(platform::CPUPlace());
int count = 0; int count = 0;
#ifdef PADDLE_WITH_CUDA #ifdef PADDLE_WITH_CUDA
try { try {
count = platform::GetCUDADeviceCount(); count = platform::GetCUDADeviceCount();
...@@ -83,12 +101,17 @@ void InitDevices(bool init_p2p) { ...@@ -83,12 +101,17 @@ void InitDevices(bool init_p2p) {
<< "'CUDA' is not supported, Please re-compile with WITH_GPU option"; << "'CUDA' is not supported, Please re-compile with WITH_GPU option";
#endif #endif
for (int i = 0; i < count; ++i) { for (size_t i = 0; i < devices.size(); ++i) {
places.emplace_back(platform::CUDAPlace(i)); if (devices[i] >= count || devices[i] < 0) {
LOG(WARNING) << "Invalid devices id.";
continue;
}
places.emplace_back(platform::CUDAPlace(devices[i]));
} }
if (init_p2p) { if (init_p2p) {
InitP2P(count); InitP2P(devices);
} }
places.emplace_back(platform::CPUPlace());
platform::DeviceContextPool::Init(places); platform::DeviceContextPool::Init(places);
} }
......
...@@ -28,5 +28,7 @@ void InitGLOG(const std::string &prog_name); ...@@ -28,5 +28,7 @@ void InitGLOG(const std::string &prog_name);
void InitDevices(bool init_p2p); void InitDevices(bool init_p2p);
void InitDevices(bool init_p2p, const std::vector<int> devices);
} // namespace framework } // namespace framework
} // namespace paddle } // namespace paddle
...@@ -16,17 +16,29 @@ limitations under the License. */ ...@@ -16,17 +16,29 @@ limitations under the License. */
#include <algorithm> #include <algorithm>
#include <fstream> #include <fstream>
#include <vector>
#include "paddle/fluid/framework/block_desc.h" #include "paddle/fluid/framework/block_desc.h"
#include "paddle/fluid/framework/feed_fetch_type.h" #include "paddle/fluid/framework/feed_fetch_type.h"
#include "paddle/fluid/framework/op_registry.h" #include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/pybind/pybind.h" #include "paddle/fluid/pybind/pybind.h"
DEFINE_string(devices, "", "The devices to be used which is joined by comma.");
DEFINE_bool(init_p2p, false, "Whether to init p2p.");
namespace paddle { namespace paddle {
namespace inference { namespace inference {
// Temporarily add this function for exposing framework::InitDevices() when void Init(const std::vector<std::string> argv) {
// linking the inference shared library. framework::InitGflags(argv);
void Init(bool init_p2p) { framework::InitDevices(init_p2p); } // init devices
std::vector<int> devices;
std::string token;
std::istringstream tokenStream(FLAGS_devices);
while (std::getline(tokenStream, token, ',')) {
devices.push_back(std::stoi(token));
}
framework::InitDevices(FLAGS_init_p2p, devices);
}
void ReadBinaryFile(const std::string& filename, std::string* contents) { void ReadBinaryFile(const std::string& filename, std::string* contents) {
std::ifstream fin(filename, std::ios::in | std::ios::binary); std::ifstream fin(filename, std::ios::in | std::ios::binary);
......
...@@ -25,7 +25,7 @@ limitations under the License. */ ...@@ -25,7 +25,7 @@ limitations under the License. */
namespace paddle { namespace paddle {
namespace inference { namespace inference {
void Init(bool init_p2p); void Init(const std::vector<std::string> argv);
void LoadPersistables(framework::Executor* executor, framework::Scope* scope, void LoadPersistables(framework::Executor* executor, framework::Scope* scope,
const framework::ProgramDesc& main_program, const framework::ProgramDesc& main_program,
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册