diff --git a/configs/cascade_mask_rcnn_r50_fpn_1x.yml b/configs/cascade_mask_rcnn_r50_fpn_1x.yml new file mode 100644 index 0000000000000000000000000000000000000000..9a7b7a8dce4d1b4e48ed794e8cc2a458cbf363b3 --- /dev/null +++ b/configs/cascade_mask_rcnn_r50_fpn_1x.yml @@ -0,0 +1,145 @@ +architecture: CascadeMaskRCNN +train_feed: MaskRCNNTrainFeed +eval_feed: MaskRCNNEvalFeed +test_feed: MaskRCNNTestFeed +use_gpu: true +max_iters: 180000 +snapshot_iter: 10000 +log_smooth_window: 20 +save_dir: output +pretrain_weights: https://paddle-imagenet-models-name.bj.bcebos.com/ResNet50_cos_pretrained.tar +metric: COCO +weights: output/cascade_mask_rcnn_r50_fpn_1x/model_final/ +num_classes: 81 + +CascadeMaskRCNN: + backbone: ResNet + fpn: FPN + rpn_head: FPNRPNHead + roi_extractor: FPNRoIAlign + bbox_head: CascadeBBoxHead + bbox_assigner: CascadeBBoxAssigner + mask_assigner: MaskAssigner + mask_head: MaskHead + +ResNet: + depth: 50 + feature_maps: [2, 3, 4, 5] + freeze_at: 2 + norm_type: affine_channel + +FPN: + max_level: 6 + min_level: 2 + num_chan: 256 + spatial_scale: [0.03125, 0.0625, 0.125, 0.25] + +FPNRPNHead: + anchor_generator: + aspect_ratios: [0.5, 1.0, 2.0] + variance: [1.0, 1.0, 1.0, 1.0] + anchor_start_size: 32 + max_level: 6 + min_level: 2 + num_chan: 256 + rpn_target_assign: + rpn_batch_size_per_im: 256 + rpn_fg_fraction: 0.5 + rpn_negative_overlap: 0.3 + rpn_positive_overlap: 0.7 + rpn_straddle_thresh: 0.0 + train_proposal: + min_size: 0.0 + nms_thresh: 0.7 + pre_nms_top_n: 2000 + post_nms_top_n: 2000 + test_proposal: + min_size: 0.0 + nms_thresh: 0.7 + pre_nms_top_n: 1000 + post_nms_top_n: 1000 + +FPNRoIAlign: + canconical_level: 4 + canonical_size: 224 + max_level: 5 + min_level: 2 + sampling_ratio: 2 + box_resolution: 7 + mask_resolution: 14 + +MaskHead: + dilation: 1 + conv_dim: 256 + num_convs: 4 + resolution: 28 + +CascadeBBoxAssigner: + batch_size_per_im: 512 + bbox_reg_weights: [10, 20, 30] + bg_thresh_hi: [0.5, 0.6, 0.7] + bg_thresh_lo: [0.0, 0.0, 0.0] + fg_fraction: 0.25 + fg_thresh: [0.5, 0.6, 0.7] + +MaskAssigner: + resolution: 28 + +CascadeBBoxHead: + head: FC6FC7Head + nms: + keep_top_k: 100 + nms_threshold: 0.5 + score_threshold: 0.05 + +FC6FC7Head: + num_chan: 1024 + +LearningRate: + base_lr: 0.01 + schedulers: + - !PiecewiseDecay + gamma: 0.1 + milestones: [120000, 160000] + - !LinearWarmup + start_factor: 0.3333333333333333 + steps: 500 + +OptimizerBuilder: + optimizer: + momentum: 0.9 + type: Momentum + regularizer: + factor: 0.0001 + type: L2 + +MaskRCNNTrainFeed: + batch_size: 1 + dataset: + dataset_dir: dataset/coco + annotation: annotations/instances_train2017.json + image_dir: train2017 + batch_transforms: + - !PadBatch + pad_to_stride: 32 + num_workers: 2 + +MaskRCNNEvalFeed: + batch_size: 1 + dataset: + dataset_dir: dataset/coco + annotation: annotations/instances_val2017.json + image_dir: val2017 + batch_transforms: + - !PadBatch + pad_to_stride: 32 + num_workers: 2 + +MaskRCNNTestFeed: + batch_size: 1 + dataset: + annotation: dataset/coco/annotations/instances_val2017.json + batch_transforms: + - !PadBatch + pad_to_stride: 32 + num_workers: 2 diff --git a/ppdet/modeling/architectures/__init__.py b/ppdet/modeling/architectures/__init__.py index dc894129d3e5cfb51d0795433871106ea3a10cf6..981982514d430d0cceb42f42387645b516a721d1 100644 --- a/ppdet/modeling/architectures/__init__.py +++ b/ppdet/modeling/architectures/__init__.py @@ -17,6 +17,7 @@ from __future__ import absolute_import from . import faster_rcnn from . import mask_rcnn from . import cascade_rcnn +from . import cascade_mask_rcnn from . import yolov3 from . import ssd from . import retinanet @@ -24,6 +25,7 @@ from . import retinanet from .faster_rcnn import * from .mask_rcnn import * from .cascade_rcnn import * +from .cascade_mask_rcnn import * from .yolov3 import * from .ssd import * from .retinanet import * diff --git a/ppdet/modeling/architectures/cascade_mask_rcnn.py b/ppdet/modeling/architectures/cascade_mask_rcnn.py new file mode 100644 index 0000000000000000000000000000000000000000..b5fb9017b84a62ab1da7f05518d622903cbce53d --- /dev/null +++ b/ppdet/modeling/architectures/cascade_mask_rcnn.py @@ -0,0 +1,256 @@ +# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import paddle.fluid as fluid + +from ppdet.core.workspace import register + +__all__ = ['CascadeMaskRCNN'] + + +@register +class CascadeMaskRCNN(object): + """ + Cascade Mask R-CNN architecture, see https://arxiv.org/abs/1712.00726 + + Args: + backbone (object): backbone instance + rpn_head (object): `RPNhead` instance + bbox_assigner (object): `BBoxAssigner` instance + roi_extractor (object): ROI extractor instance + bbox_head (object): `BBoxHead` instance + mask_assigner (object): `MaskAssigner` instance + mask_head (object): `MaskHead` instance + fpn (object): feature pyramid network instance + """ + + __category__ = 'architecture' + __inject__ = [ + 'backbone', 'rpn_head', 'bbox_assigner', 'roi_extractor', 'bbox_head', + 'mask_assigner', 'mask_head', 'fpn' + ] + + def __init__(self, + backbone, + rpn_head, + roi_extractor='FPNRoIAlign', + bbox_head='CascadeBBoxHead', + bbox_assigner='CascadeBBoxAssigner', + mask_assigner='MaskAssigner', + mask_head='MaskHead', + fpn='FPN'): + super(CascadeMaskRCNN, self).__init__() + assert fpn is not None, "cascade RCNN requires FPN" + self.backbone = backbone + self.fpn = fpn + self.rpn_head = rpn_head + self.bbox_assigner = bbox_assigner + self.roi_extractor = roi_extractor + self.bbox_head = bbox_head + self.mask_assigner = mask_assigner + self.mask_head = mask_head + # Cascade local cfg + self.cls_agnostic_bbox_reg = 2 + (brw0, brw1, brw2) = self.bbox_assigner.bbox_reg_weights + self.cascade_bbox_reg_weights = [ + [1. / brw0, 1. / brw0, 2. / brw0, 2. / brw0], + [1. / brw1, 1. / brw1, 2. / brw1, 2. / brw1], + [1. / brw2, 1. / brw2, 2. / brw2, 2. / brw2] + ] + self.cascade_rcnn_loss_weight = [1.0, 0.5, 0.25] + + def build(self, feed_vars, mode='train'): + im = feed_vars['image'] + assert mode in ['train', 'test'], \ + "only 'train' and 'test' mode is supported" + + if mode == 'train': + required_fields = [ + 'gt_label', 'gt_box', 'gt_mask', 'is_crowd', 'im_info' + ] + else: + required_fields = ['im_shape', 'im_info'] + + for var in required_fields: + assert var in feed_vars, \ + "{} has no {} field".format(feed_vars, var) + + if mode == 'train': + gt_box = feed_vars['gt_box'] + is_crowd = feed_vars['is_crowd'] + + im_info = feed_vars['im_info'] + + # backbone + body_feats = self.backbone(im) + + # FPN + if self.fpn is not None: + body_feats, spatial_scale = self.fpn.get_output(body_feats) + + # rpn proposals + rpn_rois = self.rpn_head.get_proposals(body_feats, im_info, mode=mode) + + if mode == 'train': + rpn_loss = self.rpn_head.get_loss(im_info, gt_box, is_crowd) + else: + if self.rpn_only: + im_scale = fluid.layers.slice( + im_info, [1], starts=[2], ends=[3]) + im_scale = fluid.layers.sequence_expand(im_scale, rois) + rois = rois / im_scale + return {'proposal': rois} + + proposal_list = [] + roi_feat_list = [] + rcnn_pred_list = [] + rcnn_target_list = [] + + proposals = None + bbox_pred = None + for i in range(3): + if i > 0: + refined_bbox = self._decode_box( + proposals, + bbox_pred, + curr_stage=i - 1, ) + else: + refined_bbox = rpn_rois + + if mode == 'train': + outs = self.bbox_assigner( + input_rois=refined_bbox, feed_vars=feed_vars, curr_stage=i) + + proposals = outs[0] + rcnn_target_list.append(outs) + else: + proposals = refined_bbox + proposal_list.append(proposals) + + # extract roi features + roi_feat = self.roi_extractor(body_feats, proposals, spatial_scale) + roi_feat_list.append(roi_feat) + + # bbox head + cls_score, bbox_pred = self.bbox_head.get_output( + roi_feat, + wb_scalar=1.0 / self.cascade_rcnn_loss_weight[i], + name='_' + str(i + 1) if i > 0 else '') + rcnn_pred_list.append((cls_score, bbox_pred)) + + # get mask rois + rois = proposal_list[2] + + if mode == 'train': + loss = self.bbox_head.get_loss(rcnn_pred_list, rcnn_target_list, + self.cascade_rcnn_loss_weight) + loss.update(rpn_loss) + + labels_int32 = rcnn_target_list[2][1] + + mask_rois, roi_has_mask_int32, mask_int32 = self.mask_assigner( + rois=rois, + gt_classes=feed_vars['gt_label'], + is_crowd=feed_vars['is_crowd'], + gt_segms=feed_vars['gt_mask'], + im_info=feed_vars['im_info'], + labels_int32=labels_int32) + + if self.fpn is None: + bbox_head_feat = self.bbox_head.get_head_feat() + feat = fluid.layers.gather(bbox_head_feat, roi_has_mask_int32) + else: + feat = self.roi_extractor( + body_feats, mask_rois, spatial_scale, is_mask=True) + mask_loss = self.mask_head.get_loss(feat, mask_int32) + loss.update(mask_loss) + + total_loss = fluid.layers.sum(list(loss.values())) + loss.update({'loss': total_loss}) + return loss + else: + if self.fpn is None: + last_feat = body_feats[list(body_feats.keys())[-1]] + roi_feat = self.roi_extractor(last_feat, rois) + else: + roi_feat = self.roi_extractor(body_feats, rois, spatial_scale) + + bbox_pred = self.bbox_head.get_prediction( + im_info, roi_feat_list, rcnn_pred_list, proposal_list, + self.cascade_bbox_reg_weights, self.cls_agnostic_bbox_reg) + + bbox_pred = bbox_pred['bbox'] + + # share weight + bbox_shape = fluid.layers.shape(bbox_pred) + bbox_size = fluid.layers.reduce_prod(bbox_shape) + bbox_size = fluid.layers.reshape(bbox_size, [1, 1]) + size = fluid.layers.fill_constant([1, 1], value=6, dtype='int32') + cond = fluid.layers.less_than(x=bbox_size, y=size) + + mask_pred = fluid.layers.create_global_var( + shape=[1], value=0.0, dtype='float32', persistable=False) + + with fluid.layers.control_flow.Switch() as switch: + with switch.case(cond): + fluid.layers.assign(input=bbox_pred, output=mask_pred) + with switch.default(): + bbox = fluid.layers.slice( + bbox_pred, [1], starts=[2], ends=[6]) + + im_scale = fluid.layers.slice( + im_info, [1], starts=[2], ends=[3]) + im_scale = fluid.layers.sequence_expand(im_scale, bbox) + + mask_rois = bbox * im_scale + if self.fpn is None: + mask_feat = self.roi_extractor(last_feat, mask_rois) + mask_feat = self.bbox_head.get_head_feat(mask_feat) + else: + mask_feat = self.roi_extractor( + body_feats, mask_rois, spatial_scale, is_mask=True) + + mask_out = self.mask_head.get_prediction(mask_feat, bbox) + fluid.layers.assign(input=mask_out, output=mask_pred) + return {'bbox': bbox_pred, 'mask': mask_pred} + + def _decode_box(self, proposals, bbox_pred, curr_stage): + rcnn_loc_delta_r = fluid.layers.reshape( + bbox_pred, (-1, self.cls_agnostic_bbox_reg, 4)) + # only use fg box delta to decode box + rcnn_loc_delta_s = fluid.layers.slice( + rcnn_loc_delta_r, axes=[1], starts=[1], ends=[2]) + refined_bbox = fluid.layers.box_coder( + prior_box=proposals, + prior_box_var=self.cascade_bbox_reg_weights[curr_stage], + target_box=rcnn_loc_delta_s, + code_type='decode_center_size', + box_normalized=False, + axis=1, ) + refined_bbox = fluid.layers.reshape(refined_bbox, shape=[-1, 4]) + + return refined_bbox + + def train(self, feed_vars): + return self.build(feed_vars, 'train') + + def eval(self, feed_vars): + return self.build(feed_vars, 'test') + + def test(self, feed_vars): + return self.build(feed_vars, 'test')