未验证 提交 08dcea18 编写于 作者: W wenbin 提交者: GitHub

roi_align aligned supported (#38905)

roi_align aligned supported
上级 fc6eed5b
...@@ -51,6 +51,7 @@ class RoiAlignOpConverter : public OpConverter { ...@@ -51,6 +51,7 @@ class RoiAlignOpConverter : public OpConverter {
BOOST_GET_CONST(float, op_desc.GetAttr("spatial_scale")); BOOST_GET_CONST(float, op_desc.GetAttr("spatial_scale"));
const auto sampling_ratio = const auto sampling_ratio =
BOOST_GET_CONST(int, op_desc.GetAttr("sampling_ratio")); BOOST_GET_CONST(int, op_desc.GetAttr("sampling_ratio"));
const auto aligned = BOOST_GET_CONST(bool, op_desc.GetAttr("aligned"));
const auto input_tensor = engine_->GetITensor(input_name); const auto input_tensor = engine_->GetITensor(input_name);
const auto rois_tensor = engine_->GetITensor(rois_name); const auto rois_tensor = engine_->GetITensor(rois_name);
...@@ -63,7 +64,8 @@ class RoiAlignOpConverter : public OpConverter { ...@@ -63,7 +64,8 @@ class RoiAlignOpConverter : public OpConverter {
nvinfer1::ILayer* layer = nullptr; nvinfer1::ILayer* layer = nullptr;
auto* roi_align_plugin = new plugin::RoiAlignPluginDynamic( auto* roi_align_plugin = new plugin::RoiAlignPluginDynamic(
data_type_, pooled_height, pooled_width, spatial_scale, sampling_ratio); data_type_, pooled_height, pooled_width, spatial_scale, sampling_ratio,
aligned);
auto roi_align_layer = engine_->network()->addPluginV2( auto roi_align_layer = engine_->network()->addPluginV2(
inputs.data(), inputs.size(), *roi_align_plugin); inputs.data(), inputs.size(), *roi_align_plugin);
layer = roi_align_layer; layer = roi_align_layer;
......
...@@ -13,9 +13,7 @@ ...@@ -13,9 +13,7 @@
// limitations under the License. // limitations under the License.
#include "paddle/fluid/inference/tensorrt/op_teller.h" #include "paddle/fluid/inference/tensorrt/op_teller.h"
#include <bitset> #include <bitset>
#include "paddle/fluid/framework/block_desc.h" #include "paddle/fluid/framework/block_desc.h"
#include "paddle/fluid/framework/data_layout.h" #include "paddle/fluid/framework/data_layout.h"
...@@ -737,28 +735,6 @@ bool OpTeller::Tell(const framework::ir::Node* node, bool use_no_calib_int8, ...@@ -737,28 +735,6 @@ bool OpTeller::Tell(const framework::ir::Node* node, bool use_no_calib_int8,
} }
} }
if (op_type == "roi_align") {
if (!with_dynamic_shape) return false;
std::vector<std::string> attrs{"pooled_height", "pooled_width",
"spatial_scale", "sampling_ratio"};
for (auto const attr : attrs) {
if (!desc.HasAttr(attr)) return false;
}
const auto pooled_height =
BOOST_GET_CONST(int, desc.GetAttr("pooled_height"));
if (pooled_height <= 0) return false;
const auto pooled_width =
BOOST_GET_CONST(int, desc.GetAttr("pooled_width"));
if (pooled_width <= 0) return false;
const auto spatial_scale =
BOOST_GET_CONST(float, desc.GetAttr("spatial_scale"));
if (spatial_scale <= 0.f) return false;
}
if (op_type == "hard_swish") { if (op_type == "hard_swish") {
if (desc.Input("X").size() != 1) { if (desc.Input("X").size() != 1) {
VLOG(3) << "HardSwish op has only 1 input, but got " VLOG(3) << "HardSwish op has only 1 input, but got "
...@@ -1303,12 +1279,6 @@ bool OpTeller::Tell(const framework::ir::Node* node, bool use_no_calib_int8, ...@@ -1303,12 +1279,6 @@ bool OpTeller::Tell(const framework::ir::Node* node, bool use_no_calib_int8,
BOOST_GET_CONST(float, desc.GetAttr("spatial_scale")); BOOST_GET_CONST(float, desc.GetAttr("spatial_scale"));
if (spatial_scale <= 0.f) return false; if (spatial_scale <= 0.f) return false;
const auto sampling_ratio =
BOOST_GET_CONST(int, desc.GetAttr("sampling_ratio"));
const auto aligned = BOOST_GET_CONST(bool, desc.GetAttr("aligned"));
if (sampling_ratio == -1 && aligned == true) return false;
auto roi_align_inputs = desc.Inputs(); auto roi_align_inputs = desc.Inputs();
if (roi_align_inputs.find("RoisNum") != roi_align_inputs.end()) { if (roi_align_inputs.find("RoisNum") != roi_align_inputs.end()) {
if (desc.Input("RoisNum").size() >= 1) { if (desc.Input("RoisNum").size() >= 1) {
......
...@@ -58,14 +58,12 @@ __inline__ __device__ T BilinearInterpolate(const T* input_data, ...@@ -58,14 +58,12 @@ __inline__ __device__ T BilinearInterpolate(const T* input_data,
} }
template <typename T, typename OutT, bool USE_SMEM> template <typename T, typename OutT, bool USE_SMEM>
__global__ void GPUROIAlignOpt(const int nthreads, __global__ void GPUROIAlignOpt(
const T* __restrict__ input_data, const int nthreads, const T* __restrict__ input_data,
const T* __restrict__ input_rois, const T* __restrict__ input_rois, const float spatial_scale,
const float spatial_scale, const int channels, const int channels, const int height, const int width,
const int height, const int width, const int pooled_height, const int pooled_width, const int sampling_ratio,
const int pooled_height, const int pooled_width, const int num_rois, const bool aligned, OutT* __restrict__ output_data) {
const int sampling_ratio, const int num_rois,
OutT* __restrict__ output_data) {
const int batch = blockIdx.x; const int batch = blockIdx.x;
const int channel = blockIdx.y; const int channel = blockIdx.y;
const T* offset_input_data = const T* offset_input_data =
...@@ -84,21 +82,28 @@ __global__ void GPUROIAlignOpt(const int nthreads, ...@@ -84,21 +82,28 @@ __global__ void GPUROIAlignOpt(const int nthreads,
const int roi_idx = (idx / pooled_width / pooled_height) % num_rois; const int roi_idx = (idx / pooled_width / pooled_height) % num_rois;
const int n = batch * num_rois + roi_idx; const int n = batch * num_rois + roi_idx;
const float4 rois_offset = reinterpret_cast<const float4*>(input_rois)[n]; const float4 rois_offset = reinterpret_cast<const float4*>(input_rois)[n];
const T roi_xmin = rois_offset.x * spatial_scale; const T roi_offset = aligned ? static_cast<T>(0.5) : 0;
const T roi_ymin = rois_offset.y * spatial_scale; const T roi_xmin = rois_offset.x * spatial_scale - roi_offset;
const T roi_xmax = rois_offset.z * spatial_scale; const T roi_ymin = rois_offset.y * spatial_scale - roi_offset;
const T roi_ymax = rois_offset.w * spatial_scale; const T roi_xmax = rois_offset.z * spatial_scale - roi_offset;
const T roi_width = max(roi_xmax - roi_xmin, static_cast<T>(1.f)); const T roi_ymax = rois_offset.w * spatial_scale - roi_offset;
const T roi_height = max(roi_ymax - roi_ymin, static_cast<T>(1.f));
const T bin_size_h = roi_height / static_cast<T>(pooled_height); T roi_width = roi_xmax - roi_xmin;
const T bin_size_w = roi_width / static_cast<T>(pooled_width); T roi_height = roi_ymax - roi_ymin;
if (!aligned) {
roi_width = max(roi_width, static_cast<T>(1.));
roi_height = max(roi_height, static_cast<T>(1.));
}
const T bin_size_h =
static_cast<T>(roi_height) / static_cast<T>(pooled_height);
const T bin_size_w =
static_cast<T>(roi_width) / static_cast<T>(pooled_width);
const int roi_bin_grid_h = (sampling_ratio > 0) const int roi_bin_grid_h = (sampling_ratio > 0)
? sampling_ratio ? sampling_ratio
: ceil(roi_height / pooled_height); : ceil(roi_height / pooled_height);
const int roi_bin_grid_w = const int roi_bin_grid_w =
(sampling_ratio > 0) ? sampling_ratio : ceil(roi_width / pooled_width); (sampling_ratio > 0) ? sampling_ratio : ceil(roi_width / pooled_width);
const T count = roi_bin_grid_h * roi_bin_grid_w; const T count = max(roi_bin_grid_h * roi_bin_grid_w, 1);
T output_val = 0.f; T output_val = 0.f;
for (int iy = 0; iy < roi_bin_grid_h; ++iy) { for (int iy = 0; iy < roi_bin_grid_h; ++iy) {
const T y = roi_ymin + ph * bin_size_h + const T y = roi_ymin + ph * bin_size_h +
...@@ -132,12 +137,13 @@ RoiAlignPluginDynamic::RoiAlignPluginDynamic(const nvinfer1::DataType data_type, ...@@ -132,12 +137,13 @@ RoiAlignPluginDynamic::RoiAlignPluginDynamic(const nvinfer1::DataType data_type,
const int pooled_height, const int pooled_height,
const int pooled_width, const int pooled_width,
float spatial_scale, float spatial_scale,
int sampling_ratio) int sampling_ratio, bool aligned)
: data_type_(data_type), : data_type_(data_type),
pooled_height_(pooled_height), pooled_height_(pooled_height),
pooled_width_(pooled_width), pooled_width_(pooled_width),
spatial_scale_(spatial_scale), spatial_scale_(spatial_scale),
sampling_ratio_(sampling_ratio) { sampling_ratio_(sampling_ratio),
aligned_(aligned) {
bool data_type_is_valid = data_type_ == nvinfer1::DataType::kFLOAT || bool data_type_is_valid = data_type_ == nvinfer1::DataType::kFLOAT ||
data_type_ == nvinfer1::DataType::kHALF; data_type_ == nvinfer1::DataType::kHALF;
PADDLE_ENFORCE_EQ(data_type_is_valid, true, PADDLE_ENFORCE_EQ(data_type_is_valid, true,
...@@ -187,6 +193,7 @@ RoiAlignPluginDynamic::RoiAlignPluginDynamic(void const* data, size_t length) { ...@@ -187,6 +193,7 @@ RoiAlignPluginDynamic::RoiAlignPluginDynamic(void const* data, size_t length) {
DeserializeValue(&data, &length, &pooled_width_); DeserializeValue(&data, &length, &pooled_width_);
DeserializeValue(&data, &length, &spatial_scale_); DeserializeValue(&data, &length, &spatial_scale_);
DeserializeValue(&data, &length, &sampling_ratio_); DeserializeValue(&data, &length, &sampling_ratio_);
DeserializeValue(&data, &length, &aligned_);
int smem_per_block = -1; int smem_per_block = -1;
int device = -1; int device = -1;
cudaGetDevice(&device); cudaGetDevice(&device);
...@@ -204,7 +211,7 @@ nvinfer1::IPluginV2DynamicExt* RoiAlignPluginDynamic::clone() const ...@@ -204,7 +211,7 @@ nvinfer1::IPluginV2DynamicExt* RoiAlignPluginDynamic::clone() const
TRT_NOEXCEPT { TRT_NOEXCEPT {
auto* plugin = auto* plugin =
new RoiAlignPluginDynamic(data_type_, pooled_height_, pooled_width_, new RoiAlignPluginDynamic(data_type_, pooled_height_, pooled_width_,
spatial_scale_, sampling_ratio_); spatial_scale_, sampling_ratio_, aligned_);
plugin->setPluginNamespace(namespace_.c_str()); plugin->setPluginNamespace(namespace_.c_str());
return plugin; return plugin;
} }
...@@ -272,14 +279,15 @@ int RoiAlignPluginDynamic::enqueue_impl( ...@@ -272,14 +279,15 @@ int RoiAlignPluginDynamic::enqueue_impl(
output_size, static_cast<const T*>(inputs[0]), output_size, static_cast<const T*>(inputs[0]),
static_cast<const T*>(inputs[1]), spatial_scale_, channels, height, static_cast<const T*>(inputs[1]), spatial_scale_, channels, height,
width, pooled_height_, pooled_width_, sampling_ratio_, rois_num / batch, width, pooled_height_, pooled_width_, sampling_ratio_, rois_num / batch,
static_cast<OutT*>(outputs[0])); aligned_, static_cast<OutT*>(outputs[0]));
} else { } else {
GPUROIAlignOpt< GPUROIAlignOpt<
T, OutT, true><<<blocks, threads, width * height * sizeof(T), stream>>>( T, OutT,
false><<<blocks, threads, width * height * sizeof(T), stream>>>(
output_size, static_cast<const T*>(inputs[0]), output_size, static_cast<const T*>(inputs[0]),
static_cast<const T*>(inputs[1]), spatial_scale_, channels, height, static_cast<const T*>(inputs[1]), spatial_scale_, channels, height,
width, pooled_height_, pooled_width_, sampling_ratio_, rois_num / batch, width, pooled_height_, pooled_width_, sampling_ratio_, rois_num / batch,
static_cast<OutT*>(outputs[0])); aligned_, static_cast<OutT*>(outputs[0]));
} }
return cudaGetLastError() != cudaSuccess; return cudaGetLastError() != cudaSuccess;
...@@ -313,6 +321,10 @@ const char* RoiAlignPluginDynamic::getPluginType() const TRT_NOEXCEPT { ...@@ -313,6 +321,10 @@ const char* RoiAlignPluginDynamic::getPluginType() const TRT_NOEXCEPT {
return "roi_align_plugin_dynamic"; return "roi_align_plugin_dynamic";
} }
const char* RoiAlignPluginDynamic::getPluginVersion() const TRT_NOEXCEPT {
return "2";
}
int RoiAlignPluginDynamic::getNbOutputs() const TRT_NOEXCEPT { return 1; } int RoiAlignPluginDynamic::getNbOutputs() const TRT_NOEXCEPT { return 1; }
int RoiAlignPluginDynamic::initialize() TRT_NOEXCEPT { return 0; } int RoiAlignPluginDynamic::initialize() TRT_NOEXCEPT { return 0; }
...@@ -326,6 +338,7 @@ size_t RoiAlignPluginDynamic::getSerializationSize() const TRT_NOEXCEPT { ...@@ -326,6 +338,7 @@ size_t RoiAlignPluginDynamic::getSerializationSize() const TRT_NOEXCEPT {
serialize_size += SerializedSize(pooled_width_); serialize_size += SerializedSize(pooled_width_);
serialize_size += SerializedSize(spatial_scale_); serialize_size += SerializedSize(spatial_scale_);
serialize_size += SerializedSize(sampling_ratio_); serialize_size += SerializedSize(sampling_ratio_);
serialize_size += SerializedSize(aligned_);
return serialize_size; return serialize_size;
} }
...@@ -335,6 +348,7 @@ void RoiAlignPluginDynamic::serialize(void* buffer) const TRT_NOEXCEPT { ...@@ -335,6 +348,7 @@ void RoiAlignPluginDynamic::serialize(void* buffer) const TRT_NOEXCEPT {
SerializeValue(&buffer, pooled_width_); SerializeValue(&buffer, pooled_width_);
SerializeValue(&buffer, spatial_scale_); SerializeValue(&buffer, spatial_scale_);
SerializeValue(&buffer, sampling_ratio_); SerializeValue(&buffer, sampling_ratio_);
SerializeValue(&buffer, aligned_);
} }
void RoiAlignPluginDynamic::destroy() TRT_NOEXCEPT {} void RoiAlignPluginDynamic::destroy() TRT_NOEXCEPT {}
...@@ -357,7 +371,7 @@ const char* RoiAlignPluginDynamicCreator::getPluginName() const TRT_NOEXCEPT { ...@@ -357,7 +371,7 @@ const char* RoiAlignPluginDynamicCreator::getPluginName() const TRT_NOEXCEPT {
const char* RoiAlignPluginDynamicCreator::getPluginVersion() const const char* RoiAlignPluginDynamicCreator::getPluginVersion() const
TRT_NOEXCEPT { TRT_NOEXCEPT {
return "1"; return "2";
} }
const nvinfer1::PluginFieldCollection* const nvinfer1::PluginFieldCollection*
......
...@@ -31,7 +31,7 @@ class RoiAlignPluginDynamic : public DynamicPluginTensorRT { ...@@ -31,7 +31,7 @@ class RoiAlignPluginDynamic : public DynamicPluginTensorRT {
explicit RoiAlignPluginDynamic(const nvinfer1::DataType data_type, explicit RoiAlignPluginDynamic(const nvinfer1::DataType data_type,
const int pooled_height, const int pooled_height,
const int pooled_width, float spatial_scale, const int pooled_width, float spatial_scale,
int sampling_ratio); int sampling_ratio, bool aligned);
RoiAlignPluginDynamic(void const* data, size_t length); RoiAlignPluginDynamic(void const* data, size_t length);
~RoiAlignPluginDynamic() = default; ~RoiAlignPluginDynamic() = default;
nvinfer1::IPluginV2DynamicExt* clone() const TRT_NOEXCEPT override; nvinfer1::IPluginV2DynamicExt* clone() const TRT_NOEXCEPT override;
...@@ -66,6 +66,7 @@ class RoiAlignPluginDynamic : public DynamicPluginTensorRT { ...@@ -66,6 +66,7 @@ class RoiAlignPluginDynamic : public DynamicPluginTensorRT {
size_t getSerializationSize() const TRT_NOEXCEPT override; size_t getSerializationSize() const TRT_NOEXCEPT override;
void serialize(void* buffer) const TRT_NOEXCEPT override; void serialize(void* buffer) const TRT_NOEXCEPT override;
void destroy() TRT_NOEXCEPT override; void destroy() TRT_NOEXCEPT override;
const char* getPluginVersion() const TRT_NOEXCEPT override;
private: private:
template <typename T, typename OutT> template <typename T, typename OutT>
...@@ -80,6 +81,7 @@ class RoiAlignPluginDynamic : public DynamicPluginTensorRT { ...@@ -80,6 +81,7 @@ class RoiAlignPluginDynamic : public DynamicPluginTensorRT {
float spatial_scale_; float spatial_scale_;
int sampling_ratio_; int sampling_ratio_;
int smem_per_block_; int smem_per_block_;
bool aligned_;
std::string namespace_; std::string namespace_;
}; };
......
...@@ -176,16 +176,6 @@ class TrtConvertRoiAlignTest(TrtLayerAutoScanTest): ...@@ -176,16 +176,6 @@ class TrtConvertRoiAlignTest(TrtLayerAutoScanTest):
self.add_skip_case(teller1, SkipReasons.TRT_NOT_SUPPORT, self.add_skip_case(teller1, SkipReasons.TRT_NOT_SUPPORT,
"INPUT RoisNum NOT SUPPORT") "INPUT RoisNum NOT SUPPORT")
def teller2(program_config, predictor_config):
if (program_config.ops[0].attrs['sampling_ratio'] == -1 and
program_config.ops[0].attrs['aligned'] == True):
return True
return False
self.add_skip_case(
teller2, SkipReasons.TRT_NOT_SUPPORT,
"SAMPLING_RATIO EQUAL TO - 1 WHEN ALIGNED IS TRUE IS NOT SUPPORT")
def test(self): def test(self):
self.add_skip_trt_case() self.add_skip_trt_case()
self.run_test() self.run_test()
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册