未验证 提交 71cb3ff8 编写于 作者: W wangxinxin08 提交者: GitHub

enhance yolobox trt plugin (#34128)

* enhance yolobox plugin
上级 7850f7ce
...@@ -48,13 +48,20 @@ class YoloBoxOpConverter : public OpConverter { ...@@ -48,13 +48,20 @@ class YoloBoxOpConverter : public OpConverter {
float conf_thresh = BOOST_GET_CONST(float, op_desc.GetAttr("conf_thresh")); float conf_thresh = BOOST_GET_CONST(float, op_desc.GetAttr("conf_thresh"));
bool clip_bbox = BOOST_GET_CONST(bool, op_desc.GetAttr("clip_bbox")); bool clip_bbox = BOOST_GET_CONST(bool, op_desc.GetAttr("clip_bbox"));
float scale_x_y = BOOST_GET_CONST(float, op_desc.GetAttr("scale_x_y")); float scale_x_y = BOOST_GET_CONST(float, op_desc.GetAttr("scale_x_y"));
bool iou_aware = op_desc.HasAttr("iou_aware")
? BOOST_GET_CONST(bool, op_desc.GetAttr("iou_aware"))
: false;
float iou_aware_factor =
op_desc.HasAttr("iou_aware_factor")
? BOOST_GET_CONST(float, op_desc.GetAttr("iou_aware_factor"))
: 0.5;
int type_id = static_cast<int>(engine_->WithFp16()); int type_id = static_cast<int>(engine_->WithFp16());
auto input_dim = X_tensor->getDimensions(); auto input_dim = X_tensor->getDimensions();
auto* yolo_box_plugin = new plugin::YoloBoxPlugin( auto* yolo_box_plugin = new plugin::YoloBoxPlugin(
type_id ? nvinfer1::DataType::kHALF : nvinfer1::DataType::kFLOAT, type_id ? nvinfer1::DataType::kHALF : nvinfer1::DataType::kFLOAT,
anchors, class_num, conf_thresh, downsample_ratio, clip_bbox, scale_x_y, anchors, class_num, conf_thresh, downsample_ratio, clip_bbox, scale_x_y,
input_dim.d[1], input_dim.d[2]); iou_aware, iou_aware_factor, input_dim.d[1], input_dim.d[2]);
std::vector<nvinfer1::ITensor*> yolo_box_inputs; std::vector<nvinfer1::ITensor*> yolo_box_inputs;
yolo_box_inputs.push_back(X_tensor); yolo_box_inputs.push_back(X_tensor);
......
...@@ -12,8 +12,6 @@ ...@@ -12,8 +12,6 @@
// See the License for the specific language governing permissions and // See the License for the specific language governing permissions and
// limitations under the License. // limitations under the License.
#include <cuda_fp16.h>
#include <cuda_runtime.h>
#include <algorithm> #include <algorithm>
#include <cassert> #include <cassert>
...@@ -29,7 +27,8 @@ YoloBoxPlugin::YoloBoxPlugin(const nvinfer1::DataType data_type, ...@@ -29,7 +27,8 @@ YoloBoxPlugin::YoloBoxPlugin(const nvinfer1::DataType data_type,
const std::vector<int>& anchors, const std::vector<int>& anchors,
const int class_num, const float conf_thresh, const int class_num, const float conf_thresh,
const int downsample_ratio, const bool clip_bbox, const int downsample_ratio, const bool clip_bbox,
const float scale_x_y, const int input_h, const float scale_x_y, const bool iou_aware,
const float iou_aware_factor, const int input_h,
const int input_w) const int input_w)
: data_type_(data_type), : data_type_(data_type),
class_num_(class_num), class_num_(class_num),
...@@ -37,6 +36,8 @@ YoloBoxPlugin::YoloBoxPlugin(const nvinfer1::DataType data_type, ...@@ -37,6 +36,8 @@ YoloBoxPlugin::YoloBoxPlugin(const nvinfer1::DataType data_type,
downsample_ratio_(downsample_ratio), downsample_ratio_(downsample_ratio),
clip_bbox_(clip_bbox), clip_bbox_(clip_bbox),
scale_x_y_(scale_x_y), scale_x_y_(scale_x_y),
iou_aware_(iou_aware),
iou_aware_factor_(iou_aware_factor),
input_h_(input_h), input_h_(input_h),
input_w_(input_w) { input_w_(input_w) {
anchors_.insert(anchors_.end(), anchors.cbegin(), anchors.cend()); anchors_.insert(anchors_.end(), anchors.cbegin(), anchors.cend());
...@@ -45,6 +46,7 @@ YoloBoxPlugin::YoloBoxPlugin(const nvinfer1::DataType data_type, ...@@ -45,6 +46,7 @@ YoloBoxPlugin::YoloBoxPlugin(const nvinfer1::DataType data_type,
assert(class_num_ > 0); assert(class_num_ > 0);
assert(input_h_ > 0); assert(input_h_ > 0);
assert(input_w_ > 0); assert(input_w_ > 0);
assert((iou_aware_factor_ > 0 && iou_aware_factor_ < 1));
cudaMalloc(&anchors_device_, anchors.size() * sizeof(int)); cudaMalloc(&anchors_device_, anchors.size() * sizeof(int));
cudaMemcpy(anchors_device_, anchors.data(), anchors.size() * sizeof(int), cudaMemcpy(anchors_device_, anchors.data(), anchors.size() * sizeof(int),
...@@ -59,6 +61,8 @@ YoloBoxPlugin::YoloBoxPlugin(const void* data, size_t length) { ...@@ -59,6 +61,8 @@ YoloBoxPlugin::YoloBoxPlugin(const void* data, size_t length) {
DeserializeValue(&data, &length, &downsample_ratio_); DeserializeValue(&data, &length, &downsample_ratio_);
DeserializeValue(&data, &length, &clip_bbox_); DeserializeValue(&data, &length, &clip_bbox_);
DeserializeValue(&data, &length, &scale_x_y_); DeserializeValue(&data, &length, &scale_x_y_);
DeserializeValue(&data, &length, &iou_aware_);
DeserializeValue(&data, &length, &iou_aware_factor_);
DeserializeValue(&data, &length, &input_h_); DeserializeValue(&data, &length, &input_h_);
DeserializeValue(&data, &length, &input_w_); DeserializeValue(&data, &length, &input_w_);
} }
...@@ -133,8 +137,19 @@ __device__ inline void GetYoloBox(float* box, const T* x, const int* anchors, ...@@ -133,8 +137,19 @@ __device__ inline void GetYoloBox(float* box, const T* x, const int* anchors,
__device__ inline int GetEntryIndex(int batch, int an_idx, int hw_idx, __device__ inline int GetEntryIndex(int batch, int an_idx, int hw_idx,
int an_num, int an_stride, int stride, int an_num, int an_stride, int stride,
int entry) { int entry, bool iou_aware) {
return (batch * an_num + an_idx) * an_stride + entry * stride + hw_idx; if (iou_aware) {
return (batch * an_num + an_idx) * an_stride +
(batch * an_num + an_num + entry) * stride + hw_idx;
} else {
return (batch * an_num + an_idx) * an_stride + entry * stride + hw_idx;
}
}
__device__ inline int GetIoUIndex(int batch, int an_idx, int hw_idx, int an_num,
int an_stride, int stride) {
return batch * an_num * an_stride + (batch * an_num + an_idx) * stride +
hw_idx;
} }
template <typename T> template <typename T>
...@@ -178,7 +193,8 @@ __global__ void KeYoloBoxFw(const T* const input, const int* const imgsize, ...@@ -178,7 +193,8 @@ __global__ void KeYoloBoxFw(const T* const input, const int* const imgsize,
const int w, const int an_num, const int class_num, const int w, const int an_num, const int class_num,
const int box_num, int input_size_h, const int box_num, int input_size_h,
int input_size_w, bool clip_bbox, const float scale, int input_size_w, bool clip_bbox, const float scale,
const float bias) { const float bias, bool iou_aware,
const float iou_aware_factor) {
int tid = blockIdx.x * blockDim.x + threadIdx.x; int tid = blockIdx.x * blockDim.x + threadIdx.x;
int stride = blockDim.x * gridDim.x; int stride = blockDim.x * gridDim.x;
float box[4]; float box[4];
...@@ -193,11 +209,16 @@ __global__ void KeYoloBoxFw(const T* const input, const int* const imgsize, ...@@ -193,11 +209,16 @@ __global__ void KeYoloBoxFw(const T* const input, const int* const imgsize,
int img_height = imgsize[2 * i]; int img_height = imgsize[2 * i];
int img_width = imgsize[2 * i + 1]; int img_width = imgsize[2 * i + 1];
int obj_idx = int obj_idx = GetEntryIndex(i, j, k * w + l, an_num, an_stride, grid_num, 4,
GetEntryIndex(i, j, k * w + l, an_num, an_stride, grid_num, 4); iou_aware);
float conf = sigmoid(static_cast<float>(input[obj_idx])); float conf = sigmoid(static_cast<float>(input[obj_idx]));
int box_idx = if (iou_aware) {
GetEntryIndex(i, j, k * w + l, an_num, an_stride, grid_num, 0); int iou_idx = GetIoUIndex(i, j, k * w + l, an_num, an_stride, grid_num);
float iou = sigmoid<float>(input[iou_idx]);
conf = powf(conf, 1. - iou_aware_factor) * powf(iou, iou_aware_factor);
}
int box_idx = GetEntryIndex(i, j, k * w + l, an_num, an_stride, grid_num, 0,
iou_aware);
if (conf < conf_thresh) { if (conf < conf_thresh) {
for (int i = 0; i < 4; ++i) { for (int i = 0; i < 4; ++i) {
...@@ -212,8 +233,8 @@ __global__ void KeYoloBoxFw(const T* const input, const int* const imgsize, ...@@ -212,8 +233,8 @@ __global__ void KeYoloBoxFw(const T* const input, const int* const imgsize,
box_idx = (i * box_num + j * grid_num + k * w + l) * 4; box_idx = (i * box_num + j * grid_num + k * w + l) * 4;
CalcDetectionBox<T>(boxes, box, box_idx, img_height, img_width, clip_bbox); CalcDetectionBox<T>(boxes, box, box_idx, img_height, img_width, clip_bbox);
int label_idx = int label_idx = GetEntryIndex(i, j, k * w + l, an_num, an_stride, grid_num,
GetEntryIndex(i, j, k * w + l, an_num, an_stride, grid_num, 5); 5, iou_aware);
int score_idx = (i * box_num + j * grid_num + k * w + l) * class_num; int score_idx = (i * box_num + j * grid_num + k * w + l) * class_num;
CalcLabelScore<T>(scores, input, label_idx, score_idx, class_num, conf, CalcLabelScore<T>(scores, input, label_idx, score_idx, class_num, conf,
grid_num); grid_num);
...@@ -240,7 +261,8 @@ int YoloBoxPlugin::enqueue_impl(int batch_size, const void* const* inputs, ...@@ -240,7 +261,8 @@ int YoloBoxPlugin::enqueue_impl(int batch_size, const void* const* inputs,
reinterpret_cast<const int* const>(inputs[1]), reinterpret_cast<const int* const>(inputs[1]),
reinterpret_cast<T*>(outputs[0]), reinterpret_cast<T*>(outputs[1]), reinterpret_cast<T*>(outputs[0]), reinterpret_cast<T*>(outputs[1]),
conf_thresh_, anchors_device_, n, h, w, an_num, class_num_, box_num, conf_thresh_, anchors_device_, n, h, w, an_num, class_num_, box_num,
input_size_h, input_size_w, clip_bbox_, scale_x_y_, bias); input_size_h, input_size_w, clip_bbox_, scale_x_y_, bias, iou_aware_,
iou_aware_factor_);
return cudaGetLastError() != cudaSuccess; return cudaGetLastError() != cudaSuccess;
} }
...@@ -274,6 +296,8 @@ size_t YoloBoxPlugin::getSerializationSize() const TRT_NOEXCEPT { ...@@ -274,6 +296,8 @@ size_t YoloBoxPlugin::getSerializationSize() const TRT_NOEXCEPT {
serialize_size += SerializedSize(scale_x_y_); serialize_size += SerializedSize(scale_x_y_);
serialize_size += SerializedSize(input_h_); serialize_size += SerializedSize(input_h_);
serialize_size += SerializedSize(input_w_); serialize_size += SerializedSize(input_w_);
serialize_size += SerializedSize(iou_aware_);
serialize_size += SerializedSize(iou_aware_factor_);
return serialize_size; return serialize_size;
} }
...@@ -285,6 +309,8 @@ void YoloBoxPlugin::serialize(void* buffer) const TRT_NOEXCEPT { ...@@ -285,6 +309,8 @@ void YoloBoxPlugin::serialize(void* buffer) const TRT_NOEXCEPT {
SerializeValue(&buffer, downsample_ratio_); SerializeValue(&buffer, downsample_ratio_);
SerializeValue(&buffer, clip_bbox_); SerializeValue(&buffer, clip_bbox_);
SerializeValue(&buffer, scale_x_y_); SerializeValue(&buffer, scale_x_y_);
SerializeValue(&buffer, iou_aware_);
SerializeValue(&buffer, iou_aware_factor_);
SerializeValue(&buffer, input_h_); SerializeValue(&buffer, input_h_);
SerializeValue(&buffer, input_w_); SerializeValue(&buffer, input_w_);
} }
...@@ -326,8 +352,8 @@ void YoloBoxPlugin::configurePlugin( ...@@ -326,8 +352,8 @@ void YoloBoxPlugin::configurePlugin(
nvinfer1::IPluginV2Ext* YoloBoxPlugin::clone() const TRT_NOEXCEPT { nvinfer1::IPluginV2Ext* YoloBoxPlugin::clone() const TRT_NOEXCEPT {
return new YoloBoxPlugin(data_type_, anchors_, class_num_, conf_thresh_, return new YoloBoxPlugin(data_type_, anchors_, class_num_, conf_thresh_,
downsample_ratio_, clip_bbox_, scale_x_y_, input_h_, downsample_ratio_, clip_bbox_, scale_x_y_,
input_w_); iou_aware_, iou_aware_factor_, input_h_, input_w_);
} }
YoloBoxPluginCreator::YoloBoxPluginCreator() {} YoloBoxPluginCreator::YoloBoxPluginCreator() {}
...@@ -367,6 +393,8 @@ nvinfer1::IPluginV2Ext* YoloBoxPluginCreator::createPlugin( ...@@ -367,6 +393,8 @@ nvinfer1::IPluginV2Ext* YoloBoxPluginCreator::createPlugin(
float scale_x_y = 1.; float scale_x_y = 1.;
int h = -1; int h = -1;
int w = -1; int w = -1;
bool iou_aware = false;
float iou_aware_factor = 0.5;
for (int i = 0; i < fc->nbFields; ++i) { for (int i = 0; i < fc->nbFields; ++i) {
const std::string field_name(fc->fields[i].name); const std::string field_name(fc->fields[i].name);
...@@ -386,6 +414,10 @@ nvinfer1::IPluginV2Ext* YoloBoxPluginCreator::createPlugin( ...@@ -386,6 +414,10 @@ nvinfer1::IPluginV2Ext* YoloBoxPluginCreator::createPlugin(
clip_bbox = *static_cast<const bool*>(fc->fields[i].data); clip_bbox = *static_cast<const bool*>(fc->fields[i].data);
} else if (field_name.compare("scale_x_y")) { } else if (field_name.compare("scale_x_y")) {
scale_x_y = *static_cast<const float*>(fc->fields[i].data); scale_x_y = *static_cast<const float*>(fc->fields[i].data);
} else if (field_name.compare("iou_aware")) {
iou_aware = *static_cast<const bool*>(fc->fields[i].data);
} else if (field_name.compare("iou_aware_factor")) {
iou_aware_factor = *static_cast<const float*>(fc->fields[i].data);
} else if (field_name.compare("h")) { } else if (field_name.compare("h")) {
h = *static_cast<const int*>(fc->fields[i].data); h = *static_cast<const int*>(fc->fields[i].data);
} else if (field_name.compare("w")) { } else if (field_name.compare("w")) {
...@@ -397,7 +429,8 @@ nvinfer1::IPluginV2Ext* YoloBoxPluginCreator::createPlugin( ...@@ -397,7 +429,8 @@ nvinfer1::IPluginV2Ext* YoloBoxPluginCreator::createPlugin(
return new YoloBoxPlugin( return new YoloBoxPlugin(
type_id ? nvinfer1::DataType::kHALF : nvinfer1::DataType::kFLOAT, anchors, type_id ? nvinfer1::DataType::kHALF : nvinfer1::DataType::kFLOAT, anchors,
class_num, conf_thresh, downsample_ratio, clip_bbox, scale_x_y, h, w); class_num, conf_thresh, downsample_ratio, clip_bbox, scale_x_y, iou_aware,
iou_aware_factor, h, w);
} }
nvinfer1::IPluginV2Ext* YoloBoxPluginCreator::deserializePlugin( nvinfer1::IPluginV2Ext* YoloBoxPluginCreator::deserializePlugin(
......
...@@ -31,6 +31,7 @@ class YoloBoxPlugin : public nvinfer1::IPluginV2Ext { ...@@ -31,6 +31,7 @@ class YoloBoxPlugin : public nvinfer1::IPluginV2Ext {
const std::vector<int>& anchors, const int class_num, const std::vector<int>& anchors, const int class_num,
const float conf_thresh, const int downsample_ratio, const float conf_thresh, const int downsample_ratio,
const bool clip_bbox, const float scale_x_y, const bool clip_bbox, const float scale_x_y,
const bool iou_aware, const float iou_aware_factor,
const int input_h, const int input_w); const int input_h, const int input_w);
YoloBoxPlugin(const void* data, size_t length); YoloBoxPlugin(const void* data, size_t length);
~YoloBoxPlugin() override; ~YoloBoxPlugin() override;
...@@ -89,6 +90,8 @@ class YoloBoxPlugin : public nvinfer1::IPluginV2Ext { ...@@ -89,6 +90,8 @@ class YoloBoxPlugin : public nvinfer1::IPluginV2Ext {
float scale_x_y_; float scale_x_y_;
int input_h_; int input_h_;
int input_w_; int input_w_;
bool iou_aware_;
float iou_aware_factor_;
std::string namespace_; std::string namespace_;
}; };
......
...@@ -116,5 +116,56 @@ class TRTYoloBoxFP16Test(InferencePassTest): ...@@ -116,5 +116,56 @@ class TRTYoloBoxFP16Test(InferencePassTest):
PassVersionChecker.IsCompatible('tensorrt_subgraph_pass')) PassVersionChecker.IsCompatible('tensorrt_subgraph_pass'))
class TRTYoloBoxIoUAwareTest(InferencePassTest):
def setUp(self):
self.set_params()
with fluid.program_guard(self.main_program, self.startup_program):
image_shape = [self.bs, self.channel, self.height, self.width]
image = fluid.data(name='image', shape=image_shape, dtype='float32')
image_size = fluid.data(
name='image_size', shape=[self.bs, 2], dtype='int32')
boxes, scores = self.append_yolobox(image, image_size)
self.feeds = {
'image': np.random.random(image_shape).astype('float32'),
'image_size': np.random.randint(
32, 64, size=(self.bs, 2)).astype('int32'),
}
self.enable_trt = True
self.trt_parameters = TRTYoloBoxTest.TensorRTParam(
1 << 30, self.bs, 1, AnalysisConfig.Precision.Float32, False, False)
self.fetch_list = [scores, boxes]
def set_params(self):
self.bs = 4
self.channel = 258
self.height = 64
self.width = 64
self.class_num = 80
self.anchors = [10, 13, 16, 30, 33, 23]
self.conf_thresh = .1
self.downsample_ratio = 32
self.iou_aware = True
self.iou_aware_factor = 0.5
def append_yolobox(self, image, image_size):
return fluid.layers.yolo_box(
x=image,
img_size=image_size,
class_num=self.class_num,
anchors=self.anchors,
conf_thresh=self.conf_thresh,
downsample_ratio=self.downsample_ratio,
iou_aware=self.iou_aware,
iou_aware_factor=self.iou_aware_factor)
def test_check_output(self):
if core.is_compiled_with_cuda():
use_gpu = True
self.check_output_with_option(use_gpu, flatten=True)
self.assertTrue(
PassVersionChecker.IsCompatible('tensorrt_subgraph_pass'))
if __name__ == "__main__": if __name__ == "__main__":
unittest.main() unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册