未验证 提交 4f96dc2f 编写于 作者: G Guanghua Yu 提交者: GitHub

add solov2 enhance model (#3517)

* add solov2 enhance model
上级 5f9b0bc3
......@@ -27,6 +27,20 @@ SOLOv2 (Segmenting Objects by Locations) is a fast instance segmentation framewo
- SOLOv2 is trained on COCO train2017 dataset and evaluated on val2017 results of `mAP(IoU=0.5:0.95)`.
## Enhanced model
| Backbone | Input size | Lr schd | V100 FP32(FPS) | Mask AP<sup>val</sup> | Download | Configs |
| :---------------------: | :-------------------: | :-----: | :------------: | :-----: | :---------: | :------------------------: |
| Light-R50-VD-DCN-FPN | 512 | 3x | 38.6 | 39.0 | [model](https://paddledet.bj.bcebos.com/models/solov2_r50_enhance_coco.pdparams) | [config](https://github.com/PaddlePaddle/PaddleDetection/tree/develop/configs/solov2/solov2_r50_enhance_coco.yml) |
**Optimizing method of enhanced model:**
- Better backbone network: ResNet50vd-DCN
- A better pre-training model for knowledge distillation
- [Exponential Moving Average](https://www.investopedia.com/terms/e/ema.asp)
- Synchronized Batch Normalization
- Multi-scale training
- More data augmentation methods
- DropBlock
## Citations
```
@article{wang2020solov2,
......
worker_num: 2
TrainReader:
sample_transforms:
- Decode: {}
- Poly2Mask: {}
- RandomDistort: {}
- RandomCrop: {}
- RandomResize: {interp: 1,
target_size: [[352, 852], [384, 852], [416, 852], [448, 852], [480, 852], [512, 852]],
keep_ratio: True}
- RandomFlip: {}
- NormalizeImage: {is_scale: true, mean: [0.485,0.456,0.406], std: [0.229, 0.224,0.225]}
- Permute: {}
batch_transforms:
- PadBatch: {pad_to_stride: 32}
- Gt2Solov2Target: {num_grids: [40, 36, 24, 16, 12],
scale_ranges: [[1, 96], [48, 192], [96, 384], [192, 768], [384, 2048]],
coord_sigma: 0.2}
batch_size: 2
shuffle: true
drop_last: true
EvalReader:
sample_transforms:
- Decode: {}
- NormalizeImage: {is_scale: true, mean: [0.485,0.456,0.406], std: [0.229, 0.224,0.225]}
- Resize: {interp: 1, target_size: [512, 852], keep_ratio: True}
- Permute: {}
batch_transforms:
- PadBatch: {pad_to_stride: 32}
batch_size: 1
shuffle: false
drop_last: false
TestReader:
sample_transforms:
- Decode: {}
- NormalizeImage: {is_scale: true, mean: [0.485,0.456,0.406], std: [0.229, 0.224,0.225]}
- Resize: {interp: 1, target_size: [512, 852], keep_ratio: True}
- Permute: {}
batch_transforms:
- PadBatch: {pad_to_stride: 32}
batch_size: 1
shuffle: false
drop_last: false
_BASE_: [
'../datasets/coco_instance.yml',
'../runtime.yml',
'_base_/solov2_r50_fpn.yml',
'_base_/optimizer_1x.yml',
'_base_/solov2_light_reader.yml',
]
pretrain_weights: https://paddledet.bj.bcebos.com/models/pretrained/ResNet50_vd_ssld_v2_pretrained.pdparams
weights: output/solov2_r50_fpn_3x_coco/model_final
epoch: 36
use_ema: true
ema_decay: 0.9998
ResNet:
depth: 50
variant: d
freeze_at: 0
freeze_norm: false
norm_type: sync_bn
return_idx: [0,1,2,3]
dcn_v2_stages: [1,2,3]
lr_mult_list: [0.05, 0.05, 0.1, 0.15]
num_stages: 4
SOLOv2Head:
seg_feat_channels: 256
stacked_convs: 3
num_grids: [40, 36, 24, 16, 12]
kernel_out_channels: 128
solov2_loss: SOLOv2Loss
mask_nms: MaskMatrixNMS
dcn_v2_stages: [2]
drop_block: True
SOLOv2MaskHead:
mid_channels: 128
out_channels: 128
start_level: 0
end_level: 3
use_dcn_in_tower: True
LearningRate:
base_lr: 0.01
schedulers:
- !PiecewiseDecay
gamma: 0.1
milestones: [24, 33]
- !LinearWarmup
start_factor: 0.
steps: 1000
......@@ -22,7 +22,7 @@ import paddle.nn as nn
import paddle.nn.functional as F
from paddle.nn.initializer import Normal, Constant
from ppdet.modeling.layers import ConvNormLayer, MaskMatrixNMS
from ppdet.modeling.layers import ConvNormLayer, MaskMatrixNMS, DropBlock
from ppdet.core.workspace import register
from six.moves import zip
......@@ -182,7 +182,8 @@ class SOLOv2Head(nn.Layer):
score_threshold=0.1,
mask_threshold=0.5,
mask_nms=None,
norm_type='gn'):
norm_type='gn',
drop_block=False):
super(SOLOv2Head, self).__init__()
self.num_classes = num_classes
self.in_channels = in_channels
......@@ -198,6 +199,7 @@ class SOLOv2Head(nn.Layer):
self.score_threshold = score_threshold
self.mask_threshold = mask_threshold
self.norm_type = norm_type
self.drop_block = drop_block
self.kernel_pred_convs = []
self.cate_pred_convs = []
......@@ -250,6 +252,10 @@ class SOLOv2Head(nn.Layer):
bias_attr=ParamAttr(initializer=Constant(
value=float(-np.log((1 - 0.01) / 0.01))))))
if self.drop_block:
self.drop_block_fun = DropBlock(
block_size=3, keep_prob=0.9, name='solo_cate.dropblock')
def _points_nms(self, heat, kernel_size=2):
hmax = F.max_pool2d(heat, kernel_size=kernel_size, stride=1, padding=1)
keep = paddle.cast((hmax[:, :, :-1, :-1] == heat), 'float32')
......@@ -318,10 +324,14 @@ class SOLOv2Head(nn.Layer):
for kernel_layer in self.kernel_pred_convs:
kernel_feat = F.relu(kernel_layer(kernel_feat))
if self.drop_block:
kernel_feat = self.drop_block_fun(kernel_feat)
kernel_pred = self.solo_kernel(kernel_feat)
# cate branch
for cate_layer in self.cate_pred_convs:
cate_feat = F.relu(cate_layer(cate_feat))
if self.drop_block:
cate_feat = self.drop_block_fun(cate_feat)
cate_pred = self.solo_cate(cate_feat)
if not self.training:
......
......@@ -250,6 +250,47 @@ class LiteConv(nn.Layer):
return out
class DropBlock(nn.Layer):
def __init__(self, block_size, keep_prob, name, data_format='NCHW'):
"""
DropBlock layer, see https://arxiv.org/abs/1810.12890
Args:
block_size (int): block size
keep_prob (int): keep probability
name (str): layer name
data_format (str): data format, NCHW or NHWC
"""
super(DropBlock, self).__init__()
self.block_size = block_size
self.keep_prob = keep_prob
self.name = name
self.data_format = data_format
def forward(self, x):
if not self.training or self.keep_prob == 1:
return x
else:
gamma = (1. - self.keep_prob) / (self.block_size**2)
if self.data_format == 'NCHW':
shape = x.shape[2:]
else:
shape = x.shape[1:3]
for s in shape:
gamma *= s / (s - self.block_size + 1)
matrix = paddle.cast(paddle.rand(x.shape, x.dtype) < gamma, x.dtype)
mask_inv = F.max_pool2d(
matrix,
self.block_size,
stride=1,
padding=self.block_size // 2,
data_format=self.data_format)
mask = 1. - mask_inv
y = x * mask * (mask.numel() / mask.sum())
return y
@register
@serializable
class AnchorGeneratorSSD(object):
......
......@@ -17,6 +17,7 @@ import paddle.nn as nn
import paddle.nn.functional as F
from paddle import ParamAttr
from ppdet.core.workspace import register, serializable
from ppdet.modeling.layers import DropBlock
from ..backbones.darknet import ConvBNLayer
from ..shape_spec import ShapeSpec
......@@ -173,47 +174,6 @@ class SPP(nn.Layer):
return y
class DropBlock(nn.Layer):
def __init__(self, block_size, keep_prob, name, data_format='NCHW'):
"""
DropBlock layer, see https://arxiv.org/abs/1810.12890
Args:
block_size (int): block size
keep_prob (int): keep probability
name (str): layer name
data_format (str): data format, NCHW or NHWC
"""
super(DropBlock, self).__init__()
self.block_size = block_size
self.keep_prob = keep_prob
self.name = name
self.data_format = data_format
def forward(self, x):
if not self.training or self.keep_prob == 1:
return x
else:
gamma = (1. - self.keep_prob) / (self.block_size**2)
if self.data_format == 'NCHW':
shape = x.shape[2:]
else:
shape = x.shape[1:3]
for s in shape:
gamma *= s / (s - self.block_size + 1)
matrix = paddle.cast(paddle.rand(x.shape, x.dtype) < gamma, x.dtype)
mask_inv = F.max_pool2d(
matrix,
self.block_size,
stride=1,
padding=self.block_size // 2,
data_format=self.data_format)
mask = 1. - mask_inv
y = x * mask * (mask.numel() / mask.sum())
return y
class CoordConv(nn.Layer):
def __init__(self,
ch_in,
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册