未验证 提交 da43e065 编写于 作者: F feng_shuai 提交者: GitHub

delete gather_ut skip_case (#39657)

* delete gather_ut skip_case

* add trt version limit
上级 f33ae206
...@@ -560,12 +560,14 @@ bool OpTeller::Tell(const framework::ir::Node* node, bool use_no_calib_int8, ...@@ -560,12 +560,14 @@ bool OpTeller::Tell(const framework::ir::Node* node, bool use_no_calib_int8,
"the pass."; "the pass.";
return false; return false;
} }
#if !IS_TRT_VERSION_GE(7000)
auto* x_var_desc = block->FindVar(desc.Input("X")[0]); auto* x_var_desc = block->FindVar(desc.Input("X")[0]);
const auto x_shape = x_var_desc->GetShape(); const auto x_shape = x_var_desc->GetShape();
if (x_shape.size() == 1) { if (x_shape.size() == 1) {
VLOG(3) << "Gather does not support 1-dimensional input in tensorrt"; VLOG(3) << "Gather does not support 1-dimensional input in tensorrt";
return false; return false;
} }
#endif
} }
} }
......
...@@ -155,7 +155,7 @@ class TrtConvertGatherTest(TrtLayerAutoScanTest): ...@@ -155,7 +155,7 @@ class TrtConvertGatherTest(TrtLayerAutoScanTest):
if self.input_num == 3: if self.input_num == 3:
return 0, 5 return 0, 5
else: else:
if dynamic_shape and self.axis == 0: if dynamic_shape:
return 1, 3 return 1, 3
else: else:
return 0, 4 return 0, 4
...@@ -179,9 +179,12 @@ class TrtConvertGatherTest(TrtLayerAutoScanTest): ...@@ -179,9 +179,12 @@ class TrtConvertGatherTest(TrtLayerAutoScanTest):
self.trt_param.precision = paddle_infer.PrecisionType.Float32 self.trt_param.precision = paddle_infer.PrecisionType.Float32
yield self.create_inference_config(), generate_trt_nodes_num(True), 1e-5 yield self.create_inference_config(), generate_trt_nodes_num(True), 1e-5
self.trt_param.precision = paddle_infer.PrecisionType.Half self.trt_param.precision = paddle_infer.PrecisionType.Half
yield self.create_inference_config(), generate_trt_nodes_num(True), 1e-5 yield self.create_inference_config(), generate_trt_nodes_num(True), 1e-3
def add_skip_trt_case(self): def add_skip_trt_case(self):
ver = paddle_infer.get_trt_compile_version()
if ver[0] * 1000 + ver[1] * 100 + ver[0] * 10 < 7000:
def teller1(program_config, predictor_config): def teller1(program_config, predictor_config):
if len(self.dynamic_shape.min_input_shape) != 0: if len(self.dynamic_shape.min_input_shape) != 0:
inputs = program_config.inputs inputs = program_config.inputs
...@@ -192,19 +195,9 @@ class TrtConvertGatherTest(TrtLayerAutoScanTest): ...@@ -192,19 +195,9 @@ class TrtConvertGatherTest(TrtLayerAutoScanTest):
self.add_skip_case( self.add_skip_case(
teller1, SkipReasons.TRT_NOT_SUPPORT, teller1, SkipReasons.TRT_NOT_SUPPORT,
"Need to repair the case: trt reshape out failed for dynamic shape mode when inputs' dims==1." "Need to repair the case: trt reshape out failed for dynamic shape mode when inputs' dims==1. under trt7.0 "
) )
def teller2(program_config, predictor_config):
inputs = program_config.inputs
if "axis_data" in inputs.keys():
return True
return False
self.add_skip_case(
teller2, SkipReasons.TRT_NOT_SUPPORT,
"Need to repair the case: trt do not support axis tensor input.")
def test(self): def test(self):
self.add_skip_trt_case() self.add_skip_trt_case()
self.run_test() self.run_test()
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册