diff --git a/paddle/fluid/inference/tensorrt/op_teller.cc b/paddle/fluid/inference/tensorrt/op_teller.cc index fe0332025ed4aaa538f2541496dba3069b7e1b32..13c16ab6897e378eca113e3e408c4a4455f049e5 100644 --- a/paddle/fluid/inference/tensorrt/op_teller.cc +++ b/paddle/fluid/inference/tensorrt/op_teller.cc @@ -1479,8 +1479,27 @@ bool OpTeller::Tell(const framework::ir::Node* node, bool use_no_calib_int8, std::vector shape = BOOST_GET_CONST(std::vector, desc.GetAttr("shape")); if (shape.size() >= nvinfer1::Dims::MAX_DIMS) return false; - if (!with_dynamic_shape && (shape[0] == -1 || shape.size() == 1)) + if (!with_dynamic_shape) { + if (shape.size() == 1) { + return false; + } + if (shape[0] == 0) { + return true; + } else { + auto* block = desc.Block(); + auto x_var_name = desc.Input("X")[0]; + auto* x_var_desc = block->FindVar(x_var_name); + const auto x_shape = x_var_desc->GetShape(); + int input_num = std::accumulate(x_shape.begin() + 1, x_shape.end(), 1, + std::multiplies()); + int shape_num = std::accumulate(shape.begin() + 1, shape.end(), 1, + std::multiplies()); + if (input_num == shape_num) { + return true; + } + } return false; + } } if (op_type == "clip") {