未验证 提交 e568268b 编写于 作者: J JingZhuangzhuang 提交者: GitHub

fix_poo2d_trt_convert (#41860) (#41915)

上级 aa6eb0e8
......@@ -256,6 +256,37 @@ class Pool2dOpConverter : public OpConverter {
if (!adaptive) {
if (ceil_mode) {
if (nv_ksize.d[0] % nv_strides.d[0] == 0 &&
nv_ksize.d[1] % nv_strides.d[1] == 0) {
nvinfer1::DimsHW pre_pad(0, 0);
nvinfer1::DimsHW post_pad(0, 0);
// If ceil mode is true, we will pad the appropriate size to the
// input.
DealCeilMode(input_shape, ksize, strides, paddings, &pre_pad,
&post_pad, input_dims);
auto *pad_layer = TRT_ENGINE_ADD_LAYER(engine_, Padding, *input1,
pre_pad, post_pad);
PADDLE_ENFORCE_NOT_NULL(
pad_layer, platform::errors::Fatal(
"Pad layer in poolOp converter could not be "
"created. The pointer to pad layer is `NULL`."));
input1 = pad_layer->getOutput(0);
auto *pool_layer = TRT_ENGINE_ADD_LAYER(engine_, Pooling, *input1,
nv_pool_type, nv_ksize);
PADDLE_ENFORCE_NOT_NULL(
pool_layer,
platform::errors::Fatal(
"trt pool layer in converter could not be created."));
pool_layer->setStride(nv_strides);
pool_layer->setPadding(nv_paddings);
if (padding_algorithm == "SAME") {
pool_layer->setPaddingMode(nvinfer1::PaddingMode::kSAME_UPPER);
}
pool_layer->setAverageCountExcludesPadding(exclusive);
layer = pool_layer;
} else {
std::vector<int> input_shape_v;
for (int i = 0; i < input_dims; i++) {
input_shape_v.push_back(input_shape.d[i]);
......@@ -269,6 +300,7 @@ class Pool2dOpConverter : public OpConverter {
platform::errors::Fatal(
"trt pool plugin layer in converter could not be created."));
layer = pool_layer;
}
} else {
#if IS_TRT_VERSION_GE(8000)
// Exclude padding pixels from the average mean is not supported well by
......@@ -299,7 +331,6 @@ class Pool2dOpConverter : public OpConverter {
pool_layer->setAverageCountExcludesPadding(exclusive);
layer = pool_layer;
}
} else {
// Average pooling needs to exclude the padding pixels from the average
// mean.
......@@ -327,5 +358,4 @@ class Pool2dOpConverter : public OpConverter {
} // namespace inference
} // namespace paddle
USE_OP_ITSELF(pool2d);
REGISTER_TRT_OP_CONVERTER(pool2d, Pool2dOpConverter);
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册