From 2d21d90783d67a4e32951654fe7c9f06018a2325 Mon Sep 17 00:00:00 2001 From: cnn Date: Fri, 30 Jul 2021 20:54:31 +0800 Subject: [PATCH] [dev] rbox update2 (#3828) * set lr for 4 card as default, and update --- configs/dota/_base_/s2anet_optimizer_1x.yml | 2 +- configs/dota/s2anet_1x_spine.yml | 7 +- configs/dota/s2anet_alignconv_2x_dota.yml | 2 +- ppdet/engine/export_utils.py | 5 - ppdet/modeling/heads/s2anet_head.py | 105 ++++++++++---------- ppdet/modeling/post_process.py | 2 +- 6 files changed, 60 insertions(+), 63 deletions(-) diff --git a/configs/dota/_base_/s2anet_optimizer_1x.yml b/configs/dota/_base_/s2anet_optimizer_1x.yml index 65f794dc3..91069a3bb 100644 --- a/configs/dota/_base_/s2anet_optimizer_1x.yml +++ b/configs/dota/_base_/s2anet_optimizer_1x.yml @@ -1,7 +1,7 @@ epoch: 12 LearningRate: - base_lr: 0.01 + base_lr: 0.005 schedulers: - !PiecewiseDecay gamma: 0.1 diff --git a/configs/dota/s2anet_1x_spine.yml b/configs/dota/s2anet_1x_spine.yml index 017ee28e5..440ac143a 100644 --- a/configs/dota/s2anet_1x_spine.yml +++ b/configs/dota/s2anet_1x_spine.yml @@ -8,9 +8,9 @@ _BASE_: [ weights: output/s2anet_1x_spine/model_final -# for 8 card +# for 4 card LearningRate: - base_lr: 0.01 + base_lr: 0.005 S2ANetHead: anchor_strides: [8, 16, 32, 64, 128] @@ -26,5 +26,4 @@ S2ANetHead: use_sigmoid_cls: True reg_loss_weight: [1.0, 1.0, 1.0, 1.0, 1.05] cls_loss_weight: [1.05, 1.0] - reg_loss_type: gwd - use_paddle_anchor: False + reg_loss_type: 'l1' diff --git a/configs/dota/s2anet_alignconv_2x_dota.yml b/configs/dota/s2anet_alignconv_2x_dota.yml index 7355a0122..48e2d2b38 100644 --- a/configs/dota/s2anet_alignconv_2x_dota.yml +++ b/configs/dota/s2anet_alignconv_2x_dota.yml @@ -23,4 +23,4 @@ S2ANetHead: use_sigmoid_cls: True reg_loss_weight: [1.0, 1.0, 1.0, 1.0, 1.05] cls_loss_weight: [1.05, 1.0] - #reg_loss_type: 'l1' # 'l1' 'gwd' + reg_loss_type: 'l1' diff --git a/ppdet/engine/export_utils.py b/ppdet/engine/export_utils.py index b3c26970f..236779742 100644 --- a/ppdet/engine/export_utils.py +++ b/ppdet/engine/export_utils.py @@ -139,10 +139,5 @@ def _dump_infer_config(config, path, image_shape, model): infer_cfg['Preprocess'], infer_cfg['label_list'] = _parse_reader( reader_cfg, dataset_cfg, config['metric'], label_arch, image_shape) - if infer_arch == 'S2ANet': - # TODO: move background to num_classes - if infer_cfg['label_list'][0] != 'background': - infer_cfg['label_list'].insert(0, 'background') - yaml.dump(infer_cfg, open(path, 'w')) logger.info("Export inference config file to {}".format(os.path.join(path))) diff --git a/ppdet/modeling/heads/s2anet_head.py b/ppdet/modeling/heads/s2anet_head.py index d578bbef4..9545fcc85 100644 --- a/ppdet/modeling/heads/s2anet_head.py +++ b/ppdet/modeling/heads/s2anet_head.py @@ -69,7 +69,7 @@ class S2ANetAnchorGenerator(nn.Layer): return base_anchors def _meshgrid(self, x, y, row_major=True): - yy, xx = paddle.meshgrid(x, y) + yy, xx = paddle.meshgrid(y, x) yy = yy.reshape([-1]) xx = xx.reshape([-1]) if row_major: @@ -264,7 +264,7 @@ class S2ANetHead(nn.Layer): for anchor_base in self.anchor_base_sizes: self.anchor_generators.append( S2ANetAnchorGenerator(anchor_base, anchor_scales, - anchor_ratios)) + anchor_ratios)) self.anchor_generators = nn.LayerList(self.anchor_generators) self.fam_cls_convs = nn.Sequential() @@ -551,33 +551,35 @@ class S2ANetHead(nn.Layer): fam_cls_score1 = fam_cls_score feat_labels = paddle.to_tensor(feat_labels) - feat_labels_one_hot = paddle.nn.functional.one_hot( - feat_labels, self.cls_out_channels + 1) - feat_labels_one_hot = feat_labels_one_hot[:, 1:] - feat_labels_one_hot.stop_gradient = True - - num_total_samples = paddle.to_tensor( - num_total_samples, dtype='float32', stop_gradient=True) - - fam_cls = F.sigmoid_focal_loss( - fam_cls_score1, - feat_labels_one_hot, - normalizer=num_total_samples, - reduction='none') - - feat_label_weights = feat_label_weights.reshape( - feat_label_weights.shape[0], 1) - feat_label_weights = np.repeat( - feat_label_weights, self.cls_out_channels, axis=1) - feat_label_weights = paddle.to_tensor( - feat_label_weights, stop_gradient=True) - - fam_cls = fam_cls * feat_label_weights - fam_cls_total = paddle.sum(fam_cls) + if (feat_labels >= 0).astype(paddle.int32).sum() > 0: + feat_labels_one_hot = paddle.nn.functional.one_hot( + feat_labels, self.cls_out_channels + 1) + feat_labels_one_hot = feat_labels_one_hot[:, 1:] + feat_labels_one_hot.stop_gradient = True + + num_total_samples = paddle.to_tensor( + num_total_samples, dtype='float32', stop_gradient=True) + + fam_cls = F.sigmoid_focal_loss( + fam_cls_score1, + feat_labels_one_hot, + normalizer=num_total_samples, + reduction='none') + + feat_label_weights = feat_label_weights.reshape( + feat_label_weights.shape[0], 1) + feat_label_weights = np.repeat( + feat_label_weights, self.cls_out_channels, axis=1) + feat_label_weights = paddle.to_tensor( + feat_label_weights, stop_gradient=True) + + fam_cls = fam_cls * feat_label_weights + fam_cls_total = paddle.sum(fam_cls) + else: + fam_cls_total = paddle.zeros([0], dtype=fam_cls_score1.dtype) fam_cls_losses.append(fam_cls_total) # step3: regression loss - fam_bbox_pred = fam_reg_branch_list[idx] feat_bbox_targets = paddle.to_tensor( feat_bbox_targets, dtype='float32', stop_gradient=True) feat_bbox_targets = paddle.reshape(feat_bbox_targets, [-1, 5]) @@ -585,8 +587,6 @@ class S2ANetHead(nn.Layer): fam_bbox_pred = fam_reg_branch_list[idx] fam_bbox_pred = paddle.squeeze(fam_bbox_pred, axis=0) fam_bbox_pred = paddle.reshape(fam_bbox_pred, [-1, 5]) - fam_bbox = self.smooth_l1_loss(fam_bbox_pred, feat_bbox_targets) - fam_bbox = self.smooth_l1_loss(fam_bbox_pred, feat_bbox_targets) loss_weight = paddle.to_tensor( self.reg_loss_weight, dtype='float32', stop_gradient=True) @@ -673,28 +673,31 @@ class S2ANetHead(nn.Layer): odm_cls_score1 = odm_cls_score feat_labels = paddle.to_tensor(feat_labels) - feat_labels_one_hot = paddle.nn.functional.one_hot( - feat_labels, self.cls_out_channels + 1) - feat_labels_one_hot = feat_labels_one_hot[:, 1:] - feat_labels_one_hot.stop_gradient = True - - num_total_samples = paddle.to_tensor( - num_total_samples, dtype='float32', stop_gradient=True) - odm_cls = F.sigmoid_focal_loss( - odm_cls_score1, - feat_labels_one_hot, - normalizer=num_total_samples, - reduction='none') - - feat_label_weights = feat_label_weights.reshape( - feat_label_weights.shape[0], 1) - feat_label_weights = np.repeat( - feat_label_weights, self.cls_out_channels, axis=1) - feat_label_weights = paddle.to_tensor(feat_label_weights) - feat_label_weights.stop_gradient = True - - odm_cls = odm_cls * feat_label_weights - odm_cls_total = paddle.sum(odm_cls) + if (feat_labels >= 0).astype(paddle.int32).sum() > 0: + feat_labels_one_hot = paddle.nn.functional.one_hot( + feat_labels, self.cls_out_channels + 1) + feat_labels_one_hot = feat_labels_one_hot[:, 1:] + feat_labels_one_hot.stop_gradient = True + + num_total_samples = paddle.to_tensor( + num_total_samples, dtype='float32', stop_gradient=True) + odm_cls = F.sigmoid_focal_loss( + odm_cls_score1, + feat_labels_one_hot, + normalizer=num_total_samples, + reduction='none') + + feat_label_weights = feat_label_weights.reshape( + feat_label_weights.shape[0], 1) + feat_label_weights = np.repeat( + feat_label_weights, self.cls_out_channels, axis=1) + feat_label_weights = paddle.to_tensor(feat_label_weights) + feat_label_weights.stop_gradient = True + + odm_cls = odm_cls * feat_label_weights + odm_cls_total = paddle.sum(odm_cls) + else: + odm_cls_total = paddle.zeros([0], dtype=odm_cls_score1.dtype) odm_cls_losses.append(odm_cls_total) # # step3: regression loss @@ -846,7 +849,7 @@ class S2ANetHead(nn.Layer): bbox_pred = paddle.reshape(bbox_pred, [-1, 5]) anchors = paddle.reshape(anchors, [-1, 5]) - if nms_pre > 0 and scores.shape[0] > nms_pre: + if scores.shape[0] > nms_pre: # Get maximum scores for foreground classes. if use_sigmoid_cls: max_scores = paddle.max(scores, axis=1) diff --git a/ppdet/modeling/post_process.py b/ppdet/modeling/post_process.py index 95f51a9a8..d9bda6bd7 100644 --- a/ppdet/modeling/post_process.py +++ b/ppdet/modeling/post_process.py @@ -230,7 +230,7 @@ class S2ANetBBoxPostProcess(nn.Layer): def __init__(self, num_classes=15, nms_pre=2000, min_bbox_size=0, nms=None): super(S2ANetBBoxPostProcess, self).__init__() self.num_classes = num_classes - self.nms_pre = nms_pre + self.nms_pre = paddle.to_tensor(nms_pre) self.min_bbox_size = min_bbox_size self.nms = nms self.origin_shape_list = [] -- GitLab