未验证 提交 08dcea18 编写于 作者: W wenbin 提交者: GitHub

roi_align aligned supported (#38905)

roi_align aligned supported
上级 fc6eed5b
......@@ -51,6 +51,7 @@ class RoiAlignOpConverter : public OpConverter {
BOOST_GET_CONST(float, op_desc.GetAttr("spatial_scale"));
const auto sampling_ratio =
BOOST_GET_CONST(int, op_desc.GetAttr("sampling_ratio"));
const auto aligned = BOOST_GET_CONST(bool, op_desc.GetAttr("aligned"));
const auto input_tensor = engine_->GetITensor(input_name);
const auto rois_tensor = engine_->GetITensor(rois_name);
......@@ -63,7 +64,8 @@ class RoiAlignOpConverter : public OpConverter {
nvinfer1::ILayer* layer = nullptr;
auto* roi_align_plugin = new plugin::RoiAlignPluginDynamic(
data_type_, pooled_height, pooled_width, spatial_scale, sampling_ratio);
data_type_, pooled_height, pooled_width, spatial_scale, sampling_ratio,
aligned);
auto roi_align_layer = engine_->network()->addPluginV2(
inputs.data(), inputs.size(), *roi_align_plugin);
layer = roi_align_layer;
......
......@@ -13,9 +13,7 @@
// limitations under the License.
#include "paddle/fluid/inference/tensorrt/op_teller.h"
#include <bitset>
#include "paddle/fluid/framework/block_desc.h"
#include "paddle/fluid/framework/data_layout.h"
......@@ -737,28 +735,6 @@ bool OpTeller::Tell(const framework::ir::Node* node, bool use_no_calib_int8,
}
}
if (op_type == "roi_align") {
if (!with_dynamic_shape) return false;
std::vector<std::string> attrs{"pooled_height", "pooled_width",
"spatial_scale", "sampling_ratio"};
for (auto const attr : attrs) {
if (!desc.HasAttr(attr)) return false;
}
const auto pooled_height =
BOOST_GET_CONST(int, desc.GetAttr("pooled_height"));
if (pooled_height <= 0) return false;
const auto pooled_width =
BOOST_GET_CONST(int, desc.GetAttr("pooled_width"));
if (pooled_width <= 0) return false;
const auto spatial_scale =
BOOST_GET_CONST(float, desc.GetAttr("spatial_scale"));
if (spatial_scale <= 0.f) return false;
}
if (op_type == "hard_swish") {
if (desc.Input("X").size() != 1) {
VLOG(3) << "HardSwish op has only 1 input, but got "
......@@ -1303,12 +1279,6 @@ bool OpTeller::Tell(const framework::ir::Node* node, bool use_no_calib_int8,
BOOST_GET_CONST(float, desc.GetAttr("spatial_scale"));
if (spatial_scale <= 0.f) return false;
const auto sampling_ratio =
BOOST_GET_CONST(int, desc.GetAttr("sampling_ratio"));
const auto aligned = BOOST_GET_CONST(bool, desc.GetAttr("aligned"));
if (sampling_ratio == -1 && aligned == true) return false;
auto roi_align_inputs = desc.Inputs();
if (roi_align_inputs.find("RoisNum") != roi_align_inputs.end()) {
if (desc.Input("RoisNum").size() >= 1) {
......
......@@ -58,14 +58,12 @@ __inline__ __device__ T BilinearInterpolate(const T* input_data,
}
template <typename T, typename OutT, bool USE_SMEM>
__global__ void GPUROIAlignOpt(const int nthreads,
const T* __restrict__ input_data,
const T* __restrict__ input_rois,
const float spatial_scale, const int channels,
const int height, const int width,
const int pooled_height, const int pooled_width,
const int sampling_ratio, const int num_rois,
OutT* __restrict__ output_data) {
__global__ void GPUROIAlignOpt(
const int nthreads, const T* __restrict__ input_data,
const T* __restrict__ input_rois, const float spatial_scale,
const int channels, const int height, const int width,
const int pooled_height, const int pooled_width, const int sampling_ratio,
const int num_rois, const bool aligned, OutT* __restrict__ output_data) {
const int batch = blockIdx.x;
const int channel = blockIdx.y;
const T* offset_input_data =
......@@ -84,21 +82,28 @@ __global__ void GPUROIAlignOpt(const int nthreads,
const int roi_idx = (idx / pooled_width / pooled_height) % num_rois;
const int n = batch * num_rois + roi_idx;
const float4 rois_offset = reinterpret_cast<const float4*>(input_rois)[n];
const T roi_xmin = rois_offset.x * spatial_scale;
const T roi_ymin = rois_offset.y * spatial_scale;
const T roi_xmax = rois_offset.z * spatial_scale;
const T roi_ymax = rois_offset.w * spatial_scale;
const T roi_width = max(roi_xmax - roi_xmin, static_cast<T>(1.f));
const T roi_height = max(roi_ymax - roi_ymin, static_cast<T>(1.f));
const T bin_size_h = roi_height / static_cast<T>(pooled_height);
const T bin_size_w = roi_width / static_cast<T>(pooled_width);
const T roi_offset = aligned ? static_cast<T>(0.5) : 0;
const T roi_xmin = rois_offset.x * spatial_scale - roi_offset;
const T roi_ymin = rois_offset.y * spatial_scale - roi_offset;
const T roi_xmax = rois_offset.z * spatial_scale - roi_offset;
const T roi_ymax = rois_offset.w * spatial_scale - roi_offset;
T roi_width = roi_xmax - roi_xmin;
T roi_height = roi_ymax - roi_ymin;
if (!aligned) {
roi_width = max(roi_width, static_cast<T>(1.));
roi_height = max(roi_height, static_cast<T>(1.));
}
const T bin_size_h =
static_cast<T>(roi_height) / static_cast<T>(pooled_height);
const T bin_size_w =
static_cast<T>(roi_width) / static_cast<T>(pooled_width);
const int roi_bin_grid_h = (sampling_ratio > 0)
? sampling_ratio
: ceil(roi_height / pooled_height);
const int roi_bin_grid_w =
(sampling_ratio > 0) ? sampling_ratio : ceil(roi_width / pooled_width);
const T count = roi_bin_grid_h * roi_bin_grid_w;
const T count = max(roi_bin_grid_h * roi_bin_grid_w, 1);
T output_val = 0.f;
for (int iy = 0; iy < roi_bin_grid_h; ++iy) {
const T y = roi_ymin + ph * bin_size_h +
......@@ -132,12 +137,13 @@ RoiAlignPluginDynamic::RoiAlignPluginDynamic(const nvinfer1::DataType data_type,
const int pooled_height,
const int pooled_width,
float spatial_scale,
int sampling_ratio)
int sampling_ratio, bool aligned)
: data_type_(data_type),
pooled_height_(pooled_height),
pooled_width_(pooled_width),
spatial_scale_(spatial_scale),
sampling_ratio_(sampling_ratio) {
sampling_ratio_(sampling_ratio),
aligned_(aligned) {
bool data_type_is_valid = data_type_ == nvinfer1::DataType::kFLOAT ||
data_type_ == nvinfer1::DataType::kHALF;
PADDLE_ENFORCE_EQ(data_type_is_valid, true,
......@@ -187,6 +193,7 @@ RoiAlignPluginDynamic::RoiAlignPluginDynamic(void const* data, size_t length) {
DeserializeValue(&data, &length, &pooled_width_);
DeserializeValue(&data, &length, &spatial_scale_);
DeserializeValue(&data, &length, &sampling_ratio_);
DeserializeValue(&data, &length, &aligned_);
int smem_per_block = -1;
int device = -1;
cudaGetDevice(&device);
......@@ -204,7 +211,7 @@ nvinfer1::IPluginV2DynamicExt* RoiAlignPluginDynamic::clone() const
TRT_NOEXCEPT {
auto* plugin =
new RoiAlignPluginDynamic(data_type_, pooled_height_, pooled_width_,
spatial_scale_, sampling_ratio_);
spatial_scale_, sampling_ratio_, aligned_);
plugin->setPluginNamespace(namespace_.c_str());
return plugin;
}
......@@ -272,14 +279,15 @@ int RoiAlignPluginDynamic::enqueue_impl(
output_size, static_cast<const T*>(inputs[0]),
static_cast<const T*>(inputs[1]), spatial_scale_, channels, height,
width, pooled_height_, pooled_width_, sampling_ratio_, rois_num / batch,
static_cast<OutT*>(outputs[0]));
aligned_, static_cast<OutT*>(outputs[0]));
} else {
GPUROIAlignOpt<
T, OutT, true><<<blocks, threads, width * height * sizeof(T), stream>>>(
T, OutT,
false><<<blocks, threads, width * height * sizeof(T), stream>>>(
output_size, static_cast<const T*>(inputs[0]),
static_cast<const T*>(inputs[1]), spatial_scale_, channels, height,
width, pooled_height_, pooled_width_, sampling_ratio_, rois_num / batch,
static_cast<OutT*>(outputs[0]));
aligned_, static_cast<OutT*>(outputs[0]));
}
return cudaGetLastError() != cudaSuccess;
......@@ -313,6 +321,10 @@ const char* RoiAlignPluginDynamic::getPluginType() const TRT_NOEXCEPT {
return "roi_align_plugin_dynamic";
}
const char* RoiAlignPluginDynamic::getPluginVersion() const TRT_NOEXCEPT {
return "2";
}
int RoiAlignPluginDynamic::getNbOutputs() const TRT_NOEXCEPT { return 1; }
int RoiAlignPluginDynamic::initialize() TRT_NOEXCEPT { return 0; }
......@@ -326,6 +338,7 @@ size_t RoiAlignPluginDynamic::getSerializationSize() const TRT_NOEXCEPT {
serialize_size += SerializedSize(pooled_width_);
serialize_size += SerializedSize(spatial_scale_);
serialize_size += SerializedSize(sampling_ratio_);
serialize_size += SerializedSize(aligned_);
return serialize_size;
}
......@@ -335,6 +348,7 @@ void RoiAlignPluginDynamic::serialize(void* buffer) const TRT_NOEXCEPT {
SerializeValue(&buffer, pooled_width_);
SerializeValue(&buffer, spatial_scale_);
SerializeValue(&buffer, sampling_ratio_);
SerializeValue(&buffer, aligned_);
}
void RoiAlignPluginDynamic::destroy() TRT_NOEXCEPT {}
......@@ -357,7 +371,7 @@ const char* RoiAlignPluginDynamicCreator::getPluginName() const TRT_NOEXCEPT {
const char* RoiAlignPluginDynamicCreator::getPluginVersion() const
TRT_NOEXCEPT {
return "1";
return "2";
}
const nvinfer1::PluginFieldCollection*
......
......@@ -31,7 +31,7 @@ class RoiAlignPluginDynamic : public DynamicPluginTensorRT {
explicit RoiAlignPluginDynamic(const nvinfer1::DataType data_type,
const int pooled_height,
const int pooled_width, float spatial_scale,
int sampling_ratio);
int sampling_ratio, bool aligned);
RoiAlignPluginDynamic(void const* data, size_t length);
~RoiAlignPluginDynamic() = default;
nvinfer1::IPluginV2DynamicExt* clone() const TRT_NOEXCEPT override;
......@@ -66,6 +66,7 @@ class RoiAlignPluginDynamic : public DynamicPluginTensorRT {
size_t getSerializationSize() const TRT_NOEXCEPT override;
void serialize(void* buffer) const TRT_NOEXCEPT override;
void destroy() TRT_NOEXCEPT override;
const char* getPluginVersion() const TRT_NOEXCEPT override;
private:
template <typename T, typename OutT>
......@@ -80,6 +81,7 @@ class RoiAlignPluginDynamic : public DynamicPluginTensorRT {
float spatial_scale_;
int sampling_ratio_;
int smem_per_block_;
bool aligned_;
std::string namespace_;
};
......
......@@ -176,16 +176,6 @@ class TrtConvertRoiAlignTest(TrtLayerAutoScanTest):
self.add_skip_case(teller1, SkipReasons.TRT_NOT_SUPPORT,
"INPUT RoisNum NOT SUPPORT")
def teller2(program_config, predictor_config):
if (program_config.ops[0].attrs['sampling_ratio'] == -1 and
program_config.ops[0].attrs['aligned'] == True):
return True
return False
self.add_skip_case(
teller2, SkipReasons.TRT_NOT_SUPPORT,
"SAMPLING_RATIO EQUAL TO - 1 WHEN ALIGNED IS TRUE IS NOT SUPPORT")
def test(self):
self.add_skip_trt_case()
self.run_test()
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册