From 15d5f6b9efa864f1c5c21afd88193654a4576f9f Mon Sep 17 00:00:00 2001 From: xiaoxiaohehe001 <49090790+xiaoxiaohehe001@users.noreply.github.com> Date: Fri, 1 Apr 2022 19:52:28 +0800 Subject: [PATCH] reshape_opteller (#41090) fix_reshape: for paddle-trt --- paddle/fluid/inference/tensorrt/op_teller.cc | 21 +++++++++++++++++++- 1 file changed, 20 insertions(+), 1 deletion(-) diff --git a/paddle/fluid/inference/tensorrt/op_teller.cc b/paddle/fluid/inference/tensorrt/op_teller.cc index fe0332025ed..13c16ab6897 100644 --- a/paddle/fluid/inference/tensorrt/op_teller.cc +++ b/paddle/fluid/inference/tensorrt/op_teller.cc @@ -1479,8 +1479,27 @@ bool OpTeller::Tell(const framework::ir::Node* node, bool use_no_calib_int8, std::vector shape = BOOST_GET_CONST(std::vector, desc.GetAttr("shape")); if (shape.size() >= nvinfer1::Dims::MAX_DIMS) return false; - if (!with_dynamic_shape && (shape[0] == -1 || shape.size() == 1)) + if (!with_dynamic_shape) { + if (shape.size() == 1) { + return false; + } + if (shape[0] == 0) { + return true; + } else { + auto* block = desc.Block(); + auto x_var_name = desc.Input("X")[0]; + auto* x_var_desc = block->FindVar(x_var_name); + const auto x_shape = x_var_desc->GetShape(); + int input_num = std::accumulate(x_shape.begin() + 1, x_shape.end(), 1, + std::multiplies()); + int shape_num = std::accumulate(shape.begin() + 1, shape.end(), 1, + std::multiplies()); + if (input_num == shape_num) { + return true; + } + } return false; + } } if (op_type == "clip") { -- GitLab