diff --git a/configs/queryinst/README.md b/configs/queryinst/README.md new file mode 100644 index 0000000000000000000000000000000000000000..568135328ba43780a3829977b839169126fe0b10 --- /dev/null +++ b/configs/queryinst/README.md @@ -0,0 +1,41 @@ +# QueryInst: Instances as Queries + +## Introduction + +QueryInst is a multi-stage end-to-end system that treats instances of interest as learnable queries, enabling query +based object detectors, e.g., Sparse R-CNN, to have strong instance segmentation performance. The attributes of +instances such as categories, bounding boxes, instance masks, and instance association embeddings are represented by +queries in a unified manner. In QueryInst, a query is shared by both detection and segmentation via dynamic convolutions +and driven by parallelly-supervised multi-stage learning. + +## Model Zoo + +| Backbone | Lr schd | Proposals | MultiScale | RandomCrop | bbox AP | mask AP | Download | Config | +|:------------:|:-------:|:---------:|:----------:|:----------:|:-------:|:-------:|------------------------------------------------------------------------------------------------------|----------------------------------------------------------| +| ResNet50-FPN | 1x | 100 | × | × | 42.1 | 37.8 | [model](https://bj.bcebos.com/v1/paddledet/models/queryinst_r50_fpn_1x_pro100_coco.pdparams) | [config](./queryinst_r50_fpn_1x_pro100_coco.yml) | +| ResNet50-FPN | 3x | 300 | √ | √ | 47.9 | 42.1 | [model](https://bj.bcebos.com/v1/paddledet/models/queryinst_r50_fpn_ms_crop_3x_pro300_coco.pdparams) | [config](./queryinst_r50_fpn_ms_crop_3x_pro300_coco.yml) | + +- COCO val-set evaluation results. +- These configurations are for 4-card training. + +Please modify these parameters as appropriate: + +```yaml +worker_num: 4 +TrainReader: + use_shared_memory: true +find_unused_parameters: true +``` + +## Citations + +``` +@InProceedings{Fang_2021_ICCV, + author = {Fang, Yuxin and Yang, Shusheng and Wang, Xinggang and Li, Yu and Fang, Chen and Shan, Ying and Feng, Bin and Liu, Wenyu}, + title = {Instances As Queries}, + booktitle = {Proceedings of the IEEE/CVF International Conference on Computer Vision (ICCV)}, + month = {October}, + year = {2021}, + pages = {6910-6919} +} +``` diff --git a/configs/queryinst/_base_/optimizer_1x.yml b/configs/queryinst/_base_/optimizer_1x.yml new file mode 100644 index 0000000000000000000000000000000000000000..a7c0f5cb16311f046adec9e11f7cd0cc4a93e3d9 --- /dev/null +++ b/configs/queryinst/_base_/optimizer_1x.yml @@ -0,0 +1,17 @@ +epoch: 12 + +LearningRate: + base_lr: 0.0001 + schedulers: + - !PiecewiseDecay + gamma: 0.1 + milestones: [8, 11] + - !LinearWarmup + start_factor: 0.001 + steps: 1000 + +OptimizerBuilder: + clip_grad_by_norm: 0.1 + optimizer: + type: AdamW + weight_decay: 0.0001 diff --git a/configs/queryinst/_base_/queryinst_r50_fpn.yml b/configs/queryinst/_base_/queryinst_r50_fpn.yml new file mode 100644 index 0000000000000000000000000000000000000000..05ab1c02f8a02308cfd47d441697c4a548c32f1a --- /dev/null +++ b/configs/queryinst/_base_/queryinst_r50_fpn.yml @@ -0,0 +1,74 @@ +num_proposals: &num_proposals 100 +proposal_embedding_dim: &proposal_embedding_dim 256 +bbox_resolution: &bbox_resolution 7 +mask_resolution: &mask_resolution 14 + +architecture: QueryInst +pretrain_weights: https://paddledet.bj.bcebos.com/models/pretrained/ResNet50_cos_pretrained.pdparams + +QueryInst: + backbone: ResNet + neck: FPN + rpn_head: EmbeddingRPNHead + roi_head: SparseRoIHead + post_process: SparsePostProcess + +ResNet: + depth: 50 + norm_type: bn + freeze_at: 0 + return_idx: [ 0, 1, 2, 3 ] + num_stages: 4 + lr_mult_list: [ 0.1, 0.1, 0.1, 0.1 ] + +FPN: + out_channel: *proposal_embedding_dim + extra_stage: 0 + +EmbeddingRPNHead: + num_proposals: *num_proposals + +SparseRoIHead: + num_stages: 6 + bbox_roi_extractor: + resolution: *bbox_resolution + sampling_ratio: 2 + aligned: True + mask_roi_extractor: + resolution: *mask_resolution + sampling_ratio: 2 + aligned: True + bbox_head: DIIHead + mask_head: DynamicMaskHead + loss_func: QueryInstLoss + +DIIHead: + feedforward_channels: 2048 + dynamic_feature_channels: 64 + roi_resolution: *bbox_resolution + num_attn_heads: 8 + dropout: 0.0 + num_ffn_fcs: 2 + num_cls_fcs: 1 + num_reg_fcs: 3 + +DynamicMaskHead: + dynamic_feature_channels: 64 + roi_resolution: *mask_resolution + num_convs: 4 + conv_kernel_size: 3 + conv_channels: 256 + upsample_method: 'deconv' + upsample_scale_factor: 2 + +QueryInstLoss: + focal_loss_alpha: 0.25 + focal_loss_gamma: 2.0 + class_weight: 2.0 + l1_weight: 5.0 + giou_weight: 2.0 + mask_weight: 8.0 + +SparsePostProcess: + num_proposals: *num_proposals + binary_thresh: 0.5 diff --git a/configs/queryinst/_base_/queryinst_reader.yml b/configs/queryinst/_base_/queryinst_reader.yml new file mode 100644 index 0000000000000000000000000000000000000000..e867cc27454efaf321e40bd09dc674d4f32c3a8d --- /dev/null +++ b/configs/queryinst/_base_/queryinst_reader.yml @@ -0,0 +1,43 @@ +worker_num: 4 + +TrainReader: + sample_transforms: + - Decode: {} + - Poly2Mask: {del_poly: True} + - Resize: {interp: 1, target_size: [800, 1333], keep_ratio: True} + - RandomFlip: {prob: 0.5} + - NormalizeImage: {is_scale: true, mean: [0.485,0.456,0.406], std: [0.229, 0.224,0.225]} + - Permute: {} + batch_transforms: + - PadBatch: {pad_to_stride: 32} + - Gt2SparseTarget: {} + batch_size: 4 + shuffle: true + drop_last: true + collate_batch: false + use_shared_memory: true + +EvalReader: + sample_transforms: + - Decode: {} + - Resize: {interp: 1, target_size: [800, 1333], keep_ratio: True} + - NormalizeImage: {is_scale: true, mean: [0.485,0.456,0.406], std: [0.229, 0.224,0.225]} + - Permute: {} + batch_transforms: + - PadBatch: {pad_to_stride: 32} + - Gt2SparseTarget: {} + batch_size: 1 + shuffle: false + drop_last: false + +TestReader: + sample_transforms: + - Decode: {} + - Resize: {interp: 1, target_size: [800, 1333], keep_ratio: True} + - NormalizeImage: {is_scale: true, mean: [0.485,0.456,0.406], std: [0.229, 0.224,0.225]} + - Permute: {} + batch_transforms: + - PadBatch: {pad_to_stride: 32} + - Gt2SparseTarget: {} + batch_size: 1 + shuffle: false diff --git a/configs/queryinst/queryinst_r50_fpn_1x_pro100_coco.yml b/configs/queryinst/queryinst_r50_fpn_1x_pro100_coco.yml new file mode 100644 index 0000000000000000000000000000000000000000..1e61252b71d3373c2fc062207ef2b88d699d8a0b --- /dev/null +++ b/configs/queryinst/queryinst_r50_fpn_1x_pro100_coco.yml @@ -0,0 +1,12 @@ +_BASE_: [ + '../datasets/coco_instance.yml', + '../runtime.yml', + '_base_/optimizer_1x.yml', + '_base_/queryinst_r50_fpn.yml', + '_base_/queryinst_reader.yml', +] + +log_iter: 50 +find_unused_parameters: true + +weights: output/queryinst_r50_fpn_1x_pro100_coco/model_final diff --git a/configs/queryinst/queryinst_r50_fpn_ms_crop_3x_pro300_coco.yml b/configs/queryinst/queryinst_r50_fpn_ms_crop_3x_pro300_coco.yml new file mode 100644 index 0000000000000000000000000000000000000000..7dfa8997e3f71c0941b3b7626ad256333af2161a --- /dev/null +++ b/configs/queryinst/queryinst_r50_fpn_ms_crop_3x_pro300_coco.yml @@ -0,0 +1,45 @@ +_BASE_: [ + './queryinst_r50_fpn_1x_pro100_coco.yml', +] + +weights: output/queryinst_r50_fpn_ms_crop_3x_pro300_coco/model_final + +EmbeddingRPNHead: + num_proposals: 300 + +QueryInstPostProcess: + num_proposals: 300 + +epoch: 36 + +LearningRate: + base_lr: 0.0001 + schedulers: + - !PiecewiseDecay + gamma: 0.1 + milestones: [27, 33] + - !LinearWarmup + start_factor: 0.001 + steps: 1000 + +TrainReader: + sample_transforms: + - Decode: {} + - Poly2Mask: {del_poly: True} + - RandomFlip: {prob: 0.5} + - RandomSelect: { transforms1: [ RandomShortSideResize: { short_side_sizes: [ 480, 512, 544, 576, 608, 640, 672, 704, 736, 768, 800 ], max_size: 1333 } ], + transforms2: [ + RandomShortSideResize: { short_side_sizes: [ 400, 500, 600 ], max_size: 1333 }, + RandomSizeCrop: { min_size: 384, max_size: 600, keep_empty: true }, + RandomShortSideResize: { short_side_sizes: [ 480, 512, 544, 576, 608, 640, 672, 704, 736, 768, 800 ], max_size: 1333 } ] + } + - NormalizeImage: { is_scale: true, mean: [ 0.485,0.456,0.406 ], std: [ 0.229, 0.224,0.225 ] } + - Permute: {} + batch_transforms: + - PadBatch: {pad_to_stride: 32} + - Gt2SparseTarget: {} + batch_size: 4 + shuffle: true + drop_last: true + collate_batch: false + use_shared_memory: true diff --git a/configs/sparse_rcnn/_base_/sparse_rcnn_reader.yml b/configs/sparse_rcnn/_base_/sparse_rcnn_reader.yml index 248ca39b741b0607ab053a911a954ef2719964c2..b9544b31c44159b66cd63df23a6d6a79aeb081bd 100644 --- a/configs/sparse_rcnn/_base_/sparse_rcnn_reader.yml +++ b/configs/sparse_rcnn/_base_/sparse_rcnn_reader.yml @@ -1,5 +1,4 @@ worker_num: 4 -use_process: true TrainReader: sample_transforms: @@ -10,12 +9,11 @@ TrainReader: - Permute: {} batch_transforms: - PadBatch: {pad_to_stride: 32} - - Gt2SparseRCNNTarget: {} + - Gt2SparseTarget: {use_padding_shape: True} batch_size: 4 shuffle: true drop_last: true collate_batch: false - use_process: true EvalReader: sample_transforms: @@ -25,11 +23,10 @@ EvalReader: - Permute: {} batch_transforms: - PadBatch: {pad_to_stride: 32} - - Gt2SparseRCNNTarget: {} + - Gt2SparseTarget: {use_padding_shape: True} batch_size: 1 shuffle: false drop_last: false - use_process: true TestReader: sample_transforms: @@ -39,6 +36,6 @@ TestReader: - Permute: {} batch_transforms: - PadBatch: {pad_to_stride: 32} - - Gt2SparseRCNNTarget: {} + - Gt2SparseTarget: {use_padding_shape: True} batch_size: 1 shuffle: false diff --git a/ppdet/data/transform/batch_operators.py b/ppdet/data/transform/batch_operators.py index bdb7989f0a5f1c90c39b086113c7a6f4166f53b9..92c211ee415dac977da211217bd61a5db9857153 100644 --- a/ppdet/data/transform/batch_operators.py +++ b/ppdet/data/transform/batch_operators.py @@ -44,7 +44,7 @@ __all__ = [ 'Gt2FCOSTarget', 'Gt2TTFTarget', 'Gt2Solov2Target', - 'Gt2SparseRCNNTarget', + 'Gt2SparseTarget', 'PadMaskBatch', 'Gt2GFLTarget', 'Gt2CenterNetTarget', @@ -916,27 +916,33 @@ class Gt2Solov2Target(BaseOperator): @register_op -class Gt2SparseRCNNTarget(BaseOperator): - ''' - Generate SparseRCNN targets by groud truth data - ''' - - def __init__(self): - super(Gt2SparseRCNNTarget, self).__init__() +class Gt2SparseTarget(BaseOperator): + def __init__(self, use_padding_shape=False): + super(Gt2SparseTarget, self).__init__() + self.use_padding_shape = use_padding_shape def __call__(self, samples, context=None): for sample in samples: - im = sample["image"] - h, w = im.shape[1:3] - img_whwh = np.array([w, h, w, h], dtype=np.int32) - sample["img_whwh"] = img_whwh - if "scale_factor" in sample: - sample["scale_factor_wh"] = np.array( - [sample["scale_factor"][1], sample["scale_factor"][0]], - dtype=np.float32) + ori_h, ori_w = sample['h'], sample['w'] + if self.use_padding_shape: + h, w = sample["image"].shape[1:3] + if "scale_factor" in sample: + sf_w, sf_h = sample["scale_factor"][1], sample[ + "scale_factor"][0] + sample["scale_factor_whwh"] = np.array( + [sf_w, sf_h, sf_w, sf_h], dtype=np.float32) + else: + sample["scale_factor_whwh"] = np.array( + [1.0, 1.0, 1.0, 1.0], dtype=np.float32) else: - sample["scale_factor_wh"] = np.array( - [1.0, 1.0], dtype=np.float32) + h, w = round(sample['im_shape'][0]), round(sample['im_shape'][ + 1]) + sample["scale_factor_whwh"] = np.array( + [w / ori_w, h / ori_h, w / ori_w, h / ori_h], + dtype=np.float32) + + sample["img_whwh"] = np.array([w, h, w, h], dtype=np.float32) + sample["ori_shape"] = np.array([ori_h, ori_w], dtype=np.int32) return samples diff --git a/ppdet/data/transform/operators.py b/ppdet/data/transform/operators.py index a9dfb1647e6cede35544038c9a05e3642edce301..3a68282dee0fd83da4e61fd0445c3fa81eac8e8d 100644 --- a/ppdet/data/transform/operators.py +++ b/ppdet/data/transform/operators.py @@ -2097,13 +2097,16 @@ class Pad(BaseOperator): @register_op class Poly2Mask(BaseOperator): """ - gt poly to mask annotations + gt poly to mask annotations. + Args: + del_poly (bool): Whether to delete poly after generating mask. Default: False. """ - def __init__(self): + def __init__(self, del_poly=False): super(Poly2Mask, self).__init__() import pycocotools.mask as maskUtils self.maskutils = maskUtils + self.del_poly = del_poly def _poly2mask(self, mask_ann, img_h, img_w): if isinstance(mask_ann, list): @@ -2122,13 +2125,15 @@ class Poly2Mask(BaseOperator): def apply(self, sample, context=None): assert 'gt_poly' in sample - im_h = sample['h'] - im_w = sample['w'] + im_h, im_w = sample['im_shape'] masks = [ self._poly2mask(gt_poly, im_h, im_w) for gt_poly in sample['gt_poly'] ] sample['gt_segm'] = np.asarray(masks).astype(np.uint8) + if self.del_poly: + del (sample['gt_poly']) + return sample @@ -2677,12 +2682,21 @@ class RandomShortSideResize(BaseOperator): class RandomSizeCrop(BaseOperator): """ Cut the image randomly according to `min_size` and `max_size` + Args: + min_size (int): Min size for edges of cropped image. + max_size (int): Max size for edges of cropped image. If it + is set to larger than length of the input image, + the output will keep the origin length. + keep_empty (bool): Whether to keep the cropped result with no object. + If it is set to False, the no-object result will not + be returned, replaced by the original input. """ - def __init__(self, min_size, max_size): + def __init__(self, min_size, max_size, keep_empty=True): super(RandomSizeCrop, self).__init__() self.min_size = min_size self.max_size = max_size + self.keep_empty = keep_empty from paddle.vision.transforms.functional import crop as paddle_crop self.paddle_crop = paddle_crop @@ -2712,17 +2726,20 @@ class RandomSizeCrop(BaseOperator): return i, j, th, tw def crop(self, sample, region): - image_shape = sample['image'].shape[:2] - sample['image'] = self.paddle_crop(sample['image'], *region) - keep_index = None - # apply bbox + # apply bbox and check whether the cropped result is valid if 'gt_bbox' in sample and len(sample['gt_bbox']) > 0: - sample['gt_bbox'] = self.apply_bbox(sample['gt_bbox'], region) - bbox = sample['gt_bbox'].reshape([-1, 2, 2]) + croped_bbox = self.apply_bbox(sample['gt_bbox'], region) + bbox = croped_bbox.reshape([-1, 2, 2]) area = (bbox[:, 1, :] - bbox[:, 0, :]).prod(axis=1) keep_index = np.where(area > 0)[0] - sample['gt_bbox'] = sample['gt_bbox'][keep_index] if len( + + if not self.keep_empty and len(keep_index) == 0: + # When keep_empty is set to False, cropped with no-object will + # not be used and return the origin content. + return sample + + sample['gt_bbox'] = croped_bbox[keep_index] if len( keep_index) > 0 else np.zeros( [0, 4], dtype=np.float32) sample['gt_class'] = sample['gt_class'][keep_index] if len( @@ -2737,17 +2754,24 @@ class RandomSizeCrop(BaseOperator): keep_index) > 0 else np.zeros( [0, 1], dtype=np.float32) + image_shape = sample['image'].shape[:2] + sample['image'] = self.paddle_crop(sample['image'], *region) + sample['im_shape'] = np.array( + sample['image'].shape[:2], dtype=np.float32) + # apply polygon if 'gt_poly' in sample and len(sample['gt_poly']) > 0: sample['gt_poly'] = self.apply_segm(sample['gt_poly'], region, image_shape) - if keep_index is not None: + sample['gt_poly'] = np.array(sample['gt_poly']) + if keep_index is not None and len(keep_index) > 0: sample['gt_poly'] = sample['gt_poly'][keep_index] + sample['gt_poly'] = sample['gt_poly'].tolist() # apply gt_segm if 'gt_segm' in sample and len(sample['gt_segm']) > 0: i, j, h, w = region sample['gt_segm'] = sample['gt_segm'][:, i:i + h, j:j + w] - if keep_index is not None: + if keep_index is not None and len(keep_index) > 0: sample['gt_segm'] = sample['gt_segm'][keep_index] return sample diff --git a/ppdet/modeling/architectures/__init__.py b/ppdet/modeling/architectures/__init__.py index c89106c4f5f0b69fca11676fd297177938d99fa1..5efdec0775ed46c4328ec90e7ccf483f79e64725 100644 --- a/ppdet/modeling/architectures/__init__.py +++ b/ppdet/modeling/architectures/__init__.py @@ -40,6 +40,7 @@ from . import yolox from . import yolof from . import pose3d_metro from . import centertrack +from . import queryinst from .meta_arch import * from .faster_rcnn import * @@ -70,3 +71,4 @@ from .yolox import * from .yolof import * from .pose3d_metro import * from .centertrack import * +from .queryinst import * diff --git a/ppdet/modeling/architectures/queryinst.py b/ppdet/modeling/architectures/queryinst.py new file mode 100644 index 0000000000000000000000000000000000000000..76a65ed3a2565d638546f4b3deb09670bd809c1c --- /dev/null +++ b/ppdet/modeling/architectures/queryinst.py @@ -0,0 +1,104 @@ +# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import paddle + +from ppdet.core.workspace import register, create +from .meta_arch import BaseArch + +__all__ = ['QueryInst'] + + +@register +class QueryInst(BaseArch): + __category__ = 'architecture' + __inject__ = ['post_process'] + + def __init__(self, + backbone, + neck, + rpn_head, + roi_head, + post_process='SparsePostProcess'): + super(QueryInst, self).__init__() + self.backbone = backbone + self.neck = neck + self.rpn_head = rpn_head + self.roi_head = roi_head + self.post_process = post_process + + @classmethod + def from_config(cls, cfg, *args, **kwargs): + backbone = create(cfg['backbone']) + kwargs = {'input_shape': backbone.out_shape} + neck = create(cfg['neck'], **kwargs) + + kwargs = {'input_shape': neck.out_shape} + rpn_head = create(cfg['rpn_head'], **kwargs) + roi_head = create(cfg['roi_head'], **kwargs) + + return { + 'backbone': backbone, + 'neck': neck, + 'rpn_head': rpn_head, + "roi_head": roi_head + } + + def _forward(self, targets=None): + features = self.backbone(self.inputs) + features = self.neck(features) + + proposal_bboxes, proposal_features = self.rpn_head(self.inputs[ + 'img_whwh']) + outputs = self.roi_head(features, proposal_bboxes, proposal_features, + targets) + + if self.training: + return outputs + else: + bbox_pred, bbox_num, mask_pred = self.post_process( + outputs['class_logits'], outputs['bbox_pred'], + self.inputs['scale_factor_whwh'], self.inputs['ori_shape'], + outputs['mask_logits']) + return bbox_pred, bbox_num, mask_pred + + def get_loss(self): + targets = [] + for i in range(len(self.inputs['img_whwh'])): + boxes = self.inputs['gt_bbox'][i] + labels = self.inputs['gt_class'][i].squeeze(-1) + img_whwh = self.inputs['img_whwh'][i] + if boxes.shape[0] != 0: + img_whwh_tgt = img_whwh.unsqueeze(0).tile([boxes.shape[0], 1]) + else: + img_whwh_tgt = paddle.zeros_like(boxes) + gt_segm = self.inputs['gt_segm'][i].astype('float32') + targets.append({ + 'boxes': boxes, + 'labels': labels, + 'img_whwh': img_whwh, + 'img_whwh_tgt': img_whwh_tgt, + 'gt_segm': gt_segm + }) + losses = self._forward(targets) + losses.update({'loss': sum(losses.values())}) + return losses + + def get_pred(self): + bbox_pred, bbox_num, mask_pred = self._forward() + return {'bbox': bbox_pred, 'bbox_num': bbox_num, 'mask': mask_pred} diff --git a/ppdet/modeling/architectures/sparse_rcnn.py b/ppdet/modeling/architectures/sparse_rcnn.py index 34c29498b5270161ad9497bae51abc263b9cb2eb..2cbc85338eaf899f7344c415b09d8d49901972b0 100644 --- a/ppdet/modeling/architectures/sparse_rcnn.py +++ b/ppdet/modeling/architectures/sparse_rcnn.py @@ -60,10 +60,10 @@ class SparseRCNN(BaseArch): head_outs = self.head(fpn_feats, self.inputs["img_whwh"]) if not self.training: - bboxes = self.postprocess( + bbox_pred, bbox_num = self.postprocess( head_outs["pred_logits"], head_outs["pred_boxes"], - self.inputs["scale_factor_wh"], self.inputs["img_whwh"]) - return bboxes + self.inputs["scale_factor_whwh"], self.inputs["ori_shape"]) + return bbox_pred, bbox_num else: return head_outs diff --git a/ppdet/modeling/bbox_utils.py b/ppdet/modeling/bbox_utils.py index e2a42449d89074a377b5e21c8048b3c3d4fba080..576cbbf04bff806242f36f81785b84c921e523f0 100644 --- a/ppdet/modeling/bbox_utils.py +++ b/ppdet/modeling/bbox_utils.py @@ -143,8 +143,8 @@ def delta2bbox_v2(deltas, dw = paddle.clip(dw, max=clip_scale) dh = paddle.clip(dh, max=clip_scale) else: - dw = dw.clip(min=-ctr_clip, max=ctr_clip) - dh = dh.clip(min=-ctr_clip, max=ctr_clip) + dw = dw.clip(min=-clip_scale, max=clip_scale) + dh = dh.clip(min=-clip_scale, max=clip_scale) pred_ctr_x = dx + ctr_x.unsqueeze(1) pred_ctr_y = dy + ctr_y.unsqueeze(1) diff --git a/ppdet/modeling/heads/__init__.py b/ppdet/modeling/heads/__init__.py index bef87d2c17f9aecc70114100cdc3670092ca4a51..ecd15b2f139e1244e908cfec59592dafcb1f3ec4 100644 --- a/ppdet/modeling/heads/__init__.py +++ b/ppdet/modeling/heads/__init__.py @@ -39,6 +39,7 @@ from . import ld_gfl_head from . import yolof_head from . import ppyoloe_contrast_head from . import centertrack_head +from . import sparse_roi_head from .bbox_head import * from .mask_head import * @@ -67,3 +68,4 @@ from .ppyoloe_r_head import * from .yolof_head import * from .ppyoloe_contrast_head import * from .centertrack_head import * +from .sparse_roi_head import * diff --git a/ppdet/modeling/heads/sparse_roi_head.py b/ppdet/modeling/heads/sparse_roi_head.py new file mode 100644 index 0000000000000000000000000000000000000000..bdc76a946a2a3a817376f7caba64d99200c81a97 --- /dev/null +++ b/ppdet/modeling/heads/sparse_roi_head.py @@ -0,0 +1,467 @@ +# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# This code is referenced from: https://github.com/open-mmlab/mmdetection + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import copy + +import paddle +from paddle import nn + +from ppdet.core.workspace import register +from ppdet.modeling import initializer as init +from .roi_extractor import RoIAlign +from ..bbox_utils import delta2bbox_v2 +from ..cls_utils import _get_class_default_kwargs +from ..layers import MultiHeadAttention + +__all__ = ['SparseRoIHead', 'DIIHead', 'DynamicMaskHead'] + + +class DynamicConv(nn.Layer): + def __init__(self, + in_channels=256, + feature_channels=64, + out_channels=None, + roi_resolution=7, + with_proj=True): + super(DynamicConv, self).__init__() + + self.in_channels = in_channels + self.feature_channels = feature_channels + self.out_channels = out_channels if out_channels else in_channels + + self.num_params_in = self.in_channels * self.feature_channels + self.num_params_out = self.out_channels * self.feature_channels + self.dynamic_layer = nn.Linear(self.in_channels, + self.num_params_in + self.num_params_out) + + self.norm_in = nn.LayerNorm(self.feature_channels) + self.norm_out = nn.LayerNorm(self.out_channels) + + self.activation = nn.ReLU() + + self.with_proj = with_proj + if self.with_proj: + num_output = self.out_channels * roi_resolution**2 + self.fc_layer = nn.Linear(num_output, self.out_channels) + self.fc_norm = nn.LayerNorm(self.out_channels) + + def forward(self, param_feature, input_feature): + input_feature = input_feature.flatten(2).transpose([2, 0, 1]) + input_feature = input_feature.transpose([1, 0, 2]) + + parameters = self.dynamic_layer(param_feature) + + param_in = parameters[:, :self.num_params_in].reshape( + [-1, self.in_channels, self.feature_channels]) + param_out = parameters[:, -self.num_params_out:].reshape( + [-1, self.feature_channels, self.out_channels]) + + features = paddle.bmm(input_feature, param_in) + features = self.norm_in(features) + features = self.activation(features) + + features = paddle.bmm(features, param_out) + features = self.norm_out(features) + features = self.activation(features) + + if self.with_proj: + features = features.flatten(1) + features = self.fc_layer(features) + features = self.fc_norm(features) + features = self.activation(features) + + return features + + +class FFN(nn.Layer): + def __init__(self, + embed_dims=256, + feedforward_channels=2048, + num_fcs=2, + ffn_drop=0.0, + add_identity=True): + super(FFN, self).__init__() + + layers = [] + in_channels = embed_dims + for _ in range(num_fcs - 1): + layers.append( + nn.Sequential( + nn.Linear(in_channels, feedforward_channels), + nn.ReLU(), nn.Dropout(ffn_drop))) + in_channels = feedforward_channels + layers.append(nn.Linear(feedforward_channels, embed_dims)) + layers.append(nn.Dropout(ffn_drop)) + self.layers = nn.Sequential(*layers) + + self.add_identity = add_identity + + def forward(self, x): + identity = x + out = self.layers(x) + if not self.add_identity: + return out + else: + return out + identity + + +@register +class DynamicMaskHead(nn.Layer): + __shared__ = ['num_classes', 'proposal_embedding_dim', 'norm_type'] + + def __init__(self, + num_classes=80, + proposal_embedding_dim=256, + dynamic_feature_channels=64, + roi_resolution=14, + num_convs=4, + conv_kernel_size=3, + conv_channels=256, + upsample_method='deconv', + upsample_scale_factor=2, + norm_type='bn'): + super(DynamicMaskHead, self).__init__() + + self.d_model = proposal_embedding_dim + + self.instance_interactive_conv = DynamicConv( + self.d_model, + dynamic_feature_channels, + roi_resolution=roi_resolution, + with_proj=False) + + self.convs = nn.LayerList() + for i in range(num_convs): + self.convs.append( + nn.Sequential( + nn.Conv2D( + self.d_model if i == 0 else conv_channels, + conv_channels, + conv_kernel_size, + padding='same', + bias_attr=False), + nn.BatchNorm2D(conv_channels), + nn.ReLU())) + if norm_type == 'sync_bn': + self.convs = nn.SyncBatchNorm.convert_sync_batchnorm(self.convs) + + self.upsample_method = upsample_method + if upsample_method is None: + self.upsample = None + elif upsample_method == 'deconv': + self.upsample = nn.Conv2DTranspose( + conv_channels if num_convs > 0 else self.d_model, + conv_channels, + upsample_scale_factor, + stride=upsample_scale_factor) + self.relu = nn.ReLU() + else: + self.upsample = nn.Upsample(None, upsample_scale_factor) + + cls_in_channels = conv_channels if num_convs > 0 else self.d_model + cls_in_channels = conv_channels if upsample_method == 'deconv' else cls_in_channels + self.conv_cls = nn.Conv2D(cls_in_channels, num_classes, 1) + + self._init_weights() + + def _init_weights(self): + for p in self.parameters(): + if p.dim() > 1: + init.xavier_uniform_(p) + + init.constant_(self.conv_cls.bias, 0.) + + def forward(self, roi_features, attn_features): + attn_features = attn_features.reshape([-1, self.d_model]) + attn_features_iic = self.instance_interactive_conv(attn_features, + roi_features) + + x = attn_features_iic.transpose([0, 2, 1]).reshape(roi_features.shape) + + for conv in self.convs: + x = conv(x) + if self.upsample is not None: + x = self.upsample(x) + if self.upsample_method == 'deconv': + x = self.relu(x) + mask_pred = self.conv_cls(x) + return mask_pred + + +@register +class DIIHead(nn.Layer): + __shared__ = ['num_classes', 'proposal_embedding_dim'] + + def __init__(self, + num_classes=80, + proposal_embedding_dim=256, + feedforward_channels=2048, + dynamic_feature_channels=64, + roi_resolution=7, + num_attn_heads=8, + dropout=0.0, + num_ffn_fcs=2, + num_cls_fcs=1, + num_reg_fcs=3): + super(DIIHead, self).__init__() + + self.num_classes = num_classes + self.d_model = proposal_embedding_dim + + self.attention = MultiHeadAttention(self.d_model, num_attn_heads, + dropout) + self.attention_norm = nn.LayerNorm(self.d_model) + + self.instance_interactive_conv = DynamicConv( + self.d_model, + dynamic_feature_channels, + roi_resolution=roi_resolution, + with_proj=True) + self.instance_interactive_conv_dropout = nn.Dropout(dropout) + self.instance_interactive_conv_norm = nn.LayerNorm(self.d_model) + + self.ffn = FFN(self.d_model, feedforward_channels, num_ffn_fcs, dropout) + self.ffn_norm = nn.LayerNorm(self.d_model) + + self.cls_fcs = nn.LayerList() + for _ in range(num_cls_fcs): + self.cls_fcs.append( + nn.Linear( + self.d_model, self.d_model, bias_attr=False)) + self.cls_fcs.append(nn.LayerNorm(self.d_model)) + self.cls_fcs.append(nn.ReLU()) + self.fc_cls = nn.Linear(self.d_model, self.num_classes) + + self.reg_fcs = nn.LayerList() + for _ in range(num_reg_fcs): + self.reg_fcs.append( + nn.Linear( + self.d_model, self.d_model, bias_attr=False)) + self.reg_fcs.append(nn.LayerNorm(self.d_model)) + self.reg_fcs.append(nn.ReLU()) + self.fc_reg = nn.Linear(self.d_model, 4) + + self._init_weights() + + def _init_weights(self): + for p in self.parameters(): + if p.dim() > 1: + init.xavier_uniform_(p) + + bias_init = init.bias_init_with_prob(0.01) + init.constant_(self.fc_cls.bias, bias_init) + + def forward(self, roi_features, proposal_features): + N, num_proposals = proposal_features.shape[:2] + + proposal_features = proposal_features + self.attention( + proposal_features) + attn_features = self.attention_norm(proposal_features) + + proposal_features = attn_features.reshape([-1, self.d_model]) + proposal_features_iic = self.instance_interactive_conv( + proposal_features, roi_features) + proposal_features = proposal_features + self.instance_interactive_conv_dropout( + proposal_features_iic) + obj_features = self.instance_interactive_conv_norm(proposal_features) + + obj_features = self.ffn(obj_features) + obj_features = self.ffn_norm(obj_features) + + cls_feature = obj_features.clone() + reg_feature = obj_features.clone() + + for cls_layer in self.cls_fcs: + cls_feature = cls_layer(cls_feature) + class_logits = self.fc_cls(cls_feature) + for reg_layer in self.reg_fcs: + reg_feature = reg_layer(reg_feature) + bbox_deltas = self.fc_reg(reg_feature) + + class_logits = class_logits.reshape( + [N, num_proposals, self.num_classes]) + bbox_deltas = bbox_deltas.reshape([N, num_proposals, 4]) + obj_features = obj_features.reshape([N, num_proposals, self.d_model]) + + return class_logits, bbox_deltas, obj_features, attn_features + + @staticmethod + def refine_bboxes(proposal_bboxes, bbox_deltas): + pred_bboxes = delta2bbox_v2( + bbox_deltas.reshape([-1, 4]), + proposal_bboxes.reshape([-1, 4]), + delta_mean=[0.0, 0.0, 0.0, 0.0], + delta_std=[0.5, 0.5, 1.0, 1.0], + ctr_clip=None) + return pred_bboxes.reshape(proposal_bboxes.shape) + + +@register +class SparseRoIHead(nn.Layer): + __inject__ = ['bbox_head', 'mask_head', 'loss_func'] + + def __init__(self, + num_stages=6, + bbox_roi_extractor=_get_class_default_kwargs(RoIAlign), + mask_roi_extractor=_get_class_default_kwargs(RoIAlign), + bbox_head='DIIHead', + mask_head='DynamicMaskHead', + loss_func='QueryInstLoss'): + super(SparseRoIHead, self).__init__() + + self.num_stages = num_stages + + self.bbox_roi_extractor = bbox_roi_extractor + self.mask_roi_extractor = mask_roi_extractor + if isinstance(bbox_roi_extractor, dict): + self.bbox_roi_extractor = RoIAlign(**bbox_roi_extractor) + if isinstance(mask_roi_extractor, dict): + self.mask_roi_extractor = RoIAlign(**mask_roi_extractor) + + self.bbox_heads = nn.LayerList( + [copy.deepcopy(bbox_head) for _ in range(num_stages)]) + self.mask_heads = nn.LayerList( + [copy.deepcopy(mask_head) for _ in range(num_stages)]) + + self.loss_helper = loss_func + + @classmethod + def from_config(cls, cfg, input_shape): + bbox_roi_extractor = cfg['bbox_roi_extractor'] + mask_roi_extractor = cfg['mask_roi_extractor'] + assert isinstance(bbox_roi_extractor, dict) + assert isinstance(mask_roi_extractor, dict) + + kwargs = RoIAlign.from_config(cfg, input_shape) + bbox_roi_extractor.update(kwargs) + mask_roi_extractor.update(kwargs) + + return { + 'bbox_roi_extractor': bbox_roi_extractor, + 'mask_roi_extractor': mask_roi_extractor + } + + @staticmethod + def get_roi_features(features, bboxes, roi_extractor): + rois_list = [ + bboxes[i] for i in range(len(bboxes)) if len(bboxes[i]) > 0 + ] + rois_num = paddle.to_tensor( + [len(bboxes[i]) for i in range(len(bboxes))], dtype='int32') + + pos_ids = paddle.cast(rois_num, dtype='bool') + if pos_ids.sum() != len(rois_num): + rois_num = rois_num[pos_ids] + features = [features[i][pos_ids] for i in range(len(features))] + + return roi_extractor(features, rois_list, rois_num) + + def _forward_train(self, body_feats, pro_bboxes, pro_feats, targets): + all_stage_losses = {} + for stage in range(self.num_stages): + bbox_head = self.bbox_heads[stage] + mask_head = self.mask_heads[stage] + + roi_feats = self.get_roi_features(body_feats, pro_bboxes, + self.bbox_roi_extractor) + class_logits, bbox_deltas, pro_feats, attn_feats = bbox_head( + roi_feats, pro_feats) + bbox_pred = self.bbox_heads[stage].refine_bboxes(pro_bboxes, + bbox_deltas) + + indices = self.loss_helper.matcher({ + 'pred_logits': class_logits.detach(), + 'pred_boxes': bbox_pred.detach() + }, targets) + avg_factor = paddle.to_tensor( + [sum(len(tgt['labels']) for tgt in targets)], dtype='float32') + if paddle.distributed.get_world_size() > 1: + paddle.distributed.all_reduce(avg_factor) + avg_factor /= paddle.distributed.get_world_size() + avg_factor = paddle.clip(avg_factor, min=1.) + + loss_classes = self.loss_helper.loss_classes(class_logits, targets, + indices, avg_factor) + if sum(len(v['labels']) for v in targets) == 0: + loss_bboxes = { + 'loss_bbox': paddle.to_tensor([0.]), + 'loss_giou': paddle.to_tensor([0.]) + } + loss_masks = {'loss_mask': paddle.to_tensor([0.])} + else: + loss_bboxes = self.loss_helper.loss_bboxes(bbox_pred, targets, + indices, avg_factor) + + pos_attn_feats = paddle.concat([ + paddle.gather( + src, src_idx, axis=0) + for src, (src_idx, _) in zip(attn_feats, indices) + ]) + pos_bbox_pred = [ + paddle.gather( + src, src_idx, axis=0) + for src, (src_idx, _) in zip(bbox_pred.detach(), indices) + ] + pos_roi_feats = self.get_roi_features(body_feats, pos_bbox_pred, + self.mask_roi_extractor) + mask_logits = mask_head(pos_roi_feats, pos_attn_feats) + loss_masks = self.loss_helper.loss_masks( + pos_bbox_pred, mask_logits, targets, indices, avg_factor) + + for loss in [loss_classes, loss_bboxes, loss_masks]: + for key in loss.keys(): + all_stage_losses[f'stage{stage}_{key}'] = loss[key] + + pro_bboxes = bbox_pred.detach() + + return all_stage_losses + + def _forward_test(self, body_feats, pro_bboxes, pro_feats): + for stage in range(self.num_stages): + roi_feats = self.get_roi_features(body_feats, pro_bboxes, + self.bbox_roi_extractor) + class_logits, bbox_deltas, pro_feats, attn_feats = self.bbox_heads[ + stage](roi_feats, pro_feats) + bbox_pred = self.bbox_heads[stage].refine_bboxes(pro_bboxes, + bbox_deltas) + + pro_bboxes = bbox_pred.detach() + + roi_feats = self.get_roi_features(body_feats, bbox_pred, + self.mask_roi_extractor) + mask_logits = self.mask_heads[stage](roi_feats, attn_feats) + + return { + 'class_logits': class_logits, + 'bbox_pred': bbox_pred, + 'mask_logits': mask_logits + } + + def forward(self, + body_features, + proposal_bboxes, + proposal_features, + targets=None): + if self.training: + return self._forward_train(body_features, proposal_bboxes, + proposal_features, targets) + else: + return self._forward_test(body_features, proposal_bboxes, + proposal_features) diff --git a/ppdet/modeling/losses/__init__.py b/ppdet/modeling/losses/__init__.py index 13bcd49503d9a071a2d23250012b46b3c78f9e03..0e6b31de8a8fcacce2cfa62242f458565540d0b6 100644 --- a/ppdet/modeling/losses/__init__.py +++ b/ppdet/modeling/losses/__init__.py @@ -30,6 +30,7 @@ from . import smooth_l1_loss from . import probiou_loss from . import cot_loss from . import supcontrast +from . import queryinst_loss from .yolo_loss import * from .iou_aware_loss import * @@ -49,4 +50,5 @@ from .smooth_l1_loss import * from .pose3d_loss import * from .probiou_loss import * from .cot_loss import * -from .supcontrast import * \ No newline at end of file +from .supcontrast import * +from .queryinst_loss import * diff --git a/ppdet/modeling/losses/queryinst_loss.py b/ppdet/modeling/losses/queryinst_loss.py new file mode 100644 index 0000000000000000000000000000000000000000..640b9b4102de867cab6852ee1220971c4ef3a405 --- /dev/null +++ b/ppdet/modeling/losses/queryinst_loss.py @@ -0,0 +1,175 @@ +# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import paddle +import paddle.nn.functional as F + +from ppdet.core.workspace import register +from ppdet.modeling.losses.iou_loss import GIoULoss +from .sparsercnn_loss import HungarianMatcher + +__all__ = ['QueryInstLoss'] + + +@register +class QueryInstLoss(object): + __shared__ = ['num_classes'] + + def __init__(self, + num_classes=80, + focal_loss_alpha=0.25, + focal_loss_gamma=2.0, + class_weight=2.0, + l1_weight=5.0, + giou_weight=2.0, + mask_weight=8.0): + super(QueryInstLoss, self).__init__() + + self.num_classes = num_classes + self.focal_loss_alpha = focal_loss_alpha + self.focal_loss_gamma = focal_loss_gamma + self.loss_weights = { + "loss_cls": class_weight, + "loss_bbox": l1_weight, + "loss_giou": giou_weight, + "loss_mask": mask_weight + } + self.giou_loss = GIoULoss(eps=1e-6, reduction='sum') + + self.matcher = HungarianMatcher(focal_loss_alpha, focal_loss_gamma, + class_weight, l1_weight, giou_weight) + + def loss_classes(self, class_logits, targets, indices, avg_factor): + tgt_labels = paddle.full( + class_logits.shape[:2], self.num_classes, dtype='int32') + + if sum(len(v['labels']) for v in targets) > 0: + tgt_classes = paddle.concat([ + paddle.gather( + tgt['labels'], tgt_idx, axis=0) + for tgt, (_, tgt_idx) in zip(targets, indices) + ]) + batch_idx, src_idx = self._get_src_permutation_idx(indices) + for i, (batch_i, src_i) in enumerate(zip(batch_idx, src_idx)): + tgt_labels[int(batch_i), int(src_i)] = tgt_classes[i] + + tgt_labels = tgt_labels.flatten(0, 1).unsqueeze(-1) + + tgt_labels_onehot = paddle.cast( + tgt_labels == paddle.arange(0, self.num_classes), dtype='float32') + tgt_labels_onehot.stop_gradient = True + + src_logits = class_logits.flatten(0, 1) + + loss_cls = F.sigmoid_focal_loss( + src_logits, + tgt_labels_onehot, + alpha=self.focal_loss_alpha, + gamma=self.focal_loss_gamma, + reduction='sum') / avg_factor + losses = {'loss_cls': loss_cls * self.loss_weights['loss_cls']} + return losses + + def loss_bboxes(self, bbox_pred, targets, indices, avg_factor): + bboxes = paddle.concat([ + paddle.gather( + src, src_idx, axis=0) + for src, (src_idx, _) in zip(bbox_pred, indices) + ]) + + tgt_bboxes = paddle.concat([ + paddle.gather( + tgt['boxes'], tgt_idx, axis=0) + for tgt, (_, tgt_idx) in zip(targets, indices) + ]) + tgt_bboxes.stop_gradient = True + + im_shapes = paddle.concat([tgt['img_whwh_tgt'] for tgt in targets]) + bboxes_norm = bboxes / im_shapes + tgt_bboxes_norm = tgt_bboxes / im_shapes + + loss_giou = self.giou_loss(bboxes, tgt_bboxes) / avg_factor + loss_bbox = F.l1_loss( + bboxes_norm, tgt_bboxes_norm, reduction='sum') / avg_factor + losses = { + 'loss_bbox': loss_bbox * self.loss_weights['loss_bbox'], + 'loss_giou': loss_giou * self.loss_weights['loss_giou'] + } + return losses + + def loss_masks(self, pos_bbox_pred, mask_logits, targets, indices, + avg_factor): + tgt_segm = [ + paddle.gather( + tgt['gt_segm'], tgt_idx, axis=0) + for tgt, (_, tgt_idx) in zip(targets, indices) + ] + + tgt_masks = [] + for i in range(len(indices)): + gt_segm = tgt_segm[i].unsqueeze(1) + if len(gt_segm) == 0: + continue + boxes = pos_bbox_pred[i] + boxes[:, 0::2] = paddle.clip( + boxes[:, 0::2], min=0, max=gt_segm.shape[3]) + boxes[:, 1::2] = paddle.clip( + boxes[:, 1::2], min=0, max=gt_segm.shape[2]) + boxes_num = paddle.to_tensor([1] * len(boxes), dtype='int32') + gt_mask = paddle.vision.ops.roi_align( + gt_segm, + boxes, + boxes_num, + output_size=mask_logits.shape[-2:], + aligned=True) + tgt_masks.append(gt_mask) + tgt_masks = paddle.concat(tgt_masks).squeeze(1) + tgt_masks = paddle.cast(tgt_masks >= 0.5, dtype='float32') + tgt_masks.stop_gradient = True + + tgt_labels = paddle.concat([ + paddle.gather( + tgt['labels'], tgt_idx, axis=0) + for tgt, (_, tgt_idx) in zip(targets, indices) + ]) + + mask_label = F.one_hot(tgt_labels, self.num_classes).unsqueeze([2, 3]) + mask_label = paddle.expand_as(mask_label, mask_logits) + mask_label.stop_gradient = True + + src_masks = paddle.gather_nd(mask_logits, paddle.nonzero(mask_label)) + shape = mask_logits.shape + src_masks = paddle.reshape(src_masks, [shape[0], shape[2], shape[3]]) + src_masks = F.sigmoid(src_masks) + + X = src_masks.flatten(1) + Y = tgt_masks.flatten(1) + inter = paddle.sum(X * Y, 1) + union = paddle.sum(X * X, 1) + paddle.sum(Y * Y, 1) + dice = (2 * inter) / (union + 2e-5) + + loss_mask = (1 - dice).sum() / avg_factor + losses = {'loss_mask': loss_mask * self.loss_weights['loss_mask']} + return losses + + @staticmethod + def _get_src_permutation_idx(indices): + batch_idx = paddle.concat( + [paddle.full_like(src, i) for i, (src, _) in enumerate(indices)]) + src_idx = paddle.concat([src for (src, _) in indices]) + return batch_idx, src_idx diff --git a/ppdet/modeling/losses/sparsercnn_loss.py b/ppdet/modeling/losses/sparsercnn_loss.py index 8b7db92fada6f6e3f3dd7999fda35f6e750a1f12..ac9eba6feee5d07d722e4524c4d373b51e7834c4 100644 --- a/ppdet/modeling/losses/sparsercnn_loss.py +++ b/ppdet/modeling/losses/sparsercnn_loss.py @@ -284,6 +284,11 @@ class HungarianMatcher(nn.Layer): """ bs, num_queries = outputs["pred_logits"].shape[:2] + if sum(len(v["labels"]) for v in targets) == 0: + return [(paddle.to_tensor( + [], dtype=paddle.int64), paddle.to_tensor( + [], dtype=paddle.int64)) for _ in range(bs)] + # We flatten to compute the cost matrices in a batch out_prob = F.sigmoid(outputs["pred_logits"].flatten( start_axis=0, stop_axis=1)) diff --git a/ppdet/modeling/post_process.py b/ppdet/modeling/post_process.py index a0d6432d1638e098cde2f46de04ebcbe97e78538..933d012de1818d2ced2d43a3039d1e63f2740f9e 100644 --- a/ppdet/modeling/post_process.py +++ b/ppdet/modeling/post_process.py @@ -206,31 +206,6 @@ class MaskPostProcess(object): self.export_onnx = export_onnx self.assign_on_cpu = assign_on_cpu - def paste_mask(self, masks, boxes, im_h, im_w): - """ - Paste the mask prediction to the original image. - """ - x0_int, y0_int = 0, 0 - x1_int, y1_int = im_w, im_h - x0, y0, x1, y1 = paddle.split(boxes, 4, axis=1) - N = masks.shape[0] - img_y = paddle.arange(y0_int, y1_int) + 0.5 - img_x = paddle.arange(x0_int, x1_int) + 0.5 - - img_y = (img_y - y0) / (y1 - y0) * 2 - 1 - img_x = (img_x - x0) / (x1 - x0) * 2 - 1 - # img_x, img_y have shapes (N, w), (N, h) - - if self.assign_on_cpu: - paddle.set_device('cpu') - gx = img_x[:, None, :].expand( - [N, paddle.shape(img_y)[1], paddle.shape(img_x)[1]]) - gy = img_y[:, :, None].expand( - [N, paddle.shape(img_y)[1], paddle.shape(img_x)[1]]) - grid = paddle.stack([gx, gy], axis=3) - img_masks = F.grid_sample(masks, grid, align_corners=False) - return img_masks[:, 0] - def __call__(self, mask_out, bboxes, bbox_num, origin_shape): """ Decode the mask_out and paste the mask to the origin image. @@ -253,8 +228,8 @@ class MaskPostProcess(object): if self.export_onnx: h, w = origin_shape[0][0], origin_shape[0][1] - mask_onnx = self.paste_mask(mask_out[:, None, :, :], bboxes[:, 2:], - h, w) + mask_onnx = paste_mask(mask_out[:, None, :, :], bboxes[:, 2:], h, w, + self.assign_on_cpu) mask_onnx = mask_onnx >= self.binary_thresh pred_result = paddle.cast(mask_onnx, 'int32') @@ -270,9 +245,9 @@ class MaskPostProcess(object): mask_out_i = mask_out[id_start:id_start + bbox_num[i], :, :] im_h = origin_shape[i, 0] im_w = origin_shape[i, 1] - bbox_num_i = bbox_num[id_start] - pred_mask = self.paste_mask(mask_out_i[:, None, :, :], - bboxes_i[:, 2:], im_h, im_w) + pred_mask = paste_mask(mask_out_i[:, None, :, :], + bboxes_i[:, 2:], im_h, im_w, + self.assign_on_cpu) pred_mask = paddle.cast(pred_mask >= self.binary_thresh, 'int32') pred_result[id_start:id_start + bbox_num[i], :im_h, : @@ -542,89 +517,110 @@ class DETRBBoxPostProcess(object): @register class SparsePostProcess(object): - __shared__ = ['num_classes'] + __shared__ = ['num_classes', 'assign_on_cpu'] - def __init__(self, num_proposals, num_classes=80): + def __init__(self, + num_proposals, + num_classes=80, + binary_thresh=0.5, + assign_on_cpu=False): super(SparsePostProcess, self).__init__() self.num_classes = num_classes self.num_proposals = num_proposals + self.binary_thresh = binary_thresh + self.assign_on_cpu = assign_on_cpu - def __call__(self, box_cls, box_pred, scale_factor_wh, img_whwh): - """ - Arguments: - box_cls (Tensor): tensor of shape (batch_size, num_proposals, K). - The tensor predicts the classification probability for each proposal. - box_pred (Tensor): tensors of shape (batch_size, num_proposals, 4). - The tensor predicts 4-vector (x,y,w,h) box - regression values for every proposal - scale_factor_wh (Tensor): tensors of shape [batch_size, 2] the scalor of per img - img_whwh (Tensor): tensors of shape [batch_size, 4] - Returns: - bbox_pred (Tensor): tensors of shape [num_boxes, 6] Each row has 6 values: - [label, confidence, xmin, ymin, xmax, ymax] - bbox_num (Tensor): tensors of shape [batch_size] the number of RoIs in each image. - """ - assert len(box_cls) == len(scale_factor_wh) == len(img_whwh) - - img_wh = img_whwh[:, :2] - - scores = F.sigmoid(box_cls) - labels = paddle.arange(0, self.num_classes). \ - unsqueeze(0).tile([self.num_proposals, 1]).flatten(start_axis=0, stop_axis=1) - - classes_all = [] - scores_all = [] - boxes_all = [] - for i, (scores_per_image, - box_pred_per_image) in enumerate(zip(scores, box_pred)): - - scores_per_image, topk_indices = scores_per_image.flatten( - 0, 1).topk( - self.num_proposals, sorted=False) - labels_per_image = paddle.gather(labels, topk_indices, axis=0) - - box_pred_per_image = box_pred_per_image.reshape([-1, 1, 4]).tile( - [1, self.num_classes, 1]).reshape([-1, 4]) - box_pred_per_image = paddle.gather( - box_pred_per_image, topk_indices, axis=0) - - classes_all.append(labels_per_image) - scores_all.append(scores_per_image) - boxes_all.append(box_pred_per_image) - - bbox_num = paddle.zeros([len(scale_factor_wh)], dtype="int32") - boxes_final = [] - - for i in range(len(scale_factor_wh)): - classes = classes_all[i] - boxes = boxes_all[i] - scores = scores_all[i] - - boxes[:, 0::2] = paddle.clip( - boxes[:, 0::2], min=0, max=img_wh[i][0]) / scale_factor_wh[i][0] - boxes[:, 1::2] = paddle.clip( - boxes[:, 1::2], min=0, max=img_wh[i][1]) / scale_factor_wh[i][1] - boxes_w, boxes_h = (boxes[:, 2] - boxes[:, 0]).numpy(), ( - boxes[:, 3] - boxes[:, 1]).numpy() - - keep = (boxes_w > 1.) & (boxes_h > 1.) - - if (keep.sum() == 0): - bboxes = paddle.zeros([1, 6]).astype("float32") + def __call__(self, scores, bboxes, scale_factor, ori_shape, masks=None): + assert len(scores) == len(bboxes) == \ + len(ori_shape) == len(scale_factor) + device = paddle.device.get_device() + batch_size = len(ori_shape) + + scores = F.sigmoid(scores) + has_mask = masks is not None + if has_mask: + masks = F.sigmoid(masks) + masks = masks.reshape([batch_size, -1, *masks.shape[1:]]) + + bbox_pred = [] + mask_pred = [] if has_mask else None + bbox_num = paddle.zeros([batch_size], dtype='int32') + for i in range(batch_size): + score = scores[i] + bbox = bboxes[i] + score, indices = score.flatten(0, 1).topk( + self.num_proposals, sorted=False) + label = indices % self.num_classes + if has_mask: + mask = masks[i] + mask = mask.flatten(0, 1)[indices] + + H, W = ori_shape[i][0], ori_shape[i][1] + bbox = bbox[paddle.cast(indices / self.num_classes, indices.dtype)] + bbox /= scale_factor[i] + bbox[:, 0::2] = paddle.clip(bbox[:, 0::2], 0, W) + bbox[:, 1::2] = paddle.clip(bbox[:, 1::2], 0, H) + + keep = ((bbox[:, 2] - bbox[:, 0]).numpy() > 1.) & \ + ((bbox[:, 3] - bbox[:, 1]).numpy() > 1.) + if keep.sum() == 0: + bbox = paddle.zeros([1, 6], dtype='float32') + if has_mask: + mask = paddle.zeros([1, H, W], dtype='uint8') else: - boxes = paddle.to_tensor(boxes.numpy()[keep]).astype("float32") - classes = paddle.to_tensor(classes.numpy()[keep]).astype( - "float32").unsqueeze(-1) - scores = paddle.to_tensor(scores.numpy()[keep]).astype( - "float32").unsqueeze(-1) + label = paddle.to_tensor(label.numpy()[keep]).astype( + 'float32').unsqueeze(-1) + score = paddle.to_tensor(score.numpy()[keep]).astype( + 'float32').unsqueeze(-1) + bbox = paddle.to_tensor(bbox.numpy()[keep]).astype('float32') + if has_mask: + mask = paddle.to_tensor(mask.numpy()[keep]).astype( + 'float32').unsqueeze(1) + mask = paste_mask(mask, bbox, H, W, self.assign_on_cpu) + mask = paddle.cast(mask >= self.binary_thresh, 'uint8') + bbox = paddle.concat([label, score, bbox], axis=-1) + + bbox_num[i] = bbox.shape[0] + bbox_pred.append(bbox) + if has_mask: + mask_pred.append(mask) + + bbox_pred = paddle.concat(bbox_pred) + mask_pred = paddle.concat(mask_pred) if has_mask else None - bboxes = paddle.concat([classes, scores, boxes], axis=-1) + if self.assign_on_cpu: + paddle.set_device(device) - boxes_final.append(bboxes) - bbox_num[i] = bboxes.shape[0] + if has_mask: + return bbox_pred, bbox_num, mask_pred + else: + return bbox_pred, bbox_num - bbox_pred = paddle.concat(boxes_final) - return bbox_pred, bbox_num + +def paste_mask(masks, boxes, im_h, im_w, assign_on_cpu=False): + """ + Paste the mask prediction to the original image. + """ + x0_int, y0_int = 0, 0 + x1_int, y1_int = im_w, im_h + x0, y0, x1, y1 = paddle.split(boxes, 4, axis=1) + N = masks.shape[0] + img_y = paddle.arange(y0_int, y1_int) + 0.5 + img_x = paddle.arange(x0_int, x1_int) + 0.5 + + img_y = (img_y - y0) / (y1 - y0) * 2 - 1 + img_x = (img_x - x0) / (x1 - x0) * 2 - 1 + # img_x, img_y have shapes (N, w), (N, h) + + if assign_on_cpu: + paddle.set_device('cpu') + gx = img_x[:, None, :].expand( + [N, paddle.shape(img_y)[1], paddle.shape(img_x)[1]]) + gy = img_y[:, :, None].expand( + [N, paddle.shape(img_y)[1], paddle.shape(img_x)[1]]) + grid = paddle.stack([gx, gy], axis=3) + img_masks = F.grid_sample(masks, grid, align_corners=False) + return img_masks[:, 0] def multiclass_nms(bboxs, num_classes, match_threshold=0.6, match_metric='iou'): diff --git a/ppdet/modeling/proposal_generator/__init__.py b/ppdet/modeling/proposal_generator/__init__.py index 9fb518f2af6747ec25f3b5f7428891cbe89b95a8..f3ad19999ee3c606e0d64c47f9e33732260b1d0b 100644 --- a/ppdet/modeling/proposal_generator/__init__.py +++ b/ppdet/modeling/proposal_generator/__init__.py @@ -1,2 +1,5 @@ from . import rpn_head +from . import embedding_rpn_head + from .rpn_head import * +from .embedding_rpn_head import * diff --git a/ppdet/modeling/proposal_generator/embedding_rpn_head.py b/ppdet/modeling/proposal_generator/embedding_rpn_head.py new file mode 100644 index 0000000000000000000000000000000000000000..29174984b1d998a1a8a1cb6d872f0f9d89d89408 --- /dev/null +++ b/ppdet/modeling/proposal_generator/embedding_rpn_head.py @@ -0,0 +1,63 @@ +# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# This code is referenced from: https://github.com/open-mmlab/mmdetection + +import paddle +from paddle import nn + +from ppdet.core.workspace import register + +__all__ = ['EmbeddingRPNHead'] + + +@register +class EmbeddingRPNHead(nn.Layer): + __shared__ = ['proposal_embedding_dim'] + + def __init__(self, num_proposals, proposal_embedding_dim=256): + super(EmbeddingRPNHead, self).__init__() + + self.num_proposals = num_proposals + self.proposal_embedding_dim = proposal_embedding_dim + + self._init_layers() + self._init_weights() + + def _init_layers(self): + self.init_proposal_bboxes = nn.Embedding(self.num_proposals, 4) + self.init_proposal_features = nn.Embedding(self.num_proposals, + self.proposal_embedding_dim) + + def _init_weights(self): + init_bboxes = paddle.empty_like(self.init_proposal_bboxes.weight) + init_bboxes[:, :2] = 0.5 + init_bboxes[:, 2:] = 1.0 + self.init_proposal_bboxes.weight.set_value(init_bboxes) + + @staticmethod + def bbox_cxcywh_to_xyxy(x): + cxcy, wh = paddle.split(x, 2, axis=-1) + return paddle.concat([cxcy - 0.5 * wh, cxcy + 0.5 * wh], axis=-1) + + def forward(self, img_whwh): + proposal_bboxes = self.init_proposal_bboxes.weight.clone() + proposal_bboxes = self.bbox_cxcywh_to_xyxy(proposal_bboxes) + proposal_bboxes = proposal_bboxes.unsqueeze(0) * img_whwh.unsqueeze(1) + + proposal_features = self.init_proposal_features.weight.clone() + proposal_features = proposal_features.unsqueeze(0).tile( + [img_whwh.shape[0], 1, 1]) + + return proposal_bboxes, proposal_features