未验证 提交 841f2f4e 编写于 作者: F FL77N 提交者: GitHub

add sparsercnn (#3623)

* add sparsercnn

* update sparsercnn
上级 bb846096
......@@ -33,7 +33,7 @@ logger = setup_logger(__name__)
__all__ = [
'PadBatch', 'BatchRandomResize', 'Gt2YoloTarget', 'Gt2FCOSTarget',
'Gt2TTFTarget', 'Gt2Solov2Target'
'Gt2TTFTarget', 'Gt2Solov2Target', 'Gt2SparseRCNNTarget'
]
......@@ -746,3 +746,28 @@ class Gt2Solov2Target(BaseOperator):
data['grid_order{}'.format(idx)] = gt_grid_order
return samples
@register_op
class Gt2SparseRCNNTarget(BaseOperator):
'''
Generate SparseRCNN targets by groud truth data
'''
def __init__(self):
super(Gt2SparseRCNNTarget, self).__init__()
def __call__(self, samples, context=None):
for sample in samples:
im = sample["image"]
h, w = im.shape[1:3]
img_whwh = np.array([w, h, w, h], dtype=np.int32)
sample["img_whwh"] = img_whwh
if "scale_factor" in sample:
sample["scale_factor_wh"] = np.array([sample["scale_factor"][1], sample["scale_factor"][0]],
dtype=np.float32)
sample.pop("scale_factor")
else:
sample["scale_factor_wh"] = np.array([1.0, 1.0], dtype=np.float32)
return samples
......@@ -22,6 +22,7 @@ from . import deepsort
from . import fairmot
from . import centernet
from . import detr
from . import sparse_rcnn
from .meta_arch import *
from .faster_rcnn import *
......@@ -41,3 +42,4 @@ from .fairmot import *
from .centernet import *
from .blazeface import *
from .detr import *
from .sparse_rcnn import *
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from ppdet.core.workspace import register, create
from .meta_arch import BaseArch
__all__ = ["SparseRCNN"]
@register
class SparseRCNN(BaseArch):
__category__ = 'architecture'
__inject__ = ["postprocess"]
def __init__(self,
backbone,
neck,
head="SparsercnnHead",
postprocess="SparsePostProcess"):
super(SparseRCNN, self).__init__()
self.backbone = backbone
self.neck = neck
self.head = head
self.postprocess = postprocess
@classmethod
def from_config(cls, cfg, *args, **kwargs):
backbone = create(cfg['backbone'])
kwargs = {'input_shape': backbone.out_shape}
neck = create(cfg['neck'], **kwargs)
kwargs = {'roi_input_shape': neck.out_shape}
head = create(cfg['head'], **kwargs)
return {
'backbone': backbone,
'neck': neck,
"head": head,
}
def _forward(self):
body_feats = self.backbone(self.inputs)
fpn_feats = self.neck(body_feats)
head_outs = self.head(fpn_feats, self.inputs["img_whwh"])
if not self.training:
bboxes = self.postprocess(
head_outs["pred_logits"], head_outs["pred_boxes"],
self.inputs["scale_factor_wh"], self.inputs["img_whwh"])
return bboxes
else:
return head_outs
def get_loss(self):
batch_gt_class = self.inputs["gt_class"]
batch_gt_box = self.inputs["gt_bbox"]
batch_whwh = self.inputs["img_whwh"]
targets = []
for i in range(len(batch_gt_class)):
boxes = batch_gt_box[i]
labels = batch_gt_class[i].squeeze(-1)
img_whwh = batch_whwh[i]
img_whwh_tgt = img_whwh.unsqueeze(0).tile([int(boxes.shape[0]), 1])
targets.append({
"boxes": boxes,
"labels": labels,
"img_whwh": img_whwh,
"img_whwh_tgt": img_whwh_tgt
})
outputs = self._forward()
loss_dict = self.head.get_loss(outputs, targets)
acc = loss_dict["acc"]
loss_dict.pop("acc")
total_loss = sum(loss_dict.values())
loss_dict.update({"loss": total_loss, "acc": acc})
return loss_dict
def get_pred(self):
bbox_pred, bbox_num = self._forward()
output = {'bbox': bbox_pred, 'bbox_num': bbox_num}
return output
......@@ -26,6 +26,7 @@ from . import s2anet_head
from . import keypoint_hrhrnet_head
from . import centernet_head
from . import detr_head
from . import sparsercnn_head
from .bbox_head import *
from .mask_head import *
......@@ -41,3 +42,4 @@ from .s2anet_head import *
from .keypoint_hrhrnet_head import *
from .centernet_head import *
from .detr_head import *
from .sparsercnn_head import *
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import math
import copy
import paddle
import paddle.nn as nn
import ppdet.modeling.initializer as init
from ppdet.core.workspace import register
from ppdet.modeling.heads.roi_extractor import RoIAlign
from ppdet.modeling.bbox_utils import delta2bbox
_DEFAULT_SCALE_CLAMP = math.log(100000. / 16)
class DynamicConv(nn.Layer):
def __init__(
self,
head_hidden_dim,
head_dim_dynamic,
head_num_dynamic, ):
super().__init__()
self.hidden_dim = head_hidden_dim
self.dim_dynamic = head_dim_dynamic
self.num_dynamic = head_num_dynamic
self.num_params = self.hidden_dim * self.dim_dynamic
self.dynamic_layer = nn.Linear(self.hidden_dim,
self.num_dynamic * self.num_params)
self.norm1 = nn.LayerNorm(self.dim_dynamic)
self.norm2 = nn.LayerNorm(self.hidden_dim)
self.activation = nn.ReLU()
pooler_resolution = 7
num_output = self.hidden_dim * pooler_resolution**2
self.out_layer = nn.Linear(num_output, self.hidden_dim)
self.norm3 = nn.LayerNorm(self.hidden_dim)
def forward(self, pro_features, roi_features):
'''
pro_features: (1, N * nr_boxes, self.d_model)
roi_features: (49, N * nr_boxes, self.d_model)
'''
features = roi_features.transpose(perm=[1, 0, 2])
parameters = self.dynamic_layer(pro_features).transpose(perm=[1, 0, 2])
param1 = parameters[:, :, :self.num_params].reshape(
[-1, self.hidden_dim, self.dim_dynamic])
param2 = parameters[:, :, self.num_params:].reshape(
[-1, self.dim_dynamic, self.hidden_dim])
features = paddle.bmm(features, param1)
features = self.norm1(features)
features = self.activation(features)
features = paddle.bmm(features, param2)
features = self.norm2(features)
features = self.activation(features)
features = features.flatten(1)
features = self.out_layer(features)
features = self.norm3(features)
features = self.activation(features)
return features
class RCNNHead(nn.Layer):
def __init__(
self,
d_model,
num_classes,
dim_feedforward,
nhead,
dropout,
head_cls,
head_reg,
head_dim_dynamic,
head_num_dynamic,
scale_clamp: float=_DEFAULT_SCALE_CLAMP,
bbox_weights=(2.0, 2.0, 1.0, 1.0), ):
super().__init__()
self.d_model = d_model
# dynamic.
self.self_attn = nn.MultiHeadAttention(d_model, nhead, dropout=dropout)
self.inst_interact = DynamicConv(d_model, head_dim_dynamic,
head_num_dynamic)
self.linear1 = nn.Linear(d_model, dim_feedforward)
self.dropout = nn.Dropout(dropout)
self.linear2 = nn.Linear(dim_feedforward, d_model)
self.norm1 = nn.LayerNorm(d_model)
self.norm2 = nn.LayerNorm(d_model)
self.norm3 = nn.LayerNorm(d_model)
self.dropout1 = nn.Dropout(dropout)
self.dropout2 = nn.Dropout(dropout)
self.dropout3 = nn.Dropout(dropout)
self.activation = nn.ReLU()
# cls.
num_cls = head_cls
cls_module = list()
for _ in range(num_cls):
cls_module.append(nn.Linear(d_model, d_model, bias_attr=False))
cls_module.append(nn.LayerNorm(d_model))
cls_module.append(nn.ReLU())
self.cls_module = nn.LayerList(cls_module)
# reg.
num_reg = head_reg
reg_module = list()
for _ in range(num_reg):
reg_module.append(nn.Linear(d_model, d_model, bias_attr=False))
reg_module.append(nn.LayerNorm(d_model))
reg_module.append(nn.ReLU())
self.reg_module = nn.LayerList(reg_module)
# pred.
self.class_logits = nn.Linear(d_model, num_classes)
self.bboxes_delta = nn.Linear(d_model, 4)
self.scale_clamp = scale_clamp
self.bbox_weights = bbox_weights
def forward(self, features, bboxes, pro_features, pooler):
"""
:param bboxes: (N, nr_boxes, 4)
:param pro_features: (N, nr_boxes, d_model)
"""
N, nr_boxes = bboxes.shape[:2]
proposal_boxes = list()
for b in range(N):
proposal_boxes.append(bboxes[b])
roi_num = paddle.full([N], nr_boxes).astype("int32")
roi_features = pooler(features, proposal_boxes, roi_num)
roi_features = roi_features.reshape(
[N * nr_boxes, self.d_model, -1]).transpose(perm=[2, 0, 1])
# self_att.
pro_features = pro_features.reshape([N, nr_boxes, self.d_model])
pro_features2 = self.self_attn(
pro_features, pro_features, value=pro_features)
pro_features = pro_features.transpose(perm=[1, 0, 2]) + self.dropout1(
pro_features2.transpose(perm=[1, 0, 2]))
pro_features = self.norm1(pro_features)
# inst_interact.
pro_features = pro_features.reshape(
[nr_boxes, N, self.d_model]).transpose(perm=[1, 0, 2]).reshape(
[1, N * nr_boxes, self.d_model])
pro_features2 = self.inst_interact(pro_features, roi_features)
pro_features = pro_features + self.dropout2(pro_features2)
obj_features = self.norm2(pro_features)
# obj_feature.
obj_features2 = self.linear2(
self.dropout(self.activation(self.linear1(obj_features))))
obj_features = obj_features + self.dropout3(obj_features2)
obj_features = self.norm3(obj_features)
fc_feature = obj_features.transpose(perm=[1, 0, 2]).reshape(
[N * nr_boxes, -1])
cls_feature = fc_feature.clone()
reg_feature = fc_feature.clone()
for cls_layer in self.cls_module:
cls_feature = cls_layer(cls_feature)
for reg_layer in self.reg_module:
reg_feature = reg_layer(reg_feature)
class_logits = self.class_logits(cls_feature)
bboxes_deltas = self.bboxes_delta(reg_feature)
pred_bboxes = delta2bbox(bboxes_deltas,
bboxes.reshape([-1, 4]), self.bbox_weights)
return class_logits.reshape([N, nr_boxes, -1]), pred_bboxes.reshape(
[N, nr_boxes, -1]), obj_features
@register
class SparseRCNNHead(nn.Layer):
'''
SparsercnnHead
Args:
roi_input_shape (list[ShapeSpec]): The output shape of fpn
num_classes (int): Number of classes,
head_hidden_dim (int): The param of MultiHeadAttention,
head_dim_feedforward (int): The param of MultiHeadAttention,
nhead (int): The param of MultiHeadAttention,
head_dropout (float): The p of dropout,
head_cls (int): The number of class head,
head_reg (int): The number of regressionhead,
head_num_dynamic (int): The number of DynamicConv's param,
head_num_heads (int): The number of RCNNHead,
deep_supervision (int): wheather supervise the intermediate results,
num_proposals (int): the number of proposals boxes and features
'''
__inject__ = ['loss_func']
__shared__ = ['num_classes']
def __init__(
self,
head_hidden_dim,
head_dim_feedforward,
nhead,
head_dropout,
head_cls,
head_reg,
head_dim_dynamic,
head_num_dynamic,
head_num_heads,
deep_supervision,
num_proposals,
num_classes=80,
loss_func="SparseRCNNLoss",
roi_input_shape=None, ):
super().__init__()
# Build RoI.
box_pooler = self._init_box_pooler(roi_input_shape)
self.box_pooler = box_pooler
# Build heads.
rcnn_head = RCNNHead(
head_hidden_dim,
num_classes,
head_dim_feedforward,
nhead,
head_dropout,
head_cls,
head_reg,
head_dim_dynamic,
head_num_dynamic, )
self.head_series = nn.LayerList(
[copy.deepcopy(rcnn_head) for i in range(head_num_heads)])
self.return_intermediate = deep_supervision
self.num_classes = num_classes
# build init proposal
self.init_proposal_features = nn.Embedding(num_proposals,
head_hidden_dim)
self.init_proposal_boxes = nn.Embedding(num_proposals, 4)
self.lossfunc = loss_func
# Init parameters.
init.reset_initialized_parameter(self)
self._reset_parameters()
def _reset_parameters(self):
# init all parameters.
prior_prob = 0.01
bias_value = -math.log((1 - prior_prob) / prior_prob)
for m in self.sublayers():
if isinstance(m, nn.Linear):
init.xavier_normal_(m.weight, reverse=True)
elif not isinstance(m, nn.Embedding) and hasattr(
m, "weight") and m.weight.dim() > 1:
init.xavier_normal_(m.weight, reverse=False)
if hasattr(m, "bias") and m.bias is not None and m.bias.shape[
-1] == self.num_classes:
init.constant_(m.bias, bias_value)
init_bboxes = paddle.empty_like(self.init_proposal_boxes.weight)
init_bboxes[:, :2] = 0.5
init_bboxes[:, 2:] = 1.0
self.init_proposal_boxes.weight.set_value(init_bboxes)
@staticmethod
def _init_box_pooler(input_shape):
pooler_resolution = 7
sampling_ratio = 2
if input_shape is not None:
pooler_scales = tuple(1.0 / input_shape[k].stride
for k in range(len(input_shape)))
in_channels = [
input_shape[f].channels for f in range(len(input_shape))
]
end_level = len(input_shape) - 1
# Check all channel counts are equal
assert len(set(in_channels)) == 1, in_channels
else:
pooler_scales = [1.0 / 4.0, 1.0 / 8.0, 1.0 / 16.0, 1.0 / 32.0]
end_level = 3
box_pooler = RoIAlign(
resolution=pooler_resolution,
spatial_scale=pooler_scales,
sampling_ratio=sampling_ratio,
end_level=end_level,
aligned=True)
return box_pooler
def forward(self, features, input_whwh):
bs = len(features[0])
bboxes = box_cxcywh_to_xyxy(self.init_proposal_boxes.weight.clone(
)).unsqueeze(0)
bboxes = bboxes * input_whwh.unsqueeze(-2)
init_features = self.init_proposal_features.weight.unsqueeze(0).tile(
[1, bs, 1])
proposal_features = init_features.clone()
inter_class_logits = []
inter_pred_bboxes = []
for rcnn_head in self.head_series:
class_logits, pred_bboxes, proposal_features = rcnn_head(
features, bboxes, proposal_features, self.box_pooler)
if self.return_intermediate:
inter_class_logits.append(class_logits)
inter_pred_bboxes.append(pred_bboxes)
bboxes = pred_bboxes.detach()
output = {
'pred_logits': inter_class_logits[-1],
'pred_boxes': inter_pred_bboxes[-1]
}
if self.return_intermediate:
output['aux_outputs'] = [{
'pred_logits': a,
'pred_boxes': b
} for a, b in zip(inter_class_logits[:-1], inter_pred_bboxes[:-1])]
return output
def get_loss(self, outputs, targets):
losses = self.lossfunc(outputs, targets)
weight_dict = self.lossfunc.weight_dict
for k in losses.keys():
if k in weight_dict:
losses[k] *= weight_dict[k]
return losses
def box_cxcywh_to_xyxy(x):
x_c, y_c, w, h = x.unbind(-1)
b = [(x_c - 0.5 * w), (y_c - 0.5 * h),
(x_c + 0.5 * w), (y_c + 0.5 * h)]
return paddle.stack(b, axis=-1)
\ No newline at end of file
......@@ -23,6 +23,7 @@ from . import keypoint_loss
from . import jde_loss
from . import fairmot_loss
from . import detr_loss
from . import sparsercnn_loss
from .yolo_loss import *
from .iou_aware_loss import *
......@@ -35,3 +36,4 @@ from .keypoint_loss import *
from .jde_loss import *
from .fairmot_loss import *
from .detr_loss import *
from .sparsercnn_loss import *
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from scipy.optimize import linear_sum_assignment
import paddle
import paddle.nn as nn
import paddle.nn.functional as F
from paddle.metric import accuracy
from ppdet.core.workspace import register
from ppdet.modeling.losses.iou_loss import GIoULoss
__all__ = ["SparseRCNNLoss"]
@register
class SparseRCNNLoss(nn.Layer):
""" This class computes the loss for SparseRCNN.
The process happens in two steps:
1) we compute hungarian assignment between ground truth boxes and the outputs of the model
2) we supervise each pair of matched ground-truth / prediction (supervise class and box)
"""
__shared__ = ['num_classes']
def __init__(self,
losses,
focal_loss_alpha,
focal_loss_gamma,
num_classes=80,
class_weight=2.,
l1_weight=5.,
giou_weight=2.):
""" Create the criterion.
Parameters:
num_classes: number of object categories, omitting the special no-object category
weight_dict: dict containing as key the names of the losses and as values their relative weight.
losses: list of all the losses to be applied. See get_loss for list of available losses.
matcher: module able to compute a matching between targets and proposals
"""
super().__init__()
self.num_classes = num_classes
weight_dict = {
"loss_ce": class_weight,
"loss_bbox": l1_weight,
"loss_giou": giou_weight
}
self.weight_dict = weight_dict
self.losses = losses
self.giou_loss = GIoULoss(reduction="sum")
self.focal_loss_alpha = focal_loss_alpha
self.focal_loss_gamma = focal_loss_gamma
self.matcher = HungarianMatcher(focal_loss_alpha, focal_loss_gamma,
class_weight, l1_weight, giou_weight)
def loss_labels(self, outputs, targets, indices, num_boxes, log=True):
"""Classification loss (NLL)
targets dicts must contain the key "labels" containing a tensor of dim [nb_target_boxes]
"""
assert 'pred_logits' in outputs
src_logits = outputs['pred_logits']
idx = self._get_src_permutation_idx(indices)
target_classes_o = paddle.concat([
paddle.gather(
t["labels"], J, axis=0) for t, (_, J) in zip(targets, indices)
])
target_classes = paddle.full(
src_logits.shape[:2], self.num_classes, dtype="int32")
for i, ind in enumerate(zip(idx[0], idx[1])):
target_classes[int(ind[0]), int(ind[1])] = target_classes_o[i]
target_classes.stop_gradient = True
src_logits = src_logits.flatten(start_axis=0, stop_axis=1)
# prepare one_hot target.
target_classes = target_classes.flatten(start_axis=0, stop_axis=1)
class_ids = paddle.arange(0, self.num_classes)
labels = (target_classes.unsqueeze(-1) == class_ids).astype("float32")
labels.stop_gradient = True
# comp focal loss.
class_loss = sigmoid_focal_loss(
src_logits,
labels,
alpha=self.focal_loss_alpha,
gamma=self.focal_loss_gamma,
reduction="sum", ) / num_boxes
losses = {'loss_ce': class_loss}
if log:
label_acc = target_classes_o.unsqueeze(-1)
src_idx = [src for (src, _) in indices]
pred_list = []
for i in range(outputs["pred_logits"].shape[0]):
pred_list.append(
paddle.gather(
outputs["pred_logits"][i], src_idx[i], axis=0))
pred = F.sigmoid(paddle.concat(pred_list, axis=0))
acc = accuracy(pred, label_acc.astype("int64"))
losses["acc"] = acc
return losses
def loss_boxes(self, outputs, targets, indices, num_boxes):
"""Compute the losses related to the bounding boxes, the L1 regression loss and the GIoU loss
targets dicts must contain the key "boxes" containing a tensor of dim [nb_target_boxes, 4]
The target boxes are expected in format (center_x, center_y, w, h), normalized by the image size.
"""
assert 'pred_boxes' in outputs # [batch_size, num_proposals, 4]
src_idx = [src for (src, _) in indices]
src_boxes_list = []
for i in range(outputs["pred_boxes"].shape[0]):
src_boxes_list.append(
paddle.gather(
outputs["pred_boxes"][i], src_idx[i], axis=0))
src_boxes = paddle.concat(src_boxes_list, axis=0)
target_boxes = paddle.concat(
[
paddle.gather(
t['boxes'], I, axis=0)
for t, (_, I) in zip(targets, indices)
],
axis=0)
target_boxes.stop_gradient = True
losses = {}
losses['loss_giou'] = self.giou_loss(src_boxes,
target_boxes) / num_boxes
image_size = paddle.concat([v["img_whwh_tgt"] for v in targets])
src_boxes_ = src_boxes / image_size
target_boxes_ = target_boxes / image_size
loss_bbox = F.l1_loss(src_boxes_, target_boxes_, reduction='sum')
losses['loss_bbox'] = loss_bbox / num_boxes
return losses
def _get_src_permutation_idx(self, indices):
# permute predictions following indices
batch_idx = paddle.concat(
[paddle.full_like(src, i) for i, (src, _) in enumerate(indices)])
src_idx = paddle.concat([src for (src, _) in indices])
return batch_idx, src_idx
def _get_tgt_permutation_idx(self, indices):
# permute targets following indices
batch_idx = paddle.concat(
[paddle.full_like(tgt, i) for i, (_, tgt) in enumerate(indices)])
tgt_idx = paddle.concat([tgt for (_, tgt) in indices])
return batch_idx, tgt_idx
def get_loss(self, loss, outputs, targets, indices, num_boxes, **kwargs):
loss_map = {
'labels': self.loss_labels,
'boxes': self.loss_boxes,
}
assert loss in loss_map, f'do you really want to compute {loss} loss?'
return loss_map[loss](outputs, targets, indices, num_boxes, **kwargs)
def forward(self, outputs, targets):
""" This performs the loss computation.
Parameters:
outputs: dict of tensors, see the output specification of the model for the format
targets: list of dicts, such that len(targets) == batch_size.
The expected keys in each dict depends on the losses applied, see each loss' doc
"""
outputs_without_aux = {
k: v
for k, v in outputs.items() if k != 'aux_outputs'
}
# Retrieve the matching between the outputs of the last layer and the targets
indices = self.matcher(outputs_without_aux, targets)
# Compute the average number of target boxes accross all nodes, for normalization purposes
num_boxes = sum(len(t["labels"]) for t in targets)
num_boxes = paddle.to_tensor(
[num_boxes],
dtype="float32",
place=next(iter(outputs.values())).place)
# Compute all the requested losses
losses = {}
for loss in self.losses:
losses.update(
self.get_loss(loss, outputs, targets, indices, num_boxes))
# In case of auxiliary losses, we repeat this process with the output of each intermediate layer.
if 'aux_outputs' in outputs:
for i, aux_outputs in enumerate(outputs['aux_outputs']):
indices = self.matcher(aux_outputs, targets)
for loss in self.losses:
kwargs = {}
if loss == 'labels':
# Logging is enabled only for the last layer
kwargs = {'log': False}
l_dict = self.get_loss(loss, aux_outputs, targets, indices,
num_boxes, **kwargs)
w_dict = {}
for k in l_dict.keys():
if k in self.weight_dict:
w_dict[k + f'_{i}'] = l_dict[k] * self.weight_dict[
k]
else:
w_dict[k + f'_{i}'] = l_dict[k]
losses.update(w_dict)
return losses
class HungarianMatcher(nn.Layer):
"""This class computes an assignment between the targets and the predictions of the network
For efficiency reasons, the targets don't include the no_object. Because of this, in general,
there are more predictions than targets. In this case, we do a 1-to-1 matching of the best predictions,
while the others are un-matched (and thus treated as non-objects).
"""
def __init__(self,
focal_loss_alpha,
focal_loss_gamma,
cost_class: float=1,
cost_bbox: float=1,
cost_giou: float=1):
"""Creates the matcher
Params:
cost_class: This is the relative weight of the classification error in the matching cost
cost_bbox: This is the relative weight of the L1 error of the bounding box coordinates in the matching cost
cost_giou: This is the relative weight of the giou loss of the bounding box in the matching cost
"""
super().__init__()
self.cost_class = cost_class
self.cost_bbox = cost_bbox
self.cost_giou = cost_giou
self.focal_loss_alpha = focal_loss_alpha
self.focal_loss_gamma = focal_loss_gamma
assert cost_class != 0 or cost_bbox != 0 or cost_giou != 0, "all costs cant be 0"
@paddle.no_grad()
def forward(self, outputs, targets):
""" Performs the matching
Args:
outputs: This is a dict that contains at least these entries:
"pred_logits": Tensor of dim [batch_size, num_queries, num_classes] with the classification logits
"pred_boxes": Tensor of dim [batch_size, num_queries, 4] with the predicted box coordinates
eg. outputs = {"pred_logits": pred_logits, "pred_boxes": pred_boxes}
targets: This is a list of targets (len(targets) = batch_size), where each target is a dict containing:
"labels": Tensor of dim [num_target_boxes] (where num_target_boxes is the number of ground-truth
objects in the target) containing the class labels
"boxes": Tensor of dim [num_target_boxes, 4] containing the target box coordinates
eg. targets = [{"labels":labels, "boxes": boxes}, ...,{"labels":labels, "boxes": boxes}]
Returns:
A list of size batch_size, containing tuples of (index_i, index_j) where:
- index_i is the indices of the selected predictions (in order)
- index_j is the indices of the corresponding selected targets (in order)
For each batch element, it holds:
len(index_i) = len(index_j) = min(num_queries, num_target_boxes)
"""
bs, num_queries = outputs["pred_logits"].shape[:2]
# We flatten to compute the cost matrices in a batch
out_prob = F.sigmoid(outputs["pred_logits"].flatten(
start_axis=0, stop_axis=1))
out_bbox = outputs["pred_boxes"].flatten(start_axis=0, stop_axis=1)
# Also concat the target labels and boxes
tgt_ids = paddle.concat([v["labels"] for v in targets])
assert (tgt_ids > -1).all()
tgt_bbox = paddle.concat([v["boxes"] for v in targets])
# Compute the classification cost. Contrary to the loss, we don't use the NLL,
# but approximate it in 1 - proba[target class].
# The 1 is a constant that doesn't change the matching, it can be ommitted.
# Compute the classification cost.
alpha = self.focal_loss_alpha
gamma = self.focal_loss_gamma
neg_cost_class = (1 - alpha) * (out_prob**gamma) * (-(
1 - out_prob + 1e-8).log())
pos_cost_class = alpha * ((1 - out_prob)
**gamma) * (-(out_prob + 1e-8).log())
cost_class = paddle.gather(
pos_cost_class, tgt_ids, axis=1) - paddle.gather(
neg_cost_class, tgt_ids, axis=1)
# Compute the L1 cost between boxes
image_size_out = paddle.concat(
[v["img_whwh"].unsqueeze(0) for v in targets])
image_size_out = image_size_out.unsqueeze(1).tile(
[1, num_queries, 1]).flatten(
start_axis=0, stop_axis=1)
image_size_tgt = paddle.concat([v["img_whwh_tgt"] for v in targets])
out_bbox_ = out_bbox / image_size_out
tgt_bbox_ = tgt_bbox / image_size_tgt
cost_bbox = F.l1_loss(
out_bbox_.unsqueeze(-2), tgt_bbox_,
reduction='none').sum(-1) # [batch_size * num_queries, num_tgts]
# Compute the giou cost betwen boxes
cost_giou = -get_bboxes_giou(out_bbox, tgt_bbox)
# Final cost matrix
C = self.cost_bbox * cost_bbox + self.cost_class * cost_class + self.cost_giou * cost_giou
C = C.reshape([bs, num_queries, -1])
sizes = [len(v["boxes"]) for v in targets]
indices = [
linear_sum_assignment(c[i].numpy())
for i, c in enumerate(C.split(sizes, -1))
]
return [(paddle.to_tensor(
i, dtype="int32"), paddle.to_tensor(
j, dtype="int32")) for i, j in indices]
def box_area(boxes):
assert (boxes[:, 2:] >= boxes[:, :2]).all()
wh = boxes[:, 2:] - boxes[:, :2]
return wh[:, 0] * wh[:, 1]
def boxes_iou(boxes1, boxes2):
'''
Compute iou
Args:
boxes1 (paddle.tensor) shape (N, 4)
boxes2 (paddle.tensor) shape (M, 4)
Return:
(paddle.tensor) shape (N, M)
'''
area1 = box_area(boxes1)
area2 = box_area(boxes2)
lt = paddle.maximum(boxes1.unsqueeze(-2)[:, :, :2], boxes2[:, :2])
rb = paddle.minimum(boxes1.unsqueeze(-2)[:, :, 2:], boxes2[:, 2:])
wh = (rb - lt).astype("float32").clip(min=1e-9)
inter = wh[:, :, 0] * wh[:, :, 1]
union = area1.unsqueeze(-1) + area2 - inter + 1e-9
iou = inter / union
return iou, union
def get_bboxes_giou(boxes1, boxes2, eps=1e-9):
"""calculate the ious of boxes1 and boxes2
Args:
boxes1 (Tensor): shape [N, 4]
boxes2 (Tensor): shape [M, 4]
eps (float): epsilon to avoid divide by zero
Return:
ious (Tensor): ious of boxes1 and boxes2, with the shape [N, M]
"""
assert (boxes1[:, 2:] >= boxes1[:, :2]).all()
assert (boxes2[:, 2:] >= boxes2[:, :2]).all()
iou, union = boxes_iou(boxes1, boxes2)
lt = paddle.minimum(boxes1.unsqueeze(-2)[:, :, :2], boxes2[:, :2])
rb = paddle.maximum(boxes1.unsqueeze(-2)[:, :, 2:], boxes2[:, 2:])
wh = (rb - lt).astype("float32").clip(min=eps)
enclose_area = wh[:, :, 0] * wh[:, :, 1]
giou = iou - (enclose_area - union) / enclose_area
return giou
def sigmoid_focal_loss(inputs, targets, alpha, gamma, reduction="sum"):
assert reduction in ["sum", "mean"
], f'do not support this {reduction} reduction?'
p = F.sigmoid(inputs)
ce_loss = F.binary_cross_entropy_with_logits(
inputs, targets, reduction="none")
p_t = p * targets + (1 - p) * (1 - targets)
loss = ce_loss * ((1 - p_t)**gamma)
if alpha >= 0:
alpha_t = alpha * targets + (1 - alpha) * (1 - targets)
loss = alpha_t * loss
if reduction == "mean":
loss = loss.mean()
elif reduction == "sum":
loss = loss.sum()
return loss
......@@ -28,7 +28,7 @@ except Exception:
__all__ = [
'BBoxPostProcess', 'MaskPostProcess', 'FCOSPostProcess',
'S2ANetBBoxPostProcess', 'JDEBBoxPostProcess', 'CenterNetPostProcess',
'DETRBBoxPostProcess'
'DETRBBoxPostProcess', 'SparsePostProcess'
]
......@@ -551,3 +551,90 @@ class DETRBBoxPostProcess(object):
bbox_pred.shape[1], dtype='int32').tile([bbox_pred.shape[0]])
bbox_pred = bbox_pred.reshape([-1, 6])
return bbox_pred, bbox_num
@register
class SparsePostProcess(object):
__shared__ = ['num_classes']
def __init__(self, num_proposals, num_classes=80):
super(SparsePostProcess, self).__init__()
self.num_classes = num_classes
self.num_proposals = num_proposals
def __call__(self, box_cls, box_pred, scale_factor_wh, img_whwh):
"""
Arguments:
box_cls (Tensor): tensor of shape (batch_size, num_proposals, K).
The tensor predicts the classification probability for each proposal.
box_pred (Tensor): tensors of shape (batch_size, num_proposals, 4).
The tensor predicts 4-vector (x,y,w,h) box
regression values for every proposal
scale_factor_wh (Tensor): tensors of shape [batch_size, 2] the scalor of per img
img_whwh (Tensor): tensors of shape [batch_size, 4]
Returns:
bbox_pred (Tensor): tensors of shape [num_boxes, 6] Each row has 6 values:
[label, confidence, xmin, ymin, xmax, ymax]
bbox_num (Tensor): tensors of shape [batch_size] the number of RoIs in each image.
"""
assert len(box_cls) == len(scale_factor_wh) == len(img_whwh)
img_wh = img_whwh[:, :2]
scores = F.sigmoid(box_cls)
labels = paddle.arange(0, self.num_classes). \
unsqueeze(0).tile([self.num_proposals, 1]).flatten(start_axis=0, stop_axis=1)
classes_all = []
scores_all = []
boxes_all = []
for i, (scores_per_image,
box_pred_per_image) in enumerate(zip(scores, box_pred)):
scores_per_image, topk_indices = scores_per_image.flatten(
0, 1).topk(
self.num_proposals, sorted=False)
labels_per_image = paddle.gather(labels, topk_indices, axis=0)
box_pred_per_image = box_pred_per_image.reshape([-1, 1, 4]).tile(
[1, self.num_classes, 1]).reshape([-1, 4])
box_pred_per_image = paddle.gather(
box_pred_per_image, topk_indices, axis=0)
classes_all.append(labels_per_image)
scores_all.append(scores_per_image)
boxes_all.append(box_pred_per_image)
bbox_num = paddle.zeros([len(scale_factor_wh)], dtype="int32")
boxes_final = []
for i in range(len(scale_factor_wh)):
classes = classes_all[i]
boxes = boxes_all[i]
scores = scores_all[i]
boxes[:, 0::2] = paddle.clip(
boxes[:, 0::2], min=0, max=img_wh[i][0]) / scale_factor_wh[i][0]
boxes[:, 1::2] = paddle.clip(
boxes[:, 1::2], min=0, max=img_wh[i][1]) / scale_factor_wh[i][1]
boxes_w, boxes_h = (boxes[:, 2] - boxes[:, 0]).numpy(), (
boxes[:, 3] - boxes[:, 1]).numpy()
keep = (boxes_w > 1.) & (boxes_h > 1.)
if (keep.sum() == 0):
bboxes = paddle.zeros([1, 6]).astype("float32")
else:
boxes = paddle.to_tensor(boxes.numpy()[keep]).astype("float32")
classes = paddle.to_tensor(classes.numpy()[keep]).astype(
"float32").unsqueeze(-1)
scores = paddle.to_tensor(scores.numpy()[keep]).astype(
"float32").unsqueeze(-1)
bboxes = paddle.concat([classes, scores, boxes], axis=-1)
boxes_final.append(bboxes)
bbox_num[i] = bboxes.shape[0]
bbox_pred = paddle.concat(boxes_final)
return bbox_pred, bbox_num
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册