未验证 提交 26bf8416 编写于 作者: C cnn 提交者: GitHub

[dygraph] add faster_rcnn_r50_fpn and faster_rcnn_r50 (#1843)

* add faster_rcnn_r50_fpn and faster_rcnn_r50

* update TestReader in dygraph

* if gt_ploy is empty, continue

* fix typo

* judge gt_ploy is empty or not

* fix judge condition

* fix gt_ploy judge
上级 01a14cc7
architecture: FasterRCNN
pretrain_weights: https://paddle-imagenet-models-name.bj.bcebos.com/ResNet50_cos_pretrained.tar
weights: output/faster_rcnn_r50_1x_coco/model_final.pdparams
load_static_weights: True
# Model Achitecture
FasterRCNN:
# model anchor info flow
anchor: Anchor
proposal: Proposal
# model feat info flow
backbone: ResNet
rpn_head: RPNHead
bbox_head: BBoxHead
# post process
bbox_post_process: BBoxPostProcess
ResNet:
# index 0 stands for res2
depth: 50
norm_type: bn
freeze_at: 0
return_idx: [2]
num_stages: 3
RPNHead:
rpn_feat:
name: RPNFeat
feat_in: 1024
feat_out: 1024
anchor_per_position: 15
rpn_channel: 1024
Anchor:
anchor_generator:
name: AnchorGeneratorRPN
anchor_sizes: [32, 64, 128, 256, 512]
aspect_ratios: [0.5, 1.0, 2.0]
stride: [16.0, 16.0]
variance: [1.0, 1.0, 1.0, 1.0]
anchor_target_generator:
name: AnchorTargetGeneratorRPN
batch_size_per_im: 256
fg_fraction: 0.5
negative_overlap: 0.3
positive_overlap: 0.7
straddle_thresh: 0.0
Proposal:
proposal_generator:
name: ProposalGenerator
min_size: 0.0
nms_thresh: 0.7
train_pre_nms_top_n: 12000
train_post_nms_top_n: 2000
infer_pre_nms_top_n: 6000
infer_post_nms_top_n: 1000
proposal_target_generator:
name: ProposalTargetGenerator
batch_size_per_im: 512
bbox_reg_weights: [[0.1, 0.1, 0.2, 0.2],]
bg_thresh_hi: [0.5,]
bg_thresh_lo: [0.0,]
fg_thresh: [0.5,]
fg_fraction: 0.25
BBoxHead:
bbox_feat:
name: BBoxFeat
roi_extractor:
name: RoIAlign
resolution: 14
sampling_ratio: 0
start_level: 0
end_level: 0
head_feat:
name: Res5Head
feat_in: 1024
feat_out: 512
with_pool: true
in_feat: 2048
BBoxPostProcess:
decode:
name: RCNNBox
num_classes: 81
batch_size: 1
nms:
name: MultiClassNMS
keep_top_k: 100
score_threshold: 0.05
nms_threshold: 0.5
architecture: FasterRCNN
pretrain_weights: https://paddle-imagenet-models-name.bj.bcebos.com/ResNet50_cos_pretrained.tar
weights: output/faster_rcnn_r50_fpn_1x_coco/model_final.pdparams
load_static_weights: True
# Model Achitecture
FasterRCNN:
# model anchor info flow
anchor: Anchor
proposal: Proposal
# model feat info flow
backbone: ResNet
neck: FPN
rpn_head: RPNHead
bbox_head: BBoxHead
# post process
bbox_post_process: BBoxPostProcess
ResNet:
# index 0 stands for res2
depth: 50
norm_type: bn
freeze_at: 0
return_idx: [0,1,2,3]
num_stages: 4
FPN:
in_channels: [256, 512, 1024, 2048]
out_channel: 256
min_level: 0
max_level: 4
spatial_scale: [0.25, 0.125, 0.0625, 0.03125]
RPNHead:
rpn_feat:
name: RPNFeat
feat_in: 256
feat_out: 256
anchor_per_position: 3
rpn_channel: 256
Anchor:
anchor_generator:
name: AnchorGeneratorRPN
aspect_ratios: [0.5, 1.0, 2.0]
anchor_start_size: 32
stride: [4., 4.]
anchor_target_generator:
name: AnchorTargetGeneratorRPN
batch_size_per_im: 256
fg_fraction: 0.5
negative_overlap: 0.3
positive_overlap: 0.7
straddle_thresh: 0.0
Proposal:
proposal_generator:
name: ProposalGenerator
min_size: 0.0
nms_thresh: 0.7
train_pre_nms_top_n: 2000
train_post_nms_top_n: 2000
infer_pre_nms_top_n: 1000
infer_post_nms_top_n: 1000
proposal_target_generator:
name: ProposalTargetGenerator
batch_size_per_im: 512
bbox_reg_weights: [[0.1, 0.1, 0.2, 0.2],]
bg_thresh_hi: [0.5,]
bg_thresh_lo: [0.0,]
fg_thresh: [0.5,]
fg_fraction: 0.25
BBoxHead:
bbox_feat:
name: BBoxFeat
roi_extractor:
name: RoIAlign
resolution: 7
sampling_ratio: 2
head_feat:
name: TwoFCHead
in_dim: 256
mlp_dim: 1024
in_feat: 1024
BBoxPostProcess:
decode:
name: RCNNBox
num_classes: 81
batch_size: 1
nms:
name: MultiClassNMS
keep_top_k: 100
score_threshold: 0.05
nms_threshold: 0.5
worker_num: 2
TrainReader:
inputs_def:
fields: ['image', 'im_info', 'im_id', 'gt_bbox', 'gt_class', 'is_crowd']
sample_transforms:
- DecodeImage: {to_rgb: true}
# check
- RandomFlipImage: {prob: 0.5}
- NormalizeImage: {is_channel_first: false, is_scale: true, mean: [0.485,0.456,0.406], std: [0.229, 0.224,0.225]}
- ResizeImage: {target_size: 800, max_size: 1333, interp: 1, use_cv2: true}
- Permute: {to_bgr: false, channel_first: true}
batch_transforms:
- PadBatch: {pad_to_stride: 32, use_padded_im_info: false, pad_gt: true}
batch_size: 1
shuffle: true
drop_last: true
EvalReader:
inputs_def:
fields: ['image', 'im_shape', 'scale_factor', 'im_id']
sample_transforms:
- DecodeOp: { }
- NormalizeImageOp: { is_scale: true, mean: [ 0.485,0.456,0.406 ], std: [ 0.229, 0.224,0.225 ] }
- ResizeOp: { interp: 1, target_size: [ 800, 1333 ] }
- PermuteOp: { }
batch_transforms:
- PadBatchOp: { pad_to_stride: 32, pad_gt: false }
batch_size: 1
shuffle: false
drop_last: false
drop_empty: false
TestReader:
inputs_def:
fields: ['image', 'im_shape', 'scale_factor', 'im_id']
sample_transforms:
- DecodeOp: { }
- NormalizeImageOp: { is_scale: true, mean: [ 0.485,0.456,0.406 ], std: [ 0.229, 0.224,0.225 ] }
- ResizeOp: { interp: 1, target_size: [ 800, 1333 ] }
- PermuteOp: { }
batch_transforms:
- PadBatchOp: { pad_to_stride: 32, pad_gt: false }
batch_size: 1
shuffle: false
drop_last: false
worker_num: 2
TrainReader:
inputs_def:
fields: ['image', 'im_info', 'im_id', 'gt_bbox', 'gt_class', 'is_crowd']
sample_transforms:
- !DecodeImage
to_rgb: True
- !RandomFlipImage
prob: 0.5
- !NormalizeImage
is_channel_first: false
is_scale: true
mean: [0.485,0.456,0.406]
std: [0.229, 0.224,0.225]
- !ResizeImage
target_size: 800
max_size: 1333
interp: 1
use_cv2: true
- !Permute
to_bgr: false
channel_first: true
- DecodeImage: {to_rgb: true}
- RandomFlipImage: {prob: 0.5}
- NormalizeImage: {is_channel_first: false, is_scale: true, mean: [0.485,0.456,0.406], std: [0.229, 0.224,0.225]}
- ResizeImage: {target_size: 800, max_size: 1333, interp: 1, use_cv2: true}
- Permute: {to_bgr: false, channel_first: true}
batch_transforms:
- !PadBatch
pad_to_stride: 0
use_padded_im_info: False
pad_gt: true
- PadBatch: {pad_to_stride: -1, use_padded_im_info: false, pad_gt: true}
batch_size: 1
shuffle: true
worker_num: 2
use_process: false
drop_last: true
EvalReader:
inputs_def:
fields: ['image', 'im_info', 'im_id', 'im_shape']
fields: ['image', 'im_shape', 'scale_factor', 'im_id']
sample_transforms:
- !DecodeImage
to_rgb: true
- !NormalizeImage
is_channel_first: false
is_scale: true
mean: [0.485,0.456,0.406]
std: [0.229, 0.224,0.225]
- !ResizeImage
interp: 1
max_size: 1333
target_size: 800
use_cv2: true
- !Permute
channel_first: true
to_bgr: false
- DecodeOp: { }
- NormalizeImageOp: { is_scale: true, mean: [ 0.485,0.456,0.406 ], std: [ 0.229, 0.224,0.225 ] }
- ResizeOp: { interp: 1, target_size: [ 800, 1333 ] }
- PermuteOp: { }
batch_transforms:
- !PadBatch
pad_to_stride: 32
use_padded_im_info: false
pad_gt: True
batch_size: 2
- PadBatchOp: { pad_to_stride: -1, pad_gt: false }
batch_size: 1
shuffle: false
drop_last: false
drop_empty: false
worker_num: 2
TestReader:
inputs_def:
fields: ['image', 'im_info', 'im_id', 'im_shape']
fields: ['image', 'im_shape', 'scale_factor', 'im_id']
sample_transforms:
- !DecodeImage
to_rgb: true
with_mixup: false
- !NormalizeImage
is_channel_first: false
is_scale: true
mean: [0.485,0.456,0.406]
std: [0.229, 0.224,0.225]
- !ResizeImage
interp: 1
max_size: 1333
target_size: 800
use_cv2: true
- !Permute
channel_first: true
to_bgr: false
- DecodeOp: { }
- NormalizeImageOp: { is_scale: true, mean: [ 0.485,0.456,0.406 ], std: [ 0.229, 0.224,0.225 ] }
- ResizeOp: { interp: 1, target_size: [ 800, 1333 ] }
- PermuteOp: { }
batch_transforms:
- PadBatchOp: { pad_to_stride: -1, pad_gt: false }
batch_size: 1
shuffle: false
drop_last: false
_BASE_: [
'./_base_/models/faster_rcnn_r50.yml',
'./_base_/optimizers/rcnn_1x.yml',
'./_base_/datasets/coco.yml',
'./_base_/readers/faster_reader.yml',
'./_base_/runtime.yml',
]
_BASE_: [
'./_base_/models/faster_rcnn_r50_fpn.yml',
'./_base_/optimizers/rcnn_1x.yml',
'./_base_/datasets/coco.yml',
'./_base_/readers/faster_fpn_reader.yml',
'./_base_/runtime.yml',
]
......@@ -114,7 +114,10 @@ class COCODataSet(DetDataset):
'Found an invalid bbox in annotations: im_id: {}, '
'area: {} x1: {}, y1: {}, x2: {}, y2: {}.'.format(
img_id, float(inst['area']), x1, y1, x2, y2))
num_bbox = len(bboxes)
if num_bbox <= 0:
continue
gt_bbox = np.zeros((num_bbox, 4), dtype=np.float32)
gt_class = np.zeros((num_bbox, 1), dtype=np.int32)
......@@ -123,6 +126,7 @@ class COCODataSet(DetDataset):
difficult = np.zeros((num_bbox, 1), dtype=np.int32)
gt_poly = [None] * num_bbox
has_segmentation = False
for i, box in enumerate(bboxes):
catid = box['category_id']
gt_class[i][0] = catid2clsid[catid]
......@@ -133,8 +137,9 @@ class COCODataSet(DetDataset):
gt_poly[i] = [[0.0, 0.0], ]
elif 'segmentation' in box:
gt_poly[i] = box['segmentation']
has_segmentation = True
if not any(gt_poly):
if has_segmentation and not any(gt_poly):
continue
coco_rec.update({
......
......@@ -2,7 +2,7 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from paddle import fluid
import paddle
from ppdet.core.workspace import register
from .meta_arch import BaseArch
......@@ -13,66 +13,91 @@ __all__ = ['FasterRCNN']
class FasterRCNN(BaseArch):
__category__ = 'architecture'
__inject__ = [
'anchor',
'proposal',
'backbone',
'rpn_head',
'bbox_head',
'anchor', 'proposal', 'backbone', 'neck', 'rpn_head', 'bbox_head',
'bbox_post_process'
]
def __init__(self, anchor, proposal, backbone, rpn_head, bbox_head, *args,
**kwargs):
super(FasterRCNN, self).__init__(*args, **kwargs)
def __init__(self,
anchor,
proposal,
backbone,
rpn_head,
bbox_head,
bbox_post_process,
neck=None):
super(FasterRCNN, self).__init__()
self.anchor = anchor
self.proposal = proposal
self.backbone = backbone
self.rpn_head = rpn_head
self.bbox_head = bbox_head
self.bbox_post_process = bbox_post_process
self.neck = neck
def model_arch(self, ):
def model_arch(self):
# Backbone
bb_out = self.backbone(self.gbd)
self.gbd.update(bb_out)
body_feats = self.backbone(self.inputs)
spatial_scale = 0.0625
# Neck
if self.neck is not None:
body_feats, spatial_scale = self.neck(body_feats)
# RPN
rpn_head_out = self.rpn_head(self.gbd)
self.gbd.update(rpn_head_out)
# rpn_head returns two list: rpn_feat, rpn_head_out
# each element in rpn_feats contains rpn feature on each level,
# and the length is 1 when the neck is not applied.
# each element in rpn_head_out contains (rpn_rois_score, rpn_rois_delta)
rpn_feat, self.rpn_head_out = self.rpn_head(self.inputs, body_feats)
# Anchor
anchor_out = self.anchor(self.gbd)
self.gbd.update(anchor_out)
# Proposal BBox
self.gbd['stage'] = 0
proposal_out = self.proposal(self.gbd)
self.gbd.update({'proposal_0': proposal_out})
# anchor_out returns a list,
# each element contains (anchor, anchor_var)
self.anchor_out = self.anchor(rpn_feat)
# Proposal RoI
# compute targets here when training
rois = self.proposal(self.inputs, self.rpn_head_out, self.anchor_out)
# BBox Head
bboxhead_out = self.bbox_head(self.gbd)
self.gbd.update({'bbox_head_0': bboxhead_out})
bbox_feat, self.bbox_head_out, self.bbox_head_feat_func = self.bbox_head(
body_feats, rois, spatial_scale)
if self.gbd['mode'] == 'infer':
bbox_out = self.proposal.post_process(self.gbd)
self.gbd.update(bbox_out)
if self.inputs['mode'] == 'infer':
bbox_pred, bboxes = self.bbox_head.get_prediction(
self.bbox_head_out, rois)
# Refine bbox by the output from bbox_head at test stage
self.bboxes = self.bbox_post_process(bbox_pred, bboxes,
self.inputs['im_shape'],
self.inputs['scale_factor'])
else:
# Proposal RoI for Mask branch
# bboxes update at training stage only
bbox_targets = self.proposal.get_targets()[0]
def get_loss(self, ):
rpn_cls_loss, rpn_reg_loss = self.rpn_head.get_loss(self.gbd)
bbox_cls_loss, bbox_reg_loss = self.bbox_head.get_loss(self.gbd)
losses = [rpn_cls_loss, rpn_reg_loss, bbox_cls_loss, bbox_reg_loss]
loss = fluid.layers.sum(losses)
out = {
'loss': loss,
'loss_rpn_cls': rpn_cls_loss,
'loss_rpn_reg': rpn_reg_loss,
'loss_bbox_cls': bbox_cls_loss,
'loss_bbox_reg': bbox_reg_loss,
}
return out
loss = {}
def get_pred(self, ):
outs = {
"bbox": self.gbd['predicted_bbox'].numpy(),
"bbox_nums": self.gbd['predicted_bbox_nums'].numpy(),
'im_id': self.gbd['im_id'].numpy()
# RPN loss
rpn_loss_inputs = self.anchor.generate_loss_inputs(
self.inputs, self.rpn_head_out, self.anchor_out)
loss_rpn = self.rpn_head.get_loss(rpn_loss_inputs)
loss.update(loss_rpn)
# BBox loss
bbox_targets = self.proposal.get_targets()
loss_bbox = self.bbox_head.get_loss(self.bbox_head_out, bbox_targets)
loss.update(loss_bbox)
total_loss = paddle.add_n(list(loss.values()))
loss.update({'loss': total_loss})
return loss
def get_pred(self, return_numpy=True):
bbox, bbox_num = self.bboxes
output = {
'bbox': bbox.numpy(),
'bbox_num': bbox_num.numpy(),
'im_id': self.inputs['im_id'].numpy()
}
return outs
return output
......@@ -36,7 +36,6 @@ class Anchor(object):
anchor = fluid.layers.reshape(anchor, shape=(-1, 4))
var = fluid.layers.reshape(var, shape=(-1, 4))
rpn_score_list.append(rpn_score)
rpn_delta_list.append(rpn_delta)
anchor_list.append(anchor)
......
......@@ -36,7 +36,7 @@ class RoIAlign(object):
def __call__(self, feats, rois, spatial_scale):
roi, rois_num = rois
cur_l = 0
if self.start_level == self.end_level:
rois_feat = ops.roi_align(
feats[self.start_level],
......@@ -55,7 +55,6 @@ class RoIAlign(object):
self.canconical_level,
self.canonical_size,
rois_num=rois_num)
rois_feat_list = []
for lvl in range(self.start_level, self.end_level + 1):
roi_feat = ops.roi_align(
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册