From f07b8cbec5d2c1b0395c314238e39f3b9f5998bf Mon Sep 17 00:00:00 2001 From: feng_shuai Date: Tue, 25 Jan 2022 16:52:00 +0800 Subject: [PATCH] fix:the axis must be 1(channel), when the dims of bias is 1 (#39052) --- .../inference/test_trt_convert_elementwise.py | 50 ++----------------- 1 file changed, 4 insertions(+), 46 deletions(-) diff --git a/python/paddle/fluid/tests/unittests/ir/inference/test_trt_convert_elementwise.py b/python/paddle/fluid/tests/unittests/ir/inference/test_trt_convert_elementwise.py index 1c5b640fe4b..505060e31a0 100644 --- a/python/paddle/fluid/tests/unittests/ir/inference/test_trt_convert_elementwise.py +++ b/python/paddle/fluid/tests/unittests/ir/inference/test_trt_convert_elementwise.py @@ -36,7 +36,7 @@ class TrtConvertElementwiseTest_one_input(TrtLayerAutoScanTest): for shape in [[32], [batch, 32], [batch, 32, 32], [batch, 32, 16, 32]]: for op_type in ["elementwise_add", "elementwise_mul"]: - for axis in [len(shape) - 1, -1]: + for axis in [-1 if len(shape) == 1 else 1]: self.dims = len(shape) dics = [{"axis": axis}] ops_config = [{ @@ -129,33 +129,7 @@ class TrtConvertElementwiseTest_one_input(TrtLayerAutoScanTest): True), 1e-5 def add_skip_trt_case(self): - def teller1(program_config, predictor_config): - if self.dims == 2 and len(self.dynamic_shape.max_input_shape) == 0: - return True - return False - - self.add_skip_case( - teller1, SkipReasons.TRT_NOT_IMPLEMENTED, - "The output shape are not equal between gpu and tensorrt when input dim is 2." - ) - - def teller2(program_config, predictor_config): - if self.dims == 3: - return True - return False - - self.add_skip_case( - teller2, SkipReasons.TRT_NOT_IMPLEMENTED, - "The output has diff between gpu and tensorrt when input dim is 3.") - - def teller3(program_config, predictor_config): - if self.dims == 4: - return True - return False - - self.add_skip_case( - teller3, SkipReasons.TRT_NOT_IMPLEMENTED, - "The output has diff between gpu and tensorrt when input dim is 4.") + pass def test(self): self.add_skip_trt_case() @@ -287,15 +261,7 @@ class TrtConvertElementwiseTest_two_input_without_broadcast( yield self.create_inference_config(), (1, 3), 1e-5 def add_skip_trt_case(self): - def teller1(program_config, predictor_config): - if self.dims == 2: - return True - return False - - self.add_skip_case( - teller1, SkipReasons.TRT_NOT_IMPLEMENTED, - "The output shape are not equal between gpu and tensorrt when input dim is 2." - ) + pass def test(self): self.add_skip_trt_case() @@ -418,15 +384,7 @@ class TrtConvertElementwiseTest_two_input_with_broadcast(TrtLayerAutoScanTest): yield self.create_inference_config(), (1, 3), 1e-5 def add_skip_trt_case(self): - def teller1(program_config, predictor_config): - if len(self.shape1) == 2: - return True - return False - - self.add_skip_case( - teller1, SkipReasons.TRT_NOT_IMPLEMENTED, - "The output shape are not equal between gpu and tensorrt when input dim is 2." - ) + pass def test(self): self.add_skip_trt_case() -- GitLab