提交 8c366149 编写于 作者: Z zlsh80826

add USE_NVINFER_PLUGIN define detection

上级 bb6b8480
......@@ -47,6 +47,8 @@ class GeluOpConverter : public OpConverter {
nvinfer1::ILayer* layer = nullptr;
if (engine_->with_dynamic_shape()) {
#if IS_TRT_VERSION_GE(6000)
#ifdef USE_NVINFER_PLUGIN
auto creator =
getPluginRegistry()->getPluginCreator("CustomGeluPluginDynamic", "1");
assert(creator != nullptr);
......@@ -68,6 +70,11 @@ class GeluOpConverter : public OpConverter {
creator->createPlugin("CustomGeluPluginDynamic", pluginPtr);
layer = engine_->network()->addPluginV2(&input, input_num, *pluginObj);
assert(layer != nullptr);
#else
plugin::GeluPluginDynamic* plugin = new plugin::GeluPluginDynamic();
layer = engine_->AddPluginV2(&input, input_num, plugin);
#endif
#else
PADDLE_THROW(platform::errors::Fatal(
"You are running the TRT Dynamic Shape mode, need to confirm that "
......
......@@ -157,8 +157,9 @@ class TensorRTEngine {
"version should be at least 6.";
#endif
}
#ifdef USE_NVINFER_PLUGIN
dy::initLibNvInferPlugins(&logger, "");
#endif
}
~TensorRTEngine() {}
......
......@@ -22,13 +22,18 @@ namespace dynload {
std::once_flag tensorrt_dso_flag;
void* tensorrt_dso_handle;
#ifdef USE_NVINFER_PLUGIN
std::once_flag tensorrt_plugin_dso_flag;
void* tensorrt_plugin_dso_handle;
#endif
#define DEFINE_WRAP(__name) DynLoad__##__name __name
TENSORRT_RAND_ROUTINE_EACH(DEFINE_WRAP);
#ifdef USE_NVINFER_PLUGIN
TENSORRT_PLUGIN_RAND_ROUTINE_EACH(DEFINE_WRAP);
#endif
void* GetDsoHandle(const std::string& dso_name) {
#if !defined(_WIN32)
......@@ -75,6 +80,7 @@ void* GetTensorRtHandle() {
return GetDsoHandle(dso_name);
}
#ifdef USE_NVINFER_PLUGIN
void* GetTensorRtPluginHandle() {
#if defined(__APPLE__) || defined(__OSX__)
std::string dso_name = "libnvinfer_plugin.dylib";
......@@ -85,6 +91,7 @@ void* GetTensorRtPluginHandle() {
#endif
return GetDsoHandle(dso_name);
}
#endif
} // namespace dynload
} // namespace platform
......
......@@ -14,7 +14,9 @@ limitations under the License. */
#pragma once
#include <NvInfer.h>
#ifdef USE_NVINFER_PLUGIN
#include <NvInferPlugin.h>
#endif
#if !defined(_WIN32)
#include <dlfcn.h>
#endif
......@@ -29,13 +31,15 @@ namespace platform {
namespace dynload {
void* GetTensorRtHandle();
void* GetTensorRtPluginHandle();
extern std::once_flag tensorrt_dso_flag;
extern void* tensorrt_dso_handle;
#ifdef USE_NVINFER_PLUGIN
void* GetTensorRtPluginHandle();
extern std::once_flag tensorrt_plugin_dso_flag;
extern void* tensorrt_plugin_dso_handle;
#endif
#define DECLARE_DYNAMIC_LOAD_TENSORRT_WRAP(__name) \
struct DynLoad__##__name { \
......@@ -88,7 +92,9 @@ extern void* tensorrt_plugin_dso_handle;
__macro(initLibNvInferPlugins);
TENSORRT_RAND_ROUTINE_EACH(DECLARE_DYNAMIC_LOAD_TENSORRT_WRAP)
#ifdef USE_NVINFER_PLUGIN
TENSORRT_PLUGIN_RAND_ROUTINE_EACH(DECLARE_DYNAMIC_LOAD_TENSORRT_PLUGIN_WRAP)
#endif
} // namespace dynload
} // namespace platform
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册