From 08dcea18edaf19ef1eeea1a8905e28d6f318d211 Mon Sep 17 00:00:00 2001 From: wenbin Date: Thu, 13 Jan 2022 14:00:27 +0800 Subject: [PATCH] roi_align aligned supported (#38905) roi_align aligned supported --- .../tensorrt/convert/roi_align_op.cc | 4 +- paddle/fluid/inference/tensorrt/op_teller.cc | 30 --------- .../tensorrt/plugin/roi_align_op_plugin.cu | 64 +++++++++++-------- .../tensorrt/plugin/roi_align_op_plugin.h | 4 +- .../inference/test_trt_convert_roi_align.py | 10 --- 5 files changed, 45 insertions(+), 67 deletions(-) diff --git a/paddle/fluid/inference/tensorrt/convert/roi_align_op.cc b/paddle/fluid/inference/tensorrt/convert/roi_align_op.cc index 654fe7e0133..54f7937d837 100644 --- a/paddle/fluid/inference/tensorrt/convert/roi_align_op.cc +++ b/paddle/fluid/inference/tensorrt/convert/roi_align_op.cc @@ -51,6 +51,7 @@ class RoiAlignOpConverter : public OpConverter { BOOST_GET_CONST(float, op_desc.GetAttr("spatial_scale")); const auto sampling_ratio = BOOST_GET_CONST(int, op_desc.GetAttr("sampling_ratio")); + const auto aligned = BOOST_GET_CONST(bool, op_desc.GetAttr("aligned")); const auto input_tensor = engine_->GetITensor(input_name); const auto rois_tensor = engine_->GetITensor(rois_name); @@ -63,7 +64,8 @@ class RoiAlignOpConverter : public OpConverter { nvinfer1::ILayer* layer = nullptr; auto* roi_align_plugin = new plugin::RoiAlignPluginDynamic( - data_type_, pooled_height, pooled_width, spatial_scale, sampling_ratio); + data_type_, pooled_height, pooled_width, spatial_scale, sampling_ratio, + aligned); auto roi_align_layer = engine_->network()->addPluginV2( inputs.data(), inputs.size(), *roi_align_plugin); layer = roi_align_layer; diff --git a/paddle/fluid/inference/tensorrt/op_teller.cc b/paddle/fluid/inference/tensorrt/op_teller.cc index 878eef016e7..ddee4e0d682 100644 --- a/paddle/fluid/inference/tensorrt/op_teller.cc +++ b/paddle/fluid/inference/tensorrt/op_teller.cc @@ -13,9 +13,7 @@ // limitations under the License. #include "paddle/fluid/inference/tensorrt/op_teller.h" - #include - #include "paddle/fluid/framework/block_desc.h" #include "paddle/fluid/framework/data_layout.h" @@ -737,28 +735,6 @@ bool OpTeller::Tell(const framework::ir::Node* node, bool use_no_calib_int8, } } - if (op_type == "roi_align") { - if (!with_dynamic_shape) return false; - - std::vector attrs{"pooled_height", "pooled_width", - "spatial_scale", "sampling_ratio"}; - for (auto const attr : attrs) { - if (!desc.HasAttr(attr)) return false; - } - - const auto pooled_height = - BOOST_GET_CONST(int, desc.GetAttr("pooled_height")); - if (pooled_height <= 0) return false; - - const auto pooled_width = - BOOST_GET_CONST(int, desc.GetAttr("pooled_width")); - if (pooled_width <= 0) return false; - - const auto spatial_scale = - BOOST_GET_CONST(float, desc.GetAttr("spatial_scale")); - if (spatial_scale <= 0.f) return false; - } - if (op_type == "hard_swish") { if (desc.Input("X").size() != 1) { VLOG(3) << "HardSwish op has only 1 input, but got " @@ -1303,12 +1279,6 @@ bool OpTeller::Tell(const framework::ir::Node* node, bool use_no_calib_int8, BOOST_GET_CONST(float, desc.GetAttr("spatial_scale")); if (spatial_scale <= 0.f) return false; - const auto sampling_ratio = - BOOST_GET_CONST(int, desc.GetAttr("sampling_ratio")); - const auto aligned = BOOST_GET_CONST(bool, desc.GetAttr("aligned")); - - if (sampling_ratio == -1 && aligned == true) return false; - auto roi_align_inputs = desc.Inputs(); if (roi_align_inputs.find("RoisNum") != roi_align_inputs.end()) { if (desc.Input("RoisNum").size() >= 1) { diff --git a/paddle/fluid/inference/tensorrt/plugin/roi_align_op_plugin.cu b/paddle/fluid/inference/tensorrt/plugin/roi_align_op_plugin.cu index 06540b36260..7dc31fb4471 100644 --- a/paddle/fluid/inference/tensorrt/plugin/roi_align_op_plugin.cu +++ b/paddle/fluid/inference/tensorrt/plugin/roi_align_op_plugin.cu @@ -58,14 +58,12 @@ __inline__ __device__ T BilinearInterpolate(const T* input_data, } template -__global__ void GPUROIAlignOpt(const int nthreads, - const T* __restrict__ input_data, - const T* __restrict__ input_rois, - const float spatial_scale, const int channels, - const int height, const int width, - const int pooled_height, const int pooled_width, - const int sampling_ratio, const int num_rois, - OutT* __restrict__ output_data) { +__global__ void GPUROIAlignOpt( + const int nthreads, const T* __restrict__ input_data, + const T* __restrict__ input_rois, const float spatial_scale, + const int channels, const int height, const int width, + const int pooled_height, const int pooled_width, const int sampling_ratio, + const int num_rois, const bool aligned, OutT* __restrict__ output_data) { const int batch = blockIdx.x; const int channel = blockIdx.y; const T* offset_input_data = @@ -84,21 +82,28 @@ __global__ void GPUROIAlignOpt(const int nthreads, const int roi_idx = (idx / pooled_width / pooled_height) % num_rois; const int n = batch * num_rois + roi_idx; const float4 rois_offset = reinterpret_cast(input_rois)[n]; - const T roi_xmin = rois_offset.x * spatial_scale; - const T roi_ymin = rois_offset.y * spatial_scale; - const T roi_xmax = rois_offset.z * spatial_scale; - const T roi_ymax = rois_offset.w * spatial_scale; - const T roi_width = max(roi_xmax - roi_xmin, static_cast(1.f)); - const T roi_height = max(roi_ymax - roi_ymin, static_cast(1.f)); - const T bin_size_h = roi_height / static_cast(pooled_height); - const T bin_size_w = roi_width / static_cast(pooled_width); + const T roi_offset = aligned ? static_cast(0.5) : 0; + const T roi_xmin = rois_offset.x * spatial_scale - roi_offset; + const T roi_ymin = rois_offset.y * spatial_scale - roi_offset; + const T roi_xmax = rois_offset.z * spatial_scale - roi_offset; + const T roi_ymax = rois_offset.w * spatial_scale - roi_offset; + + T roi_width = roi_xmax - roi_xmin; + T roi_height = roi_ymax - roi_ymin; + if (!aligned) { + roi_width = max(roi_width, static_cast(1.)); + roi_height = max(roi_height, static_cast(1.)); + } + const T bin_size_h = + static_cast(roi_height) / static_cast(pooled_height); + const T bin_size_w = + static_cast(roi_width) / static_cast(pooled_width); const int roi_bin_grid_h = (sampling_ratio > 0) ? sampling_ratio : ceil(roi_height / pooled_height); const int roi_bin_grid_w = (sampling_ratio > 0) ? sampling_ratio : ceil(roi_width / pooled_width); - const T count = roi_bin_grid_h * roi_bin_grid_w; - + const T count = max(roi_bin_grid_h * roi_bin_grid_w, 1); T output_val = 0.f; for (int iy = 0; iy < roi_bin_grid_h; ++iy) { const T y = roi_ymin + ph * bin_size_h + @@ -132,12 +137,13 @@ RoiAlignPluginDynamic::RoiAlignPluginDynamic(const nvinfer1::DataType data_type, const int pooled_height, const int pooled_width, float spatial_scale, - int sampling_ratio) + int sampling_ratio, bool aligned) : data_type_(data_type), pooled_height_(pooled_height), pooled_width_(pooled_width), spatial_scale_(spatial_scale), - sampling_ratio_(sampling_ratio) { + sampling_ratio_(sampling_ratio), + aligned_(aligned) { bool data_type_is_valid = data_type_ == nvinfer1::DataType::kFLOAT || data_type_ == nvinfer1::DataType::kHALF; PADDLE_ENFORCE_EQ(data_type_is_valid, true, @@ -187,6 +193,7 @@ RoiAlignPluginDynamic::RoiAlignPluginDynamic(void const* data, size_t length) { DeserializeValue(&data, &length, &pooled_width_); DeserializeValue(&data, &length, &spatial_scale_); DeserializeValue(&data, &length, &sampling_ratio_); + DeserializeValue(&data, &length, &aligned_); int smem_per_block = -1; int device = -1; cudaGetDevice(&device); @@ -204,7 +211,7 @@ nvinfer1::IPluginV2DynamicExt* RoiAlignPluginDynamic::clone() const TRT_NOEXCEPT { auto* plugin = new RoiAlignPluginDynamic(data_type_, pooled_height_, pooled_width_, - spatial_scale_, sampling_ratio_); + spatial_scale_, sampling_ratio_, aligned_); plugin->setPluginNamespace(namespace_.c_str()); return plugin; } @@ -272,14 +279,15 @@ int RoiAlignPluginDynamic::enqueue_impl( output_size, static_cast(inputs[0]), static_cast(inputs[1]), spatial_scale_, channels, height, width, pooled_height_, pooled_width_, sampling_ratio_, rois_num / batch, - static_cast(outputs[0])); + aligned_, static_cast(outputs[0])); } else { GPUROIAlignOpt< - T, OutT, true><<>>( + T, OutT, + false><<>>( output_size, static_cast(inputs[0]), static_cast(inputs[1]), spatial_scale_, channels, height, width, pooled_height_, pooled_width_, sampling_ratio_, rois_num / batch, - static_cast(outputs[0])); + aligned_, static_cast(outputs[0])); } return cudaGetLastError() != cudaSuccess; @@ -313,6 +321,10 @@ const char* RoiAlignPluginDynamic::getPluginType() const TRT_NOEXCEPT { return "roi_align_plugin_dynamic"; } +const char* RoiAlignPluginDynamic::getPluginVersion() const TRT_NOEXCEPT { + return "2"; +} + int RoiAlignPluginDynamic::getNbOutputs() const TRT_NOEXCEPT { return 1; } int RoiAlignPluginDynamic::initialize() TRT_NOEXCEPT { return 0; } @@ -326,6 +338,7 @@ size_t RoiAlignPluginDynamic::getSerializationSize() const TRT_NOEXCEPT { serialize_size += SerializedSize(pooled_width_); serialize_size += SerializedSize(spatial_scale_); serialize_size += SerializedSize(sampling_ratio_); + serialize_size += SerializedSize(aligned_); return serialize_size; } @@ -335,6 +348,7 @@ void RoiAlignPluginDynamic::serialize(void* buffer) const TRT_NOEXCEPT { SerializeValue(&buffer, pooled_width_); SerializeValue(&buffer, spatial_scale_); SerializeValue(&buffer, sampling_ratio_); + SerializeValue(&buffer, aligned_); } void RoiAlignPluginDynamic::destroy() TRT_NOEXCEPT {} @@ -357,7 +371,7 @@ const char* RoiAlignPluginDynamicCreator::getPluginName() const TRT_NOEXCEPT { const char* RoiAlignPluginDynamicCreator::getPluginVersion() const TRT_NOEXCEPT { - return "1"; + return "2"; } const nvinfer1::PluginFieldCollection* diff --git a/paddle/fluid/inference/tensorrt/plugin/roi_align_op_plugin.h b/paddle/fluid/inference/tensorrt/plugin/roi_align_op_plugin.h index 44d2b630698..9f4723da9e1 100644 --- a/paddle/fluid/inference/tensorrt/plugin/roi_align_op_plugin.h +++ b/paddle/fluid/inference/tensorrt/plugin/roi_align_op_plugin.h @@ -31,7 +31,7 @@ class RoiAlignPluginDynamic : public DynamicPluginTensorRT { explicit RoiAlignPluginDynamic(const nvinfer1::DataType data_type, const int pooled_height, const int pooled_width, float spatial_scale, - int sampling_ratio); + int sampling_ratio, bool aligned); RoiAlignPluginDynamic(void const* data, size_t length); ~RoiAlignPluginDynamic() = default; nvinfer1::IPluginV2DynamicExt* clone() const TRT_NOEXCEPT override; @@ -66,6 +66,7 @@ class RoiAlignPluginDynamic : public DynamicPluginTensorRT { size_t getSerializationSize() const TRT_NOEXCEPT override; void serialize(void* buffer) const TRT_NOEXCEPT override; void destroy() TRT_NOEXCEPT override; + const char* getPluginVersion() const TRT_NOEXCEPT override; private: template @@ -80,6 +81,7 @@ class RoiAlignPluginDynamic : public DynamicPluginTensorRT { float spatial_scale_; int sampling_ratio_; int smem_per_block_; + bool aligned_; std::string namespace_; }; diff --git a/python/paddle/fluid/tests/unittests/ir/inference/test_trt_convert_roi_align.py b/python/paddle/fluid/tests/unittests/ir/inference/test_trt_convert_roi_align.py index 56efdb91959..b2d754337fe 100644 --- a/python/paddle/fluid/tests/unittests/ir/inference/test_trt_convert_roi_align.py +++ b/python/paddle/fluid/tests/unittests/ir/inference/test_trt_convert_roi_align.py @@ -176,16 +176,6 @@ class TrtConvertRoiAlignTest(TrtLayerAutoScanTest): self.add_skip_case(teller1, SkipReasons.TRT_NOT_SUPPORT, "INPUT RoisNum NOT SUPPORT") - def teller2(program_config, predictor_config): - if (program_config.ops[0].attrs['sampling_ratio'] == -1 and - program_config.ops[0].attrs['aligned'] == True): - return True - return False - - self.add_skip_case( - teller2, SkipReasons.TRT_NOT_SUPPORT, - "SAMPLING_RATIO EQUAL TO - 1 WHEN ALIGNED IS TRUE IS NOT SUPPORT") - def test(self): self.add_skip_trt_case() self.run_test() -- GitLab