未验证 提交 5a69ddb9 编写于 作者: Z Zhang Jun 提交者: GitHub

[cherrypick][inference][trt]remove trt sparse weights flags (#53562) (#53850)

* remove kSPARSE_WEIGHTS

* remove kFASTER_DYNAMIC_SHAPES_0805 and add 'TrtMajorVersion' function
上级 2992f787
...@@ -556,17 +556,14 @@ void TensorRtSubgraphPass::CreateTensorRTOp( ...@@ -556,17 +556,14 @@ void TensorRtSubgraphPass::CreateTensorRTOp(
opt_input_shape = {}; opt_input_shape = {};
} }
auto to_major_version = [&](int full_version) -> float { const float trt_compile_version = tensorrt::TrtMajorVersion(TRT_VERSION);
return (full_version / 100) / 10.0; const float trt_runtime_version =
}; tensorrt::TrtMajorVersion(tensorrt::GetInferLibVersion());
const float compile_time_trt_version = to_major_version(TRT_VERSION); if (trt_compile_version != trt_runtime_version) {
const float run_time_trt_version =
to_major_version(tensorrt::GetInferLibVersion());
if (compile_time_trt_version != run_time_trt_version) {
LOG_FIRST_N(WARNING, 1) LOG_FIRST_N(WARNING, 1)
<< "The Paddle Inference library is compiled with " << "The Paddle Inference library is compiled with "
<< compile_time_trt_version << " version TensorRT, " << trt_compile_version << " version TensorRT, "
<< "but the runtime TensorRT you are using is " << run_time_trt_version << "but the runtime TensorRT you are using is " << trt_runtime_version
<< " version. " << " version. "
"This might cause serious compatibility issues. We strongly " "This might cause serious compatibility issues. We strongly "
"recommend using the same TRT version at runtime."; "recommend using the same TRT version at runtime.";
......
...@@ -158,12 +158,6 @@ void TensorRTEngine::FreezeNetwork() { ...@@ -158,12 +158,6 @@ void TensorRTEngine::FreezeNetwork() {
infer_builder_config_->setMaxWorkspaceSize(max_workspace_); infer_builder_config_->setMaxWorkspaceSize(max_workspace_);
#endif #endif
#if IS_TRT_VERSION_GE(8500)
infer_builder_config_->setPreviewFeature(
nvinfer1::PreviewFeature::kFASTER_DYNAMIC_SHAPES_0805, true);
#else
#endif
bool enable_fp16 = (precision_ == AnalysisConfig::Precision::kHalf); bool enable_fp16 = (precision_ == AnalysisConfig::Precision::kHalf);
if (enable_fp16) { if (enable_fp16) {
bool support_fp16 = infer_builder_->platformHasFastFp16(); bool support_fp16 = infer_builder_->platformHasFastFp16();
...@@ -325,7 +319,6 @@ void TensorRTEngine::FreezeNetwork() { ...@@ -325,7 +319,6 @@ void TensorRTEngine::FreezeNetwork() {
infer_engine_.reset(infer_builder_->buildEngineWithConfig( infer_engine_.reset(infer_builder_->buildEngineWithConfig(
*network(), *infer_builder_config_)); *network(), *infer_builder_config_));
#else #else
infer_builder_config_->setFlag(nvinfer1::BuilderFlag::kSPARSE_WEIGHTS);
ihost_memory_.reset(infer_builder_->buildSerializedNetwork( ihost_memory_.reset(infer_builder_->buildSerializedNetwork(
*network(), *infer_builder_config_)); *network(), *infer_builder_config_));
infer_ptr<nvinfer1::IRuntime> runtime(createInferRuntime(&logger_)); infer_ptr<nvinfer1::IRuntime> runtime(createInferRuntime(&logger_));
......
...@@ -96,6 +96,10 @@ static std::tuple<int, int, int> GetTrtCompileVersion() { ...@@ -96,6 +96,10 @@ static std::tuple<int, int, int> GetTrtCompileVersion() {
NV_TENSORRT_MAJOR, NV_TENSORRT_MINOR, NV_TENSORRT_PATCH}; NV_TENSORRT_MAJOR, NV_TENSORRT_MINOR, NV_TENSORRT_PATCH};
} }
static float TrtMajorVersion(int full_version) {
return (full_version / 100) / 10.0;
}
template <typename T> template <typename T>
struct Destroyer { struct Destroyer {
void operator()(T* x) { void operator()(T* x) {
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册