From de6da1029d1888329bd1e3354714dd6aa75a201c Mon Sep 17 00:00:00 2001 From: nhzlx Date: Fri, 17 Aug 2018 05:02:33 +0000 Subject: [PATCH] add support for global pooling for trt --- paddle/fluid/inference/tensorrt/convert/pool2d_op.cc | 9 ++++++++- .../fluid/inference/tensorrt/convert/test_pool2d_op.cc | 2 ++ 2 files changed, 10 insertions(+), 1 deletion(-) diff --git a/paddle/fluid/inference/tensorrt/convert/pool2d_op.cc b/paddle/fluid/inference/tensorrt/convert/pool2d_op.cc index 11cad9536..73f1b28dd 100644 --- a/paddle/fluid/inference/tensorrt/convert/pool2d_op.cc +++ b/paddle/fluid/inference/tensorrt/convert/pool2d_op.cc @@ -33,6 +33,7 @@ class Pool2dOpConverter : public OpConverter { PADDLE_ENFORCE_EQ(op_desc.Output("Out").size(), 1); auto* input1 = engine_->GetITensor(op_desc.Input("X")[0]); + bool global_pooling = boost::get(op_desc.GetAttr("global_pooling")); std::string pool_type = boost::get(op_desc.GetAttr("pooling_type")); std::vector ksize = @@ -42,7 +43,13 @@ class Pool2dOpConverter : public OpConverter { std::vector paddings = boost::get>(op_desc.GetAttr("paddings")); - const nvinfer1::DimsHW nv_ksize(ksize[0], ksize[1]); + nvinfer1::DimsHW nv_ksize(ksize[0], ksize[1]); + if (global_pooling == true) { + nvinfer1::Dims input_shape = input1->getDimensions(); + int nbDims = input_shape.nbDims; + nv_ksize.d[0] = input_shape.d[nbDims - 2]; + nv_ksize.d[1] = input_shape.d[nbDims - 1]; + } const nvinfer1::DimsHW nv_strides(strides[0], strides[1]); const nvinfer1::DimsHW nv_paddings(paddings[0], paddings[1]); diff --git a/paddle/fluid/inference/tensorrt/convert/test_pool2d_op.cc b/paddle/fluid/inference/tensorrt/convert/test_pool2d_op.cc index c5dddbc8c..dbdc0dcaf 100644 --- a/paddle/fluid/inference/tensorrt/convert/test_pool2d_op.cc +++ b/paddle/fluid/inference/tensorrt/convert/test_pool2d_op.cc @@ -40,11 +40,13 @@ TEST(Pool2dOpConverter, main) { std::vector strides({2, 2}); std::vector paddings({0, 0}); std::string pooling_t = "max"; + bool global_pooling = false; desc.SetAttr("pooling_type", pooling_t); desc.SetAttr("ksize", ksize); desc.SetAttr("strides", strides); desc.SetAttr("paddings", paddings); + desc.SetAttr("global_pooling", global_pooling); LOG(INFO) << "set OP"; validator.SetOp(*desc.Proto()); -- GitLab