diff --git a/paddle/fluid/inference/tensorrt/convert/pool2d_op.cc b/paddle/fluid/inference/tensorrt/convert/pool2d_op.cc index 11cad95361867476c6f775af778015da37f1cfb1..73f1b28ddf73403862e55d102a259d7b6cf67b1f 100644 --- a/paddle/fluid/inference/tensorrt/convert/pool2d_op.cc +++ b/paddle/fluid/inference/tensorrt/convert/pool2d_op.cc @@ -33,6 +33,7 @@ class Pool2dOpConverter : public OpConverter { PADDLE_ENFORCE_EQ(op_desc.Output("Out").size(), 1); auto* input1 = engine_->GetITensor(op_desc.Input("X")[0]); + bool global_pooling = boost::get(op_desc.GetAttr("global_pooling")); std::string pool_type = boost::get(op_desc.GetAttr("pooling_type")); std::vector ksize = @@ -42,7 +43,13 @@ class Pool2dOpConverter : public OpConverter { std::vector paddings = boost::get>(op_desc.GetAttr("paddings")); - const nvinfer1::DimsHW nv_ksize(ksize[0], ksize[1]); + nvinfer1::DimsHW nv_ksize(ksize[0], ksize[1]); + if (global_pooling == true) { + nvinfer1::Dims input_shape = input1->getDimensions(); + int nbDims = input_shape.nbDims; + nv_ksize.d[0] = input_shape.d[nbDims - 2]; + nv_ksize.d[1] = input_shape.d[nbDims - 1]; + } const nvinfer1::DimsHW nv_strides(strides[0], strides[1]); const nvinfer1::DimsHW nv_paddings(paddings[0], paddings[1]); diff --git a/paddle/fluid/inference/tensorrt/convert/test_pool2d_op.cc b/paddle/fluid/inference/tensorrt/convert/test_pool2d_op.cc index c5dddbc8cd37b9fb1ba39382af2da5ad045f3af2..aedd6b62df040eeee4e48f628128511cd8bf4439 100644 --- a/paddle/fluid/inference/tensorrt/convert/test_pool2d_op.cc +++ b/paddle/fluid/inference/tensorrt/convert/test_pool2d_op.cc @@ -20,7 +20,7 @@ namespace paddle { namespace inference { namespace tensorrt { -TEST(Pool2dOpConverter, main) { +void test_pool2d(bool global_pooling) { framework::Scope scope; std::unordered_set parameters; TRTConvertValidation validator(5, parameters, scope, 1 << 15); @@ -28,7 +28,10 @@ TEST(Pool2dOpConverter, main) { // The ITensor's Dims should not contain the batch size. // So, the ITensor's Dims of input and output should be C * H * W. validator.DeclInputVar("pool2d-X", nvinfer1::Dims3(3, 4, 4)); - validator.DeclOutputVar("pool2d-Out", nvinfer1::Dims3(3, 2, 2)); + if (global_pooling) + validator.DeclOutputVar("pool2d-Out", nvinfer1::Dims3(3, 1, 1)); + else + validator.DeclOutputVar("pool2d-Out", nvinfer1::Dims3(3, 2, 2)); // Prepare Op description framework::OpDesc desc; @@ -45,6 +48,7 @@ TEST(Pool2dOpConverter, main) { desc.SetAttr("ksize", ksize); desc.SetAttr("strides", strides); desc.SetAttr("paddings", paddings); + desc.SetAttr("global_pooling", global_pooling); LOG(INFO) << "set OP"; validator.SetOp(*desc.Proto()); @@ -53,6 +57,10 @@ TEST(Pool2dOpConverter, main) { validator.Execute(3); } +TEST(Pool2dOpConverter, normal) { test_pool2d(false); } + +TEST(Pool2dOpConverter, test_global_pooling) { test_pool2d(true); } + } // namespace tensorrt } // namespace inference } // namespace paddle