提交 8c366149 编写于 作者: Z zlsh80826

add USE_NVINFER_PLUGIN define detection

上级 bb6b8480
...@@ -47,6 +47,8 @@ class GeluOpConverter : public OpConverter { ...@@ -47,6 +47,8 @@ class GeluOpConverter : public OpConverter {
nvinfer1::ILayer* layer = nullptr; nvinfer1::ILayer* layer = nullptr;
if (engine_->with_dynamic_shape()) { if (engine_->with_dynamic_shape()) {
#if IS_TRT_VERSION_GE(6000) #if IS_TRT_VERSION_GE(6000)
#ifdef USE_NVINFER_PLUGIN
auto creator = auto creator =
getPluginRegistry()->getPluginCreator("CustomGeluPluginDynamic", "1"); getPluginRegistry()->getPluginCreator("CustomGeluPluginDynamic", "1");
assert(creator != nullptr); assert(creator != nullptr);
...@@ -68,6 +70,11 @@ class GeluOpConverter : public OpConverter { ...@@ -68,6 +70,11 @@ class GeluOpConverter : public OpConverter {
creator->createPlugin("CustomGeluPluginDynamic", pluginPtr); creator->createPlugin("CustomGeluPluginDynamic", pluginPtr);
layer = engine_->network()->addPluginV2(&input, input_num, *pluginObj); layer = engine_->network()->addPluginV2(&input, input_num, *pluginObj);
assert(layer != nullptr); assert(layer != nullptr);
#else
plugin::GeluPluginDynamic* plugin = new plugin::GeluPluginDynamic();
layer = engine_->AddPluginV2(&input, input_num, plugin);
#endif
#else #else
PADDLE_THROW(platform::errors::Fatal( PADDLE_THROW(platform::errors::Fatal(
"You are running the TRT Dynamic Shape mode, need to confirm that " "You are running the TRT Dynamic Shape mode, need to confirm that "
......
...@@ -157,8 +157,9 @@ class TensorRTEngine { ...@@ -157,8 +157,9 @@ class TensorRTEngine {
"version should be at least 6."; "version should be at least 6.";
#endif #endif
} }
#ifdef USE_NVINFER_PLUGIN
dy::initLibNvInferPlugins(&logger, ""); dy::initLibNvInferPlugins(&logger, "");
#endif
} }
~TensorRTEngine() {} ~TensorRTEngine() {}
......
...@@ -22,13 +22,18 @@ namespace dynload { ...@@ -22,13 +22,18 @@ namespace dynload {
std::once_flag tensorrt_dso_flag; std::once_flag tensorrt_dso_flag;
void* tensorrt_dso_handle; void* tensorrt_dso_handle;
#ifdef USE_NVINFER_PLUGIN
std::once_flag tensorrt_plugin_dso_flag; std::once_flag tensorrt_plugin_dso_flag;
void* tensorrt_plugin_dso_handle; void* tensorrt_plugin_dso_handle;
#endif
#define DEFINE_WRAP(__name) DynLoad__##__name __name #define DEFINE_WRAP(__name) DynLoad__##__name __name
TENSORRT_RAND_ROUTINE_EACH(DEFINE_WRAP); TENSORRT_RAND_ROUTINE_EACH(DEFINE_WRAP);
#ifdef USE_NVINFER_PLUGIN
TENSORRT_PLUGIN_RAND_ROUTINE_EACH(DEFINE_WRAP); TENSORRT_PLUGIN_RAND_ROUTINE_EACH(DEFINE_WRAP);
#endif
void* GetDsoHandle(const std::string& dso_name) { void* GetDsoHandle(const std::string& dso_name) {
#if !defined(_WIN32) #if !defined(_WIN32)
...@@ -75,6 +80,7 @@ void* GetTensorRtHandle() { ...@@ -75,6 +80,7 @@ void* GetTensorRtHandle() {
return GetDsoHandle(dso_name); return GetDsoHandle(dso_name);
} }
#ifdef USE_NVINFER_PLUGIN
void* GetTensorRtPluginHandle() { void* GetTensorRtPluginHandle() {
#if defined(__APPLE__) || defined(__OSX__) #if defined(__APPLE__) || defined(__OSX__)
std::string dso_name = "libnvinfer_plugin.dylib"; std::string dso_name = "libnvinfer_plugin.dylib";
...@@ -85,6 +91,7 @@ void* GetTensorRtPluginHandle() { ...@@ -85,6 +91,7 @@ void* GetTensorRtPluginHandle() {
#endif #endif
return GetDsoHandle(dso_name); return GetDsoHandle(dso_name);
} }
#endif
} // namespace dynload } // namespace dynload
} // namespace platform } // namespace platform
......
...@@ -14,7 +14,9 @@ limitations under the License. */ ...@@ -14,7 +14,9 @@ limitations under the License. */
#pragma once #pragma once
#include <NvInfer.h> #include <NvInfer.h>
#ifdef USE_NVINFER_PLUGIN
#include <NvInferPlugin.h> #include <NvInferPlugin.h>
#endif
#if !defined(_WIN32) #if !defined(_WIN32)
#include <dlfcn.h> #include <dlfcn.h>
#endif #endif
...@@ -29,13 +31,15 @@ namespace platform { ...@@ -29,13 +31,15 @@ namespace platform {
namespace dynload { namespace dynload {
void* GetTensorRtHandle(); void* GetTensorRtHandle();
void* GetTensorRtPluginHandle();
extern std::once_flag tensorrt_dso_flag; extern std::once_flag tensorrt_dso_flag;
extern void* tensorrt_dso_handle; extern void* tensorrt_dso_handle;
#ifdef USE_NVINFER_PLUGIN
void* GetTensorRtPluginHandle();
extern std::once_flag tensorrt_plugin_dso_flag; extern std::once_flag tensorrt_plugin_dso_flag;
extern void* tensorrt_plugin_dso_handle; extern void* tensorrt_plugin_dso_handle;
#endif
#define DECLARE_DYNAMIC_LOAD_TENSORRT_WRAP(__name) \ #define DECLARE_DYNAMIC_LOAD_TENSORRT_WRAP(__name) \
struct DynLoad__##__name { \ struct DynLoad__##__name { \
...@@ -88,7 +92,9 @@ extern void* tensorrt_plugin_dso_handle; ...@@ -88,7 +92,9 @@ extern void* tensorrt_plugin_dso_handle;
__macro(initLibNvInferPlugins); __macro(initLibNvInferPlugins);
TENSORRT_RAND_ROUTINE_EACH(DECLARE_DYNAMIC_LOAD_TENSORRT_WRAP) TENSORRT_RAND_ROUTINE_EACH(DECLARE_DYNAMIC_LOAD_TENSORRT_WRAP)
#ifdef USE_NVINFER_PLUGIN
TENSORRT_PLUGIN_RAND_ROUTINE_EACH(DECLARE_DYNAMIC_LOAD_TENSORRT_PLUGIN_WRAP) TENSORRT_PLUGIN_RAND_ROUTINE_EACH(DECLARE_DYNAMIC_LOAD_TENSORRT_PLUGIN_WRAP)
#endif
} // namespace dynload } // namespace dynload
} // namespace platform } // namespace platform
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册