未验证 提交 2f9fa9b7 编写于 作者: W whs 提交者: GitHub

Merge pull request #10167 from wanghaoshuang/fluid_init

Add init interface for customize devices.
......@@ -15,7 +15,6 @@ limitations under the License. */
#include <algorithm>
#include <stdexcept>
#include <string>
#include <vector>
#include "paddle/fluid/framework/init.h"
#include "paddle/fluid/framework/operator.h"
......@@ -31,6 +30,7 @@ std::once_flag p2p_init_flag;
void InitGflags(std::vector<std::string> argv) {
std::call_once(gflags_init_flag, [&]() {
argv.insert(argv.begin(), "dummy");
int argc = argv.size();
char **arr = new char *[argv.size()];
std::string line;
......@@ -44,20 +44,23 @@ void InitGflags(std::vector<std::string> argv) {
});
}
void InitP2P(int count) {
void InitP2P(std::vector<int> devices) {
#ifdef PADDLE_WITH_CUDA
std::call_once(p2p_init_flag, [&]() {
int count = devices.size();
for (int i = 0; i < count; ++i) {
for (int j = 0; j < count; ++j) {
if (i == j) continue;
if (devices[i] == devices[j]) continue;
int can_acess = -1;
PADDLE_ENFORCE(cudaDeviceCanAccessPeer(&can_acess, i, j),
"Failed to test P2P access.");
PADDLE_ENFORCE(
cudaDeviceCanAccessPeer(&can_acess, devices[i], devices[j]),
"Failed to test P2P access.");
if (can_acess != 1) {
LOG(WARNING) << "Cannot enable P2P access from " << i << " to " << j;
LOG(WARNING) << "Cannot enable P2P access from " << devices[i]
<< " to " << devices[j];
} else {
cudaSetDevice(i);
cudaDeviceEnablePeerAccess(j, 0);
cudaSetDevice(devices[i]);
cudaDeviceEnablePeerAccess(devices[j], 0);
}
}
}
......@@ -67,11 +70,26 @@ void InitP2P(int count) {
void InitDevices(bool init_p2p) {
/*Init all available devices by default */
std::vector<int> devices;
#ifdef PADDLE_WITH_CUDA
try {
int count = platform::GetCUDADeviceCount();
for (int i = 0; i < count; ++i) {
devices.push_back(i);
}
} catch (const std::exception &exp) {
LOG(WARNING) << "Compiled with WITH_GPU, but no GPU found in runtime.";
}
#else
LOG(WARNING)
<< "'CUDA' is not supported, Please re-compile with WITH_GPU option";
#endif
InitDevices(init_p2p, devices);
}
void InitDevices(bool init_p2p, const std::vector<int> devices) {
std::vector<platform::Place> places;
places.emplace_back(platform::CPUPlace());
int count = 0;
#ifdef PADDLE_WITH_CUDA
try {
count = platform::GetCUDADeviceCount();
......@@ -83,12 +101,17 @@ void InitDevices(bool init_p2p) {
<< "'CUDA' is not supported, Please re-compile with WITH_GPU option";
#endif
for (int i = 0; i < count; ++i) {
places.emplace_back(platform::CUDAPlace(i));
for (size_t i = 0; i < devices.size(); ++i) {
if (devices[i] >= count || devices[i] < 0) {
LOG(WARNING) << "Invalid devices id.";
continue;
}
places.emplace_back(platform::CUDAPlace(devices[i]));
}
if (init_p2p) {
InitP2P(count);
InitP2P(devices);
}
places.emplace_back(platform::CPUPlace());
platform::DeviceContextPool::Init(places);
}
......
......@@ -28,5 +28,7 @@ void InitGLOG(const std::string &prog_name);
void InitDevices(bool init_p2p);
void InitDevices(bool init_p2p, const std::vector<int> devices);
} // namespace framework
} // namespace paddle
......@@ -16,17 +16,29 @@ limitations under the License. */
#include <algorithm>
#include <fstream>
#include <vector>
#include "paddle/fluid/framework/block_desc.h"
#include "paddle/fluid/framework/feed_fetch_type.h"
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/pybind/pybind.h"
DEFINE_string(devices, "", "The devices to be used which is joined by comma.");
DEFINE_bool(init_p2p, false, "Whether to init p2p.");
namespace paddle {
namespace inference {
// Temporarily add this function for exposing framework::InitDevices() when
// linking the inference shared library.
void Init(bool init_p2p) { framework::InitDevices(init_p2p); }
void Init(const std::vector<std::string> argv) {
framework::InitGflags(argv);
// init devices
std::vector<int> devices;
std::string token;
std::istringstream tokenStream(FLAGS_devices);
while (std::getline(tokenStream, token, ',')) {
devices.push_back(std::stoi(token));
}
framework::InitDevices(FLAGS_init_p2p, devices);
}
void ReadBinaryFile(const std::string& filename, std::string* contents) {
std::ifstream fin(filename, std::ios::in | std::ios::binary);
......
......@@ -25,7 +25,7 @@ limitations under the License. */
namespace paddle {
namespace inference {
void Init(bool init_p2p);
void Init(const std::vector<std::string> argv);
void LoadPersistables(framework::Executor* executor, framework::Scope* scope,
const framework::ProgramDesc& main_program,
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册