diff --git a/paddle/fluid/inference/tensorrt/op_teller.cc b/paddle/fluid/inference/tensorrt/op_teller.cc index 6663103d4ca37445b96ce53fa39ddc3474988999..2d4b2ef659b9a862fc09fade6ab3f05285bed513 100644 --- a/paddle/fluid/inference/tensorrt/op_teller.cc +++ b/paddle/fluid/inference/tensorrt/op_teller.cc @@ -487,7 +487,6 @@ bool OpTeller::Tell(const framework::ir::Node* node, bool use_no_calib_int8, BOOST_GET_CONST(std::vector, desc.GetAttr("axis")); if (!with_dynamic_shape && axis[0] != 0) return false; if (axis.size() >= nvinfer1::Dims::MAX_DIMS) return false; - if (axis[0] == 0 && axis.size() == 2) return false; auto* block = desc.Block(); if (block == nullptr) { @@ -499,7 +498,9 @@ bool OpTeller::Tell(const framework::ir::Node* node, bool use_no_calib_int8, auto x_var_name = desc.Input("X")[0]; auto* x_var_desc = block->FindVar(x_var_name); const auto x_shape = x_var_desc->GetShape(); + if (axis.size() != x_shape.size()) return false; int dims = x_shape.size(); + std::vector perm(nvinfer1::Dims::MAX_DIMS); for (int i = 0; i < dims; i++) { perm[i] = axis[i]; @@ -518,6 +519,7 @@ bool OpTeller::Tell(const framework::ir::Node* node, bool use_no_calib_int8, if (!is_valid_permutation(dims, perm)) { VLOG(3) << "Invalid permutation dimensions for trt transpose op " "converter: duplicate or out of bound."; + return false; } } if (op_type == "flatten2" || op_type == "flatten") { diff --git a/python/paddle/fluid/tests/unittests/ir/inference/test_trt_convert_transpose.py b/python/paddle/fluid/tests/unittests/ir/inference/test_trt_convert_transpose.py index 31b4d027f1780b604ed39cdb0a1ae56e0daee74c..87e81396ab4112c17b90080a9dd586bac920dbc9 100644 --- a/python/paddle/fluid/tests/unittests/ir/inference/test_trt_convert_transpose.py +++ b/python/paddle/fluid/tests/unittests/ir/inference/test_trt_convert_transpose.py @@ -146,19 +146,7 @@ class TrtConvertTransposeTest(TrtLayerAutoScanTest): yield self.create_inference_config(), generate_trt_nodes_num(attrs, True), 1e-5 - def add_skip_trt_case(self): - def teller1(program_config, predictor_config): - if program_config.ops[0].attrs['axis'] == [0, 1]: - return True - return False - - self.add_skip_case( - teller1, SkipReasons.TRT_NOT_IMPLEMENTED, - "INPUT AXIS [0, 1] NOT SUPPORT: we need to add support in the future" - ) - def test(self): - self.add_skip_trt_case() self.run_test()