未验证 提交 4130b640 编写于 作者: B baoachun 提交者: GitHub

update gather_nd trt converter ut (#39584)

* update gather_nd trt converter ut

* update ut
上级 da492a13
......@@ -346,7 +346,7 @@ class TrtConvertGatherNdTest_dim_2_2(TrtLayerAutoScanTest):
return np.random.random([2, 32]).astype(np.float32)
def generate_input2():
return np.ones([2, 2]).astype(np.int32)
return np.array([[0, 3], [1, 9]]).astype(np.int32)
ops_config = [{
"op_type": "gather_nd",
......@@ -408,23 +408,11 @@ class TrtConvertGatherNdTest_dim_2_2(TrtLayerAutoScanTest):
# for dynamic_shape
generate_dynamic_shape(attrs)
self.trt_param.precision = paddle_infer.PrecisionType.Float32
yield self.create_inference_config(), (1, 3), 1e-5
yield self.create_inference_config(), (0, 4), 1e-5
self.trt_param.precision = paddle_infer.PrecisionType.Half
yield self.create_inference_config(), (1, 3), 1e-5
def add_skip_trt_case(self):
def teller(program_config, predictor_config):
if len(self.dynamic_shape.min_input_shape) != 0:
return True
return False
self.add_skip_case(
teller, SkipReasons.TRT_NOT_SUPPORT,
"Need to repair the case: the output of trt and GPU has diff when inputs' dim is 1 and 2."
)
yield self.create_inference_config(), (0, 4), 1e-5
def test(self):
self.add_skip_trt_case()
self.run_test()
......@@ -434,10 +422,11 @@ class TrtConvertGatherNdTest_dim_3_3(TrtLayerAutoScanTest):
def sample_program_configs(self):
def generate_input1():
return np.random.random([2, 32, 256]).astype(np.float32)
return np.random.random([16, 32, 256]).astype(np.float32)
def generate_input2():
return np.ones([2, 2, 2]).astype(np.int32)
return np.array(
[[[2, 5], [3, 8]], [[0, 2], [0, 3]]]).astype(np.int32)
ops_config = [{
"op_type": "gather_nd",
......@@ -471,7 +460,7 @@ class TrtConvertGatherNdTest_dim_3_3(TrtLayerAutoScanTest):
"index_data": [1, 1, 1]
}
self.dynamic_shape.max_input_shape = {
"input_data": [4, 64, 512],
"input_data": [16, 64, 512],
"index_data": [4, 2, 4]
}
self.dynamic_shape.opt_input_shape = {
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册