提交 2f89e761 编写于 作者: S sunyanfang01

add fasterrcnn loss

上级 948032a7
...@@ -43,7 +43,8 @@ class FasterRCNN(BaseAPI): ...@@ -43,7 +43,8 @@ class FasterRCNN(BaseAPI):
backbone='ResNet50', backbone='ResNet50',
with_fpn=True, with_fpn=True,
aspect_ratios=[0.5, 1.0, 2.0], aspect_ratios=[0.5, 1.0, 2.0],
anchor_sizes=[32, 64, 128, 256, 512]): anchor_sizes=[32, 64, 128, 256, 512],
bbox_loss_type='SmoothL1Loss'):
self.init_params = locals() self.init_params = locals()
super(FasterRCNN, self).__init__('detector') super(FasterRCNN, self).__init__('detector')
backbones = [ backbones = [
...@@ -57,6 +58,7 @@ class FasterRCNN(BaseAPI): ...@@ -57,6 +58,7 @@ class FasterRCNN(BaseAPI):
self.with_fpn = with_fpn self.with_fpn = with_fpn
self.aspect_ratios = aspect_ratios self.aspect_ratios = aspect_ratios
self.anchor_sizes = anchor_sizes self.anchor_sizes = anchor_sizes
self.bbox_loss_type = bbox_loss_type
self.labels = None self.labels = None
self.fixed_input_shape = None self.fixed_input_shape = None
...@@ -72,6 +74,8 @@ class FasterRCNN(BaseAPI): ...@@ -72,6 +74,8 @@ class FasterRCNN(BaseAPI):
layers = 50 layers = 50
variant = 'd' variant = 'd'
norm_type = 'affine_channel' norm_type = 'affine_channel'
if self.bbox_loss_type != 'SmoothL1Loss':
norm_type = 'bn'
elif backbone_name == 'ResNet101': elif backbone_name == 'ResNet101':
layers = 101 layers = 101
variant = 'b' variant = 'b'
...@@ -118,7 +122,8 @@ class FasterRCNN(BaseAPI): ...@@ -118,7 +122,8 @@ class FasterRCNN(BaseAPI):
anchor_sizes=self.anchor_sizes, anchor_sizes=self.anchor_sizes,
train_pre_nms_top_n=train_pre_nms_top_n, train_pre_nms_top_n=train_pre_nms_top_n,
test_pre_nms_top_n=test_pre_nms_top_n, test_pre_nms_top_n=test_pre_nms_top_n,
fixed_input_shape=self.fixed_input_shape) fixed_input_shape=self.fixed_input_shape,
bbox_loss_type=self.bbox_loss_type)
inputs = model.generate_inputs() inputs = model.generate_inputs()
if mode == 'train': if mode == 'train':
model_out = model.build_net(inputs) model_out = model.build_net(inputs)
...@@ -134,26 +139,49 @@ class FasterRCNN(BaseAPI): ...@@ -134,26 +139,49 @@ class FasterRCNN(BaseAPI):
outputs = model.build_net(inputs) outputs = model.build_net(inputs)
return inputs, outputs return inputs, outputs
# def default_optimizer(self, learning_rate, warmup_steps, warmup_start_lr,
# lr_decay_epochs, lr_decay_gamma,
# num_steps_each_epoch):
# if warmup_steps > lr_decay_epochs[0] * num_steps_each_epoch:
# raise Exception("warmup_steps should less than {}".format(
# lr_decay_epochs[0] * num_steps_each_epoch))
# boundaries = [b * num_steps_each_epoch for b in lr_decay_epochs]
# values = [(lr_decay_gamma**i) * learning_rate
# for i in range(len(lr_decay_epochs) + 1)]
# lr_decay = fluid.layers.piecewise_decay(
# boundaries=boundaries, values=values)
# lr_warmup = fluid.layers.linear_lr_warmup(
# learning_rate=lr_decay,
# warmup_steps=warmup_steps,
# start_lr=warmup_start_lr,
# end_lr=learning_rate)
# optimizer = fluid.optimizer.Momentum(
# learning_rate=lr_warmup,
# momentum=0.9,
# regularization=fluid.regularizer.L2Decay(1e-04))
# return optimizer
def default_optimizer(self, learning_rate, warmup_steps, warmup_start_lr, def default_optimizer(self, learning_rate, warmup_steps, warmup_start_lr,
lr_decay_epochs, lr_decay_gamma, lr_decay_epochs, lr_decay_gamma,
num_steps_each_epoch): num_steps_each_epoch):
if warmup_steps > lr_decay_epochs[0] * num_steps_each_epoch: #if warmup_steps > lr_decay_epochs[0] * num_steps_each_epoch:
raise Exception("warmup_steps should less than {}".format( # raise Exception("warmup_steps should less than {}".format(
lr_decay_epochs[0] * num_steps_each_epoch)) # lr_decay_epochs[0] * num_steps_each_epoch))
boundaries = [b * num_steps_each_epoch for b in lr_decay_epochs] boundaries = [b * num_steps_each_epoch for b in lr_decay_epochs]
values = [(lr_decay_gamma**i) * learning_rate values = [(lr_decay_gamma**i) * learning_rate
for i in range(len(lr_decay_epochs) + 1)] for i in range(len(lr_decay_epochs) + 1)]
lr_decay = fluid.layers.piecewise_decay( lr_decay = fluid.layers.piecewise_decay(
boundaries=boundaries, values=values) boundaries=boundaries, values=values)
lr_warmup = fluid.layers.linear_lr_warmup( #lr_warmup = fluid.layers.linear_lr_warmup(
learning_rate=lr_decay, # learning_rate=lr_decay,
warmup_steps=warmup_steps, # warmup_steps=warmup_steps,
start_lr=warmup_start_lr, # start_lr=warmup_start_lr,
end_lr=learning_rate) # end_lr=learning_rate)
optimizer = fluid.optimizer.Momentum( optimizer = fluid.optimizer.Momentum(
learning_rate=lr_warmup, #learning_rate=lr_warmup,
learning_rate=lr_decay,
momentum=0.9, momentum=0.9,
regularization=fluid.regularizer.L2Decay(1e-04)) regularization=fluid.regularizer.L2DecayRegularizer(1e-04))
return optimizer return optimizer
def train(self, def train(self,
......
...@@ -24,6 +24,7 @@ from paddle.fluid.initializer import Normal, Xavier ...@@ -24,6 +24,7 @@ from paddle.fluid.initializer import Normal, Xavier
from paddle.fluid.regularizer import L2Decay from paddle.fluid.regularizer import L2Decay
from paddle.fluid.initializer import MSRA from paddle.fluid.initializer import MSRA
__all__ = ['BBoxHead', 'TwoFCHead'] __all__ = ['BBoxHead', 'TwoFCHead']
...@@ -82,7 +83,8 @@ class BBoxHead(object): ...@@ -82,7 +83,8 @@ class BBoxHead(object):
background_label=0, background_label=0,
#bbox_loss #bbox_loss
sigma=1.0, sigma=1.0,
num_classes=81): num_classes=81,
bbox_loss_type='SmoothL1Loss'):
super(BBoxHead, self).__init__() super(BBoxHead, self).__init__()
self.head = head self.head = head
self.prior_box_var = prior_box_var self.prior_box_var = prior_box_var
...@@ -99,6 +101,7 @@ class BBoxHead(object): ...@@ -99,6 +101,7 @@ class BBoxHead(object):
self.sigma = sigma self.sigma = sigma
self.num_classes = num_classes self.num_classes = num_classes
self.head_feat = None self.head_feat = None
self.bbox_loss_type = bbox_loss_type
def get_head_feat(self, input=None): def get_head_feat(self, input=None):
""" """
...@@ -126,6 +129,7 @@ class BBoxHead(object): ...@@ -126,6 +129,7 @@ class BBoxHead(object):
[N, num_anchors * 4, H, W]. [N, num_anchors * 4, H, W].
""" """
head_feat = self.get_head_feat(roi_feat) head_feat = self.get_head_feat(roi_feat)
# when ResNetC5 output a single feature map # when ResNetC5 output a single feature map
if not isinstance(self.head, TwoFCHead): if not isinstance(self.head, TwoFCHead):
head_feat = fluid.layers.pool2d( head_feat = fluid.layers.pool2d(
...@@ -173,18 +177,50 @@ class BBoxHead(object): ...@@ -173,18 +177,50 @@ class BBoxHead(object):
""" """
cls_score, bbox_pred = self._get_output(roi_feat) cls_score, bbox_pred = self._get_output(roi_feat)
labels_int64 = fluid.layers.cast(x=labels_int32, dtype='int64') labels_int64 = fluid.layers.cast(x=labels_int32, dtype='int64')
labels_int64.stop_gradient = True labels_int64.stop_gradient = True
loss_cls = fluid.layers.softmax_with_cross_entropy( loss_cls = fluid.layers.softmax_with_cross_entropy(
logits=cls_score, label=labels_int64, numeric_stable_mode=True) logits=cls_score, label=labels_int64, numeric_stable_mode=True)
loss_cls = fluid.layers.reduce_mean(loss_cls) loss_cls = fluid.layers.reduce_mean(loss_cls)
loss_bbox = fluid.layers.smooth_l1( if self.bbox_loss_type == 'CiouLoss':
x=bbox_pred, from .loss.diou_loss import DiouLoss
y=bbox_targets, loss_obj = DiouLoss(loss_weight=10.,
inside_weight=bbox_inside_weights, is_cls_agnostic=False,
outside_weight=bbox_outside_weights, num_classes=self.num_classes,
sigma=self.sigma) use_complete_iou_loss=True)
loss_bbox = loss_obj(
x=bbox_pred,
y=bbox_targets,
inside_weight=bbox_inside_weights,
outside_weight=bbox_outside_weights)
elif self.bbox_loss_type == 'DiouLoss':
from .loss.diou_loss import DiouLoss
loss_obj = DiouLoss(loss_weight=12.,
is_cls_agnostic=False,
num_classes=self.num_classes,
use_complete_iou_loss=False)
loss_bbox = loss_obj(
x=bbox_pred,
y=bbox_targets,
inside_weight=bbox_inside_weights,
outside_weight=bbox_outside_weights)
elif self.bbox_loss_type == 'GiouLoss':
from .loss.giou_loss import GiouLoss
loss_obj = GiouLoss(loss_weight=10.,
is_cls_agnostic=False,
num_classes=self.num_classes)
loss_bbox = loss_obj(
x=bbox_pred,
y=bbox_targets,
inside_weight=bbox_inside_weights,
outside_weight=bbox_outside_weights)
else:
loss_bbox = fluid.layers.smooth_l1(
x=bbox_pred,
y=bbox_targets,
inside_weight=bbox_inside_weights,
outside_weight=bbox_outside_weights,
sigma=self.sigma)
loss_bbox = fluid.layers.reduce_mean(loss_bbox) loss_bbox = fluid.layers.reduce_mean(loss_bbox)
return {'loss_cls': loss_cls, 'loss_bbox': loss_bbox} return {'loss_cls': loss_cls, 'loss_bbox': loss_bbox}
...@@ -229,14 +265,21 @@ class BBoxHead(object): ...@@ -229,14 +265,21 @@ class BBoxHead(object):
cliped_box = fluid.layers.box_clip(input=decoded_box, im_info=im_shape) cliped_box = fluid.layers.box_clip(input=decoded_box, im_info=im_shape)
if return_box_score: if return_box_score:
return {'bbox': cliped_box, 'score': cls_prob} return {'bbox': cliped_box, 'score': cls_prob}
pred_result = fluid.layers.multiclass_nms( if self.bbox_loss_type == 'CiouLoss':
bboxes=cliped_box, from .nms import MultiClassDiouNMS
scores=cls_prob, nms_obj = MultiClassDiouNMS(score_threshold=self.score_threshold,
score_threshold=self.score_threshold, nms_threshold=self.nms_threshold,
nms_top_k=self.nms_top_k, keep_top_k=self.keep_top_k)
keep_top_k=self.keep_top_k, pred_result = nms_obj(bboxes=cliped_box, scores=cls_prob)
nms_threshold=self.nms_threshold, else:
normalized=self.normalized, pred_result = fluid.layers.multiclass_nms(
nms_eta=self.nms_eta, bboxes=cliped_box,
background_label=self.background_label) scores=cls_prob,
score_threshold=self.score_threshold,
nms_top_k=self.nms_top_k,
keep_top_k=self.keep_top_k,
nms_threshold=self.nms_threshold,
normalized=self.normalized,
nms_eta=self.nms_eta,
background_label=self.background_label)
return {'bbox': pred_result} return {'bbox': pred_result}
...@@ -70,6 +70,7 @@ class FasterRCNN(object): ...@@ -70,6 +70,7 @@ class FasterRCNN(object):
keep_top_k=100, keep_top_k=100,
nms_threshold=0.5, nms_threshold=0.5,
score_threshold=0.05, score_threshold=0.05,
bbox_loss_type='SmoothL1Loss',
#bbox_assigner #bbox_assigner
batch_size_per_im=512, batch_size_per_im=512,
fg_fraction=.25, fg_fraction=.25,
...@@ -145,7 +146,8 @@ class FasterRCNN(object): ...@@ -145,7 +146,8 @@ class FasterRCNN(object):
keep_top_k=keep_top_k, keep_top_k=keep_top_k,
nms_threshold=nms_threshold, nms_threshold=nms_threshold,
score_threshold=score_threshold, score_threshold=score_threshold,
num_classes=num_classes) num_classes=num_classes,
bbox_loss_type=bbox_loss_type)
self.bbox_head = bbox_head self.bbox_head = bbox_head
self.batch_size_per_im = batch_size_per_im self.batch_size_per_im = batch_size_per_im
self.fg_fraction = fg_fraction self.fg_fraction = fg_fraction
...@@ -189,7 +191,6 @@ class FasterRCNN(object): ...@@ -189,7 +191,6 @@ class FasterRCNN(object):
bbox_reg_weights=self.bbox_reg_weights, bbox_reg_weights=self.bbox_reg_weights,
class_nums=self.num_classes, class_nums=self.num_classes,
use_random=self.rpn_head.use_random) use_random=self.rpn_head.use_random)
rois = outputs[0] rois = outputs[0]
labels_int32 = outputs[1] labels_int32 = outputs[1]
bbox_targets = outputs[2] bbox_targets = outputs[2]
...@@ -211,10 +212,12 @@ class FasterRCNN(object): ...@@ -211,10 +212,12 @@ class FasterRCNN(object):
else: else:
roi_feat = self.roi_extractor(body_feats, rois, spatial_scale) roi_feat = self.roi_extractor(body_feats, rois, spatial_scale)
if self.mode == 'train': if self.mode == 'train':
loss = self.bbox_head.get_loss(roi_feat, labels_int32, loss = self.bbox_head.get_loss(roi_feat, labels_int32,
bbox_targets, bbox_inside_weights, bbox_targets, bbox_inside_weights,
bbox_outside_weights) bbox_outside_weights)
loss.update(rpn_loss) loss.update(rpn_loss)
total_loss = fluid.layers.sum(list(loss.values())) total_loss = fluid.layers.sum(list(loss.values()))
loss.update({'loss': total_loss}) loss.update({'loss': total_loss})
......
#copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
#
#Licensed under the Apache License, Version 2.0 (the "License");
#you may not use this file except in compliance with the License.
#You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
#Unless required by applicable law or agreed to in writing, software
#distributed under the License is distributed on an "AS IS" BASIS,
#WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
#See the License for the specific language governing permissions and
#limitations under the License.
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import numpy as np
from paddle import fluid
from .giou_loss import GiouLoss
class DiouLoss(GiouLoss):
"""
Distance-IoU Loss, see https://arxiv.org/abs/1911.08287
Args:
loss_weight (float): diou loss weight, default as 10 in faster-rcnn
is_cls_agnostic (bool): flag of class-agnostic
num_classes (int): class num
use_complete_iou_loss (bool): whether to use complete iou loss
"""
def __init__(self,
loss_weight=10.,
is_cls_agnostic=False,
num_classes=81,
use_complete_iou_loss=True):
super(DiouLoss, self).__init__(
loss_weight=loss_weight,
is_cls_agnostic=is_cls_agnostic,
num_classes=num_classes)
self.use_complete_iou_loss = use_complete_iou_loss
def __call__(self,
x,
y,
inside_weight=None,
outside_weight=None,
bbox_reg_weight=[0.1, 0.1, 0.2, 0.2]):
eps = 1.e-10
x1, y1, x2, y2 = self.bbox_transform(x, bbox_reg_weight)
x1g, y1g, x2g, y2g = self.bbox_transform(y, bbox_reg_weight)
cx = (x1 + x2) / 2
cy = (y1 + y2) / 2
w = x2 - x1
h = y2 - y1
cxg = (x1g + x2g) / 2
cyg = (y1g + y2g) / 2
wg = x2g - x1g
hg = y2g - y1g
x2 = fluid.layers.elementwise_max(x1, x2)
y2 = fluid.layers.elementwise_max(y1, y2)
# A and B
xkis1 = fluid.layers.elementwise_max(x1, x1g)
ykis1 = fluid.layers.elementwise_max(y1, y1g)
xkis2 = fluid.layers.elementwise_min(x2, x2g)
ykis2 = fluid.layers.elementwise_min(y2, y2g)
# A or B
xc1 = fluid.layers.elementwise_min(x1, x1g)
yc1 = fluid.layers.elementwise_min(y1, y1g)
xc2 = fluid.layers.elementwise_max(x2, x2g)
yc2 = fluid.layers.elementwise_max(y2, y2g)
intsctk = (xkis2 - xkis1) * (ykis2 - ykis1)
intsctk = intsctk * fluid.layers.greater_than(
xkis2, xkis1) * fluid.layers.greater_than(ykis2, ykis1)
unionk = (x2 - x1) * (y2 - y1) + (x2g - x1g) * (y2g - y1g
) - intsctk + eps
iouk = intsctk / unionk
# DIOU term
dist_intersection = (cx - cxg) * (cx - cxg) + (cy - cyg) * (cy - cyg)
dist_union = (xc2 - xc1) * (xc2 - xc1) + (yc2 - yc1) * (yc2 - yc1)
diou_term = (dist_intersection + eps) / (dist_union + eps)
# CIOU term
ciou_term = 0
if self.use_complete_iou_loss:
ar_gt = wg / hg
ar_pred = w / h
arctan = fluid.layers.atan(ar_gt) - fluid.layers.atan(ar_pred)
ar_loss = 4. / np.pi / np.pi * arctan * arctan
alpha = ar_loss / (1 - iouk + ar_loss + eps)
alpha.stop_gradient = True
ciou_term = alpha * ar_loss
iou_weights = 1
if inside_weight is not None and outside_weight is not None:
inside_weight = fluid.layers.reshape(inside_weight, shape=(-1, 4))
outside_weight = fluid.layers.reshape(outside_weight, shape=(-1, 4))
inside_weight = fluid.layers.reduce_mean(inside_weight, dim=1)
outside_weight = fluid.layers.reduce_mean(outside_weight, dim=1)
iou_weights = inside_weight * outside_weight
class_weight = 2 if self.is_cls_agnostic else self.num_classes
diou = fluid.layers.reduce_mean(
(1 - iouk + ciou_term + diou_term) * iou_weights) * class_weight
return diou * self.loss_weight
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import numpy as np
from paddle import fluid
class GiouLoss(object):
'''
Generalized Intersection over Union, see https://arxiv.org/abs/1902.09630
Args:
loss_weight (float): diou loss weight, default as 10 in faster-rcnn
is_cls_agnostic (bool): flag of class-agnostic
num_classes (int): class num
'''
def __init__(self, loss_weight=10., is_cls_agnostic=False, num_classes=81):
super(GiouLoss, self).__init__()
self.loss_weight = loss_weight
self.is_cls_agnostic = is_cls_agnostic
self.num_classes = num_classes
# deltas: NxMx4
def bbox_transform(self, deltas, weights):
wx, wy, ww, wh = weights
deltas = fluid.layers.reshape(deltas, shape=(0, -1, 4))
dx = fluid.layers.slice(deltas, axes=[2], starts=[0], ends=[1]) * wx
dy = fluid.layers.slice(deltas, axes=[2], starts=[1], ends=[2]) * wy
dw = fluid.layers.slice(deltas, axes=[2], starts=[2], ends=[3]) * ww
dh = fluid.layers.slice(deltas, axes=[2], starts=[3], ends=[4]) * wh
dw = fluid.layers.clip(dw, -1.e10, np.log(1000. / 16))
dh = fluid.layers.clip(dh, -1.e10, np.log(1000. / 16))
pred_ctr_x = dx
pred_ctr_y = dy
pred_w = fluid.layers.exp(dw)
pred_h = fluid.layers.exp(dh)
x1 = pred_ctr_x - 0.5 * pred_w
y1 = pred_ctr_y - 0.5 * pred_h
x2 = pred_ctr_x + 0.5 * pred_w
y2 = pred_ctr_y + 0.5 * pred_h
x1 = fluid.layers.reshape(x1, shape=(-1, ))
y1 = fluid.layers.reshape(y1, shape=(-1, ))
x2 = fluid.layers.reshape(x2, shape=(-1, ))
y2 = fluid.layers.reshape(y2, shape=(-1, ))
return x1, y1, x2, y2
def __call__(self,
x,
y,
inside_weight=None,
outside_weight=None,
bbox_reg_weight=[0.1, 0.1, 0.2, 0.2]):
eps = 1.e-10
x1, y1, x2, y2 = self.bbox_transform(x, bbox_reg_weight)
x1g, y1g, x2g, y2g = self.bbox_transform(y, bbox_reg_weight)
x2 = fluid.layers.elementwise_max(x1, x2)
y2 = fluid.layers.elementwise_max(y1, y2)
xkis1 = fluid.layers.elementwise_max(x1, x1g)
ykis1 = fluid.layers.elementwise_max(y1, y1g)
xkis2 = fluid.layers.elementwise_min(x2, x2g)
ykis2 = fluid.layers.elementwise_min(y2, y2g)
xc1 = fluid.layers.elementwise_min(x1, x1g)
yc1 = fluid.layers.elementwise_min(y1, y1g)
xc2 = fluid.layers.elementwise_max(x2, x2g)
yc2 = fluid.layers.elementwise_max(y2, y2g)
intsctk = (xkis2 - xkis1) * (ykis2 - ykis1)
intsctk = intsctk * fluid.layers.greater_than(
xkis2, xkis1) * fluid.layers.greater_than(ykis2, ykis1)
unionk = (x2 - x1) * (y2 - y1) + (x2g - x1g) * (y2g - y1g
) - intsctk + eps
iouk = intsctk / unionk
area_c = (xc2 - xc1) * (yc2 - yc1) + eps
miouk = iouk - ((area_c - unionk) / area_c)
iou_weights = 1
if inside_weight is not None and outside_weight is not None:
inside_weight = fluid.layers.reshape(inside_weight, shape=(-1, 4))
outside_weight = fluid.layers.reshape(outside_weight, shape=(-1, 4))
inside_weight = fluid.layers.reduce_mean(inside_weight, dim=1)
outside_weight = fluid.layers.reduce_mean(outside_weight, dim=1)
iou_weights = inside_weight * outside_weight
class_weight = 2 if self.is_cls_agnostic else self.num_classes
iouk = fluid.layers.reduce_mean((1 - iouk) * iou_weights) * class_weight
miouk = fluid.layers.reduce_mean(
(1 - miouk) * iou_weights) * class_weight
return miouk * self.loss_weight
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import numpy as np
from paddle import fluid
class MultiClassDiouNMS(object):
def __init__(
self,
score_threshold=0.05,
keep_top_k=100,
nms_threshold=0.5,
normalized=False,
background_label=0, ):
super(MultiClassDiouNMS, self).__init__()
self.score_threshold = score_threshold
self.nms_threshold = nms_threshold
self.keep_top_k = keep_top_k
self.normalized = normalized
self.background_label = background_label
def __call__(self, bboxes, scores):
def create_tmp_var(program, name, dtype, shape, lod_level):
return program.current_block().create_var(
name=name, dtype=dtype, shape=shape, lod_level=lod_level)
def _calc_diou_term(dets1, dets2):
eps = 1.e-10
eta = 0 if self.normalized else 1
x1, y1, x2, y2 = dets1[0], dets1[1], dets1[2], dets1[3]
x1g, y1g, x2g, y2g = dets2[0], dets2[1], dets2[2], dets2[3]
cx = (x1 + x2) / 2
cy = (y1 + y2) / 2
w = x2 - x1 + eta
h = y2 - y1 + eta
cxg = (x1g + x2g) / 2
cyg = (y1g + y2g) / 2
wg = x2g - x1g + eta
hg = y2g - y1g + eta
x2 = np.maximum(x1, x2)
y2 = np.maximum(y1, y2)
# A or B
xc1 = np.minimum(x1, x1g)
yc1 = np.minimum(y1, y1g)
xc2 = np.maximum(x2, x2g)
yc2 = np.maximum(y2, y2g)
# DIOU term
dist_intersection = (cx - cxg)**2 + (cy - cyg)**2
dist_union = (xc2 - xc1)**2 + (yc2 - yc1)**2
diou_term = (dist_intersection + eps) / (dist_union + eps)
return diou_term
def _diou_nms_for_cls(dets, thres):
"""_diou_nms_for_cls"""
scores = dets[:, 0]
x1 = dets[:, 1]
y1 = dets[:, 2]
x2 = dets[:, 3]
y2 = dets[:, 4]
eta = 0 if self.normalized else 1
areas = (x2 - x1 + eta) * (y2 - y1 + eta)
dt_num = dets.shape[0]
order = np.array(range(dt_num))
keep = []
while order.size > 0:
i = order[0]
keep.append(i)
xx1 = np.maximum(x1[i], x1[order[1:]])
yy1 = np.maximum(y1[i], y1[order[1:]])
xx2 = np.minimum(x2[i], x2[order[1:]])
yy2 = np.minimum(y2[i], y2[order[1:]])
w = np.maximum(0.0, xx2 - xx1 + eta)
h = np.maximum(0.0, yy2 - yy1 + eta)
inter = w * h
ovr = inter / (areas[i] + areas[order[1:]] - inter)
diou_term = _calc_diou_term([x1[i], y1[i], x2[i], y2[i]], [
x1[order[1:]], y1[order[1:]], x2[order[1:]], y2[order[1:]]
])
inds = np.where(ovr - diou_term <= thres)[0]
order = order[inds + 1]
dets_final = dets[keep]
return dets_final
def _diou_nms(bboxes, scores):
bboxes = np.array(bboxes)
scores = np.array(scores)
class_nums = scores.shape[-1]
score_threshold = self.score_threshold
nms_threshold = self.nms_threshold
keep_top_k = self.keep_top_k
cls_boxes = [[] for _ in range(class_nums)]
cls_ids = [[] for _ in range(class_nums)]
start_idx = 1 if self.background_label == 0 else 0
for j in range(start_idx, class_nums):
inds = np.where(scores[:, j] >= score_threshold)[0]
scores_j = scores[inds, j]
rois_j = bboxes[inds, j, :]
dets_j = np.hstack((scores_j[:, np.newaxis], rois_j)).astype(
np.float32, copy=False)
cls_rank = np.argsort(-dets_j[:, 0])
dets_j = dets_j[cls_rank]
cls_boxes[j] = _diou_nms_for_cls(dets_j, thres=nms_threshold)
cls_ids[j] = np.array([j] * cls_boxes[j].shape[0]).reshape(-1,
1)
cls_boxes = np.vstack(cls_boxes[start_idx:])
cls_ids = np.vstack(cls_ids[start_idx:])
pred_result = np.hstack([cls_ids, cls_boxes]).astype(np.float32)
# Limit to max_per_image detections **over all classes**
image_scores = cls_boxes[:, 0]
if len(image_scores) > keep_top_k:
image_thresh = np.sort(image_scores)[-keep_top_k]
keep = np.where(cls_boxes[:, 0] >= image_thresh)[0]
pred_result = pred_result[keep, :]
res = fluid.LoDTensor()
res.set_lod([[0, pred_result.shape[0]]])
if pred_result.shape[0] == 0:
pred_result = np.array([[1]], dtype=np.float32)
res.set(pred_result, fluid.CPUPlace())
return res
pred_result = create_tmp_var(
fluid.default_main_program(),
name='diou_nms_pred_result',
dtype='float32',
shape=[-1, 6],
lod_level=0)
fluid.layers.py_func(
func=_diou_nms, x=[bboxes, scores], out=pred_result)
return pred_result
\ No newline at end of file
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册