nms.cc 1.8 KB
Newer Older
Y
Yang Zhang 已提交
1 2
#include <cmath>
#include <algorithm>
3

4 5
namespace nms {
struct proposal {
6 7 8
  float score, x1, y1, x2, y2;
};

9
inline static bool cmp(const proposal& a, const proposal& b) {
10 11 12
  return a.score < b.score;
}

13 14
inline static float iou(const proposal&, const proposal&)
    __attribute__((always_inline));
15

16
static float iou(const proposal& a, const proposal& b) {
17
  auto overlap = 0.f;
18
  float iw = std::min(b.x2, a.x2) - std::max(b.x1, a.x1) + 1;
19 20 21 22 23 24 25 26 27 28 29 30
  if (iw > 0) {
    float ih = std::min(b.y2, a.y2) - std::max(b.y1, a.y1) + 1;
    if (ih > 0) {
      float ab = (b.x2 - b.x1 + 1) * (b.y2 - b.y1 + 1);
      float aa = (a.x2 - a.x1 + 1) * (a.y2 - a.y1 + 1);
      float inter = iw * ih;
      overlap = inter / (aa + ab - inter);
    }
  }
  return overlap;
}

31
enum class Method : uint32_t { LINEAR = 0, GAUSSIAN, HARD };
32

33 34
size_t soft_nms(float* boxes, int32_t* index, size_t count, Method method,
                float Nt, float sigma, float threshold) {
35 36 37 38 39 40 41 42 43
  std::iota(index, index + count, 0);  // np.arange()
  auto p = reinterpret_cast<proposal*>(boxes);

  auto N = count;
  for (size_t i = 0; i < N; ++i) {
    auto max = std::max_element(p + i, p + N, cmp);
    std::swap(p[i], *max);
    std::swap(index[i], index[max - p]);

44
    auto j = i + 1;
45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72
    auto weight = 0.f;
    while (j < N) {
      auto ov = iou(p[i], p[j]);
      switch (method) {
        case Method::LINEAR:
          weight = ov > Nt ? 1.f - ov : 1.f;
          break;
        case Method::GAUSSIAN:
          weight = std::exp(-(ov * ov) / sigma);
          break;
        case Method::HARD:
          weight = ov > Nt ? 0.f : 1.f;
          break;
      }
      p[j].score *= weight;
      if (p[j].score < threshold) {
        N--;
        std::swap(p[j], p[N]);
        std::swap(index[j], index[N]);
        j--;
      }
      j++;
    }
  };

  return N;
}
} /* namespace nms */