未验证 提交 3cf5a126 编写于 作者: littletomatodonkey's avatar littletomatodonkey 提交者: GitHub

add server side rcnn model (#433)

* add server side rcnn model, 62 fps with coco mAP 41.6% and 20fps with coco mAP 47.8%
上级 c06f1ea0
# Practical Server-side detection method base on RCNN
## Introduction
* This is developed by PaddleDetection. Many useful tricks are utilized for the model training process. More details can be seen in the configuration file.
## Model Zoo
| Backbone | Type | Image/gpu | Lr schd | Inf time (fps) | Box AP | Mask AP | Download |
| :---------------------- | :-------------: | :-------: | :-----: | :------------: | :----: | :-----: | :----------------------------------------------------------: |
| ResNet50-vd-FPN-Dcnv2 | Faster | 2 | 3x | 61.425 | 41.6 | - | [model](https://paddlemodels.bj.bcebos.com/object_detection/faster_rcnn_dcn_r50_vd_fpn_3x_server_side.tar) |
| ResNet50-vd-FPN-Dcnv2 | Cascade Faster | 2 | 3x | 20.001 | 47.8 | - | [model](https://paddlemodels.bj.bcebos.com/object_detection/cascade_rcnn_dcn_r50_vd_fpn_3x_server_side.tar) |
architecture: CascadeRCNN
max_iters: 270000
snapshot_iter: 30000
use_gpu: true
log_smooth_window: 20
log_iter: 20
save_dir: output
pretrain_weights: https://paddle-imagenet-models-name.bj.bcebos.com/ResNet50_vd_82.39_pretrained.tar
weights: output/cascade_rcnn_dcn_r50_vd_fpn_3x_server_side/model_final
metric: COCO
num_classes: 81
CascadeRCNN:
backbone: ResNet
fpn: FPN
rpn_head: FPNRPNHead
roi_extractor: FPNRoIAlign
bbox_head: CascadeBBoxHead
bbox_assigner: CascadeBBoxAssigner
ResNet:
norm_type: bn
depth: 50
feature_maps: [2, 3, 4, 5]
freeze_at: 2
variant: d
dcn_v2_stages: [3, 4, 5]
lr_mult_list: [0.05, 0.05, 0.1, 0.15]
FPN:
max_level: 6
min_level: 2
num_chan: 64
spatial_scale: [0.03125, 0.0625, 0.125, 0.25]
FPNRPNHead:
anchor_generator:
anchor_sizes: [32, 64, 128, 256, 512]
aspect_ratios: [0.5, 1.0, 2.0]
stride: [16.0, 16.0]
variance: [1.0, 1.0, 1.0, 1.0]
anchor_start_size: 32
min_level: 2
max_level: 6
num_chan: 64
rpn_target_assign:
rpn_batch_size_per_im: 256
rpn_fg_fraction: 0.5
rpn_positive_overlap: 0.7
rpn_negative_overlap: 0.3
rpn_straddle_thresh: 0.0
train_proposal:
min_size: 0.0
nms_thresh: 0.7
pre_nms_top_n: 2000
post_nms_top_n: 2000
test_proposal:
min_size: 0.0
nms_thresh: 0.7
pre_nms_top_n: 500
post_nms_top_n: 300
FPNRoIAlign:
canconical_level: 4
canonical_size: 224
min_level: 2
max_level: 5
box_resolution: 7
sampling_ratio: 2
CascadeBBoxAssigner:
batch_size_per_im: 512
bbox_reg_weights: [10, 20, 30]
bg_thresh_lo: [0.0, 0.0, 0.0]
bg_thresh_hi: [0.5, 0.6, 0.7]
fg_thresh: [0.5, 0.6, 0.7]
fg_fraction: 0.25
CascadeBBoxHead:
head: CascadeTwoFCHead
bbox_loss: BalancedL1Loss
nms:
keep_top_k: 100
nms_threshold: 0.5
score_threshold: 0.05
BalancedL1Loss:
alpha: 0.5
gamma: 1.5
beta: 1.0
loss_weight: 1.0
CascadeTwoFCHead:
mlp_dim: 1024
LearningRate:
base_lr: 0.02
schedulers:
- !PiecewiseDecay
gamma: 0.1
milestones: [180000, 240000]
- !LinearWarmup
start_factor: 0.1
steps: 1000
OptimizerBuilder:
optimizer:
momentum: 0.9
type: Momentum
regularizer:
factor: 0.0001
type: L2
TrainReader:
inputs_def:
fields: ['image', 'im_info', 'im_id', 'gt_bbox', 'gt_class', 'is_crowd']
dataset:
!COCODataSet
image_dir: train2017
anno_path: annotations/instances_train2017.json
dataset_dir: dataset/coco
sample_transforms:
- !DecodeImage
to_rgb: true
- !RandomFlipImage
prob: 0.5
- !AutoAugmentImage
autoaug_type: v1
- !NormalizeImage
is_channel_first: false
is_scale: true
mean: [0.485,0.456,0.406]
std: [0.229, 0.224,0.225]
- !ResizeImage
target_size: [640, 672, 704, 736, 768, 800, 832, 864, 896, 928, 960, 992, 1024]
max_size: 1500
interp: 1
use_cv2: true
- !Permute
to_bgr: false
channel_first: true
batch_transforms:
- !PadBatch
pad_to_stride: 32
use_padded_im_info: false
batch_size: 2
shuffle: true
worker_num: 2
use_process: false
TestReader:
inputs_def:
# set image_shape if needed
fields: ['image', 'im_info', 'im_id', 'im_shape']
dataset:
!ImageFolder
anno_path: annotations/instances_val2017.json
sample_transforms:
- !DecodeImage
to_rgb: true
with_mixup: false
- !NormalizeImage
is_channel_first: false
is_scale: true
mean: [0.485,0.456,0.406]
std: [0.229, 0.224,0.225]
- !ResizeImage
interp: 1
max_size: 1500
target_size: 1000
use_cv2: true
- !Permute
channel_first: true
to_bgr: false
batch_transforms:
- !PadBatch
pad_to_stride: 32
use_padded_im_info: true
batch_size: 1
shuffle: false
EvalReader:
inputs_def:
fields: ['image', 'im_info', 'im_id', 'im_shape']
# for voc
#fields: ['image', 'im_info', 'im_id', 'gt_bbox', 'gt_class', 'is_difficult']
dataset:
!COCODataSet
image_dir: val2017
anno_path: annotations/instances_val2017.json
dataset_dir: dataset/coco
sample_transforms:
- !DecodeImage
to_rgb: true
with_mixup: false
- !NormalizeImage
is_channel_first: false
is_scale: true
mean: [0.485,0.456,0.406]
std: [0.229, 0.224,0.225]
- !ResizeImage
interp: 1
max_size: 1500
target_size: 1000
use_cv2: true
- !Permute
channel_first: true
to_bgr: false
batch_transforms:
- !PadBatch
pad_to_stride: 32
use_padded_im_info: true
batch_size: 1
shuffle: false
drop_empty: false
worker_num: 2
architecture: FasterRCNN
max_iters: 270000
snapshot_iter: 30000
use_gpu: true
log_smooth_window: 20
log_iter: 20
save_dir: output
pretrain_weights: https://paddle-imagenet-models-name.bj.bcebos.com/ResNet50_vd_82.39_pretrained.tar
weights: output/faster_rcnn_dcn_r50_vd_fpn_3x_server_side/model_final
metric: COCO
num_classes: 81
FasterRCNN:
backbone: ResNet
fpn: FPN
rpn_head: FPNRPNHead
roi_extractor: FPNRoIAlign
bbox_head: BBoxHead
bbox_assigner: LibraBBoxAssigner
ResNet:
depth: 50
feature_maps: [2, 3, 4, 5]
freeze_at: 2
norm_type: bn
variant: d
dcn_v2_stages: [3, 4, 5]
lr_mult_list: [0.05, 0.05, 0.1, 0.15]
FPN:
max_level: 6
min_level: 2
num_chan: 64
spatial_scale: [0.03125, 0.0625, 0.125, 0.25]
FPNRPNHead:
anchor_generator:
anchor_sizes: [32, 64, 128, 256, 512]
aspect_ratios: [0.5, 1.0, 2.0]
stride: [16.0, 16.0]
variance: [1.0, 1.0, 1.0, 1.0]
anchor_start_size: 32
max_level: 6
min_level: 2
num_chan: 64
rpn_target_assign:
rpn_batch_size_per_im: 256
rpn_fg_fraction: 0.5
rpn_negative_overlap: 0.3
rpn_positive_overlap: 0.7
rpn_straddle_thresh: 0.0
train_proposal:
min_size: 0.0
nms_thresh: 0.7
post_nms_top_n: 2000
pre_nms_top_n: 2000
test_proposal:
min_size: 0.0
nms_thresh: 0.7
post_nms_top_n: 300
pre_nms_top_n: 500
FPNRoIAlign:
canconical_level: 4
canonical_size: 224
max_level: 5
min_level: 2
box_resolution: 7
sampling_ratio: 2
LibraBBoxAssigner:
batch_size_per_im: 512
bbox_reg_weights: [0.1, 0.1, 0.2, 0.2]
bg_thresh_hi: 0.5
bg_thresh_lo: 0.0
fg_fraction: 0.25
fg_thresh: 0.5
BBoxHead:
head: TwoFCHead
nms:
keep_top_k: 100
nms_threshold: 0.5
score_threshold: 0.05
bbox_loss: DiouLoss
DiouLoss:
loss_weight: 10.0
is_cls_agnostic: false
use_complete_iou_loss: true
TwoFCHead:
mlp_dim: 1024
LearningRate:
base_lr: 0.02
schedulers:
- !PiecewiseDecay
gamma: 0.1
milestones: [180000, 240000]
- !LinearWarmup
start_factor: 0.1
steps: 1000
OptimizerBuilder:
optimizer:
momentum: 0.9
type: Momentum
regularizer:
factor: 0.0001
type: L2
TrainReader:
inputs_def:
fields: ['image', 'im_info', 'im_id', 'gt_bbox', 'gt_class', 'is_crowd']
dataset:
!COCODataSet
image_dir: train2017
anno_path: annotations/instances_train2017.json
dataset_dir: dataset/coco
sample_transforms:
- !DecodeImage
to_rgb: true
- !RandomFlipImage
prob: 0.5
- !AutoAugmentImage
autoaug_type: v1
- !NormalizeImage
is_channel_first: false
is_scale: true
mean: [0.485,0.456,0.406]
std: [0.229, 0.224,0.225]
- !ResizeImage
target_size: [384, 416, 448, 480, 512, 544, 576, 608, 640, 672]
max_size: 1000
interp: 1
use_cv2: true
- !Permute
to_bgr: false
channel_first: true
batch_transforms:
- !PadBatch
pad_to_stride: 32
use_padded_im_info: false
batch_size: 2
shuffle: true
worker_num: 2
use_process: false
TestReader:
inputs_def:
# set image_shape if needed
fields: ['image', 'im_info', 'im_id', 'im_shape']
dataset:
!ImageFolder
anno_path: annotations/instances_val2017.json
sample_transforms:
- !DecodeImage
to_rgb: true
with_mixup: false
- !NormalizeImage
is_channel_first: false
is_scale: true
mean: [0.485,0.456,0.406]
std: [0.229, 0.224,0.225]
- !ResizeImage
interp: 1
max_size: 640
target_size: 640
use_cv2: true
- !Permute
channel_first: true
to_bgr: false
batch_transforms:
- !PadBatch
pad_to_stride: 32
use_padded_im_info: true
batch_size: 1
shuffle: false
EvalReader:
inputs_def:
fields: ['image', 'im_info', 'im_id', 'im_shape']
# for voc
#fields: ['image', 'im_info', 'im_id', 'gt_bbox', 'gt_class', 'is_difficult']
dataset:
!COCODataSet
image_dir: val2017
anno_path: annotations/instances_val2017.json
dataset_dir: dataset/coco
sample_transforms:
- !DecodeImage
to_rgb: true
with_mixup: false
- !NormalizeImage
is_channel_first: false
is_scale: true
mean: [0.485,0.456,0.406]
std: [0.229, 0.224,0.225]
- !ResizeImage
interp: 1
max_size: 640
target_size: 640
use_cv2: true
- !Permute
channel_first: true
to_bgr: false
batch_transforms:
- !PadBatch
pad_to_stride: 32
use_padded_im_info: true
batch_size: 1
shuffle: false
drop_empty: false
worker_num: 2
......@@ -53,6 +53,9 @@ class ResNet(object):
gcb_params (dict): gc blocks config, includes ratio(default as 1.0/16),
pooling_type(default as "att") and
fusion_types(default as ['channel_add'])
lr_mult_list (list): learning rate ratio of different resnet stages(2,3,4,5),
lower learning rate ratio is need for pretrained model
got using distillation(default as [1.0, 1.0, 1.0, 1.0]).
"""
__shared__ = ['norm_type', 'freeze_norm', 'weight_prefix_name']
......@@ -68,7 +71,8 @@ class ResNet(object):
weight_prefix_name='',
nonlocal_stages=[],
gcb_stages=[],
gcb_params=dict()):
gcb_params=dict(),
lr_mult_list=[1., 1., 1., 1.]):
super(ResNet, self).__init__()
if isinstance(feature_maps, Integral):
......@@ -82,6 +86,9 @@ class ResNet(object):
assert norm_type in ['bn', 'sync_bn', 'affine_channel']
assert not (len(nonlocal_stages)>0 and depth<50), \
"non-local is not supported for resnet18 or resnet34"
assert len(lr_mult_list
) == 4, "lr_mult_list length must be 4 but got {}".format(
len(lr_mult_list))
self.depth = depth
self.freeze_at = freeze_at
......@@ -116,6 +123,10 @@ class ResNet(object):
self.gcb_stages = gcb_stages
self.gcb_params = gcb_params
self.lr_mult_list = lr_mult_list
# var denoting curr stage
self.stage_num = -1
def _conv_offset(self,
input,
filter_size,
......@@ -148,6 +159,13 @@ class ResNet(object):
name=None,
dcn_v2=False):
_name = self.prefix_name + name if self.prefix_name != '' else name
# need fine lr for distilled model, default as 1.0
lr_mult = 1.0
mult_idx = max(self.stage_num - 2, 0)
mult_idx = min(self.stage_num - 2, 3)
lr_mult = self.lr_mult_list[mult_idx]
if not dcn_v2:
conv = fluid.layers.conv2d(
input=input,
......@@ -157,7 +175,8 @@ class ResNet(object):
padding=(filter_size - 1) // 2,
groups=groups,
act=None,
param_attr=ParamAttr(name=_name + "_weights"),
param_attr=ParamAttr(
name=_name + "_weights", learning_rate=lr_mult),
bias_attr=False,
name=_name + '.conv2d.output.1')
else:
......@@ -187,14 +206,15 @@ class ResNet(object):
groups=groups,
deformable_groups=1,
im2col_step=1,
param_attr=ParamAttr(name=_name + "_weights"),
param_attr=ParamAttr(
name=_name + "_weights", learning_rate=lr_mult),
bias_attr=False,
name=_name + ".conv2d.output.1")
bn_name = self.na.fix_conv_norm_name(name)
bn_name = self.prefix_name + bn_name if self.prefix_name != '' else bn_name
norm_lr = 0. if self.freeze_norm else 1.
norm_lr = 0. if self.freeze_norm else lr_mult
norm_decay = self.norm_decay
pattr = ParamAttr(
name=bn_name + '_scale',
......@@ -365,6 +385,8 @@ class ResNet(object):
"""
assert stage_num in [2, 3, 4, 5]
self.stage_num = stage_num
stages, block_func = self.depth_cfg[self.depth]
count = stages[stage_num - 2]
......
......@@ -552,6 +552,8 @@ class BBoxAssigner(object):
@register
class LibraBBoxAssigner(object):
__shared__ = ['num_classes']
def __init__(self,
batch_size_per_im=512,
fg_fraction=.25,
......@@ -797,6 +799,7 @@ class LibraBBoxAssigner(object):
hs = boxes[:, 3] - boxes[:, 1] + 1
keep = np.where((ws > 0) & (hs > 0))[0]
boxes = boxes[keep]
max_overlaps = max_overlaps[keep]
fg_inds = np.where(max_overlaps >= fg_thresh)[0]
bg_inds = np.where((max_overlaps < bg_thresh_hi) & (
max_overlaps >= bg_thresh_lo))[0]
......
......@@ -23,6 +23,7 @@ from paddle.fluid.initializer import MSRA
from ppdet.modeling.ops import MultiClassNMS
from ppdet.modeling.ops import ConvNorm
from ppdet.modeling.losses import SmoothL1Loss
from ppdet.core.workspace import register
__all__ = ['CascadeBBoxHead']
......@@ -38,16 +39,24 @@ class CascadeBBoxHead(object):
nms (object): `MultiClassNMS` instance
num_classes: number of output classes
"""
__inject__ = ['head', 'nms']
__inject__ = ['head', 'nms', 'bbox_loss']
__shared__ = ['num_classes']
def __init__(self, head, nms=MultiClassNMS().__dict__, num_classes=81):
def __init__(
self,
head,
nms=MultiClassNMS().__dict__,
bbox_loss=SmoothL1Loss().__dict__,
num_classes=81, ):
super(CascadeBBoxHead, self).__init__()
self.head = head
self.nms = nms
self.bbox_loss = bbox_loss
self.num_classes = num_classes
if isinstance(nms, dict):
self.nms = MultiClassNMS(**nms)
if isinstance(bbox_loss, dict):
self.bbox_loss = SmoothL1Loss(**bbox_loss)
def get_output(self,
roi_feat,
......@@ -123,13 +132,11 @@ class CascadeBBoxHead(object):
loss_cls = fluid.layers.reduce_mean(
loss_cls, name='loss_cls_' + str(i)) * rcnn_loss_weight_list[i]
loss_bbox = fluid.layers.smooth_l1(
loss_bbox = self.bbox_loss(
x=rcnn_pred[1],
y=rcnn_target[2],
inside_weight=rcnn_target[3],
outside_weight=rcnn_target[4],
sigma=1.0, # detectron use delta = 1./sigma**2
)
outside_weight=rcnn_target[4])
loss_bbox = fluid.layers.reduce_mean(
loss_bbox,
name='loss_bbox_' + str(i)) * rcnn_loss_weight_list[i]
......
......@@ -21,7 +21,11 @@ from paddle import fluid
from ppdet.core.workspace import register
from ppdet.modeling.ops import BBoxAssigner, MaskAssigner
__all__ = ['BBoxAssigner', 'MaskAssigner', 'CascadeBBoxAssigner']
__all__ = [
'BBoxAssigner',
'MaskAssigner',
'CascadeBBoxAssigner',
]
@register
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册