diff --git a/configs/dota/README.md b/configs/dota/README.md index 934e597f79056653db231298930384c99b2f0408..e8a44dac329cf2c95a35c1fdc65c79bedc45dba7 100644 --- a/configs/dota/README.md +++ b/configs/dota/README.md @@ -129,9 +129,10 @@ python3.7 tools/infer.py -c configs/dota/s2anet_1x_dota.yml -o weights=./weights ### S2ANet模型 -| 模型 | GPU个数 | Conv类型 | mAP | 模型下载 | 配置文件 | -|:-----------:|:-------:|:----------:|:--------:| :----------:| :---------: | -| S2ANet | 8 | Conv | 71.42 | [model](https://paddledet.bj.bcebos.com/models/s2anet_conv_1x_dota.pdparams) | [config](https://github.com/PaddlePaddle/PaddleDetection/tree/develop/configs/dota/s2anet_conv_1x_dota.yml) | +| 模型 | Conv类型 | mAP | 模型下载 | 配置文件 | +|:-----------:|:----------:|:--------:| :----------:| :---------: | +| S2ANet | Conv | 71.42 | [model](https://paddledet.bj.bcebos.com/models/s2anet_conv_1x_dota.pdparams) | [config](https://github.com/PaddlePaddle/PaddleDetection/tree/develop/configs/dota/s2anet_conv_1x_dota.yml) | +| S2ANet | AlignConv | 74.0 | [model](https://paddledet.bj.bcebos.com/models/s2anet_alignconv_2x_dota.pdparams) | [config](https://github.com/PaddlePaddle/PaddleDetection/tree/develop/configs/dota/s2anet_alignconv_2x_dota.yml) | **注意:**这里使用`multiclass_nms`,与原作者使用nms略有不同,精度相比原始论文中高0.15 (71.27-->71.42)。 diff --git a/configs/dota/_base_/s2anet.yml b/configs/dota/_base_/s2anet.yml index b798e5fc9d336af0559ac5388cebf2fbfef058a5..fb6064224ab068b60bada406872d6af4238b3808 100644 --- a/configs/dota/_base_/s2anet.yml +++ b/configs/dota/_base_/s2anet.yml @@ -53,4 +53,3 @@ S2ANetBBoxPostProcess: score_threshold: 0.05 nms_threshold: 0.1 normalized: False - #background_label: -1 diff --git a/configs/dota/_base_/s2anet_optimizer_2x.yml b/configs/dota/_base_/s2anet_optimizer_2x.yml new file mode 100644 index 0000000000000000000000000000000000000000..c96027a6146807b3365731892fad8d312b0cdac0 --- /dev/null +++ b/configs/dota/_base_/s2anet_optimizer_2x.yml @@ -0,0 +1,20 @@ +epoch: 24 + +LearningRate: + base_lr: 0.005 + schedulers: + - !PiecewiseDecay + gamma: 0.1 + milestones: [14, 20] + - !LinearWarmup + start_factor: 0.3333333333333333 + steps: 1000 + +OptimizerBuilder: + optimizer: + momentum: 0.9 + type: Momentum + regularizer: + factor: 0.0001 + type: L2 + clip_grad_by_norm: 35 diff --git a/configs/dota/s2anet_1x_spine.yml b/configs/dota/s2anet_1x_spine.yml index 6e24d727241f2fd87755bf8f8a53eceac9fc1576..017ee28e51f12bdd31c3a0b87222a9cc3c3b55a1 100644 --- a/configs/dota/s2anet_1x_spine.yml +++ b/configs/dota/s2anet_1x_spine.yml @@ -24,5 +24,7 @@ S2ANetHead: align_conv_type: 'AlignConv' # AlignConv Conv align_conv_size: 3 use_sigmoid_cls: True - reg_loss_weight: [1.0, 1.0, 1.0, 1.0, 1.05] - cls_loss_weight: [1.05, 1.0] + reg_loss_weight: [1.0, 1.0, 1.0, 1.0, 1.05] + cls_loss_weight: [1.05, 1.0] + reg_loss_type: gwd + use_paddle_anchor: False diff --git a/configs/dota/s2anet_1x_dota.yml b/configs/dota/s2anet_alignconv_2x_dota.yml similarity index 54% rename from configs/dota/s2anet_1x_dota.yml rename to configs/dota/s2anet_alignconv_2x_dota.yml index 37959a0e60c0b5ce1b32527142f1fed9f384ae01..7355a0122aa5a2d131740d0ee655c1226e1c40ae 100644 --- a/configs/dota/s2anet_1x_dota.yml +++ b/configs/dota/s2anet_alignconv_2x_dota.yml @@ -1,11 +1,13 @@ -it _BASE_: [ +_BASE_: [ '../datasets/dota.yml', '../runtime.yml', - '_base_/s2anet_optimizer_1x.yml', + '_base_/s2anet_optimizer_2x.yml', '_base_/s2anet.yml', '_base_/s2anet_reader.yml', ] -weights: output/s2anet_1x_dota/model_final +pretrain_weights: https://paddledet.bj.bcebos.com/models/pretrained/ResNet50_vd_ssld_v2_pretrained.pdparams + +weights: output/s2anet_alignconv_2x_dota/model_final S2ANetHead: anchor_strides: [8, 16, 32, 64, 128] @@ -19,5 +21,6 @@ S2ANetHead: align_conv_type: 'AlignConv' # AlignConv Conv align_conv_size: 3 use_sigmoid_cls: True - reg_loss_weight: [1.0, 1.0, 1.0, 1.0, 1.1] - cls_loss_weight: [1.1, 1.05] + reg_loss_weight: [1.0, 1.0, 1.0, 1.0, 1.05] + cls_loss_weight: [1.05, 1.0] + #reg_loss_type: 'l1' # 'l1' 'gwd' diff --git a/configs/dota/s2anet_conv_1x_dota.yml b/configs/dota/s2anet_conv_1x_dota.yml index 2a192ecf96cab529c8b83ac0b54b255b86432ece..a81c76f3f4f6a81961aff802db7c8442df79293f 100644 --- a/configs/dota/s2anet_conv_1x_dota.yml +++ b/configs/dota/s2anet_conv_1x_dota.yml @@ -7,6 +7,13 @@ _BASE_: [ ] weights: output/s2anet_1x_dota/model_final +ResNet: + depth: 50 + variant: b + norm_type: bn + return_idx: [1,2,3] + num_stages: 4 + S2ANetHead: anchor_strides: [8, 16, 32, 64, 128] anchor_scales: [4] diff --git a/ppdet/modeling/heads/s2anet_head.py b/ppdet/modeling/heads/s2anet_head.py index 67e8fb0ec32d1955614be4f797b1702980035f56..d578bbef4c02c12b9d2ac6dd5d6a5263621f772d 100644 --- a/ppdet/modeling/heads/s2anet_head.py +++ b/ppdet/modeling/heads/s2anet_head.py @@ -102,8 +102,7 @@ class S2ANetAnchorGenerator(nn.Layer): valid_xx, valid_yy = self._meshgrid(valid_x, valid_y) valid = valid_xx & valid_yy valid = paddle.reshape(valid, [-1, 1]) - valid = paddle.expand(valid, - [-1, self.num_base_anchors]).reshape([-1]) + valid = paddle.expand(valid, [-1, self.num_base_anchors]).reshape([-1]) return valid @@ -179,9 +178,12 @@ class AlignConv(nn.Layer): offset_x = x_anchor - x_conv offset_y = y_anchor - y_conv offset = paddle.stack([offset_y, offset_x], axis=-1) - offset = paddle.reshape(offset, [feat_h * feat_w, self.kernel_size * self.kernel_size * 2]) + offset = paddle.reshape( + offset, [feat_h * feat_w, self.kernel_size * self.kernel_size * 2]) offset = paddle.transpose(offset, [1, 0]) - offset = paddle.reshape(offset, [1, self.kernel_size * self.kernel_size * 2, feat_h, feat_w]) + offset = paddle.reshape( + offset, + [1, self.kernel_size * self.kernel_size * 2, feat_h, feat_w]) return offset def forward(self, x, refine_anchors, featmap_size, stride): @@ -260,8 +262,8 @@ class S2ANetHead(nn.Layer): # anchor self.anchor_generators = [] for anchor_base in self.anchor_base_sizes: - self.anchor_generators.append( - S2ANetAnchorGenerator(anchor_base, anchor_scales, + self.anchor_generators.append( + S2ANetAnchorGenerator(anchor_base, anchor_scales, anchor_ratios)) self.anchor_generators = nn.LayerList(self.anchor_generators) @@ -440,8 +442,7 @@ class S2ANetHead(nn.Layer): init_anchors = paddle.to_tensor(init_anchors, dtype='float32') NA = featmap_size[0] * featmap_size[1] - init_anchors = paddle.reshape( - init_anchors, [NA, 4]) + init_anchors = paddle.reshape(init_anchors, [NA, 4]) init_anchors = self.rect2rbox(init_anchors) self.base_anchors_list.append(init_anchors) @@ -474,18 +475,19 @@ class S2ANetHead(nn.Layer): # [N, CLS, H, W] --> [N, H, W, CLS] odm_cls_score = odm_cls_score.transpose([0, 2, 3, 1]) odm_cls_score_shape = odm_cls_score.shape - odm_cls_score_reshape = paddle.reshape( - odm_cls_score, - [odm_cls_score_shape[0], odm_cls_score_shape[1] * odm_cls_score_shape[2], self.cls_out_channels]) + odm_cls_score_reshape = paddle.reshape(odm_cls_score, [ + odm_cls_score_shape[0], odm_cls_score_shape[1] * + odm_cls_score_shape[2], self.cls_out_channels + ]) odm_cls_branch_list.append(odm_cls_score_reshape) odm_bbox_pred = self.odm_reg(odm_reg_feat) # [N, 5, H, W] --> [N, H, W, 5] odm_bbox_pred = odm_bbox_pred.transpose([0, 2, 3, 1]) - odm_bbox_pred_reshape = paddle.reshape( - odm_bbox_pred, [-1, 5]) - odm_bbox_pred_reshape = paddle.unsqueeze(odm_bbox_pred_reshape, axis=0) + odm_bbox_pred_reshape = paddle.reshape(odm_bbox_pred, [-1, 5]) + odm_bbox_pred_reshape = paddle.unsqueeze( + odm_bbox_pred_reshape, axis=0) odm_reg_branch_list.append(odm_bbox_pred_reshape) self.s2anet_head_out = (fam_cls_branch_list, fam_reg_branch_list, @@ -499,12 +501,8 @@ class S2ANetHead(nn.Layer): odm_cls_branch_list = self.s2anet_head_out[2] odm_reg_branch_list = self.s2anet_head_out[3] pred_scores, pred_bboxes = self.get_bboxes( - odm_cls_branch_list, - odm_reg_branch_list, - refine_anchors, - nms_pre, - self.cls_out_channels, - self.use_sigmoid_cls) + odm_cls_branch_list, odm_reg_branch_list, refine_anchors, nms_pre, + self.cls_out_channels, self.use_sigmoid_cls) return pred_scores, pred_bboxes def smooth_l1_loss(self, pred, label, delta=1.0 / 9.0): @@ -523,8 +521,8 @@ class S2ANetHead(nn.Layer): return loss def get_fam_loss(self, fam_target, s2anet_head_out, reg_loss_type='gwd'): - (labels, label_weights, bbox_targets, bbox_weights, pos_inds, - neg_inds) = fam_target + (labels, label_weights, bbox_targets, bbox_weights, bbox_gt_bboxes, + pos_inds, neg_inds) = fam_target fam_cls_branch_list, fam_reg_branch_list, odm_cls_branch_list, odm_reg_branch_list = s2anet_head_out fam_cls_losses = [] @@ -543,7 +541,6 @@ class S2ANetHead(nn.Layer): feat_bbox_targets = bbox_targets[st_idx:st_idx + feat_anchor_num, :] feat_bbox_weights = bbox_weights[st_idx:st_idx + feat_anchor_num, :] - st_idx += feat_anchor_num # step2: calc cls loss feat_labels = feat_labels.reshape(-1) @@ -590,18 +587,53 @@ class S2ANetHead(nn.Layer): fam_bbox_pred = paddle.reshape(fam_bbox_pred, [-1, 5]) fam_bbox = self.smooth_l1_loss(fam_bbox_pred, feat_bbox_targets) - # iou_factor + fam_bbox = self.smooth_l1_loss(fam_bbox_pred, feat_bbox_targets) + loss_weight = paddle.to_tensor( + self.reg_loss_weight, dtype='float32', stop_gradient=True) + fam_bbox = paddle.multiply(fam_bbox, loss_weight) + feat_bbox_weights = paddle.to_tensor( + feat_bbox_weights, stop_gradient=True) + if reg_loss_type == 'l1': - fam_bbox = self.smooth_l1_loss(fam_bbox_pred, feat_bbox_targets) - loss_weight = paddle.to_tensor( - self.reg_loss_weight, dtype='float32', stop_gradient=True) - fam_bbox = paddle.multiply(fam_bbox, loss_weight) - feat_bbox_weights = paddle.to_tensor( - feat_bbox_weights, stop_gradient=True) fam_bbox = fam_bbox * feat_bbox_weights fam_bbox_total = paddle.sum(fam_bbox) / num_total_samples + elif reg_loss_type == 'iou' or reg_loss_type == 'gwd': + fam_bbox = paddle.sum(fam_bbox, axis=-1) + feat_bbox_weights = paddle.sum(feat_bbox_weights, axis=-1) + try: + from rbox_iou_ops import rbox_iou + except Exception as e: + print("import custom_ops error, try install rbox_iou_ops " \ + "following ppdet/ext_op/README.md", e) + sys.stdout.flush() + sys.exit(-1) + # calc iou + fam_bbox_decode = self.delta2rbox(self.base_anchors_list[idx], + fam_bbox_pred) + bbox_gt_bboxes = paddle.to_tensor( + bbox_gt_bboxes, + dtype=fam_bbox_decode.dtype, + place=fam_bbox_decode.place) + bbox_gt_bboxes.stop_gradient = True + iou = rbox_iou(fam_bbox_decode, bbox_gt_bboxes) + iou = paddle.diag(iou) + + if reg_loss_type == 'iou': + EPS = paddle.to_tensor( + 1e-8, dtype='float32', stop_gradient=True) + iou_factor = -1.0 * paddle.log(iou + EPS) / (fam_bbox + EPS) + iou_factor.stop_gradient = True + #fam_bbox = fam_bbox * iou_factor + elif reg_loss_type == 'gwd': + bbox_gt_bboxes_level = bbox_gt_bboxes[st_idx:st_idx + + feat_anchor_num, :] + fam_bbox_total = self.gwd_loss(fam_bbox_decode, + bbox_gt_bboxes_level) + fam_bbox_total = fam_bbox_total * feat_bbox_weights + fam_bbox_total = paddle.sum(fam_bbox_total) fam_bbox_losses.append(fam_bbox_total) + st_idx += feat_anchor_num fam_cls_loss = paddle.add_n(fam_cls_losses) fam_cls_loss_weight = paddle.to_tensor( @@ -611,8 +643,8 @@ class S2ANetHead(nn.Layer): return fam_cls_loss, fam_reg_loss def get_odm_loss(self, odm_target, s2anet_head_out, reg_loss_type='gwd'): - (labels, label_weights, bbox_targets, bbox_weights, pos_inds, - neg_inds) = odm_target + (labels, label_weights, bbox_targets, bbox_weights, bbox_gt_bboxes, + pos_inds, neg_inds) = odm_target fam_cls_branch_list, fam_reg_branch_list, odm_cls_branch_list, odm_reg_branch_list = s2anet_head_out odm_cls_losses = [] @@ -621,7 +653,7 @@ class S2ANetHead(nn.Layer): num_total_samples = len(pos_inds) + len( neg_inds) if self.sampling else len(pos_inds) num_total_samples = max(1, num_total_samples) - + for idx, feat_size in enumerate(self.featmap_sizes_list): feat_anchor_num = feat_size[0] * feat_size[1] @@ -631,7 +663,6 @@ class S2ANetHead(nn.Layer): feat_bbox_targets = bbox_targets[st_idx:st_idx + feat_anchor_num, :] feat_bbox_weights = bbox_weights[st_idx:st_idx + feat_anchor_num, :] - st_idx += feat_anchor_num # step2: calc cls loss feat_labels = feat_labels.reshape(-1) @@ -676,19 +707,53 @@ class S2ANetHead(nn.Layer): odm_bbox_pred = paddle.squeeze(odm_bbox_pred, axis=0) odm_bbox_pred = paddle.reshape(odm_bbox_pred, [-1, 5]) odm_bbox = self.smooth_l1_loss(odm_bbox_pred, feat_bbox_targets) - - # iou_factor odm not use_iou + + loss_weight = paddle.to_tensor( + self.reg_loss_weight, dtype='float32', stop_gradient=True) + odm_bbox = paddle.multiply(odm_bbox, loss_weight) + feat_bbox_weights = paddle.to_tensor( + feat_bbox_weights, stop_gradient=True) + if reg_loss_type == 'l1': - odm_bbox = self.smooth_l1_loss(odm_bbox_pred, feat_bbox_targets) - loss_weight = paddle.to_tensor( - self.reg_loss_weight, dtype='float32', stop_gradient=True) - odm_bbox = paddle.multiply(odm_bbox, loss_weight) - feat_bbox_weights = paddle.to_tensor( - feat_bbox_weights, stop_gradient=True) odm_bbox = odm_bbox * feat_bbox_weights odm_bbox_total = paddle.sum(odm_bbox) / num_total_samples - + elif reg_loss_type == 'iou' or reg_loss_type == 'gwd': + odm_bbox = paddle.sum(odm_bbox, axis=-1) + feat_bbox_weights = paddle.sum(feat_bbox_weights, axis=-1) + try: + from rbox_iou_ops import rbox_iou + except Exception as e: + print("import custom_ops error, try install rbox_iou_ops " \ + "following ppdet/ext_op/README.md", e) + sys.stdout.flush() + sys.exit(-1) + # calc iou + odm_bbox_decode = self.delta2rbox(self.refine_anchor_list[idx], + odm_bbox_pred) + bbox_gt_bboxes = paddle.to_tensor( + bbox_gt_bboxes, + dtype=odm_bbox_decode.dtype, + place=odm_bbox_decode.place) + bbox_gt_bboxes.stop_gradient = True + iou = rbox_iou(odm_bbox_decode, bbox_gt_bboxes) + iou = paddle.diag(iou) + + if reg_loss_type == 'iou': + EPS = paddle.to_tensor( + 1e-8, dtype='float32', stop_gradient=True) + iou_factor = -1.0 * paddle.log(iou + EPS) / (odm_bbox + EPS) + iou_factor.stop_gradient = True + # odm_bbox = odm_bbox * iou_factor + elif reg_loss_type == 'gwd': + bbox_gt_bboxes_level = bbox_gt_bboxes[st_idx:st_idx + + feat_anchor_num, :] + odm_bbox_total = self.gwd_loss(odm_bbox_decode, + bbox_gt_bboxes_level) + odm_bbox_total = odm_bbox_total * feat_bbox_weights + odm_bbox_total = paddle.sum(odm_bbox_total) + odm_bbox_losses.append(odm_bbox_total) + st_idx += feat_anchor_num odm_cls_loss = paddle.add_n(odm_cls_losses) odm_cls_loss_weight = paddle.to_tensor( @@ -737,11 +802,12 @@ class S2ANetHead(nn.Layer): fam_reg_loss_lst.append(im_fam_reg_loss) # ODM - np_refine_anchors_list = paddle.concat(self.refine_anchor_list).numpy() + np_refine_anchors_list = paddle.concat( + self.refine_anchor_list).numpy() np_refine_anchors_list = np.concatenate(np_refine_anchors_list) np_refine_anchors_list = np_refine_anchors_list.reshape(-1, 5) - im_odm_target = self.anchor_assign(np_refine_anchors_list, gt_bboxes, - gt_labels, is_crowd) + im_odm_target = self.anchor_assign(np_refine_anchors_list, + gt_bboxes, gt_labels, is_crowd) if im_odm_target is not None: im_odm_cls_loss, im_odm_reg_loss = self.get_odm_loss( @@ -841,7 +907,8 @@ class S2ANetHead(nn.Layer): deltas = paddle.reshape(deltas, [-1, 5]) rrois = paddle.reshape(rrois, [-1, 5]) # fix dy2st bug denorm_deltas = deltas * self.stds + self.means - denorm_deltas = paddle.add(paddle.multiply(deltas, self.stds), self.means) + denorm_deltas = paddle.add( + paddle.multiply(deltas, self.stds), self.means) dx = denorm_deltas[:, 0] dy = denorm_deltas[:, 1] @@ -872,9 +939,7 @@ class S2ANetHead(nn.Layer): bboxes = paddle.stack([gx, gy, gw, gh, ga], axis=-1) return bboxes - def bbox_decode(self, - bbox_preds, - anchors): + def bbox_decode(self, bbox_preds, anchors): """decode bbox from deltas Args: bbox_preds: [N,H,W,5] @@ -886,3 +951,109 @@ class S2ANetHead(nn.Layer): bbox_delta = paddle.reshape(bbox_preds, [-1, 5]) bboxes = self.delta2rbox(anchors, bbox_delta) return bboxes + + def trace(self, A): + tr = paddle.diagonal(A, axis1=-2, axis2=-1) + tr = paddle.sum(tr, axis=-1) + return tr + + def sqrt_newton_schulz_autograd(self, A, numIters): + A_shape = A.shape + batchSize = A_shape[0] + dim = A_shape[1] + + normA = A * A + normA = paddle.sum(normA, axis=1) + normA = paddle.sum(normA, axis=1) + normA = paddle.sqrt(normA) + normA1 = normA.reshape([batchSize, 1, 1]) + Y = paddle.divide(A, paddle.expand_as(normA1, A)) + I = paddle.eye(dim, dim).reshape([1, dim, dim]) + l0 = [] + for i in range(batchSize): + l0.append(I) + I = paddle.concat(l0, axis=0) + I.stop_gradient = False + Z = paddle.eye(dim, dim).reshape([1, dim, dim]) + l1 = [] + for i in range(batchSize): + l1.append(Z) + Z = paddle.concat(l1, axis=0) + Z.stop_gradient = False + + for i in range(numIters): + T = 0.5 * (3.0 * I - Z.bmm(Y)) + Y = Y.bmm(T) + Z = T.bmm(Z) + sA = Y * paddle.sqrt(normA1).reshape([batchSize, 1, 1]) + sA = paddle.expand_as(sA, A) + return sA + + def wasserstein_distance_sigma(sigma1, sigma2): + wasserstein_distance_item2 = paddle.matmul( + sigma1, sigma1) + paddle.matmul( + sigma2, sigma2) - 2 * self.sqrt_newton_schulz_autograd( + paddle.matmul( + paddle.matmul(sigma1, paddle.matmul(sigma2, sigma2)), + sigma1), 10) + wasserstein_distance_item2 = self.trace(wasserstein_distance_item2) + + return wasserstein_distance_item2 + + def xywhr2xyrs(self, xywhr): + xywhr = paddle.reshape(xywhr, [-1, 5]) + xy = xywhr[:, :2] + wh = paddle.clip(xywhr[:, 2:4], min=1e-7, max=1e7) + r = xywhr[:, 4] + cos_r = paddle.cos(r) + sin_r = paddle.sin(r) + R = paddle.stack( + (cos_r, -sin_r, sin_r, cos_r), axis=-1).reshape([-1, 2, 2]) + S = 0.5 * paddle.nn.functional.diag_embed(wh) + return xy, R, S + + def gwd_loss(self, + pred, + target, + fun='log', + tau=1.0, + alpha=1.0, + normalize=False): + + xy_p, R_p, S_p = self.xywhr2xyrs(pred) + xy_t, R_t, S_t = self.xywhr2xyrs(target) + + xy_distance = (xy_p - xy_t).square().sum(axis=-1) + + Sigma_p = R_p.matmul(S_p.square()).matmul(R_p.transpose([0, 2, 1])) + Sigma_t = R_t.matmul(S_t.square()).matmul(R_t.transpose([0, 2, 1])) + + whr_distance = paddle.diagonal( + S_p, axis1=-2, axis2=-1).square().sum(axis=-1) + + whr_distance = whr_distance + paddle.diagonal( + S_t, axis1=-2, axis2=-1).square().sum(axis=-1) + _t = Sigma_p.matmul(Sigma_t) + + _t_tr = paddle.diagonal(_t, axis1=-2, axis2=-1).sum(axis=-1) + _t_det_sqrt = paddle.diagonal(S_p, axis1=-2, axis2=-1).prod(axis=-1) + _t_det_sqrt = _t_det_sqrt * paddle.diagonal( + S_t, axis1=-2, axis2=-1).prod(axis=-1) + whr_distance = whr_distance + (-2) * ( + (_t_tr + 2 * _t_det_sqrt).clip(0).sqrt()) + + distance = (xy_distance + alpha * alpha * whr_distance).clip(0) + + if normalize: + wh_p = pred[..., 2:4].clip(min=1e-7, max=1e7) + wh_t = target[..., 2:4].clip(min=1e-7, max=1e7) + scale = ((wh_p.log() + wh_t.log()).sum(dim=-1) / 4).exp() + distance = distance / scale + + if fun == 'log': + distance = paddle.log1p(distance) + + if tau >= 1.0: + return 1 - 1 / (tau + distance) + + return distance diff --git a/ppdet/modeling/proposal_generator/target_layer.py b/ppdet/modeling/proposal_generator/target_layer.py index cc9880a446db6556b68b0f61728975dafaaf6131..f20ff2c7d695ae858ffd62fe74c184a1182c630c 100644 --- a/ppdet/modeling/proposal_generator/target_layer.py +++ b/ppdet/modeling/proposal_generator/target_layer.py @@ -451,16 +451,17 @@ class RBoxAssigner(object): anchors_num = anchors.shape[0] bbox_targets = np.zeros_like(anchors) bbox_weights = np.zeros_like(anchors) + bbox_gt_bboxes = np.zeros_like(anchors) pos_labels = np.ones(anchors_num, dtype=np.int32) * -1 pos_labels_weights = np.zeros(anchors_num, dtype=np.float32) pos_sampled_anchors = anchors[pos_inds] - #print('ancho target pos_inds', pos_inds, len(pos_inds)) pos_sampled_gt_boxes = gt_bboxes[anchor_gt_bbox_inds[pos_inds]] if len(pos_inds) > 0: pos_bbox_targets = self.rbox2delta(pos_sampled_anchors, pos_sampled_gt_boxes) bbox_targets[pos_inds, :] = pos_bbox_targets + bbox_gt_bboxes[pos_inds, :] = pos_sampled_gt_boxes bbox_weights[pos_inds, :] = 1.0 pos_labels[pos_inds] = labels[pos_inds] @@ -469,4 +470,4 @@ class RBoxAssigner(object): if len(neg_inds) > 0: pos_labels_weights[neg_inds] = 1.0 return (pos_labels, pos_labels_weights, bbox_targets, bbox_weights, - pos_inds, neg_inds) + bbox_gt_bboxes, pos_inds, neg_inds)