未验证 提交 beb0ca5f 编写于 作者: P Pei Yang 提交者: GitHub

Fix TRT plugin registry without TRT lib (#25982)

* fix trt plugin registry without trt lib

* support trt4

* refine code style
上级 2191a083
...@@ -56,9 +56,11 @@ static nvinfer1::IRuntime* createInferRuntime(nvinfer1::ILogger* logger) { ...@@ -56,9 +56,11 @@ static nvinfer1::IRuntime* createInferRuntime(nvinfer1::ILogger* logger) {
return static_cast<nvinfer1::IRuntime*>( return static_cast<nvinfer1::IRuntime*>(
dy::createInferRuntime_INTERNAL(logger, NV_TENSORRT_VERSION)); dy::createInferRuntime_INTERNAL(logger, NV_TENSORRT_VERSION));
} }
static nvinfer1::IPluginRegistry* getPluginRegistry() { #if IS_TRT_VERSION_GE(6000)
static nvinfer1::IPluginRegistry* GetPluginRegistry() {
return static_cast<nvinfer1::IPluginRegistry*>(dy::getPluginRegistry()); return static_cast<nvinfer1::IPluginRegistry*>(dy::getPluginRegistry());
} }
#endif
// A logger for create TensorRT infer builder. // A logger for create TensorRT infer builder.
class NaiveLogger : public nvinfer1::ILogger { class NaiveLogger : public nvinfer1::ILogger {
......
...@@ -178,12 +178,16 @@ class DynamicPluginTensorRT : public nvinfer1::IPluginV2DynamicExt { ...@@ -178,12 +178,16 @@ class DynamicPluginTensorRT : public nvinfer1::IPluginV2DynamicExt {
std::string name_space_; std::string name_space_;
std::string plugin_base_; std::string plugin_base_;
}; };
#endif
template <typename T> template <typename T>
class TrtPluginRegistrarV2 { class TrtPluginRegistrarV2 {
public: public:
TrtPluginRegistrarV2() { getPluginRegistry()->registerCreator(creator, ""); } TrtPluginRegistrarV2() {
static auto func_ptr = GetPluginRegistry();
if (func_ptr != nullptr) {
func_ptr->registerCreator(creator, "");
}
}
private: private:
T creator; T creator;
...@@ -193,6 +197,8 @@ class TrtPluginRegistrarV2 { ...@@ -193,6 +197,8 @@ class TrtPluginRegistrarV2 {
static paddle::inference::tensorrt::plugin::TrtPluginRegistrarV2<name> \ static paddle::inference::tensorrt::plugin::TrtPluginRegistrarV2<name> \
plugin_registrar_##name {} plugin_registrar_##name {}
#endif
} // namespace plugin } // namespace plugin
} // namespace tensorrt } // namespace tensorrt
} // namespace inference } // namespace inference
......
...@@ -36,26 +36,29 @@ extern void* tensorrt_dso_handle; ...@@ -36,26 +36,29 @@ extern void* tensorrt_dso_handle;
struct DynLoad__##__name { \ struct DynLoad__##__name { \
template <typename... Args> \ template <typename... Args> \
auto operator()(Args... args) -> DECLARE_TYPE(__name, args...) { \ auto operator()(Args... args) -> DECLARE_TYPE(__name, args...) { \
using tensorrt_func = decltype(&::__name); \
std::call_once(tensorrt_dso_flag, []() { \ std::call_once(tensorrt_dso_flag, []() { \
tensorrt_dso_handle = paddle::platform::dynload::GetTensorRtHandle(); \ tensorrt_dso_handle = paddle::platform::dynload::GetTensorRtHandle(); \
PADDLE_ENFORCE_NOT_NULL(tensorrt_dso_handle, \
platform::errors::Unavailable( \
"Load tensorrt %s failed", #__name)); \
}); \ }); \
static void* p_##__name = dlsym(tensorrt_dso_handle, #__name); \ static void* p_##__name = dlsym(tensorrt_dso_handle, #__name); \
PADDLE_ENFORCE_NOT_NULL( \ if (p_##__name == nullptr) { \
p_##__name, \ return nullptr; \
platform::errors::Unavailable("Load tensorrt %s failed", #__name)); \ } \
using tensorrt_func = decltype(&::__name); \
return reinterpret_cast<tensorrt_func>(p_##__name)(args...); \ return reinterpret_cast<tensorrt_func>(p_##__name)(args...); \
} \ } \
}; \ }; \
extern DynLoad__##__name __name extern DynLoad__##__name __name
#if (NV_TENSORRT_MAJOR >= 6)
#define TENSORRT_RAND_ROUTINE_EACH(__macro) \ #define TENSORRT_RAND_ROUTINE_EACH(__macro) \
__macro(createInferBuilder_INTERNAL); \ __macro(createInferBuilder_INTERNAL); \
__macro(createInferRuntime_INTERNAL); \ __macro(createInferRuntime_INTERNAL); \
__macro(getPluginRegistry); __macro(getPluginRegistry);
#else
#define TENSORRT_RAND_ROUTINE_EACH(__macro) \
__macro(createInferBuilder_INTERNAL); \
__macro(createInferRuntime_INTERNAL);
#endif
TENSORRT_RAND_ROUTINE_EACH(DECLARE_DYNAMIC_LOAD_TENSORRT_WRAP) TENSORRT_RAND_ROUTINE_EACH(DECLARE_DYNAMIC_LOAD_TENSORRT_WRAP)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册