diff --git a/configs/solov2/README.md b/configs/solov2/README.md index 81d6cddfae34c1555b52ee5aef177c6d9ecc5ae8..b53f5b336eadcef22adf2ff7dc7b8e6143aae569 100644 --- a/configs/solov2/README.md +++ b/configs/solov2/README.md @@ -21,6 +21,7 @@ SOLOv2 (Segmenting Objects by Locations) is a fast instance segmentation framewo | SOLOv2 (Paper) | X101-DCN-FPN | True | 3x | 42.4 | 5.9 | V100 | - | - | | SOLOv2 | R50-FPN | False | 1x | 35.5 | 21.9 | V100 | [model](https://paddledet.bj.bcebos.com/models/solov2_r50_fpn_1x_coco.pdparams) | [config](https://github.com/PaddlePaddle/PaddleDetection/tree/develop/configs/solov2/solov2_r50_fpn_1x_coco.yml) | | SOLOv2 | R50-FPN | True | 3x | 38.0 | 21.9 | V100 | [model](https://paddledet.bj.bcebos.com/models/solov2_r50_fpn_3x_coco.pdparams) | [config](https://github.com/PaddlePaddle/PaddleDetection/tree/develop/configs/solov2/solov2_r50_fpn_3x_coco.yml) | +| SOLOv2 | R101vd-FPN | True | 3x | 42.7 | 12.1 | V100 | [model](https://paddledet.bj.bcebos.com/models/solov2_r101_vd_fpn_3x_coco.pdparams) | [config](https://github.com/PaddlePaddle/PaddleDetection/tree/develop/configs/solov2/solov2_r101_vd_fpn_3x_coco.yml) | **Notes:** diff --git a/configs/solov2/_base_/solov2_r50_fpn.yml b/configs/solov2/_base_/solov2_r50_fpn.yml index 53ec3b288fbb1f2e39d7d4f504fb069e26faeaa8..93a6892698a879c6ff60e731f617e6d0649072a9 100644 --- a/configs/solov2/_base_/solov2_r50_fpn.yml +++ b/configs/solov2/_base_/solov2_r50_fpn.yml @@ -9,7 +9,6 @@ SOLOv2: ResNet: depth: 50 - norm_type: bn freeze_at: 0 return_idx: [0,1,2,3] num_stages: 4 diff --git a/configs/solov2/solov2_r101_vd_fpn_3x_coco.yml b/configs/solov2/solov2_r101_vd_fpn_3x_coco.yml new file mode 100644 index 0000000000000000000000000000000000000000..db29c9ad19edd5396562e1b9e3f8400ae1a3367c --- /dev/null +++ b/configs/solov2/solov2_r101_vd_fpn_3x_coco.yml @@ -0,0 +1,66 @@ +_BASE_: [ + '../datasets/coco_instance.yml', + '../runtime.yml', + '_base_/solov2_r50_fpn.yml', + '_base_/optimizer_1x.yml', + '_base_/solov2_reader.yml', +] +pretrain_weights: https://paddledet.bj.bcebos.com/models/pretrained/ResNet101_vd_pretrained.pdparams +weights: output/solov2_r101_vd_fpn_3x_coco/model_final +epoch: 36 +use_ema: true +ema_decay: 0.9998 + +ResNet: + depth: 101 + variant: d + freeze_at: 0 + return_idx: [0,1,2,3] + dcn_v2_stages: [1,2,3] + num_stages: 4 + +SOLOv2Head: + seg_feat_channels: 512 + stacked_convs: 4 + num_grids: [40, 36, 24, 16, 12] + kernel_out_channels: 256 + solov2_loss: SOLOv2Loss + mask_nms: MaskMatrixNMS + dcn_v2_stages: [0, 1, 2, 3] + +SOLOv2MaskHead: + mid_channels: 128 + out_channels: 256 + start_level: 0 + end_level: 3 + use_dcn_in_tower: True + + +LearningRate: + base_lr: 0.01 + schedulers: + - !PiecewiseDecay + gamma: 0.1 + milestones: [24, 33] + - !LinearWarmup + start_factor: 0. + steps: 2000 + +TrainReader: + sample_transforms: + - Decode: {} + - Poly2Mask: {} + - RandomResize: {interp: 1, + target_size: [[640, 1333], [672, 1333], [704, 1333], [736, 1333], [768, 1333], [800, 1333]], + keep_ratio: True} + - RandomFlip: {} + - NormalizeImage: {is_scale: true, mean: [0.485,0.456,0.406], std: [0.229, 0.224,0.225]} + - Permute: {} + batch_transforms: + - PadBatch: {pad_to_stride: 32} + - Gt2Solov2Target: {num_grids: [40, 36, 24, 16, 12], + scale_ranges: [[1, 96], [48, 192], [96, 384], [192, 768], [384, 2048]], + coord_sigma: 0.2} + batch_size: 2 + shuffle: true + drop_last: true diff --git a/ppdet/modeling/heads/solov2_head.py b/ppdet/modeling/heads/solov2_head.py index ac5ebe7df79426045e2b4d8057b54c78968ef89f..8338e53b460eca53a7cc6ec871d8671a516b7499 100644 --- a/ppdet/modeling/heads/solov2_head.py +++ b/ppdet/modeling/heads/solov2_head.py @@ -43,6 +43,7 @@ class SOLOv2MaskHead(nn.Layer): end_level (int): The position where the input ends. use_dcn_in_tower (bool): Whether to use dcn in tower or not. """ + __shared__ = ['norm_type'] def __init__(self, in_channels=256, @@ -50,7 +51,8 @@ class SOLOv2MaskHead(nn.Layer): out_channels=256, start_level=0, end_level=3, - use_dcn_in_tower=False): + use_dcn_in_tower=False, + norm_type='gn'): super(SOLOv2MaskHead, self).__init__() assert start_level >= 0 and end_level >= start_level self.in_channels = in_channels @@ -58,24 +60,22 @@ class SOLOv2MaskHead(nn.Layer): self.mid_channels = mid_channels self.use_dcn_in_tower = use_dcn_in_tower self.range_level = end_level - start_level + 1 - # TODO: add DeformConvNorm - conv_type = [ConvNormLayer] - self.conv_func = conv_type[0] - if self.use_dcn_in_tower: - self.conv_func = conv_type[1] + self.use_dcn = True if self.use_dcn_in_tower else False self.convs_all_levels = [] + self.norm_type = norm_type for i in range(start_level, end_level + 1): conv_feat_name = 'mask_feat_head.convs_all_levels.{}'.format(i) conv_pre_feat = nn.Sequential() if i == start_level: conv_pre_feat.add_sublayer( conv_feat_name + '.conv' + str(i), - self.conv_func( + ConvNormLayer( ch_in=self.in_channels, ch_out=self.mid_channels, filter_size=3, stride=1, - norm_type='gn')) + use_dcn=self.use_dcn, + norm_type=self.norm_type)) self.add_sublayer('conv_pre_feat' + str(i), conv_pre_feat) self.convs_all_levels.append(conv_pre_feat) else: @@ -87,12 +87,13 @@ class SOLOv2MaskHead(nn.Layer): ch_in = self.mid_channels conv_pre_feat.add_sublayer( conv_feat_name + '.conv' + str(j), - self.conv_func( + ConvNormLayer( ch_in=ch_in, ch_out=self.mid_channels, filter_size=3, stride=1, - norm_type='gn')) + use_dcn=self.use_dcn, + norm_type=self.norm_type)) conv_pre_feat.add_sublayer( conv_feat_name + '.conv' + str(j) + 'act', nn.ReLU()) conv_pre_feat.add_sublayer( @@ -105,12 +106,13 @@ class SOLOv2MaskHead(nn.Layer): conv_pred_name = 'mask_feat_head.conv_pred.0' self.conv_pred = self.add_sublayer( conv_pred_name, - self.conv_func( + ConvNormLayer( ch_in=self.mid_channels, ch_out=self.out_channels, filter_size=1, stride=1, - norm_type='gn')) + use_dcn=self.use_dcn, + norm_type=self.norm_type)) def forward(self, inputs): """ @@ -165,7 +167,7 @@ class SOLOv2Head(nn.Layer): mask_nms (object): MaskMatrixNMS instance. """ __inject__ = ['solov2_loss', 'mask_nms'] - __shared__ = ['num_classes'] + __shared__ = ['norm_type', 'num_classes'] def __init__(self, num_classes=80, @@ -179,7 +181,8 @@ class SOLOv2Head(nn.Layer): solov2_loss=None, score_threshold=0.1, mask_threshold=0.5, - mask_nms=None): + mask_nms=None, + norm_type='gn'): super(SOLOv2Head, self).__init__() self.num_classes = num_classes self.in_channels = in_channels @@ -194,33 +197,33 @@ class SOLOv2Head(nn.Layer): self.mask_nms = mask_nms self.score_threshold = score_threshold self.mask_threshold = mask_threshold + self.norm_type = norm_type - conv_type = [ConvNormLayer] - self.conv_func = conv_type[0] self.kernel_pred_convs = [] self.cate_pred_convs = [] for i in range(self.stacked_convs): - if i in self.dcn_v2_stages: - self.conv_func = conv_type[1] + use_dcn = True if i in self.dcn_v2_stages else False ch_in = self.in_channels + 2 if i == 0 else self.seg_feat_channels kernel_conv = self.add_sublayer( 'bbox_head.kernel_convs.' + str(i), - self.conv_func( + ConvNormLayer( ch_in=ch_in, ch_out=self.seg_feat_channels, filter_size=3, stride=1, - norm_type='gn')) + use_dcn=use_dcn, + norm_type=self.norm_type)) self.kernel_pred_convs.append(kernel_conv) ch_in = self.in_channels if i == 0 else self.seg_feat_channels cate_conv = self.add_sublayer( 'bbox_head.cate_convs.' + str(i), - self.conv_func( + ConvNormLayer( ch_in=ch_in, ch_out=self.seg_feat_channels, filter_size=3, stride=1, - norm_type='gn')) + use_dcn=use_dcn, + norm_type=self.norm_type)) self.cate_pred_convs.append(cate_conv) self.solo_kernel = self.add_sublayer(