From 27f8460a11cfc31301481f580ef66e0fa43fc6a1 Mon Sep 17 00:00:00 2001 From: wenbin Date: Tue, 18 Jan 2022 13:59:21 +0800 Subject: [PATCH] modify transpose params check (#39006) * modify params check * correct compile --- paddle/fluid/inference/tensorrt/op_teller.cc | 4 +++- .../ir/inference/test_trt_convert_transpose.py | 12 ------------ 2 files changed, 3 insertions(+), 13 deletions(-) diff --git a/paddle/fluid/inference/tensorrt/op_teller.cc b/paddle/fluid/inference/tensorrt/op_teller.cc index 6663103d4ca..2d4b2ef659b 100644 --- a/paddle/fluid/inference/tensorrt/op_teller.cc +++ b/paddle/fluid/inference/tensorrt/op_teller.cc @@ -487,7 +487,6 @@ bool OpTeller::Tell(const framework::ir::Node* node, bool use_no_calib_int8, BOOST_GET_CONST(std::vector, desc.GetAttr("axis")); if (!with_dynamic_shape && axis[0] != 0) return false; if (axis.size() >= nvinfer1::Dims::MAX_DIMS) return false; - if (axis[0] == 0 && axis.size() == 2) return false; auto* block = desc.Block(); if (block == nullptr) { @@ -499,7 +498,9 @@ bool OpTeller::Tell(const framework::ir::Node* node, bool use_no_calib_int8, auto x_var_name = desc.Input("X")[0]; auto* x_var_desc = block->FindVar(x_var_name); const auto x_shape = x_var_desc->GetShape(); + if (axis.size() != x_shape.size()) return false; int dims = x_shape.size(); + std::vector perm(nvinfer1::Dims::MAX_DIMS); for (int i = 0; i < dims; i++) { perm[i] = axis[i]; @@ -518,6 +519,7 @@ bool OpTeller::Tell(const framework::ir::Node* node, bool use_no_calib_int8, if (!is_valid_permutation(dims, perm)) { VLOG(3) << "Invalid permutation dimensions for trt transpose op " "converter: duplicate or out of bound."; + return false; } } if (op_type == "flatten2" || op_type == "flatten") { diff --git a/python/paddle/fluid/tests/unittests/ir/inference/test_trt_convert_transpose.py b/python/paddle/fluid/tests/unittests/ir/inference/test_trt_convert_transpose.py index 31b4d027f17..87e81396ab4 100644 --- a/python/paddle/fluid/tests/unittests/ir/inference/test_trt_convert_transpose.py +++ b/python/paddle/fluid/tests/unittests/ir/inference/test_trt_convert_transpose.py @@ -146,19 +146,7 @@ class TrtConvertTransposeTest(TrtLayerAutoScanTest): yield self.create_inference_config(), generate_trt_nodes_num(attrs, True), 1e-5 - def add_skip_trt_case(self): - def teller1(program_config, predictor_config): - if program_config.ops[0].attrs['axis'] == [0, 1]: - return True - return False - - self.add_skip_case( - teller1, SkipReasons.TRT_NOT_IMPLEMENTED, - "INPUT AXIS [0, 1] NOT SUPPORT: we need to add support in the future" - ) - def test(self): - self.add_skip_trt_case() self.run_test() -- GitLab