未验证 提交 e22eb8a1 编写于 作者: U ucsk 提交者: GitHub

[论文复现营] add QueryInst (#7585)

* add QueryInst

* update sparse config

* Clean up redundancy and update style

* update config

* add trained models & resolve conflict

* share bbox_cxcywh_to_xyxy

* Revert "share bbox_cxcywh_to_xyxy"

This reverts commit ccf811211eda0a59dc0446f1b716f4e7412367f9.

* repeat bbox_cxcywh_to_xyxy

* Poly2Mask: Add delete switch
RandomSizeCrop: keep origin randint style

* update copyright
上级 b7aa8a92
# QueryInst: Instances as Queries
## Introduction
QueryInst is a multi-stage end-to-end system that treats instances of interest as learnable queries, enabling query
based object detectors, e.g., Sparse R-CNN, to have strong instance segmentation performance. The attributes of
instances such as categories, bounding boxes, instance masks, and instance association embeddings are represented by
queries in a unified manner. In QueryInst, a query is shared by both detection and segmentation via dynamic convolutions
and driven by parallelly-supervised multi-stage learning.
## Model Zoo
| Backbone | Lr schd | Proposals | MultiScale | RandomCrop | bbox AP | mask AP | Download | Config |
|:------------:|:-------:|:---------:|:----------:|:----------:|:-------:|:-------:|------------------------------------------------------------------------------------------------------|----------------------------------------------------------|
| ResNet50-FPN | 1x | 100 | × | × | 42.1 | 37.8 | [model](https://bj.bcebos.com/v1/paddledet/models/queryinst_r50_fpn_1x_pro100_coco.pdparams) | [config](./queryinst_r50_fpn_1x_pro100_coco.yml) |
| ResNet50-FPN | 3x | 300 | √ | √ | 47.9 | 42.1 | [model](https://bj.bcebos.com/v1/paddledet/models/queryinst_r50_fpn_ms_crop_3x_pro300_coco.pdparams) | [config](./queryinst_r50_fpn_ms_crop_3x_pro300_coco.yml) |
- COCO val-set evaluation results.
- These configurations are for 4-card training.
Please modify these parameters as appropriate:
```yaml
worker_num: 4
TrainReader:
use_shared_memory: true
find_unused_parameters: true
```
## Citations
```
@InProceedings{Fang_2021_ICCV,
author = {Fang, Yuxin and Yang, Shusheng and Wang, Xinggang and Li, Yu and Fang, Chen and Shan, Ying and Feng, Bin and Liu, Wenyu},
title = {Instances As Queries},
booktitle = {Proceedings of the IEEE/CVF International Conference on Computer Vision (ICCV)},
month = {October},
year = {2021},
pages = {6910-6919}
}
```
epoch: 12
LearningRate:
base_lr: 0.0001
schedulers:
- !PiecewiseDecay
gamma: 0.1
milestones: [8, 11]
- !LinearWarmup
start_factor: 0.001
steps: 1000
OptimizerBuilder:
clip_grad_by_norm: 0.1
optimizer:
type: AdamW
weight_decay: 0.0001
num_proposals: &num_proposals 100
proposal_embedding_dim: &proposal_embedding_dim 256
bbox_resolution: &bbox_resolution 7
mask_resolution: &mask_resolution 14
architecture: QueryInst
pretrain_weights: https://paddledet.bj.bcebos.com/models/pretrained/ResNet50_cos_pretrained.pdparams
QueryInst:
backbone: ResNet
neck: FPN
rpn_head: EmbeddingRPNHead
roi_head: SparseRoIHead
post_process: SparsePostProcess
ResNet:
depth: 50
norm_type: bn
freeze_at: 0
return_idx: [ 0, 1, 2, 3 ]
num_stages: 4
lr_mult_list: [ 0.1, 0.1, 0.1, 0.1 ]
FPN:
out_channel: *proposal_embedding_dim
extra_stage: 0
EmbeddingRPNHead:
num_proposals: *num_proposals
SparseRoIHead:
num_stages: 6
bbox_roi_extractor:
resolution: *bbox_resolution
sampling_ratio: 2
aligned: True
mask_roi_extractor:
resolution: *mask_resolution
sampling_ratio: 2
aligned: True
bbox_head: DIIHead
mask_head: DynamicMaskHead
loss_func: QueryInstLoss
DIIHead:
feedforward_channels: 2048
dynamic_feature_channels: 64
roi_resolution: *bbox_resolution
num_attn_heads: 8
dropout: 0.0
num_ffn_fcs: 2
num_cls_fcs: 1
num_reg_fcs: 3
DynamicMaskHead:
dynamic_feature_channels: 64
roi_resolution: *mask_resolution
num_convs: 4
conv_kernel_size: 3
conv_channels: 256
upsample_method: 'deconv'
upsample_scale_factor: 2
QueryInstLoss:
focal_loss_alpha: 0.25
focal_loss_gamma: 2.0
class_weight: 2.0
l1_weight: 5.0
giou_weight: 2.0
mask_weight: 8.0
SparsePostProcess:
num_proposals: *num_proposals
binary_thresh: 0.5
worker_num: 4
TrainReader:
sample_transforms:
- Decode: {}
- Poly2Mask: {del_poly: True}
- Resize: {interp: 1, target_size: [800, 1333], keep_ratio: True}
- RandomFlip: {prob: 0.5}
- NormalizeImage: {is_scale: true, mean: [0.485,0.456,0.406], std: [0.229, 0.224,0.225]}
- Permute: {}
batch_transforms:
- PadBatch: {pad_to_stride: 32}
- Gt2SparseTarget: {}
batch_size: 4
shuffle: true
drop_last: true
collate_batch: false
use_shared_memory: true
EvalReader:
sample_transforms:
- Decode: {}
- Resize: {interp: 1, target_size: [800, 1333], keep_ratio: True}
- NormalizeImage: {is_scale: true, mean: [0.485,0.456,0.406], std: [0.229, 0.224,0.225]}
- Permute: {}
batch_transforms:
- PadBatch: {pad_to_stride: 32}
- Gt2SparseTarget: {}
batch_size: 1
shuffle: false
drop_last: false
TestReader:
sample_transforms:
- Decode: {}
- Resize: {interp: 1, target_size: [800, 1333], keep_ratio: True}
- NormalizeImage: {is_scale: true, mean: [0.485,0.456,0.406], std: [0.229, 0.224,0.225]}
- Permute: {}
batch_transforms:
- PadBatch: {pad_to_stride: 32}
- Gt2SparseTarget: {}
batch_size: 1
shuffle: false
_BASE_: [
'../datasets/coco_instance.yml',
'../runtime.yml',
'_base_/optimizer_1x.yml',
'_base_/queryinst_r50_fpn.yml',
'_base_/queryinst_reader.yml',
]
log_iter: 50
find_unused_parameters: true
weights: output/queryinst_r50_fpn_1x_pro100_coco/model_final
_BASE_: [
'./queryinst_r50_fpn_1x_pro100_coco.yml',
]
weights: output/queryinst_r50_fpn_ms_crop_3x_pro300_coco/model_final
EmbeddingRPNHead:
num_proposals: 300
QueryInstPostProcess:
num_proposals: 300
epoch: 36
LearningRate:
base_lr: 0.0001
schedulers:
- !PiecewiseDecay
gamma: 0.1
milestones: [27, 33]
- !LinearWarmup
start_factor: 0.001
steps: 1000
TrainReader:
sample_transforms:
- Decode: {}
- Poly2Mask: {del_poly: True}
- RandomFlip: {prob: 0.5}
- RandomSelect: { transforms1: [ RandomShortSideResize: { short_side_sizes: [ 480, 512, 544, 576, 608, 640, 672, 704, 736, 768, 800 ], max_size: 1333 } ],
transforms2: [
RandomShortSideResize: { short_side_sizes: [ 400, 500, 600 ], max_size: 1333 },
RandomSizeCrop: { min_size: 384, max_size: 600, keep_empty: true },
RandomShortSideResize: { short_side_sizes: [ 480, 512, 544, 576, 608, 640, 672, 704, 736, 768, 800 ], max_size: 1333 } ]
}
- NormalizeImage: { is_scale: true, mean: [ 0.485,0.456,0.406 ], std: [ 0.229, 0.224,0.225 ] }
- Permute: {}
batch_transforms:
- PadBatch: {pad_to_stride: 32}
- Gt2SparseTarget: {}
batch_size: 4
shuffle: true
drop_last: true
collate_batch: false
use_shared_memory: true
worker_num: 4
use_process: true
TrainReader:
sample_transforms:
......@@ -10,12 +9,11 @@ TrainReader:
- Permute: {}
batch_transforms:
- PadBatch: {pad_to_stride: 32}
- Gt2SparseRCNNTarget: {}
- Gt2SparseTarget: {use_padding_shape: True}
batch_size: 4
shuffle: true
drop_last: true
collate_batch: false
use_process: true
EvalReader:
sample_transforms:
......@@ -25,11 +23,10 @@ EvalReader:
- Permute: {}
batch_transforms:
- PadBatch: {pad_to_stride: 32}
- Gt2SparseRCNNTarget: {}
- Gt2SparseTarget: {use_padding_shape: True}
batch_size: 1
shuffle: false
drop_last: false
use_process: true
TestReader:
sample_transforms:
......@@ -39,6 +36,6 @@ TestReader:
- Permute: {}
batch_transforms:
- PadBatch: {pad_to_stride: 32}
- Gt2SparseRCNNTarget: {}
- Gt2SparseTarget: {use_padding_shape: True}
batch_size: 1
shuffle: false
......@@ -44,7 +44,7 @@ __all__ = [
'Gt2FCOSTarget',
'Gt2TTFTarget',
'Gt2Solov2Target',
'Gt2SparseRCNNTarget',
'Gt2SparseTarget',
'PadMaskBatch',
'Gt2GFLTarget',
'Gt2CenterNetTarget',
......@@ -916,27 +916,33 @@ class Gt2Solov2Target(BaseOperator):
@register_op
class Gt2SparseRCNNTarget(BaseOperator):
'''
Generate SparseRCNN targets by groud truth data
'''
def __init__(self):
super(Gt2SparseRCNNTarget, self).__init__()
class Gt2SparseTarget(BaseOperator):
def __init__(self, use_padding_shape=False):
super(Gt2SparseTarget, self).__init__()
self.use_padding_shape = use_padding_shape
def __call__(self, samples, context=None):
for sample in samples:
im = sample["image"]
h, w = im.shape[1:3]
img_whwh = np.array([w, h, w, h], dtype=np.int32)
sample["img_whwh"] = img_whwh
if "scale_factor" in sample:
sample["scale_factor_wh"] = np.array(
[sample["scale_factor"][1], sample["scale_factor"][0]],
dtype=np.float32)
ori_h, ori_w = sample['h'], sample['w']
if self.use_padding_shape:
h, w = sample["image"].shape[1:3]
if "scale_factor" in sample:
sf_w, sf_h = sample["scale_factor"][1], sample[
"scale_factor"][0]
sample["scale_factor_whwh"] = np.array(
[sf_w, sf_h, sf_w, sf_h], dtype=np.float32)
else:
sample["scale_factor_whwh"] = np.array(
[1.0, 1.0, 1.0, 1.0], dtype=np.float32)
else:
sample["scale_factor_wh"] = np.array(
[1.0, 1.0], dtype=np.float32)
h, w = round(sample['im_shape'][0]), round(sample['im_shape'][
1])
sample["scale_factor_whwh"] = np.array(
[w / ori_w, h / ori_h, w / ori_w, h / ori_h],
dtype=np.float32)
sample["img_whwh"] = np.array([w, h, w, h], dtype=np.float32)
sample["ori_shape"] = np.array([ori_h, ori_w], dtype=np.int32)
return samples
......
......@@ -2097,13 +2097,16 @@ class Pad(BaseOperator):
@register_op
class Poly2Mask(BaseOperator):
"""
gt poly to mask annotations
gt poly to mask annotations.
Args:
del_poly (bool): Whether to delete poly after generating mask. Default: False.
"""
def __init__(self):
def __init__(self, del_poly=False):
super(Poly2Mask, self).__init__()
import pycocotools.mask as maskUtils
self.maskutils = maskUtils
self.del_poly = del_poly
def _poly2mask(self, mask_ann, img_h, img_w):
if isinstance(mask_ann, list):
......@@ -2122,13 +2125,15 @@ class Poly2Mask(BaseOperator):
def apply(self, sample, context=None):
assert 'gt_poly' in sample
im_h = sample['h']
im_w = sample['w']
im_h, im_w = sample['im_shape']
masks = [
self._poly2mask(gt_poly, im_h, im_w)
for gt_poly in sample['gt_poly']
]
sample['gt_segm'] = np.asarray(masks).astype(np.uint8)
if self.del_poly:
del (sample['gt_poly'])
return sample
......@@ -2677,12 +2682,21 @@ class RandomShortSideResize(BaseOperator):
class RandomSizeCrop(BaseOperator):
"""
Cut the image randomly according to `min_size` and `max_size`
Args:
min_size (int): Min size for edges of cropped image.
max_size (int): Max size for edges of cropped image. If it
is set to larger than length of the input image,
the output will keep the origin length.
keep_empty (bool): Whether to keep the cropped result with no object.
If it is set to False, the no-object result will not
be returned, replaced by the original input.
"""
def __init__(self, min_size, max_size):
def __init__(self, min_size, max_size, keep_empty=True):
super(RandomSizeCrop, self).__init__()
self.min_size = min_size
self.max_size = max_size
self.keep_empty = keep_empty
from paddle.vision.transforms.functional import crop as paddle_crop
self.paddle_crop = paddle_crop
......@@ -2712,17 +2726,20 @@ class RandomSizeCrop(BaseOperator):
return i, j, th, tw
def crop(self, sample, region):
image_shape = sample['image'].shape[:2]
sample['image'] = self.paddle_crop(sample['image'], *region)
keep_index = None
# apply bbox
# apply bbox and check whether the cropped result is valid
if 'gt_bbox' in sample and len(sample['gt_bbox']) > 0:
sample['gt_bbox'] = self.apply_bbox(sample['gt_bbox'], region)
bbox = sample['gt_bbox'].reshape([-1, 2, 2])
croped_bbox = self.apply_bbox(sample['gt_bbox'], region)
bbox = croped_bbox.reshape([-1, 2, 2])
area = (bbox[:, 1, :] - bbox[:, 0, :]).prod(axis=1)
keep_index = np.where(area > 0)[0]
sample['gt_bbox'] = sample['gt_bbox'][keep_index] if len(
if not self.keep_empty and len(keep_index) == 0:
# When keep_empty is set to False, cropped with no-object will
# not be used and return the origin content.
return sample
sample['gt_bbox'] = croped_bbox[keep_index] if len(
keep_index) > 0 else np.zeros(
[0, 4], dtype=np.float32)
sample['gt_class'] = sample['gt_class'][keep_index] if len(
......@@ -2737,17 +2754,24 @@ class RandomSizeCrop(BaseOperator):
keep_index) > 0 else np.zeros(
[0, 1], dtype=np.float32)
image_shape = sample['image'].shape[:2]
sample['image'] = self.paddle_crop(sample['image'], *region)
sample['im_shape'] = np.array(
sample['image'].shape[:2], dtype=np.float32)
# apply polygon
if 'gt_poly' in sample and len(sample['gt_poly']) > 0:
sample['gt_poly'] = self.apply_segm(sample['gt_poly'], region,
image_shape)
if keep_index is not None:
sample['gt_poly'] = np.array(sample['gt_poly'])
if keep_index is not None and len(keep_index) > 0:
sample['gt_poly'] = sample['gt_poly'][keep_index]
sample['gt_poly'] = sample['gt_poly'].tolist()
# apply gt_segm
if 'gt_segm' in sample and len(sample['gt_segm']) > 0:
i, j, h, w = region
sample['gt_segm'] = sample['gt_segm'][:, i:i + h, j:j + w]
if keep_index is not None:
if keep_index is not None and len(keep_index) > 0:
sample['gt_segm'] = sample['gt_segm'][keep_index]
return sample
......
......@@ -40,6 +40,7 @@ from . import yolox
from . import yolof
from . import pose3d_metro
from . import centertrack
from . import queryinst
from .meta_arch import *
from .faster_rcnn import *
......@@ -70,3 +71,4 @@ from .yolox import *
from .yolof import *
from .pose3d_metro import *
from .centertrack import *
from .queryinst import *
# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import paddle
from ppdet.core.workspace import register, create
from .meta_arch import BaseArch
__all__ = ['QueryInst']
@register
class QueryInst(BaseArch):
__category__ = 'architecture'
__inject__ = ['post_process']
def __init__(self,
backbone,
neck,
rpn_head,
roi_head,
post_process='SparsePostProcess'):
super(QueryInst, self).__init__()
self.backbone = backbone
self.neck = neck
self.rpn_head = rpn_head
self.roi_head = roi_head
self.post_process = post_process
@classmethod
def from_config(cls, cfg, *args, **kwargs):
backbone = create(cfg['backbone'])
kwargs = {'input_shape': backbone.out_shape}
neck = create(cfg['neck'], **kwargs)
kwargs = {'input_shape': neck.out_shape}
rpn_head = create(cfg['rpn_head'], **kwargs)
roi_head = create(cfg['roi_head'], **kwargs)
return {
'backbone': backbone,
'neck': neck,
'rpn_head': rpn_head,
"roi_head": roi_head
}
def _forward(self, targets=None):
features = self.backbone(self.inputs)
features = self.neck(features)
proposal_bboxes, proposal_features = self.rpn_head(self.inputs[
'img_whwh'])
outputs = self.roi_head(features, proposal_bboxes, proposal_features,
targets)
if self.training:
return outputs
else:
bbox_pred, bbox_num, mask_pred = self.post_process(
outputs['class_logits'], outputs['bbox_pred'],
self.inputs['scale_factor_whwh'], self.inputs['ori_shape'],
outputs['mask_logits'])
return bbox_pred, bbox_num, mask_pred
def get_loss(self):
targets = []
for i in range(len(self.inputs['img_whwh'])):
boxes = self.inputs['gt_bbox'][i]
labels = self.inputs['gt_class'][i].squeeze(-1)
img_whwh = self.inputs['img_whwh'][i]
if boxes.shape[0] != 0:
img_whwh_tgt = img_whwh.unsqueeze(0).tile([boxes.shape[0], 1])
else:
img_whwh_tgt = paddle.zeros_like(boxes)
gt_segm = self.inputs['gt_segm'][i].astype('float32')
targets.append({
'boxes': boxes,
'labels': labels,
'img_whwh': img_whwh,
'img_whwh_tgt': img_whwh_tgt,
'gt_segm': gt_segm
})
losses = self._forward(targets)
losses.update({'loss': sum(losses.values())})
return losses
def get_pred(self):
bbox_pred, bbox_num, mask_pred = self._forward()
return {'bbox': bbox_pred, 'bbox_num': bbox_num, 'mask': mask_pred}
......@@ -60,10 +60,10 @@ class SparseRCNN(BaseArch):
head_outs = self.head(fpn_feats, self.inputs["img_whwh"])
if not self.training:
bboxes = self.postprocess(
bbox_pred, bbox_num = self.postprocess(
head_outs["pred_logits"], head_outs["pred_boxes"],
self.inputs["scale_factor_wh"], self.inputs["img_whwh"])
return bboxes
self.inputs["scale_factor_whwh"], self.inputs["ori_shape"])
return bbox_pred, bbox_num
else:
return head_outs
......
......@@ -143,8 +143,8 @@ def delta2bbox_v2(deltas,
dw = paddle.clip(dw, max=clip_scale)
dh = paddle.clip(dh, max=clip_scale)
else:
dw = dw.clip(min=-ctr_clip, max=ctr_clip)
dh = dh.clip(min=-ctr_clip, max=ctr_clip)
dw = dw.clip(min=-clip_scale, max=clip_scale)
dh = dh.clip(min=-clip_scale, max=clip_scale)
pred_ctr_x = dx + ctr_x.unsqueeze(1)
pred_ctr_y = dy + ctr_y.unsqueeze(1)
......
......@@ -39,6 +39,7 @@ from . import ld_gfl_head
from . import yolof_head
from . import ppyoloe_contrast_head
from . import centertrack_head
from . import sparse_roi_head
from .bbox_head import *
from .mask_head import *
......@@ -67,3 +68,4 @@ from .ppyoloe_r_head import *
from .yolof_head import *
from .ppyoloe_contrast_head import *
from .centertrack_head import *
from .sparse_roi_head import *
# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# This code is referenced from: https://github.com/open-mmlab/mmdetection
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import copy
import paddle
from paddle import nn
from ppdet.core.workspace import register
from ppdet.modeling import initializer as init
from .roi_extractor import RoIAlign
from ..bbox_utils import delta2bbox_v2
from ..cls_utils import _get_class_default_kwargs
from ..layers import MultiHeadAttention
__all__ = ['SparseRoIHead', 'DIIHead', 'DynamicMaskHead']
class DynamicConv(nn.Layer):
def __init__(self,
in_channels=256,
feature_channels=64,
out_channels=None,
roi_resolution=7,
with_proj=True):
super(DynamicConv, self).__init__()
self.in_channels = in_channels
self.feature_channels = feature_channels
self.out_channels = out_channels if out_channels else in_channels
self.num_params_in = self.in_channels * self.feature_channels
self.num_params_out = self.out_channels * self.feature_channels
self.dynamic_layer = nn.Linear(self.in_channels,
self.num_params_in + self.num_params_out)
self.norm_in = nn.LayerNorm(self.feature_channels)
self.norm_out = nn.LayerNorm(self.out_channels)
self.activation = nn.ReLU()
self.with_proj = with_proj
if self.with_proj:
num_output = self.out_channels * roi_resolution**2
self.fc_layer = nn.Linear(num_output, self.out_channels)
self.fc_norm = nn.LayerNorm(self.out_channels)
def forward(self, param_feature, input_feature):
input_feature = input_feature.flatten(2).transpose([2, 0, 1])
input_feature = input_feature.transpose([1, 0, 2])
parameters = self.dynamic_layer(param_feature)
param_in = parameters[:, :self.num_params_in].reshape(
[-1, self.in_channels, self.feature_channels])
param_out = parameters[:, -self.num_params_out:].reshape(
[-1, self.feature_channels, self.out_channels])
features = paddle.bmm(input_feature, param_in)
features = self.norm_in(features)
features = self.activation(features)
features = paddle.bmm(features, param_out)
features = self.norm_out(features)
features = self.activation(features)
if self.with_proj:
features = features.flatten(1)
features = self.fc_layer(features)
features = self.fc_norm(features)
features = self.activation(features)
return features
class FFN(nn.Layer):
def __init__(self,
embed_dims=256,
feedforward_channels=2048,
num_fcs=2,
ffn_drop=0.0,
add_identity=True):
super(FFN, self).__init__()
layers = []
in_channels = embed_dims
for _ in range(num_fcs - 1):
layers.append(
nn.Sequential(
nn.Linear(in_channels, feedforward_channels),
nn.ReLU(), nn.Dropout(ffn_drop)))
in_channels = feedforward_channels
layers.append(nn.Linear(feedforward_channels, embed_dims))
layers.append(nn.Dropout(ffn_drop))
self.layers = nn.Sequential(*layers)
self.add_identity = add_identity
def forward(self, x):
identity = x
out = self.layers(x)
if not self.add_identity:
return out
else:
return out + identity
@register
class DynamicMaskHead(nn.Layer):
__shared__ = ['num_classes', 'proposal_embedding_dim', 'norm_type']
def __init__(self,
num_classes=80,
proposal_embedding_dim=256,
dynamic_feature_channels=64,
roi_resolution=14,
num_convs=4,
conv_kernel_size=3,
conv_channels=256,
upsample_method='deconv',
upsample_scale_factor=2,
norm_type='bn'):
super(DynamicMaskHead, self).__init__()
self.d_model = proposal_embedding_dim
self.instance_interactive_conv = DynamicConv(
self.d_model,
dynamic_feature_channels,
roi_resolution=roi_resolution,
with_proj=False)
self.convs = nn.LayerList()
for i in range(num_convs):
self.convs.append(
nn.Sequential(
nn.Conv2D(
self.d_model if i == 0 else conv_channels,
conv_channels,
conv_kernel_size,
padding='same',
bias_attr=False),
nn.BatchNorm2D(conv_channels),
nn.ReLU()))
if norm_type == 'sync_bn':
self.convs = nn.SyncBatchNorm.convert_sync_batchnorm(self.convs)
self.upsample_method = upsample_method
if upsample_method is None:
self.upsample = None
elif upsample_method == 'deconv':
self.upsample = nn.Conv2DTranspose(
conv_channels if num_convs > 0 else self.d_model,
conv_channels,
upsample_scale_factor,
stride=upsample_scale_factor)
self.relu = nn.ReLU()
else:
self.upsample = nn.Upsample(None, upsample_scale_factor)
cls_in_channels = conv_channels if num_convs > 0 else self.d_model
cls_in_channels = conv_channels if upsample_method == 'deconv' else cls_in_channels
self.conv_cls = nn.Conv2D(cls_in_channels, num_classes, 1)
self._init_weights()
def _init_weights(self):
for p in self.parameters():
if p.dim() > 1:
init.xavier_uniform_(p)
init.constant_(self.conv_cls.bias, 0.)
def forward(self, roi_features, attn_features):
attn_features = attn_features.reshape([-1, self.d_model])
attn_features_iic = self.instance_interactive_conv(attn_features,
roi_features)
x = attn_features_iic.transpose([0, 2, 1]).reshape(roi_features.shape)
for conv in self.convs:
x = conv(x)
if self.upsample is not None:
x = self.upsample(x)
if self.upsample_method == 'deconv':
x = self.relu(x)
mask_pred = self.conv_cls(x)
return mask_pred
@register
class DIIHead(nn.Layer):
__shared__ = ['num_classes', 'proposal_embedding_dim']
def __init__(self,
num_classes=80,
proposal_embedding_dim=256,
feedforward_channels=2048,
dynamic_feature_channels=64,
roi_resolution=7,
num_attn_heads=8,
dropout=0.0,
num_ffn_fcs=2,
num_cls_fcs=1,
num_reg_fcs=3):
super(DIIHead, self).__init__()
self.num_classes = num_classes
self.d_model = proposal_embedding_dim
self.attention = MultiHeadAttention(self.d_model, num_attn_heads,
dropout)
self.attention_norm = nn.LayerNorm(self.d_model)
self.instance_interactive_conv = DynamicConv(
self.d_model,
dynamic_feature_channels,
roi_resolution=roi_resolution,
with_proj=True)
self.instance_interactive_conv_dropout = nn.Dropout(dropout)
self.instance_interactive_conv_norm = nn.LayerNorm(self.d_model)
self.ffn = FFN(self.d_model, feedforward_channels, num_ffn_fcs, dropout)
self.ffn_norm = nn.LayerNorm(self.d_model)
self.cls_fcs = nn.LayerList()
for _ in range(num_cls_fcs):
self.cls_fcs.append(
nn.Linear(
self.d_model, self.d_model, bias_attr=False))
self.cls_fcs.append(nn.LayerNorm(self.d_model))
self.cls_fcs.append(nn.ReLU())
self.fc_cls = nn.Linear(self.d_model, self.num_classes)
self.reg_fcs = nn.LayerList()
for _ in range(num_reg_fcs):
self.reg_fcs.append(
nn.Linear(
self.d_model, self.d_model, bias_attr=False))
self.reg_fcs.append(nn.LayerNorm(self.d_model))
self.reg_fcs.append(nn.ReLU())
self.fc_reg = nn.Linear(self.d_model, 4)
self._init_weights()
def _init_weights(self):
for p in self.parameters():
if p.dim() > 1:
init.xavier_uniform_(p)
bias_init = init.bias_init_with_prob(0.01)
init.constant_(self.fc_cls.bias, bias_init)
def forward(self, roi_features, proposal_features):
N, num_proposals = proposal_features.shape[:2]
proposal_features = proposal_features + self.attention(
proposal_features)
attn_features = self.attention_norm(proposal_features)
proposal_features = attn_features.reshape([-1, self.d_model])
proposal_features_iic = self.instance_interactive_conv(
proposal_features, roi_features)
proposal_features = proposal_features + self.instance_interactive_conv_dropout(
proposal_features_iic)
obj_features = self.instance_interactive_conv_norm(proposal_features)
obj_features = self.ffn(obj_features)
obj_features = self.ffn_norm(obj_features)
cls_feature = obj_features.clone()
reg_feature = obj_features.clone()
for cls_layer in self.cls_fcs:
cls_feature = cls_layer(cls_feature)
class_logits = self.fc_cls(cls_feature)
for reg_layer in self.reg_fcs:
reg_feature = reg_layer(reg_feature)
bbox_deltas = self.fc_reg(reg_feature)
class_logits = class_logits.reshape(
[N, num_proposals, self.num_classes])
bbox_deltas = bbox_deltas.reshape([N, num_proposals, 4])
obj_features = obj_features.reshape([N, num_proposals, self.d_model])
return class_logits, bbox_deltas, obj_features, attn_features
@staticmethod
def refine_bboxes(proposal_bboxes, bbox_deltas):
pred_bboxes = delta2bbox_v2(
bbox_deltas.reshape([-1, 4]),
proposal_bboxes.reshape([-1, 4]),
delta_mean=[0.0, 0.0, 0.0, 0.0],
delta_std=[0.5, 0.5, 1.0, 1.0],
ctr_clip=None)
return pred_bboxes.reshape(proposal_bboxes.shape)
@register
class SparseRoIHead(nn.Layer):
__inject__ = ['bbox_head', 'mask_head', 'loss_func']
def __init__(self,
num_stages=6,
bbox_roi_extractor=_get_class_default_kwargs(RoIAlign),
mask_roi_extractor=_get_class_default_kwargs(RoIAlign),
bbox_head='DIIHead',
mask_head='DynamicMaskHead',
loss_func='QueryInstLoss'):
super(SparseRoIHead, self).__init__()
self.num_stages = num_stages
self.bbox_roi_extractor = bbox_roi_extractor
self.mask_roi_extractor = mask_roi_extractor
if isinstance(bbox_roi_extractor, dict):
self.bbox_roi_extractor = RoIAlign(**bbox_roi_extractor)
if isinstance(mask_roi_extractor, dict):
self.mask_roi_extractor = RoIAlign(**mask_roi_extractor)
self.bbox_heads = nn.LayerList(
[copy.deepcopy(bbox_head) for _ in range(num_stages)])
self.mask_heads = nn.LayerList(
[copy.deepcopy(mask_head) for _ in range(num_stages)])
self.loss_helper = loss_func
@classmethod
def from_config(cls, cfg, input_shape):
bbox_roi_extractor = cfg['bbox_roi_extractor']
mask_roi_extractor = cfg['mask_roi_extractor']
assert isinstance(bbox_roi_extractor, dict)
assert isinstance(mask_roi_extractor, dict)
kwargs = RoIAlign.from_config(cfg, input_shape)
bbox_roi_extractor.update(kwargs)
mask_roi_extractor.update(kwargs)
return {
'bbox_roi_extractor': bbox_roi_extractor,
'mask_roi_extractor': mask_roi_extractor
}
@staticmethod
def get_roi_features(features, bboxes, roi_extractor):
rois_list = [
bboxes[i] for i in range(len(bboxes)) if len(bboxes[i]) > 0
]
rois_num = paddle.to_tensor(
[len(bboxes[i]) for i in range(len(bboxes))], dtype='int32')
pos_ids = paddle.cast(rois_num, dtype='bool')
if pos_ids.sum() != len(rois_num):
rois_num = rois_num[pos_ids]
features = [features[i][pos_ids] for i in range(len(features))]
return roi_extractor(features, rois_list, rois_num)
def _forward_train(self, body_feats, pro_bboxes, pro_feats, targets):
all_stage_losses = {}
for stage in range(self.num_stages):
bbox_head = self.bbox_heads[stage]
mask_head = self.mask_heads[stage]
roi_feats = self.get_roi_features(body_feats, pro_bboxes,
self.bbox_roi_extractor)
class_logits, bbox_deltas, pro_feats, attn_feats = bbox_head(
roi_feats, pro_feats)
bbox_pred = self.bbox_heads[stage].refine_bboxes(pro_bboxes,
bbox_deltas)
indices = self.loss_helper.matcher({
'pred_logits': class_logits.detach(),
'pred_boxes': bbox_pred.detach()
}, targets)
avg_factor = paddle.to_tensor(
[sum(len(tgt['labels']) for tgt in targets)], dtype='float32')
if paddle.distributed.get_world_size() > 1:
paddle.distributed.all_reduce(avg_factor)
avg_factor /= paddle.distributed.get_world_size()
avg_factor = paddle.clip(avg_factor, min=1.)
loss_classes = self.loss_helper.loss_classes(class_logits, targets,
indices, avg_factor)
if sum(len(v['labels']) for v in targets) == 0:
loss_bboxes = {
'loss_bbox': paddle.to_tensor([0.]),
'loss_giou': paddle.to_tensor([0.])
}
loss_masks = {'loss_mask': paddle.to_tensor([0.])}
else:
loss_bboxes = self.loss_helper.loss_bboxes(bbox_pred, targets,
indices, avg_factor)
pos_attn_feats = paddle.concat([
paddle.gather(
src, src_idx, axis=0)
for src, (src_idx, _) in zip(attn_feats, indices)
])
pos_bbox_pred = [
paddle.gather(
src, src_idx, axis=0)
for src, (src_idx, _) in zip(bbox_pred.detach(), indices)
]
pos_roi_feats = self.get_roi_features(body_feats, pos_bbox_pred,
self.mask_roi_extractor)
mask_logits = mask_head(pos_roi_feats, pos_attn_feats)
loss_masks = self.loss_helper.loss_masks(
pos_bbox_pred, mask_logits, targets, indices, avg_factor)
for loss in [loss_classes, loss_bboxes, loss_masks]:
for key in loss.keys():
all_stage_losses[f'stage{stage}_{key}'] = loss[key]
pro_bboxes = bbox_pred.detach()
return all_stage_losses
def _forward_test(self, body_feats, pro_bboxes, pro_feats):
for stage in range(self.num_stages):
roi_feats = self.get_roi_features(body_feats, pro_bboxes,
self.bbox_roi_extractor)
class_logits, bbox_deltas, pro_feats, attn_feats = self.bbox_heads[
stage](roi_feats, pro_feats)
bbox_pred = self.bbox_heads[stage].refine_bboxes(pro_bboxes,
bbox_deltas)
pro_bboxes = bbox_pred.detach()
roi_feats = self.get_roi_features(body_feats, bbox_pred,
self.mask_roi_extractor)
mask_logits = self.mask_heads[stage](roi_feats, attn_feats)
return {
'class_logits': class_logits,
'bbox_pred': bbox_pred,
'mask_logits': mask_logits
}
def forward(self,
body_features,
proposal_bboxes,
proposal_features,
targets=None):
if self.training:
return self._forward_train(body_features, proposal_bboxes,
proposal_features, targets)
else:
return self._forward_test(body_features, proposal_bboxes,
proposal_features)
......@@ -30,6 +30,7 @@ from . import smooth_l1_loss
from . import probiou_loss
from . import cot_loss
from . import supcontrast
from . import queryinst_loss
from .yolo_loss import *
from .iou_aware_loss import *
......@@ -49,4 +50,5 @@ from .smooth_l1_loss import *
from .pose3d_loss import *
from .probiou_loss import *
from .cot_loss import *
from .supcontrast import *
\ No newline at end of file
from .supcontrast import *
from .queryinst_loss import *
# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import paddle
import paddle.nn.functional as F
from ppdet.core.workspace import register
from ppdet.modeling.losses.iou_loss import GIoULoss
from .sparsercnn_loss import HungarianMatcher
__all__ = ['QueryInstLoss']
@register
class QueryInstLoss(object):
__shared__ = ['num_classes']
def __init__(self,
num_classes=80,
focal_loss_alpha=0.25,
focal_loss_gamma=2.0,
class_weight=2.0,
l1_weight=5.0,
giou_weight=2.0,
mask_weight=8.0):
super(QueryInstLoss, self).__init__()
self.num_classes = num_classes
self.focal_loss_alpha = focal_loss_alpha
self.focal_loss_gamma = focal_loss_gamma
self.loss_weights = {
"loss_cls": class_weight,
"loss_bbox": l1_weight,
"loss_giou": giou_weight,
"loss_mask": mask_weight
}
self.giou_loss = GIoULoss(eps=1e-6, reduction='sum')
self.matcher = HungarianMatcher(focal_loss_alpha, focal_loss_gamma,
class_weight, l1_weight, giou_weight)
def loss_classes(self, class_logits, targets, indices, avg_factor):
tgt_labels = paddle.full(
class_logits.shape[:2], self.num_classes, dtype='int32')
if sum(len(v['labels']) for v in targets) > 0:
tgt_classes = paddle.concat([
paddle.gather(
tgt['labels'], tgt_idx, axis=0)
for tgt, (_, tgt_idx) in zip(targets, indices)
])
batch_idx, src_idx = self._get_src_permutation_idx(indices)
for i, (batch_i, src_i) in enumerate(zip(batch_idx, src_idx)):
tgt_labels[int(batch_i), int(src_i)] = tgt_classes[i]
tgt_labels = tgt_labels.flatten(0, 1).unsqueeze(-1)
tgt_labels_onehot = paddle.cast(
tgt_labels == paddle.arange(0, self.num_classes), dtype='float32')
tgt_labels_onehot.stop_gradient = True
src_logits = class_logits.flatten(0, 1)
loss_cls = F.sigmoid_focal_loss(
src_logits,
tgt_labels_onehot,
alpha=self.focal_loss_alpha,
gamma=self.focal_loss_gamma,
reduction='sum') / avg_factor
losses = {'loss_cls': loss_cls * self.loss_weights['loss_cls']}
return losses
def loss_bboxes(self, bbox_pred, targets, indices, avg_factor):
bboxes = paddle.concat([
paddle.gather(
src, src_idx, axis=0)
for src, (src_idx, _) in zip(bbox_pred, indices)
])
tgt_bboxes = paddle.concat([
paddle.gather(
tgt['boxes'], tgt_idx, axis=0)
for tgt, (_, tgt_idx) in zip(targets, indices)
])
tgt_bboxes.stop_gradient = True
im_shapes = paddle.concat([tgt['img_whwh_tgt'] for tgt in targets])
bboxes_norm = bboxes / im_shapes
tgt_bboxes_norm = tgt_bboxes / im_shapes
loss_giou = self.giou_loss(bboxes, tgt_bboxes) / avg_factor
loss_bbox = F.l1_loss(
bboxes_norm, tgt_bboxes_norm, reduction='sum') / avg_factor
losses = {
'loss_bbox': loss_bbox * self.loss_weights['loss_bbox'],
'loss_giou': loss_giou * self.loss_weights['loss_giou']
}
return losses
def loss_masks(self, pos_bbox_pred, mask_logits, targets, indices,
avg_factor):
tgt_segm = [
paddle.gather(
tgt['gt_segm'], tgt_idx, axis=0)
for tgt, (_, tgt_idx) in zip(targets, indices)
]
tgt_masks = []
for i in range(len(indices)):
gt_segm = tgt_segm[i].unsqueeze(1)
if len(gt_segm) == 0:
continue
boxes = pos_bbox_pred[i]
boxes[:, 0::2] = paddle.clip(
boxes[:, 0::2], min=0, max=gt_segm.shape[3])
boxes[:, 1::2] = paddle.clip(
boxes[:, 1::2], min=0, max=gt_segm.shape[2])
boxes_num = paddle.to_tensor([1] * len(boxes), dtype='int32')
gt_mask = paddle.vision.ops.roi_align(
gt_segm,
boxes,
boxes_num,
output_size=mask_logits.shape[-2:],
aligned=True)
tgt_masks.append(gt_mask)
tgt_masks = paddle.concat(tgt_masks).squeeze(1)
tgt_masks = paddle.cast(tgt_masks >= 0.5, dtype='float32')
tgt_masks.stop_gradient = True
tgt_labels = paddle.concat([
paddle.gather(
tgt['labels'], tgt_idx, axis=0)
for tgt, (_, tgt_idx) in zip(targets, indices)
])
mask_label = F.one_hot(tgt_labels, self.num_classes).unsqueeze([2, 3])
mask_label = paddle.expand_as(mask_label, mask_logits)
mask_label.stop_gradient = True
src_masks = paddle.gather_nd(mask_logits, paddle.nonzero(mask_label))
shape = mask_logits.shape
src_masks = paddle.reshape(src_masks, [shape[0], shape[2], shape[3]])
src_masks = F.sigmoid(src_masks)
X = src_masks.flatten(1)
Y = tgt_masks.flatten(1)
inter = paddle.sum(X * Y, 1)
union = paddle.sum(X * X, 1) + paddle.sum(Y * Y, 1)
dice = (2 * inter) / (union + 2e-5)
loss_mask = (1 - dice).sum() / avg_factor
losses = {'loss_mask': loss_mask * self.loss_weights['loss_mask']}
return losses
@staticmethod
def _get_src_permutation_idx(indices):
batch_idx = paddle.concat(
[paddle.full_like(src, i) for i, (src, _) in enumerate(indices)])
src_idx = paddle.concat([src for (src, _) in indices])
return batch_idx, src_idx
......@@ -284,6 +284,11 @@ class HungarianMatcher(nn.Layer):
"""
bs, num_queries = outputs["pred_logits"].shape[:2]
if sum(len(v["labels"]) for v in targets) == 0:
return [(paddle.to_tensor(
[], dtype=paddle.int64), paddle.to_tensor(
[], dtype=paddle.int64)) for _ in range(bs)]
# We flatten to compute the cost matrices in a batch
out_prob = F.sigmoid(outputs["pred_logits"].flatten(
start_axis=0, stop_axis=1))
......
......@@ -206,31 +206,6 @@ class MaskPostProcess(object):
self.export_onnx = export_onnx
self.assign_on_cpu = assign_on_cpu
def paste_mask(self, masks, boxes, im_h, im_w):
"""
Paste the mask prediction to the original image.
"""
x0_int, y0_int = 0, 0
x1_int, y1_int = im_w, im_h
x0, y0, x1, y1 = paddle.split(boxes, 4, axis=1)
N = masks.shape[0]
img_y = paddle.arange(y0_int, y1_int) + 0.5
img_x = paddle.arange(x0_int, x1_int) + 0.5
img_y = (img_y - y0) / (y1 - y0) * 2 - 1
img_x = (img_x - x0) / (x1 - x0) * 2 - 1
# img_x, img_y have shapes (N, w), (N, h)
if self.assign_on_cpu:
paddle.set_device('cpu')
gx = img_x[:, None, :].expand(
[N, paddle.shape(img_y)[1], paddle.shape(img_x)[1]])
gy = img_y[:, :, None].expand(
[N, paddle.shape(img_y)[1], paddle.shape(img_x)[1]])
grid = paddle.stack([gx, gy], axis=3)
img_masks = F.grid_sample(masks, grid, align_corners=False)
return img_masks[:, 0]
def __call__(self, mask_out, bboxes, bbox_num, origin_shape):
"""
Decode the mask_out and paste the mask to the origin image.
......@@ -253,8 +228,8 @@ class MaskPostProcess(object):
if self.export_onnx:
h, w = origin_shape[0][0], origin_shape[0][1]
mask_onnx = self.paste_mask(mask_out[:, None, :, :], bboxes[:, 2:],
h, w)
mask_onnx = paste_mask(mask_out[:, None, :, :], bboxes[:, 2:], h, w,
self.assign_on_cpu)
mask_onnx = mask_onnx >= self.binary_thresh
pred_result = paddle.cast(mask_onnx, 'int32')
......@@ -270,9 +245,9 @@ class MaskPostProcess(object):
mask_out_i = mask_out[id_start:id_start + bbox_num[i], :, :]
im_h = origin_shape[i, 0]
im_w = origin_shape[i, 1]
bbox_num_i = bbox_num[id_start]
pred_mask = self.paste_mask(mask_out_i[:, None, :, :],
bboxes_i[:, 2:], im_h, im_w)
pred_mask = paste_mask(mask_out_i[:, None, :, :],
bboxes_i[:, 2:], im_h, im_w,
self.assign_on_cpu)
pred_mask = paddle.cast(pred_mask >= self.binary_thresh,
'int32')
pred_result[id_start:id_start + bbox_num[i], :im_h, :
......@@ -542,89 +517,110 @@ class DETRBBoxPostProcess(object):
@register
class SparsePostProcess(object):
__shared__ = ['num_classes']
__shared__ = ['num_classes', 'assign_on_cpu']
def __init__(self, num_proposals, num_classes=80):
def __init__(self,
num_proposals,
num_classes=80,
binary_thresh=0.5,
assign_on_cpu=False):
super(SparsePostProcess, self).__init__()
self.num_classes = num_classes
self.num_proposals = num_proposals
self.binary_thresh = binary_thresh
self.assign_on_cpu = assign_on_cpu
def __call__(self, box_cls, box_pred, scale_factor_wh, img_whwh):
"""
Arguments:
box_cls (Tensor): tensor of shape (batch_size, num_proposals, K).
The tensor predicts the classification probability for each proposal.
box_pred (Tensor): tensors of shape (batch_size, num_proposals, 4).
The tensor predicts 4-vector (x,y,w,h) box
regression values for every proposal
scale_factor_wh (Tensor): tensors of shape [batch_size, 2] the scalor of per img
img_whwh (Tensor): tensors of shape [batch_size, 4]
Returns:
bbox_pred (Tensor): tensors of shape [num_boxes, 6] Each row has 6 values:
[label, confidence, xmin, ymin, xmax, ymax]
bbox_num (Tensor): tensors of shape [batch_size] the number of RoIs in each image.
"""
assert len(box_cls) == len(scale_factor_wh) == len(img_whwh)
img_wh = img_whwh[:, :2]
scores = F.sigmoid(box_cls)
labels = paddle.arange(0, self.num_classes). \
unsqueeze(0).tile([self.num_proposals, 1]).flatten(start_axis=0, stop_axis=1)
classes_all = []
scores_all = []
boxes_all = []
for i, (scores_per_image,
box_pred_per_image) in enumerate(zip(scores, box_pred)):
scores_per_image, topk_indices = scores_per_image.flatten(
0, 1).topk(
self.num_proposals, sorted=False)
labels_per_image = paddle.gather(labels, topk_indices, axis=0)
box_pred_per_image = box_pred_per_image.reshape([-1, 1, 4]).tile(
[1, self.num_classes, 1]).reshape([-1, 4])
box_pred_per_image = paddle.gather(
box_pred_per_image, topk_indices, axis=0)
classes_all.append(labels_per_image)
scores_all.append(scores_per_image)
boxes_all.append(box_pred_per_image)
bbox_num = paddle.zeros([len(scale_factor_wh)], dtype="int32")
boxes_final = []
for i in range(len(scale_factor_wh)):
classes = classes_all[i]
boxes = boxes_all[i]
scores = scores_all[i]
boxes[:, 0::2] = paddle.clip(
boxes[:, 0::2], min=0, max=img_wh[i][0]) / scale_factor_wh[i][0]
boxes[:, 1::2] = paddle.clip(
boxes[:, 1::2], min=0, max=img_wh[i][1]) / scale_factor_wh[i][1]
boxes_w, boxes_h = (boxes[:, 2] - boxes[:, 0]).numpy(), (
boxes[:, 3] - boxes[:, 1]).numpy()
keep = (boxes_w > 1.) & (boxes_h > 1.)
if (keep.sum() == 0):
bboxes = paddle.zeros([1, 6]).astype("float32")
def __call__(self, scores, bboxes, scale_factor, ori_shape, masks=None):
assert len(scores) == len(bboxes) == \
len(ori_shape) == len(scale_factor)
device = paddle.device.get_device()
batch_size = len(ori_shape)
scores = F.sigmoid(scores)
has_mask = masks is not None
if has_mask:
masks = F.sigmoid(masks)
masks = masks.reshape([batch_size, -1, *masks.shape[1:]])
bbox_pred = []
mask_pred = [] if has_mask else None
bbox_num = paddle.zeros([batch_size], dtype='int32')
for i in range(batch_size):
score = scores[i]
bbox = bboxes[i]
score, indices = score.flatten(0, 1).topk(
self.num_proposals, sorted=False)
label = indices % self.num_classes
if has_mask:
mask = masks[i]
mask = mask.flatten(0, 1)[indices]
H, W = ori_shape[i][0], ori_shape[i][1]
bbox = bbox[paddle.cast(indices / self.num_classes, indices.dtype)]
bbox /= scale_factor[i]
bbox[:, 0::2] = paddle.clip(bbox[:, 0::2], 0, W)
bbox[:, 1::2] = paddle.clip(bbox[:, 1::2], 0, H)
keep = ((bbox[:, 2] - bbox[:, 0]).numpy() > 1.) & \
((bbox[:, 3] - bbox[:, 1]).numpy() > 1.)
if keep.sum() == 0:
bbox = paddle.zeros([1, 6], dtype='float32')
if has_mask:
mask = paddle.zeros([1, H, W], dtype='uint8')
else:
boxes = paddle.to_tensor(boxes.numpy()[keep]).astype("float32")
classes = paddle.to_tensor(classes.numpy()[keep]).astype(
"float32").unsqueeze(-1)
scores = paddle.to_tensor(scores.numpy()[keep]).astype(
"float32").unsqueeze(-1)
label = paddle.to_tensor(label.numpy()[keep]).astype(
'float32').unsqueeze(-1)
score = paddle.to_tensor(score.numpy()[keep]).astype(
'float32').unsqueeze(-1)
bbox = paddle.to_tensor(bbox.numpy()[keep]).astype('float32')
if has_mask:
mask = paddle.to_tensor(mask.numpy()[keep]).astype(
'float32').unsqueeze(1)
mask = paste_mask(mask, bbox, H, W, self.assign_on_cpu)
mask = paddle.cast(mask >= self.binary_thresh, 'uint8')
bbox = paddle.concat([label, score, bbox], axis=-1)
bbox_num[i] = bbox.shape[0]
bbox_pred.append(bbox)
if has_mask:
mask_pred.append(mask)
bbox_pred = paddle.concat(bbox_pred)
mask_pred = paddle.concat(mask_pred) if has_mask else None
bboxes = paddle.concat([classes, scores, boxes], axis=-1)
if self.assign_on_cpu:
paddle.set_device(device)
boxes_final.append(bboxes)
bbox_num[i] = bboxes.shape[0]
if has_mask:
return bbox_pred, bbox_num, mask_pred
else:
return bbox_pred, bbox_num
bbox_pred = paddle.concat(boxes_final)
return bbox_pred, bbox_num
def paste_mask(masks, boxes, im_h, im_w, assign_on_cpu=False):
"""
Paste the mask prediction to the original image.
"""
x0_int, y0_int = 0, 0
x1_int, y1_int = im_w, im_h
x0, y0, x1, y1 = paddle.split(boxes, 4, axis=1)
N = masks.shape[0]
img_y = paddle.arange(y0_int, y1_int) + 0.5
img_x = paddle.arange(x0_int, x1_int) + 0.5
img_y = (img_y - y0) / (y1 - y0) * 2 - 1
img_x = (img_x - x0) / (x1 - x0) * 2 - 1
# img_x, img_y have shapes (N, w), (N, h)
if assign_on_cpu:
paddle.set_device('cpu')
gx = img_x[:, None, :].expand(
[N, paddle.shape(img_y)[1], paddle.shape(img_x)[1]])
gy = img_y[:, :, None].expand(
[N, paddle.shape(img_y)[1], paddle.shape(img_x)[1]])
grid = paddle.stack([gx, gy], axis=3)
img_masks = F.grid_sample(masks, grid, align_corners=False)
return img_masks[:, 0]
def multiclass_nms(bboxs, num_classes, match_threshold=0.6, match_metric='iou'):
......
from . import rpn_head
from . import embedding_rpn_head
from .rpn_head import *
from .embedding_rpn_head import *
# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# This code is referenced from: https://github.com/open-mmlab/mmdetection
import paddle
from paddle import nn
from ppdet.core.workspace import register
__all__ = ['EmbeddingRPNHead']
@register
class EmbeddingRPNHead(nn.Layer):
__shared__ = ['proposal_embedding_dim']
def __init__(self, num_proposals, proposal_embedding_dim=256):
super(EmbeddingRPNHead, self).__init__()
self.num_proposals = num_proposals
self.proposal_embedding_dim = proposal_embedding_dim
self._init_layers()
self._init_weights()
def _init_layers(self):
self.init_proposal_bboxes = nn.Embedding(self.num_proposals, 4)
self.init_proposal_features = nn.Embedding(self.num_proposals,
self.proposal_embedding_dim)
def _init_weights(self):
init_bboxes = paddle.empty_like(self.init_proposal_bboxes.weight)
init_bboxes[:, :2] = 0.5
init_bboxes[:, 2:] = 1.0
self.init_proposal_bboxes.weight.set_value(init_bboxes)
@staticmethod
def bbox_cxcywh_to_xyxy(x):
cxcy, wh = paddle.split(x, 2, axis=-1)
return paddle.concat([cxcy - 0.5 * wh, cxcy + 0.5 * wh], axis=-1)
def forward(self, img_whwh):
proposal_bboxes = self.init_proposal_bboxes.weight.clone()
proposal_bboxes = self.bbox_cxcywh_to_xyxy(proposal_bboxes)
proposal_bboxes = proposal_bboxes.unsqueeze(0) * img_whwh.unsqueeze(1)
proposal_features = self.init_proposal_features.weight.clone()
proposal_features = proposal_features.unsqueeze(0).tile(
[img_whwh.shape[0], 1, 1])
return proposal_bboxes, proposal_features
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册