diff --git a/paddle/fluid/inference/tensorrt/convert/pool2d_op.cc b/paddle/fluid/inference/tensorrt/convert/pool2d_op.cc index e7605cceb7b45b95d6bd81f4bf69a9fdb0d7e276..7c887f506358ecc2816f1b1c690504a73adc011f 100644 --- a/paddle/fluid/inference/tensorrt/convert/pool2d_op.cc +++ b/paddle/fluid/inference/tensorrt/convert/pool2d_op.cc @@ -79,18 +79,25 @@ class Pool2dOpConverter : public OpConverter { std::vector paddings = boost::get>(op_desc.GetAttr("paddings")); bool ceil_mode = boost::get(op_desc.GetAttr("ceil_mode")); + bool exclusive = op_desc.HasAttr("exclusive") + ? boost::get(op_desc.GetAttr("exclusive")) + : true; bool adaptive = false; if (op_desc.HasAttr("adaptive")) adaptive = boost::get(op_desc.GetAttr("adaptive")); nvinfer1::PoolingType nv_pool_type = nvinfer1::PoolingType::kMAX; + nvinfer1::ReduceOperation reduce_operation = + nvinfer1::ReduceOperation::kMAX; plugin::PoolPlugin::PoolType plugin_pool_type = plugin::PoolPlugin::PoolType::max; if (pool_type == "max") { nv_pool_type = nvinfer1::PoolingType::kMAX; + reduce_operation = nvinfer1::ReduceOperation::kMAX; plugin_pool_type = plugin::PoolPlugin::PoolType::max; } else if (pool_type == "avg") { nv_pool_type = nvinfer1::PoolingType::kAVERAGE; + reduce_operation = nvinfer1::ReduceOperation::kAVG; plugin_pool_type = plugin::PoolPlugin::PoolType::avg; } else { PADDLE_THROW(platform::errors::Fatal( @@ -113,12 +120,17 @@ class Pool2dOpConverter : public OpConverter { } if (engine_->with_dynamic_shape()) { - if (!adaptive && pool_type == "max" && !global_pooling && !ceil_mode) { + if (!adaptive && !global_pooling && !ceil_mode) { auto *pool_layer = TRT_ENGINE_ADD_LAYER(engine_, Pooling, *input1, nv_pool_type, nv_ksize); pool_layer->setStride(nv_strides); pool_layer->setPadding(nv_paddings); + pool_layer->setAverageCountExcludesPadding(exclusive); layer = pool_layer; + } else if (global_pooling) { + auto *reduce_layer = TRT_ENGINE_ADD_LAYER(engine_, Reduce, *input1, + reduce_operation, 12, true); + layer = reduce_layer; } else { #if IS_TRT_VERSION_GE(6000) plugin::PoolPluginDynamic *plugin = @@ -140,23 +152,27 @@ class Pool2dOpConverter : public OpConverter { if (global_pooling == true) { nv_ksize.d[0] = input_shape.d[input_dims - 2]; nv_ksize.d[1] = input_shape.d[input_dims - 1]; - auto *layer = TRT_ENGINE_ADD_LAYER( + auto *pool_layer = TRT_ENGINE_ADD_LAYER( engine_, Pooling, *const_cast(input1), nv_pool_type, nv_ksize); PADDLE_ENFORCE_NOT_NULL( - layer, platform::errors::Fatal( - "trt pool layer in converter could not be created.")); + pool_layer, platform::errors::Fatal( + "trt pool layer in converter could not be created.")); auto output_name = op_desc.Output("Out")[0]; - layer->setName(("pool2d (Output: " + output_name + ")").c_str()); - layer->getOutput(0)->setName(output_name.c_str()); - engine_->SetITensor(output_name, layer->getOutput(0)); + pool_layer->setStride(nv_strides); + pool_layer->setPadding(nv_paddings); + pool_layer->setAverageCountExcludesPadding(exclusive); + pool_layer->setName(("pool2d (Output: " + output_name + ")").c_str()); + pool_layer->getOutput(0)->setName(output_name.c_str()); + engine_->SetITensor(output_name, pool_layer->getOutput(0)); + layer = pool_layer; if (test_mode) { engine_->DeclareOutput(output_name); } return; } - if (!adaptive && pool_type == "max") { + if (!adaptive) { // Under ceil mode, the pre_pad and post_pad are used to // record the the padding size. In some ceil mode cases, // we do not need padding, so we initialize the two vars to 0. @@ -184,6 +200,7 @@ class Pool2dOpConverter : public OpConverter { "trt pool layer in converter could not be created.")); pool_layer->setStride(nv_strides); pool_layer->setPadding(nv_paddings); + pool_layer->setAverageCountExcludesPadding(exclusive); layer = pool_layer; } else { // Average pooling needs to exclude the padding pixels from the average @@ -203,7 +220,6 @@ class Pool2dOpConverter : public OpConverter { "trt pool plugin layer in converter could not be created.")); layer = pool_layer; } - auto output_name = op_desc.Output("Out")[0]; RreplenishLayerAndOutput(layer, "pool2d", {output_name}, test_mode); }