diff --git a/paddle/fluid/platform/dynload/tensorrt.cc b/paddle/fluid/platform/dynload/tensorrt.cc index c9c3a9456b736ee1afb2efbe9bf092e2ae298372..8ddc9e982bab8cdcb80ce1b27ab6c024e8c6d5ef 100644 --- a/paddle/fluid/platform/dynload/tensorrt.cc +++ b/paddle/fluid/platform/dynload/tensorrt.cc @@ -22,19 +22,15 @@ namespace dynload { std::once_flag tensorrt_dso_flag; void* tensorrt_dso_handle; +std::once_flag tensorrt_plugin_dso_flag; +void* tensorrt_plugin_dso_handle; + #define DEFINE_WRAP(__name) DynLoad__##__name __name TENSORRT_RAND_ROUTINE_EACH(DEFINE_WRAP); +TENSORRT_PLUGIN_RAND_ROUTINE_EACH(DEFINE_WRAP); -void* GetTensorRtHandle() { -#if defined(__APPLE__) || defined(__OSX__) - std::string dso_name = "libnvinfer.dylib"; -#elif defined(_WIN32) - std::string dso_name = "nvinfer.dll"; -#else - std::string dso_name = "libnvinfer.so"; -#endif - +void* GetDsoHandle(const std::string& dso_name) { #if !defined(_WIN32) int dynload_flags = RTLD_LAZY | RTLD_LOCAL; #else @@ -65,10 +61,31 @@ void* GetTensorRtHandle() { #endif // !_WIN32 std::cerr << string::Sprintf(error_msg, dso_name, errorno); } - return dso_handle; } +void* GetTensorRtHandle() { +#if defined(__APPLE__) || defined(__OSX__) + std::string dso_name = "libnvinfer.dylib"; +#elif defined(_WIN32) + std::string dso_name = "nvinfer.dll"; +#else + std::string dso_name = "libnvinfer.so"; +#endif + return GetDsoHandle(dso_name); +} + +void* GetTensorRtPluginHandle() { +#if defined(__APPLE__) || defined(__OSX__) + std::string dso_name = "libnvinfer_plugin.dylib"; +#elif defined(_WIN32) + std::string dso_name = "nvinfer_plugin.dll"; +#else + std::string dso_name = "libnvinfer_plugin.so"; +#endif + return GetDsoHandle(dso_name); +} + } // namespace dynload } // namespace platform } // namespace paddle diff --git a/paddle/fluid/platform/dynload/tensorrt.h b/paddle/fluid/platform/dynload/tensorrt.h index 60e299385d6a6433d11753c7a0b96958b48a8e2a..c9982274f734591f4f31d64d4afd72057fa7e2f9 100644 --- a/paddle/fluid/platform/dynload/tensorrt.h +++ b/paddle/fluid/platform/dynload/tensorrt.h @@ -14,6 +14,7 @@ limitations under the License. */ #pragma once #include +#include #if !defined(_WIN32) #include #endif @@ -28,10 +29,14 @@ namespace platform { namespace dynload { void* GetTensorRtHandle(); +void* GetTensorRtPluginHandle(); extern std::once_flag tensorrt_dso_flag; extern void* tensorrt_dso_handle; +extern std::once_flag tensorrt_plugin_dso_flag; +extern void* tensorrt_plugin_dso_handle; + #define DECLARE_DYNAMIC_LOAD_TENSORRT_WRAP(__name) \ struct DynLoad__##__name { \ template \ @@ -52,12 +57,38 @@ extern void* tensorrt_dso_handle; }; \ extern DynLoad__##__name __name +#define DECLARE_DYNAMIC_LOAD_TENSORRT_PLUGIN_WRAP(__name) \ + struct DynLoad__##__name { \ + template \ + auto operator()(Args... args) -> DECLARE_TYPE(__name, args...) { \ + using tensorrt_plugin_func = decltype(&::__name); \ + std::call_once(tensorrt_plugin_dso_flag, []() { \ + tensorrt_plugin_dso_handle = \ + paddle::platform::dynload::GetTensorRtPluginHandle(); \ + PADDLE_ENFORCE_NOT_NULL( \ + tensorrt_plugin_dso_handle, \ + platform::errors::Unavailable("Load tensorrt plugin %s failed", \ + #__name)); \ + }); \ + static void* p_##__name = dlsym(tensorrt_plugin_dso_handle, #__name); \ + PADDLE_ENFORCE_NOT_NULL(p_##__name, \ + platform::errors::Unavailable( \ + "Load tensorrt plugin %s failed", #__name)); \ + return reinterpret_cast(p_##__name)(args...); \ + } \ + }; \ + extern DynLoad__##__name __name + #define TENSORRT_RAND_ROUTINE_EACH(__macro) \ __macro(createInferBuilder_INTERNAL); \ __macro(createInferRuntime_INTERNAL); \ __macro(getPluginRegistry); +#define TENSORRT_PLUGIN_RAND_ROUTINE_EACH(__macro) \ + __macro(initLibNvInferPlugins); + TENSORRT_RAND_ROUTINE_EACH(DECLARE_DYNAMIC_LOAD_TENSORRT_WRAP) +TENSORRT_PLUGIN_RAND_ROUTINE_EACH(DECLARE_DYNAMIC_LOAD_TENSORRT_PLUGIN_WRAP) } // namespace dynload } // namespace platform