未验证 提交 3cf5a126 编写于 作者: littletomatodonkey's avatar littletomatodonkey 提交者: GitHub

add server side rcnn model (#433)

* add server side rcnn model, 62 fps with coco mAP 41.6% and 20fps with coco mAP 47.8%
上级 c06f1ea0
# Practical Server-side detection method base on RCNN
## Introduction
* This is developed by PaddleDetection. Many useful tricks are utilized for the model training process. More details can be seen in the configuration file.
## Model Zoo
| Backbone | Type | Image/gpu | Lr schd | Inf time (fps) | Box AP | Mask AP | Download |
| :---------------------- | :-------------: | :-------: | :-----: | :------------: | :----: | :-----: | :----------------------------------------------------------: |
| ResNet50-vd-FPN-Dcnv2 | Faster | 2 | 3x | 61.425 | 41.6 | - | [model](https://paddlemodels.bj.bcebos.com/object_detection/faster_rcnn_dcn_r50_vd_fpn_3x_server_side.tar) |
| ResNet50-vd-FPN-Dcnv2 | Cascade Faster | 2 | 3x | 20.001 | 47.8 | - | [model](https://paddlemodels.bj.bcebos.com/object_detection/cascade_rcnn_dcn_r50_vd_fpn_3x_server_side.tar) |
architecture: CascadeRCNN
max_iters: 270000
snapshot_iter: 30000
use_gpu: true
log_smooth_window: 20
log_iter: 20
save_dir: output
pretrain_weights: https://paddle-imagenet-models-name.bj.bcebos.com/ResNet50_vd_82.39_pretrained.tar
weights: output/cascade_rcnn_dcn_r50_vd_fpn_3x_server_side/model_final
metric: COCO
num_classes: 81
CascadeRCNN:
backbone: ResNet
fpn: FPN
rpn_head: FPNRPNHead
roi_extractor: FPNRoIAlign
bbox_head: CascadeBBoxHead
bbox_assigner: CascadeBBoxAssigner
ResNet:
norm_type: bn
depth: 50
feature_maps: [2, 3, 4, 5]
freeze_at: 2
variant: d
dcn_v2_stages: [3, 4, 5]
lr_mult_list: [0.05, 0.05, 0.1, 0.15]
FPN:
max_level: 6
min_level: 2
num_chan: 64
spatial_scale: [0.03125, 0.0625, 0.125, 0.25]
FPNRPNHead:
anchor_generator:
anchor_sizes: [32, 64, 128, 256, 512]
aspect_ratios: [0.5, 1.0, 2.0]
stride: [16.0, 16.0]
variance: [1.0, 1.0, 1.0, 1.0]
anchor_start_size: 32
min_level: 2
max_level: 6
num_chan: 64
rpn_target_assign:
rpn_batch_size_per_im: 256
rpn_fg_fraction: 0.5
rpn_positive_overlap: 0.7
rpn_negative_overlap: 0.3
rpn_straddle_thresh: 0.0
train_proposal:
min_size: 0.0
nms_thresh: 0.7
pre_nms_top_n: 2000
post_nms_top_n: 2000
test_proposal:
min_size: 0.0
nms_thresh: 0.7
pre_nms_top_n: 500
post_nms_top_n: 300
FPNRoIAlign:
canconical_level: 4
canonical_size: 224
min_level: 2
max_level: 5
box_resolution: 7
sampling_ratio: 2
CascadeBBoxAssigner:
batch_size_per_im: 512
bbox_reg_weights: [10, 20, 30]
bg_thresh_lo: [0.0, 0.0, 0.0]
bg_thresh_hi: [0.5, 0.6, 0.7]
fg_thresh: [0.5, 0.6, 0.7]
fg_fraction: 0.25
CascadeBBoxHead:
head: CascadeTwoFCHead
bbox_loss: BalancedL1Loss
nms:
keep_top_k: 100
nms_threshold: 0.5
score_threshold: 0.05
BalancedL1Loss:
alpha: 0.5
gamma: 1.5
beta: 1.0
loss_weight: 1.0
CascadeTwoFCHead:
mlp_dim: 1024
LearningRate:
base_lr: 0.02
schedulers:
- !PiecewiseDecay
gamma: 0.1
milestones: [180000, 240000]
- !LinearWarmup
start_factor: 0.1
steps: 1000
OptimizerBuilder:
optimizer:
momentum: 0.9
type: Momentum
regularizer:
factor: 0.0001
type: L2
TrainReader:
inputs_def:
fields: ['image', 'im_info', 'im_id', 'gt_bbox', 'gt_class', 'is_crowd']
dataset:
!COCODataSet
image_dir: train2017
anno_path: annotations/instances_train2017.json
dataset_dir: dataset/coco
sample_transforms:
- !DecodeImage
to_rgb: true
- !RandomFlipImage
prob: 0.5
- !AutoAugmentImage
autoaug_type: v1
- !NormalizeImage
is_channel_first: false
is_scale: true
mean: [0.485,0.456,0.406]
std: [0.229, 0.224,0.225]
- !ResizeImage
target_size: [640, 672, 704, 736, 768, 800, 832, 864, 896, 928, 960, 992, 1024]
max_size: 1500
interp: 1
use_cv2: true
- !Permute
to_bgr: false
channel_first: true
batch_transforms:
- !PadBatch
pad_to_stride: 32
use_padded_im_info: false
batch_size: 2
shuffle: true
worker_num: 2
use_process: false
TestReader:
inputs_def:
# set image_shape if needed
fields: ['image', 'im_info', 'im_id', 'im_shape']
dataset:
!ImageFolder
anno_path: annotations/instances_val2017.json
sample_transforms:
- !DecodeImage
to_rgb: true
with_mixup: false
- !NormalizeImage
is_channel_first: false
is_scale: true
mean: [0.485,0.456,0.406]
std: [0.229, 0.224,0.225]
- !ResizeImage
interp: 1
max_size: 1500
target_size: 1000
use_cv2: true
- !Permute
channel_first: true
to_bgr: false
batch_transforms:
- !PadBatch
pad_to_stride: 32
use_padded_im_info: true
batch_size: 1
shuffle: false
EvalReader:
inputs_def:
fields: ['image', 'im_info', 'im_id', 'im_shape']
# for voc
#fields: ['image', 'im_info', 'im_id', 'gt_bbox', 'gt_class', 'is_difficult']
dataset:
!COCODataSet
image_dir: val2017
anno_path: annotations/instances_val2017.json
dataset_dir: dataset/coco
sample_transforms:
- !DecodeImage
to_rgb: true
with_mixup: false
- !NormalizeImage
is_channel_first: false
is_scale: true
mean: [0.485,0.456,0.406]
std: [0.229, 0.224,0.225]
- !ResizeImage
interp: 1
max_size: 1500
target_size: 1000
use_cv2: true
- !Permute
channel_first: true
to_bgr: false
batch_transforms:
- !PadBatch
pad_to_stride: 32
use_padded_im_info: true
batch_size: 1
shuffle: false
drop_empty: false
worker_num: 2
architecture: FasterRCNN
max_iters: 270000
snapshot_iter: 30000
use_gpu: true
log_smooth_window: 20
log_iter: 20
save_dir: output
pretrain_weights: https://paddle-imagenet-models-name.bj.bcebos.com/ResNet50_vd_82.39_pretrained.tar
weights: output/faster_rcnn_dcn_r50_vd_fpn_3x_server_side/model_final
metric: COCO
num_classes: 81
FasterRCNN:
backbone: ResNet
fpn: FPN
rpn_head: FPNRPNHead
roi_extractor: FPNRoIAlign
bbox_head: BBoxHead
bbox_assigner: LibraBBoxAssigner
ResNet:
depth: 50
feature_maps: [2, 3, 4, 5]
freeze_at: 2
norm_type: bn
variant: d
dcn_v2_stages: [3, 4, 5]
lr_mult_list: [0.05, 0.05, 0.1, 0.15]
FPN:
max_level: 6
min_level: 2
num_chan: 64
spatial_scale: [0.03125, 0.0625, 0.125, 0.25]
FPNRPNHead:
anchor_generator:
anchor_sizes: [32, 64, 128, 256, 512]
aspect_ratios: [0.5, 1.0, 2.0]
stride: [16.0, 16.0]
variance: [1.0, 1.0, 1.0, 1.0]
anchor_start_size: 32
max_level: 6
min_level: 2
num_chan: 64
rpn_target_assign:
rpn_batch_size_per_im: 256
rpn_fg_fraction: 0.5
rpn_negative_overlap: 0.3
rpn_positive_overlap: 0.7
rpn_straddle_thresh: 0.0
train_proposal:
min_size: 0.0
nms_thresh: 0.7
post_nms_top_n: 2000
pre_nms_top_n: 2000
test_proposal:
min_size: 0.0
nms_thresh: 0.7
post_nms_top_n: 300
pre_nms_top_n: 500
FPNRoIAlign:
canconical_level: 4
canonical_size: 224
max_level: 5
min_level: 2
box_resolution: 7
sampling_ratio: 2
LibraBBoxAssigner:
batch_size_per_im: 512
bbox_reg_weights: [0.1, 0.1, 0.2, 0.2]
bg_thresh_hi: 0.5
bg_thresh_lo: 0.0
fg_fraction: 0.25
fg_thresh: 0.5
BBoxHead:
head: TwoFCHead
nms:
keep_top_k: 100
nms_threshold: 0.5
score_threshold: 0.05
bbox_loss: DiouLoss
DiouLoss:
loss_weight: 10.0
is_cls_agnostic: false
use_complete_iou_loss: true
TwoFCHead:
mlp_dim: 1024
LearningRate:
base_lr: 0.02
schedulers:
- !PiecewiseDecay
gamma: 0.1
milestones: [180000, 240000]
- !LinearWarmup
start_factor: 0.1
steps: 1000
OptimizerBuilder:
optimizer:
momentum: 0.9
type: Momentum
regularizer:
factor: 0.0001
type: L2
TrainReader:
inputs_def:
fields: ['image', 'im_info', 'im_id', 'gt_bbox', 'gt_class', 'is_crowd']
dataset:
!COCODataSet
image_dir: train2017
anno_path: annotations/instances_train2017.json
dataset_dir: dataset/coco
sample_transforms:
- !DecodeImage
to_rgb: true
- !RandomFlipImage
prob: 0.5
- !AutoAugmentImage
autoaug_type: v1
- !NormalizeImage
is_channel_first: false
is_scale: true
mean: [0.485,0.456,0.406]
std: [0.229, 0.224,0.225]
- !ResizeImage
target_size: [384, 416, 448, 480, 512, 544, 576, 608, 640, 672]
max_size: 1000
interp: 1
use_cv2: true
- !Permute
to_bgr: false
channel_first: true
batch_transforms:
- !PadBatch
pad_to_stride: 32
use_padded_im_info: false
batch_size: 2
shuffle: true
worker_num: 2
use_process: false
TestReader:
inputs_def:
# set image_shape if needed
fields: ['image', 'im_info', 'im_id', 'im_shape']
dataset:
!ImageFolder
anno_path: annotations/instances_val2017.json
sample_transforms:
- !DecodeImage
to_rgb: true
with_mixup: false
- !NormalizeImage
is_channel_first: false
is_scale: true
mean: [0.485,0.456,0.406]
std: [0.229, 0.224,0.225]
- !ResizeImage
interp: 1
max_size: 640
target_size: 640
use_cv2: true
- !Permute
channel_first: true
to_bgr: false
batch_transforms:
- !PadBatch
pad_to_stride: 32
use_padded_im_info: true
batch_size: 1
shuffle: false
EvalReader:
inputs_def:
fields: ['image', 'im_info', 'im_id', 'im_shape']
# for voc
#fields: ['image', 'im_info', 'im_id', 'gt_bbox', 'gt_class', 'is_difficult']
dataset:
!COCODataSet
image_dir: val2017
anno_path: annotations/instances_val2017.json
dataset_dir: dataset/coco
sample_transforms:
- !DecodeImage
to_rgb: true
with_mixup: false
- !NormalizeImage
is_channel_first: false
is_scale: true
mean: [0.485,0.456,0.406]
std: [0.229, 0.224,0.225]
- !ResizeImage
interp: 1
max_size: 640
target_size: 640
use_cv2: true
- !Permute
channel_first: true
to_bgr: false
batch_transforms:
- !PadBatch
pad_to_stride: 32
use_padded_im_info: true
batch_size: 1
shuffle: false
drop_empty: false
worker_num: 2
...@@ -53,6 +53,9 @@ class ResNet(object): ...@@ -53,6 +53,9 @@ class ResNet(object):
gcb_params (dict): gc blocks config, includes ratio(default as 1.0/16), gcb_params (dict): gc blocks config, includes ratio(default as 1.0/16),
pooling_type(default as "att") and pooling_type(default as "att") and
fusion_types(default as ['channel_add']) fusion_types(default as ['channel_add'])
lr_mult_list (list): learning rate ratio of different resnet stages(2,3,4,5),
lower learning rate ratio is need for pretrained model
got using distillation(default as [1.0, 1.0, 1.0, 1.0]).
""" """
__shared__ = ['norm_type', 'freeze_norm', 'weight_prefix_name'] __shared__ = ['norm_type', 'freeze_norm', 'weight_prefix_name']
...@@ -68,7 +71,8 @@ class ResNet(object): ...@@ -68,7 +71,8 @@ class ResNet(object):
weight_prefix_name='', weight_prefix_name='',
nonlocal_stages=[], nonlocal_stages=[],
gcb_stages=[], gcb_stages=[],
gcb_params=dict()): gcb_params=dict(),
lr_mult_list=[1., 1., 1., 1.]):
super(ResNet, self).__init__() super(ResNet, self).__init__()
if isinstance(feature_maps, Integral): if isinstance(feature_maps, Integral):
...@@ -82,6 +86,9 @@ class ResNet(object): ...@@ -82,6 +86,9 @@ class ResNet(object):
assert norm_type in ['bn', 'sync_bn', 'affine_channel'] assert norm_type in ['bn', 'sync_bn', 'affine_channel']
assert not (len(nonlocal_stages)>0 and depth<50), \ assert not (len(nonlocal_stages)>0 and depth<50), \
"non-local is not supported for resnet18 or resnet34" "non-local is not supported for resnet18 or resnet34"
assert len(lr_mult_list
) == 4, "lr_mult_list length must be 4 but got {}".format(
len(lr_mult_list))
self.depth = depth self.depth = depth
self.freeze_at = freeze_at self.freeze_at = freeze_at
...@@ -116,6 +123,10 @@ class ResNet(object): ...@@ -116,6 +123,10 @@ class ResNet(object):
self.gcb_stages = gcb_stages self.gcb_stages = gcb_stages
self.gcb_params = gcb_params self.gcb_params = gcb_params
self.lr_mult_list = lr_mult_list
# var denoting curr stage
self.stage_num = -1
def _conv_offset(self, def _conv_offset(self,
input, input,
filter_size, filter_size,
...@@ -148,6 +159,13 @@ class ResNet(object): ...@@ -148,6 +159,13 @@ class ResNet(object):
name=None, name=None,
dcn_v2=False): dcn_v2=False):
_name = self.prefix_name + name if self.prefix_name != '' else name _name = self.prefix_name + name if self.prefix_name != '' else name
# need fine lr for distilled model, default as 1.0
lr_mult = 1.0
mult_idx = max(self.stage_num - 2, 0)
mult_idx = min(self.stage_num - 2, 3)
lr_mult = self.lr_mult_list[mult_idx]
if not dcn_v2: if not dcn_v2:
conv = fluid.layers.conv2d( conv = fluid.layers.conv2d(
input=input, input=input,
...@@ -157,7 +175,8 @@ class ResNet(object): ...@@ -157,7 +175,8 @@ class ResNet(object):
padding=(filter_size - 1) // 2, padding=(filter_size - 1) // 2,
groups=groups, groups=groups,
act=None, act=None,
param_attr=ParamAttr(name=_name + "_weights"), param_attr=ParamAttr(
name=_name + "_weights", learning_rate=lr_mult),
bias_attr=False, bias_attr=False,
name=_name + '.conv2d.output.1') name=_name + '.conv2d.output.1')
else: else:
...@@ -187,14 +206,15 @@ class ResNet(object): ...@@ -187,14 +206,15 @@ class ResNet(object):
groups=groups, groups=groups,
deformable_groups=1, deformable_groups=1,
im2col_step=1, im2col_step=1,
param_attr=ParamAttr(name=_name + "_weights"), param_attr=ParamAttr(
name=_name + "_weights", learning_rate=lr_mult),
bias_attr=False, bias_attr=False,
name=_name + ".conv2d.output.1") name=_name + ".conv2d.output.1")
bn_name = self.na.fix_conv_norm_name(name) bn_name = self.na.fix_conv_norm_name(name)
bn_name = self.prefix_name + bn_name if self.prefix_name != '' else bn_name bn_name = self.prefix_name + bn_name if self.prefix_name != '' else bn_name
norm_lr = 0. if self.freeze_norm else 1. norm_lr = 0. if self.freeze_norm else lr_mult
norm_decay = self.norm_decay norm_decay = self.norm_decay
pattr = ParamAttr( pattr = ParamAttr(
name=bn_name + '_scale', name=bn_name + '_scale',
...@@ -365,6 +385,8 @@ class ResNet(object): ...@@ -365,6 +385,8 @@ class ResNet(object):
""" """
assert stage_num in [2, 3, 4, 5] assert stage_num in [2, 3, 4, 5]
self.stage_num = stage_num
stages, block_func = self.depth_cfg[self.depth] stages, block_func = self.depth_cfg[self.depth]
count = stages[stage_num - 2] count = stages[stage_num - 2]
......
...@@ -552,6 +552,8 @@ class BBoxAssigner(object): ...@@ -552,6 +552,8 @@ class BBoxAssigner(object):
@register @register
class LibraBBoxAssigner(object): class LibraBBoxAssigner(object):
__shared__ = ['num_classes']
def __init__(self, def __init__(self,
batch_size_per_im=512, batch_size_per_im=512,
fg_fraction=.25, fg_fraction=.25,
...@@ -797,6 +799,7 @@ class LibraBBoxAssigner(object): ...@@ -797,6 +799,7 @@ class LibraBBoxAssigner(object):
hs = boxes[:, 3] - boxes[:, 1] + 1 hs = boxes[:, 3] - boxes[:, 1] + 1
keep = np.where((ws > 0) & (hs > 0))[0] keep = np.where((ws > 0) & (hs > 0))[0]
boxes = boxes[keep] boxes = boxes[keep]
max_overlaps = max_overlaps[keep]
fg_inds = np.where(max_overlaps >= fg_thresh)[0] fg_inds = np.where(max_overlaps >= fg_thresh)[0]
bg_inds = np.where((max_overlaps < bg_thresh_hi) & ( bg_inds = np.where((max_overlaps < bg_thresh_hi) & (
max_overlaps >= bg_thresh_lo))[0] max_overlaps >= bg_thresh_lo))[0]
......
...@@ -23,6 +23,7 @@ from paddle.fluid.initializer import MSRA ...@@ -23,6 +23,7 @@ from paddle.fluid.initializer import MSRA
from ppdet.modeling.ops import MultiClassNMS from ppdet.modeling.ops import MultiClassNMS
from ppdet.modeling.ops import ConvNorm from ppdet.modeling.ops import ConvNorm
from ppdet.modeling.losses import SmoothL1Loss
from ppdet.core.workspace import register from ppdet.core.workspace import register
__all__ = ['CascadeBBoxHead'] __all__ = ['CascadeBBoxHead']
...@@ -38,16 +39,24 @@ class CascadeBBoxHead(object): ...@@ -38,16 +39,24 @@ class CascadeBBoxHead(object):
nms (object): `MultiClassNMS` instance nms (object): `MultiClassNMS` instance
num_classes: number of output classes num_classes: number of output classes
""" """
__inject__ = ['head', 'nms'] __inject__ = ['head', 'nms', 'bbox_loss']
__shared__ = ['num_classes'] __shared__ = ['num_classes']
def __init__(self, head, nms=MultiClassNMS().__dict__, num_classes=81): def __init__(
self,
head,
nms=MultiClassNMS().__dict__,
bbox_loss=SmoothL1Loss().__dict__,
num_classes=81, ):
super(CascadeBBoxHead, self).__init__() super(CascadeBBoxHead, self).__init__()
self.head = head self.head = head
self.nms = nms self.nms = nms
self.bbox_loss = bbox_loss
self.num_classes = num_classes self.num_classes = num_classes
if isinstance(nms, dict): if isinstance(nms, dict):
self.nms = MultiClassNMS(**nms) self.nms = MultiClassNMS(**nms)
if isinstance(bbox_loss, dict):
self.bbox_loss = SmoothL1Loss(**bbox_loss)
def get_output(self, def get_output(self,
roi_feat, roi_feat,
...@@ -123,13 +132,11 @@ class CascadeBBoxHead(object): ...@@ -123,13 +132,11 @@ class CascadeBBoxHead(object):
loss_cls = fluid.layers.reduce_mean( loss_cls = fluid.layers.reduce_mean(
loss_cls, name='loss_cls_' + str(i)) * rcnn_loss_weight_list[i] loss_cls, name='loss_cls_' + str(i)) * rcnn_loss_weight_list[i]
loss_bbox = fluid.layers.smooth_l1( loss_bbox = self.bbox_loss(
x=rcnn_pred[1], x=rcnn_pred[1],
y=rcnn_target[2], y=rcnn_target[2],
inside_weight=rcnn_target[3], inside_weight=rcnn_target[3],
outside_weight=rcnn_target[4], outside_weight=rcnn_target[4])
sigma=1.0, # detectron use delta = 1./sigma**2
)
loss_bbox = fluid.layers.reduce_mean( loss_bbox = fluid.layers.reduce_mean(
loss_bbox, loss_bbox,
name='loss_bbox_' + str(i)) * rcnn_loss_weight_list[i] name='loss_bbox_' + str(i)) * rcnn_loss_weight_list[i]
......
...@@ -21,7 +21,11 @@ from paddle import fluid ...@@ -21,7 +21,11 @@ from paddle import fluid
from ppdet.core.workspace import register from ppdet.core.workspace import register
from ppdet.modeling.ops import BBoxAssigner, MaskAssigner from ppdet.modeling.ops import BBoxAssigner, MaskAssigner
__all__ = ['BBoxAssigner', 'MaskAssigner', 'CascadeBBoxAssigner'] __all__ = [
'BBoxAssigner',
'MaskAssigner',
'CascadeBBoxAssigner',
]
@register @register
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册