From 932bbe955b51cbf3dba0849a187e1aa42e8f2ed2 Mon Sep 17 00:00:00 2001 From: Zhaolong Xing Date: Wed, 2 Sep 2020 11:05:11 +0800 Subject: [PATCH] fix pool trt plugin bug (#26463) test=develop --- .../tensorrt/plugin/pool_op_plugin.cu | 61 ++++++++++++------- 1 file changed, 40 insertions(+), 21 deletions(-) diff --git a/paddle/fluid/inference/tensorrt/plugin/pool_op_plugin.cu b/paddle/fluid/inference/tensorrt/plugin/pool_op_plugin.cu index 48afcfce347..1fa5b3228e1 100644 --- a/paddle/fluid/inference/tensorrt/plugin/pool_op_plugin.cu +++ b/paddle/fluid/inference/tensorrt/plugin/pool_op_plugin.cu @@ -104,32 +104,51 @@ nvinfer1::DimsExprs PoolPluginDynamic::getOutputDimensions( auto stri_0 = expr_builder.constant(strides_[0]); auto stri_1 = expr_builder.constant(strides_[1]); + auto one_value = expr_builder.constant(1); - auto tmp1_0 = - expr_builder.constant((-ksize_[0] + 2 * paddings_[0]) / strides_[0] + 1); - auto tmp1_1 = - expr_builder.constant((-ksize_[1] + 2 * paddings_[1]) / strides_[1] + 1); + auto v0_tmp = expr_builder.constant(-ksize_[0] + 2 * paddings_[0]); + auto v1_tmp = expr_builder.constant(-ksize_[1] + 2 * paddings_[1]); - auto tmp2_0 = expr_builder.constant( - (-ksize_[0] + 2 * paddings_[0] + strides_[0] - 1) / strides_[0] + 1); - auto tmp2_1 = expr_builder.constant( - (-ksize_[1] + 2 * paddings_[1] + strides_[1] - 1) / strides_[1] + 1); - - auto *a_d = expr_builder.operation(nvinfer1::DimensionOperation::kCEIL_DIV, - *inputs[0].d[2], *stri_0); - auto *b_d = expr_builder.operation(nvinfer1::DimensionOperation::kCEIL_DIV, - *inputs[0].d[3], *stri_1); + auto ceil_tmp = + expr_builder.constant(-ksize_[0] + 2 * paddings_[0] + strides_[0] - 1); + auto ceil1_tmp = + expr_builder.constant(-ksize_[1] + 2 * paddings_[1] + strides_[1] - 1); if (!ceil_mode_) { - output.d[2] = expr_builder.operation(nvinfer1::DimensionOperation::kSUM, - *a_d, *tmp1_0); - output.d[3] = expr_builder.operation(nvinfer1::DimensionOperation::kSUM, - *b_d, *tmp1_1); + output.d[2] = expr_builder.operation( + nvinfer1::DimensionOperation::kSUM, + *expr_builder.operation( + nvinfer1::DimensionOperation::kFLOOR_DIV, + *expr_builder.operation(nvinfer1::DimensionOperation::kSUM, + *inputs[0].d[2], *v0_tmp), + *stri_0), + *one_value); + output.d[3] = expr_builder.operation( + nvinfer1::DimensionOperation::kSUM, + *expr_builder.operation( + nvinfer1::DimensionOperation::kFLOOR_DIV, + *expr_builder.operation(nvinfer1::DimensionOperation::kSUM, + *inputs[0].d[3], *v1_tmp), + *stri_1), + *one_value); + } else { - output.d[2] = expr_builder.operation(nvinfer1::DimensionOperation::kSUM, - *a_d, *tmp2_0); - output.d[3] = expr_builder.operation(nvinfer1::DimensionOperation::kSUM, - *b_d, *tmp2_1); + output.d[2] = expr_builder.operation( + nvinfer1::DimensionOperation::kSUM, + *expr_builder.operation( + nvinfer1::DimensionOperation::kFLOOR_DIV, + *expr_builder.operation(nvinfer1::DimensionOperation::kSUM, + *inputs[0].d[2], *ceil_tmp), + *stri_0), + *one_value); + output.d[3] = expr_builder.operation( + nvinfer1::DimensionOperation::kSUM, + *expr_builder.operation( + nvinfer1::DimensionOperation::kFLOOR_DIV, + *expr_builder.operation(nvinfer1::DimensionOperation::kSUM, + *inputs[0].d[3], *ceil1_tmp), + *stri_1), + *one_value); } return output; -- GitLab