From 087c3abe21aaddce466b7e5fa6313774ba5a2316 Mon Sep 17 00:00:00 2001 From: JingZhuangzhuang <75348594+JZZ-NOTE@users.noreply.github.com> Date: Sun, 24 Oct 2021 21:13:26 -0500 Subject: [PATCH] fix pool2d convert case (#36667) --- .../ir/inference/test_trt_convert_pool2d.py | 30 +++++++++++++++++-- 1 file changed, 27 insertions(+), 3 deletions(-) diff --git a/python/paddle/fluid/tests/unittests/ir/inference/test_trt_convert_pool2d.py b/python/paddle/fluid/tests/unittests/ir/inference/test_trt_convert_pool2d.py index 9ec2f83fa5..05545f0b0e 100644 --- a/python/paddle/fluid/tests/unittests/ir/inference/test_trt_convert_pool2d.py +++ b/python/paddle/fluid/tests/unittests/ir/inference/test_trt_convert_pool2d.py @@ -18,6 +18,7 @@ import numpy as np import paddle.inference as paddle_infer from functools import partial from typing import Optional, List, Callable, Dict, Any, Set +import unittest class TrtConvertPool2dTest(TrtLayerAutoScanTest): @@ -46,16 +47,16 @@ class TrtConvertPool2dTest(TrtLayerAutoScanTest): def generate_weight1(attrs: List[Dict[str, Any]]): return np.random.random([24, 3, 3, 3]).astype(np.float32) - for strides in [[1, 1], [2, 2], [1, 2]]: + for strides in [[1, 1], [1, 2], [2, 2]]: for paddings in [[0, 2], [0, 3], [0, 1, 2, 3]]: for pooling_type in ['max', 'avg']: for padding_algotithm in ['EXPLICIT', 'SAME', 'VAILD']: for ksize in [[2, 3], [3, 3]]: for data_format in ['NCHW']: for global_pooling in [True, False]: - for exclusive in [True, False]: + for exclusive in [False, True]: for adaptive in [True, False]: - for ceil_mode in [True, False]: + for ceil_mode in [False, True]: dics = [{ "pooling_type": @@ -157,6 +158,29 @@ class TrtConvertPool2dTest(TrtLayerAutoScanTest): teller2, SkipReasons.TRT_NOT_IMPLEMENTED, "It is not support that global_pooling is true for trt now.") + def teller3(program_config, predictor_config): + if self.dynamic_shape.min_input_shape == {} and program_config.ops[ + 0].attrs['ceil_mode'] == True: + return True + return False + + self.add_skip_case( + teller3, SkipReasons.TRT_NOT_IMPLEMENTED, + "It is not support that ceil_mode is true in static mode for trt now." + ) + + def teller4(program_config, predictor_config): + if self.dynamic_shape.min_input_shape != {} and ( + program_config.ops[0].attrs['strides'] == [1, 2] or + program_config.ops[0].attrs['strides'] == [2, 2]): + return True + return False + + self.add_skip_case( + teller4, SkipReasons.TRT_NOT_IMPLEMENTED, + "It is not support that strides is not equal [1, 1] in dynamic mode for trt now." + ) + def test(self): self.add_skip_trt_case() self.run_test() -- GitLab