提交 d077c569 编写于 作者: Y Yang Zhang

Add native implementation for soft nms

and refactor a bit
上级 17f7cbfa
cmake_minimum_required(VERSION 3.8 FATAL_ERROR)
set(CMAKE_CXX_STANDARD 11)
set(PYBIND11_CPP_STANDARD -std=c++11)
include(FetchContent)
FetchContent_Declare(
pybind11
GIT_REPOSITORY "https://github.com/pybind/pybind11.git")
FetchContent_GetProperties(pybind11)
if(NOT pybind11_POPULATED)
FetchContent_Populate(pybind11)
add_subdirectory(${pybind11_SOURCE_DIR})
endif()
set(nms_SOURCES src/binding.cc src/nms.cc)
add_library(nms MODULE ${nms_SOURCES})
set_target_properties(nms PROPERTIES
CXX_VISIBILITY_PRESET "hidden"
PREFIX "${PYTHON_MODULE_PREFIX}"
SUFFIX "${PYTHON_MODULE_EXTENSION}"
OUTPUT_NAME "nms"
)
target_compile_options(nms BEFORE PRIVATE $<$<COMPILE_LANGUAGE:CXX>:-g -O2 -fopenmp -mavx2>)
target_link_libraries(nms PRIVATE -fopenmp)
target_link_libraries(nms PRIVATE pybind11::module)
execute_process(COMMAND ${PYTHON_EXECUTABLE} "-c" "import sysconfig; print(sysconfig.get_config_vars()['INCLUDEPY'])" OUTPUT_VARIABLE PYTHON_INC_PATHS OUTPUT_STRIP_TRAILING_WHITESPACE)
execute_process(COMMAND ${PYTHON_EXECUTABLE} "-c" "import sysconfig; print(sysconfig.get_config_vars()['BLDLIBRARY'])" OUTPUT_VARIABLE PYTHON_LIB_FLAGS OUTPUT_STRIP_TRAILING_WHITESPACE)
execute_process(COMMAND ${PYTHON_EXECUTABLE} "-c" "import sysconfig; print(sysconfig.get_config_vars().get('EXT_SUFFIX', 'so'))" OUTPUT_VARIABLE EXT_SUFFIX OUTPUT_STRIP_TRAILING_WHITESPACE)
install(TARGETS nms DESTINATION ${CMAKE_HOME_DIRECTORY}/ppdet/modeling)
......@@ -266,64 +266,82 @@ class MultiClassNMS(object):
@register
@serializable
class MultiClassSoftNMS(object):
def __init__(
self,
score_threshold=0.01,
keep_top_k=300,
softnms_sigma=0.5,
normalized=False,
background_label=0, ):
def __init__(self,
score_threshold=0.01,
keep_top_k=300,
softnms_sigma=0.5,
normalized=False,
background_label=0,
Nt=0.3,
method='gaussian',
use_cpp=False):
super(MultiClassSoftNMS, self).__init__()
self.score_threshold = score_threshold
self.keep_top_k = keep_top_k
self.softnms_sigma = softnms_sigma
self.normalized = normalized
self.background_label = background_label
self.Nt = Nt
methods = ['gaussian', 'linear', 'hard']
assert method in methods, "supported methods are " + str(methods)
self.method = method
self.use_cpp = use_cpp
if not use_cpp:
return
try:
from ppdet.modeling.nms import NMSMethod, soft_nms
self.method = getattr(NMSMethod, method.upper())
self.soft_nms_native = soft_nms
except ImportError:
print("nms native extension not found, please build it with"
+ "`cmake -B build . ; cmake --build build -t install`")
def per_class_soft_nms(self, dets):
"""soft_nms_for_cls"""
if self.use_cpp:
keep, _ = self.soft_nms_native(dets, self.method, Nt=self.Nt,
sigma=self.softnms_sigma,
threshold=self.score_threshold)
return keep
dets_final = []
while len(dets) > 0:
maxpos = np.argmax(dets[:, 0])
dets_final.append(dets[maxpos].copy())
ts, tx1, ty1, tx2, ty2 = dets[maxpos]
scores = dets[:, 0]
# force remove bbox at maxpos
scores[maxpos] = -1
x1 = dets[:, 1]
y1 = dets[:, 2]
x2 = dets[:, 3]
y2 = dets[:, 4]
eta = 0 if self.normalized else 1
areas = (x2 - x1 + eta) * (y2 - y1 + eta)
xx1 = np.maximum(tx1, x1)
yy1 = np.maximum(ty1, y1)
xx2 = np.minimum(tx2, x2)
yy2 = np.minimum(ty2, y2)
w = np.maximum(0.0, xx2 - xx1 + eta)
h = np.maximum(0.0, yy2 - yy1 + eta)
inter = w * h
ovr = inter / (areas + areas[maxpos] - inter)
weight = np.exp(-(ovr * ovr) / self.softnms_sigma)
scores = scores * weight
idx_keep = np.where(scores >= self.score_threshold)
dets[:, 0] = scores
dets = dets[idx_keep]
dets_final = np.array(dets_final).reshape(-1, 5)
return dets_final
def __call__(self, bboxes, scores):
def create_tmp_var(program, name, dtype, shape, lod_level):
return program.current_block().create_var(
name=name, dtype=dtype, shape=shape, lod_level=lod_level)
def _soft_nms_for_cls(dets, sigma, thres):
"""soft_nms_for_cls"""
dets_final = []
while len(dets) > 0:
maxpos = np.argmax(dets[:, 0])
dets_final.append(dets[maxpos].copy())
ts, tx1, ty1, tx2, ty2 = dets[maxpos]
scores = dets[:, 0]
# force remove bbox at maxpos
scores[maxpos] = -1
x1 = dets[:, 1]
y1 = dets[:, 2]
x2 = dets[:, 3]
y2 = dets[:, 4]
eta = 0 if self.normalized else 1
areas = (x2 - x1 + eta) * (y2 - y1 + eta)
xx1 = np.maximum(tx1, x1)
yy1 = np.maximum(ty1, y1)
xx2 = np.minimum(tx2, x2)
yy2 = np.minimum(ty2, y2)
w = np.maximum(0.0, xx2 - xx1 + eta)
h = np.maximum(0.0, yy2 - yy1 + eta)
inter = w * h
ovr = inter / (areas + areas[maxpos] - inter)
weight = np.exp(-(ovr * ovr) / sigma)
scores = scores * weight
idx_keep = np.where(scores >= thres)
dets[:, 0] = scores
dets = dets[idx_keep]
dets_final = np.array(dets_final).reshape(-1, 5)
return dets_final
def _soft_nms(bboxes, scores):
bboxes = np.array(bboxes)
scores = np.array(scores)
class_nums = scores.shape[-1]
softnms_thres = self.score_threshold
softnms_sigma = self.softnms_sigma
keep_top_k = self.keep_top_k
cls_boxes = [[] for _ in range(class_nums)]
......@@ -339,10 +357,9 @@ class MultiClassSoftNMS(object):
cls_rank = np.argsort(-dets_j[:, 0])
dets_j = dets_j[cls_rank]
cls_boxes[j] = _soft_nms_for_cls(
dets_j, sigma=softnms_sigma, thres=softnms_thres)
cls_ids[j] = np.array([j] * cls_boxes[j].shape[0]).reshape(-1,
1)
cls_boxes[j] = self.per_class_soft_nms(dets_j)
cls_ids[j] = np.array(
[j] * cls_boxes[j].shape[0]).reshape(-1, 1)
cls_boxes = np.vstack(cls_boxes[start_idx:])
cls_ids = np.vstack(cls_ids[start_idx:])
......@@ -360,17 +377,13 @@ class MultiClassSoftNMS(object):
if pred_result.shape[0] == 0:
pred_result = np.array([[1]], dtype=np.float32)
res.set(pred_result, fluid.CPUPlace())
return res
pred_result = create_tmp_var(
fluid.default_main_program(),
name='softnms_pred_result',
dtype='float32',
shape=[6],
lod_level=1)
fluid.layers.py_func(
func=_soft_nms, x=[bboxes, scores], out=pred_result)
pred_result = fluid.default_main_program().current_block().create_var(
name='softnms_pred_result', dtype='float32', shape=[6],
lod_leval=1)
fluid.layers.py_func(func=_soft_nms,
x=[bboxes, scores], out=pred_result)
return pred_result
......
#include <bits/stdc++.h>
#include <pybind11/pybind11.h>
#include <pybind11/pytypes.h>
#include <pybind11/numpy.h>
#include <pybind11/stl.h>
namespace nms
{
enum class Method : uint32_t
{
LINEAR = 0,
GAUSSIAN,
HARD
};
size_t soft_nms(float* boxes,
int32_t* index,
size_t count,
Method method,
float Nt,
float sigma,
float threshold);
} // namespace nms
namespace binding
{
namespace py = pybind11;
using namespace pybind11::literals;
py::tuple py_soft_nms(py::array_t<float, py::array::c_style> boxes,
nms::Method method = nms::Method::GAUSSIAN,
float Nt = 0.3,
float sigma = 0.5,
float threshold = 0.001)
{
assert(boxes.ndim() == 2 && "Input should be 2-D NumPy array");
assert(boxes.shape()[1] == 5 && "Input should have size [N,5]");
auto count = boxes.size() / 5;
auto i = new int32_t[count];
auto b = new float[boxes.size()];
std::copy(boxes.data(), boxes.data() + boxes.size(), b);
auto N = nms::soft_nms(b, i, count, method, Nt, sigma, threshold);
std::vector<size_t> shape5 = {N, 5};
std::vector<size_t> shape1 = {N};
std::vector<ssize_t> strides5 = {sizeof(float) * 5, sizeof(float)};
std::vector<ssize_t> strides1 = {sizeof(float)};
auto cap_b =
py::capsule(b, [](void* v) { delete[] reinterpret_cast<float*>(v); });
auto cap_i =
py::capsule(i, [](void* v) { delete[] reinterpret_cast<int32_t*>(v); });
auto pyb = py::array(py::dtype("float32"), shape5, strides5, b, cap_b);
auto pyi = py::array(py::dtype("int32"), shape1, strides1, i, cap_i);
return py::make_tuple(pyb, pyi);
}
PYBIND11_MODULE(nms, m) {
m.doc() = "SoftNMS for object detection.";
py::enum_<nms::Method>(m, "NMSMethod")
.value("LINEAR", nms::Method::LINEAR)
.value("GAUSSIAN", nms::Method::GAUSSIAN)
.value("HARD", nms::Method::HARD)
.export_values();
m.def("soft_nms", &py_soft_nms, "boxes"_a.noconvert(),
"method"_a = nms::Method::GAUSSIAN,
"Nt"_a = 0.3, "sigma"_a = 0.5, "threshold"_a = 0.001);
}
} /* namespace binding */
#include <bits/stdc++.h>
namespace nms
{
struct proposal
{
float score, x1, y1, x2, y2;
};
inline static bool cmp(const proposal& a, const proposal& b)
{
return a.score < b.score;
}
inline static float iou(const proposal&, const proposal&) __attribute__((always_inline));
static float iou(const proposal& a, const proposal& b)
{
auto overlap = 0.f;
float iw = std::min(b.x2, a.x2) - std::max(b.x1, a.x1) + 1;
if (iw > 0) {
float ih = std::min(b.y2, a.y2) - std::max(b.y1, a.y1) + 1;
if (ih > 0) {
float ab = (b.x2 - b.x1 + 1) * (b.y2 - b.y1 + 1);
float aa = (a.x2 - a.x1 + 1) * (a.y2 - a.y1 + 1);
float inter = iw * ih;
overlap = inter / (aa + ab - inter);
}
}
return overlap;
}
enum class Method : uint32_t
{
LINEAR = 0,
GAUSSIAN,
HARD
};
size_t soft_nms(float* boxes,
int32_t* index,
size_t count,
Method method,
float Nt,
float sigma,
float threshold)
{
std::iota(index, index + count, 0); // np.arange()
auto p = reinterpret_cast<proposal*>(boxes);
auto N = count;
for (size_t i = 0; i < N; ++i) {
auto max = std::max_element(p + i, p + N, cmp);
std::swap(p[i], *max);
std::swap(index[i], index[max - p]);
auto j = i + 1;
auto weight = 0.f;
while (j < N) {
auto ov = iou(p[i], p[j]);
switch (method) {
case Method::LINEAR:
weight = ov > Nt ? 1.f - ov : 1.f;
break;
case Method::GAUSSIAN:
weight = std::exp(-(ov * ov) / sigma);
break;
case Method::HARD:
weight = ov > Nt ? 0.f : 1.f;
break;
}
p[j].score *= weight;
if (p[j].score < threshold) {
N--;
std::swap(p[j], p[N]);
std::swap(index[j], index[N]);
j--;
}
j++;
}
};
return N;
}
} /* namespace nms */
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册