未验证 提交 4130b640 编写于 作者: B baoachun 提交者: GitHub

update gather_nd trt converter ut (#39584)

* update gather_nd trt converter ut

* update ut
上级 da492a13
...@@ -346,7 +346,7 @@ class TrtConvertGatherNdTest_dim_2_2(TrtLayerAutoScanTest): ...@@ -346,7 +346,7 @@ class TrtConvertGatherNdTest_dim_2_2(TrtLayerAutoScanTest):
return np.random.random([2, 32]).astype(np.float32) return np.random.random([2, 32]).astype(np.float32)
def generate_input2(): def generate_input2():
return np.ones([2, 2]).astype(np.int32) return np.array([[0, 3], [1, 9]]).astype(np.int32)
ops_config = [{ ops_config = [{
"op_type": "gather_nd", "op_type": "gather_nd",
...@@ -408,23 +408,11 @@ class TrtConvertGatherNdTest_dim_2_2(TrtLayerAutoScanTest): ...@@ -408,23 +408,11 @@ class TrtConvertGatherNdTest_dim_2_2(TrtLayerAutoScanTest):
# for dynamic_shape # for dynamic_shape
generate_dynamic_shape(attrs) generate_dynamic_shape(attrs)
self.trt_param.precision = paddle_infer.PrecisionType.Float32 self.trt_param.precision = paddle_infer.PrecisionType.Float32
yield self.create_inference_config(), (1, 3), 1e-5 yield self.create_inference_config(), (0, 4), 1e-5
self.trt_param.precision = paddle_infer.PrecisionType.Half self.trt_param.precision = paddle_infer.PrecisionType.Half
yield self.create_inference_config(), (1, 3), 1e-5 yield self.create_inference_config(), (0, 4), 1e-5
def add_skip_trt_case(self):
def teller(program_config, predictor_config):
if len(self.dynamic_shape.min_input_shape) != 0:
return True
return False
self.add_skip_case(
teller, SkipReasons.TRT_NOT_SUPPORT,
"Need to repair the case: the output of trt and GPU has diff when inputs' dim is 1 and 2."
)
def test(self): def test(self):
self.add_skip_trt_case()
self.run_test() self.run_test()
...@@ -434,10 +422,11 @@ class TrtConvertGatherNdTest_dim_3_3(TrtLayerAutoScanTest): ...@@ -434,10 +422,11 @@ class TrtConvertGatherNdTest_dim_3_3(TrtLayerAutoScanTest):
def sample_program_configs(self): def sample_program_configs(self):
def generate_input1(): def generate_input1():
return np.random.random([2, 32, 256]).astype(np.float32) return np.random.random([16, 32, 256]).astype(np.float32)
def generate_input2(): def generate_input2():
return np.ones([2, 2, 2]).astype(np.int32) return np.array(
[[[2, 5], [3, 8]], [[0, 2], [0, 3]]]).astype(np.int32)
ops_config = [{ ops_config = [{
"op_type": "gather_nd", "op_type": "gather_nd",
...@@ -471,7 +460,7 @@ class TrtConvertGatherNdTest_dim_3_3(TrtLayerAutoScanTest): ...@@ -471,7 +460,7 @@ class TrtConvertGatherNdTest_dim_3_3(TrtLayerAutoScanTest):
"index_data": [1, 1, 1] "index_data": [1, 1, 1]
} }
self.dynamic_shape.max_input_shape = { self.dynamic_shape.max_input_shape = {
"input_data": [4, 64, 512], "input_data": [16, 64, 512],
"index_data": [4, 2, 4] "index_data": [4, 2, 4]
} }
self.dynamic_shape.opt_input_shape = { self.dynamic_shape.opt_input_shape = {
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册