提交 469e717a 编写于 作者: Z zlsh80826

add nvinfer_plugin dyloader

上级 68c6160e
...@@ -22,19 +22,15 @@ namespace dynload { ...@@ -22,19 +22,15 @@ namespace dynload {
std::once_flag tensorrt_dso_flag; std::once_flag tensorrt_dso_flag;
void* tensorrt_dso_handle; void* tensorrt_dso_handle;
std::once_flag tensorrt_plugin_dso_flag;
void* tensorrt_plugin_dso_handle;
#define DEFINE_WRAP(__name) DynLoad__##__name __name #define DEFINE_WRAP(__name) DynLoad__##__name __name
TENSORRT_RAND_ROUTINE_EACH(DEFINE_WRAP); TENSORRT_RAND_ROUTINE_EACH(DEFINE_WRAP);
TENSORRT_PLUGIN_RAND_ROUTINE_EACH(DEFINE_WRAP);
void* GetTensorRtHandle() { void* GetDsoHandle(const std::string& dso_name) {
#if defined(__APPLE__) || defined(__OSX__)
std::string dso_name = "libnvinfer.dylib";
#elif defined(_WIN32)
std::string dso_name = "nvinfer.dll";
#else
std::string dso_name = "libnvinfer.so";
#endif
#if !defined(_WIN32) #if !defined(_WIN32)
int dynload_flags = RTLD_LAZY | RTLD_LOCAL; int dynload_flags = RTLD_LAZY | RTLD_LOCAL;
#else #else
...@@ -65,10 +61,31 @@ void* GetTensorRtHandle() { ...@@ -65,10 +61,31 @@ void* GetTensorRtHandle() {
#endif // !_WIN32 #endif // !_WIN32
std::cerr << string::Sprintf(error_msg, dso_name, errorno); std::cerr << string::Sprintf(error_msg, dso_name, errorno);
} }
return dso_handle; return dso_handle;
} }
void* GetTensorRtHandle() {
#if defined(__APPLE__) || defined(__OSX__)
std::string dso_name = "libnvinfer.dylib";
#elif defined(_WIN32)
std::string dso_name = "nvinfer.dll";
#else
std::string dso_name = "libnvinfer.so";
#endif
return GetDsoHandle(dso_name);
}
void* GetTensorRtPluginHandle() {
#if defined(__APPLE__) || defined(__OSX__)
std::string dso_name = "libnvinfer_plugin.dylib";
#elif defined(_WIN32)
std::string dso_name = "nvinfer_plugin.dll";
#else
std::string dso_name = "libnvinfer_plugin.so";
#endif
return GetDsoHandle(dso_name);
}
} // namespace dynload } // namespace dynload
} // namespace platform } // namespace platform
} // namespace paddle } // namespace paddle
...@@ -14,6 +14,7 @@ limitations under the License. */ ...@@ -14,6 +14,7 @@ limitations under the License. */
#pragma once #pragma once
#include <NvInfer.h> #include <NvInfer.h>
#include <NvInferPlugin.h>
#if !defined(_WIN32) #if !defined(_WIN32)
#include <dlfcn.h> #include <dlfcn.h>
#endif #endif
...@@ -28,10 +29,14 @@ namespace platform { ...@@ -28,10 +29,14 @@ namespace platform {
namespace dynload { namespace dynload {
void* GetTensorRtHandle(); void* GetTensorRtHandle();
void* GetTensorRtPluginHandle();
extern std::once_flag tensorrt_dso_flag; extern std::once_flag tensorrt_dso_flag;
extern void* tensorrt_dso_handle; extern void* tensorrt_dso_handle;
extern std::once_flag tensorrt_plugin_dso_flag;
extern void* tensorrt_plugin_dso_handle;
#define DECLARE_DYNAMIC_LOAD_TENSORRT_WRAP(__name) \ #define DECLARE_DYNAMIC_LOAD_TENSORRT_WRAP(__name) \
struct DynLoad__##__name { \ struct DynLoad__##__name { \
template <typename... Args> \ template <typename... Args> \
...@@ -52,12 +57,38 @@ extern void* tensorrt_dso_handle; ...@@ -52,12 +57,38 @@ extern void* tensorrt_dso_handle;
}; \ }; \
extern DynLoad__##__name __name extern DynLoad__##__name __name
#define DECLARE_DYNAMIC_LOAD_TENSORRT_PLUGIN_WRAP(__name) \
struct DynLoad__##__name { \
template <typename... Args> \
auto operator()(Args... args) -> DECLARE_TYPE(__name, args...) { \
using tensorrt_plugin_func = decltype(&::__name); \
std::call_once(tensorrt_plugin_dso_flag, []() { \
tensorrt_plugin_dso_handle = \
paddle::platform::dynload::GetTensorRtPluginHandle(); \
PADDLE_ENFORCE_NOT_NULL( \
tensorrt_plugin_dso_handle, \
platform::errors::Unavailable("Load tensorrt plugin %s failed", \
#__name)); \
}); \
static void* p_##__name = dlsym(tensorrt_plugin_dso_handle, #__name); \
PADDLE_ENFORCE_NOT_NULL(p_##__name, \
platform::errors::Unavailable( \
"Load tensorrt plugin %s failed", #__name)); \
return reinterpret_cast<tensorrt_plugin_func>(p_##__name)(args...); \
} \
}; \
extern DynLoad__##__name __name
#define TENSORRT_RAND_ROUTINE_EACH(__macro) \ #define TENSORRT_RAND_ROUTINE_EACH(__macro) \
__macro(createInferBuilder_INTERNAL); \ __macro(createInferBuilder_INTERNAL); \
__macro(createInferRuntime_INTERNAL); \ __macro(createInferRuntime_INTERNAL); \
__macro(getPluginRegistry); __macro(getPluginRegistry);
#define TENSORRT_PLUGIN_RAND_ROUTINE_EACH(__macro) \
__macro(initLibNvInferPlugins);
TENSORRT_RAND_ROUTINE_EACH(DECLARE_DYNAMIC_LOAD_TENSORRT_WRAP) TENSORRT_RAND_ROUTINE_EACH(DECLARE_DYNAMIC_LOAD_TENSORRT_WRAP)
TENSORRT_PLUGIN_RAND_ROUTINE_EACH(DECLARE_DYNAMIC_LOAD_TENSORRT_PLUGIN_WRAP)
} // namespace dynload } // namespace dynload
} // namespace platform } // namespace platform
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册