未验证 提交 51574ea1 编写于 作者: C cnn 提交者: GitHub

[dev] reorganize code of s2anet_head (#3481)

* reorganize the code, and fix bug

* fix typo

* remve comment

* reorganize code

* Revert "reorganize code"

This reverts commit 4f928c3a46b264ba8815978b914bf81a46b66b62.

* set weight

* set default config

* fix code style
上级 4f96dc2f
......@@ -36,6 +36,8 @@ S2ANetHead:
align_conv_type: 'Conv' # AlignConv Conv
align_conv_size: 3
use_sigmoid_cls: True
reg_loss_weight: [ 1.0, 1.0, 1.0, 1.0, 1.1 ]
cls_loss_weight: [ 1.1, 1.05 ]
RBoxAssigner:
pos_iou_thr: 0.5
......@@ -52,4 +54,3 @@ S2ANetBBoxPostProcess:
score_threshold: 0.05
nms_threshold: 0.1
normalized: False
#background_label: -1
_BASE_: [
it _BASE_: [
'../datasets/dota.yml',
'../runtime.yml',
'_base_/s2anet_optimizer_1x.yml',
......@@ -6,3 +6,18 @@ _BASE_: [
'_base_/s2anet_reader.yml',
]
weights: output/s2anet_1x_dota/model_final
S2ANetHead:
anchor_strides: [8, 16, 32, 64, 128]
anchor_scales: [4]
anchor_ratios: [1.0]
anchor_assign: RBoxAssigner
stacked_convs: 2
feat_in: 256
feat_out: 256
num_classes: 15
align_conv_type: 'AlignConv' # AlignConv Conv
align_conv_size: 3
use_sigmoid_cls: True
reg_loss_weight: [1.0, 1.0, 1.0, 1.0, 1.1]
cls_loss_weight: [1.1, 1.05]
......@@ -19,3 +19,5 @@ S2ANetHead:
align_conv_type: 'Conv' # AlignConv Conv
align_conv_size: 3
use_sigmoid_cls: True
reg_loss_weight: [1.0, 1.0, 1.0, 1.0, 1.1]
cls_loss_weight: [1.1, 1.05]
......@@ -267,6 +267,150 @@ def bbox_iou(box1, box2, giou=False, diou=False, ciou=False, eps=1e-9):
return iou
def rect2rbox(bboxes):
"""
:param bboxes: shape (n, 4) (xmin, ymin, xmax, ymax)
:return: dbboxes: shape (n, 5) (x_ctr, y_ctr, w, h, angle)
"""
bboxes = bboxes.reshape(-1, 4)
num_boxes = bboxes.shape[0]
x_ctr = (bboxes[:, 2] + bboxes[:, 0]) / 2.0
y_ctr = (bboxes[:, 3] + bboxes[:, 1]) / 2.0
edges1 = np.abs(bboxes[:, 2] - bboxes[:, 0])
edges2 = np.abs(bboxes[:, 3] - bboxes[:, 1])
angles = np.zeros([num_boxes], dtype=bboxes.dtype)
inds = edges1 < edges2
rboxes = np.stack((x_ctr, y_ctr, edges1, edges2, angles), axis=1)
rboxes[inds, 2] = edges2[inds]
rboxes[inds, 3] = edges1[inds]
rboxes[inds, 4] = np.pi / 2.0
return rboxes
def delta2rbox(rrois,
deltas,
means=[0, 0, 0, 0, 0],
stds=[1, 1, 1, 1, 1],
wh_ratio_clip=1e-6):
"""
:param rrois: (cx, cy, w, h, theta)
:param deltas: (dx, dy, dw, dh, dtheta)
:param means:
:param stds:
:param wh_ratio_clip:
:return:
"""
means = paddle.to_tensor(means)
stds = paddle.to_tensor(stds)
deltas = paddle.reshape(deltas, [-1, deltas.shape[-1]])
denorm_deltas = deltas * stds + means
dx = denorm_deltas[:, 0]
dy = denorm_deltas[:, 1]
dw = denorm_deltas[:, 2]
dh = denorm_deltas[:, 3]
dangle = denorm_deltas[:, 4]
max_ratio = np.abs(np.log(wh_ratio_clip))
dw = paddle.clip(dw, min=-max_ratio, max=max_ratio)
dh = paddle.clip(dh, min=-max_ratio, max=max_ratio)
rroi_x = rrois[:, 0]
rroi_y = rrois[:, 1]
rroi_w = rrois[:, 2]
rroi_h = rrois[:, 3]
rroi_angle = rrois[:, 4]
gx = dx * rroi_w * paddle.cos(rroi_angle) - dy * rroi_h * paddle.sin(
rroi_angle) + rroi_x
gy = dx * rroi_w * paddle.sin(rroi_angle) + dy * rroi_h * paddle.cos(
rroi_angle) + rroi_y
gw = rroi_w * dw.exp()
gh = rroi_h * dh.exp()
ga = np.pi * dangle + rroi_angle
ga = (ga + np.pi / 4) % np.pi - np.pi / 4
ga = paddle.to_tensor(ga)
gw = paddle.to_tensor(gw, dtype='float32')
gh = paddle.to_tensor(gh, dtype='float32')
bboxes = paddle.stack([gx, gy, gw, gh, ga], axis=-1)
return bboxes
def rbox2delta(proposals, gt, means=[0, 0, 0, 0, 0], stds=[1, 1, 1, 1, 1]):
"""
Args:
proposals:
gt:
means: 1x5
stds: 1x5
Returns:
"""
proposals = proposals.astype(np.float64)
PI = np.pi
gt_widths = gt[..., 2]
gt_heights = gt[..., 3]
gt_angle = gt[..., 4]
proposals_widths = proposals[..., 2]
proposals_heights = proposals[..., 3]
proposals_angle = proposals[..., 4]
coord = gt[..., 0:2] - proposals[..., 0:2]
dx = (np.cos(proposals[..., 4]) * coord[..., 0] + np.sin(proposals[..., 4])
* coord[..., 1]) / proposals_widths
dy = (-np.sin(proposals[..., 4]) * coord[..., 0] + np.cos(proposals[..., 4])
* coord[..., 1]) / proposals_heights
dw = np.log(gt_widths / proposals_widths)
dh = np.log(gt_heights / proposals_heights)
da = (gt_angle - proposals_angle)
da = (da + PI / 4) % PI - PI / 4
da /= PI
deltas = np.stack([dx, dy, dw, dh, da], axis=-1)
means = np.array(means, dtype=deltas.dtype)
stds = np.array(stds, dtype=deltas.dtype)
deltas = (deltas - means) / stds
deltas = deltas.astype(np.float32)
return deltas
def bbox_decode(bbox_preds,
anchors,
means=[0, 0, 0, 0, 0],
stds=[1, 1, 1, 1, 1]):
"""decode bbox from deltas
Args:
bbox_preds: [N,H,W,5]
anchors: [H*W,5]
return:
bboxes: [N,H,W,5]
"""
means = paddle.to_tensor(means)
stds = paddle.to_tensor(stds)
num_imgs, H, W, _ = bbox_preds.shape
bboxes_list = []
for img_id in range(num_imgs):
bbox_pred = bbox_preds[img_id]
# bbox_pred.shape=[5,H,W]
bbox_delta = bbox_pred
anchors = paddle.to_tensor(anchors)
bboxes = delta2rbox(
anchors, bbox_delta, means, stds, wh_ratio_clip=1e-6)
bboxes = paddle.reshape(bboxes, [H, W, 5])
bboxes_list.append(bboxes)
return paddle.stack(bboxes_list, axis=0)
def poly2rbox(polys):
"""
poly:[x0,y0,x1,y1,x2,y2,x3,y3]
......
......@@ -17,21 +17,26 @@ import paddle.nn as nn
import paddle.nn.functional as F
from paddle.nn.initializer import Normal, Constant
from ppdet.core.workspace import register
from ppdet.modeling import ops
from ppdet.modeling import bbox_utils
from ppdet.modeling.proposal_generator.target_layer import RBoxAssigner
import numpy as np
class S2ANetAnchorGenerator(nn.Layer):
class S2ANetAnchorGenerator(object):
"""
AnchorGenerator by paddle
S2ANetAnchorGenerator by np
"""
def __init__(self, base_size, scales, ratios, scale_major=True, ctr=None):
super(S2ANetAnchorGenerator, self).__init__()
def __init__(self,
base_size=8,
scales=1.0,
ratios=1.0,
scale_major=True,
ctr=None):
self.base_size = base_size
self.scales = paddle.to_tensor(scales)
self.ratios = paddle.to_tensor(ratios)
self.scales = scales
self.ratios = ratios
self.scale_major = scale_major
self.ctr = ctr
self.base_anchors = self.gen_base_anchors()
......@@ -49,7 +54,7 @@ class S2ANetAnchorGenerator(nn.Layer):
else:
x_ctr, y_ctr = self.ctr
h_ratios = paddle.sqrt(self.ratios)
h_ratios = np.sqrt(self.ratios)
w_ratios = 1 / h_ratios
if self.scale_major:
ws = (w * w_ratios[:] * self.scales[:]).reshape([-1])
......@@ -58,51 +63,53 @@ class S2ANetAnchorGenerator(nn.Layer):
ws = (w * self.scales[:] * w_ratios[:]).reshape([-1])
hs = (h * self.scales[:] * h_ratios[:]).reshape([-1])
base_anchors = paddle.stack(
# yapf: disable
base_anchors = np.stack(
[
x_ctr - 0.5 * (ws - 1), y_ctr - 0.5 * (hs - 1),
x_ctr + 0.5 * (ws - 1), y_ctr + 0.5 * (hs - 1)
],
axis=-1)
base_anchors = paddle.round(base_anchors)
base_anchors = np.round(base_anchors)
# yapf: enable
return base_anchors
def _meshgrid(self, x, y, row_major=True):
yy, xx = paddle.meshgrid(x, y)
yy = yy.reshape([-1])
xx = xx.reshape([-1])
xx, yy = np.meshgrid(x, y)
xx = xx.reshape(-1)
yy = yy.reshape(-1)
if row_major:
return xx, yy
else:
return yy, xx
def forward(self, featmap_size, stride=16):
def grid_anchors(self, featmap_size, stride=16):
# featmap_size*stride project it to original area
base_anchors = self.base_anchors
feat_h = featmap_size[0]
feat_w = featmap_size[1]
shift_x = paddle.arange(0, feat_w, 1, 'int32') * stride
shift_y = paddle.arange(0, feat_h, 1, 'int32') * stride
feat_h, feat_w = featmap_size
shift_x = np.arange(0, feat_w, 1, 'int32') * stride
shift_y = np.arange(0, feat_h, 1, 'int32') * stride
shift_xx, shift_yy = self._meshgrid(shift_x, shift_y)
shifts = paddle.stack([shift_xx, shift_yy, shift_xx, shift_yy], axis=-1)
shifts = np.stack([shift_xx, shift_yy, shift_xx, shift_yy], axis=-1)
all_anchors = base_anchors[:, :] + shifts[:, :]
all_anchors = all_anchors.reshape([feat_h * feat_w, 4])
all_anchors = base_anchors[None, :, :] + shifts[:, None, :]
return all_anchors
def valid_flags(self, featmap_size, valid_size):
feat_h, feat_w = featmap_size
valid_h, valid_w = valid_size
assert valid_h <= feat_h and valid_w <= feat_w
valid_x = paddle.zeros([feat_w], dtype='uint8')
valid_y = paddle.zeros([feat_h], dtype='uint8')
valid_x = np.zeros([feat_w], dtype='uint8')
valid_y = np.zeros([feat_h], dtype='uint8')
valid_x[:valid_w] = 1
valid_y[:valid_h] = 1
valid_xx, valid_yy = self._meshgrid(valid_x, valid_y)
valid = valid_xx & valid_yy
valid = valid[:, None].expand(
[valid.size(0), self.num_base_anchors]).reshape([-1])
valid = valid.reshape([-1])
# valid = valid[:, None].expand(
# [valid.size(0), self.num_base_anchors]).reshape([-1])
return valid
......@@ -225,8 +232,8 @@ class S2ANetHead(nn.Layer):
anchor_strides=[8, 16, 32, 64, 128],
anchor_scales=[4],
anchor_ratios=[1.0],
target_means=0.0,
target_stds=1.0,
target_means=(.0, .0, .0, .0, .0),
target_stds=(1.0, 1.0, 1.0, 1.0, 1.0),
align_conv_type='AlignConv',
align_conv_size=3,
use_sigmoid_cls=True,
......@@ -263,8 +270,6 @@ class S2ANetHead(nn.Layer):
self.anchor_generators.append(
S2ANetAnchorGenerator(anchor_base, anchor_scales,
anchor_ratios))
self.anchor_generators = paddle.nn.LayerList(self.anchor_generators)
self.add_sublayer('s2anet_anchor_gen', self.anchor_generators)
self.fam_cls_convs = nn.Sequential()
self.fam_reg_convs = nn.Sequential()
......@@ -399,9 +404,9 @@ class S2ANetHead(nn.Layer):
weight_attr=ParamAttr(initializer=Normal(0.0, 0.01)),
bias_attr=ParamAttr(initializer=Constant(0)))
self.featmap_size_list = []
self.init_anchors_list = []
self.rbox_anchors_list = []
self.base_anchors = dict()
self.featmap_sizes = dict()
self.base_anchors = dict()
self.refine_anchor_list = []
def forward(self, feats):
......@@ -411,27 +416,13 @@ class S2ANetHead(nn.Layer):
odm_reg_branch_list = []
odm_cls_branch_list = []
fam_reg1_branch_list = []
self.featmap_size_list = []
self.init_anchors_list = []
self.rbox_anchors_list = []
self.featmap_sizes = dict()
self.base_anchors = dict()
self.refine_anchor_list = []
for i, feat in enumerate(feats):
# prepare anchor
featmap_size = paddle.shape(feat)[-2:]
self.featmap_size_list.append(featmap_size)
init_anchors = self.anchor_generators[i](featmap_size,
self.anchor_strides[i])
init_anchors = paddle.reshape(
init_anchors, [featmap_size[0] * featmap_size[1], 4])
self.init_anchors_list.append(init_anchors)
rbox_anchors = self.rect2rbox(init_anchors)
self.rbox_anchors_list.append(rbox_anchors)
fam_cls_feat = self.fam_cls_convs(feat)
fam_cls = self.fam_cls(fam_cls_feat)
# [N, CLS, H, W] --> [N, H, W, CLS]
fam_cls = fam_cls.transpose([0, 2, 3, 1])
......@@ -447,13 +438,21 @@ class S2ANetHead(nn.Layer):
fam_reg_reshape = paddle.reshape(fam_reg, [fam_reg.shape[0], -1, 5])
fam_reg_branch_list.append(fam_reg_reshape)
# refine anchors
fam_reg1 = fam_reg.clone()
fam_reg1.stop_gradient = True
rbox_anchors.stop_gradient = True
fam_reg1_branch_list.append(fam_reg1)
refine_anchor = self.bbox_decode(
fam_reg1, rbox_anchors, self.target_stds, self.target_means)
# prepare anchor
featmap_size = feat.shape[-2:]
self.featmap_sizes[i] = featmap_size
init_anchors = self.anchor_generators[i].grid_anchors(
featmap_size, self.anchor_strides[i])
init_anchors = bbox_utils.rect2rbox(init_anchors)
self.base_anchors[(i, featmap_size[0])] = init_anchors
#fam_reg1 = fam_reg
#fam_reg1.stop_gradient = True
refine_anchor = bbox_utils.bbox_decode(
fam_reg.detach(), init_anchors, self.target_means,
self.target_stds)
self.refine_anchor_list.append(refine_anchor)
if self.align_conv_type == 'AlignConv':
......@@ -493,87 +492,6 @@ class S2ANetHead(nn.Layer):
odm_cls_branch_list, odm_reg_branch_list)
return self.s2anet_head_out
def rect2rbox(self, bboxes):
"""
:param bboxes: shape (n, 4) (xmin, ymin, xmax, ymax)
:return: dbboxes: shape (n, 5) (x_ctr, y_ctr, w, h, angle)
"""
num_boxes = paddle.shape(bboxes)[0]
x_ctr = (bboxes[:, 2] + bboxes[:, 0]) / 2.0
y_ctr = (bboxes[:, 3] + bboxes[:, 1]) / 2.0
edges1 = paddle.abs(bboxes[:, 2] - bboxes[:, 0])
edges2 = paddle.abs(bboxes[:, 3] - bboxes[:, 1])
rbox_w = paddle.maximum(edges1, edges2)
rbox_h = paddle.minimum(edges1, edges2)
# set angle
inds = edges1 < edges2
inds = paddle.cast(inds, 'int32')
inds1 = inds * paddle.arange(0, num_boxes)
rboxes_angle = inds1 * np.pi / 2.0
rboxes = paddle.stack(
(x_ctr, y_ctr, rbox_w, rbox_h, rboxes_angle), axis=1)
return rboxes
# deltas to rbox
def delta2rbox(self, rrois, deltas, means, stds, wh_ratio_clip=1e-6):
"""
:param rrois: (cx, cy, w, h, theta)
:param deltas: (dx, dy, dw, dh, dtheta)
:param means: means of anchor
:param stds: stds of anchor
:param wh_ratio_clip: clip threshold of wh_ratio
:return:
"""
deltas = paddle.reshape(deltas, [-1, 5])
rrois = paddle.reshape(rrois, [-1, 5])
pd_means = paddle.ones(shape=[5]) * means
pd_stds = paddle.ones(shape=[5]) * stds
denorm_deltas = deltas * pd_stds + pd_means
dx = denorm_deltas[:, 0]
dy = denorm_deltas[:, 1]
dw = denorm_deltas[:, 2]
dh = denorm_deltas[:, 3]
dangle = denorm_deltas[:, 4]
max_ratio = np.abs(np.log(wh_ratio_clip))
dw = paddle.clip(dw, min=-max_ratio, max=max_ratio)
dh = paddle.clip(dh, min=-max_ratio, max=max_ratio)
rroi_x = rrois[:, 0]
rroi_y = rrois[:, 1]
rroi_w = rrois[:, 2]
rroi_h = rrois[:, 3]
rroi_angle = rrois[:, 4]
gx = dx * rroi_w * paddle.cos(rroi_angle) - dy * rroi_h * paddle.sin(
rroi_angle) + rroi_x
gy = dx * rroi_w * paddle.sin(rroi_angle) + dy * rroi_h * paddle.cos(
rroi_angle) + rroi_y
gw = rroi_w * dw.exp()
gh = rroi_h * dh.exp()
ga = np.pi * dangle + rroi_angle
ga = (ga + np.pi / 4) % np.pi - np.pi / 4
bboxes = paddle.stack([gx, gy, gw, gh, ga], axis=-1)
return bboxes
def bbox_decode(self, bbox_preds, anchors, stds, means, wh_ratio_clip=1e-6):
"""decode bbox from deltas
Args:
bbox_preds: bbox_preds, shape=[N,H,W,5]
anchors: anchors, shape=[H,W,5]
return:
bboxes: return decoded bboxes, shape=[N*H*W,5]
"""
num_imgs, H, W, _ = bbox_preds.shape
bbox_delta = paddle.reshape(bbox_preds, [-1, 5])
bboxes = self.delta2rbox(anchors, bbox_delta, means, stds,
wh_ratio_clip)
return bboxes
def get_prediction(self, nms_pre):
refine_anchors = self.refine_anchor_list
fam_cls_branch_list, fam_reg_branch_list, odm_cls_branch_list, odm_reg_branch_list = self.s2anet_head_out
......@@ -584,7 +502,6 @@ class S2ANetHead(nn.Layer):
nms_pre,
cls_out_channels=self.cls_out_channels,
use_sigmoid_cls=self.use_sigmoid_cls)
return pred_scores, pred_bboxes
def smooth_l1_loss(self, pred, label, delta=1.0 / 9.0):
......@@ -603,125 +520,170 @@ class S2ANetHead(nn.Layer):
return loss
def get_fam_loss(self, fam_target, s2anet_head_out):
(feat_labels, feat_label_weights, feat_bbox_targets, feat_bbox_weights,
pos_inds, neg_inds) = fam_target
fam_cls_score, fam_bbox_pred = s2anet_head_out
# step1: sample count
(labels, label_weights, bbox_targets, bbox_weights, pos_inds,
neg_inds) = fam_target
fam_cls_branch_list, fam_reg_branch_list, odm_cls_branch_list, odm_reg_branch_list = s2anet_head_out
fam_cls_losses = []
fam_bbox_losses = []
st_idx = 0
featmap_sizes = [self.featmap_sizes[e] for e in self.featmap_sizes]
num_total_samples = len(pos_inds) + len(
neg_inds) if self.sampling else len(pos_inds)
num_total_samples = max(1, num_total_samples)
# step2: calc cls loss
feat_labels = feat_labels.reshape(-1)
feat_label_weights = feat_label_weights.reshape(-1)
fam_cls_score = paddle.squeeze(fam_cls_score, axis=0)
fam_cls_score1 = fam_cls_score
feat_labels = paddle.to_tensor(feat_labels)
feat_labels_one_hot = F.one_hot(feat_labels, self.cls_out_channels + 1)
feat_labels_one_hot = feat_labels_one_hot[:, 1:]
feat_labels_one_hot.stop_gradient = True
num_total_samples = paddle.to_tensor(
num_total_samples, dtype='float32', stop_gradient=True)
fam_cls = F.sigmoid_focal_loss(
fam_cls_score1,
feat_labels_one_hot,
normalizer=num_total_samples,
reduction='none')
feat_label_weights = feat_label_weights.reshape(
feat_label_weights.shape[0], 1)
feat_label_weights = np.repeat(
feat_label_weights, self.cls_out_channels, axis=1)
feat_label_weights = paddle.to_tensor(
feat_label_weights, stop_gradient=True)
fam_cls = fam_cls * feat_label_weights
fam_cls_total = paddle.sum(fam_cls)
# step3: regression loss
feat_bbox_targets = paddle.to_tensor(
feat_bbox_targets, dtype='float32', stop_gradient=True)
feat_bbox_targets = paddle.reshape(feat_bbox_targets, [-1, 5])
fam_bbox_pred = paddle.squeeze(fam_bbox_pred, axis=0)
fam_bbox_pred = paddle.reshape(fam_bbox_pred, [-1, 5])
fam_bbox = self.smooth_l1_loss(fam_bbox_pred, feat_bbox_targets)
loss_weight = paddle.to_tensor(
self.reg_loss_weight, dtype='float32', stop_gradient=True)
fam_bbox = paddle.multiply(fam_bbox, loss_weight)
feat_bbox_weights = paddle.to_tensor(
feat_bbox_weights, stop_gradient=True)
fam_bbox = fam_bbox * feat_bbox_weights
fam_bbox_total = paddle.sum(fam_bbox) / num_total_samples
for idx, feat_size in enumerate(featmap_sizes):
feat_anchor_num = feat_size[0] * feat_size[1]
# step1: get data
feat_labels = labels[st_idx:st_idx + feat_anchor_num]
feat_label_weights = label_weights[st_idx:st_idx + feat_anchor_num]
feat_bbox_targets = bbox_targets[st_idx:st_idx + feat_anchor_num, :]
feat_bbox_weights = bbox_weights[st_idx:st_idx + feat_anchor_num, :]
st_idx += feat_anchor_num
# step2: calc cls loss
feat_labels = feat_labels.reshape(-1)
feat_label_weights = feat_label_weights.reshape(-1)
fam_cls_score = fam_cls_branch_list[idx]
fam_cls_score = paddle.squeeze(fam_cls_score, axis=0)
fam_cls_score1 = fam_cls_score
feat_labels = paddle.to_tensor(feat_labels)
feat_labels_one_hot = paddle.nn.functional.one_hot(
feat_labels, self.cls_out_channels + 1)
feat_labels_one_hot = feat_labels_one_hot[:, 1:]
feat_labels_one_hot.stop_gradient = True
num_total_samples = paddle.to_tensor(
num_total_samples, dtype='float32', stop_gradient=True)
fam_cls = F.sigmoid_focal_loss(
fam_cls_score1,
feat_labels_one_hot,
normalizer=num_total_samples,
reduction='none')
feat_label_weights = feat_label_weights.reshape(
feat_label_weights.shape[0], 1)
feat_label_weights = np.repeat(
feat_label_weights, self.cls_out_channels, axis=1)
feat_label_weights = paddle.to_tensor(
feat_label_weights, stop_gradient=True)
fam_cls = fam_cls * feat_label_weights
fam_cls_total = paddle.sum(fam_cls)
fam_cls_losses.append(fam_cls_total)
# step3: regression loss
fam_bbox_pred = fam_reg_branch_list[idx]
feat_bbox_targets = paddle.to_tensor(
feat_bbox_targets, dtype='float32', stop_gradient=True)
feat_bbox_targets = paddle.reshape(feat_bbox_targets, [-1, 5])
fam_bbox_pred = fam_reg_branch_list[idx]
fam_bbox_pred = paddle.squeeze(fam_bbox_pred, axis=0)
fam_bbox_pred = paddle.reshape(fam_bbox_pred, [-1, 5])
fam_bbox = self.smooth_l1_loss(fam_bbox_pred, feat_bbox_targets)
loss_weight = paddle.to_tensor(
self.reg_loss_weight, dtype='float32', stop_gradient=True)
fam_bbox = paddle.multiply(fam_bbox, loss_weight)
feat_bbox_weights = paddle.to_tensor(
feat_bbox_weights, stop_gradient=True)
fam_bbox = fam_bbox * feat_bbox_weights
fam_bbox_total = paddle.sum(fam_bbox) / num_total_samples
fam_bbox_losses.append(fam_bbox_total)
fam_cls_loss = paddle.add_n(fam_cls_losses)
fam_cls_loss_weight = paddle.to_tensor(
self.cls_loss_weight[0], dtype='float32', stop_gradient=True)
fam_cls_loss = fam_cls_total * fam_cls_loss_weight
fam_reg_loss = paddle.add_n(fam_bbox_total)
fam_cls_loss = fam_cls_loss * fam_cls_loss_weight
fam_reg_loss = paddle.add_n(fam_bbox_losses)
return fam_cls_loss, fam_reg_loss
def get_odm_loss(self, odm_target, s2anet_head_out):
(feat_labels, feat_label_weights, feat_bbox_targets, feat_bbox_weights,
pos_inds, neg_inds) = odm_target
odm_cls_score, odm_bbox_pred = s2anet_head_out
# step1: sample count
(labels, label_weights, bbox_targets, bbox_weights, pos_inds,
neg_inds) = odm_target
fam_cls_branch_list, fam_reg_branch_list, odm_cls_branch_list, odm_reg_branch_list = s2anet_head_out
odm_cls_losses = []
odm_bbox_losses = []
st_idx = 0
featmap_sizes = [self.featmap_sizes[e] for e in self.featmap_sizes]
num_total_samples = len(pos_inds) + len(
neg_inds) if self.sampling else len(pos_inds)
num_total_samples = max(1, num_total_samples)
# step2: calc cls loss
feat_labels = feat_labels.reshape(-1)
feat_label_weights = feat_label_weights.reshape(-1)
odm_cls_score = paddle.squeeze(odm_cls_score, axis=0)
odm_cls_score1 = odm_cls_score
feat_labels = paddle.to_tensor(feat_labels)
feat_labels_one_hot = F.one_hot(feat_labels, self.cls_out_channels + 1)
feat_labels_one_hot = feat_labels_one_hot[:, 1:]
feat_labels_one_hot.stop_gradient = True
num_total_samples = paddle.to_tensor(
num_total_samples, dtype='float32', stop_gradient=True)
odm_cls = F.sigmoid_focal_loss(
odm_cls_score1,
feat_labels_one_hot,
normalizer=num_total_samples,
reduction='none')
feat_label_weights = feat_label_weights.reshape(
feat_label_weights.shape[0], 1)
feat_label_weights = np.repeat(
feat_label_weights, self.cls_out_channels, axis=1)
feat_label_weights = paddle.to_tensor(
feat_label_weights, stop_gradient=True)
odm_cls = odm_cls * feat_label_weights
odm_cls_total = paddle.sum(odm_cls)
# step3: regression loss
feat_bbox_targets = paddle.to_tensor(
feat_bbox_targets, dtype='float32', stop_gradient=True)
feat_bbox_targets = paddle.reshape(feat_bbox_targets, [-1, 5])
odm_bbox_pred = paddle.squeeze(odm_bbox_pred, axis=0)
odm_bbox_pred = paddle.reshape(odm_bbox_pred, [-1, 5])
odm_bbox = self.smooth_l1_loss(odm_bbox_pred, feat_bbox_targets)
loss_weight = paddle.to_tensor(
self.reg_loss_weight, dtype='float32', stop_gradient=True)
odm_bbox = paddle.multiply(odm_bbox, loss_weight)
feat_bbox_weights = paddle.to_tensor(
feat_bbox_weights, stop_gradient=True)
odm_bbox = odm_bbox * feat_bbox_weights
odm_bbox_total = paddle.sum(odm_bbox) / num_total_samples
for idx, feat_size in enumerate(featmap_sizes):
feat_anchor_num = feat_size[0] * feat_size[1]
# step1: get data
feat_labels = labels[st_idx:st_idx + feat_anchor_num]
feat_label_weights = label_weights[st_idx:st_idx + feat_anchor_num]
feat_bbox_targets = bbox_targets[st_idx:st_idx + feat_anchor_num, :]
feat_bbox_weights = bbox_weights[st_idx:st_idx + feat_anchor_num, :]
st_idx += feat_anchor_num
# step2: calc cls loss
feat_labels = feat_labels.reshape(-1)
feat_label_weights = feat_label_weights.reshape(-1)
odm_cls_score = odm_cls_branch_list[idx]
odm_cls_score = paddle.squeeze(odm_cls_score, axis=0)
odm_cls_score1 = odm_cls_score
feat_labels = paddle.to_tensor(feat_labels)
feat_labels_one_hot = paddle.nn.functional.one_hot(
feat_labels, self.cls_out_channels + 1)
feat_labels_one_hot = feat_labels_one_hot[:, 1:]
feat_labels_one_hot.stop_gradient = True
num_total_samples = paddle.to_tensor(
num_total_samples, dtype='float32', stop_gradient=True)
odm_cls = F.sigmoid_focal_loss(
odm_cls_score1,
feat_labels_one_hot,
normalizer=num_total_samples,
reduction='none')
feat_label_weights = feat_label_weights.reshape(
feat_label_weights.shape[0], 1)
feat_label_weights = np.repeat(
feat_label_weights, self.cls_out_channels, axis=1)
feat_label_weights = paddle.to_tensor(feat_label_weights)
feat_label_weights.stop_gradient = True
odm_cls = odm_cls * feat_label_weights
odm_cls_total = paddle.sum(odm_cls)
odm_cls_losses.append(odm_cls_total)
# # step3: regression loss
feat_bbox_targets = paddle.to_tensor(
feat_bbox_targets, dtype='float32')
feat_bbox_targets = paddle.reshape(feat_bbox_targets, [-1, 5])
feat_bbox_targets.stop_gradient = True
odm_bbox_pred = odm_reg_branch_list[idx]
odm_bbox_pred = paddle.squeeze(odm_bbox_pred, axis=0)
odm_bbox_pred = paddle.reshape(odm_bbox_pred, [-1, 5])
odm_bbox = self.smooth_l1_loss(odm_bbox_pred, feat_bbox_targets)
loss_weight = paddle.to_tensor(
self.reg_loss_weight, dtype='float32', stop_gradient=True)
odm_bbox = paddle.multiply(odm_bbox, loss_weight)
feat_bbox_weights = paddle.to_tensor(
feat_bbox_weights, stop_gradient=True)
odm_bbox = odm_bbox * feat_bbox_weights
odm_bbox_total = paddle.sum(odm_bbox) / num_total_samples
odm_bbox_losses.append(odm_bbox_total)
odm_cls_loss = paddle.add_n(odm_cls_losses)
odm_cls_loss_weight = paddle.to_tensor(
self.cls_loss_weight[0], dtype='float32', stop_gradient=True)
odm_cls_loss = odm_cls_total * odm_cls_loss_weight
odm_reg_loss = paddle.add_n(odm_bbox_total)
self.cls_loss_weight[1], dtype='float32', stop_gradient=True)
odm_cls_loss = odm_cls_loss * odm_cls_loss_weight
odm_reg_loss = paddle.add_n(odm_bbox_losses)
return odm_cls_loss, odm_reg_loss
def get_loss(self, inputs):
......@@ -743,38 +705,46 @@ class S2ANetHead(nn.Layer):
is_crowd = inputs['is_crowd'][im_id].numpy()
gt_labels = gt_labels + 1
# featmap_sizes
featmap_sizes = [self.featmap_sizes[e] for e in self.featmap_sizes]
anchors_list, valid_flag_list = self.get_init_anchors(featmap_sizes,
np_im_shape)
anchors_list_all = []
for ii, anchor in enumerate(anchors_list):
anchor = anchor.reshape(-1, 4)
anchor = bbox_utils.rect2rbox(anchor)
anchors_list_all.extend(anchor)
anchors_list_all = np.array(anchors_list_all)
# get im_feat
fam_cls_feats_list = [e[im_id] for e in self.s2anet_head_out[0]]
fam_reg_feats_list = [e[im_id] for e in self.s2anet_head_out[1]]
odm_cls_feats_list = [e[im_id] for e in self.s2anet_head_out[2]]
odm_reg_feats_list = [e[im_id] for e in self.s2anet_head_out[3]]
im_s2anet_head_out = (fam_cls_feats_list, fam_reg_feats_list,
odm_cls_feats_list, odm_reg_feats_list)
# FAM
for idx, rbox_anchors in enumerate(self.rbox_anchors_list):
rbox_anchors = rbox_anchors.numpy()
rbox_anchors = rbox_anchors.reshape(-1, 5)
im_fam_target = self.anchor_assign(rbox_anchors, gt_bboxes,
gt_labels, is_crowd)
# feat
fam_cls_feat = self.s2anet_head_out[0][idx][im_id]
fam_reg_feat = self.s2anet_head_out[1][idx][im_id]
im_s2anet_fam_feat = (fam_cls_feat, fam_reg_feat)
im_fam_target = self.anchor_assign(anchors_list_all, gt_bboxes,
gt_labels, is_crowd)
if im_fam_target is not None:
im_fam_cls_loss, im_fam_reg_loss = self.get_fam_loss(
im_fam_target, im_s2anet_fam_feat)
im_fam_target, im_s2anet_head_out)
fam_cls_loss_lst.append(im_fam_cls_loss)
fam_reg_loss_lst.append(im_fam_reg_loss)
# ODM
for idx, refine_anchors in enumerate(self.refine_anchor_list):
refine_anchors = refine_anchors.numpy()
refine_anchors = refine_anchors.reshape(-1, 5)
im_odm_target = self.anchor_assign(refine_anchors, gt_bboxes,
gt_labels, is_crowd)
odm_cls_feat = self.s2anet_head_out[2][idx][im_id]
odm_reg_feat = self.s2anet_head_out[3][idx][im_id]
refine_anchors_list, valid_flag_list = self.get_refine_anchors(
featmap_sizes, image_shape=np_im_shape)
refine_anchors_list = np.array(refine_anchors_list)
im_odm_target = self.anchor_assign(refine_anchors_list, gt_bboxes,
gt_labels, is_crowd)
im_s2anet_odm_feat = (odm_cls_feat, odm_reg_feat)
if im_odm_target is not None:
im_odm_cls_loss, im_odm_reg_loss = self.get_odm_loss(
im_odm_target, im_s2anet_odm_feat)
im_odm_target, im_s2anet_head_out)
odm_cls_loss_lst.append(im_odm_cls_loss)
odm_reg_loss_lst.append(im_odm_reg_loss)
fam_cls_loss = paddle.add_n(fam_cls_loss_lst)
fam_reg_loss = paddle.add_n(fam_reg_loss_lst)
odm_cls_loss = paddle.add_n(odm_cls_loss_lst)
......@@ -786,6 +756,65 @@ class S2ANetHead(nn.Layer):
'odm_reg_loss': odm_reg_loss
}
def get_init_anchors(self, featmap_sizes, image_shape):
"""Get anchors according to feature map sizes.
Args:
featmap_sizes (list[tuple]): Multi-level feature map sizes.
image_shape (list[dict]): Image meta info.
Returns:
tuple: anchors of each image, valid flags of each image
"""
num_levels = len(featmap_sizes)
# since feature map sizes of all images are the same, we only compute
# anchors for one time
anchor_list = []
for i in range(num_levels):
anchors = self.anchor_generators[i].grid_anchors(
featmap_sizes[i], self.anchor_strides[i])
anchor_list.append(anchors)
# for each image, we compute valid flags of multi level anchors
valid_flag_list = []
for i in range(num_levels):
anchor_stride = self.anchor_strides[i]
feat_h, feat_w = featmap_sizes[i]
h, w = image_shape
valid_feat_h = min(int(np.ceil(h / anchor_stride)), feat_h)
valid_feat_w = min(int(np.ceil(w / anchor_stride)), feat_w)
flags = self.anchor_generators[i].valid_flags(
(feat_h, feat_w), (valid_feat_h, valid_feat_w))
valid_flag_list.append(flags)
return anchor_list, valid_flag_list
def get_refine_anchors(self, featmap_sizes, image_shape):
num_levels = len(featmap_sizes)
refine_anchors_list = []
for i in range(num_levels):
refine_anchor = self.refine_anchor_list[i]
refine_anchor = paddle.squeeze(refine_anchor, axis=0)
refine_anchor = refine_anchor.numpy()
refine_anchor = np.reshape(refine_anchor,
[-1, refine_anchor.shape[-1]])
refine_anchors_list.extend(refine_anchor)
# for each image, we compute valid flags of multi level anchors
valid_flag_list = []
for i in range(num_levels):
anchor_stride = self.anchor_strides[i]
feat_h, feat_w = featmap_sizes[i]
h, w = image_shape
valid_feat_h = min(int(np.ceil(h / anchor_stride)), feat_h)
valid_feat_w = min(int(np.ceil(w / anchor_stride)), feat_w)
flags = self.anchor_generators[i].valid_flags(
(feat_h, feat_w), (valid_feat_h, valid_feat_w))
valid_flag_list.append(flags)
return refine_anchors_list, valid_flag_list
def get_bboxes(self, cls_score_list, bbox_pred_list, mlvl_anchors, nms_pre,
cls_out_channels, use_sigmoid_cls):
assert len(cls_score_list) == len(bbox_pred_list) == len(mlvl_anchors)
......@@ -819,8 +848,10 @@ class S2ANetHead(nn.Layer):
bbox_pred = paddle.gather(bbox_pred, topk_inds)
scores = paddle.gather(scores, topk_inds)
bboxes = self.delta2rbox(anchors, bbox_pred, self.target_means,
self.target_stds)
target_means = (.0, .0, .0, .0, .0)
target_stds = (1.0, 1.0, 1.0, 1.0, 1.0)
bboxes = bbox_utils.delta2rbox(anchors, bbox_pred, target_means,
target_stds)
mlvl_bboxes.append(bboxes)
mlvl_scores.append(scores)
......
......@@ -296,7 +296,7 @@ class RBoxAssigner(object):
anchors = anchors.reshape(-1, anchors.shape[-1])
assert anchors.ndim == 2
anchor_num = anchors.shape[0]
anchor_valid = np.ones((anchor_num), np.uint8)
anchor_valid = np.ones((anchor_num), np.int32)
anchor_inds = np.arange(anchor_num)
return anchor_inds
......@@ -371,9 +371,8 @@ class RBoxAssigner(object):
# calc rbox iou
anchors_xc_yc = anchors_xc_yc.astype(np.float32)
gt_bboxes_xc_yc = gt_bboxes_xc_yc.astype(np.float32)
anchors_xc_yc = paddle.to_tensor(anchors_xc_yc, place=paddle.CPUPlace())
gt_bboxes_xc_yc = paddle.to_tensor(
gt_bboxes_xc_yc, place=paddle.CPUPlace())
anchors_xc_yc = paddle.to_tensor(anchors_xc_yc)
gt_bboxes_xc_yc = paddle.to_tensor(gt_bboxes_xc_yc)
try:
from rbox_iou_ops import rbox_iou
......@@ -433,8 +432,7 @@ class RBoxAssigner(object):
ignore_iof_thr = self.ignore_iof_thr
anchor_num = anchors.shape[0]
anchors_inds = self.anchor_valid(anchors)
anchors = anchors[anchors_inds]
gt_bboxes = gt_bboxes
is_crowd_slice = is_crowd
not_crowd_inds = np.where(is_crowd_slice == 0)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册