未验证 提交 beb0ca5f 编写于 作者: P Pei Yang 提交者: GitHub

Fix TRT plugin registry without TRT lib (#25982)

* fix trt plugin registry without trt lib

* support trt4

* refine code style
上级 2191a083
......@@ -56,9 +56,11 @@ static nvinfer1::IRuntime* createInferRuntime(nvinfer1::ILogger* logger) {
return static_cast<nvinfer1::IRuntime*>(
dy::createInferRuntime_INTERNAL(logger, NV_TENSORRT_VERSION));
}
static nvinfer1::IPluginRegistry* getPluginRegistry() {
#if IS_TRT_VERSION_GE(6000)
static nvinfer1::IPluginRegistry* GetPluginRegistry() {
return static_cast<nvinfer1::IPluginRegistry*>(dy::getPluginRegistry());
}
#endif
// A logger for create TensorRT infer builder.
class NaiveLogger : public nvinfer1::ILogger {
......
......@@ -178,12 +178,16 @@ class DynamicPluginTensorRT : public nvinfer1::IPluginV2DynamicExt {
std::string name_space_;
std::string plugin_base_;
};
#endif
template <typename T>
class TrtPluginRegistrarV2 {
public:
TrtPluginRegistrarV2() { getPluginRegistry()->registerCreator(creator, ""); }
TrtPluginRegistrarV2() {
static auto func_ptr = GetPluginRegistry();
if (func_ptr != nullptr) {
func_ptr->registerCreator(creator, "");
}
}
private:
T creator;
......@@ -193,6 +197,8 @@ class TrtPluginRegistrarV2 {
static paddle::inference::tensorrt::plugin::TrtPluginRegistrarV2<name> \
plugin_registrar_##name {}
#endif
} // namespace plugin
} // namespace tensorrt
} // namespace inference
......
......@@ -36,26 +36,29 @@ extern void* tensorrt_dso_handle;
struct DynLoad__##__name { \
template <typename... Args> \
auto operator()(Args... args) -> DECLARE_TYPE(__name, args...) { \
using tensorrt_func = decltype(&::__name); \
std::call_once(tensorrt_dso_flag, []() { \
tensorrt_dso_handle = paddle::platform::dynload::GetTensorRtHandle(); \
PADDLE_ENFORCE_NOT_NULL(tensorrt_dso_handle, \
platform::errors::Unavailable( \
"Load tensorrt %s failed", #__name)); \
}); \
static void* p_##__name = dlsym(tensorrt_dso_handle, #__name); \
PADDLE_ENFORCE_NOT_NULL( \
p_##__name, \
platform::errors::Unavailable("Load tensorrt %s failed", #__name)); \
if (p_##__name == nullptr) { \
return nullptr; \
} \
using tensorrt_func = decltype(&::__name); \
return reinterpret_cast<tensorrt_func>(p_##__name)(args...); \
} \
}; \
extern DynLoad__##__name __name
#if (NV_TENSORRT_MAJOR >= 6)
#define TENSORRT_RAND_ROUTINE_EACH(__macro) \
__macro(createInferBuilder_INTERNAL); \
__macro(createInferRuntime_INTERNAL); \
__macro(getPluginRegistry);
#else
#define TENSORRT_RAND_ROUTINE_EACH(__macro) \
__macro(createInferBuilder_INTERNAL); \
__macro(createInferRuntime_INTERNAL);
#endif
TENSORRT_RAND_ROUTINE_EACH(DECLARE_DYNAMIC_LOAD_TENSORRT_WRAP)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册