未验证 提交 2ca65904 编写于 作者: J JingZhuangzhuang 提交者: GitHub

cherry pick delay tensorrt log (#45958)

* cherry pick delay tensorrt log
* Update trt_plugin.h
上级 2fac8abb
...@@ -32,6 +32,7 @@ ...@@ -32,6 +32,7 @@
#include "paddle/fluid/inference/tensorrt/engine.h" #include "paddle/fluid/inference/tensorrt/engine.h"
#include "paddle/fluid/inference/tensorrt/helper.h" #include "paddle/fluid/inference/tensorrt/helper.h"
#include "paddle/fluid/inference/tensorrt/op_teller.h" #include "paddle/fluid/inference/tensorrt/op_teller.h"
#include "paddle/fluid/inference/tensorrt/plugin/trt_plugin.h"
#include "paddle/fluid/inference/utils/io_utils.h" #include "paddle/fluid/inference/utils/io_utils.h"
#include "paddle/phi/common/backend.h" #include "paddle/phi/common/backend.h"
#include "paddle/phi/common/data_type.h" #include "paddle/phi/common/data_type.h"
...@@ -117,6 +118,11 @@ void analysis::TensorRtSubgraphPass::ApplyImpl( ...@@ -117,6 +118,11 @@ void analysis::TensorRtSubgraphPass::ApplyImpl(
framework::ir::Graph *graph) const { framework::ir::Graph *graph) const {
framework::ir::FusePassBase::Init("tensorrt_subgraph_pass", graph); framework::ir::FusePassBase::Init("tensorrt_subgraph_pass", graph);
static std::once_flag trt_plugin_registered;
std::call_once(trt_plugin_registered, []() {
tensorrt::plugin::TrtPluginRegistry::Global()->RegistToTrt();
});
auto model_precision = auto model_precision =
static_cast<phi::DataType>(Get<int>("model_precision")); static_cast<phi::DataType>(Get<int>("model_precision"));
if (model_precision == phi::DataType::BFLOAT16) { if (model_precision == phi::DataType::BFLOAT16) {
......
...@@ -38,6 +38,13 @@ namespace inference { ...@@ -38,6 +38,13 @@ namespace inference {
namespace tensorrt { namespace tensorrt {
namespace plugin { namespace plugin {
#if defined(_WIN32)
#define UNUSED
#define __builtin_expect(EXP, C) (EXP)
#else
#define UNUSED __attribute__((unused))
#endif
class PluginTensorRT; class PluginTensorRT;
typedef std::function<PluginTensorRT*(const void*, size_t)> typedef std::function<PluginTensorRT*(const void*, size_t)>
...@@ -372,6 +379,26 @@ class TensorRTPluginCreator : public nvinfer1::IPluginCreator { ...@@ -372,6 +379,26 @@ class TensorRTPluginCreator : public nvinfer1::IPluginCreator {
std::vector<nvinfer1::PluginField> plugin_attributes_; std::vector<nvinfer1::PluginField> plugin_attributes_;
}; };
class TrtPluginRegistry {
public:
static TrtPluginRegistry* Global() {
static TrtPluginRegistry registry;
return &registry;
}
bool Regist(const std::string& name, const std::function<void()>& func) {
map.emplace(name, func);
return true;
}
void RegistToTrt() {
for (auto& it : map) {
it.second();
}
}
private:
std::unordered_map<std::string, std::function<void()>> map;
};
template <typename T> template <typename T>
class TrtPluginRegistrarV2 { class TrtPluginRegistrarV2 {
public: public:
...@@ -386,9 +413,14 @@ class TrtPluginRegistrarV2 { ...@@ -386,9 +413,14 @@ class TrtPluginRegistrarV2 {
T creator; T creator;
}; };
#define REGISTER_TRT_PLUGIN_V2(name) \ #define REGISTER_TRT_PLUGIN_V2(name) REGISTER_TRT_PLUGIN_V2_HELPER(name)
#define REGISTER_TRT_PLUGIN_V2_HELPER(name) \
UNUSED static bool REGISTER_TRT_PLUGIN_V2_HELPER##name = \
TrtPluginRegistry::Global()->Regist(#name, []() -> void { \
static paddle::inference::tensorrt::plugin::TrtPluginRegistrarV2<name> \ static paddle::inference::tensorrt::plugin::TrtPluginRegistrarV2<name> \
plugin_registrar_##name {} plugin_registrar_##name{}; \
});
} // namespace plugin } // namespace plugin
} // namespace tensorrt } // namespace tensorrt
......
...@@ -284,6 +284,7 @@ class TensorRTDynamicTestFusedTokenPrune : public ::testing::Test { ...@@ -284,6 +284,7 @@ class TensorRTDynamicTestFusedTokenPrune : public ::testing::Test {
TEST_F(TensorRTDynamicTestFusedTokenPrune, test_fused_token_prune) { TEST_F(TensorRTDynamicTestFusedTokenPrune, test_fused_token_prune) {
#if IS_TRT_VERSION_GE(8000) #if IS_TRT_VERSION_GE(8000)
tensorrt::plugin::TrtPluginRegistry::Global()->RegistToTrt();
auto *attn = engine_->DeclareInput( auto *attn = engine_->DeclareInput(
"attn", nvinfer1::DataType::kHALF, nvinfer1::Dims4{-1, 1, 4, 4}); "attn", nvinfer1::DataType::kHALF, nvinfer1::Dims4{-1, 1, 4, 4});
auto *x = engine_->DeclareInput( auto *x = engine_->DeclareInput(
......
...@@ -41,21 +41,11 @@ void* GetDsoHandle(const std::string& dso_name) { ...@@ -41,21 +41,11 @@ void* GetDsoHandle(const std::string& dso_name) {
void* dso_handle = dlopen(dso_name.c_str(), dynload_flags); void* dso_handle = dlopen(dso_name.c_str(), dynload_flags);
if (nullptr == dso_handle) { PADDLE_ENFORCE_NOT_NULL(dso_handle,
auto error_msg = paddle::platform::errors::NotFound(
"You are using Paddle compiled with TensorRT, but TensorRT dynamic " "TensorRT is needed, "
"library is not found. Ignore this if TensorRT is not needed.\n" "but TensorRT dynamic library is not found."));
"The TensorRT that Paddle depends on is not configured correctly.\n"
" Suggestions:\n"
" 1. Check if the TensorRT is installed correctly and its version"
" is matched with paddlepaddle you installed.\n"
" 2. Configure environment variables as "
"follows:\n"
" - Linux: set LD_LIBRARY_PATH by `export LD_LIBRARY_PATH=...`\n"
" - Windows: set PATH by `set PATH=XXX;%PATH%`\n"
" - Mac: set DYLD_LIBRARY_PATH by `export DYLD_LIBRARY_PATH=...`\n";
LOG(WARNING) << error_msg;
}
return dso_handle; return dso_handle;
} }
......
...@@ -40,21 +40,10 @@ void* GetDsoHandle(const std::string& dso_name) { ...@@ -40,21 +40,10 @@ void* GetDsoHandle(const std::string& dso_name) {
void* dso_handle = dlopen(dso_name.c_str(), dynload_flags); void* dso_handle = dlopen(dso_name.c_str(), dynload_flags);
if (nullptr == dso_handle) { PADDLE_ENFORCE_NOT_NULL(dso_handle,
auto error_msg = paddle::platform::errors::NotFound(
"You are using Paddle compiled with TensorRT, but TensorRT dynamic " "TensorRT is needed, "
"library is not found. Ignore this if TensorRT is not needed.\n" "but TensorRT dynamic library is not found."));
"The TensorRT that Paddle depends on is not configured correctly.\n"
" Suggestions:\n"
" 1. Check if the TensorRT is installed correctly and its version"
" is matched with paddlepaddle you installed.\n"
" 2. Configure environment variables as "
"follows:\n"
" - Linux: set LD_LIBRARY_PATH by `export LD_LIBRARY_PATH=...`\n"
" - Windows: set PATH by `set PATH=XXX;%PATH%`\n"
" - Mac: set DYLD_LIBRARY_PATH by `export DYLD_LIBRARY_PATH=...`\n";
LOG(WARNING) << error_msg;
}
return dso_handle; return dso_handle;
} }
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册