未验证 提交 087c3abe 编写于 作者: J JingZhuangzhuang 提交者: GitHub

fix pool2d convert case (#36667)

上级 e2173b68
......@@ -18,6 +18,7 @@ import numpy as np
import paddle.inference as paddle_infer
from functools import partial
from typing import Optional, List, Callable, Dict, Any, Set
import unittest
class TrtConvertPool2dTest(TrtLayerAutoScanTest):
......@@ -46,16 +47,16 @@ class TrtConvertPool2dTest(TrtLayerAutoScanTest):
def generate_weight1(attrs: List[Dict[str, Any]]):
return np.random.random([24, 3, 3, 3]).astype(np.float32)
for strides in [[1, 1], [2, 2], [1, 2]]:
for strides in [[1, 1], [1, 2], [2, 2]]:
for paddings in [[0, 2], [0, 3], [0, 1, 2, 3]]:
for pooling_type in ['max', 'avg']:
for padding_algotithm in ['EXPLICIT', 'SAME', 'VAILD']:
for ksize in [[2, 3], [3, 3]]:
for data_format in ['NCHW']:
for global_pooling in [True, False]:
for exclusive in [True, False]:
for exclusive in [False, True]:
for adaptive in [True, False]:
for ceil_mode in [True, False]:
for ceil_mode in [False, True]:
dics = [{
"pooling_type":
......@@ -157,6 +158,29 @@ class TrtConvertPool2dTest(TrtLayerAutoScanTest):
teller2, SkipReasons.TRT_NOT_IMPLEMENTED,
"It is not support that global_pooling is true for trt now.")
def teller3(program_config, predictor_config):
if self.dynamic_shape.min_input_shape == {} and program_config.ops[
0].attrs['ceil_mode'] == True:
return True
return False
self.add_skip_case(
teller3, SkipReasons.TRT_NOT_IMPLEMENTED,
"It is not support that ceil_mode is true in static mode for trt now."
)
def teller4(program_config, predictor_config):
if self.dynamic_shape.min_input_shape != {} and (
program_config.ops[0].attrs['strides'] == [1, 2] or
program_config.ops[0].attrs['strides'] == [2, 2]):
return True
return False
self.add_skip_case(
teller4, SkipReasons.TRT_NOT_IMPLEMENTED,
"It is not support that strides is not equal [1, 1] in dynamic mode for trt now."
)
def test(self):
self.add_skip_trt_case()
self.run_test()
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册