未验证 提交 7612bf1c 编写于 作者: J JingZhuangzhuang 提交者: GitHub

Add pool2d test convert (#36338) (#36663)

上级 a9b7d1d2
...@@ -107,6 +107,9 @@ class Pool2dOpConverter : public OpConverter { ...@@ -107,6 +107,9 @@ class Pool2dOpConverter : public OpConverter {
plugin_pool_type = plugin::PoolPlugin::PoolType::avg; plugin_pool_type = plugin::PoolPlugin::PoolType::avg;
} }
if (padding_algorithm == "VALID") {
std::fill(paddings.begin(), paddings.end(), 0);
}
nvinfer1::DimsHW nv_ksize(ksize[0], ksize[1]); nvinfer1::DimsHW nv_ksize(ksize[0], ksize[1]);
nvinfer1::DimsHW nv_strides(strides[0], strides[1]); nvinfer1::DimsHW nv_strides(strides[0], strides[1]);
nvinfer1::DimsHW nv_paddings(paddings[0], paddings[1]); nvinfer1::DimsHW nv_paddings(paddings[0], paddings[1]);
...@@ -149,6 +152,30 @@ class Pool2dOpConverter : public OpConverter { ...@@ -149,6 +152,30 @@ class Pool2dOpConverter : public OpConverter {
input1 = pad_layer->getOutput(0); input1 = pad_layer->getOutput(0);
} }
auto *pool_layer = TRT_ENGINE_ADD_LAYER(engine_, Pooling, *input1,
nv_pool_type, nv_ksize);
pool_layer->setStride(nv_strides);
pool_layer->setPadding(nv_paddings);
pool_layer->setAverageCountExcludesPadding(exclusive);
if (padding_algorithm == "SAME") {
pool_layer->setPaddingMode(nvinfer1::PaddingMode::kSAME_UPPER);
}
layer = pool_layer;
} else if (!adaptive && !global_pooling && ceil_mode) {
nvinfer1::DimsHW pre_pad(0, 0);
nvinfer1::DimsHW post_pad(0, 0);
// If ceil mode is true, we will pad the appropriate size to the input.
DealCeilMode(input_shape, ksize, strides, paddings, &pre_pad, &post_pad,
input_dims);
auto *pad_layer = TRT_ENGINE_ADD_LAYER(
engine_, Padding, *const_cast<nvinfer1::ITensor *>(input1), pre_pad,
post_pad);
PADDLE_ENFORCE_NOT_NULL(
pad_layer, platform::errors::Fatal(
"Pad layer in poolOp converter could not be "
"created. The pointer to pad layer is `NULL`."));
input1 = pad_layer->getOutput(0);
auto *pool_layer = TRT_ENGINE_ADD_LAYER(engine_, Pooling, *input1, auto *pool_layer = TRT_ENGINE_ADD_LAYER(engine_, Pooling, *input1,
nv_pool_type, nv_ksize); nv_pool_type, nv_ksize);
pool_layer->setStride(nv_strides); pool_layer->setStride(nv_strides);
......
...@@ -178,22 +178,8 @@ bool OpTeller::Tell(const framework::ir::Node* node, bool use_no_calib_int8, ...@@ -178,22 +178,8 @@ bool OpTeller::Tell(const framework::ir::Node* node, bool use_no_calib_int8,
if (op_type == "pool2d") { if (op_type == "pool2d") {
std::vector<int> paddings = std::vector<int> paddings =
BOOST_GET_CONST(std::vector<int>, desc.GetAttr("paddings")); BOOST_GET_CONST(std::vector<int>, desc.GetAttr("paddings"));
if (paddings.size() > 2) return false; if (paddings.size() > 2) {
if (desc.HasAttr("exclusive")) { return false;
if (BOOST_GET_CONST(bool, desc.GetAttr("exclusive"))) {
std::vector<int> ksize =
BOOST_GET_CONST(std::vector<int>, desc.GetAttr("ksize"));
for (size_t i = 0; i < ksize.size(); i++) {
if (ksize[i] <= paddings[i]) {
VLOG(3) << "the padding size should be less than the filter size "
"for exclusive-counting pooling.";
return false;
}
}
}
}
if (desc.HasAttr("ceil_mode")) {
if (BOOST_GET_CONST(bool, desc.GetAttr("ceil_mode"))) return false;
} }
if (desc.Input("X").size() != 1) { if (desc.Input("X").size() != 1) {
VLOG(3) << "TRT Pool2d expect 1 input, but got " VLOG(3) << "TRT Pool2d expect 1 input, but got "
...@@ -215,18 +201,32 @@ bool OpTeller::Tell(const framework::ir::Node* node, bool use_no_calib_int8, ...@@ -215,18 +201,32 @@ bool OpTeller::Tell(const framework::ir::Node* node, bool use_no_calib_int8,
<< pool_type << " pool type."; << pool_type << " pool type.";
return false; return false;
} }
if (pool_type == "avg") {
if (desc.HasAttr("global_pooling")) {
if (!BOOST_GET_CONST(bool, desc.GetAttr("global_pooling"))) {
if (desc.HasAttr("exclusive")) {
if (BOOST_GET_CONST(bool, desc.GetAttr("exclusive"))) {
std::vector<int> ksize =
BOOST_GET_CONST(std::vector<int>, desc.GetAttr("ksize"));
for (size_t i = 0; i < ksize.size(); i++) {
if (ksize[i] <= paddings[i]) {
VLOG(3) << "the padding size should be less than the "
"filter size "
"for exclusive-counting pooling.";
return false;
}
}
}
}
}
}
}
} }
} }
if (op_type == "conv2d" || op_type == "conv2d_transpose" || if (op_type == "conv2d" || op_type == "conv2d_transpose" ||
op_type == "conv2d_fusion" || op_type == "depthwise_conv2d" || op_type == "conv2d_fusion" || op_type == "depthwise_conv2d" ||
op_type == "depthwise_conv2d_transpose") { op_type == "depthwise_conv2d_transpose") {
std::vector<int> paddings =
BOOST_GET_CONST(std::vector<int>, desc.GetAttr("paddings"));
// conv2d and conv2d_transpose need padding check
if (paddings.size() > 2 && op_type != "conv2d_fusion") return false;
if (desc.Input("Input").size() != 1) { if (desc.Input("Input").size() != 1) {
VLOG(3) << "TRT Conv2d expect 1 input, but got " VLOG(3) << "TRT Conv2d expect 1 input, but got "
<< desc.Input("Input").size() << " input."; << desc.Input("Input").size() << " input.";
......
...@@ -21,9 +21,22 @@ from typing import Optional, List, Callable, Dict, Any, Set ...@@ -21,9 +21,22 @@ from typing import Optional, List, Callable, Dict, Any, Set
class TrtConvertPool2dTest(TrtLayerAutoScanTest): class TrtConvertPool2dTest(TrtLayerAutoScanTest):
def is_program_valid(self, program_config: ProgramConfig) -> bool: def is_paddings_valid(self, program_config: ProgramConfig) -> bool:
exclusive = program_config.ops[0].attrs['exclusive']
paddings = program_config.ops[0].attrs['paddings']
ksize = program_config.ops[0].attrs['ksize']
pooling_type = program_config.ops[0].attrs['pooling_type']
global_pooling = program_config.ops[0].attrs['global_pooling']
if global_pooling == False:
if pooling_type == 'avg':
for index in range(len(ksize)):
if ksize[index] <= paddings[index]:
return False
return True return True
def is_program_valid(self, program_config: ProgramConfig) -> bool:
return self.is_paddings_valid(program_config)
def sample_program_configs(self): def sample_program_configs(self):
self.trt_param.workspace_size = 1073741824 self.trt_param.workspace_size = 1073741824
...@@ -34,7 +47,7 @@ class TrtConvertPool2dTest(TrtLayerAutoScanTest): ...@@ -34,7 +47,7 @@ class TrtConvertPool2dTest(TrtLayerAutoScanTest):
return np.random.random([24, 3, 3, 3]).astype(np.float32) return np.random.random([24, 3, 3, 3]).astype(np.float32)
for strides in [[1, 1], [2, 2], [1, 2]]: for strides in [[1, 1], [2, 2], [1, 2]]:
for paddings in [[0, 2], [0, 3], [1, 2, 3, 4]]: for paddings in [[0, 2], [0, 3], [0, 1, 2, 3]]:
for pooling_type in ['max', 'avg']: for pooling_type in ['max', 'avg']:
for padding_algotithm in ['EXPLICIT', 'SAME', 'VAILD']: for padding_algotithm in ['EXPLICIT', 'SAME', 'VAILD']:
for ksize in [[2, 3], [3, 3]]: for ksize in [[2, 3], [3, 3]]:
...@@ -43,7 +56,6 @@ class TrtConvertPool2dTest(TrtLayerAutoScanTest): ...@@ -43,7 +56,6 @@ class TrtConvertPool2dTest(TrtLayerAutoScanTest):
for exclusive in [True, False]: for exclusive in [True, False]:
for adaptive in [True, False]: for adaptive in [True, False]:
for ceil_mode in [True, False]: for ceil_mode in [True, False]:
self.paddings = paddings
dics = [{ dics = [{
"pooling_type": "pooling_type":
...@@ -102,9 +114,6 @@ class TrtConvertPool2dTest(TrtLayerAutoScanTest): ...@@ -102,9 +114,6 @@ class TrtConvertPool2dTest(TrtLayerAutoScanTest):
self.dynamic_shape.opt_input_shape = {} self.dynamic_shape.opt_input_shape = {}
def generate_trt_nodes_num(attrs, dynamic_shape): def generate_trt_nodes_num(attrs, dynamic_shape):
if self.paddings == [0, 3] or attrs[0][
'global_pooling'] == True or attrs[0]['ceil_mode'] == True:
return 0, 3
return 1, 2 return 1, 2
attrs = [ attrs = [
...@@ -139,6 +148,15 @@ class TrtConvertPool2dTest(TrtLayerAutoScanTest): ...@@ -139,6 +148,15 @@ class TrtConvertPool2dTest(TrtLayerAutoScanTest):
self.add_skip_case(teller1, SkipReasons.TRT_NOT_IMPLEMENTED, self.add_skip_case(teller1, SkipReasons.TRT_NOT_IMPLEMENTED,
"4-dims paddings are not support for trt now.") "4-dims paddings are not support for trt now.")
def teller2(program_config, predictor_config):
if program_config.ops[0].attrs['global_pooling'] == True:
return True
return False
self.add_skip_case(
teller2, SkipReasons.TRT_NOT_IMPLEMENTED,
"It is not support that global_pooling is true for trt now.")
def test(self): def test(self):
self.add_skip_trt_case() self.add_skip_trt_case()
self.run_test() self.run_test()
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册