未验证 提交 7e4ed848 编写于 作者: R ronnywang 提交者: GitHub

[CustomDevice]Improved custom device initialization (#39634)

上级 d6d0820e
...@@ -354,25 +354,15 @@ void RegisterKernelWithMetaInfoMap( ...@@ -354,25 +354,15 @@ void RegisterKernelWithMetaInfoMap(
} }
} }
void LoadCustomKernelLib(const std::string& dso_lib_path) { void LoadCustomKernelLib(const std::string& dso_lib_path, void* dso_handle) {
#ifdef _LINUX #ifdef _LINUX
void* dso_handle = nullptr;
int dynload_flags = RTLD_NOW | RTLD_LOCAL;
dso_handle = dlopen(dso_lib_path.c_str(), dynload_flags);
// MUST valid dso_lib_path
PADDLE_ENFORCE_NOT_NULL(
dso_handle,
platform::errors::InvalidArgument(
"Fail to open library: %s with error: %s", dso_lib_path, dlerror()));
typedef OpKernelInfoMap& get_op_kernel_info_map_t(); typedef OpKernelInfoMap& get_op_kernel_info_map_t();
auto* func = reinterpret_cast<get_op_kernel_info_map_t*>( auto* func = reinterpret_cast<get_op_kernel_info_map_t*>(
dlsym(dso_handle, "PD_GetOpKernelInfoMap")); dlsym(dso_handle, "PD_GetOpKernelInfoMap"));
if (func == nullptr) { if (func == nullptr) {
LOG(INFO) << "Skipped lib [" << dso_lib_path << "]: fail to find " LOG(WARNING) << "Skipped lib [" << dso_lib_path << "]: fail to find "
<< "PD_GetOpKernelInfoMap symbol in this lib."; << "PD_GetOpKernelInfoMap symbol in this lib.";
return; return;
} }
auto& op_kernel_info_map = func(); auto& op_kernel_info_map = func();
...@@ -384,42 +374,5 @@ void LoadCustomKernelLib(const std::string& dso_lib_path) { ...@@ -384,42 +374,5 @@ void LoadCustomKernelLib(const std::string& dso_lib_path) {
return; return;
} }
// List all libs with given path
std::vector<std::string> ListAllLib(const std::string& libs_path) {
DIR* dir = nullptr;
dir = opendir(libs_path.c_str());
// MUST valid libs_path
PADDLE_ENFORCE_NOT_NULL(dir, platform::errors::InvalidArgument(
"Fail to open path: %s", libs_path));
dirent* ptr = nullptr;
std::vector<std::string> libs;
std::regex express(".*\\.so");
std::match_results<std::string::iterator> results;
while ((ptr = readdir(dir)) != nullptr) {
std::string filename(ptr->d_name);
if (std::regex_match(filename.begin(), filename.end(), results, express)) {
libs.emplace_back(libs_path + '/' + filename);
LOG(INFO) << "Found lib [" << filename << "]";
} else {
VLOG(3) << "Skipped file [" << filename << "] without .so postfix";
}
}
closedir(dir);
return libs;
}
// Load custom kernels with given path
void LoadCustomKernel(const std::string& libs_path) {
VLOG(3) << "Try loading custom libs from: [" << libs_path << "]";
std::vector<std::string> libs = ListAllLib(libs_path);
for (auto& lib_path : libs) {
LoadCustomKernelLib(lib_path);
}
LOG(INFO) << "Finished in LoadCustomKernel with libs_path: [" << libs_path
<< "]";
}
} // namespace framework } // namespace framework
} // namespace paddle } // namespace paddle
...@@ -19,10 +19,7 @@ limitations under the License. */ ...@@ -19,10 +19,7 @@ limitations under the License. */
namespace paddle { namespace paddle {
namespace framework { namespace framework {
// Load custom kernel lib from giwen path void LoadCustomKernelLib(const std::string& dso_lib_path, void* dso_handle);
void LoadCustomKernel(const std::string& libs_path);
void LoadCustomKernelLib(const std::string& dso_lib_path);
// Load custom kernel api: register kernel after user compiled // Load custom kernel api: register kernel after user compiled
void LoadOpKernelInfoAndRegister(const std::string& dso_name); void LoadOpKernelInfoAndRegister(const std::string& dso_name);
......
...@@ -621,28 +621,26 @@ bool ValidCustomCustomRuntimeParams(const CustomRuntimeParams* params) { ...@@ -621,28 +621,26 @@ bool ValidCustomCustomRuntimeParams(const CustomRuntimeParams* params) {
typedef bool (*RegisterDevicePluginFn)(CustomRuntimeParams* runtime_params); typedef bool (*RegisterDevicePluginFn)(CustomRuntimeParams* runtime_params);
bool LoadCustomRuntimeLib(const CustomRuntimeParams& runtime_params, void LoadCustomRuntimeLib(const CustomRuntimeParams& runtime_params,
std::unique_ptr<C_DeviceInterface> device_interface, std::unique_ptr<C_DeviceInterface> device_interface,
void* dso_handle) { const std::string& dso_lib_path, void* dso_handle) {
if (ValidCustomCustomRuntimeParams(&runtime_params)) { if (ValidCustomCustomRuntimeParams(&runtime_params)) {
auto device = auto device =
std::make_unique<CustomDevice>(runtime_params.device_type, 255, true, std::make_unique<CustomDevice>(runtime_params.device_type, 255, true,
std::move(device_interface), dso_handle); std::move(device_interface), dso_handle);
if (false == DeviceManager::Register(std::move(device))) { if (false == DeviceManager::Register(std::move(device))) {
LOG(WARNING) << "Skip this library. Register failed!!! there may be a " LOG(WARNING) << "Skipped lib [" << dso_lib_path
<< "]. Register failed!!! there may be a "
"Custom Runtime with the same name."; "Custom Runtime with the same name.";
return false;
} }
} else { } else {
LOG(WARNING) LOG(WARNING) << "Skipped lib [" << dso_lib_path
<< "Skip this library. Wrong parameters!!! please check the version " << "]. Wrong parameters!!! please check the version "
"compatibility between PaddlePaddle and Custom Runtime."; "compatibility between PaddlePaddle and Custom Runtime.";
return false;
} }
return true;
} }
bool LoadCustomRuntimeLib(void* dso_handle) { void LoadCustomRuntimeLib(const std::string& dso_lib_path, void* dso_handle) {
CustomRuntimeParams runtime_params; CustomRuntimeParams runtime_params;
std::memset(&runtime_params, 0, sizeof(CustomRuntimeParams)); std::memset(&runtime_params, 0, sizeof(CustomRuntimeParams));
runtime_params.size = sizeof(CustomRuntimeParams); runtime_params.size = sizeof(CustomRuntimeParams);
...@@ -653,19 +651,23 @@ bool LoadCustomRuntimeLib(void* dso_handle) { ...@@ -653,19 +651,23 @@ bool LoadCustomRuntimeLib(void* dso_handle) {
RegisterDevicePluginFn init_plugin_fn = RegisterDevicePluginFn init_plugin_fn =
reinterpret_cast<RegisterDevicePluginFn>(dlsym(dso_handle, "InitPlugin")); reinterpret_cast<RegisterDevicePluginFn>(dlsym(dso_handle, "InitPlugin"));
if (!init_plugin_fn) {
LOG(WARNING) << "Skip this library. InitPlugin symbol not found."; if (init_plugin_fn == nullptr) {
return false; LOG(WARNING) << "Skipped lib [" << dso_lib_path << "]: fail to find "
<< "InitPlugin symbol in this lib.";
return;
} }
init_plugin_fn(&runtime_params); init_plugin_fn(&runtime_params);
if (runtime_params.device_type == nullptr) { if (runtime_params.device_type == nullptr) {
LOG(WARNING) LOG(WARNING) << "Skipped lib [" << dso_lib_path
<< "Skip this library. InitPlugin failed!!! please check the version " << "]: InitPlugin failed, please check the version "
"compatibility between PaddlePaddle and Custom Runtime."; "compatibility between PaddlePaddle and Custom Runtime.";
return false; return;
} }
return LoadCustomRuntimeLib(runtime_params, std::move(device_interface), LoadCustomRuntimeLib(runtime_params, std::move(device_interface),
dso_handle); dso_lib_path, dso_handle);
LOG(INFO) << "Successed in loading custom runtime in lib: " << dso_lib_path;
} }
} // namespace platform } // namespace platform
......
...@@ -30,8 +30,8 @@ void RegisterDevice() { ...@@ -30,8 +30,8 @@ void RegisterDevice() {
runtime_params.interface->size = sizeof(C_DeviceInterface); runtime_params.interface->size = sizeof(C_DeviceInterface);
InitFakeCPUDevice(&runtime_params); InitFakeCPUDevice(&runtime_params);
EXPECT_TRUE(paddle::platform::LoadCustomRuntimeLib( paddle::platform::LoadCustomRuntimeLib(
runtime_params, std::move(device_interface), nullptr)); runtime_params, std::move(device_interface), "", nullptr);
} }
void InitDevice() { void InitDevice() {
......
...@@ -389,15 +389,14 @@ std::vector<std::string> ListAllLibraries(const std::string& library_dir) { ...@@ -389,15 +389,14 @@ std::vector<std::string> ListAllLibraries(const std::string& library_dir) {
dir = opendir(library_dir.c_str()); dir = opendir(library_dir.c_str());
if (dir == nullptr) { if (dir == nullptr) {
VLOG(4) << "open CustomDevice library_dir: " << library_dir << " failed"; VLOG(4) << "Failed to open path: " << library_dir;
} else { } else {
while ((ptr = readdir(dir)) != nullptr) { while ((ptr = readdir(dir)) != nullptr) {
std::string filename(ptr->d_name); std::string filename(ptr->d_name);
if (std::regex_match(filename.begin(), filename.end(), results, if (std::regex_match(filename.begin(), filename.end(), results,
express)) { express)) {
libraries.push_back(library_dir + '/' + filename); libraries.push_back(library_dir + '/' + filename);
VLOG(4) << "found CustomDevice library: " << libraries.back() VLOG(4) << "Found lib: " << libraries.back();
<< std::endl;
} }
} }
closedir(dir); closedir(dir);
...@@ -406,15 +405,6 @@ std::vector<std::string> ListAllLibraries(const std::string& library_dir) { ...@@ -406,15 +405,6 @@ std::vector<std::string> ListAllLibraries(const std::string& library_dir) {
return libraries; return libraries;
} }
bool LoadCustomDevice(const std::string& library_dir) {
std::vector<std::string> libs = ListAllLibraries(library_dir);
for (const auto& lib_path : libs) {
auto dso_handle = dlopen(lib_path.c_str(), RTLD_NOW);
LoadCustomRuntimeLib(dso_handle);
}
return true;
}
} // namespace platform } // namespace platform
} // namespace paddle } // namespace paddle
#endif #endif
...@@ -162,13 +162,13 @@ class DeviceManager { ...@@ -162,13 +162,13 @@ class DeviceManager {
device_map_; device_map_;
}; };
bool LoadCustomRuntimeLib(void* dso_handle); std::vector<std::string> ListAllLibraries(const std::string& library_dir);
bool LoadCustomRuntimeLib(const CustomRuntimeParams& runtime_params, void LoadCustomRuntimeLib(const std::string& dso_lib_path, void* dso_handle);
std::unique_ptr<C_DeviceInterface> device_interface,
void* dso_handle);
bool LoadCustomDevice(const std::string& library_path); void LoadCustomRuntimeLib(const CustomRuntimeParams& runtime_params,
std::unique_ptr<C_DeviceInterface> device_interface,
const std::string& dso_lib_path, void* dso_handle);
class Registrar { class Registrar {
public: public:
......
...@@ -141,6 +141,25 @@ void InitCupti() { ...@@ -141,6 +141,25 @@ void InitCupti() {
} }
#endif #endif
#ifdef PADDLE_WITH_CUSTOM_DEVICE
void LoadCustomDevice(const std::string &library_dir) {
LOG(INFO) << "Try loading custom device libs from: [" << library_dir << "]";
std::vector<std::string> libs = platform::ListAllLibraries(library_dir);
for (const auto &lib_path : libs) {
auto dso_handle = dlopen(lib_path.c_str(), RTLD_NOW);
PADDLE_ENFORCE_NOT_NULL(
dso_handle,
platform::errors::InvalidArgument(
"Fail to open library: %s with error: %s", lib_path, dlerror()));
platform::LoadCustomRuntimeLib(lib_path, dso_handle);
framework::LoadCustomKernelLib(lib_path, dso_handle);
}
LOG(INFO) << "Finished in LoadCustomDevice with libs_path: [" << library_dir
<< "]";
}
#endif
void InitDevices() { void InitDevices() {
// CUPTI attribute should be set before any CUDA context is created (see CUPTI // CUPTI attribute should be set before any CUDA context is created (see CUPTI
// documentation about CUpti_ActivityAttribute). // documentation about CUpti_ActivityAttribute).
...@@ -227,6 +246,7 @@ void InitDevices(const std::vector<int> devices) { ...@@ -227,6 +246,7 @@ void InitDevices(const std::vector<int> devices) {
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP) #if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
places.emplace_back(platform::CUDAPinnedPlace()); places.emplace_back(platform::CUDAPinnedPlace());
#endif #endif
#ifdef PADDLE_WITH_CUSTOM_DEVICE
const char *custom_kernel_root_p = std::getenv("CUSTOM_DEVICE_ROOT"); const char *custom_kernel_root_p = std::getenv("CUSTOM_DEVICE_ROOT");
if (!custom_kernel_root_p) { if (!custom_kernel_root_p) {
VLOG(3) << "Env [CUSTOM_DEVICE_ROOT] is not set."; VLOG(3) << "Env [CUSTOM_DEVICE_ROOT] is not set.";
...@@ -234,24 +254,22 @@ void InitDevices(const std::vector<int> devices) { ...@@ -234,24 +254,22 @@ void InitDevices(const std::vector<int> devices) {
std::string custom_kernel_root(custom_kernel_root_p); std::string custom_kernel_root(custom_kernel_root_p);
if (!custom_kernel_root.empty()) { if (!custom_kernel_root.empty()) {
LOG(INFO) << "ENV [CUSTOM_DEVICE_ROOT]=" << custom_kernel_root; LOG(INFO) << "ENV [CUSTOM_DEVICE_ROOT]=" << custom_kernel_root;
framework::LoadCustomKernel(custom_kernel_root); LoadCustomDevice(custom_kernel_root);
#ifdef PADDLE_WITH_CUSTOM_DEVICE
if (platform::LoadCustomDevice(custom_kernel_root)) { auto device_types = platform::DeviceManager::GetAllCustomDeviceTypes();
auto device_types = platform::DeviceManager::GetAllCustomDeviceTypes(); for (auto &dev_type : device_types) {
for (auto &dev_type : device_types) { auto device_count = platform::DeviceManager::GetDeviceCount(dev_type);
VLOG(1) << "Device type: " << dev_type << ", visible devices count: " LOG(INFO) << "CustomDevice: " << dev_type
<< platform::DeviceManager::GetDeviceCount(dev_type); << ", visible devices count: " << device_count;
for (size_t i = 0; for (size_t i = 0; i < device_count; i++) {
i < platform::DeviceManager::GetDeviceCount(dev_type); i++) { places.push_back(platform::CustomPlace(dev_type, i));
places.push_back(platform::CustomPlace(dev_type, i));
}
} }
} }
#endif
} else { } else {
VLOG(3) << "ENV [CUSTOM_DEVICE_ROOT] is empty."; VLOG(3) << "ENV [CUSTOM_DEVICE_ROOT] is empty.";
} }
} }
#endif
platform::DeviceContextPool::Init(places); platform::DeviceContextPool::Init(places);
#ifndef PADDLE_WITH_MKLDNN #ifndef PADDLE_WITH_MKLDNN
......
...@@ -633,10 +633,6 @@ class PADDLE_API OpKernelInfoBuilder { ...@@ -633,10 +633,6 @@ class PADDLE_API OpKernelInfoBuilder {
// Call after PD_REGISTER_KERNEL(...) // Call after PD_REGISTER_KERNEL(...)
void RegisterAllCustomKernel(); void RegisterAllCustomKernel();
// Using this api to load compiled custom kernel's dynamic library and
// register custom kernels
void LoadCustomKernelLib(const std::string& dso_name);
//////////////// Custom kernel register macro ///////////////////// //////////////// Custom kernel register macro /////////////////////
// Refer to paddle/pten/core/kernel_registry.h, we can not use // Refer to paddle/pten/core/kernel_registry.h, we can not use
// PT_REGISTER_KERNEL directly, common macros and functions are // PT_REGISTER_KERNEL directly, common macros and functions are
......
...@@ -92,12 +92,6 @@ void RegisterAllCustomKernel() { ...@@ -92,12 +92,6 @@ void RegisterAllCustomKernel() {
framework::RegisterKernelWithMetaInfoMap(op_kernel_info_map); framework::RegisterKernelWithMetaInfoMap(op_kernel_info_map);
} }
// Using this api to load compiled custom kernel's dynamic library and
// register custom kernels
void LoadCustomKernelLib(const std::string& dso_name) {
framework::LoadCustomKernelLib(dso_name);
}
} // namespace paddle } // namespace paddle
#ifdef __cplusplus #ifdef __cplusplus
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册