未验证 提交 d7d35ff8 编写于 作者: J JingZhuangzhuang 提交者: GitHub

delay tensorrt registry (#45824)

* Delay TensorRT registry
* Add unused define
* Fix TensorRT test
* fix function to reference
* Update trt_plugin.h
上级 6891a4fe
......@@ -32,6 +32,7 @@
#include "paddle/fluid/inference/tensorrt/engine.h"
#include "paddle/fluid/inference/tensorrt/helper.h"
#include "paddle/fluid/inference/tensorrt/op_teller.h"
#include "paddle/fluid/inference/tensorrt/plugin/trt_plugin.h"
#include "paddle/fluid/inference/utils/io_utils.h"
#include "paddle/phi/common/backend.h"
#include "paddle/phi/common/data_type.h"
......@@ -117,6 +118,11 @@ void analysis::TensorRtSubgraphPass::ApplyImpl(
framework::ir::Graph *graph) const {
framework::ir::FusePassBase::Init("tensorrt_subgraph_pass", graph);
static std::once_flag trt_plugin_registered;
std::call_once(trt_plugin_registered, []() {
tensorrt::plugin::TrtPluginRegistry::Global()->RegistToTrt();
});
auto model_precision =
static_cast<phi::DataType>(Get<int>("model_precision"));
if (model_precision == phi::DataType::BFLOAT16) {
......
......@@ -38,6 +38,13 @@ namespace inference {
namespace tensorrt {
namespace plugin {
#if defined(_WIN32)
#define UNUSED
#define __builtin_expect(EXP, C) (EXP)
#else
#define UNUSED __attribute__((unused))
#endif
class PluginTensorRT;
typedef std::function<PluginTensorRT*(const void*, size_t)>
......@@ -372,6 +379,26 @@ class TensorRTPluginCreator : public nvinfer1::IPluginCreator {
std::vector<nvinfer1::PluginField> plugin_attributes_;
};
class TrtPluginRegistry {
public:
static TrtPluginRegistry* Global() {
static TrtPluginRegistry registry;
return &registry;
}
bool Regist(const std::string& name, const std::function<void()>& func) {
map.emplace(name, func);
return true;
}
void RegistToTrt() {
for (auto& it : map) {
it.second();
}
}
private:
std::unordered_map<std::string, std::function<void()>> map;
};
template <typename T>
class TrtPluginRegistrarV2 {
public:
......@@ -386,9 +413,14 @@ class TrtPluginRegistrarV2 {
T creator;
};
#define REGISTER_TRT_PLUGIN_V2(name) \
#define REGISTER_TRT_PLUGIN_V2(name) REGISTER_TRT_PLUGIN_V2_HELPER(name)
#define REGISTER_TRT_PLUGIN_V2_HELPER(name) \
UNUSED static bool REGISTER_TRT_PLUGIN_V2_HELPER##name = \
TrtPluginRegistry::Global()->Regist(#name, []() -> void { \
static paddle::inference::tensorrt::plugin::TrtPluginRegistrarV2<name> \
plugin_registrar_##name {}
plugin_registrar_##name{}; \
});
} // namespace plugin
} // namespace tensorrt
......
......@@ -284,6 +284,7 @@ class TensorRTDynamicTestFusedTokenPrune : public ::testing::Test {
TEST_F(TensorRTDynamicTestFusedTokenPrune, test_fused_token_prune) {
#if IS_TRT_VERSION_GE(8000)
tensorrt::plugin::TrtPluginRegistry::Global()->RegistToTrt();
auto *attn = engine_->DeclareInput(
"attn", nvinfer1::DataType::kHALF, nvinfer1::Dims4{-1, 1, 4, 4});
auto *x = engine_->DeclareInput(
......
......@@ -41,21 +41,11 @@ void* GetDsoHandle(const std::string& dso_name) {
void* dso_handle = dlopen(dso_name.c_str(), dynload_flags);
if (nullptr == dso_handle) {
auto error_msg =
"You are using Paddle compiled with TensorRT, but TensorRT dynamic "
"library is not found. Ignore this if TensorRT is not needed.\n"
"The TensorRT that Paddle depends on is not configured correctly.\n"
" Suggestions:\n"
" 1. Check if the TensorRT is installed correctly and its version"
" is matched with paddlepaddle you installed.\n"
" 2. Configure environment variables as "
"follows:\n"
" - Linux: set LD_LIBRARY_PATH by `export LD_LIBRARY_PATH=...`\n"
" - Windows: set PATH by `set PATH=XXX;%PATH%`\n"
" - Mac: set DYLD_LIBRARY_PATH by `export DYLD_LIBRARY_PATH=...`\n";
LOG(WARNING) << error_msg;
}
PADDLE_ENFORCE_NOT_NULL(dso_handle,
paddle::platform::errors::NotFound(
"TensorRT is needed, "
"but TensorRT dynamic library is not found."));
return dso_handle;
}
......
......@@ -40,21 +40,10 @@ void* GetDsoHandle(const std::string& dso_name) {
void* dso_handle = dlopen(dso_name.c_str(), dynload_flags);
if (nullptr == dso_handle) {
auto error_msg =
"You are using Paddle compiled with TensorRT, but TensorRT dynamic "
"library is not found. Ignore this if TensorRT is not needed.\n"
"The TensorRT that Paddle depends on is not configured correctly.\n"
" Suggestions:\n"
" 1. Check if the TensorRT is installed correctly and its version"
" is matched with paddlepaddle you installed.\n"
" 2. Configure environment variables as "
"follows:\n"
" - Linux: set LD_LIBRARY_PATH by `export LD_LIBRARY_PATH=...`\n"
" - Windows: set PATH by `set PATH=XXX;%PATH%`\n"
" - Mac: set DYLD_LIBRARY_PATH by `export DYLD_LIBRARY_PATH=...`\n";
LOG(WARNING) << error_msg;
}
PADDLE_ENFORCE_NOT_NULL(dso_handle,
paddle::platform::errors::NotFound(
"TensorRT is needed, "
"but TensorRT dynamic library is not found."));
return dso_handle;
}
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册