diff --git a/paddle/fluid/inference/analysis/ir_passes/tensorrt_subgraph_pass.cc b/paddle/fluid/inference/analysis/ir_passes/tensorrt_subgraph_pass.cc index bf0d87da91f534ef8470636448a698074485be55..158c834c256f59bbb89c2938b268d077181285eb 100644 --- a/paddle/fluid/inference/analysis/ir_passes/tensorrt_subgraph_pass.cc +++ b/paddle/fluid/inference/analysis/ir_passes/tensorrt_subgraph_pass.cc @@ -151,9 +151,11 @@ void TensorRtSubgraphPass::CreateTensorRTOp( std::set output_names; std::set output_names_with_id; + std::vector origin_output_dims; for (auto *x : node->outputs) { output_names.insert(x->Name()); output_names_with_id.insert(x->Name() + std::to_string(x->id())); + origin_output_dims.push_back(x->Var()->GetShape().size()); } std::unordered_map output_name_map; @@ -224,6 +226,7 @@ void TensorRtSubgraphPass::CreateTensorRTOp( op_desc->SetAttr("workspace_size", Get("workspace_size")); op_desc->SetAttr("gpu_id", Get("gpu_device_id")); op_desc->SetAttr("output_name_mapping", output_mapping); + op_desc->SetAttr("origin_output_dims", origin_output_dims); op_desc->SetAttr("parameters", params); // we record all inputs' shapes in attr to check if they are consistent diff --git a/paddle/fluid/operators/tensorrt/tensorrt_engine_op.h b/paddle/fluid/operators/tensorrt/tensorrt_engine_op.h index 792737865ba179666cbf2f1a012698246fe91367..b8805c025a768e3f0d4565d08e176dfe904b42ef 100644 --- a/paddle/fluid/operators/tensorrt/tensorrt_engine_op.h +++ b/paddle/fluid/operators/tensorrt/tensorrt_engine_op.h @@ -288,6 +288,8 @@ class TensorRTEngineOp : public framework::OperatorBase { // Bind output tensor to TRT. int output_index = 0; + std::vector origin_output_dims = + Attr>("origin_output_dims"); VLOG(4) << "TensorRT Engine Op Outputs:"; for (const auto &y : Outputs("Ys")) { const int bind_index = @@ -306,7 +308,10 @@ class TensorRTEngineOp : public framework::OperatorBase { auto dims = trt_context->getBindingDimensions(bind_index); int nb_dims = dims.nbDims; for (; nb_dims > 0; nb_dims--) { - if (dims.d[nb_dims - 1] != 1) break; + // some 'x 1' of shape is normal, no need to remove it + if (dims.d[nb_dims - 1] != 1 || + nb_dims == origin_output_dims[output_index]) + break; } for (int i = 0; i < nb_dims; i++) ddim.push_back(dims.d[i]); #endif diff --git a/paddle/fluid/operators/tensorrt/tensorrt_engine_op_test.cc b/paddle/fluid/operators/tensorrt/tensorrt_engine_op_test.cc index e813e9ca7579f154b91db851c70286e5f4405820..1dcaccd6e926411c37acbaa0f58d0b3eb1438f3a 100644 --- a/paddle/fluid/operators/tensorrt/tensorrt_engine_op_test.cc +++ b/paddle/fluid/operators/tensorrt/tensorrt_engine_op_test.cc @@ -109,6 +109,7 @@ TEST(TensorRTEngineOp, manual) { engine_op_desc.SetAttr("use_calib_mode", static_cast(false)); engine_op_desc.SetAttr("output_name_mapping", std::vector({"z0"})); + engine_op_desc.SetAttr("origin_output_dims", std::vector({2})); engine_op_desc.SetAttr("subgraph", std::string(block_->SerializeAsString())); engine_op_desc.SetAttr("engine_serialized_data", std::string("")); int device_id = 0; @@ -210,6 +211,7 @@ void Execute(int batch_size, int input_dim, int output_dim, int nlayers = 1) { engine_op_desc.SetAttr("use_calib_mode", static_cast(false)); engine_op_desc.SetAttr("output_name_mapping", std::vector({"z3"})); + engine_op_desc.SetAttr("origin_output_dims", std::vector({2})); engine_op_desc.SetAttr("subgraph", std::string(block_->SerializeAsString())); engine_op_desc.SetAttr("engine_serialized_data", std::string("")); int device_id = 0;