未验证 提交 7e4ed848 编写于 作者: R ronnywang 提交者: GitHub

[CustomDevice]Improved custom device initialization (#39634)

上级 d6d0820e
......@@ -354,25 +354,15 @@ void RegisterKernelWithMetaInfoMap(
}
}
void LoadCustomKernelLib(const std::string& dso_lib_path) {
void LoadCustomKernelLib(const std::string& dso_lib_path, void* dso_handle) {
#ifdef _LINUX
void* dso_handle = nullptr;
int dynload_flags = RTLD_NOW | RTLD_LOCAL;
dso_handle = dlopen(dso_lib_path.c_str(), dynload_flags);
// MUST valid dso_lib_path
PADDLE_ENFORCE_NOT_NULL(
dso_handle,
platform::errors::InvalidArgument(
"Fail to open library: %s with error: %s", dso_lib_path, dlerror()));
typedef OpKernelInfoMap& get_op_kernel_info_map_t();
auto* func = reinterpret_cast<get_op_kernel_info_map_t*>(
dlsym(dso_handle, "PD_GetOpKernelInfoMap"));
if (func == nullptr) {
LOG(INFO) << "Skipped lib [" << dso_lib_path << "]: fail to find "
<< "PD_GetOpKernelInfoMap symbol in this lib.";
LOG(WARNING) << "Skipped lib [" << dso_lib_path << "]: fail to find "
<< "PD_GetOpKernelInfoMap symbol in this lib.";
return;
}
auto& op_kernel_info_map = func();
......@@ -384,42 +374,5 @@ void LoadCustomKernelLib(const std::string& dso_lib_path) {
return;
}
// List all libs with given path
std::vector<std::string> ListAllLib(const std::string& libs_path) {
DIR* dir = nullptr;
dir = opendir(libs_path.c_str());
// MUST valid libs_path
PADDLE_ENFORCE_NOT_NULL(dir, platform::errors::InvalidArgument(
"Fail to open path: %s", libs_path));
dirent* ptr = nullptr;
std::vector<std::string> libs;
std::regex express(".*\\.so");
std::match_results<std::string::iterator> results;
while ((ptr = readdir(dir)) != nullptr) {
std::string filename(ptr->d_name);
if (std::regex_match(filename.begin(), filename.end(), results, express)) {
libs.emplace_back(libs_path + '/' + filename);
LOG(INFO) << "Found lib [" << filename << "]";
} else {
VLOG(3) << "Skipped file [" << filename << "] without .so postfix";
}
}
closedir(dir);
return libs;
}
// Load custom kernels with given path
void LoadCustomKernel(const std::string& libs_path) {
VLOG(3) << "Try loading custom libs from: [" << libs_path << "]";
std::vector<std::string> libs = ListAllLib(libs_path);
for (auto& lib_path : libs) {
LoadCustomKernelLib(lib_path);
}
LOG(INFO) << "Finished in LoadCustomKernel with libs_path: [" << libs_path
<< "]";
}
} // namespace framework
} // namespace paddle
......@@ -19,10 +19,7 @@ limitations under the License. */
namespace paddle {
namespace framework {
// Load custom kernel lib from giwen path
void LoadCustomKernel(const std::string& libs_path);
void LoadCustomKernelLib(const std::string& dso_lib_path);
void LoadCustomKernelLib(const std::string& dso_lib_path, void* dso_handle);
// Load custom kernel api: register kernel after user compiled
void LoadOpKernelInfoAndRegister(const std::string& dso_name);
......
......@@ -621,28 +621,26 @@ bool ValidCustomCustomRuntimeParams(const CustomRuntimeParams* params) {
typedef bool (*RegisterDevicePluginFn)(CustomRuntimeParams* runtime_params);
bool LoadCustomRuntimeLib(const CustomRuntimeParams& runtime_params,
void LoadCustomRuntimeLib(const CustomRuntimeParams& runtime_params,
std::unique_ptr<C_DeviceInterface> device_interface,
void* dso_handle) {
const std::string& dso_lib_path, void* dso_handle) {
if (ValidCustomCustomRuntimeParams(&runtime_params)) {
auto device =
std::make_unique<CustomDevice>(runtime_params.device_type, 255, true,
std::move(device_interface), dso_handle);
if (false == DeviceManager::Register(std::move(device))) {
LOG(WARNING) << "Skip this library. Register failed!!! there may be a "
LOG(WARNING) << "Skipped lib [" << dso_lib_path
<< "]. Register failed!!! there may be a "
"Custom Runtime with the same name.";
return false;
}
} else {
LOG(WARNING)
<< "Skip this library. Wrong parameters!!! please check the version "
"compatibility between PaddlePaddle and Custom Runtime.";
return false;
LOG(WARNING) << "Skipped lib [" << dso_lib_path
<< "]. Wrong parameters!!! please check the version "
"compatibility between PaddlePaddle and Custom Runtime.";
}
return true;
}
bool LoadCustomRuntimeLib(void* dso_handle) {
void LoadCustomRuntimeLib(const std::string& dso_lib_path, void* dso_handle) {
CustomRuntimeParams runtime_params;
std::memset(&runtime_params, 0, sizeof(CustomRuntimeParams));
runtime_params.size = sizeof(CustomRuntimeParams);
......@@ -653,19 +651,23 @@ bool LoadCustomRuntimeLib(void* dso_handle) {
RegisterDevicePluginFn init_plugin_fn =
reinterpret_cast<RegisterDevicePluginFn>(dlsym(dso_handle, "InitPlugin"));
if (!init_plugin_fn) {
LOG(WARNING) << "Skip this library. InitPlugin symbol not found.";
return false;
if (init_plugin_fn == nullptr) {
LOG(WARNING) << "Skipped lib [" << dso_lib_path << "]: fail to find "
<< "InitPlugin symbol in this lib.";
return;
}
init_plugin_fn(&runtime_params);
if (runtime_params.device_type == nullptr) {
LOG(WARNING)
<< "Skip this library. InitPlugin failed!!! please check the version "
"compatibility between PaddlePaddle and Custom Runtime.";
return false;
}
return LoadCustomRuntimeLib(runtime_params, std::move(device_interface),
dso_handle);
LOG(WARNING) << "Skipped lib [" << dso_lib_path
<< "]: InitPlugin failed, please check the version "
"compatibility between PaddlePaddle and Custom Runtime.";
return;
}
LoadCustomRuntimeLib(runtime_params, std::move(device_interface),
dso_lib_path, dso_handle);
LOG(INFO) << "Successed in loading custom runtime in lib: " << dso_lib_path;
}
} // namespace platform
......
......@@ -30,8 +30,8 @@ void RegisterDevice() {
runtime_params.interface->size = sizeof(C_DeviceInterface);
InitFakeCPUDevice(&runtime_params);
EXPECT_TRUE(paddle::platform::LoadCustomRuntimeLib(
runtime_params, std::move(device_interface), nullptr));
paddle::platform::LoadCustomRuntimeLib(
runtime_params, std::move(device_interface), "", nullptr);
}
void InitDevice() {
......
......@@ -389,15 +389,14 @@ std::vector<std::string> ListAllLibraries(const std::string& library_dir) {
dir = opendir(library_dir.c_str());
if (dir == nullptr) {
VLOG(4) << "open CustomDevice library_dir: " << library_dir << " failed";
VLOG(4) << "Failed to open path: " << library_dir;
} else {
while ((ptr = readdir(dir)) != nullptr) {
std::string filename(ptr->d_name);
if (std::regex_match(filename.begin(), filename.end(), results,
express)) {
libraries.push_back(library_dir + '/' + filename);
VLOG(4) << "found CustomDevice library: " << libraries.back()
<< std::endl;
VLOG(4) << "Found lib: " << libraries.back();
}
}
closedir(dir);
......@@ -406,15 +405,6 @@ std::vector<std::string> ListAllLibraries(const std::string& library_dir) {
return libraries;
}
bool LoadCustomDevice(const std::string& library_dir) {
std::vector<std::string> libs = ListAllLibraries(library_dir);
for (const auto& lib_path : libs) {
auto dso_handle = dlopen(lib_path.c_str(), RTLD_NOW);
LoadCustomRuntimeLib(dso_handle);
}
return true;
}
} // namespace platform
} // namespace paddle
#endif
......@@ -162,13 +162,13 @@ class DeviceManager {
device_map_;
};
bool LoadCustomRuntimeLib(void* dso_handle);
std::vector<std::string> ListAllLibraries(const std::string& library_dir);
bool LoadCustomRuntimeLib(const CustomRuntimeParams& runtime_params,
std::unique_ptr<C_DeviceInterface> device_interface,
void* dso_handle);
void LoadCustomRuntimeLib(const std::string& dso_lib_path, void* dso_handle);
bool LoadCustomDevice(const std::string& library_path);
void LoadCustomRuntimeLib(const CustomRuntimeParams& runtime_params,
std::unique_ptr<C_DeviceInterface> device_interface,
const std::string& dso_lib_path, void* dso_handle);
class Registrar {
public:
......
......@@ -141,6 +141,25 @@ void InitCupti() {
}
#endif
#ifdef PADDLE_WITH_CUSTOM_DEVICE
void LoadCustomDevice(const std::string &library_dir) {
LOG(INFO) << "Try loading custom device libs from: [" << library_dir << "]";
std::vector<std::string> libs = platform::ListAllLibraries(library_dir);
for (const auto &lib_path : libs) {
auto dso_handle = dlopen(lib_path.c_str(), RTLD_NOW);
PADDLE_ENFORCE_NOT_NULL(
dso_handle,
platform::errors::InvalidArgument(
"Fail to open library: %s with error: %s", lib_path, dlerror()));
platform::LoadCustomRuntimeLib(lib_path, dso_handle);
framework::LoadCustomKernelLib(lib_path, dso_handle);
}
LOG(INFO) << "Finished in LoadCustomDevice with libs_path: [" << library_dir
<< "]";
}
#endif
void InitDevices() {
// CUPTI attribute should be set before any CUDA context is created (see CUPTI
// documentation about CUpti_ActivityAttribute).
......@@ -227,6 +246,7 @@ void InitDevices(const std::vector<int> devices) {
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
places.emplace_back(platform::CUDAPinnedPlace());
#endif
#ifdef PADDLE_WITH_CUSTOM_DEVICE
const char *custom_kernel_root_p = std::getenv("CUSTOM_DEVICE_ROOT");
if (!custom_kernel_root_p) {
VLOG(3) << "Env [CUSTOM_DEVICE_ROOT] is not set.";
......@@ -234,24 +254,22 @@ void InitDevices(const std::vector<int> devices) {
std::string custom_kernel_root(custom_kernel_root_p);
if (!custom_kernel_root.empty()) {
LOG(INFO) << "ENV [CUSTOM_DEVICE_ROOT]=" << custom_kernel_root;
framework::LoadCustomKernel(custom_kernel_root);
#ifdef PADDLE_WITH_CUSTOM_DEVICE
if (platform::LoadCustomDevice(custom_kernel_root)) {
auto device_types = platform::DeviceManager::GetAllCustomDeviceTypes();
for (auto &dev_type : device_types) {
VLOG(1) << "Device type: " << dev_type << ", visible devices count: "
<< platform::DeviceManager::GetDeviceCount(dev_type);
for (size_t i = 0;
i < platform::DeviceManager::GetDeviceCount(dev_type); i++) {
places.push_back(platform::CustomPlace(dev_type, i));
}
LoadCustomDevice(custom_kernel_root);
auto device_types = platform::DeviceManager::GetAllCustomDeviceTypes();
for (auto &dev_type : device_types) {
auto device_count = platform::DeviceManager::GetDeviceCount(dev_type);
LOG(INFO) << "CustomDevice: " << dev_type
<< ", visible devices count: " << device_count;
for (size_t i = 0; i < device_count; i++) {
places.push_back(platform::CustomPlace(dev_type, i));
}
}
#endif
} else {
VLOG(3) << "ENV [CUSTOM_DEVICE_ROOT] is empty.";
}
}
#endif
platform::DeviceContextPool::Init(places);
#ifndef PADDLE_WITH_MKLDNN
......
......@@ -633,10 +633,6 @@ class PADDLE_API OpKernelInfoBuilder {
// Call after PD_REGISTER_KERNEL(...)
void RegisterAllCustomKernel();
// Using this api to load compiled custom kernel's dynamic library and
// register custom kernels
void LoadCustomKernelLib(const std::string& dso_name);
//////////////// Custom kernel register macro /////////////////////
// Refer to paddle/pten/core/kernel_registry.h, we can not use
// PT_REGISTER_KERNEL directly, common macros and functions are
......
......@@ -92,12 +92,6 @@ void RegisterAllCustomKernel() {
framework::RegisterKernelWithMetaInfoMap(op_kernel_info_map);
}
// Using this api to load compiled custom kernel's dynamic library and
// register custom kernels
void LoadCustomKernelLib(const std::string& dso_name) {
framework::LoadCustomKernelLib(dso_name);
}
} // namespace paddle
#ifdef __cplusplus
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册