提交 c67e75b6 编写于 作者: D Dmitry Kurtaev

Refactor NMS procedure at RegionLayer

上级 047ad4ff
......@@ -482,7 +482,7 @@ namespace cv {
}
else if (layer_type == "region")
{
float thresh = 0.001; // in the original Darknet is equal to the detection threshold set by the user
float thresh = getParam<float>(layer_params, "thresh", 0.001);
int coords = getParam<int>(layer_params, "coords", 4);
int classes = getParam<int>(layer_params, "classes", -1);
int num_of_anchors = getParam<int>(layer_params, "num", -1);
......
......@@ -43,7 +43,7 @@
#include "../precomp.hpp"
#include <opencv2/dnn/shape_utils.hpp>
#include <opencv2/dnn/all_layers.hpp>
#include <iostream>
#include "nms.inl.hpp"
#include "opencl_kernels_dnn.hpp"
namespace cv
......@@ -173,8 +173,7 @@ public:
if (nmsThreshold > 0) {
Mat mat = outBlob.getMat(ACCESS_WRITE);
float *dstData = mat.ptr<float>();
do_nms_sort(dstData, rows*cols*anchors, nmsThreshold);
//do_nms(dstData, rows*cols*anchors, nmsThreshold);
do_nms_sort(dstData, rows*cols*anchors, thresh, nmsThreshold);
}
}
......@@ -263,128 +262,48 @@ public:
}
if (nmsThreshold > 0) {
do_nms_sort(dstData, rows*cols*anchors, nmsThreshold);
//do_nms(dstData, rows*cols*anchors, nmsThreshold);
do_nms_sort(dstData, rows*cols*anchors, thresh, nmsThreshold);
}
}
}
struct box {
float x, y, w, h;
float *probs;
};
float overlap(float x1, float w1, float x2, float w2)
{
float l1 = x1 - w1 / 2;
float l2 = x2 - w2 / 2;
float left = l1 > l2 ? l1 : l2;
float r1 = x1 + w1 / 2;
float r2 = x2 + w2 / 2;
float right = r1 < r2 ? r1 : r2;
return right - left;
}
float box_intersection(box a, box b)
{
float w = overlap(a.x, a.w, b.x, b.w);
float h = overlap(a.y, a.h, b.y, b.h);
if (w < 0 || h < 0) return 0;
float area = w*h;
return area;
}
float box_union(box a, box b)
static inline float rectOverlap(const Rect2f& a, const Rect2f& b)
{
float i = box_intersection(a, b);
float u = a.w*a.h + b.w*b.h - i;
return u;
return 1.0f - jaccardDistance(a, b);
}
float box_iou(box a, box b)
void do_nms_sort(float *detections, int total, float score_thresh, float nms_thresh)
{
return box_intersection(a, b) / box_union(a, b);
}
struct sortable_bbox {
int index;
float *probs;
};
struct nms_comparator {
int k;
nms_comparator(int _k) : k(_k) {}
bool operator ()(sortable_bbox v1, sortable_bbox v2) {
return v2.probs[k] < v1.probs[k];
}
};
void do_nms_sort(float *detections, int total, float nms_thresh)
{
std::vector<box> boxes(total);
for (int i = 0; i < total; ++i) {
box &b = boxes[i];
int box_index = i * (classes + coords + 1);
b.x = detections[box_index + 0];
b.y = detections[box_index + 1];
b.w = detections[box_index + 2];
b.h = detections[box_index + 3];
int class_index = i * (classes + 5) + 5;
b.probs = (detections + class_index);
}
std::vector<sortable_bbox> s(total);
for (int i = 0; i < total; ++i) {
s[i].index = i;
int class_index = i * (classes + 5) + 5;
s[i].probs = (detections + class_index);
}
std::vector<Rect2f> boxes(total);
std::vector<float> scores(total);
for (int k = 0; k < classes; ++k) {
std::stable_sort(s.begin(), s.end(), nms_comparator(k));
for (int i = 0; i < total; ++i) {
if (boxes[s[i].index].probs[k] == 0) continue;
box a = boxes[s[i].index];
for (int j = i + 1; j < total; ++j) {
box b = boxes[s[j].index];
if (box_iou(a, b) > nms_thresh) {
boxes[s[j].index].probs[k] = 0;
}
}
}
}
}
void do_nms(float *detections, int total, float nms_thresh)
{
std::vector<box> boxes(total);
for (int i = 0; i < total; ++i) {
box &b = boxes[i];
for (int i = 0; i < total; ++i)
{
Rect2f &b = boxes[i];
int box_index = i * (classes + coords + 1);
b.x = detections[box_index + 0];
b.y = detections[box_index + 1];
b.w = detections[box_index + 2];
b.h = detections[box_index + 3];
int class_index = i * (classes + 5) + 5;
b.probs = (detections + class_index);
b.width = detections[box_index + 2];
b.height = detections[box_index + 3];
b.x = detections[box_index + 0] - b.width / 2;
b.y = detections[box_index + 1] - b.height / 2;
}
for (int i = 0; i < total; ++i) {
bool any = false;
for (int k = 0; k < classes; ++k) any = any || (boxes[i].probs[k] > 0);
if (!any) {
continue;
std::vector<int> indices;
for (int k = 0; k < classes; ++k)
{
for (int i = 0; i < total; ++i)
{
int box_index = i * (classes + coords + 1);
int class_index = box_index + 5;
scores[i] = detections[class_index + k];
detections[class_index + k] = 0;
}
for (int j = i + 1; j < total; ++j) {
if (box_iou(boxes[i], boxes[j]) > nms_thresh) {
for (int k = 0; k < classes; ++k) {
if (boxes[i].probs[k] < boxes[j].probs[k]) boxes[i].probs[k] = 0;
else boxes[j].probs[k] = 0;
}
}
NMSFast_(boxes, scores, score_thresh, nms_thresh, 1, 0, indices, rectOverlap);
for (int i = 0, n = indices.size(); i < n; ++i)
{
int box_index = indices[i] * (classes + coords + 1);
int class_index = box_index + 5;
detections[class_index + k] = scores[indices[i]];
}
}
}
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册