未验证 提交 087c3abe 编写于 作者: J JingZhuangzhuang 提交者: GitHub

fix pool2d convert case (#36667)

上级 e2173b68
...@@ -18,6 +18,7 @@ import numpy as np ...@@ -18,6 +18,7 @@ import numpy as np
import paddle.inference as paddle_infer import paddle.inference as paddle_infer
from functools import partial from functools import partial
from typing import Optional, List, Callable, Dict, Any, Set from typing import Optional, List, Callable, Dict, Any, Set
import unittest
class TrtConvertPool2dTest(TrtLayerAutoScanTest): class TrtConvertPool2dTest(TrtLayerAutoScanTest):
...@@ -46,16 +47,16 @@ class TrtConvertPool2dTest(TrtLayerAutoScanTest): ...@@ -46,16 +47,16 @@ class TrtConvertPool2dTest(TrtLayerAutoScanTest):
def generate_weight1(attrs: List[Dict[str, Any]]): def generate_weight1(attrs: List[Dict[str, Any]]):
return np.random.random([24, 3, 3, 3]).astype(np.float32) return np.random.random([24, 3, 3, 3]).astype(np.float32)
for strides in [[1, 1], [2, 2], [1, 2]]: for strides in [[1, 1], [1, 2], [2, 2]]:
for paddings in [[0, 2], [0, 3], [0, 1, 2, 3]]: for paddings in [[0, 2], [0, 3], [0, 1, 2, 3]]:
for pooling_type in ['max', 'avg']: for pooling_type in ['max', 'avg']:
for padding_algotithm in ['EXPLICIT', 'SAME', 'VAILD']: for padding_algotithm in ['EXPLICIT', 'SAME', 'VAILD']:
for ksize in [[2, 3], [3, 3]]: for ksize in [[2, 3], [3, 3]]:
for data_format in ['NCHW']: for data_format in ['NCHW']:
for global_pooling in [True, False]: for global_pooling in [True, False]:
for exclusive in [True, False]: for exclusive in [False, True]:
for adaptive in [True, False]: for adaptive in [True, False]:
for ceil_mode in [True, False]: for ceil_mode in [False, True]:
dics = [{ dics = [{
"pooling_type": "pooling_type":
...@@ -157,6 +158,29 @@ class TrtConvertPool2dTest(TrtLayerAutoScanTest): ...@@ -157,6 +158,29 @@ class TrtConvertPool2dTest(TrtLayerAutoScanTest):
teller2, SkipReasons.TRT_NOT_IMPLEMENTED, teller2, SkipReasons.TRT_NOT_IMPLEMENTED,
"It is not support that global_pooling is true for trt now.") "It is not support that global_pooling is true for trt now.")
def teller3(program_config, predictor_config):
if self.dynamic_shape.min_input_shape == {} and program_config.ops[
0].attrs['ceil_mode'] == True:
return True
return False
self.add_skip_case(
teller3, SkipReasons.TRT_NOT_IMPLEMENTED,
"It is not support that ceil_mode is true in static mode for trt now."
)
def teller4(program_config, predictor_config):
if self.dynamic_shape.min_input_shape != {} and (
program_config.ops[0].attrs['strides'] == [1, 2] or
program_config.ops[0].attrs['strides'] == [2, 2]):
return True
return False
self.add_skip_case(
teller4, SkipReasons.TRT_NOT_IMPLEMENTED,
"It is not support that strides is not equal [1, 1] in dynamic mode for trt now."
)
def test(self): def test(self):
self.add_skip_trt_case() self.add_skip_trt_case()
self.run_test() self.run_test()
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册