未验证 提交 27f8460a 编写于 作者: W wenbin 提交者: GitHub

modify transpose params check (#39006)

* modify params check

* correct compile
上级 a17e51dd
......@@ -487,7 +487,6 @@ bool OpTeller::Tell(const framework::ir::Node* node, bool use_no_calib_int8,
BOOST_GET_CONST(std::vector<int>, desc.GetAttr("axis"));
if (!with_dynamic_shape && axis[0] != 0) return false;
if (axis.size() >= nvinfer1::Dims::MAX_DIMS) return false;
if (axis[0] == 0 && axis.size() == 2) return false;
auto* block = desc.Block();
if (block == nullptr) {
......@@ -499,7 +498,9 @@ bool OpTeller::Tell(const framework::ir::Node* node, bool use_no_calib_int8,
auto x_var_name = desc.Input("X")[0];
auto* x_var_desc = block->FindVar(x_var_name);
const auto x_shape = x_var_desc->GetShape();
if (axis.size() != x_shape.size()) return false;
int dims = x_shape.size();
std::vector<int> perm(nvinfer1::Dims::MAX_DIMS);
for (int i = 0; i < dims; i++) {
perm[i] = axis[i];
......@@ -518,6 +519,7 @@ bool OpTeller::Tell(const framework::ir::Node* node, bool use_no_calib_int8,
if (!is_valid_permutation(dims, perm)) {
VLOG(3) << "Invalid permutation dimensions for trt transpose op "
"converter: duplicate or out of bound.";
return false;
}
}
if (op_type == "flatten2" || op_type == "flatten") {
......
......@@ -146,19 +146,7 @@ class TrtConvertTransposeTest(TrtLayerAutoScanTest):
yield self.create_inference_config(), generate_trt_nodes_num(attrs,
True), 1e-5
def add_skip_trt_case(self):
def teller1(program_config, predictor_config):
if program_config.ops[0].attrs['axis'] == [0, 1]:
return True
return False
self.add_skip_case(
teller1, SkipReasons.TRT_NOT_IMPLEMENTED,
"INPUT AXIS [0, 1] NOT SUPPORT: we need to add support in the future"
)
def test(self):
self.add_skip_trt_case()
self.run_test()
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册