未验证 提交 e3029fd8 编写于 作者: C cnn 提交者: GitHub

[dev] clean code of s2anet (#3733) (#3778)

* clean code of s2anet

* fix for paddle==2.1.0

* remove debug info

* update comment
上级 621d2bdd
architecture: S2ANet architecture: S2ANet
pretrain_weights: https://paddledet.bj.bcebos.com/models/pretrained/ResNet50_cos_pretrained.pdparams pretrain_weights: https://paddledet.bj.bcebos.com/models/pretrained/ResNet50_vd_ssld_v2_pretrained.pdparams
weights: output/s2anet_r50_fpn_1x_dota/model_final.pdparams weights: output/s2anet_r50_fpn_1x_dota/model_final.pdparams
...@@ -12,6 +12,7 @@ S2ANet: ...@@ -12,6 +12,7 @@ S2ANet:
ResNet: ResNet:
depth: 50 depth: 50
variant: d
norm_type: bn norm_type: bn
return_idx: [1,2,3] return_idx: [1,2,3]
num_stages: 4 num_stages: 4
...@@ -33,11 +34,9 @@ S2ANetHead: ...@@ -33,11 +34,9 @@ S2ANetHead:
feat_in: 256 feat_in: 256
feat_out: 256 feat_out: 256
num_classes: 15 num_classes: 15
align_conv_type: 'Conv' # AlignConv Conv align_conv_type: 'AlignConv' # AlignConv Conv
align_conv_size: 3 align_conv_size: 3
use_sigmoid_cls: True use_sigmoid_cls: True
reg_loss_weight: [ 1.0, 1.0, 1.0, 1.0, 1.1 ]
cls_loss_weight: [ 1.1, 1.05 ]
RBoxAssigner: RBoxAssigner:
pos_iou_thr: 0.5 pos_iou_thr: 0.5
...@@ -54,3 +53,4 @@ S2ANetBBoxPostProcess: ...@@ -54,3 +53,4 @@ S2ANetBBoxPostProcess:
score_threshold: 0.05 score_threshold: 0.05
nms_threshold: 0.1 nms_threshold: 0.1
normalized: False normalized: False
#background_label: -1
...@@ -23,20 +23,16 @@ from ppdet.modeling.proposal_generator.target_layer import RBoxAssigner ...@@ -23,20 +23,16 @@ from ppdet.modeling.proposal_generator.target_layer import RBoxAssigner
import numpy as np import numpy as np
class S2ANetAnchorGenerator(object): class S2ANetAnchorGenerator(nn.Layer):
""" """
S2ANetAnchorGenerator by np AnchorGenerator by paddle
""" """
def __init__(self, def __init__(self, base_size, scales, ratios, scale_major=True, ctr=None):
base_size=8, super(S2ANetAnchorGenerator, self).__init__()
scales=1.0,
ratios=1.0,
scale_major=True,
ctr=None):
self.base_size = base_size self.base_size = base_size
self.scales = scales self.scales = paddle.to_tensor(scales)
self.ratios = ratios self.ratios = paddle.to_tensor(ratios)
self.scale_major = scale_major self.scale_major = scale_major
self.ctr = ctr self.ctr = ctr
self.base_anchors = self.gen_base_anchors() self.base_anchors = self.gen_base_anchors()
...@@ -54,7 +50,7 @@ class S2ANetAnchorGenerator(object): ...@@ -54,7 +50,7 @@ class S2ANetAnchorGenerator(object):
else: else:
x_ctr, y_ctr = self.ctr x_ctr, y_ctr = self.ctr
h_ratios = np.sqrt(self.ratios) h_ratios = paddle.sqrt(self.ratios)
w_ratios = 1 / h_ratios w_ratios = 1 / h_ratios
if self.scale_major: if self.scale_major:
ws = (w * w_ratios[:] * self.scales[:]).reshape([-1]) ws = (w * w_ratios[:] * self.scales[:]).reshape([-1])
...@@ -63,53 +59,51 @@ class S2ANetAnchorGenerator(object): ...@@ -63,53 +59,51 @@ class S2ANetAnchorGenerator(object):
ws = (w * self.scales[:] * w_ratios[:]).reshape([-1]) ws = (w * self.scales[:] * w_ratios[:]).reshape([-1])
hs = (h * self.scales[:] * h_ratios[:]).reshape([-1]) hs = (h * self.scales[:] * h_ratios[:]).reshape([-1])
# yapf: disable base_anchors = paddle.stack(
base_anchors = np.stack(
[ [
x_ctr - 0.5 * (ws - 1), y_ctr - 0.5 * (hs - 1), x_ctr - 0.5 * (ws - 1), y_ctr - 0.5 * (hs - 1),
x_ctr + 0.5 * (ws - 1), y_ctr + 0.5 * (hs - 1) x_ctr + 0.5 * (ws - 1), y_ctr + 0.5 * (hs - 1)
], ],
axis=-1) axis=-1)
base_anchors = np.round(base_anchors) base_anchors = paddle.round(base_anchors)
# yapf: enable
return base_anchors return base_anchors
def _meshgrid(self, x, y, row_major=True): def _meshgrid(self, x, y, row_major=True):
xx, yy = np.meshgrid(x, y) yy, xx = paddle.meshgrid(x, y)
xx = xx.reshape(-1) yy = yy.reshape([-1])
yy = yy.reshape(-1) xx = xx.reshape([-1])
if row_major: if row_major:
return xx, yy return xx, yy
else: else:
return yy, xx return yy, xx
def grid_anchors(self, featmap_size, stride=16): def forward(self, featmap_size, stride=16):
# featmap_size*stride project it to original area # featmap_size*stride project it to original area
base_anchors = self.base_anchors
feat_h, feat_w = featmap_size feat_h = featmap_size[0]
shift_x = np.arange(0, feat_w, 1, 'int32') * stride feat_w = featmap_size[1]
shift_y = np.arange(0, feat_h, 1, 'int32') * stride shift_x = paddle.arange(0, feat_w, 1, 'int32') * stride
shift_y = paddle.arange(0, feat_h, 1, 'int32') * stride
shift_xx, shift_yy = self._meshgrid(shift_x, shift_y) shift_xx, shift_yy = self._meshgrid(shift_x, shift_y)
shifts = np.stack([shift_xx, shift_yy, shift_xx, shift_yy], axis=-1) shifts = paddle.stack([shift_xx, shift_yy, shift_xx, shift_yy], axis=-1)
all_anchors = base_anchors[None, :, :] + shifts[:, None, :] all_anchors = self.base_anchors[:, :] + shifts[:, :]
all_anchors = all_anchors.reshape([feat_h * feat_w, 4])
return all_anchors return all_anchors
def valid_flags(self, featmap_size, valid_size): def valid_flags(self, featmap_size, valid_size):
feat_h, feat_w = featmap_size feat_h, feat_w = featmap_size
valid_h, valid_w = valid_size valid_h, valid_w = valid_size
assert valid_h <= feat_h and valid_w <= feat_w assert valid_h <= feat_h and valid_w <= feat_w
valid_x = np.zeros([feat_w], dtype='uint8') valid_x = paddle.zeros([feat_w], dtype='int32')
valid_y = np.zeros([feat_h], dtype='uint8') valid_y = paddle.zeros([feat_h], dtype='int32')
valid_x[:valid_w] = 1 valid_x[:valid_w] = 1
valid_y[:valid_h] = 1 valid_y[:valid_h] = 1
valid_xx, valid_yy = self._meshgrid(valid_x, valid_y) valid_xx, valid_yy = self._meshgrid(valid_x, valid_y)
valid = valid_xx & valid_yy valid = valid_xx & valid_yy
valid = valid.reshape([-1]) valid = paddle.reshape(valid, [-1, 1])
valid = paddle.expand(valid,
# valid = valid[:, None].expand( [-1, self.num_base_anchors]).reshape([-1])
# [valid.size(0), self.num_base_anchors]).reshape([-1])
return valid return valid
...@@ -138,7 +132,8 @@ class AlignConv(nn.Layer): ...@@ -138,7 +132,8 @@ class AlignConv(nn.Layer):
""" """
anchors = paddle.reshape(anchors, [-1, 5]) # (NA,5) anchors = paddle.reshape(anchors, [-1, 5]) # (NA,5)
dtype = anchors.dtype dtype = anchors.dtype
feat_h, feat_w = featmap_size feat_h = featmap_size[0]
feat_w = featmap_size[1]
pad = (self.kernel_size - 1) // 2 pad = (self.kernel_size - 1) // 2
idx = paddle.arange(-pad, pad + 1, dtype=dtype) idx = paddle.arange(-pad, pad + 1, dtype=dtype)
...@@ -164,11 +159,11 @@ class AlignConv(nn.Layer): ...@@ -164,11 +159,11 @@ class AlignConv(nn.Layer):
h = anchors[:, 3] h = anchors[:, 3]
a = anchors[:, 4] a = anchors[:, 4]
x_ctr = paddle.reshape(x_ctr, [x_ctr.shape[0], 1]) x_ctr = paddle.reshape(x_ctr, [-1, 1])
y_ctr = paddle.reshape(y_ctr, [y_ctr.shape[0], 1]) y_ctr = paddle.reshape(y_ctr, [-1, 1])
w = paddle.reshape(w, [w.shape[0], 1]) w = paddle.reshape(w, [-1, 1])
h = paddle.reshape(h, [h.shape[0], 1]) h = paddle.reshape(h, [-1, 1])
a = paddle.reshape(a, [a.shape[0], 1]) a = paddle.reshape(a, [-1, 1])
x_ctr = x_ctr / stride x_ctr = x_ctr / stride
y_ctr = y_ctr / stride y_ctr = y_ctr / stride
...@@ -183,20 +178,13 @@ class AlignConv(nn.Layer): ...@@ -183,20 +178,13 @@ class AlignConv(nn.Layer):
# get offset filed # get offset filed
offset_x = x_anchor - x_conv offset_x = x_anchor - x_conv
offset_y = y_anchor - y_conv offset_y = y_anchor - y_conv
# x, y in anchors is opposite in image coordinates,
# so we stack them with y, x other than x, y
offset = paddle.stack([offset_y, offset_x], axis=-1) offset = paddle.stack([offset_y, offset_x], axis=-1)
# NA,ks*ks*2 offset = paddle.reshape(offset, [feat_h * feat_w, self.kernel_size * self.kernel_size * 2])
# [NA, ks, ks, 2] --> [NA, ks*ks*2]
offset = paddle.reshape(offset, [offset.shape[0], -1])
# [NA, ks*ks*2] --> [ks*ks*2, NA]
offset = paddle.transpose(offset, [1, 0]) offset = paddle.transpose(offset, [1, 0])
# [NA, ks*ks*2] --> [1, ks*ks*2, H, W] offset = paddle.reshape(offset, [1, self.kernel_size * self.kernel_size * 2, feat_h, feat_w])
offset = paddle.reshape(offset, [1, -1, feat_h, feat_w])
return offset return offset
def forward(self, x, refine_anchors, stride): def forward(self, x, refine_anchors, featmap_size, stride):
featmap_size = (x.shape[2], x.shape[3])
offset = self.get_offset(refine_anchors, featmap_size, stride) offset = self.get_offset(refine_anchors, featmap_size, stride)
x = F.relu(self.align_conv(x, offset)) x = F.relu(self.align_conv(x, offset))
return x return x
...@@ -232,14 +220,15 @@ class S2ANetHead(nn.Layer): ...@@ -232,14 +220,15 @@ class S2ANetHead(nn.Layer):
anchor_strides=[8, 16, 32, 64, 128], anchor_strides=[8, 16, 32, 64, 128],
anchor_scales=[4], anchor_scales=[4],
anchor_ratios=[1.0], anchor_ratios=[1.0],
target_means=(.0, .0, .0, .0, .0), target_means=0.0,
target_stds=(1.0, 1.0, 1.0, 1.0, 1.0), target_stds=1.0,
align_conv_type='AlignConv', align_conv_type='AlignConv',
align_conv_size=3, align_conv_size=3,
use_sigmoid_cls=True, use_sigmoid_cls=True,
anchor_assign=RBoxAssigner().__dict__, anchor_assign=RBoxAssigner().__dict__,
reg_loss_weight=[1.0, 1.0, 1.0, 1.0, 1.0], reg_loss_weight=[1.0, 1.0, 1.0, 1.0, 1.0],
cls_loss_weight=[1.0, 1.0]): cls_loss_weight=[1.0, 1.0],
reg_loss_type='l1'):
super(S2ANetHead, self).__init__() super(S2ANetHead, self).__init__()
self.stacked_convs = stacked_convs self.stacked_convs = stacked_convs
self.feat_in = feat_in self.feat_in = feat_in
...@@ -248,9 +237,10 @@ class S2ANetHead(nn.Layer): ...@@ -248,9 +237,10 @@ class S2ANetHead(nn.Layer):
self.anchor_scales = anchor_scales self.anchor_scales = anchor_scales
self.anchor_ratios = anchor_ratios self.anchor_ratios = anchor_ratios
self.anchor_strides = anchor_strides self.anchor_strides = anchor_strides
self.anchor_strides = paddle.to_tensor(anchor_strides)
self.anchor_base_sizes = list(anchor_strides) self.anchor_base_sizes = list(anchor_strides)
self.target_means = target_means self.means = paddle.ones(shape=[5]) * target_means
self.target_stds = target_stds self.stds = paddle.ones(shape=[5]) * target_stds
assert align_conv_type in ['AlignConv', 'Conv', 'DCN'] assert align_conv_type in ['AlignConv', 'Conv', 'DCN']
self.align_conv_type = align_conv_type self.align_conv_type = align_conv_type
self.align_conv_size = align_conv_size self.align_conv_size = align_conv_size
...@@ -261,16 +251,20 @@ class S2ANetHead(nn.Layer): ...@@ -261,16 +251,20 @@ class S2ANetHead(nn.Layer):
self.anchor_assign = anchor_assign self.anchor_assign = anchor_assign
self.reg_loss_weight = reg_loss_weight self.reg_loss_weight = reg_loss_weight
self.cls_loss_weight = cls_loss_weight self.cls_loss_weight = cls_loss_weight
self.alpha = 1.0
self.beta = 1.0
self.reg_loss_type = reg_loss_type
self.s2anet_head_out = None self.s2anet_head_out = None
# anchor # anchor
self.anchor_generators = [] self.anchor_generators = []
for anchor_base in self.anchor_base_sizes: for anchor_base in self.anchor_base_sizes:
self.anchor_generators.append( self.anchor_generators.append(
S2ANetAnchorGenerator(anchor_base, anchor_scales, S2ANetAnchorGenerator(anchor_base, anchor_scales,
anchor_ratios)) anchor_ratios))
self.anchor_generators = nn.LayerList(self.anchor_generators)
self.fam_cls_convs = nn.Sequential() self.fam_cls_convs = nn.Sequential()
self.fam_reg_convs = nn.Sequential() self.fam_reg_convs = nn.Sequential()
...@@ -404,9 +398,8 @@ class S2ANetHead(nn.Layer): ...@@ -404,9 +398,8 @@ class S2ANetHead(nn.Layer):
weight_attr=ParamAttr(initializer=Normal(0.0, 0.01)), weight_attr=ParamAttr(initializer=Normal(0.0, 0.01)),
bias_attr=ParamAttr(initializer=Constant(0))) bias_attr=ParamAttr(initializer=Constant(0)))
self.base_anchors = dict() self.featmap_sizes = []
self.featmap_sizes = dict() self.base_anchors_list = []
self.base_anchors = dict()
self.refine_anchor_list = [] self.refine_anchor_list = []
def forward(self, feats): def forward(self, feats):
...@@ -416,11 +409,12 @@ class S2ANetHead(nn.Layer): ...@@ -416,11 +409,12 @@ class S2ANetHead(nn.Layer):
odm_reg_branch_list = [] odm_reg_branch_list = []
odm_cls_branch_list = [] odm_cls_branch_list = []
self.featmap_sizes = dict() self.featmap_sizes_list = []
self.base_anchors = dict() self.base_anchors_list = []
self.refine_anchor_list = [] self.refine_anchor_list = []
for i, feat in enumerate(feats): for feat_idx in range(len(feats)):
feat = feats[feat_idx]
fam_cls_feat = self.fam_cls_convs(feat) fam_cls_feat = self.fam_cls_convs(feat)
fam_cls = self.fam_cls(fam_cls_feat) fam_cls = self.fam_cls(fam_cls_feat)
...@@ -439,26 +433,30 @@ class S2ANetHead(nn.Layer): ...@@ -439,26 +433,30 @@ class S2ANetHead(nn.Layer):
fam_reg_branch_list.append(fam_reg_reshape) fam_reg_branch_list.append(fam_reg_reshape)
# prepare anchor # prepare anchor
featmap_size = feat.shape[-2:] featmap_size = (paddle.shape(feat)[2], paddle.shape(feat)[3])
self.featmap_sizes[i] = featmap_size self.featmap_sizes_list.append(featmap_size)
init_anchors = self.anchor_generators[i].grid_anchors( init_anchors = self.anchor_generators[feat_idx](
featmap_size, self.anchor_strides[i]) featmap_size, self.anchor_strides[feat_idx])
init_anchors = bbox_utils.rect2rbox(init_anchors) init_anchors = paddle.to_tensor(init_anchors, dtype='float32')
self.base_anchors[(i, featmap_size[0])] = init_anchors NA = featmap_size[0] * featmap_size[1]
init_anchors = paddle.reshape(
#fam_reg1 = fam_reg init_anchors, [NA, 4])
#fam_reg1.stop_gradient = True init_anchors = self.rect2rbox(init_anchors)
refine_anchor = bbox_utils.bbox_decode( self.base_anchors_list.append(init_anchors)
fam_reg.detach(), init_anchors, self.target_means,
self.target_stds) fam_reg1 = fam_reg
fam_reg1.stop_gradient = True
refine_anchor = self.bbox_decode(fam_reg1, init_anchors)
#refine_anchor = self.bbox_decode(fam_reg.detach(), init_anchors)
self.refine_anchor_list.append(refine_anchor) self.refine_anchor_list.append(refine_anchor)
if self.align_conv_type == 'AlignConv': if self.align_conv_type == 'AlignConv':
align_feat = self.align_conv(feat, align_feat = self.align_conv(feat,
refine_anchor.clone(), refine_anchor.clone(),
self.anchor_strides[i]) featmap_size,
self.anchor_strides[feat_idx])
elif self.align_conv_type == 'DCN': elif self.align_conv_type == 'DCN':
align_offset = self.align_conv_offset(feat) align_offset = self.align_conv_offset(feat)
align_feat = self.align_conv(feat, align_offset) align_feat = self.align_conv(feat, align_offset)
...@@ -475,9 +473,10 @@ class S2ANetHead(nn.Layer): ...@@ -475,9 +473,10 @@ class S2ANetHead(nn.Layer):
odm_cls_score = self.odm_cls(odm_cls_feat) odm_cls_score = self.odm_cls(odm_cls_feat)
# [N, CLS, H, W] --> [N, H, W, CLS] # [N, CLS, H, W] --> [N, H, W, CLS]
odm_cls_score = odm_cls_score.transpose([0, 2, 3, 1]) odm_cls_score = odm_cls_score.transpose([0, 2, 3, 1])
odm_cls_score_shape = odm_cls_score.shape
odm_cls_score_reshape = paddle.reshape( odm_cls_score_reshape = paddle.reshape(
odm_cls_score, odm_cls_score,
[odm_cls_score.shape[0], -1, self.cls_out_channels]) [odm_cls_score_shape[0], odm_cls_score_shape[1] * odm_cls_score_shape[2], self.cls_out_channels])
odm_cls_branch_list.append(odm_cls_score_reshape) odm_cls_branch_list.append(odm_cls_score_reshape)
...@@ -485,23 +484,27 @@ class S2ANetHead(nn.Layer): ...@@ -485,23 +484,27 @@ class S2ANetHead(nn.Layer):
# [N, 5, H, W] --> [N, H, W, 5] # [N, 5, H, W] --> [N, H, W, 5]
odm_bbox_pred = odm_bbox_pred.transpose([0, 2, 3, 1]) odm_bbox_pred = odm_bbox_pred.transpose([0, 2, 3, 1])
odm_bbox_pred_reshape = paddle.reshape( odm_bbox_pred_reshape = paddle.reshape(
odm_bbox_pred, [odm_bbox_pred.shape[0], -1, 5]) odm_bbox_pred, [-1, 5])
odm_bbox_pred_reshape = paddle.unsqueeze(odm_bbox_pred_reshape, axis=0)
odm_reg_branch_list.append(odm_bbox_pred_reshape) odm_reg_branch_list.append(odm_bbox_pred_reshape)
self.s2anet_head_out = (fam_cls_branch_list, fam_reg_branch_list, self.s2anet_head_out = (fam_cls_branch_list, fam_reg_branch_list,
odm_cls_branch_list, odm_reg_branch_list) odm_cls_branch_list, odm_reg_branch_list)
return self.s2anet_head_out return self.s2anet_head_out
def get_prediction(self, nms_pre): def get_prediction(self, nms_pre=2000):
refine_anchors = self.refine_anchor_list refine_anchors = self.refine_anchor_list
fam_cls_branch_list, fam_reg_branch_list, odm_cls_branch_list, odm_reg_branch_list = self.s2anet_head_out fam_cls_branch_list = self.s2anet_head_out[0]
fam_reg_branch_list = self.s2anet_head_out[1]
odm_cls_branch_list = self.s2anet_head_out[2]
odm_reg_branch_list = self.s2anet_head_out[3]
pred_scores, pred_bboxes = self.get_bboxes( pred_scores, pred_bboxes = self.get_bboxes(
odm_cls_branch_list, odm_cls_branch_list,
odm_reg_branch_list, odm_reg_branch_list,
refine_anchors, refine_anchors,
nms_pre, nms_pre,
cls_out_channels=self.cls_out_channels, self.cls_out_channels,
use_sigmoid_cls=self.use_sigmoid_cls) self.use_sigmoid_cls)
return pred_scores, pred_bboxes return pred_scores, pred_bboxes
def smooth_l1_loss(self, pred, label, delta=1.0 / 9.0): def smooth_l1_loss(self, pred, label, delta=1.0 / 9.0):
...@@ -519,7 +522,7 @@ class S2ANetHead(nn.Layer): ...@@ -519,7 +522,7 @@ class S2ANetHead(nn.Layer):
diff - 0.5 * delta) diff - 0.5 * delta)
return loss return loss
def get_fam_loss(self, fam_target, s2anet_head_out): def get_fam_loss(self, fam_target, s2anet_head_out, reg_loss_type='gwd'):
(labels, label_weights, bbox_targets, bbox_weights, pos_inds, (labels, label_weights, bbox_targets, bbox_weights, pos_inds,
neg_inds) = fam_target neg_inds) = fam_target
fam_cls_branch_list, fam_reg_branch_list, odm_cls_branch_list, odm_reg_branch_list = s2anet_head_out fam_cls_branch_list, fam_reg_branch_list, odm_cls_branch_list, odm_reg_branch_list = s2anet_head_out
...@@ -527,12 +530,11 @@ class S2ANetHead(nn.Layer): ...@@ -527,12 +530,11 @@ class S2ANetHead(nn.Layer):
fam_cls_losses = [] fam_cls_losses = []
fam_bbox_losses = [] fam_bbox_losses = []
st_idx = 0 st_idx = 0
featmap_sizes = [self.featmap_sizes[e] for e in self.featmap_sizes]
num_total_samples = len(pos_inds) + len( num_total_samples = len(pos_inds) + len(
neg_inds) if self.sampling else len(pos_inds) neg_inds) if self.sampling else len(pos_inds)
num_total_samples = max(1, num_total_samples) num_total_samples = max(1, num_total_samples)
for idx, feat_size in enumerate(featmap_sizes): for idx, feat_size in enumerate(self.featmap_sizes_list):
feat_anchor_num = feat_size[0] * feat_size[1] feat_anchor_num = feat_size[0] * feat_size[1]
# step1: get data # step1: get data
...@@ -587,13 +589,17 @@ class S2ANetHead(nn.Layer): ...@@ -587,13 +589,17 @@ class S2ANetHead(nn.Layer):
fam_bbox_pred = paddle.squeeze(fam_bbox_pred, axis=0) fam_bbox_pred = paddle.squeeze(fam_bbox_pred, axis=0)
fam_bbox_pred = paddle.reshape(fam_bbox_pred, [-1, 5]) fam_bbox_pred = paddle.reshape(fam_bbox_pred, [-1, 5])
fam_bbox = self.smooth_l1_loss(fam_bbox_pred, feat_bbox_targets) fam_bbox = self.smooth_l1_loss(fam_bbox_pred, feat_bbox_targets)
loss_weight = paddle.to_tensor(
self.reg_loss_weight, dtype='float32', stop_gradient=True) # iou_factor
fam_bbox = paddle.multiply(fam_bbox, loss_weight) if reg_loss_type == 'l1':
feat_bbox_weights = paddle.to_tensor( fam_bbox = self.smooth_l1_loss(fam_bbox_pred, feat_bbox_targets)
feat_bbox_weights, stop_gradient=True) loss_weight = paddle.to_tensor(
fam_bbox = fam_bbox * feat_bbox_weights self.reg_loss_weight, dtype='float32', stop_gradient=True)
fam_bbox_total = paddle.sum(fam_bbox) / num_total_samples fam_bbox = paddle.multiply(fam_bbox, loss_weight)
feat_bbox_weights = paddle.to_tensor(
feat_bbox_weights, stop_gradient=True)
fam_bbox = fam_bbox * feat_bbox_weights
fam_bbox_total = paddle.sum(fam_bbox) / num_total_samples
fam_bbox_losses.append(fam_bbox_total) fam_bbox_losses.append(fam_bbox_total)
...@@ -604,7 +610,7 @@ class S2ANetHead(nn.Layer): ...@@ -604,7 +610,7 @@ class S2ANetHead(nn.Layer):
fam_reg_loss = paddle.add_n(fam_bbox_losses) fam_reg_loss = paddle.add_n(fam_bbox_losses)
return fam_cls_loss, fam_reg_loss return fam_cls_loss, fam_reg_loss
def get_odm_loss(self, odm_target, s2anet_head_out): def get_odm_loss(self, odm_target, s2anet_head_out, reg_loss_type='gwd'):
(labels, label_weights, bbox_targets, bbox_weights, pos_inds, (labels, label_weights, bbox_targets, bbox_weights, pos_inds,
neg_inds) = odm_target neg_inds) = odm_target
fam_cls_branch_list, fam_reg_branch_list, odm_cls_branch_list, odm_reg_branch_list = s2anet_head_out fam_cls_branch_list, fam_reg_branch_list, odm_cls_branch_list, odm_reg_branch_list = s2anet_head_out
...@@ -612,11 +618,11 @@ class S2ANetHead(nn.Layer): ...@@ -612,11 +618,11 @@ class S2ANetHead(nn.Layer):
odm_cls_losses = [] odm_cls_losses = []
odm_bbox_losses = [] odm_bbox_losses = []
st_idx = 0 st_idx = 0
featmap_sizes = [self.featmap_sizes[e] for e in self.featmap_sizes]
num_total_samples = len(pos_inds) + len( num_total_samples = len(pos_inds) + len(
neg_inds) if self.sampling else len(pos_inds) neg_inds) if self.sampling else len(pos_inds)
num_total_samples = max(1, num_total_samples) num_total_samples = max(1, num_total_samples)
for idx, feat_size in enumerate(featmap_sizes):
for idx, feat_size in enumerate(self.featmap_sizes_list):
feat_anchor_num = feat_size[0] * feat_size[1] feat_anchor_num = feat_size[0] * feat_size[1]
# step1: get data # step1: get data
...@@ -670,13 +676,18 @@ class S2ANetHead(nn.Layer): ...@@ -670,13 +676,18 @@ class S2ANetHead(nn.Layer):
odm_bbox_pred = paddle.squeeze(odm_bbox_pred, axis=0) odm_bbox_pred = paddle.squeeze(odm_bbox_pred, axis=0)
odm_bbox_pred = paddle.reshape(odm_bbox_pred, [-1, 5]) odm_bbox_pred = paddle.reshape(odm_bbox_pred, [-1, 5])
odm_bbox = self.smooth_l1_loss(odm_bbox_pred, feat_bbox_targets) odm_bbox = self.smooth_l1_loss(odm_bbox_pred, feat_bbox_targets)
loss_weight = paddle.to_tensor(
self.reg_loss_weight, dtype='float32', stop_gradient=True) # iou_factor odm not use_iou
odm_bbox = paddle.multiply(odm_bbox, loss_weight) if reg_loss_type == 'l1':
feat_bbox_weights = paddle.to_tensor( odm_bbox = self.smooth_l1_loss(odm_bbox_pred, feat_bbox_targets)
feat_bbox_weights, stop_gradient=True) loss_weight = paddle.to_tensor(
odm_bbox = odm_bbox * feat_bbox_weights self.reg_loss_weight, dtype='float32', stop_gradient=True)
odm_bbox_total = paddle.sum(odm_bbox) / num_total_samples odm_bbox = paddle.multiply(odm_bbox, loss_weight)
feat_bbox_weights = paddle.to_tensor(
feat_bbox_weights, stop_gradient=True)
odm_bbox = odm_bbox * feat_bbox_weights
odm_bbox_total = paddle.sum(odm_bbox) / num_total_samples
odm_bbox_losses.append(odm_bbox_total) odm_bbox_losses.append(odm_bbox_total)
odm_cls_loss = paddle.add_n(odm_cls_losses) odm_cls_loss = paddle.add_n(odm_cls_losses)
...@@ -706,15 +717,7 @@ class S2ANetHead(nn.Layer): ...@@ -706,15 +717,7 @@ class S2ANetHead(nn.Layer):
gt_labels = gt_labels + 1 gt_labels = gt_labels + 1
# featmap_sizes # featmap_sizes
featmap_sizes = [self.featmap_sizes[e] for e in self.featmap_sizes] anchors_list_all = np.concatenate(self.base_anchors_list)
anchors_list, valid_flag_list = self.get_init_anchors(featmap_sizes,
np_im_shape)
anchors_list_all = []
for ii, anchor in enumerate(anchors_list):
anchor = anchor.reshape(-1, 4)
anchor = bbox_utils.rect2rbox(anchor)
anchors_list_all.extend(anchor)
anchors_list_all = np.array(anchors_list_all)
# get im_feat # get im_feat
fam_cls_feats_list = [e[im_id] for e in self.s2anet_head_out[0]] fam_cls_feats_list = [e[im_id] for e in self.s2anet_head_out[0]]
...@@ -729,20 +732,20 @@ class S2ANetHead(nn.Layer): ...@@ -729,20 +732,20 @@ class S2ANetHead(nn.Layer):
gt_labels, is_crowd) gt_labels, is_crowd)
if im_fam_target is not None: if im_fam_target is not None:
im_fam_cls_loss, im_fam_reg_loss = self.get_fam_loss( im_fam_cls_loss, im_fam_reg_loss = self.get_fam_loss(
im_fam_target, im_s2anet_head_out) im_fam_target, im_s2anet_head_out, self.reg_loss_type)
fam_cls_loss_lst.append(im_fam_cls_loss) fam_cls_loss_lst.append(im_fam_cls_loss)
fam_reg_loss_lst.append(im_fam_reg_loss) fam_reg_loss_lst.append(im_fam_reg_loss)
# ODM # ODM
refine_anchors_list, valid_flag_list = self.get_refine_anchors( np_refine_anchors_list = paddle.concat(self.refine_anchor_list).numpy()
featmap_sizes, image_shape=np_im_shape) np_refine_anchors_list = np.concatenate(np_refine_anchors_list)
refine_anchors_list = np.array(refine_anchors_list) np_refine_anchors_list = np_refine_anchors_list.reshape(-1, 5)
im_odm_target = self.anchor_assign(refine_anchors_list, gt_bboxes, im_odm_target = self.anchor_assign(np_refine_anchors_list, gt_bboxes,
gt_labels, is_crowd) gt_labels, is_crowd)
if im_odm_target is not None: if im_odm_target is not None:
im_odm_cls_loss, im_odm_reg_loss = self.get_odm_loss( im_odm_cls_loss, im_odm_reg_loss = self.get_odm_loss(
im_odm_target, im_s2anet_head_out) im_odm_target, im_s2anet_head_out, self.reg_loss_type)
odm_cls_loss_lst.append(im_odm_cls_loss) odm_cls_loss_lst.append(im_odm_cls_loss)
odm_reg_loss_lst.append(im_odm_reg_loss) odm_reg_loss_lst.append(im_odm_reg_loss)
fam_cls_loss = paddle.add_n(fam_cls_loss_lst) fam_cls_loss = paddle.add_n(fam_cls_loss_lst)
...@@ -756,65 +759,6 @@ class S2ANetHead(nn.Layer): ...@@ -756,65 +759,6 @@ class S2ANetHead(nn.Layer):
'odm_reg_loss': odm_reg_loss 'odm_reg_loss': odm_reg_loss
} }
def get_init_anchors(self, featmap_sizes, image_shape):
"""Get anchors according to feature map sizes.
Args:
featmap_sizes (list[tuple]): Multi-level feature map sizes.
image_shape (list[dict]): Image meta info.
Returns:
tuple: anchors of each image, valid flags of each image
"""
num_levels = len(featmap_sizes)
# since feature map sizes of all images are the same, we only compute
# anchors for one time
anchor_list = []
for i in range(num_levels):
anchors = self.anchor_generators[i].grid_anchors(
featmap_sizes[i], self.anchor_strides[i])
anchor_list.append(anchors)
# for each image, we compute valid flags of multi level anchors
valid_flag_list = []
for i in range(num_levels):
anchor_stride = self.anchor_strides[i]
feat_h, feat_w = featmap_sizes[i]
h, w = image_shape
valid_feat_h = min(int(np.ceil(h / anchor_stride)), feat_h)
valid_feat_w = min(int(np.ceil(w / anchor_stride)), feat_w)
flags = self.anchor_generators[i].valid_flags(
(feat_h, feat_w), (valid_feat_h, valid_feat_w))
valid_flag_list.append(flags)
return anchor_list, valid_flag_list
def get_refine_anchors(self, featmap_sizes, image_shape):
num_levels = len(featmap_sizes)
refine_anchors_list = []
for i in range(num_levels):
refine_anchor = self.refine_anchor_list[i]
refine_anchor = paddle.squeeze(refine_anchor, axis=0)
refine_anchor = refine_anchor.numpy()
refine_anchor = np.reshape(refine_anchor,
[-1, refine_anchor.shape[-1]])
refine_anchors_list.extend(refine_anchor)
# for each image, we compute valid flags of multi level anchors
valid_flag_list = []
for i in range(num_levels):
anchor_stride = self.anchor_strides[i]
feat_h, feat_w = featmap_sizes[i]
h, w = image_shape
valid_feat_h = min(int(np.ceil(h / anchor_stride)), feat_h)
valid_feat_w = min(int(np.ceil(w / anchor_stride)), feat_w)
flags = self.anchor_generators[i].valid_flags(
(feat_h, feat_w), (valid_feat_h, valid_feat_w))
valid_flag_list.append(flags)
return refine_anchors_list, valid_flag_list
def get_bboxes(self, cls_score_list, bbox_pred_list, mlvl_anchors, nms_pre, def get_bboxes(self, cls_score_list, bbox_pred_list, mlvl_anchors, nms_pre,
cls_out_channels, use_sigmoid_cls): cls_out_channels, use_sigmoid_cls):
assert len(cls_score_list) == len(bbox_pred_list) == len(mlvl_anchors) assert len(cls_score_list) == len(bbox_pred_list) == len(mlvl_anchors)
...@@ -848,10 +792,8 @@ class S2ANetHead(nn.Layer): ...@@ -848,10 +792,8 @@ class S2ANetHead(nn.Layer):
bbox_pred = paddle.gather(bbox_pred, topk_inds) bbox_pred = paddle.gather(bbox_pred, topk_inds)
scores = paddle.gather(scores, topk_inds) scores = paddle.gather(scores, topk_inds)
target_means = (.0, .0, .0, .0, .0) bbox_delta = paddle.reshape(bbox_pred, [-1, 5])
target_stds = (1.0, 1.0, 1.0, 1.0, 1.0) bboxes = self.delta2rbox(anchors, bbox_delta)
bboxes = bbox_utils.delta2rbox(anchors, bbox_pred, target_means,
target_stds)
mlvl_bboxes.append(bboxes) mlvl_bboxes.append(bboxes)
mlvl_scores.append(scores) mlvl_scores.append(scores)
...@@ -861,3 +803,86 @@ class S2ANetHead(nn.Layer): ...@@ -861,3 +803,86 @@ class S2ANetHead(nn.Layer):
mlvl_scores = paddle.concat(mlvl_scores) mlvl_scores = paddle.concat(mlvl_scores)
return mlvl_scores, mlvl_bboxes return mlvl_scores, mlvl_bboxes
def rect2rbox(self, bboxes):
"""
:param bboxes: shape (n, 4) (xmin, ymin, xmax, ymax)
:return: dbboxes: shape (n, 5) (x_ctr, y_ctr, w, h, angle)
"""
bboxes = paddle.reshape(bboxes, [-1, 4])
num_boxes = paddle.shape(bboxes)[0]
x_ctr = (bboxes[:, 2] + bboxes[:, 0]) / 2.0
y_ctr = (bboxes[:, 3] + bboxes[:, 1]) / 2.0
edges1 = paddle.abs(bboxes[:, 2] - bboxes[:, 0])
edges2 = paddle.abs(bboxes[:, 3] - bboxes[:, 1])
rbox_w = paddle.maximum(edges1, edges2)
rbox_h = paddle.minimum(edges1, edges2)
# set angle
inds = edges1 < edges2
inds = paddle.cast(inds, 'int32')
rboxes_angle = inds * np.pi / 2.0
rboxes = paddle.stack(
(x_ctr, y_ctr, rbox_w, rbox_h, rboxes_angle), axis=1)
return rboxes
# deltas to rbox
def delta2rbox(self, rrois, deltas, wh_ratio_clip=1e-6):
"""
:param rrois: (cx, cy, w, h, theta)
:param deltas: (dx, dy, dw, dh, dtheta)
:param means: means of anchor
:param stds: stds of anchor
:param wh_ratio_clip: clip threshold of wh_ratio
:return:
"""
deltas = paddle.reshape(deltas, [-1, 5])
rrois = paddle.reshape(rrois, [-1, 5])
# fix dy2st bug denorm_deltas = deltas * self.stds + self.means
denorm_deltas = paddle.add(paddle.multiply(deltas, self.stds), self.means)
dx = denorm_deltas[:, 0]
dy = denorm_deltas[:, 1]
dw = denorm_deltas[:, 2]
dh = denorm_deltas[:, 3]
dangle = denorm_deltas[:, 4]
max_ratio = np.abs(np.log(wh_ratio_clip))
dw = paddle.clip(dw, min=-max_ratio, max=max_ratio)
dh = paddle.clip(dh, min=-max_ratio, max=max_ratio)
rroi_x = rrois[:, 0]
rroi_y = rrois[:, 1]
rroi_w = rrois[:, 2]
rroi_h = rrois[:, 3]
rroi_angle = rrois[:, 4]
gx = dx * rroi_w * paddle.cos(rroi_angle) - dy * rroi_h * paddle.sin(
rroi_angle) + rroi_x
gy = dx * rroi_w * paddle.sin(rroi_angle) + dy * rroi_h * paddle.cos(
rroi_angle) + rroi_y
gw = rroi_w * dw.exp()
gh = rroi_h * dh.exp()
ga = np.pi * dangle + rroi_angle
ga = (ga + np.pi / 4) % np.pi - np.pi / 4
ga = paddle.to_tensor(ga)
gw = paddle.to_tensor(gw, dtype='float32')
gh = paddle.to_tensor(gh, dtype='float32')
bboxes = paddle.stack([gx, gy, gw, gh, ga], axis=-1)
return bboxes
def bbox_decode(self,
bbox_preds,
anchors):
"""decode bbox from deltas
Args:
bbox_preds: [N,H,W,5]
anchors: [H*W,5]
return:
bboxes: [N,H,W,5]
"""
num_imgs, H, W, _ = bbox_preds.shape
bbox_delta = paddle.reshape(bbox_preds, [-1, 5])
bboxes = self.delta2rbox(anchors, bbox_delta)
return bboxes
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册