From 926efec28ce2ee39119296fa5e0b53ef2ca424ca Mon Sep 17 00:00:00 2001 From: wangguanzhong Date: Tue, 9 Feb 2021 15:57:02 +0800 Subject: [PATCH] Update cascade (#2179) * update cascade, test=dygraph * clean code, test=dygraph * add cascade_rcnn, test=dygraph * update according to reviews, test=dygraph * support dy2static, test=dygraph --- .../_base_/cascade_fpn_reader.yml | 32 +-- .../_base_/cascade_mask_fpn_reader.yml | 14 +- .../_base_/cascade_mask_rcnn_r50_fpn.yml | 125 +++++------ .../_base_/cascade_rcnn_r50_fpn.yml | 94 +++----- .../cascade_rcnn/_base_/optimizer_1x.yml | 4 +- .../faster_rcnn/_base_/faster_rcnn_r50.yml | 4 +- .../_base_/faster_rcnn_r50_fpn.yml | 4 +- .../mask_rcnn/_base_/mask_fpn_reader.yml | 2 +- .../mask_rcnn/_base_/mask_rcnn_r50.yml | 4 +- .../mask_rcnn/_base_/mask_rcnn_r50_fpn.yml | 4 +- .../configs/mask_rcnn/_base_/mask_reader.yml | 2 +- dygraph/ppdet/data/transform/operator.py | 10 +- .../modeling/architectures/cascade_rcnn.py | 170 ++++++--------- dygraph/ppdet/modeling/backbones/resnet.py | 7 +- dygraph/ppdet/modeling/bbox_utils.py | 4 +- dygraph/ppdet/modeling/heads/__init__.py | 2 + dygraph/ppdet/modeling/heads/bbox_head.py | 52 +++-- dygraph/ppdet/modeling/heads/cascade_head.py | 205 ++++++++++++++++++ dygraph/ppdet/modeling/layers.py | 13 +- dygraph/ppdet/modeling/ops.py | 4 +- dygraph/ppdet/modeling/post_process.py | 4 + .../modeling/proposal_generator/rpn_head.py | 1 + .../modeling/proposal_generator/target.py | 70 +++--- .../proposal_generator/target_layer.py | 20 +- dygraph/ppdet/utils/checkpoint.py | 2 +- 25 files changed, 490 insertions(+), 363 deletions(-) create mode 100644 dygraph/ppdet/modeling/heads/cascade_head.py diff --git a/dygraph/configs/cascade_rcnn/_base_/cascade_fpn_reader.yml b/dygraph/configs/cascade_rcnn/_base_/cascade_fpn_reader.yml index 8a8fb1ebb..5e380eb76 100644 --- a/dygraph/configs/cascade_rcnn/_base_/cascade_fpn_reader.yml +++ b/dygraph/configs/cascade_rcnn/_base_/cascade_fpn_reader.yml @@ -1,13 +1,13 @@ worker_num: 2 TrainReader: sample_transforms: - - DecodeOp: { } - - RandomFlipImage: {prob: 0.5} - - NormalizeImage: {is_channel_first: false, is_scale: true, mean: [0.485,0.456,0.406], std: [0.229, 0.224,0.225]} - - ResizeImage: {target_size: 800, max_size: 1333, interp: 1, use_cv2: true} - - Permute: {to_bgr: false, channel_first: true} + - DecodeOp: {} + - RandomResizeOp: {target_size: [[640, 1333], [672, 1333], [704, 1333], [736, 1333], [768, 1333], [800, 1333]], interp: 2, keep_ratio: True} + - RandomFlipOp: {prob: 0.5} + - NormalizeImageOp: {is_scale: true, mean: [0.485,0.456,0.406], std: [0.229, 0.224,0.225]} + - PermuteOp: {} batch_transforms: - - PadBatch: {pad_to_stride: 32, use_padded_im_info: false, pad_gt: true} + - PadBatchOp: {pad_to_stride: 32, pad_gt: true} batch_size: 1 shuffle: true drop_last: true @@ -15,12 +15,12 @@ TrainReader: EvalReader: sample_transforms: - - DecodeOp: { } - - NormalizeImageOp: { is_scale: true, mean: [ 0.485,0.456,0.406 ], std: [ 0.229, 0.224,0.225 ] } - - ResizeOp: { interp: 1, target_size: [ 800, 1333 ], keep_ratio: True } - - PermuteOp: { } + - DecodeOp: {} + - ResizeOp: {interp: 2, target_size: [800, 1333], keep_ratio: True} + - NormalizeImageOp: {is_scale: true, mean: [0.485,0.456,0.406], std: [0.229, 0.224,0.225]} + - PermuteOp: {} batch_transforms: - - PadBatchOp: { pad_to_stride: 32, pad_gt: false } + - PadBatchOp: {pad_to_stride: 32, pad_gt: false} batch_size: 1 shuffle: false drop_last: false @@ -29,12 +29,12 @@ EvalReader: TestReader: sample_transforms: - - DecodeOp: { } - - NormalizeImageOp: { is_scale: true, mean: [ 0.485,0.456,0.406 ], std: [ 0.229, 0.224,0.225 ] } - - ResizeOp: { interp: 1, target_size: [ 800, 1333 ], keep_ratio: True } - - PermuteOp: { } + - DecodeOp: {} + - ResizeOp: {interp: 2, target_size: [800, 1333], keep_ratio: True} + - NormalizeImageOp: {is_scale: true, mean: [0.485,0.456,0.406], std: [0.229, 0.224,0.225]} + - PermuteOp: {} batch_transforms: - - PadBatchOp: { pad_to_stride: 32, pad_gt: false } + - PadBatchOp: {pad_to_stride: 32, pad_gt: false} batch_size: 1 shuffle: false drop_last: false diff --git a/dygraph/configs/cascade_rcnn/_base_/cascade_mask_fpn_reader.yml b/dygraph/configs/cascade_rcnn/_base_/cascade_mask_fpn_reader.yml index 1d73c7f31..5e380eb76 100644 --- a/dygraph/configs/cascade_rcnn/_base_/cascade_mask_fpn_reader.yml +++ b/dygraph/configs/cascade_rcnn/_base_/cascade_mask_fpn_reader.yml @@ -2,12 +2,12 @@ worker_num: 2 TrainReader: sample_transforms: - DecodeOp: {} - - RandomFlipImage: {prob: 0.5, is_mask_flip: true} - - NormalizeImage: {is_channel_first: false, is_scale: true, mean: [0.485,0.456,0.406], std: [0.229, 0.224,0.225]} - - ResizeImage: {target_size: 800, max_size: 1333, interp: 1, use_cv2: true} - - Permute: {to_bgr: false, channel_first: true} + - RandomResizeOp: {target_size: [[640, 1333], [672, 1333], [704, 1333], [736, 1333], [768, 1333], [800, 1333]], interp: 2, keep_ratio: True} + - RandomFlipOp: {prob: 0.5} + - NormalizeImageOp: {is_scale: true, mean: [0.485,0.456,0.406], std: [0.229, 0.224,0.225]} + - PermuteOp: {} batch_transforms: - - PadBatch: {pad_to_stride: 32, use_padded_im_info: false, pad_gt: true} + - PadBatchOp: {pad_to_stride: 32, pad_gt: true} batch_size: 1 shuffle: true drop_last: true @@ -16,8 +16,8 @@ TrainReader: EvalReader: sample_transforms: - DecodeOp: {} + - ResizeOp: {interp: 2, target_size: [800, 1333], keep_ratio: True} - NormalizeImageOp: {is_scale: true, mean: [0.485,0.456,0.406], std: [0.229, 0.224,0.225]} - - ResizeOp: {interp: 1, target_size: [800, 1333], keep_ratio: True} - PermuteOp: {} batch_transforms: - PadBatchOp: {pad_to_stride: 32, pad_gt: false} @@ -30,8 +30,8 @@ EvalReader: TestReader: sample_transforms: - DecodeOp: {} + - ResizeOp: {interp: 2, target_size: [800, 1333], keep_ratio: True} - NormalizeImageOp: {is_scale: true, mean: [0.485,0.456,0.406], std: [0.229, 0.224,0.225]} - - ResizeOp: {interp: 1, target_size: [800, 1333], keep_ratio: True} - PermuteOp: {} batch_transforms: - PadBatchOp: {pad_to_stride: 32, pad_gt: false} diff --git a/dygraph/configs/cascade_rcnn/_base_/cascade_mask_rcnn_r50_fpn.yml b/dygraph/configs/cascade_rcnn/_base_/cascade_mask_rcnn_r50_fpn.yml index 7c45b6a6b..f13370362 100644 --- a/dygraph/configs/cascade_rcnn/_base_/cascade_mask_rcnn_r50_fpn.yml +++ b/dygraph/configs/cascade_rcnn/_base_/cascade_mask_rcnn_r50_fpn.yml @@ -1,19 +1,13 @@ architecture: CascadeRCNN pretrain_weights: https://paddle-imagenet-models-name.bj.bcebos.com/ResNet50_cos_pretrained.tar load_static_weights: True -roi_stages: 3 -# Model Achitecture + CascadeRCNN: - # model anchor info flow - anchor: Anchor - proposal: Proposal - mask: Mask - # model feat info flow backbone: ResNet neck: FPN rpn_head: RPNHead - bbox_head: BBoxHead + bbox_head: CascadeHead mask_head: MaskHead # post process bbox_post_process: BBoxPostProcess @@ -28,97 +22,78 @@ ResNet: num_stages: 4 FPN: - in_channels: [256, 512, 1024, 2048] out_channel: 256 - min_level: 0 - max_level: 4 - spatial_scale: [0.25, 0.125, 0.0625, 0.03125] RPNHead: - rpn_feat: - name: RPNFeat - feat_in: 256 - feat_out: 256 - anchor_per_position: 3 - rpn_channel: 256 - -Anchor: anchor_generator: - name: AnchorGeneratorRPN aspect_ratios: [0.5, 1.0, 2.0] - anchor_start_size: 32 - stride: [4., 4.] - anchor_target_generator: - name: AnchorTargetGeneratorRPN + anchor_sizes: [[32], [64], [128], [256], [512]] + strides: [4, 8, 16, 32, 64] + rpn_target_assign: batch_size_per_im: 256 fg_fraction: 0.5 negative_overlap: 0.3 positive_overlap: 0.7 - straddle_thresh: 0.0 - -Proposal: - proposal_generator: - name: ProposalGenerator + use_random: True + train_proposal: min_size: 0.0 nms_thresh: 0.7 - train_pre_nms_top_n: 2000 - train_post_nms_top_n: 2000 - infer_pre_nms_top_n: 1000 - infer_post_nms_top_n: 1000 - proposal_target_generator: - name: ProposalTargetGenerator - batch_size_per_im: 512 - bbox_reg_weights: [0.1, 0.1, 0.2, 0.2] - bg_thresh_hi: [0.5, 0.6, 0.7] - bg_thresh_lo: [0.0, 0.0, 0.0] - fg_thresh: [0.5, 0.6, 0.7] - fg_fraction: 0.25 - is_cls_agnostic: true + pre_nms_top_n: 2000 + post_nms_top_n: 2000 + topk_after_collect: True + test_proposal: + min_size: 0.0 + nms_thresh: 0.7 + pre_nms_top_n: 1000 + post_nms_top_n: 1000 + -BBoxHead: - bbox_feat: - name: BBoxFeat - roi_extractor: - name: RoIAlign - resolution: 7 - sampling_ratio: 2 - head_feat: - name: TwoFCHead - in_dim: 256 - mlp_dim: 1024 - in_feat: 1024 - cls_agnostic: true +CascadeHead: + head: CascadeTwoFCHead + roi_extractor: + resolution: 7 + sampling_ratio: 0 + aligned: True + bbox_assigner: BBoxAssigner + +BBoxAssigner: + batch_size_per_im: 512 + bg_thresh: 0.5 + fg_thresh: 0.5 + fg_fraction: 0.25 + cascade_iou: [0.5, 0.6, 0.7] + use_random: True + +CascadeTwoFCHead: + mlp_dim: 1024 BBoxPostProcess: decode: name: RCNNBox - num_classes: 81 - batch_size: 1 - var_weight: 3. + prior_box_var: [30.0, 30.0, 15.0, 15.0] nms: name: MultiClassNMS keep_top_k: 100 score_threshold: 0.05 nms_threshold: 0.5 + normalized: true -Mask: - mask_target_generator: - name: MaskTargetGenerator - mask_resolution: 28 MaskHead: - mask_feat: - name: MaskFeat - num_convs: 4 - feat_in: 256 - feat_out: 256 - mask_roi_extractor: - name: RoIAlign - resolution: 14 - sampling_ratio: 2 - share_bbox_feat: False - feat_in: 256 + head: MaskFeat + roi_extractor: + resolution: 14 + sampling_ratio: 0 + aligned: True + mask_assigner: MaskAssigner + share_bbox_feat: False +MaskFeat: + num_convs: 4 + out_channels: 256 -MaskPostProcess: +MaskAssigner: mask_resolution: 28 + +MaskPostProcess: + binary_thresh: 0.5 diff --git a/dygraph/configs/cascade_rcnn/_base_/cascade_rcnn_r50_fpn.yml b/dygraph/configs/cascade_rcnn/_base_/cascade_rcnn_r50_fpn.yml index 26d867c31..155f5f500 100644 --- a/dygraph/configs/cascade_rcnn/_base_/cascade_rcnn_r50_fpn.yml +++ b/dygraph/configs/cascade_rcnn/_base_/cascade_rcnn_r50_fpn.yml @@ -1,18 +1,13 @@ architecture: CascadeRCNN pretrain_weights: https://paddle-imagenet-models-name.bj.bcebos.com/ResNet50_cos_pretrained.tar load_static_weights: True -roi_stages: 3 -# Model Achitecture + CascadeRCNN: - # model anchor info flow - anchor: Anchor - proposal: Proposal - # model feat info flow backbone: ResNet neck: FPN rpn_head: RPNHead - bbox_head: BBoxHead + bbox_head: CascadeHead # post process bbox_post_process: BBoxPostProcess @@ -25,75 +20,58 @@ ResNet: num_stages: 4 FPN: - in_channels: [256, 512, 1024, 2048] out_channel: 256 - min_level: 0 - max_level: 4 - spatial_scale: [0.25, 0.125, 0.0625, 0.03125] RPNHead: - rpn_feat: - name: RPNFeat - feat_in: 256 - feat_out: 256 - anchor_per_position: 3 - rpn_channel: 256 - -Anchor: anchor_generator: - name: AnchorGeneratorRPN aspect_ratios: [0.5, 1.0, 2.0] - anchor_start_size: 32 - stride: [4., 4.] - anchor_target_generator: - name: AnchorTargetGeneratorRPN + anchor_sizes: [[32], [64], [128], [256], [512]] + strides: [4, 8, 16, 32, 64] + rpn_target_assign: batch_size_per_im: 256 fg_fraction: 0.5 negative_overlap: 0.3 positive_overlap: 0.7 - straddle_thresh: 0.0 - -Proposal: - proposal_generator: - name: ProposalGenerator + use_random: True + train_proposal: + min_size: 0.0 + nms_thresh: 0.7 + pre_nms_top_n: 2000 + post_nms_top_n: 2000 + topk_after_collect: True + test_proposal: min_size: 0.0 nms_thresh: 0.7 - train_pre_nms_top_n: 2000 - train_post_nms_top_n: 2000 - infer_pre_nms_top_n: 1000 - infer_post_nms_top_n: 1000 - proposal_target_generator: - name: ProposalTargetGenerator - batch_size_per_im: 512 - bbox_reg_weights: [0.1, 0.1, 0.2, 0.2] - bg_thresh_hi: [0.5, 0.6, 0.7] - bg_thresh_lo: [0.0, 0.0, 0.0] - fg_thresh: [0.5, 0.6, 0.7] - fg_fraction: 0.25 - is_cls_agnostic: true + pre_nms_top_n: 1000 + post_nms_top_n: 1000 + + +CascadeHead: + head: CascadeTwoFCHead + roi_extractor: + resolution: 7 + sampling_ratio: 0 + aligned: True + bbox_assigner: BBoxAssigner + +BBoxAssigner: + batch_size_per_im: 512 + bg_thresh: 0.5 + fg_thresh: 0.5 + fg_fraction: 0.25 + cascade_iou: [0.5, 0.6, 0.7] + use_random: True -BBoxHead: - bbox_feat: - name: BBoxFeat - roi_extractor: - name: RoIAlign - resolution: 7 - sampling_ratio: 2 - head_feat: - name: TwoFCHead - in_dim: 256 - mlp_dim: 1024 - in_feat: 1024 - cls_agnostic: true +CascadeTwoFCHead: + mlp_dim: 1024 BBoxPostProcess: decode: name: RCNNBox - num_classes: 81 - batch_size: 1 - var_weight: 3. + prior_box_var: [30.0, 30.0, 15.0, 15.0] nms: name: MultiClassNMS keep_top_k: 100 score_threshold: 0.05 nms_threshold: 0.5 + normalized: true diff --git a/dygraph/configs/cascade_rcnn/_base_/optimizer_1x.yml b/dygraph/configs/cascade_rcnn/_base_/optimizer_1x.yml index d28b0947b..63f898e9c 100644 --- a/dygraph/configs/cascade_rcnn/_base_/optimizer_1x.yml +++ b/dygraph/configs/cascade_rcnn/_base_/optimizer_1x.yml @@ -7,8 +7,8 @@ LearningRate: gamma: 0.1 milestones: [8, 11] - !LinearWarmup - start_factor: 0.3333333333333333 - steps: 500 + start_factor: 0.001 + steps: 1000 OptimizerBuilder: optimizer: diff --git a/dygraph/configs/faster_rcnn/_base_/faster_rcnn_r50.yml b/dygraph/configs/faster_rcnn/_base_/faster_rcnn_r50.yml index 5ee7bffb4..3719f6d02 100644 --- a/dygraph/configs/faster_rcnn/_base_/faster_rcnn_r50.yml +++ b/dygraph/configs/faster_rcnn/_base_/faster_rcnn_r50.yml @@ -53,8 +53,8 @@ BBoxHead: BBoxAssigner: batch_size_per_im: 512 - bg_thresh: [0.5,] - fg_thresh: [0.5,] + bg_thresh: 0.5 + fg_thresh: 0.5 fg_fraction: 0.25 use_random: True diff --git a/dygraph/configs/faster_rcnn/_base_/faster_rcnn_r50_fpn.yml b/dygraph/configs/faster_rcnn/_base_/faster_rcnn_r50_fpn.yml index 881485798..f846e544f 100644 --- a/dygraph/configs/faster_rcnn/_base_/faster_rcnn_r50_fpn.yml +++ b/dygraph/configs/faster_rcnn/_base_/faster_rcnn_r50_fpn.yml @@ -56,8 +56,8 @@ BBoxHead: BBoxAssigner: batch_size_per_im: 512 - bg_thresh: [0.5,] - fg_thresh: [0.5,] + bg_thresh: 0.5 + fg_thresh: 0.5 fg_fraction: 0.25 use_random: True diff --git a/dygraph/configs/mask_rcnn/_base_/mask_fpn_reader.yml b/dygraph/configs/mask_rcnn/_base_/mask_fpn_reader.yml index a3d663558..6e830449d 100644 --- a/dygraph/configs/mask_rcnn/_base_/mask_fpn_reader.yml +++ b/dygraph/configs/mask_rcnn/_base_/mask_fpn_reader.yml @@ -3,7 +3,7 @@ TrainReader: sample_transforms: - DecodeOp: {} - RandomResizeOp: {target_size: [[640, 1333], [672, 1333], [704, 1333], [736, 1333], [768, 1333], [800, 1333]], interp: 2, keep_ratio: True} - - RandomFlipOp: {prob: 0.5, is_mask_flip: true} + - RandomFlipOp: {prob: 0.5} - NormalizeImageOp: {is_scale: true, mean: [0.485,0.456,0.406], std: [0.229, 0.224,0.225]} - PermuteOp: {} batch_transforms: diff --git a/dygraph/configs/mask_rcnn/_base_/mask_rcnn_r50.yml b/dygraph/configs/mask_rcnn/_base_/mask_rcnn_r50.yml index 75191261d..9a6942c9e 100644 --- a/dygraph/configs/mask_rcnn/_base_/mask_rcnn_r50.yml +++ b/dygraph/configs/mask_rcnn/_base_/mask_rcnn_r50.yml @@ -54,8 +54,8 @@ BBoxHead: BBoxAssigner: batch_size_per_im: 512 - bg_thresh: [0.5,] - fg_thresh: [0.5,] + bg_thresh: 0.5 + fg_thresh: 0.5 fg_fraction: 0.25 use_random: True diff --git a/dygraph/configs/mask_rcnn/_base_/mask_rcnn_r50_fpn.yml b/dygraph/configs/mask_rcnn/_base_/mask_rcnn_r50_fpn.yml index 650184f21..1be1f22fe 100644 --- a/dygraph/configs/mask_rcnn/_base_/mask_rcnn_r50_fpn.yml +++ b/dygraph/configs/mask_rcnn/_base_/mask_rcnn_r50_fpn.yml @@ -56,8 +56,8 @@ BBoxHead: BBoxAssigner: batch_size_per_im: 512 - bg_thresh: [0.5,] - fg_thresh: [0.5,] + bg_thresh: 0.5 + fg_thresh: 0.5 fg_fraction: 0.25 use_random: True diff --git a/dygraph/configs/mask_rcnn/_base_/mask_reader.yml b/dygraph/configs/mask_rcnn/_base_/mask_reader.yml index 7da1f9af9..48fef1456 100644 --- a/dygraph/configs/mask_rcnn/_base_/mask_reader.yml +++ b/dygraph/configs/mask_rcnn/_base_/mask_reader.yml @@ -3,7 +3,7 @@ TrainReader: sample_transforms: - DecodeOp: {} - RandomResizeOp: {target_size: [[640, 1333], [672, 1333], [704, 1333], [736, 1333], [768, 1333], [800, 1333]], interp: 2, keep_ratio: True} - - RandomFlipOp: {prob: 0.5, is_mask_flip: true} + - RandomFlipOp: {prob: 0.5} - NormalizeImageOp: {is_scale: true, mean: [0.485,0.456,0.406], std: [0.229, 0.224,0.225]} - PermuteOp: {} batch_transforms: diff --git a/dygraph/ppdet/data/transform/operator.py b/dygraph/ppdet/data/transform/operator.py index 90b69dccb..bdb38ff64 100644 --- a/dygraph/ppdet/data/transform/operator.py +++ b/dygraph/ppdet/data/transform/operator.py @@ -484,17 +484,14 @@ class AutoAugmentOp(BaseOperator): @register_op class RandomFlipOp(BaseOperator): - def __init__(self, prob=0.5, is_mask_flip=False): + def __init__(self, prob=0.5): """ Args: prob (float): the probability of flipping image - is_mask_flip (bool): whether flip the segmentation """ super(RandomFlipOp, self).__init__() self.prob = prob - self.is_mask_flip = is_mask_flip - if not (isinstance(self.prob, float) and - isinstance(self.is_mask_flip, bool)): + if not (isinstance(self.prob, float)): raise TypeError("{}: input type is invalid.".format(self)) def apply_segm(self, segms, height, width): @@ -557,8 +554,7 @@ class RandomFlipOp(BaseOperator): im = self.apply_image(im) if 'gt_bbox' in sample and len(sample['gt_bbox']) > 0: sample['gt_bbox'] = self.apply_bbox(sample['gt_bbox'], width) - if self.is_mask_flip and 'gt_poly' in sample and len(sample[ - 'gt_poly']) > 0: + if 'gt_poly' in sample and len(sample['gt_poly']) > 0: sample['gt_poly'] = self.apply_segm(sample['gt_poly'], height, width) if 'gt_keypoint' in sample and len(sample['gt_keypoint']) > 0: diff --git a/dygraph/ppdet/modeling/architectures/cascade_rcnn.py b/dygraph/ppdet/modeling/architectures/cascade_rcnn.py index 7b9e93484..987d7a77a 100644 --- a/dygraph/ppdet/modeling/architectures/cascade_rcnn.py +++ b/dygraph/ppdet/modeling/architectures/cascade_rcnn.py @@ -17,7 +17,7 @@ from __future__ import division from __future__ import print_function import paddle -from ppdet.core.workspace import register +from ppdet.core.workspace import register, create from .meta_arch import BaseArch __all__ = ['CascadeRCNN'] @@ -26,142 +26,106 @@ __all__ = ['CascadeRCNN'] @register class CascadeRCNN(BaseArch): __category__ = 'architecture' - __shared__ = ['roi_stages'] __inject__ = [ - 'anchor', - 'proposal', - 'mask', - 'backbone', - 'neck', - 'rpn_head', - 'bbox_head', - 'mask_head', 'bbox_post_process', 'mask_post_process', ] def __init__(self, - anchor, - proposal, backbone, rpn_head, bbox_head, bbox_post_process, neck=None, - mask=None, mask_head=None, - mask_post_process=None, - roi_stages=3): + mask_post_process=None): super(CascadeRCNN, self).__init__() - self.anchor = anchor - self.proposal = proposal self.backbone = backbone self.rpn_head = rpn_head self.bbox_head = bbox_head self.bbox_post_process = bbox_post_process self.neck = neck - self.mask = mask self.mask_head = mask_head self.mask_post_process = mask_post_process - self.roi_stages = roi_stages - self.with_mask = mask is not None + self.with_mask = mask_head is not None + + @classmethod + def from_config(cls, cfg, *args, **kwargs): + backbone = create(cfg['backbone']) + kwargs = {'input_shape': backbone.out_shape} + neck = cfg['neck'] and create(cfg['neck'], **kwargs) + + out_shape = neck and neck.out_shape or backbone.out_shape + kwargs = {'input_shape': out_shape} + rpn_head = create(cfg['rpn_head'], **kwargs) + bbox_head = create(cfg['bbox_head'], **kwargs) + + out_shape = neck and out_shape or bbox_head.get_head().out_shape + kwargs = {'input_shape': out_shape} + mask_head = cfg['mask_head'] and create(cfg['mask_head'], **kwargs) + return { + 'backbone': backbone, + 'neck': neck, + "rpn_head": rpn_head, + "bbox_head": bbox_head, + "mask_head": mask_head, + } - def model_arch(self, ): - # Backbone + def _forward(self): body_feats = self.backbone(self.inputs) - - # Neck if self.neck is not None: - body_feats, spatial_scale = self.neck(body_feats) - - # RPN - # rpn_head returns two list: rpn_feat, rpn_head_out - # each element in rpn_feats contains rpn feature on each level, - # and the length is 1 when the neck is not applied. - # each element in rpn_head_out contains (rpn_rois_score, rpn_rois_delta) - rpn_feat, self.rpn_head_out = self.rpn_head(self.inputs, body_feats) - - # Anchor - # anchor_out returns a list, - # each element contains (anchor, anchor_var) - self.anchor_out = self.anchor(rpn_feat) - - # Proposal RoI - # compute targets here when training - rois = None - bbox_head_out = None - max_overlap = None - self.bbox_head_list = [] - rois_list = [] - for i in range(self.roi_stages): - # Proposal BBox - rois = self.proposal( - self.inputs, - self.rpn_head_out, - self.anchor_out, - self.training, - i, - rois, - bbox_head_out, - max_overlap=max_overlap) - rois_list.append(rois) - max_overlap = self.proposal.get_max_overlap() - # BBox Head - bbox_feat, bbox_head_out, _ = self.bbox_head(body_feats, rois, - spatial_scale, i) - self.bbox_head_list.append(bbox_head_out) - - if not self.training: - bbox_pred, bboxes = self.bbox_head.get_cascade_prediction( - self.bbox_head_list, rois_list) - self.bboxes = self.bbox_post_process(bbox_pred, bboxes, - self.inputs['im_shape'], - self.inputs['scale_factor']) - - if self.with_mask: - rois = rois_list[-1] - rois_has_mask_int32 = None - if self.training: - bbox_targets = self.proposal.get_targets()[-1] - self.bboxes, rois_has_mask_int32 = self.mask(self.inputs, rois, - bbox_targets) - # Mask Head - self.mask_head_out = self.mask_head( - self.inputs, body_feats, self.bboxes, bbox_feat, - rois_has_mask_int32, spatial_scale) + body_feats = self.neck(body_feats) + + if self.training: + rois, rois_num, rpn_loss = self.rpn_head(body_feats, self.inputs) + bbox_loss, bbox_feat = self.bbox_head(body_feats, rois, rois_num, + self.inputs) + rois, rois_num = self.bbox_head.get_assigned_rois() + bbox_targets = self.bbox_head.get_assigned_targets() + if self.with_mask: + mask_loss = self.mask_head(body_feats, rois, rois_num, + self.inputs, bbox_targets, bbox_feat) + return rpn_loss, bbox_loss, mask_loss + else: + return rpn_loss, bbox_loss, {} + else: + rois, rois_num, _ = self.rpn_head(body_feats, self.inputs) + preds, _ = self.bbox_head(body_feats, rois, rois_num, self.inputs) + refined_rois = self.bbox_head.get_refined_rois() + + im_shape = self.inputs['im_shape'] + scale_factor = self.inputs['scale_factor'] + + bbox, bbox_num = self.bbox_post_process( + preds, (refined_rois, rois_num), im_shape, scale_factor) + # rescale the prediction back to origin image + bbox_pred = self.bbox_post_process.get_pred(bbox, bbox_num, + im_shape, scale_factor) + if not self.with_mask: + return bbox_pred, bbox_num, None + mask_out = self.mask_head(body_feats, bbox, bbox_num, self.inputs) + origin_shape = self.bbox_post_process.get_origin_shape() + mask_pred = self.mask_post_process(mask_out[:, 0, :, :], bbox_pred, + bbox_num, origin_shape) + return bbox_pred, bbox_num, mask_pred def get_loss(self, ): + rpn_loss, bbox_loss, mask_loss = self._forward() loss = {} - - # RPN loss - rpn_loss_inputs = self.anchor.generate_loss_inputs( - self.inputs, self.rpn_head_out, self.anchor_out) - loss_rpn = self.rpn_head.get_loss(rpn_loss_inputs) - loss.update(loss_rpn) - - # BBox loss - bbox_targets_list = self.proposal.get_targets() - loss_bbox = self.bbox_head.get_loss(self.bbox_head_list, - bbox_targets_list) - loss.update(loss_bbox) - + loss.update(rpn_loss) + loss.update(bbox_loss) if self.with_mask: - # Mask loss - mask_targets = self.mask.get_targets() - loss_mask = self.mask_head.get_loss(self.mask_head_out, - mask_targets) - loss.update(loss_mask) - + loss.update(mask_loss) total_loss = paddle.add_n(list(loss.values())) loss.update({'loss': total_loss}) return loss def get_pred(self): - bbox, bbox_num = self.bboxes + bbox_pred, bbox_num, mask_pred = self._forward() output = { - 'bbox': bbox, + 'bbox': bbox_pred, 'bbox_num': bbox_num, } if self.with_mask: - output.update({'mask': self.mask_head_out}) + output.update({'mask': mask_pred}) return output diff --git a/dygraph/ppdet/modeling/backbones/resnet.py b/dygraph/ppdet/modeling/backbones/resnet.py index 9a9e2181b..86a3c9b96 100755 --- a/dygraph/ppdet/modeling/backbones/resnet.py +++ b/dygraph/ppdet/modeling/backbones/resnet.py @@ -547,11 +547,8 @@ class Res5Head(nn.Layer): if depth < 50: feat_in = 256 na = NameAdapter(self) - self.res5 = self.add_sublayer( - 'res5_roi_feat', - Blocks( - depth, feat_in, feat_out, count=3, name_adapter=na, - stage_num=5)) + self.res5 = Blocks( + depth, feat_in, feat_out, count=3, name_adapter=na, stage_num=5) self.feat_out = feat_out if depth < 50 else feat_out * 4 @property diff --git a/dygraph/ppdet/modeling/bbox_utils.py b/dygraph/ppdet/modeling/bbox_utils.py index 505cea8e5..a8566140d 100644 --- a/dygraph/ppdet/modeling/bbox_utils.py +++ b/dygraph/ppdet/modeling/bbox_utils.py @@ -64,7 +64,7 @@ def delta2bbox(deltas, boxes, weights): pred_boxes.append(pred_ctr_y - 0.5 * pred_h) pred_boxes.append(pred_ctr_x + 0.5 * pred_w) pred_boxes.append(pred_ctr_y + 0.5 * pred_h) - pred_boxes = paddle.stack(pred_boxes, axis=-1) + pred_boxes = paddle.concat(pred_boxes, axis=-1) return pred_boxes @@ -88,7 +88,7 @@ def expand_bbox(bboxes, scale): def clip_bbox(boxes, im_shape): - h, w = im_shape + h, w = im_shape[0], im_shape[1] x1 = boxes[:, 0].clip(0, w) y1 = boxes[:, 1].clip(0, h) x2 = boxes[:, 2].clip(0, w) diff --git a/dygraph/ppdet/modeling/heads/__init__.py b/dygraph/ppdet/modeling/heads/__init__.py index 14b587700..d9e41bd81 100644 --- a/dygraph/ppdet/modeling/heads/__init__.py +++ b/dygraph/ppdet/modeling/heads/__init__.py @@ -20,6 +20,7 @@ from . import ssd_head from . import fcos_head from . import solov2_head from . import ttf_head +from . import cascade_head from .bbox_head import * from .mask_head import * @@ -29,3 +30,4 @@ from .ssd_head import * from .fcos_head import * from .solov2_head import * from .ttf_head import * +from .cascade_head import * diff --git a/dygraph/ppdet/modeling/heads/bbox_head.py b/dygraph/ppdet/modeling/heads/bbox_head.py index f897d799d..3ad260e3e 100644 --- a/dygraph/ppdet/modeling/heads/bbox_head.py +++ b/dygraph/ppdet/modeling/heads/bbox_head.py @@ -33,19 +33,16 @@ class TwoFCHead(nn.Layer): self.in_dim = in_dim self.mlp_dim = mlp_dim fan = in_dim * resolution * resolution - lr_factor = 1. self.fc6 = nn.Linear( in_dim * resolution * resolution, mlp_dim, weight_attr=paddle.ParamAttr( - learning_rate=lr_factor, initializer=XavierUniform(fan_out=fan))) self.fc7 = nn.Linear( mlp_dim, mlp_dim, - weight_attr=paddle.ParamAttr( - learning_rate=lr_factor, initializer=XavierUniform())) + weight_attr=paddle.ParamAttr(initializer=XavierUniform())) @classmethod def from_config(cls, cfg, input_shape): @@ -73,6 +70,12 @@ class BBoxHead(nn.Layer): """ head (nn.Layer): Extract feature in bbox head in_channel (int): Input channel after RoI extractor + roi_extractor (object): The module of RoI Extractor + bbox_assigner (object): The module of Box Assigner, label and sample the + box. + with_pool (bool): Whether to use pooling for the RoI feature. + num_classes (int): The number of classes + bbox_weight (List[float]): The weight to get the decode box """ def __init__(self, @@ -94,21 +97,17 @@ class BBoxHead(nn.Layer): self.num_classes = num_classes self.bbox_weight = bbox_weight - lr_factor = 1. self.bbox_score = nn.Linear( in_channel, self.num_classes + 1, - weight_attr=paddle.ParamAttr( - learning_rate=lr_factor, initializer=Normal( - mean=0.0, std=0.01))) + weight_attr=paddle.ParamAttr(initializer=Normal( + mean=0.0, std=0.01))) self.bbox_delta = nn.Linear( in_channel, 4 * self.num_classes, - weight_attr=paddle.ParamAttr( - learning_rate=lr_factor, - initializer=Normal( - mean=0.0, std=0.001))) + weight_attr=paddle.ParamAttr(initializer=Normal( + mean=0.0, std=0.001))) self.assigned_label = None self.assigned_rois = None @@ -128,14 +127,13 @@ class BBoxHead(nn.Layer): def forward(self, body_feats=None, rois=None, rois_num=None, inputs=None): """ - body_feats (list[Tensor]): - rois (Tensor): - rois_num (Tensor): - inputs (dict{Tensor}): + body_feats (list[Tensor]): Feature maps from backbone + rois (Tensor): RoIs generated from RPN module + rois_num (Tensor): The number of RoIs in each image + inputs (dict{Tensor}): The ground-truth of image """ if self.training: - rois, rois_num, _, targets = self.bbox_assigner(rois, rois_num, - inputs) + rois, rois_num, targets = self.bbox_assigner(rois, rois_num, inputs) self.assigned_rois = (rois, rois_num) self.assigned_targets = targets @@ -150,13 +148,14 @@ class BBoxHead(nn.Layer): deltas = self.bbox_delta(feat) if self.training: - loss = self.get_loss(scores, deltas, targets, rois) + loss = self.get_loss(scores, deltas, targets, rois, + self.bbox_weight) return loss, bbox_feat else: pred = self.get_prediction(scores, deltas) return pred, self.head - def get_loss(self, scores, deltas, targets, rois): + def get_loss(self, scores, deltas, targets, rois, bbox_weight): """ scores (Tensor): scores from bbox head outputs deltas (Tensor): deltas from bbox head outputs @@ -179,6 +178,14 @@ class BBoxHead(nn.Layer): paddle.logical_and(tgt_labels >= 0, tgt_labels < self.num_classes)).flatten() + cls_name = 'loss_bbox_cls' + reg_name = 'loss_bbox_reg' + loss_bbox = {} + + if fg_inds.numel() == 0: + loss_bbox[cls_name] = paddle.to_tensor(0., dtype='float32') + loss_bbox[reg_name] = paddle.to_tensor(0., dtype='float32') + return loss_bbox if cls_agnostic_bbox_reg: reg_delta = paddle.gather(deltas, fg_inds) else: @@ -198,16 +205,13 @@ class BBoxHead(nn.Layer): tgt_bboxes = paddle.concat(tgt_bboxes) if len( tgt_bboxes) > 1 else tgt_bboxes[0] - reg_target = bbox2delta(rois, tgt_bboxes, self.bbox_weight) + reg_target = bbox2delta(rois, tgt_bboxes, bbox_weight) reg_target = paddle.gather(reg_target, fg_inds) reg_target.stop_gradient = True loss_bbox_reg = paddle.abs(reg_delta - reg_target).sum( ) / tgt_labels.shape[0] - cls_name = 'loss_bbox_cls' - reg_name = 'loss_bbox_reg' - loss_bbox = {} loss_bbox[cls_name] = loss_bbox_cls loss_bbox[reg_name] = loss_bbox_reg diff --git a/dygraph/ppdet/modeling/heads/cascade_head.py b/dygraph/ppdet/modeling/heads/cascade_head.py new file mode 100644 index 000000000..2ef88a3dc --- /dev/null +++ b/dygraph/ppdet/modeling/heads/cascade_head.py @@ -0,0 +1,205 @@ +# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import paddle +import paddle.nn as nn +import paddle.nn.functional as F +from paddle.nn.initializer import Normal, XavierUniform +from paddle.regularizer import L2Decay + +from ppdet.core.workspace import register, create +from ppdet.modeling import ops + +from .bbox_head import BBoxHead, TwoFCHead +from .roi_extractor import RoIAlign +from ..shape_spec import ShapeSpec +from ..bbox_utils import bbox2delta, delta2bbox, clip_bbox, nonempty_bbox + + +@register +class CascadeTwoFCHead(nn.Layer): + __shared__ = ['num_cascade_stage'] + + def __init__(self, + in_dim=256, + mlp_dim=1024, + resolution=7, + num_cascade_stage=3): + super(CascadeTwoFCHead, self).__init__() + + self.in_dim = in_dim + self.mlp_dim = mlp_dim + + self.head_list = [] + for stage in range(num_cascade_stage): + head_per_stage = self.add_sublayer( + str(stage), TwoFCHead(in_dim, mlp_dim, resolution)) + self.head_list.append(head_per_stage) + + @classmethod + def from_config(cls, cfg, input_shape): + s = input_shape + s = s[0] if isinstance(s, (list, tuple)) else s + return {'in_dim': s.channels} + + @property + def out_shape(self): + return [ShapeSpec(channels=self.mlp_dim, )] + + def forward(self, rois_feat, stage=0): + out = self.head_list[stage](rois_feat) + return out + + +@register +class CascadeHead(BBoxHead): + __shared__ = ['num_classes', 'num_cascade_stages'] + __inject__ = ['bbox_assigner'] + """ + head (nn.Layer): Extract feature in bbox head + in_channel (int): Input channel after RoI extractor + roi_extractor (object): The module of RoI Extractor + bbox_assigner (object): The module of Box Assigner, label and sample the + box. + num_classes (int): The number of classes + bbox_weight (List[List[float]]): The weight to get the decode box and the + length of weight is the number of cascade + stage + num_cascade_stages (int): THe number of stage to refine the box + """ + + def __init__(self, + head, + in_channel, + roi_extractor=RoIAlign().__dict__, + bbox_assigner='BboxAssigner', + num_classes=80, + bbox_weight=[[10., 10., 5., 5.], [20.0, 20.0, 10.0, 10.0], + [30.0, 30.0, 15.0, 15.0]], + num_cascade_stages=3): + nn.Layer.__init__(self, ) + self.head = head + self.roi_extractor = roi_extractor + if isinstance(roi_extractor, dict): + self.roi_extractor = RoIAlign(**roi_extractor) + self.bbox_assigner = bbox_assigner + + self.num_classes = num_classes + self.bbox_weight = bbox_weight + self.num_cascade_stages = num_cascade_stages + + self.bbox_score_list = [] + self.bbox_delta_list = [] + for i in range(num_cascade_stages): + score_name = 'bbox_score_stage{}'.format(i) + delta_name = 'bbox_delta_stage{}'.format(i) + bbox_score = self.add_sublayer( + score_name, + nn.Linear( + in_channel, + self.num_classes + 1, + weight_attr=paddle.ParamAttr(initializer=Normal( + mean=0.0, std=0.01)))) + + bbox_delta = self.add_sublayer( + delta_name, + nn.Linear( + in_channel, + 4, + weight_attr=paddle.ParamAttr(initializer=Normal( + mean=0.0, std=0.001)))) + self.bbox_score_list.append(bbox_score) + self.bbox_delta_list.append(bbox_delta) + self.assigned_label = None + self.assigned_rois = None + + def forward(self, body_feats=None, rois=None, rois_num=None, inputs=None): + """ + body_feats (list[Tensor]): Feature maps from backbone + rois (Tensor): RoIs generated from RPN module + rois_num (Tensor): The number of RoIs in each image + inputs (dict{Tensor}): The ground-truth of image + """ + targets = [] + if self.training: + rois, rois_num, targets = self.bbox_assigner(rois, rois_num, inputs) + targets_list = [targets] + self.assigned_rois = (rois, rois_num) + self.assigned_targets = targets + + pred_bbox = None + head_out_list = [] + for i in range(self.num_cascade_stages): + if i > 0: + rois, rois_num = self._get_rois_from_boxes(pred_bbox, + inputs['im_shape']) + if self.training: + rois, rois_num, targets = self.bbox_assigner( + rois, rois_num, inputs, i, is_cascade=True) + targets_list.append(targets) + + rois_feat = self.roi_extractor(body_feats, rois, rois_num) + bbox_feat = self.head(rois_feat, i) + scores = self.bbox_score_list[i](bbox_feat) + deltas = self.bbox_delta_list[i](bbox_feat) + head_out_list.append([scores, deltas, rois]) + pred_bbox = self._get_pred_bbox(deltas, rois, self.bbox_weight[i]) + + if self.training: + loss = {} + for stage, value in enumerate(zip(head_out_list, targets_list)): + (scores, deltas, rois), targets = value + loss_stage = self.get_loss(scores, deltas, targets, rois, + self.bbox_weight[stage]) + for k, v in loss_stage.items(): + loss[k + "_stage{}".format( + stage)] = v / self.num_cascade_stages + + return loss, bbox_feat + else: + scores, deltas, self.refined_rois = self.get_prediction( + head_out_list) + return (deltas, scores), self.head + + def _get_rois_from_boxes(self, boxes, im_shape): + rois = [] + for i, boxes_per_image in enumerate(boxes): + clip_box = clip_bbox(boxes_per_image, im_shape[i]) + if self.training: + keep = nonempty_bbox(clip_box) + clip_box = paddle.gather(clip_box, keep) + rois.append(clip_box) + rois_num = paddle.concat([paddle.shape(r)[0] for r in rois]) + return rois, rois_num + + def _get_pred_bbox(self, deltas, proposals, weights): + pred_proposals = paddle.concat(proposals) if len( + proposals) > 1 else proposals[0] + pred_bbox = delta2bbox(deltas, pred_proposals, weights) + num_prop = [p.shape[0] for p in proposals] + return pred_bbox.split(num_prop) + + def get_prediction(self, head_out_list): + """ + head_out_list(List[Tensor]): scores, deltas, rois + """ + pred_list = [] + scores_list = [F.softmax(head[0]) for head in head_out_list] + scores = paddle.add_n(scores_list) / self.num_cascade_stages + # Get deltas and rois from the last stage + _, deltas, rois = head_out_list[-1] + return scores, deltas, rois + + def get_refined_rois(self, ): + return self.refined_rois diff --git a/dygraph/ppdet/modeling/layers.py b/dygraph/ppdet/modeling/layers.py index a499d528d..967d6283b 100644 --- a/dygraph/ppdet/modeling/layers.py +++ b/dygraph/ppdet/modeling/layers.py @@ -291,14 +291,18 @@ class AnchorGeneratorSSD(object): @register @serializable class RCNNBox(object): + __shared__ = ['num_classes'] + def __init__(self, prior_box_var=[10., 10., 5., 5.], code_type="decode_center_size", - box_normalized=False): + box_normalized=False, + num_classes=80): super(RCNNBox, self).__init__() self.prior_box_var = prior_box_var self.code_type = code_type self.box_normalized = box_normalized + self.num_classes = num_classes def __call__(self, bbox_head_out, rois, im_shape, scale_factor): bbox_pred, cls_prob = bbox_head_out @@ -322,6 +326,13 @@ class RCNNBox(object): bbox = delta2bbox(bbox_pred, bbox, self.prior_box_var) scores = cls_prob[:, :-1] + # [N*C, 4] + + bbox_num_class = bbox.shape[1] // 4 + bbox = paddle.reshape(bbox, [-1, bbox_num_class, 4]) + if bbox_num_class == 1: + bbox = paddle.tile(bbox, [1, self.num_classes, 1]) + origin_h = paddle.unsqueeze(origin_shape[:, 0], axis=1) origin_w = paddle.unsqueeze(origin_shape[:, 1], axis=1) zeros = paddle.zeros_like(origin_h) diff --git a/dygraph/ppdet/modeling/ops.py b/dygraph/ppdet/modeling/ops.py index 416277fe4..80ef92d0c 100644 --- a/dygraph/ppdet/modeling/ops.py +++ b/dygraph/ppdet/modeling/ops.py @@ -239,7 +239,7 @@ def roi_align(input, align_out = core.ops.roi_align( input, rois, rois_num, "pooled_height", pooled_height, "pooled_width", pooled_width, "spatial_scale", spatial_scale, - "sampling_ratio", sampling_ratio) #, "aligned", aligned) + "sampling_ratio", sampling_ratio, "aligned", aligned) return align_out else: @@ -265,7 +265,7 @@ def roi_align(input, "pooled_width": pooled_width, "spatial_scale": spatial_scale, "sampling_ratio": sampling_ratio, - #"aligned": aligned, + "aligned": aligned, }) return align_out diff --git a/dygraph/ppdet/modeling/post_process.py b/dygraph/ppdet/modeling/post_process.py index 2e3da37be..7e1e961bc 100644 --- a/dygraph/ppdet/modeling/post_process.py +++ b/dygraph/ppdet/modeling/post_process.py @@ -19,6 +19,10 @@ import paddle.nn.functional as F from ppdet.core.workspace import register from ppdet.modeling.bbox_utils import nonempty_bbox from . import ops +try: + from collections.abc import Sequence +except Exception: + from collections import Sequence @register diff --git a/dygraph/ppdet/modeling/proposal_generator/rpn_head.py b/dygraph/ppdet/modeling/proposal_generator/rpn_head.py index 5658e9947..426d33bf2 100644 --- a/dygraph/ppdet/modeling/proposal_generator/rpn_head.py +++ b/dygraph/ppdet/modeling/proposal_generator/rpn_head.py @@ -169,6 +169,7 @@ class RPNHead(nn.Layer): rois_collect.append(topk_rois) rois_num_collect.append(paddle.shape(topk_rois)[0]) rois_num_collect = paddle.concat(rois_num_collect) + return rois_collect, rois_num_collect def get_loss(self, pred_scores, pred_deltas, anchors, inputs): diff --git a/dygraph/ppdet/modeling/proposal_generator/target.py b/dygraph/ppdet/modeling/proposal_generator/target.py index e5d4a1019..b66f0d9cd 100644 --- a/dygraph/ppdet/modeling/proposal_generator/target.py +++ b/dygraph/ppdet/modeling/proposal_generator/target.py @@ -37,7 +37,7 @@ def rpn_anchor_target(anchors, gt_bbox = gt_boxes[i] # Step1: match anchor and gt_bbox - matches, match_labels, matched_vals = label_box( + matches, match_labels = label_box( anchors, gt_bbox, rpn_positive_overlap, rpn_negative_overlap, True) # Step2: sample anchor fg_inds, bg_inds = subsample_labels(match_labels, rpn_batch_size_per_im, @@ -84,8 +84,7 @@ def label_box(anchors, gt_boxes, positive_overlap, negative_overlap, matches = matches.flatten() match_labels = match_labels.flatten() - matched_vals = matched_vals.flatten() - return matches, match_labels, matched_vals + return matches, match_labels def subsample_labels(labels, @@ -118,16 +117,6 @@ def subsample_labels(labels, return fg_inds, bg_inds -def filter_roi(rois, max_overlap): - ws = rois[:, 2] - rois[:, 0] - hs = rois[:, 3] - rois[:, 1] - valid_mask = paddle.logical_and(ws > 0, hs > 0, max_overlap < 1) - keep = paddle.nonzero(valid_mask) - if keep.numel() > 0: - return rois[keep[:, 1]] - return paddle.zeros((1, 4), dtype='float32') - - def generate_proposal_target(rpn_rois, gt_classes, gt_boxes, @@ -137,67 +126,68 @@ def generate_proposal_target(rpn_rois, bg_thresh, num_classes, use_random=True, - is_cascade_rcnn=False, - max_overlaps=None): + is_cascade=False, + cascade_iou=0.5): rois_with_gt = [] tgt_labels = [] tgt_bboxes = [] - sampled_max_overlaps = [] tgt_gt_inds = [] new_rois_num = [] + fg_thresh = cascade_iou if is_cascade else fg_thresh + bg_thresh = cascade_iou if is_cascade else bg_thresh for i, rpn_roi in enumerate(rpn_rois): - max_overlap = max_overlaps[i] if is_cascade_rcnn else None gt_bbox = gt_boxes[i] gt_class = gt_classes[i] - if is_cascade_rcnn: - rpn_roi = filter_roi(rpn_roi, max_overlap) - bbox = paddle.concat([rpn_roi, gt_bbox]) - - # Step1: label bbox - matches, match_labels, matched_vals = label_box( - bbox, gt_bbox, fg_thresh, bg_thresh, False) + if not is_cascade: + bbox = paddle.concat([rpn_roi, gt_bbox]) + else: + bbox = rpn_roi + + # Step1: label bbox + matches, match_labels = label_box(bbox, gt_bbox, fg_thresh, bg_thresh, + False) # Step2: sample bbox sampled_inds, sampled_gt_classes = sample_bbox( matches, match_labels, gt_class, batch_size_per_im, fg_fraction, - num_classes, use_random) + num_classes, use_random, is_cascade) # Step3: make output - rois_per_image = paddle.gather(bbox, sampled_inds) - sampled_gt_ind = paddle.gather(matches, sampled_inds) + rois_per_image = bbox if is_cascade else paddle.gather(bbox, + sampled_inds) + sampled_gt_ind = matches if is_cascade else paddle.gather(matches, + sampled_inds) sampled_bbox = paddle.gather(gt_bbox, sampled_gt_ind) - sampled_overlap = paddle.gather(matched_vals, sampled_inds) rois_per_image.stop_gradient = True sampled_gt_ind.stop_gradient = True sampled_bbox.stop_gradient = True - sampled_overlap.stop_gradient = True - tgt_labels.append(sampled_gt_classes) tgt_bboxes.append(sampled_bbox) rois_with_gt.append(rois_per_image) - sampled_max_overlaps.append(sampled_overlap) tgt_gt_inds.append(sampled_gt_ind) new_rois_num.append(paddle.shape(sampled_inds)[0]) new_rois_num = paddle.concat(new_rois_num) - return rois_with_gt, tgt_labels, tgt_bboxes, tgt_gt_inds, new_rois_num, sampled_max_overlaps + return rois_with_gt, tgt_labels, tgt_bboxes, tgt_gt_inds, new_rois_num -def sample_bbox( - matches, - match_labels, - gt_classes, - batch_size_per_im, - fg_fraction, - num_classes, - use_random=True, ): +def sample_bbox(matches, + match_labels, + gt_classes, + batch_size_per_im, + fg_fraction, + num_classes, + use_random=True, + is_cascade=False): gt_classes = paddle.gather(gt_classes, matches) gt_classes = paddle.where(match_labels == 0, paddle.ones_like(gt_classes) * num_classes, gt_classes) gt_classes = paddle.where(match_labels == -1, paddle.ones_like(gt_classes) * -1, gt_classes) + if is_cascade: + return matches, gt_classes rois_per_image = int(batch_size_per_im) fg_inds, bg_inds = subsample_labels(gt_classes, rois_per_image, fg_fraction, diff --git a/dygraph/ppdet/modeling/proposal_generator/target_layer.py b/dygraph/ppdet/modeling/proposal_generator/target_layer.py index ed6c08651..4586cadf3 100644 --- a/dygraph/ppdet/modeling/proposal_generator/target_layer.py +++ b/dygraph/ppdet/modeling/proposal_generator/target_layer.py @@ -58,10 +58,11 @@ class BBoxAssigner(object): def __init__(self, batch_size_per_im=512, fg_fraction=.25, - fg_thresh=[.5, ], - bg_thresh=[.5, ], + fg_thresh=.5, + bg_thresh=.5, use_random=True, is_cls_agnostic=False, + cascade_iou=[0.5, 0.6, 0.7], num_classes=80): super(BBoxAssigner, self).__init__() self.batch_size_per_im = batch_size_per_im @@ -70,6 +71,7 @@ class BBoxAssigner(object): self.bg_thresh = bg_thresh self.use_random = use_random self.is_cls_agnostic = is_cls_agnostic + self.cascade_iou = cascade_iou self.num_classes = num_classes def __call__(self, @@ -77,22 +79,20 @@ class BBoxAssigner(object): rpn_rois_num, inputs, stage=0, - max_overlap=None): - is_cascade = True if stage > 0 else False + is_cascade=False): gt_classes = inputs['gt_class'] gt_boxes = inputs['gt_bbox'] # rois, tgt_labels, tgt_bboxes, tgt_gt_inds - # new_rois_num, sampled_max_overlaps + # new_rois_num outs = generate_proposal_target( rpn_rois, gt_classes, gt_boxes, self.batch_size_per_im, - self.fg_fraction, self.fg_thresh[stage], self.bg_thresh[stage], - self.num_classes, self.use_random, is_cascade, max_overlap) + self.fg_fraction, self.fg_thresh, self.bg_thresh, self.num_classes, + self.use_random, is_cascade, self.cascade_iou[stage]) rois = outs[0] - rois_num = outs[-2] - max_overlaps = outs[-1] + rois_num = outs[-1] # tgt_labels, tgt_bboxes, tgt_gt_inds targets = outs[1:4] - return rois, rois_num, max_overlaps, targets + return rois, rois_num, targets @register diff --git a/dygraph/ppdet/utils/checkpoint.py b/dygraph/ppdet/utils/checkpoint.py index fd25bd1ad..834e3d9cd 100644 --- a/dygraph/ppdet/utils/checkpoint.py +++ b/dygraph/ppdet/utils/checkpoint.py @@ -92,7 +92,6 @@ def load_weight(model, weight, optimizer=None): param_state_dict = paddle.load(pdparam_path) model_dict = model.state_dict() - model_weight = {} incorrect_keys = 0 @@ -129,6 +128,7 @@ def load_pretrain_weight(model, pretrain_weight, load_static_weights=False, weight_type='pretrain'): + assert weight_type in ['pretrain', 'finetune'] if is_url(pretrain_weight): pretrain_weight = get_weights_path_dist(pretrain_weight) -- GitLab