From beb0ca5fab7e791e1fbccf550fccd93617318ab2 Mon Sep 17 00:00:00 2001 From: Pei Yang Date: Fri, 7 Aug 2020 10:45:45 +0800 Subject: [PATCH] Fix TRT plugin registry without TRT lib (#25982) * fix trt plugin registry without trt lib * support trt4 * refine code style --- paddle/fluid/inference/tensorrt/helper.h | 4 +++- .../inference/tensorrt/plugin/trt_plugin.h | 10 ++++++++-- paddle/fluid/platform/dynload/tensorrt.h | 17 ++++++++++------- 3 files changed, 21 insertions(+), 10 deletions(-) diff --git a/paddle/fluid/inference/tensorrt/helper.h b/paddle/fluid/inference/tensorrt/helper.h index 55a57caf9a..971f99e691 100644 --- a/paddle/fluid/inference/tensorrt/helper.h +++ b/paddle/fluid/inference/tensorrt/helper.h @@ -56,9 +56,11 @@ static nvinfer1::IRuntime* createInferRuntime(nvinfer1::ILogger* logger) { return static_cast( dy::createInferRuntime_INTERNAL(logger, NV_TENSORRT_VERSION)); } -static nvinfer1::IPluginRegistry* getPluginRegistry() { +#if IS_TRT_VERSION_GE(6000) +static nvinfer1::IPluginRegistry* GetPluginRegistry() { return static_cast(dy::getPluginRegistry()); } +#endif // A logger for create TensorRT infer builder. class NaiveLogger : public nvinfer1::ILogger { diff --git a/paddle/fluid/inference/tensorrt/plugin/trt_plugin.h b/paddle/fluid/inference/tensorrt/plugin/trt_plugin.h index f4424b8b78..528adacb27 100644 --- a/paddle/fluid/inference/tensorrt/plugin/trt_plugin.h +++ b/paddle/fluid/inference/tensorrt/plugin/trt_plugin.h @@ -178,12 +178,16 @@ class DynamicPluginTensorRT : public nvinfer1::IPluginV2DynamicExt { std::string name_space_; std::string plugin_base_; }; -#endif template class TrtPluginRegistrarV2 { public: - TrtPluginRegistrarV2() { getPluginRegistry()->registerCreator(creator, ""); } + TrtPluginRegistrarV2() { + static auto func_ptr = GetPluginRegistry(); + if (func_ptr != nullptr) { + func_ptr->registerCreator(creator, ""); + } + } private: T creator; @@ -193,6 +197,8 @@ class TrtPluginRegistrarV2 { static paddle::inference::tensorrt::plugin::TrtPluginRegistrarV2 \ plugin_registrar_##name {} +#endif + } // namespace plugin } // namespace tensorrt } // namespace inference diff --git a/paddle/fluid/platform/dynload/tensorrt.h b/paddle/fluid/platform/dynload/tensorrt.h index 60e299385d..67a79ce4bb 100644 --- a/paddle/fluid/platform/dynload/tensorrt.h +++ b/paddle/fluid/platform/dynload/tensorrt.h @@ -36,26 +36,29 @@ extern void* tensorrt_dso_handle; struct DynLoad__##__name { \ template \ auto operator()(Args... args) -> DECLARE_TYPE(__name, args...) { \ - using tensorrt_func = decltype(&::__name); \ std::call_once(tensorrt_dso_flag, []() { \ tensorrt_dso_handle = paddle::platform::dynload::GetTensorRtHandle(); \ - PADDLE_ENFORCE_NOT_NULL(tensorrt_dso_handle, \ - platform::errors::Unavailable( \ - "Load tensorrt %s failed", #__name)); \ }); \ static void* p_##__name = dlsym(tensorrt_dso_handle, #__name); \ - PADDLE_ENFORCE_NOT_NULL( \ - p_##__name, \ - platform::errors::Unavailable("Load tensorrt %s failed", #__name)); \ + if (p_##__name == nullptr) { \ + return nullptr; \ + } \ + using tensorrt_func = decltype(&::__name); \ return reinterpret_cast(p_##__name)(args...); \ } \ }; \ extern DynLoad__##__name __name +#if (NV_TENSORRT_MAJOR >= 6) #define TENSORRT_RAND_ROUTINE_EACH(__macro) \ __macro(createInferBuilder_INTERNAL); \ __macro(createInferRuntime_INTERNAL); \ __macro(getPluginRegistry); +#else +#define TENSORRT_RAND_ROUTINE_EACH(__macro) \ + __macro(createInferBuilder_INTERNAL); \ + __macro(createInferRuntime_INTERNAL); +#endif TENSORRT_RAND_ROUTINE_EACH(DECLARE_DYNAMIC_LOAD_TENSORRT_WRAP) -- GitLab