binding.cc 2.0 KB
Newer Older
Y
Yang Zhang 已提交
1
#include <vector>
2 3 4 5 6
#include <pybind11/pybind11.h>
#include <pybind11/pytypes.h>
#include <pybind11/numpy.h>
#include <pybind11/stl.h>

7 8
namespace nms {
enum class Method : uint32_t { LINEAR = 0, GAUSSIAN, HARD };
9

10 11
size_t soft_nms(float* boxes, int32_t* index, size_t count, Method method,
                float Nt, float sigma, float threshold);
12 13
}  // namespace nms

14
namespace binding {
15 16 17 18 19
namespace py = pybind11;
using namespace pybind11::literals;

py::tuple py_soft_nms(py::array_t<float, py::array::c_style> boxes,
                      nms::Method method = nms::Method::GAUSSIAN,
20 21
                      float Nt = 0.3, float sigma = 0.5,
                      float threshold = 0.001) {
22 23 24 25 26 27 28 29 30 31
  assert(boxes.ndim() == 2 && "Input should be 2-D NumPy array");
  assert(boxes.shape()[1] == 5 && "Input should have size [N,5]");

  auto count = boxes.size() / 5;
  auto i = new int32_t[count];
  auto b = new float[boxes.size()];
  std::copy(boxes.data(), boxes.data() + boxes.size(), b);

  auto N = nms::soft_nms(b, i, count, method, Nt, sigma, threshold);

32 33
  std::vector<size_t> shape5 = {N, 5};
  std::vector<size_t> shape1 = {N};
34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50
  std::vector<ssize_t> strides5 = {sizeof(float) * 5, sizeof(float)};
  std::vector<ssize_t> strides1 = {sizeof(float)};

  auto cap_b =
      py::capsule(b, [](void* v) { delete[] reinterpret_cast<float*>(v); });
  auto cap_i =
      py::capsule(i, [](void* v) { delete[] reinterpret_cast<int32_t*>(v); });

  auto pyb = py::array(py::dtype("float32"), shape5, strides5, b, cap_b);
  auto pyi = py::array(py::dtype("int32"), shape1, strides1, i, cap_i);
  return py::make_tuple(pyb, pyi);
}

PYBIND11_MODULE(nms, m) {
  m.doc() = "SoftNMS for object detection.";

  py::enum_<nms::Method>(m, "NMSMethod")
51 52 53 54
      .value("LINEAR", nms::Method::LINEAR)
      .value("GAUSSIAN", nms::Method::GAUSSIAN)
      .value("HARD", nms::Method::HARD)
      .export_values();
55
  m.def("soft_nms", &py_soft_nms, "boxes"_a.noconvert(),
56 57
        "method"_a = nms::Method::GAUSSIAN, "Nt"_a = 0.3, "sigma"_a = 0.5,
        "threshold"_a = 0.001);
58 59
}
} /* namespace binding */