提交 591d2c67 编写于 作者: R root

according to version of develop, modify some tensorrt error in v1.8

上级 5260e473
...@@ -56,9 +56,11 @@ static nvinfer1::IRuntime* createInferRuntime(nvinfer1::ILogger* logger) { ...@@ -56,9 +56,11 @@ static nvinfer1::IRuntime* createInferRuntime(nvinfer1::ILogger* logger) {
return static_cast<nvinfer1::IRuntime*>( return static_cast<nvinfer1::IRuntime*>(
dy::createInferRuntime_INTERNAL(logger, NV_TENSORRT_VERSION)); dy::createInferRuntime_INTERNAL(logger, NV_TENSORRT_VERSION));
} }
static nvinfer1::IPluginRegistry* getPluginRegistry() { #if IS_TRT_VERSION_GE(6000)
static nvinfer1::IPluginRegistry* GetPluginRegistry() {
return static_cast<nvinfer1::IPluginRegistry*>(dy::getPluginRegistry()); return static_cast<nvinfer1::IPluginRegistry*>(dy::getPluginRegistry());
} }
#endif
// A logger for create TensorRT infer builder. // A logger for create TensorRT infer builder.
class NaiveLogger : public nvinfer1::ILogger { class NaiveLogger : public nvinfer1::ILogger {
......
...@@ -178,7 +178,6 @@ class DynamicPluginTensorRT : public nvinfer1::IPluginV2DynamicExt { ...@@ -178,7 +178,6 @@ class DynamicPluginTensorRT : public nvinfer1::IPluginV2DynamicExt {
std::string name_space_; std::string name_space_;
std::string plugin_base_; std::string plugin_base_;
}; };
#endif
template <typename T> template <typename T>
class TrtPluginRegistrarV2 { class TrtPluginRegistrarV2 {
...@@ -197,6 +196,7 @@ class TrtPluginRegistrarV2 { ...@@ -197,6 +196,7 @@ class TrtPluginRegistrarV2 {
#define REGISTER_TRT_PLUGIN_V2(name) \ #define REGISTER_TRT_PLUGIN_V2(name) \
static paddle::inference::tensorrt::plugin::TrtPluginRegistrarV2<name> \ static paddle::inference::tensorrt::plugin::TrtPluginRegistrarV2<name> \
plugin_registrar_##name {} plugin_registrar_##name {}
#endif
} // namespace plugin } // namespace plugin
} // namespace tensorrt } // namespace tensorrt
......
...@@ -36,7 +36,6 @@ extern void* tensorrt_dso_handle; ...@@ -36,7 +36,6 @@ extern void* tensorrt_dso_handle;
struct DynLoad__##__name { \ struct DynLoad__##__name { \
template <typename... Args> \ template <typename... Args> \
auto operator()(Args... args) -> DECLARE_TYPE(__name, args...) { \ auto operator()(Args... args) -> DECLARE_TYPE(__name, args...) { \
using tensorrt_func = decltype(&::__name); \
std::call_once(tensorrt_dso_flag, []() { \ std::call_once(tensorrt_dso_flag, []() { \
tensorrt_dso_handle = paddle::platform::dynload::GetTensorRtHandle(); \ tensorrt_dso_handle = paddle::platform::dynload::GetTensorRtHandle(); \
}); \ }); \
...@@ -44,15 +43,22 @@ extern void* tensorrt_dso_handle; ...@@ -44,15 +43,22 @@ extern void* tensorrt_dso_handle;
if (p_##__name == nullptr) { \ if (p_##__name == nullptr) { \
return nullptr; \ return nullptr; \
} \ } \
using tensorrt_func = decltype(&::__name); \
return reinterpret_cast<tensorrt_func>(p_##__name)(args...); \ return reinterpret_cast<tensorrt_func>(p_##__name)(args...); \
} \ } \
}; \ }; \
extern DynLoad__##__name __name extern DynLoad__##__name __name
#if (NV_TENSORRT_MAJOR >= 6)
#define TENSORRT_RAND_ROUTINE_EACH(__macro) \ #define TENSORRT_RAND_ROUTINE_EACH(__macro) \
__macro(createInferBuilder_INTERNAL); \ __macro(createInferBuilder_INTERNAL); \
__macro(createInferRuntime_INTERNAL); \ __macro(createInferRuntime_INTERNAL); \
__macro(getPluginRegistry); __macro(getPluginRegistry);
#else
#define TENSORRT_RAND_ROUTINE_EACH(__macro) \
__macro(createInferBuilder_INTERNAL); \
__macro(createInferRuntime_INTERNAL);
#endif
TENSORRT_RAND_ROUTINE_EACH(DECLARE_DYNAMIC_LOAD_TENSORRT_WRAP) TENSORRT_RAND_ROUTINE_EACH(DECLARE_DYNAMIC_LOAD_TENSORRT_WRAP)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册