From 179d4264cae53014c256cde6e1eeb8f29a10c2f3 Mon Sep 17 00:00:00 2001 From: ming1753 <61511741+ming1753@users.noreply.github.com> Date: Mon, 4 Sep 2023 13:09:47 +0800 Subject: [PATCH] Modify MarkTrtEngineOutputs API (#56858) * Modify MarkTrtEngineOutputs API --- paddle/fluid/inference/analysis/argument.h | 1 + .../inference/analysis/ir_pass_manager.cc | 2 ++ .../ir_passes/tensorrt_subgraph_pass.cc | 25 ++++++++++--------- paddle/fluid/inference/api/analysis_config.cc | 5 +++- .../fluid/inference/api/analysis_predictor.cc | 1 + .../inference/api/paddle_analysis_config.h | 4 ++- paddle/fluid/pybind/inference_api.cc | 3 ++- .../api/trt_mark_trt_engine_outputs_test.cc | 8 +++--- 8 files changed, 31 insertions(+), 18 deletions(-) diff --git a/paddle/fluid/inference/analysis/argument.h b/paddle/fluid/inference/analysis/argument.h index b3757886e2f..73bd1cb5e6c 100644 --- a/paddle/fluid/inference/analysis/argument.h +++ b/paddle/fluid/inference/analysis/argument.h @@ -241,6 +241,7 @@ struct Argument { DECL_ARGUMENT_FIELD(tensorrt_workspace_size, TensorRtWorkspaceSize, int64_t); DECL_ARGUMENT_FIELD(tensorrt_min_subgraph_size, TensorRtMinSubgraphSize, int); DECL_ARGUMENT_FIELD(trt_mark_output, TRTMarkOutput, bool); + DECL_ARGUMENT_FIELD(trt_mark_output_with_id, TRTMarkOutputWithId, bool); DECL_ARGUMENT_FIELD(trt_output_tensor_names, TRTOutputTensorNames, std::vector); diff --git a/paddle/fluid/inference/analysis/ir_pass_manager.cc b/paddle/fluid/inference/analysis/ir_pass_manager.cc index 47091b347c6..ab3f1de01bd 100644 --- a/paddle/fluid/inference/analysis/ir_pass_manager.cc +++ b/paddle/fluid/inference/analysis/ir_pass_manager.cc @@ -163,6 +163,8 @@ void IRPassManager::CreatePasses(Argument *argument, pass->Set("min_subgraph_size", new int(argument->tensorrt_min_subgraph_size())); pass->Set("mark_output", new bool(argument->trt_mark_output())); + pass->Set("mark_output_with_id", + new bool(argument->trt_mark_output_with_id())); pass->Set( "output_tensor_names", new std::vector(argument->trt_output_tensor_names())); diff --git a/paddle/fluid/inference/analysis/ir_passes/tensorrt_subgraph_pass.cc b/paddle/fluid/inference/analysis/ir_passes/tensorrt_subgraph_pass.cc index e65aff11180..1fc24afe494 100644 --- a/paddle/fluid/inference/analysis/ir_passes/tensorrt_subgraph_pass.cc +++ b/paddle/fluid/inference/analysis/ir_passes/tensorrt_subgraph_pass.cc @@ -376,29 +376,30 @@ std::string TensorRtSubgraphPass::CreateTensorRTOp( std::vector origin_outputs_dtype; std::map map_origin_outputs_dtype; - // Whether to mark Outpus + // Mark TensorRT output nodes as trt outputs auto mark_output = Get("mark_output"); auto output_tensor_name = Get>("output_tensor_names"); - VLOG(1) << "mark Output: " << mark_output; + auto mark_output_with_id = Get("mark_output_with_id"); - if (mark_output == 1) { + if (mark_output) { VLOG(1) << "begin to mark output ..."; for (auto node : subgraph) { if (node->NodeType() == Node::Type::kOperation) { - if (node->Op()->Outputs().count("Xshape")) continue; for (auto *x : node->outputs) { if (std::count(parameters.begin(), parameters.end(), x->Name()) > 0) continue; - if (!output_tensor_name.empty() && - std::count(output_tensor_name.begin(), - output_tensor_name.end(), - x->Name())) { - VLOG(1) << "output " << x->Name() << " has been marked"; - std::string output_name_withid = - x->Name() + std::to_string(x->id()); + std::string name_with_id = x->Name() + std::to_string(x->id()); + if (((!mark_output_with_id && std::count(output_tensor_name.begin(), + output_tensor_name.end(), + x->Name()) > 0) || + (mark_output_with_id && std::count(output_tensor_name.begin(), + output_tensor_name.end(), + name_with_id) > 0)) && + !x->outputs.empty()) { + VLOG(3) << "output " << x->Name() << " has been marked"; output_names.insert(x->Name()); - output_names_with_id.insert(output_name_withid); + output_names_with_id.insert(name_with_id); origin_name_output_rank[x->Name()] = x->Var()->GetShape().size(); trt_outputs.insert(x); map_origin_outputs_dtype[x->Name()] = diff --git a/paddle/fluid/inference/api/analysis_config.cc b/paddle/fluid/inference/api/analysis_config.cc index 3f9ca0a58ed..efd450f2bf5 100644 --- a/paddle/fluid/inference/api/analysis_config.cc +++ b/paddle/fluid/inference/api/analysis_config.cc @@ -461,6 +461,7 @@ AnalysisConfig::AnalysisConfig(const AnalysisConfig &other) { CP_MEMBER(tensorrt_min_subgraph_size_); CP_MEMBER(tensorrt_precision_mode_); CP_MEMBER(trt_mark_output_); + CP_MEMBER(trt_mark_output_with_id_); CP_MEMBER(trt_output_tensor_names_); CP_MEMBER(trt_disabled_ops_); CP_MEMBER(trt_use_dla_); @@ -762,8 +763,10 @@ void AnalysisConfig::EnableTensorRtEngine(int64_t workspace_size, } void AnalysisConfig::MarkTrtEngineOutputs( - const std::vector &output_tensor_names) { + const std::vector &output_tensor_names, + const bool mark_output_with_id) { trt_mark_output_ = true; + trt_mark_output_with_id_ = mark_output_with_id; trt_output_tensor_names_ = output_tensor_names; } diff --git a/paddle/fluid/inference/api/analysis_predictor.cc b/paddle/fluid/inference/api/analysis_predictor.cc index 1fb7e2c1571..83f75c1ae07 100644 --- a/paddle/fluid/inference/api/analysis_predictor.cc +++ b/paddle/fluid/inference/api/analysis_predictor.cc @@ -1392,6 +1392,7 @@ void AnalysisPredictor::PrepareArgument() { argument_->SetTensorRtMaxBatchSize(config_.tensorrt_max_batchsize_); argument_->SetTensorRtMinSubgraphSize(config_.tensorrt_min_subgraph_size_); argument_->SetTRTMarkOutput(config_.trt_mark_output_); + argument_->SetTRTMarkOutputWithId(config_.trt_mark_output_with_id_); argument_->SetTRTOutputTensorNames(config_.trt_output_tensor_names_); argument_->SetTensorRtDisabledOPs(config_.trt_disabled_ops_); argument_->SetTensorRtUseDLA(config_.trt_use_dla_); diff --git a/paddle/fluid/inference/api/paddle_analysis_config.h b/paddle/fluid/inference/api/paddle_analysis_config.h index f1d193d0640..7348418d6e5 100644 --- a/paddle/fluid/inference/api/paddle_analysis_config.h +++ b/paddle/fluid/inference/api/paddle_analysis_config.h @@ -695,7 +695,8 @@ struct PD_INFER_DECL AnalysisConfig { /// \param output_tensor_names The name of the Tensor that needs to be marked /// void MarkTrtEngineOutputs( - const std::vector& output_tensor_names = {}); + const std::vector& output_tensor_names = {}, + const bool trt_mark_output_with_id = false); /// /// \brief Turn on the TensorRT memory optimization. /// @@ -1237,6 +1238,7 @@ struct PD_INFER_DECL AnalysisConfig { bool trt_use_varseqlen_{false}; bool trt_with_interleaved_{false}; bool trt_mark_output_{false}; + bool trt_mark_output_with_id_{false}; std::vector trt_output_tensor_names_{}; std::string tensorrt_transformer_posid_{""}; std::string tensorrt_transformer_maskid_{""}; diff --git a/paddle/fluid/pybind/inference_api.cc b/paddle/fluid/pybind/inference_api.cc index 16131ad12ea..1690d738a2c 100644 --- a/paddle/fluid/pybind/inference_api.cc +++ b/paddle/fluid/pybind/inference_api.cc @@ -896,7 +896,8 @@ void BindAnalysisConfig(py::module *m) { &AnalysisConfig::tensorrt_dynamic_shape_enabled) .def("mark_trt_engine_outputs", &AnalysisConfig::MarkTrtEngineOutputs, - py::arg("output_tensor_names") = std::vector({})) + py::arg("output_tensor_names") = std::vector({}), + py::arg("mark_output_with_id") = false) .def("enable_tensorrt_varseqlen", &AnalysisConfig::EnableVarseqlen) .def("tensorrt_varseqlen_enabled", &AnalysisConfig::tensorrt_varseqlen_enabled) diff --git a/test/cpp/inference/api/trt_mark_trt_engine_outputs_test.cc b/test/cpp/inference/api/trt_mark_trt_engine_outputs_test.cc index d34d640cfaf..7157f442ae6 100644 --- a/test/cpp/inference/api/trt_mark_trt_engine_outputs_test.cc +++ b/test/cpp/inference/api/trt_mark_trt_engine_outputs_test.cc @@ -24,9 +24,11 @@ TEST(TensorRT, mark_trt_engine_outputs) { config.EnableUseGpu(100, 0); config.EnableTensorRtEngine( 1 << 30, 1, 5, AnalysisConfig::Precision::kFloat32, false, false); - // The name of the tensor that needs to be marked, the default is empty (all - // marks) - std::vector markOutput = {"fc_0.tmp_0", "fc_0.tmp_1"}; + // The name of the tensor that needs to be marked + std::vector markOutput = {"pool2d_0.tmp_0", + "elementwise_add_0.tmp_0", + "conv2d_5.tmp_0", + "batch_norm_6.tmp_2"}; config.MarkTrtEngineOutputs(markOutput); std::vector> inputs_all; -- GitLab