提交 469e717a 编写于 作者: Z zlsh80826

add nvinfer_plugin dyloader

上级 68c6160e
......@@ -22,19 +22,15 @@ namespace dynload {
std::once_flag tensorrt_dso_flag;
void* tensorrt_dso_handle;
std::once_flag tensorrt_plugin_dso_flag;
void* tensorrt_plugin_dso_handle;
#define DEFINE_WRAP(__name) DynLoad__##__name __name
TENSORRT_RAND_ROUTINE_EACH(DEFINE_WRAP);
TENSORRT_PLUGIN_RAND_ROUTINE_EACH(DEFINE_WRAP);
void* GetTensorRtHandle() {
#if defined(__APPLE__) || defined(__OSX__)
std::string dso_name = "libnvinfer.dylib";
#elif defined(_WIN32)
std::string dso_name = "nvinfer.dll";
#else
std::string dso_name = "libnvinfer.so";
#endif
void* GetDsoHandle(const std::string& dso_name) {
#if !defined(_WIN32)
int dynload_flags = RTLD_LAZY | RTLD_LOCAL;
#else
......@@ -65,10 +61,31 @@ void* GetTensorRtHandle() {
#endif // !_WIN32
std::cerr << string::Sprintf(error_msg, dso_name, errorno);
}
return dso_handle;
}
void* GetTensorRtHandle() {
#if defined(__APPLE__) || defined(__OSX__)
std::string dso_name = "libnvinfer.dylib";
#elif defined(_WIN32)
std::string dso_name = "nvinfer.dll";
#else
std::string dso_name = "libnvinfer.so";
#endif
return GetDsoHandle(dso_name);
}
void* GetTensorRtPluginHandle() {
#if defined(__APPLE__) || defined(__OSX__)
std::string dso_name = "libnvinfer_plugin.dylib";
#elif defined(_WIN32)
std::string dso_name = "nvinfer_plugin.dll";
#else
std::string dso_name = "libnvinfer_plugin.so";
#endif
return GetDsoHandle(dso_name);
}
} // namespace dynload
} // namespace platform
} // namespace paddle
......@@ -14,6 +14,7 @@ limitations under the License. */
#pragma once
#include <NvInfer.h>
#include <NvInferPlugin.h>
#if !defined(_WIN32)
#include <dlfcn.h>
#endif
......@@ -28,10 +29,14 @@ namespace platform {
namespace dynload {
void* GetTensorRtHandle();
void* GetTensorRtPluginHandle();
extern std::once_flag tensorrt_dso_flag;
extern void* tensorrt_dso_handle;
extern std::once_flag tensorrt_plugin_dso_flag;
extern void* tensorrt_plugin_dso_handle;
#define DECLARE_DYNAMIC_LOAD_TENSORRT_WRAP(__name) \
struct DynLoad__##__name { \
template <typename... Args> \
......@@ -52,12 +57,38 @@ extern void* tensorrt_dso_handle;
}; \
extern DynLoad__##__name __name
#define DECLARE_DYNAMIC_LOAD_TENSORRT_PLUGIN_WRAP(__name) \
struct DynLoad__##__name { \
template <typename... Args> \
auto operator()(Args... args) -> DECLARE_TYPE(__name, args...) { \
using tensorrt_plugin_func = decltype(&::__name); \
std::call_once(tensorrt_plugin_dso_flag, []() { \
tensorrt_plugin_dso_handle = \
paddle::platform::dynload::GetTensorRtPluginHandle(); \
PADDLE_ENFORCE_NOT_NULL( \
tensorrt_plugin_dso_handle, \
platform::errors::Unavailable("Load tensorrt plugin %s failed", \
#__name)); \
}); \
static void* p_##__name = dlsym(tensorrt_plugin_dso_handle, #__name); \
PADDLE_ENFORCE_NOT_NULL(p_##__name, \
platform::errors::Unavailable( \
"Load tensorrt plugin %s failed", #__name)); \
return reinterpret_cast<tensorrt_plugin_func>(p_##__name)(args...); \
} \
}; \
extern DynLoad__##__name __name
#define TENSORRT_RAND_ROUTINE_EACH(__macro) \
__macro(createInferBuilder_INTERNAL); \
__macro(createInferRuntime_INTERNAL); \
__macro(getPluginRegistry);
#define TENSORRT_PLUGIN_RAND_ROUTINE_EACH(__macro) \
__macro(initLibNvInferPlugins);
TENSORRT_RAND_ROUTINE_EACH(DECLARE_DYNAMIC_LOAD_TENSORRT_WRAP)
TENSORRT_PLUGIN_RAND_ROUTINE_EACH(DECLARE_DYNAMIC_LOAD_TENSORRT_PLUGIN_WRAP)
} // namespace dynload
} // namespace platform
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册