From 021085e30b9bb5e7d49cbe3cc505e5877d53727b Mon Sep 17 00:00:00 2001 From: zhoutianzi666 <39978853+zhoutianzi666@users.noreply.github.com> Date: Tue, 3 Jan 2023 13:12:02 +0800 Subject: [PATCH] forbid ops who have 1D intermediate tensor entering Paddle-TRT (#49378) --- paddle/fluid/inference/tensorrt/op_teller.cc | 29 +++++++++----------- 1 file changed, 13 insertions(+), 16 deletions(-) diff --git a/paddle/fluid/inference/tensorrt/op_teller.cc b/paddle/fluid/inference/tensorrt/op_teller.cc index 61c0e0d23b..66bfe56f35 100644 --- a/paddle/fluid/inference/tensorrt/op_teller.cc +++ b/paddle/fluid/inference/tensorrt/op_teller.cc @@ -119,24 +119,21 @@ struct SimpleOpTypeSetTeller : public Teller { #endif } - // In static shape mode in TRT, we can't allow that op's input is a - // 1D-tensor So we filter it here. Some op like elementwise having "Y" too, - // but that is dealt with in the specified op, here just the common case + // In static shape in Paddle-TRT, we can't allow that one op has a + // 1D intermediate tensor as input. if (!with_dynamic_shape) { - std::string X_name; auto inputs = desc.Inputs(); - if (inputs.count("X") && !desc.Input("X").empty()) { - X_name = desc.Input("X")[0]; - } else if (inputs.count("Input") && !desc.Input("Input").empty()) { - X_name = desc.Input("Input")[0]; - } - auto* block = desc.Block(); - if (block) { - auto* x_var_desc = block->FindVar(X_name); - // Can't get feed op's TensorDesc - if (op_type != "feed" && x_var_desc && !x_var_desc->Persistable()) { - const auto x_shape = x_var_desc->GetShape(); - if (x_shape.size() == 1) return false; + for (auto iter : inputs) { + for (auto var_name : iter.second) { + auto* block = desc.Block(); + if (block) { + auto* var_desc = block->FindVar(var_name); + // Can't get feed op's TensorDesc + if (op_type != "feed" && var_desc && !var_desc->Persistable()) { + const auto shape = var_desc->GetShape(); + if (shape.size() == 1) return false; + } + } } } } -- GitLab