未验证 提交 da43e065 编写于 作者: F feng_shuai 提交者: GitHub

delete gather_ut skip_case (#39657)

* delete gather_ut skip_case

* add trt version limit
上级 f33ae206
......@@ -560,12 +560,14 @@ bool OpTeller::Tell(const framework::ir::Node* node, bool use_no_calib_int8,
"the pass.";
return false;
}
#if !IS_TRT_VERSION_GE(7000)
auto* x_var_desc = block->FindVar(desc.Input("X")[0]);
const auto x_shape = x_var_desc->GetShape();
if (x_shape.size() == 1) {
VLOG(3) << "Gather does not support 1-dimensional input in tensorrt";
return false;
}
#endif
}
}
......
......@@ -155,7 +155,7 @@ class TrtConvertGatherTest(TrtLayerAutoScanTest):
if self.input_num == 3:
return 0, 5
else:
if dynamic_shape and self.axis == 0:
if dynamic_shape:
return 1, 3
else:
return 0, 4
......@@ -179,31 +179,24 @@ class TrtConvertGatherTest(TrtLayerAutoScanTest):
self.trt_param.precision = paddle_infer.PrecisionType.Float32
yield self.create_inference_config(), generate_trt_nodes_num(True), 1e-5
self.trt_param.precision = paddle_infer.PrecisionType.Half
yield self.create_inference_config(), generate_trt_nodes_num(True), 1e-5
yield self.create_inference_config(), generate_trt_nodes_num(True), 1e-3
def add_skip_trt_case(self):
def teller1(program_config, predictor_config):
if len(self.dynamic_shape.min_input_shape) != 0:
inputs = program_config.inputs
if len(inputs['input_data'].shape) == 1 or len(inputs[
'index_data'].shape) == 1:
return True
return False
self.add_skip_case(
teller1, SkipReasons.TRT_NOT_SUPPORT,
"Need to repair the case: trt reshape out failed for dynamic shape mode when inputs' dims==1."
)
def teller2(program_config, predictor_config):
inputs = program_config.inputs
if "axis_data" in inputs.keys():
return True
return False
self.add_skip_case(
teller2, SkipReasons.TRT_NOT_SUPPORT,
"Need to repair the case: trt do not support axis tensor input.")
ver = paddle_infer.get_trt_compile_version()
if ver[0] * 1000 + ver[1] * 100 + ver[0] * 10 < 7000:
def teller1(program_config, predictor_config):
if len(self.dynamic_shape.min_input_shape) != 0:
inputs = program_config.inputs
if len(inputs['input_data'].shape) == 1 or len(inputs[
'index_data'].shape) == 1:
return True
return False
self.add_skip_case(
teller1, SkipReasons.TRT_NOT_SUPPORT,
"Need to repair the case: trt reshape out failed for dynamic shape mode when inputs' dims==1. under trt7.0 "
)
def test(self):
self.add_skip_trt_case()
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册