未验证 提交 27f8460a 编写于 作者: W wenbin 提交者: GitHub

modify transpose params check (#39006)

* modify params check

* correct compile
上级 a17e51dd
...@@ -487,7 +487,6 @@ bool OpTeller::Tell(const framework::ir::Node* node, bool use_no_calib_int8, ...@@ -487,7 +487,6 @@ bool OpTeller::Tell(const framework::ir::Node* node, bool use_no_calib_int8,
BOOST_GET_CONST(std::vector<int>, desc.GetAttr("axis")); BOOST_GET_CONST(std::vector<int>, desc.GetAttr("axis"));
if (!with_dynamic_shape && axis[0] != 0) return false; if (!with_dynamic_shape && axis[0] != 0) return false;
if (axis.size() >= nvinfer1::Dims::MAX_DIMS) return false; if (axis.size() >= nvinfer1::Dims::MAX_DIMS) return false;
if (axis[0] == 0 && axis.size() == 2) return false;
auto* block = desc.Block(); auto* block = desc.Block();
if (block == nullptr) { if (block == nullptr) {
...@@ -499,7 +498,9 @@ bool OpTeller::Tell(const framework::ir::Node* node, bool use_no_calib_int8, ...@@ -499,7 +498,9 @@ bool OpTeller::Tell(const framework::ir::Node* node, bool use_no_calib_int8,
auto x_var_name = desc.Input("X")[0]; auto x_var_name = desc.Input("X")[0];
auto* x_var_desc = block->FindVar(x_var_name); auto* x_var_desc = block->FindVar(x_var_name);
const auto x_shape = x_var_desc->GetShape(); const auto x_shape = x_var_desc->GetShape();
if (axis.size() != x_shape.size()) return false;
int dims = x_shape.size(); int dims = x_shape.size();
std::vector<int> perm(nvinfer1::Dims::MAX_DIMS); std::vector<int> perm(nvinfer1::Dims::MAX_DIMS);
for (int i = 0; i < dims; i++) { for (int i = 0; i < dims; i++) {
perm[i] = axis[i]; perm[i] = axis[i];
...@@ -518,6 +519,7 @@ bool OpTeller::Tell(const framework::ir::Node* node, bool use_no_calib_int8, ...@@ -518,6 +519,7 @@ bool OpTeller::Tell(const framework::ir::Node* node, bool use_no_calib_int8,
if (!is_valid_permutation(dims, perm)) { if (!is_valid_permutation(dims, perm)) {
VLOG(3) << "Invalid permutation dimensions for trt transpose op " VLOG(3) << "Invalid permutation dimensions for trt transpose op "
"converter: duplicate or out of bound."; "converter: duplicate or out of bound.";
return false;
} }
} }
if (op_type == "flatten2" || op_type == "flatten") { if (op_type == "flatten2" || op_type == "flatten") {
......
...@@ -146,19 +146,7 @@ class TrtConvertTransposeTest(TrtLayerAutoScanTest): ...@@ -146,19 +146,7 @@ class TrtConvertTransposeTest(TrtLayerAutoScanTest):
yield self.create_inference_config(), generate_trt_nodes_num(attrs, yield self.create_inference_config(), generate_trt_nodes_num(attrs,
True), 1e-5 True), 1e-5
def add_skip_trt_case(self):
def teller1(program_config, predictor_config):
if program_config.ops[0].attrs['axis'] == [0, 1]:
return True
return False
self.add_skip_case(
teller1, SkipReasons.TRT_NOT_IMPLEMENTED,
"INPUT AXIS [0, 1] NOT SUPPORT: we need to add support in the future"
)
def test(self): def test(self):
self.add_skip_trt_case()
self.run_test() self.run_test()
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册