From 8c36614985f10b00f7f9607cea124ca1d37c4907 Mon Sep 17 00:00:00 2001 From: zlsh80826 Date: Fri, 7 Aug 2020 12:18:50 +0800 Subject: [PATCH] add USE_NVINFER_PLUGIN define detection --- paddle/fluid/inference/tensorrt/convert/gelu_op.cc | 7 +++++++ paddle/fluid/inference/tensorrt/engine.h | 3 ++- paddle/fluid/platform/dynload/tensorrt.cc | 7 +++++++ paddle/fluid/platform/dynload/tensorrt.h | 8 +++++++- 4 files changed, 23 insertions(+), 2 deletions(-) diff --git a/paddle/fluid/inference/tensorrt/convert/gelu_op.cc b/paddle/fluid/inference/tensorrt/convert/gelu_op.cc index c10fffc706e..5f64f3771d1 100644 --- a/paddle/fluid/inference/tensorrt/convert/gelu_op.cc +++ b/paddle/fluid/inference/tensorrt/convert/gelu_op.cc @@ -47,6 +47,8 @@ class GeluOpConverter : public OpConverter { nvinfer1::ILayer* layer = nullptr; if (engine_->with_dynamic_shape()) { #if IS_TRT_VERSION_GE(6000) + +#ifdef USE_NVINFER_PLUGIN auto creator = getPluginRegistry()->getPluginCreator("CustomGeluPluginDynamic", "1"); assert(creator != nullptr); @@ -68,6 +70,11 @@ class GeluOpConverter : public OpConverter { creator->createPlugin("CustomGeluPluginDynamic", pluginPtr); layer = engine_->network()->addPluginV2(&input, input_num, *pluginObj); assert(layer != nullptr); +#else + plugin::GeluPluginDynamic* plugin = new plugin::GeluPluginDynamic(); + layer = engine_->AddPluginV2(&input, input_num, plugin); +#endif + #else PADDLE_THROW(platform::errors::Fatal( "You are running the TRT Dynamic Shape mode, need to confirm that " diff --git a/paddle/fluid/inference/tensorrt/engine.h b/paddle/fluid/inference/tensorrt/engine.h index 774d0fc4694..1af2a4c5c73 100644 --- a/paddle/fluid/inference/tensorrt/engine.h +++ b/paddle/fluid/inference/tensorrt/engine.h @@ -157,8 +157,9 @@ class TensorRTEngine { "version should be at least 6."; #endif } - +#ifdef USE_NVINFER_PLUGIN dy::initLibNvInferPlugins(&logger, ""); +#endif } ~TensorRTEngine() {} diff --git a/paddle/fluid/platform/dynload/tensorrt.cc b/paddle/fluid/platform/dynload/tensorrt.cc index 8ddc9e982ba..ecca415b998 100644 --- a/paddle/fluid/platform/dynload/tensorrt.cc +++ b/paddle/fluid/platform/dynload/tensorrt.cc @@ -22,13 +22,18 @@ namespace dynload { std::once_flag tensorrt_dso_flag; void* tensorrt_dso_handle; +#ifdef USE_NVINFER_PLUGIN std::once_flag tensorrt_plugin_dso_flag; void* tensorrt_plugin_dso_handle; +#endif #define DEFINE_WRAP(__name) DynLoad__##__name __name TENSORRT_RAND_ROUTINE_EACH(DEFINE_WRAP); + +#ifdef USE_NVINFER_PLUGIN TENSORRT_PLUGIN_RAND_ROUTINE_EACH(DEFINE_WRAP); +#endif void* GetDsoHandle(const std::string& dso_name) { #if !defined(_WIN32) @@ -75,6 +80,7 @@ void* GetTensorRtHandle() { return GetDsoHandle(dso_name); } +#ifdef USE_NVINFER_PLUGIN void* GetTensorRtPluginHandle() { #if defined(__APPLE__) || defined(__OSX__) std::string dso_name = "libnvinfer_plugin.dylib"; @@ -85,6 +91,7 @@ void* GetTensorRtPluginHandle() { #endif return GetDsoHandle(dso_name); } +#endif } // namespace dynload } // namespace platform diff --git a/paddle/fluid/platform/dynload/tensorrt.h b/paddle/fluid/platform/dynload/tensorrt.h index c9982274f73..4521a4f2e19 100644 --- a/paddle/fluid/platform/dynload/tensorrt.h +++ b/paddle/fluid/platform/dynload/tensorrt.h @@ -14,7 +14,9 @@ limitations under the License. */ #pragma once #include +#ifdef USE_NVINFER_PLUGIN #include +#endif #if !defined(_WIN32) #include #endif @@ -29,13 +31,15 @@ namespace platform { namespace dynload { void* GetTensorRtHandle(); -void* GetTensorRtPluginHandle(); extern std::once_flag tensorrt_dso_flag; extern void* tensorrt_dso_handle; +#ifdef USE_NVINFER_PLUGIN +void* GetTensorRtPluginHandle(); extern std::once_flag tensorrt_plugin_dso_flag; extern void* tensorrt_plugin_dso_handle; +#endif #define DECLARE_DYNAMIC_LOAD_TENSORRT_WRAP(__name) \ struct DynLoad__##__name { \ @@ -88,7 +92,9 @@ extern void* tensorrt_plugin_dso_handle; __macro(initLibNvInferPlugins); TENSORRT_RAND_ROUTINE_EACH(DECLARE_DYNAMIC_LOAD_TENSORRT_WRAP) +#ifdef USE_NVINFER_PLUGIN TENSORRT_PLUGIN_RAND_ROUTINE_EACH(DECLARE_DYNAMIC_LOAD_TENSORRT_PLUGIN_WRAP) +#endif } // namespace dynload } // namespace platform -- GitLab