diff --git a/paddle/fluid/inference/analysis/ir_passes/tensorrt_subgraph_pass.cc b/paddle/fluid/inference/analysis/ir_passes/tensorrt_subgraph_pass.cc index cd79b3fcde0ef8eb8627f051fbce05f4fc3b6f8d..60e0864a9be43f5f8cb25068e22e6b2e239ac7f7 100644 --- a/paddle/fluid/inference/analysis/ir_passes/tensorrt_subgraph_pass.cc +++ b/paddle/fluid/inference/analysis/ir_passes/tensorrt_subgraph_pass.cc @@ -32,6 +32,7 @@ #include "paddle/fluid/inference/tensorrt/engine.h" #include "paddle/fluid/inference/tensorrt/helper.h" #include "paddle/fluid/inference/tensorrt/op_teller.h" +#include "paddle/fluid/inference/tensorrt/plugin/trt_plugin.h" #include "paddle/fluid/inference/utils/io_utils.h" #include "paddle/phi/common/backend.h" #include "paddle/phi/common/data_type.h" @@ -117,6 +118,11 @@ void analysis::TensorRtSubgraphPass::ApplyImpl( framework::ir::Graph *graph) const { framework::ir::FusePassBase::Init("tensorrt_subgraph_pass", graph); + static std::once_flag trt_plugin_registered; + std::call_once(trt_plugin_registered, []() { + tensorrt::plugin::TrtPluginRegistry::Global()->RegistToTrt(); + }); + auto model_precision = static_cast(Get("model_precision")); if (model_precision == phi::DataType::BFLOAT16) { diff --git a/paddle/fluid/inference/tensorrt/plugin/trt_plugin.h b/paddle/fluid/inference/tensorrt/plugin/trt_plugin.h index 8c105230d27d4f7b85c6d695c9a485dc1a7288f6..f08a8a75ba4067838642dfc1763db86e2a5ea138 100644 --- a/paddle/fluid/inference/tensorrt/plugin/trt_plugin.h +++ b/paddle/fluid/inference/tensorrt/plugin/trt_plugin.h @@ -38,6 +38,13 @@ namespace inference { namespace tensorrt { namespace plugin { +#if defined(_WIN32) +#define UNUSED +#define __builtin_expect(EXP, C) (EXP) +#else +#define UNUSED __attribute__((unused)) +#endif + class PluginTensorRT; typedef std::function @@ -372,6 +379,26 @@ class TensorRTPluginCreator : public nvinfer1::IPluginCreator { std::vector plugin_attributes_; }; +class TrtPluginRegistry { + public: + static TrtPluginRegistry* Global() { + static TrtPluginRegistry registry; + return ®istry; + } + bool Regist(const std::string& name, const std::function& func) { + map.emplace(name, func); + return true; + } + void RegistToTrt() { + for (auto& it : map) { + it.second(); + } + } + + private: + std::unordered_map> map; +}; + template class TrtPluginRegistrarV2 { public: @@ -386,9 +413,14 @@ class TrtPluginRegistrarV2 { T creator; }; -#define REGISTER_TRT_PLUGIN_V2(name) \ - static paddle::inference::tensorrt::plugin::TrtPluginRegistrarV2 \ - plugin_registrar_##name {} +#define REGISTER_TRT_PLUGIN_V2(name) REGISTER_TRT_PLUGIN_V2_HELPER(name) + +#define REGISTER_TRT_PLUGIN_V2_HELPER(name) \ + UNUSED static bool REGISTER_TRT_PLUGIN_V2_HELPER##name = \ + TrtPluginRegistry::Global()->Regist(#name, []() -> void { \ + static paddle::inference::tensorrt::plugin::TrtPluginRegistrarV2 \ + plugin_registrar_##name{}; \ + }); } // namespace plugin } // namespace tensorrt diff --git a/paddle/fluid/inference/tensorrt/test_dynamic_engine.cc b/paddle/fluid/inference/tensorrt/test_dynamic_engine.cc index 6ac23e32856becf88e1a238d255ad2bdee272316..43e219232d111ee429a63ac24f5ef9d86c61b676 100644 --- a/paddle/fluid/inference/tensorrt/test_dynamic_engine.cc +++ b/paddle/fluid/inference/tensorrt/test_dynamic_engine.cc @@ -284,6 +284,7 @@ class TensorRTDynamicTestFusedTokenPrune : public ::testing::Test { TEST_F(TensorRTDynamicTestFusedTokenPrune, test_fused_token_prune) { #if IS_TRT_VERSION_GE(8000) + tensorrt::plugin::TrtPluginRegistry::Global()->RegistToTrt(); auto *attn = engine_->DeclareInput( "attn", nvinfer1::DataType::kHALF, nvinfer1::Dims4{-1, 1, 4, 4}); auto *x = engine_->DeclareInput( diff --git a/paddle/fluid/platform/dynload/tensorrt.cc b/paddle/fluid/platform/dynload/tensorrt.cc index 8d700faac0c1485860cfb49f26a1ae37b155357f..d3f3eeadee1eba593576997498e7a2475a81fb9d 100644 --- a/paddle/fluid/platform/dynload/tensorrt.cc +++ b/paddle/fluid/platform/dynload/tensorrt.cc @@ -41,21 +41,11 @@ void* GetDsoHandle(const std::string& dso_name) { void* dso_handle = dlopen(dso_name.c_str(), dynload_flags); - if (nullptr == dso_handle) { - auto error_msg = - "You are using Paddle compiled with TensorRT, but TensorRT dynamic " - "library is not found. Ignore this if TensorRT is not needed.\n" - "The TensorRT that Paddle depends on is not configured correctly.\n" - " Suggestions:\n" - " 1. Check if the TensorRT is installed correctly and its version" - " is matched with paddlepaddle you installed.\n" - " 2. Configure environment variables as " - "follows:\n" - " - Linux: set LD_LIBRARY_PATH by `export LD_LIBRARY_PATH=...`\n" - " - Windows: set PATH by `set PATH=XXX;%PATH%`\n" - " - Mac: set DYLD_LIBRARY_PATH by `export DYLD_LIBRARY_PATH=...`\n"; - LOG(WARNING) << error_msg; - } + PADDLE_ENFORCE_NOT_NULL(dso_handle, + paddle::platform::errors::NotFound( + "TensorRT is needed, " + "but TensorRT dynamic library is not found.")); + return dso_handle; } diff --git a/paddle/phi/backends/dynload/tensorrt.cc b/paddle/phi/backends/dynload/tensorrt.cc index 45525701020250f8e1b060613284cf98d1a0bbf6..2e2319a47cc542d37e492bcdb7189d72fb8dfa06 100644 --- a/paddle/phi/backends/dynload/tensorrt.cc +++ b/paddle/phi/backends/dynload/tensorrt.cc @@ -40,21 +40,10 @@ void* GetDsoHandle(const std::string& dso_name) { void* dso_handle = dlopen(dso_name.c_str(), dynload_flags); - if (nullptr == dso_handle) { - auto error_msg = - "You are using Paddle compiled with TensorRT, but TensorRT dynamic " - "library is not found. Ignore this if TensorRT is not needed.\n" - "The TensorRT that Paddle depends on is not configured correctly.\n" - " Suggestions:\n" - " 1. Check if the TensorRT is installed correctly and its version" - " is matched with paddlepaddle you installed.\n" - " 2. Configure environment variables as " - "follows:\n" - " - Linux: set LD_LIBRARY_PATH by `export LD_LIBRARY_PATH=...`\n" - " - Windows: set PATH by `set PATH=XXX;%PATH%`\n" - " - Mac: set DYLD_LIBRARY_PATH by `export DYLD_LIBRARY_PATH=...`\n"; - LOG(WARNING) << error_msg; - } + PADDLE_ENFORCE_NOT_NULL(dso_handle, + paddle::platform::errors::NotFound( + "TensorRT is needed, " + "but TensorRT dynamic library is not found.")); return dso_handle; }