提交 88ee56d0 编写于 作者: J jerrywgz

enhance nms for mask rcnn

上级 3f815e07
......@@ -93,5 +93,25 @@ void BboxOverlaps(const framework::Tensor& r_boxes,
}
}
template <class T>
void SliceOneClass(const platform::DeviceContext& ctx,
const framework::Tensor& items, const int class_id,
framework::Tensor* one_class_item) {
T* item_data = one_class_item->mutable_data<T>(ctx.GetPlace());
const T* items_data = items.data<T>();
const int64_t num_item = items.dims()[0];
const int class_num = items.dims()[1];
int item_size = 1;
if (items.dims().size() == 3) {
item_size = items.dims()[2];
}
for (int i = 0; i < num_item; ++i) {
for (int j = 0; j < item_size; ++j) {
item_data[i * item_size + j] =
items_data[i * class_num * item_size + class_id * item_size + j];
}
}
}
} // namespace operators
} // namespace paddle
......@@ -19,7 +19,7 @@ import copy
from op_test import OpTest
def iou(box_a, box_b):
def iou(box_a, box_b, normalized):
"""Apply intersection-over-union overlap between box_a and box_b
"""
xmin_a = min(box_a[0], box_a[2])
......@@ -32,8 +32,10 @@ def iou(box_a, box_b):
xmax_b = max(box_b[0], box_b[2])
ymax_b = max(box_b[1], box_b[3])
area_a = (ymax_a - ymin_a) * (xmax_a - xmin_a)
area_b = (ymax_b - ymin_b) * (xmax_b - xmin_b)
area_a = (ymax_a - ymin_a + (normalized == False)) * \
(xmax_a - xmin_a + (normalized == False))
area_b = (ymax_b - ymin_b + (normalized == False)) * \
(xmax_b - xmin_b + (normalized == False))
if area_a <= 0 and area_b <= 0:
return 0.0
......@@ -42,17 +44,21 @@ def iou(box_a, box_b):
xb = min(xmax_a, xmax_b)
yb = min(ymax_a, ymax_b)
inter_area = max(xb - xa, 0.0) * max(yb - ya, 0.0)
box_a_area = (box_a[2] - box_a[0]) * (box_a[3] - box_a[1])
box_b_area = (box_b[2] - box_b[0]) * (box_b[3] - box_b[1])
inter_area = max(xb - xa + (normalized == False), 0.0) * \
max(yb - ya + (normalized == False), 0.0)
iou_ratio = inter_area / (area_a + area_b - inter_area)
return iou_ratio
def nms(boxes, scores, score_threshold, nms_threshold, top_k=200, eta=1.0):
def nms(boxes,
scores,
score_threshold,
nms_threshold,
top_k=200,
normalized=True,
eta=1.0):
"""Apply non-maximum suppression at test time to avoid detecting too many
overlapping bounding boxes for a given object.
Args:
......@@ -87,7 +93,7 @@ def nms(boxes, scores, score_threshold, nms_threshold, top_k=200, eta=1.0):
for k in range(len(selected_indices)):
if keep:
kept_idx = selected_indices[k]
overlap = iou(boxes[idx], boxes[kept_idx])
overlap = iou(boxes[idx], boxes[kept_idx], normalized)
keep = True if overlap <= adaptive_threshold else False
else:
break
......@@ -99,16 +105,24 @@ def nms(boxes, scores, score_threshold, nms_threshold, top_k=200, eta=1.0):
def multiclass_nms(boxes, scores, background, score_threshold, nms_threshold,
nms_top_k, keep_top_k):
class_num = scores.shape[0]
priorbox_num = scores.shape[1]
nms_top_k, keep_top_k, normalized, shared):
if shared:
class_num = scores.shape[0]
priorbox_num = scores.shape[1]
else:
box_num = scores.shape[0]
class_num = scores.shape[1]
selected_indices = {}
num_det = 0
for c in range(class_num):
if c == background: continue
indices = nms(boxes, scores[c], score_threshold, nms_threshold,
nms_top_k)
if shared:
indices = nms(boxes, scores[c], score_threshold, nms_threshold,
nms_top_k, normalized)
else:
indices = nms(boxes[:, c, :], scores[:, c], score_threshold,
nms_threshold, nms_top_k, normalized)
selected_indices[c] = indices
num_det += len(indices)
......@@ -116,7 +130,10 @@ def multiclass_nms(boxes, scores, background, score_threshold, nms_threshold,
score_index = []
for c, indices in selected_indices.items():
for idx in indices:
score_index.append((scores[c][idx], c, idx))
if shared:
score_index.append((scores[c][idx], c, idx))
else:
score_index.append((scores[idx][c], c, idx))
sorted_score_index = sorted(
score_index, key=lambda tup: tup[0], reverse=True)
......@@ -127,24 +144,74 @@ def multiclass_nms(boxes, scores, background, score_threshold, nms_threshold,
selected_indices[c] = []
for s, c, idx in sorted_score_index:
selected_indices[c].append(idx)
if not shared:
for labels in selected_indices:
selected_indices[labels].sort()
num_det = keep_top_k
return selected_indices, num_det
def batched_multiclass_nms(boxes, scores, background, score_threshold,
nms_threshold, nms_top_k, keep_top_k):
def lod_multiclass_nms(boxes, scores, background, score_threshold,
nms_threshold, nms_top_k, keep_top_k, box_lod,
normalized):
det_outs = []
lod = []
head = 0
for n in range(len(box_lod[0])):
box = boxes[head:head + box_lod[0][n]]
score = scores[head:head + box_lod[0][n]]
head = head + box_lod[0][n]
nmsed_outs, nmsed_num = multiclass_nms(
box,
score,
background,
score_threshold,
nms_threshold,
nms_top_k,
keep_top_k,
normalized,
shared=False)
if nmsed_num == 0:
lod.append(1)
continue
lod.append(nmsed_num)
for c, indices in nmsed_outs.items():
for idx in indices:
xmin, ymin, xmax, ymax = box[idx, c, :]
det_outs.append([c, score[idx][c], xmin, ymin, xmax, ymax])
return det_outs, lod
def batched_multiclass_nms(boxes,
scores,
background,
score_threshold,
nms_threshold,
nms_top_k,
keep_top_k,
normalized=True):
batch_size = scores.shape[0]
det_outs = []
lod = []
for n in range(batch_size):
nmsed_outs, nmsed_num = multiclass_nms(boxes[n], scores[n], background,
score_threshold, nms_threshold,
nms_top_k, keep_top_k)
lod.append(nmsed_num)
if nmsed_num == 0: continue
nmsed_outs, nmsed_num = multiclass_nms(
boxes[n],
scores[n],
background,
score_threshold,
nms_threshold,
nms_top_k,
keep_top_k,
normalized,
shared=True)
if nmsed_num == 0:
lod.append(1)
continue
lod.append(nmsed_num)
tmp_det_out = []
for c, indices in nmsed_outs.items():
for idx in indices:
......@@ -168,7 +235,6 @@ class TestMulticlassNMSOp(OpTest):
M = 1200
C = 21
BOX_SIZE = 4
background = 0
nms_threshold = 0.3
nms_top_k = 400
......@@ -193,6 +259,7 @@ class TestMulticlassNMSOp(OpTest):
nmsed_outs, lod = batched_multiclass_nms(boxes, scores, background,
score_threshold, nms_threshold,
nms_top_k, keep_top_k)
print('python lod: ', lod)
nmsed_outs = [-1] if not nmsed_outs else nmsed_outs
nmsed_outs = np.array(nmsed_outs).astype('float32')
......@@ -206,6 +273,7 @@ class TestMulticlassNMSOp(OpTest):
'keep_top_k': keep_top_k,
'score_threshold': score_threshold,
'nms_eta': 1.0,
'normalized': True,
}
def test_check_output(self):
......@@ -219,13 +287,70 @@ class TestMulticlassNMSOpNoOutput(TestMulticlassNMSOp):
self.score_threshold = 2.0
class TestMulticlassNMSLoDInput(OpTest):
def set_argument(self):
self.score_threshold = 0.01
def setUp(self):
self.set_argument()
M = 1200
C = 21
BOX_SIZE = 4
box_lod = [[1200]]
background = 0
nms_threshold = 0.3
nms_top_k = 400
keep_top_k = 200
score_threshold = self.score_threshold
normalized = False
scores = np.random.random((M, C)).astype('float32')
def softmax(x):
shiftx = x - np.max(x).clip(-64.)
exps = np.exp(shiftx)
return exps / np.sum(exps)
scores = np.apply_along_axis(softmax, 1, scores)
boxes = np.random.random((M, C, BOX_SIZE)).astype('float32')
boxes[:, :, 0] = boxes[:, :, 0] * 10
boxes[:, :, 1] = boxes[:, :, 1] * 10
boxes[:, :, 2] = boxes[:, :, 2] * 10 + 10
boxes[:, :, 3] = boxes[:, :, 3] * 10 + 10
nmsed_outs, lod = lod_multiclass_nms(
boxes, scores, background, score_threshold, nms_threshold,
nms_top_k, keep_top_k, box_lod, normalized)
nmsed_outs = [-1] if not nmsed_outs else nmsed_outs
nmsed_outs = np.array(nmsed_outs).astype('float32')
self.op_type = 'multiclass_nms'
self.inputs = {
'BBoxes': (boxes, box_lod),
'Scores': (scores, box_lod),
}
self.outputs = {'Out': (nmsed_outs, [lod])}
self.attrs = {
'background_label': 0,
'nms_threshold': nms_threshold,
'nms_top_k': nms_top_k,
'keep_top_k': keep_top_k,
'score_threshold': score_threshold,
'nms_eta': 1.0,
'normalized': normalized,
}
def test_check_output(self):
self.check_output()
class TestIOU(unittest.TestCase):
def test_iou(self):
box1 = np.array([4.0, 3.0, 7.0, 5.0]).astype('float32')
box2 = np.array([3.0, 4.0, 6.0, 8.0]).astype('float32')
expt_output = np.array([2.0 / 16.0]).astype('float32')
calc_output = np.array([iou(box1, box2)]).astype('float32')
calc_output = np.array([iou(box1, box2, True)]).astype('float32')
self.assertTrue(np.allclose(calc_output, expt_output))
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册