From ebf689197d61af28110fa6b45e91527c47f68076 Mon Sep 17 00:00:00 2001 From: Shang Zhizhou Date: Thu, 3 Dec 2020 16:57:38 +0800 Subject: [PATCH] fix tensorrt output shape error (#29308) * fix tensorrt output shape error * fix unittest tensorrt_engine_op_test * fix code style for unitest --- .../inference/analysis/ir_passes/tensorrt_subgraph_pass.cc | 3 +++ paddle/fluid/operators/tensorrt/tensorrt_engine_op.h | 7 ++++++- paddle/fluid/operators/tensorrt/tensorrt_engine_op_test.cc | 2 ++ 3 files changed, 11 insertions(+), 1 deletion(-) diff --git a/paddle/fluid/inference/analysis/ir_passes/tensorrt_subgraph_pass.cc b/paddle/fluid/inference/analysis/ir_passes/tensorrt_subgraph_pass.cc index bf0d87da91f..158c834c256 100644 --- a/paddle/fluid/inference/analysis/ir_passes/tensorrt_subgraph_pass.cc +++ b/paddle/fluid/inference/analysis/ir_passes/tensorrt_subgraph_pass.cc @@ -151,9 +151,11 @@ void TensorRtSubgraphPass::CreateTensorRTOp( std::set output_names; std::set output_names_with_id; + std::vector origin_output_dims; for (auto *x : node->outputs) { output_names.insert(x->Name()); output_names_with_id.insert(x->Name() + std::to_string(x->id())); + origin_output_dims.push_back(x->Var()->GetShape().size()); } std::unordered_map output_name_map; @@ -224,6 +226,7 @@ void TensorRtSubgraphPass::CreateTensorRTOp( op_desc->SetAttr("workspace_size", Get("workspace_size")); op_desc->SetAttr("gpu_id", Get("gpu_device_id")); op_desc->SetAttr("output_name_mapping", output_mapping); + op_desc->SetAttr("origin_output_dims", origin_output_dims); op_desc->SetAttr("parameters", params); // we record all inputs' shapes in attr to check if they are consistent diff --git a/paddle/fluid/operators/tensorrt/tensorrt_engine_op.h b/paddle/fluid/operators/tensorrt/tensorrt_engine_op.h index 792737865ba..b8805c025a7 100644 --- a/paddle/fluid/operators/tensorrt/tensorrt_engine_op.h +++ b/paddle/fluid/operators/tensorrt/tensorrt_engine_op.h @@ -288,6 +288,8 @@ class TensorRTEngineOp : public framework::OperatorBase { // Bind output tensor to TRT. int output_index = 0; + std::vector origin_output_dims = + Attr>("origin_output_dims"); VLOG(4) << "TensorRT Engine Op Outputs:"; for (const auto &y : Outputs("Ys")) { const int bind_index = @@ -306,7 +308,10 @@ class TensorRTEngineOp : public framework::OperatorBase { auto dims = trt_context->getBindingDimensions(bind_index); int nb_dims = dims.nbDims; for (; nb_dims > 0; nb_dims--) { - if (dims.d[nb_dims - 1] != 1) break; + // some 'x 1' of shape is normal, no need to remove it + if (dims.d[nb_dims - 1] != 1 || + nb_dims == origin_output_dims[output_index]) + break; } for (int i = 0; i < nb_dims; i++) ddim.push_back(dims.d[i]); #endif diff --git a/paddle/fluid/operators/tensorrt/tensorrt_engine_op_test.cc b/paddle/fluid/operators/tensorrt/tensorrt_engine_op_test.cc index e813e9ca757..1dcaccd6e92 100644 --- a/paddle/fluid/operators/tensorrt/tensorrt_engine_op_test.cc +++ b/paddle/fluid/operators/tensorrt/tensorrt_engine_op_test.cc @@ -109,6 +109,7 @@ TEST(TensorRTEngineOp, manual) { engine_op_desc.SetAttr("use_calib_mode", static_cast(false)); engine_op_desc.SetAttr("output_name_mapping", std::vector({"z0"})); + engine_op_desc.SetAttr("origin_output_dims", std::vector({2})); engine_op_desc.SetAttr("subgraph", std::string(block_->SerializeAsString())); engine_op_desc.SetAttr("engine_serialized_data", std::string("")); int device_id = 0; @@ -210,6 +211,7 @@ void Execute(int batch_size, int input_dim, int output_dim, int nlayers = 1) { engine_op_desc.SetAttr("use_calib_mode", static_cast(false)); engine_op_desc.SetAttr("output_name_mapping", std::vector({"z3"})); + engine_op_desc.SetAttr("origin_output_dims", std::vector({2})); engine_op_desc.SetAttr("subgraph", std::string(block_->SerializeAsString())); engine_op_desc.SetAttr("engine_serialized_data", std::string("")); int device_id = 0; -- GitLab