From 932bbe955b51cbf3dba0849a187e1aa42e8f2ed2 Mon Sep 17 00:00:00 2001
From: Zhaolong Xing <nhzlx.dragon@gmail.com>
Date: Wed, 2 Sep 2020 11:05:11 +0800
Subject: [PATCH] fix pool trt plugin bug (#26463)

test=develop
---
 .../tensorrt/plugin/pool_op_plugin.cu         | 61 ++++++++++++-------
 1 file changed, 40 insertions(+), 21 deletions(-)

diff --git a/paddle/fluid/inference/tensorrt/plugin/pool_op_plugin.cu b/paddle/fluid/inference/tensorrt/plugin/pool_op_plugin.cu
index 48afcfce347..1fa5b3228e1 100644
--- a/paddle/fluid/inference/tensorrt/plugin/pool_op_plugin.cu
+++ b/paddle/fluid/inference/tensorrt/plugin/pool_op_plugin.cu
@@ -104,32 +104,51 @@ nvinfer1::DimsExprs PoolPluginDynamic::getOutputDimensions(
 
   auto stri_0 = expr_builder.constant(strides_[0]);
   auto stri_1 = expr_builder.constant(strides_[1]);
+  auto one_value = expr_builder.constant(1);
 
-  auto tmp1_0 =
-      expr_builder.constant((-ksize_[0] + 2 * paddings_[0]) / strides_[0] + 1);
-  auto tmp1_1 =
-      expr_builder.constant((-ksize_[1] + 2 * paddings_[1]) / strides_[1] + 1);
+  auto v0_tmp = expr_builder.constant(-ksize_[0] + 2 * paddings_[0]);
+  auto v1_tmp = expr_builder.constant(-ksize_[1] + 2 * paddings_[1]);
 
-  auto tmp2_0 = expr_builder.constant(
-      (-ksize_[0] + 2 * paddings_[0] + strides_[0] - 1) / strides_[0] + 1);
-  auto tmp2_1 = expr_builder.constant(
-      (-ksize_[1] + 2 * paddings_[1] + strides_[1] - 1) / strides_[1] + 1);
-
-  auto *a_d = expr_builder.operation(nvinfer1::DimensionOperation::kCEIL_DIV,
-                                     *inputs[0].d[2], *stri_0);
-  auto *b_d = expr_builder.operation(nvinfer1::DimensionOperation::kCEIL_DIV,
-                                     *inputs[0].d[3], *stri_1);
+  auto ceil_tmp =
+      expr_builder.constant(-ksize_[0] + 2 * paddings_[0] + strides_[0] - 1);
+  auto ceil1_tmp =
+      expr_builder.constant(-ksize_[1] + 2 * paddings_[1] + strides_[1] - 1);
 
   if (!ceil_mode_) {
-    output.d[2] = expr_builder.operation(nvinfer1::DimensionOperation::kSUM,
-                                         *a_d, *tmp1_0);
-    output.d[3] = expr_builder.operation(nvinfer1::DimensionOperation::kSUM,
-                                         *b_d, *tmp1_1);
+    output.d[2] = expr_builder.operation(
+        nvinfer1::DimensionOperation::kSUM,
+        *expr_builder.operation(
+            nvinfer1::DimensionOperation::kFLOOR_DIV,
+            *expr_builder.operation(nvinfer1::DimensionOperation::kSUM,
+                                    *inputs[0].d[2], *v0_tmp),
+            *stri_0),
+        *one_value);
+    output.d[3] = expr_builder.operation(
+        nvinfer1::DimensionOperation::kSUM,
+        *expr_builder.operation(
+            nvinfer1::DimensionOperation::kFLOOR_DIV,
+            *expr_builder.operation(nvinfer1::DimensionOperation::kSUM,
+                                    *inputs[0].d[3], *v1_tmp),
+            *stri_1),
+        *one_value);
+
   } else {
-    output.d[2] = expr_builder.operation(nvinfer1::DimensionOperation::kSUM,
-                                         *a_d, *tmp2_0);
-    output.d[3] = expr_builder.operation(nvinfer1::DimensionOperation::kSUM,
-                                         *b_d, *tmp2_1);
+    output.d[2] = expr_builder.operation(
+        nvinfer1::DimensionOperation::kSUM,
+        *expr_builder.operation(
+            nvinfer1::DimensionOperation::kFLOOR_DIV,
+            *expr_builder.operation(nvinfer1::DimensionOperation::kSUM,
+                                    *inputs[0].d[2], *ceil_tmp),
+            *stri_0),
+        *one_value);
+    output.d[3] = expr_builder.operation(
+        nvinfer1::DimensionOperation::kSUM,
+        *expr_builder.operation(
+            nvinfer1::DimensionOperation::kFLOOR_DIV,
+            *expr_builder.operation(nvinfer1::DimensionOperation::kSUM,
+                                    *inputs[0].d[3], *ceil1_tmp),
+            *stri_1),
+        *one_value);
   }
 
   return output;
-- 
GitLab