diff --git a/python/paddle/fluid/tests/unittests/ir/inference/test_trt_convert_gather_nd.py b/python/paddle/fluid/tests/unittests/ir/inference/test_trt_convert_gather_nd.py index 0c7eae5f85f9557f6db58af8c4e6a677894ede05..6b6a9536d81bef87a44e7996ea234f4833e058df 100644 --- a/python/paddle/fluid/tests/unittests/ir/inference/test_trt_convert_gather_nd.py +++ b/python/paddle/fluid/tests/unittests/ir/inference/test_trt_convert_gather_nd.py @@ -346,7 +346,7 @@ class TrtConvertGatherNdTest_dim_2_2(TrtLayerAutoScanTest): return np.random.random([2, 32]).astype(np.float32) def generate_input2(): - return np.ones([2, 2]).astype(np.int32) + return np.array([[0, 3], [1, 9]]).astype(np.int32) ops_config = [{ "op_type": "gather_nd", @@ -408,23 +408,11 @@ class TrtConvertGatherNdTest_dim_2_2(TrtLayerAutoScanTest): # for dynamic_shape generate_dynamic_shape(attrs) self.trt_param.precision = paddle_infer.PrecisionType.Float32 - yield self.create_inference_config(), (1, 3), 1e-5 + yield self.create_inference_config(), (0, 4), 1e-5 self.trt_param.precision = paddle_infer.PrecisionType.Half - yield self.create_inference_config(), (1, 3), 1e-5 - - def add_skip_trt_case(self): - def teller(program_config, predictor_config): - if len(self.dynamic_shape.min_input_shape) != 0: - return True - return False - - self.add_skip_case( - teller, SkipReasons.TRT_NOT_SUPPORT, - "Need to repair the case: the output of trt and GPU has diff when inputs' dim is 1 and 2." - ) + yield self.create_inference_config(), (0, 4), 1e-5 def test(self): - self.add_skip_trt_case() self.run_test() @@ -434,10 +422,11 @@ class TrtConvertGatherNdTest_dim_3_3(TrtLayerAutoScanTest): def sample_program_configs(self): def generate_input1(): - return np.random.random([2, 32, 256]).astype(np.float32) + return np.random.random([16, 32, 256]).astype(np.float32) def generate_input2(): - return np.ones([2, 2, 2]).astype(np.int32) + return np.array( + [[[2, 5], [3, 8]], [[0, 2], [0, 3]]]).astype(np.int32) ops_config = [{ "op_type": "gather_nd", @@ -471,7 +460,7 @@ class TrtConvertGatherNdTest_dim_3_3(TrtLayerAutoScanTest): "index_data": [1, 1, 1] } self.dynamic_shape.max_input_shape = { - "input_data": [4, 64, 512], + "input_data": [16, 64, 512], "index_data": [4, 2, 4] } self.dynamic_shape.opt_input_shape = {