未验证 提交 759b2dfc 编写于 作者: littletomatodonkey's avatar littletomatodonkey 提交者: GitHub

Add DIOU nms (#160)

* add diou nms.
* fix typo in diou nms and soft nms (lod_leval->lod_level).
上级 1e6090d8
......@@ -45,4 +45,4 @@
| :---------------------- | :------------- | :---: | :---: | :-------: | :-----: | :------------: | :----: | :-----: | :----------------------------------------------------------: |
| ResNet50-vd-FPN | Faster | GIOU | 10 | 2 | 1x | 22.94 | 39.4 | - | [model](https://paddlemodels.bj.bcebos.com/object_detection/faster_rcnn_r50_vd_fpn_giou_loss_1x.tar) |
| ResNet50-vd-FPN | Faster | DIOU | 12 | 2 | 1x | 22.94 | 39.2 | - | [model](https://paddlemodels.bj.bcebos.com/object_detection/faster_rcnn_r50_vd_fpn_diou_loss_1x.tar) |
| ResNet50-vd-FPN | Faster | CIOU | 12 | 2 | 1x | 22.95 | 39.5 | - | [model](https://paddlemodels.bj.bcebos.com/object_detection/faster_rcnn_r50_vd_fpn_ciou_loss_1x.tar) |
| ResNet50-vd-FPN | Faster | CIOU | 12 | 2 | 1x | 22.95 | 39.6 | - | [model](https://paddlemodels.bj.bcebos.com/object_detection/faster_rcnn_r50_vd_fpn_ciou_loss_1x.tar) |
......@@ -75,12 +75,14 @@ BBoxAssigner:
BBoxHead:
head: TwoFCHead
nms:
keep_top_k: 100
nms_threshold: 0.5
score_threshold: 0.05
nms: MultiClassDiouNMS
bbox_loss: DiouLoss
MultiClassDiouNMS:
keep_top_k: 100
nms_threshold: 0.5
score_threshold: 0.05
DiouLoss:
loss_weight: 10.0
is_cls_agnostic: false
......
......@@ -90,19 +90,21 @@ class DiouLoss(GiouLoss):
) - intsctk + eps
iouk = intsctk / unionk
# DIOU term
dist_intersection = (cx - cxg) * (cx - cxg) + (cy - cyg) * (cy - cyg)
dist_union = (xc2 - xc1) * (xc2 - xc1) + (yc2 - yc1) * (yc2 - yc1)
diou_term = (dist_intersection + eps) / (dist_union + eps)
# CIOU term
ciou_term = 0
if self.use_complete_iou_loss:
dist_intersection = (cx - cxg) * (cx - cxg) + (cy - cyg) * (cy - cyg
)
dist_union = (xc2 - xc1) * (xc2 - xc1) + (yc2 - yc1) * (yc2 - yc1)
d = (dist_intersection + eps) / (dist_union + eps)
ar_gt = wg / hg
ar_pred = w / h
arctan = fluid.layers.atan(ar_gt) - fluid.layers.atan(ar_pred)
ar_loss = 4. / np.pi / np.pi * arctan * arctan
alpha = ar_loss / (1 - iouk + ar_loss + eps)
alpha.stop_gradient = True
ciou_term = d + alpha * ar_loss
ciou_term = alpha * ar_loss
iou_weights = 1
if inside_weight is not None and outside_weight is not None:
......@@ -116,6 +118,6 @@ class DiouLoss(GiouLoss):
class_weight = 2 if self.is_cls_agnostic else self.num_classes
diou = fluid.layers.reduce_mean(
(1 - iouk + ciou_term) * iou_weights) * class_weight
(1 - iouk + ciou_term + diou_term) * iou_weights) * class_weight
return diou * self.loss_weight
......@@ -24,7 +24,7 @@ __all__ = [
'AnchorGenerator', 'RPNTargetAssign', 'GenerateProposals', 'MultiClassNMS',
'BBoxAssigner', 'MaskAssigner', 'RoIAlign', 'RoIPool', 'MultiBoxHead',
'SSDOutputDecoder', 'RetinaTargetAssign', 'RetinaOutputDecoder', 'ConvNorm',
'MultiClassSoftNMS'
'MultiClassSoftNMS', 'MultiClassDiouNMS'
]
......@@ -226,9 +226,9 @@ class MultiClassSoftNMS(object):
self.background_label = background_label
def __call__(self, bboxes, scores):
def create_tmp_var(program, name, dtype, shape, lod_leval):
def create_tmp_var(program, name, dtype, shape, lod_level):
return program.current_block().create_var(
name=name, dtype=dtype, shape=shape, lod_leval=lod_leval)
name=name, dtype=dtype, shape=shape, lod_level=lod_level)
def _soft_nms_for_cls(dets, sigma, thres):
"""soft_nms_for_cls"""
......@@ -313,12 +313,159 @@ class MultiClassSoftNMS(object):
name='softnms_pred_result',
dtype='float32',
shape=[6],
lod_leval=1)
lod_level=1)
fluid.layers.py_func(
func=_soft_nms, x=[bboxes, scores], out=pred_result)
return pred_result
@register
@serializable
class MultiClassDiouNMS(object):
def __init__(
self,
score_threshold=0.05,
keep_top_k=100,
nms_threshold=0.5,
normalized=False,
background_label=0, ):
super(MultiClassDiouNMS, self).__init__()
self.score_threshold = score_threshold
self.nms_threshold = nms_threshold
self.keep_top_k = keep_top_k
self.normalized = normalized
self.background_label = background_label
def __call__(self, bboxes, scores):
def create_tmp_var(program, name, dtype, shape, lod_level):
return program.current_block().create_var(
name=name, dtype=dtype, shape=shape, lod_level=lod_level)
def _calc_diou_term(dets1, dets2):
eps = 1.e-10
eta = 0 if self.normalized else 1
x1, y1, x2, y2 = dets1[0], dets1[1], dets1[2], dets1[3]
x1g, y1g, x2g, y2g = dets2[0], dets2[1], dets2[2], dets2[3]
cx = (x1 + x2) / 2
cy = (y1 + y2) / 2
w = x2 - x1 + eta
h = y2 - y1 + eta
cxg = (x1g + x2g) / 2
cyg = (y1g + y2g) / 2
wg = x2g - x1g + eta
hg = y2g - y1g + eta
x2 = np.maximum(x1, x2)
y2 = np.maximum(y1, y2)
# A or B
xc1 = np.minimum(x1, x1g)
yc1 = np.minimum(y1, y1g)
xc2 = np.maximum(x2, x2g)
yc2 = np.maximum(y2, y2g)
# DIOU term
dist_intersection = (cx - cxg)**2 + (cy - cyg)**2
dist_union = (xc2 - xc1)**2 + (yc2 - yc1)**2
diou_term = (dist_intersection + eps) / (dist_union + eps)
return diou_term
def _diou_nms_for_cls(dets, thres):
"""_diou_nms_for_cls"""
scores = dets[:, 0]
x1 = dets[:, 1]
y1 = dets[:, 2]
x2 = dets[:, 3]
y2 = dets[:, 4]
eta = 0 if self.normalized else 1
areas = (x2 - x1 + eta) * (y2 - y1 + eta)
dt_num = dets.shape[0]
order = np.array(range(dt_num))
keep = []
while order.size > 0:
i = order[0]
keep.append(i)
xx1 = np.maximum(x1[i], x1[order[1:]])
yy1 = np.maximum(y1[i], y1[order[1:]])
xx2 = np.minimum(x2[i], x2[order[1:]])
yy2 = np.minimum(y2[i], y2[order[1:]])
w = np.maximum(0.0, xx2 - xx1 + eta)
h = np.maximum(0.0, yy2 - yy1 + eta)
inter = w * h
ovr = inter / (areas[i] + areas[order[1:]] - inter)
diou_term = _calc_diou_term([x1[i], y1[i], x2[i], y2[i]], [
x1[order[1:]], y1[order[1:]], x2[order[1:]], y2[order[1:]]
])
inds = np.where(ovr - diou_term <= thres)[0]
order = order[inds + 1]
dets_final = dets[keep]
return dets_final
def _diou_nms(bboxes, scores):
bboxes = np.array(bboxes)
scores = np.array(scores)
class_nums = scores.shape[-1]
score_threshold = self.score_threshold
nms_threshold = self.nms_threshold
keep_top_k = self.keep_top_k
cls_boxes = [[] for _ in range(class_nums)]
cls_ids = [[] for _ in range(class_nums)]
start_idx = 1 if self.background_label == 0 else 0
for j in range(start_idx, class_nums):
inds = np.where(scores[:, j] >= score_threshold)[0]
scores_j = scores[inds, j]
rois_j = bboxes[inds, j, :]
dets_j = np.hstack((scores_j[:, np.newaxis], rois_j)).astype(
np.float32, copy=False)
cls_rank = np.argsort(-dets_j[:, 0])
dets_j = dets_j[cls_rank]
cls_boxes[j] = _diou_nms_for_cls(dets_j, thres=nms_threshold)
cls_ids[j] = np.array([j] * cls_boxes[j].shape[0]).reshape(-1,
1)
cls_boxes = np.vstack(cls_boxes[start_idx:])
cls_ids = np.vstack(cls_ids[start_idx:])
pred_result = np.hstack([cls_ids, cls_boxes])
# Limit to max_per_image detections **over all classes**
image_scores = cls_boxes[:, 0]
if len(image_scores) > keep_top_k:
image_thresh = np.sort(image_scores)[-keep_top_k]
keep = np.where(cls_boxes[:, 0] >= image_thresh)[0]
pred_result = pred_result[keep, :]
res = fluid.LoDTensor()
res.set_lod([[0, pred_result.shape[0]]])
if pred_result.shape[0] == 0:
pred_result = np.array([[1]], dtype=np.float32)
res.set(pred_result, fluid.CPUPlace())
return res
pred_result = create_tmp_var(
fluid.default_main_program(),
name='diou_nms_pred_result',
dtype='float32',
shape=[6],
lod_level=1)
fluid.layers.py_func(
func=_diou_nms, x=[bboxes, scores], out=pred_result)
return pred_result
@register
class BBoxAssigner(object):
__op__ = fluid.layers.generate_proposal_labels
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册