未验证 提交 c7c0568f 编写于 作者: W wangguanzhong 提交者: GitHub

add ags module (#2885)

上级 a06a6258
......@@ -200,6 +200,9 @@ class TTFHead(nn.Layer):
lite_head(bool): whether use lite version. False by default.
norm_type (string): norm type, 'sync_bn', 'bn', 'gn' are optional.
bn by default
ags_module(bool): whether use AGS module to reweight location feature.
false by default.
"""
__shared__ = ['num_classes', 'down_ratio', 'norm_type']
......@@ -218,7 +221,8 @@ class TTFHead(nn.Layer):
down_ratio=4,
dcn_head=False,
lite_head=False,
norm_type='bn'):
norm_type='bn',
ags_module=False):
super(TTFHead, self).__init__()
self.in_channels = in_channels
self.hm_head = HMHead(in_channels, hm_head_planes, num_classes,
......@@ -230,6 +234,7 @@ class TTFHead(nn.Layer):
self.wh_offset_base = wh_offset_base
self.down_ratio = down_ratio
self.ags_module = ags_module
@classmethod
def from_config(cls, cfg, input_shape):
......@@ -253,6 +258,12 @@ class TTFHead(nn.Layer):
target = paddle.gather_nd(target, index)
return pred, target, weight
def filter_loc_by_weight(self, score, weight):
index = paddle.nonzero(weight > 0)
index.stop_gradient = True
score = paddle.gather_nd(score, index)
return score
def get_loss(self, pred_hm, pred_wh, target_hm, box_target, target_weight):
pred_hm = paddle.clip(F.sigmoid(pred_hm), 1e-4, 1 - 1e-4)
hm_loss = self.hm_loss(pred_hm, target_hm)
......@@ -274,10 +285,24 @@ class TTFHead(nn.Layer):
boxes = paddle.transpose(box_target, [0, 2, 3, 1])
boxes.stop_gradient = True
if self.ags_module:
pred_hm_max = paddle.max(pred_hm, axis=1, keepdim=True)
pred_hm_max_softmax = F.softmax(pred_hm_max, axis=1)
pred_hm_max_softmax = paddle.transpose(pred_hm_max_softmax,
[0, 2, 3, 1])
pred_hm_max_softmax = self.filter_loc_by_weight(pred_hm_max_softmax,
mask)
else:
pred_hm_max_softmax = None
pred_boxes, boxes, mask = self.filter_box_by_weight(pred_boxes, boxes,
mask)
mask.stop_gradient = True
wh_loss = self.wh_loss(pred_boxes, boxes, iou_weight=mask.unsqueeze(1))
wh_loss = self.wh_loss(
pred_boxes,
boxes,
iou_weight=mask.unsqueeze(1),
loc_reweight=pred_hm_max_softmax)
wh_loss = wh_loss / avg_factor
ttf_loss = {'hm_loss': hm_loss, 'wh_loss': wh_loss}
......
......@@ -110,7 +110,7 @@ class GIoULoss(object):
return iou, overlap, union
def __call__(self, pbox, gbox, iou_weight=1.):
def __call__(self, pbox, gbox, iou_weight=1., loc_reweight=None):
x1, y1, x2, y2 = paddle.split(pbox, num_or_sections=4, axis=-1)
x1g, y1g, x2g, y2g = paddle.split(gbox, num_or_sections=4, axis=-1)
box1 = [x1, y1, x2, y2]
......@@ -123,7 +123,13 @@ class GIoULoss(object):
area_c = (xc2 - xc1) * (yc2 - yc1) + self.eps
miou = iou - ((area_c - union) / area_c)
giou = 1 - miou
if loc_reweight is not None:
loc_reweight = paddle.reshape(loc_reweight, shape=(-1, 1))
loc_thresh = 0.9
giou = 1 - (1 - loc_thresh
) * miou - loc_thresh * miou * loc_reweight
else:
giou = 1 - miou
if self.reduction == 'none':
loss = giou
elif self.reduction == 'sum':
......
......@@ -67,6 +67,8 @@ class TTFHead(object):
keep_prob(float): keep_prob parameter for drop_block. 0.9 by default.
fusion_method (string): Method to fusion upsample and lateral branch.
'add' and 'concat' are optional, add by default
ags_module(bool): whether use AGS module to reweight location feature.
false by default.
"""
__inject__ = ['wh_loss']
......@@ -93,7 +95,8 @@ class TTFHead(object):
drop_block=False,
block_size=3,
keep_prob=0.9,
fusion_method='add'):
fusion_method='add',
ags_module=False):
super(TTFHead, self).__init__()
self.head_conv = head_conv
self.num_classes = num_classes
......@@ -119,6 +122,7 @@ class TTFHead(object):
self.block_size = block_size
self.keep_prob = keep_prob
self.fusion_method = fusion_method
self.ags_module = ags_module
def shortcut(self, x, out_c, layer_num, kernel_size=3, padding=1,
name=None):
......@@ -359,6 +363,12 @@ class TTFHead(object):
target = fluid.layers.gather_nd(target, index)
return pred, target, weight
def filter_loc_by_weight(self, score, weight):
index = fluid.layers.where(weight > 0)
index.stop_gradient = True
score = fluid.layers.gather_nd(score, index)
return score
def get_loss(self, pred_hm, pred_wh, target_hm, box_target, target_weight):
try:
pred_hm = paddle.clip(fluid.layers.sigmoid(pred_hm), 1e-4, 1 - 1e-4)
......@@ -387,11 +397,25 @@ class TTFHead(object):
boxes = fluid.layers.transpose(box_target, [0, 2, 3, 1])
boxes.stop_gradient = True
if self.ags_module:
pred_hm_max = fluid.layers.reduce_max(pred_hm, dim=1, keep_dim=True)
pred_hm_max_softmax = fluid.layers.softmax(pred_hm_max, axis=1)
pred_hm_max_softmax = fluid.layers.transpose(pred_hm_max_softmax,
[0, 2, 3, 1])
pred_hm_max_softmax = self.filter_loc_by_weight(pred_hm_max_softmax,
mask)
else:
pred_hm_max_softmax = None
pred_boxes, boxes, mask = self.filter_box_by_weight(pred_boxes, boxes,
mask)
mask.stop_gradient = True
wh_loss = self.wh_loss(
pred_boxes, boxes, outside_weight=mask, use_transform=False)
pred_boxes,
boxes,
loc_reweight=pred_hm_max_softmax,
outside_weight=mask,
use_transform=False)
wh_loss = wh_loss / avg_factor
ttf_loss = {'hm_loss': hm_loss, 'wh_loss': wh_loss}
......
......@@ -89,6 +89,7 @@ class GiouLoss(object):
inside_weight=None,
outside_weight=None,
bbox_reg_weight=[0.1, 0.1, 0.2, 0.2],
loc_reweight=None,
use_transform=True):
eps = 1.e-10
if use_transform:
......@@ -134,11 +135,19 @@ class GiouLoss(object):
elif outside_weight is not None:
iou_weights = outside_weight
if loc_reweight is not None:
loc_reweight = fluid.layers.reshape(loc_reweight, shape=(-1, 1))
loc_thresh = 0.9
giou = 1 - (1 - loc_thresh
) * miouk - loc_thresh * miouk * loc_reweight
else:
giou = 1 - miouk
if self.do_average:
miouk = fluid.layers.reduce_mean((1 - miouk) * iou_weights)
miouk = fluid.layers.reduce_mean(giou * iou_weights)
else:
iou_distance = fluid.layers.elementwise_mul(
1 - miouk, iou_weights, axis=0)
giou, iou_weights, axis=0)
miouk = fluid.layers.reduce_sum(iou_distance)
if self.use_class_weight:
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册