diff --git a/configs/solov2/README.md b/configs/solov2/README.md index b53f5b336eadcef22adf2ff7dc7b8e6143aae569..037b2f96ce9f04087eed386abe58390857d59c4d 100644 --- a/configs/solov2/README.md +++ b/configs/solov2/README.md @@ -27,6 +27,20 @@ SOLOv2 (Segmenting Objects by Locations) is a fast instance segmentation framewo - SOLOv2 is trained on COCO train2017 dataset and evaluated on val2017 results of `mAP(IoU=0.5:0.95)`. +## Enhanced model +| Backbone | Input size | Lr schd | V100 FP32(FPS) | Mask APval | Download | Configs | +| :---------------------: | :-------------------: | :-----: | :------------: | :-----: | :---------: | :------------------------: | +| Light-R50-VD-DCN-FPN | 512 | 3x | 38.6 | 39.0 | [model](https://paddledet.bj.bcebos.com/models/solov2_r50_enhance_coco.pdparams) | [config](https://github.com/PaddlePaddle/PaddleDetection/tree/develop/configs/solov2/solov2_r50_enhance_coco.yml) | + +**Optimizing method of enhanced model:** +- Better backbone network: ResNet50vd-DCN +- A better pre-training model for knowledge distillation +- [Exponential Moving Average](https://www.investopedia.com/terms/e/ema.asp) +- Synchronized Batch Normalization +- Multi-scale training +- More data augmentation methods +- DropBlock + ## Citations ``` @article{wang2020solov2, diff --git a/configs/solov2/_base_/solov2_light_reader.yml b/configs/solov2/_base_/solov2_light_reader.yml new file mode 100644 index 0000000000000000000000000000000000000000..901049c13d35251558c9235058cad80d8e5ea1be --- /dev/null +++ b/configs/solov2/_base_/solov2_light_reader.yml @@ -0,0 +1,47 @@ +worker_num: 2 +TrainReader: + sample_transforms: + - Decode: {} + - Poly2Mask: {} + - RandomDistort: {} + - RandomCrop: {} + - RandomResize: {interp: 1, + target_size: [[352, 852], [384, 852], [416, 852], [448, 852], [480, 852], [512, 852]], + keep_ratio: True} + - RandomFlip: {} + - NormalizeImage: {is_scale: true, mean: [0.485,0.456,0.406], std: [0.229, 0.224,0.225]} + - Permute: {} + batch_transforms: + - PadBatch: {pad_to_stride: 32} + - Gt2Solov2Target: {num_grids: [40, 36, 24, 16, 12], + scale_ranges: [[1, 96], [48, 192], [96, 384], [192, 768], [384, 2048]], + coord_sigma: 0.2} + batch_size: 2 + shuffle: true + drop_last: true + + +EvalReader: + sample_transforms: + - Decode: {} + - NormalizeImage: {is_scale: true, mean: [0.485,0.456,0.406], std: [0.229, 0.224,0.225]} + - Resize: {interp: 1, target_size: [512, 852], keep_ratio: True} + - Permute: {} + batch_transforms: + - PadBatch: {pad_to_stride: 32} + batch_size: 1 + shuffle: false + drop_last: false + + +TestReader: + sample_transforms: + - Decode: {} + - NormalizeImage: {is_scale: true, mean: [0.485,0.456,0.406], std: [0.229, 0.224,0.225]} + - Resize: {interp: 1, target_size: [512, 852], keep_ratio: True} + - Permute: {} + batch_transforms: + - PadBatch: {pad_to_stride: 32} + batch_size: 1 + shuffle: false + drop_last: false diff --git a/configs/solov2/solov2_r50_enhance_coco.yml b/configs/solov2/solov2_r50_enhance_coco.yml new file mode 100644 index 0000000000000000000000000000000000000000..0cadd8783a3c45efc4c20f96fcd3241a0df8c02a --- /dev/null +++ b/configs/solov2/solov2_r50_enhance_coco.yml @@ -0,0 +1,50 @@ +_BASE_: [ + '../datasets/coco_instance.yml', + '../runtime.yml', + '_base_/solov2_r50_fpn.yml', + '_base_/optimizer_1x.yml', + '_base_/solov2_light_reader.yml', +] +pretrain_weights: https://paddledet.bj.bcebos.com/models/pretrained/ResNet50_vd_ssld_v2_pretrained.pdparams +weights: output/solov2_r50_fpn_3x_coco/model_final +epoch: 36 +use_ema: true +ema_decay: 0.9998 + +ResNet: + depth: 50 + variant: d + freeze_at: 0 + freeze_norm: false + norm_type: sync_bn + return_idx: [0,1,2,3] + dcn_v2_stages: [1,2,3] + lr_mult_list: [0.05, 0.05, 0.1, 0.15] + num_stages: 4 + +SOLOv2Head: + seg_feat_channels: 256 + stacked_convs: 3 + num_grids: [40, 36, 24, 16, 12] + kernel_out_channels: 128 + solov2_loss: SOLOv2Loss + mask_nms: MaskMatrixNMS + dcn_v2_stages: [2] + drop_block: True + +SOLOv2MaskHead: + mid_channels: 128 + out_channels: 128 + start_level: 0 + end_level: 3 + use_dcn_in_tower: True + +LearningRate: + base_lr: 0.01 + schedulers: + - !PiecewiseDecay + gamma: 0.1 + milestones: [24, 33] + - !LinearWarmup + start_factor: 0. + steps: 1000 diff --git a/ppdet/modeling/heads/solov2_head.py b/ppdet/modeling/heads/solov2_head.py index 8338e53b460eca53a7cc6ec871d8671a516b7499..36c120cf4bfe811b587828fde63a8a6119e0f8fe 100644 --- a/ppdet/modeling/heads/solov2_head.py +++ b/ppdet/modeling/heads/solov2_head.py @@ -22,7 +22,7 @@ import paddle.nn as nn import paddle.nn.functional as F from paddle.nn.initializer import Normal, Constant -from ppdet.modeling.layers import ConvNormLayer, MaskMatrixNMS +from ppdet.modeling.layers import ConvNormLayer, MaskMatrixNMS, DropBlock from ppdet.core.workspace import register from six.moves import zip @@ -182,7 +182,8 @@ class SOLOv2Head(nn.Layer): score_threshold=0.1, mask_threshold=0.5, mask_nms=None, - norm_type='gn'): + norm_type='gn', + drop_block=False): super(SOLOv2Head, self).__init__() self.num_classes = num_classes self.in_channels = in_channels @@ -198,6 +199,7 @@ class SOLOv2Head(nn.Layer): self.score_threshold = score_threshold self.mask_threshold = mask_threshold self.norm_type = norm_type + self.drop_block = drop_block self.kernel_pred_convs = [] self.cate_pred_convs = [] @@ -250,6 +252,10 @@ class SOLOv2Head(nn.Layer): bias_attr=ParamAttr(initializer=Constant( value=float(-np.log((1 - 0.01) / 0.01)))))) + if self.drop_block: + self.drop_block_fun = DropBlock( + block_size=3, keep_prob=0.9, name='solo_cate.dropblock') + def _points_nms(self, heat, kernel_size=2): hmax = F.max_pool2d(heat, kernel_size=kernel_size, stride=1, padding=1) keep = paddle.cast((hmax[:, :, :-1, :-1] == heat), 'float32') @@ -318,10 +324,14 @@ class SOLOv2Head(nn.Layer): for kernel_layer in self.kernel_pred_convs: kernel_feat = F.relu(kernel_layer(kernel_feat)) + if self.drop_block: + kernel_feat = self.drop_block_fun(kernel_feat) kernel_pred = self.solo_kernel(kernel_feat) # cate branch for cate_layer in self.cate_pred_convs: cate_feat = F.relu(cate_layer(cate_feat)) + if self.drop_block: + cate_feat = self.drop_block_fun(cate_feat) cate_pred = self.solo_cate(cate_feat) if not self.training: diff --git a/ppdet/modeling/layers.py b/ppdet/modeling/layers.py index 919c3d3531a630c02d5779862cb25a0ffb742209..ca945c2df0f83360d7ebb6fdb2df0d0b508c04e0 100644 --- a/ppdet/modeling/layers.py +++ b/ppdet/modeling/layers.py @@ -250,6 +250,47 @@ class LiteConv(nn.Layer): return out +class DropBlock(nn.Layer): + def __init__(self, block_size, keep_prob, name, data_format='NCHW'): + """ + DropBlock layer, see https://arxiv.org/abs/1810.12890 + + Args: + block_size (int): block size + keep_prob (int): keep probability + name (str): layer name + data_format (str): data format, NCHW or NHWC + """ + super(DropBlock, self).__init__() + self.block_size = block_size + self.keep_prob = keep_prob + self.name = name + self.data_format = data_format + + def forward(self, x): + if not self.training or self.keep_prob == 1: + return x + else: + gamma = (1. - self.keep_prob) / (self.block_size**2) + if self.data_format == 'NCHW': + shape = x.shape[2:] + else: + shape = x.shape[1:3] + for s in shape: + gamma *= s / (s - self.block_size + 1) + + matrix = paddle.cast(paddle.rand(x.shape, x.dtype) < gamma, x.dtype) + mask_inv = F.max_pool2d( + matrix, + self.block_size, + stride=1, + padding=self.block_size // 2, + data_format=self.data_format) + mask = 1. - mask_inv + y = x * mask * (mask.numel() / mask.sum()) + return y + + @register @serializable class AnchorGeneratorSSD(object): diff --git a/ppdet/modeling/necks/yolo_fpn.py b/ppdet/modeling/necks/yolo_fpn.py index 1ce6531033aca52dae564dcf672276907b9d56cd..d3197f0586b4170cbef6760c3fb945578cd73ccc 100644 --- a/ppdet/modeling/necks/yolo_fpn.py +++ b/ppdet/modeling/necks/yolo_fpn.py @@ -17,6 +17,7 @@ import paddle.nn as nn import paddle.nn.functional as F from paddle import ParamAttr from ppdet.core.workspace import register, serializable +from ppdet.modeling.layers import DropBlock from ..backbones.darknet import ConvBNLayer from ..shape_spec import ShapeSpec @@ -173,47 +174,6 @@ class SPP(nn.Layer): return y -class DropBlock(nn.Layer): - def __init__(self, block_size, keep_prob, name, data_format='NCHW'): - """ - DropBlock layer, see https://arxiv.org/abs/1810.12890 - - Args: - block_size (int): block size - keep_prob (int): keep probability - name (str): layer name - data_format (str): data format, NCHW or NHWC - """ - super(DropBlock, self).__init__() - self.block_size = block_size - self.keep_prob = keep_prob - self.name = name - self.data_format = data_format - - def forward(self, x): - if not self.training or self.keep_prob == 1: - return x - else: - gamma = (1. - self.keep_prob) / (self.block_size**2) - if self.data_format == 'NCHW': - shape = x.shape[2:] - else: - shape = x.shape[1:3] - for s in shape: - gamma *= s / (s - self.block_size + 1) - - matrix = paddle.cast(paddle.rand(x.shape, x.dtype) < gamma, x.dtype) - mask_inv = F.max_pool2d( - matrix, - self.block_size, - stride=1, - padding=self.block_size // 2, - data_format=self.data_format) - mask = 1. - mask_inv - y = x * mask * (mask.numel() / mask.sum()) - return y - - class CoordConv(nn.Layer): def __init__(self, ch_in,