diff --git a/paddle/fluid/inference/tensorrt/convert/pool2d_op.cc b/paddle/fluid/inference/tensorrt/convert/pool2d_op.cc index 11cad95361867476c6f775af778015da37f1cfb1..73f1b28ddf73403862e55d102a259d7b6cf67b1f 100644 --- a/paddle/fluid/inference/tensorrt/convert/pool2d_op.cc +++ b/paddle/fluid/inference/tensorrt/convert/pool2d_op.cc @@ -33,6 +33,7 @@ class Pool2dOpConverter : public OpConverter { PADDLE_ENFORCE_EQ(op_desc.Output("Out").size(), 1); auto* input1 = engine_->GetITensor(op_desc.Input("X")[0]); + bool global_pooling = boost::get(op_desc.GetAttr("global_pooling")); std::string pool_type = boost::get(op_desc.GetAttr("pooling_type")); std::vector ksize = @@ -42,7 +43,13 @@ class Pool2dOpConverter : public OpConverter { std::vector paddings = boost::get>(op_desc.GetAttr("paddings")); - const nvinfer1::DimsHW nv_ksize(ksize[0], ksize[1]); + nvinfer1::DimsHW nv_ksize(ksize[0], ksize[1]); + if (global_pooling == true) { + nvinfer1::Dims input_shape = input1->getDimensions(); + int nbDims = input_shape.nbDims; + nv_ksize.d[0] = input_shape.d[nbDims - 2]; + nv_ksize.d[1] = input_shape.d[nbDims - 1]; + } const nvinfer1::DimsHW nv_strides(strides[0], strides[1]); const nvinfer1::DimsHW nv_paddings(paddings[0], paddings[1]); diff --git a/paddle/fluid/inference/tensorrt/convert/test_pool2d_op.cc b/paddle/fluid/inference/tensorrt/convert/test_pool2d_op.cc index c5dddbc8cd37b9fb1ba39382af2da5ad045f3af2..dbdc0dcaf7649b838ea844659415f3cd269cd782 100644 --- a/paddle/fluid/inference/tensorrt/convert/test_pool2d_op.cc +++ b/paddle/fluid/inference/tensorrt/convert/test_pool2d_op.cc @@ -40,11 +40,13 @@ TEST(Pool2dOpConverter, main) { std::vector strides({2, 2}); std::vector paddings({0, 0}); std::string pooling_t = "max"; + bool global_pooling = false; desc.SetAttr("pooling_type", pooling_t); desc.SetAttr("ksize", ksize); desc.SetAttr("strides", strides); desc.SetAttr("paddings", paddings); + desc.SetAttr("global_pooling", global_pooling); LOG(INFO) << "set OP"; validator.SetOp(*desc.Proto());