From f7df1eb937c9ba06a77eb54e8b818e9458baa5f1 Mon Sep 17 00:00:00 2001 From: Guanghua Yu <742925032@qq.com> Date: Fri, 5 Nov 2021 22:29:21 +0800 Subject: [PATCH] Update citation of some code (#4484) --- deploy/cpp/src/picodet_postprocess.cc | 5 +- deploy/lite/src/picodet_postprocess.cc | 5 +- ppdet/data/transform/atss_assigner.py | 3 + ppdet/data/transform/batch_operators.py | 2 + ppdet/modeling/assigners/simota_assigner.py | 223 +++++++++--------- ppdet/modeling/heads/gfl_head.py | 3 + ppdet/modeling/heads/simota_head.py | 192 +++++++-------- ppdet/modeling/heads/solov2_head.py | 8 +- ppdet/modeling/losses/gfocal_loss.py | 3 + ppdet/modeling/losses/varifocal_loss.py | 3 + ppdet/modeling/necks/__init__.py | 2 - ppdet/modeling/necks/csp_pan.py | 7 +- ppdet/modeling/necks/pan.py | 127 ---------- .../modeling/anchor_heads/solov2_head.py | 3 + .../modeling/mask_head/solo_mask_head.py | 3 + 15 files changed, 239 insertions(+), 350 deletions(-) delete mode 100644 ppdet/modeling/necks/pan.py diff --git a/deploy/cpp/src/picodet_postprocess.cc b/deploy/cpp/src/picodet_postprocess.cc index ba73c7d8c..cbe70d43f 100644 --- a/deploy/cpp/src/picodet_postprocess.cc +++ b/deploy/cpp/src/picodet_postprocess.cc @@ -11,6 +11,9 @@ // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. +// +// The code is based on: +// https://github.com/RangiLyu/nanodet/blob/main/demo_mnn/nanodet_mnn.cpp #include "include/picodet_postprocess.h" @@ -124,4 +127,4 @@ void PicoDetPostProcess(std::vector* results, } } -} // namespace PaddleDetection \ No newline at end of file +} // namespace PaddleDetection diff --git a/deploy/lite/src/picodet_postprocess.cc b/deploy/lite/src/picodet_postprocess.cc index ba73c7d8c..cbe70d43f 100644 --- a/deploy/lite/src/picodet_postprocess.cc +++ b/deploy/lite/src/picodet_postprocess.cc @@ -11,6 +11,9 @@ // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. +// +// The code is based on: +// https://github.com/RangiLyu/nanodet/blob/main/demo_mnn/nanodet_mnn.cpp #include "include/picodet_postprocess.h" @@ -124,4 +127,4 @@ void PicoDetPostProcess(std::vector* results, } } -} // namespace PaddleDetection \ No newline at end of file +} // namespace PaddleDetection diff --git a/ppdet/data/transform/atss_assigner.py b/ppdet/data/transform/atss_assigner.py index d41c85a7e..178d94fb6 100644 --- a/ppdet/data/transform/atss_assigner.py +++ b/ppdet/data/transform/atss_assigner.py @@ -12,6 +12,9 @@ # See the License for the specific language governing permissions and # limitations under the License. +# The code is based on: +# https://github.com/open-mmlab/mmdetection/blob/master/mmdet/core/bbox/assigners/atss_assigner.py + from __future__ import absolute_import from __future__ import division from __future__ import print_function diff --git a/ppdet/data/transform/batch_operators.py b/ppdet/data/transform/batch_operators.py index edc25bead..e43fb7d20 100644 --- a/ppdet/data/transform/batch_operators.py +++ b/ppdet/data/transform/batch_operators.py @@ -707,6 +707,8 @@ class Gt2TTFTarget(BaseOperator): @register_op class Gt2Solov2Target(BaseOperator): """Assign mask target and labels in SOLOv2 network. + The code of this function is based on: + https://github.com/WXinlong/SOLO/blob/master/mmdet/models/anchor_heads/solov2_head.py#L271 Args: num_grids (list): The list of feature map grids size. scale_ranges (list): The list of mask boundary range. diff --git a/ppdet/modeling/assigners/simota_assigner.py b/ppdet/modeling/assigners/simota_assigner.py index ce22ad7e4..4b34027e3 100644 --- a/ppdet/modeling/assigners/simota_assigner.py +++ b/ppdet/modeling/assigners/simota_assigner.py @@ -11,15 +11,13 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -""" -This code is refer from: -https://github.com/open-mmlab/mmdetection/blob/master/mmdet/core/bbox/assigners/sim_ota_assigner.py -""" + +# The code is based on: +# https://github.com/open-mmlab/mmdetection/blob/master/mmdet/core/bbox/assigners/sim_ota_assigner.py import paddle import numpy as np import paddle.nn.functional as F -import paddle.nn as nn from ppdet.modeling.losses.varifocal_loss import varifocal_loss from ppdet.modeling.bbox_utils import batch_bbox_overlaps @@ -29,7 +27,6 @@ from ppdet.core.workspace import register @register class SimOTAAssigner(object): """Computes matching between predictions and ground truth. - Args: center_radius (int | float, optional): Ground truth center size to judge whether a prior is in center. Default 2.5. @@ -58,86 +55,89 @@ class SimOTAAssigner(object): self.num_classes = num_classes self.use_vfl = use_vfl - def get_in_gt_and_in_center_info(self, priors, gt_bboxes): + def get_in_gt_and_in_center_info(self, flatten_center_and_stride, + gt_bboxes): num_gt = gt_bboxes.shape[0] - repeated_x = priors[:, 0].unsqueeze(1).tile([1, num_gt]) - repeated_y = priors[:, 1].unsqueeze(1).tile([1, num_gt]) - repeated_stride_x = priors[:, 2].unsqueeze(1).tile([1, num_gt]) - repeated_stride_y = priors[:, 3].unsqueeze(1).tile([1, num_gt]) - - # is prior centers in gt bboxes, shape: [n_prior, n_gt] - l_ = repeated_x - gt_bboxes[:, 0] - t_ = repeated_y - gt_bboxes[:, 1] - r_ = gt_bboxes[:, 2] - repeated_x - b_ = gt_bboxes[:, 3] - repeated_y + flatten_x = flatten_center_and_stride[:, 0].unsqueeze(1).tile( + [1, num_gt]) + flatten_y = flatten_center_and_stride[:, 1].unsqueeze(1).tile( + [1, num_gt]) + flatten_stride_x = flatten_center_and_stride[:, 2].unsqueeze(1).tile( + [1, num_gt]) + flatten_stride_y = flatten_center_and_stride[:, 3].unsqueeze(1).tile( + [1, num_gt]) + + # is prior centers in gt bboxes, shape: [n_center, n_gt] + l_ = flatten_x - gt_bboxes[:, 0] + t_ = flatten_y - gt_bboxes[:, 1] + r_ = gt_bboxes[:, 2] - flatten_x + b_ = gt_bboxes[:, 3] - flatten_y deltas = paddle.stack([l_, t_, r_, b_], axis=1) is_in_gts = deltas.min(axis=1) > 0 is_in_gts_all = is_in_gts.sum(axis=1) > 0 # is prior centers in gt centers - gt_cxs = (gt_bboxes[:, 0] + gt_bboxes[:, 2]) / 2.0 - gt_cys = (gt_bboxes[:, 1] + gt_bboxes[:, 3]) / 2.0 - ct_box_l = gt_cxs - self.center_radius * repeated_stride_x - ct_box_t = gt_cys - self.center_radius * repeated_stride_y - ct_box_r = gt_cxs + self.center_radius * repeated_stride_x - ct_box_b = gt_cys + self.center_radius * repeated_stride_y - - cl_ = repeated_x - ct_box_l - ct_ = repeated_y - ct_box_t - cr_ = ct_box_r - repeated_x - cb_ = ct_box_b - repeated_y + gt_center_xs = (gt_bboxes[:, 0] + gt_bboxes[:, 2]) / 2.0 + gt_center_ys = (gt_bboxes[:, 1] + gt_bboxes[:, 3]) / 2.0 + ct_bound_l = gt_center_xs - self.center_radius * flatten_stride_x + ct_bound_t = gt_center_ys - self.center_radius * flatten_stride_y + ct_bound_r = gt_center_xs + self.center_radius * flatten_stride_x + ct_bound_b = gt_center_ys + self.center_radius * flatten_stride_y + + cl_ = flatten_x - ct_bound_l + ct_ = flatten_y - ct_bound_t + cr_ = ct_bound_r - flatten_x + cb_ = ct_bound_b - flatten_y ct_deltas = paddle.stack([cl_, ct_, cr_, cb_], axis=1) is_in_cts = ct_deltas.min(axis=1) > 0 is_in_cts_all = is_in_cts.sum(axis=1) > 0 - # in boxes or in centers, shape: [num_priors] - is_in_gts_or_centers = paddle.logical_or(is_in_gts_all, is_in_cts_all) + # in any of gts or gt centers, shape: [n_center] + is_in_gts_or_centers_all = paddle.logical_or(is_in_gts_all, + is_in_cts_all) - is_in_gts_or_centers_inds = paddle.nonzero( - is_in_gts_or_centers).squeeze(1) + is_in_gts_or_centers_all_inds = paddle.nonzero( + is_in_gts_or_centers_all).squeeze(1) - # both in boxes and centers, shape: [num_fg, num_gt] - is_in_boxes_and_centers = paddle.logical_and( + # both in gts and gt centers, shape: [num_fg, num_gt] + is_in_gts_and_centers = paddle.logical_and( paddle.gather( - is_in_gts.cast('int'), is_in_gts_or_centers_inds, + is_in_gts.cast('int'), is_in_gts_or_centers_all_inds, axis=0).cast('bool'), paddle.gather( - is_in_cts.cast('int'), is_in_gts_or_centers_inds, + is_in_cts.cast('int'), is_in_gts_or_centers_all_inds, axis=0).cast('bool')) - return is_in_gts_or_centers, is_in_boxes_and_centers + return is_in_gts_or_centers_all, is_in_gts_or_centers_all_inds, is_in_gts_and_centers - def dynamic_k_matching(self, cost, pairwise_ious, num_gt, valid_mask): - matching_matrix = np.zeros_like(cost.numpy()) + def dynamic_k_matching(self, cost_matrix, pairwise_ious, num_gt): + match_matrix = np.zeros_like(cost_matrix.numpy()) # select candidate topk ious for dynamic-k calculation topk_ious, _ = paddle.topk(pairwise_ious, self.candidate_topk, axis=0) # calculate dynamic k for each gt dynamic_ks = paddle.clip(topk_ious.sum(0).cast('int'), min=1) for gt_idx in range(num_gt): _, pos_idx = paddle.topk( - cost[:, gt_idx], k=dynamic_ks[gt_idx], largest=False) - matching_matrix[:, gt_idx][pos_idx.numpy()] = 1.0 + cost_matrix[:, gt_idx], k=dynamic_ks[gt_idx], largest=False) + match_matrix[:, gt_idx][pos_idx.numpy()] = 1.0 del topk_ious, dynamic_ks, pos_idx - prior_match_gt_mask = matching_matrix.sum(1) > 1 - if prior_match_gt_mask.sum() > 0: - cost = cost.numpy() - cost_argmin = np.argmin(cost[prior_match_gt_mask, :], axis=1) - matching_matrix[prior_match_gt_mask, :] *= 0.0 - matching_matrix[prior_match_gt_mask, cost_argmin] = 1.0 - # get foreground mask inside box and center prior - fg_mask_inboxes = matching_matrix.sum(1) > 0.0 - valid_mask[valid_mask.copy()] = fg_mask_inboxes + # match points more than two gts + extra_match_gts_mask = match_matrix.sum(1) > 1 + if extra_match_gts_mask.sum() > 0: + cost_matrix = cost_matrix.numpy() + cost_argmin = np.argmin( + cost_matrix[extra_match_gts_mask, :], axis=1) + match_matrix[extra_match_gts_mask, :] *= 0.0 + match_matrix[extra_match_gts_mask, cost_argmin] = 1.0 + # get foreground mask + match_fg_mask_inmatrix = match_matrix.sum(1) > 0 + match_gt_inds_to_fg = match_matrix[match_fg_mask_inmatrix, :].argmax(1) - matched_gt_inds = matching_matrix[fg_mask_inboxes, :].argmax(1) - - matched_gt_inds = paddle.to_tensor( - matched_gt_inds, place=pairwise_ious.place) - - return matched_gt_inds, valid_mask + return match_gt_inds_to_fg, match_fg_mask_inmatrix def get_sample(self, assign_gt_inds, gt_bboxes): pos_inds = np.unique(np.nonzero(assign_gt_inds > 0)[0]) @@ -155,9 +155,9 @@ class SimOTAAssigner(object): return pos_inds, neg_inds, pos_gt_bboxes, pos_assigned_gt_inds def __call__(self, - pred_scores, - priors, - decoded_bboxes, + flatten_cls_pred_scores, + flatten_center_and_stride, + flatten_bboxes, gt_bboxes, gt_labels, eps=1e-7): @@ -166,35 +166,31 @@ class SimOTAAssigner(object): Returns: assign_result: The assigned result. """ - - INF = 100000000 num_gt = gt_bboxes.shape[0] - num_bboxes = decoded_bboxes.shape[0] + num_bboxes = flatten_bboxes.shape[0] - # assign 0 by default - assigned_gt_inds = paddle.full( - (num_bboxes, ), 0, dtype=paddle.int64).numpy() if num_gt == 0 or num_bboxes == 0: - # No ground truth or boxes, return empty assignment - priors = priors.numpy() - labels = np.ones([num_bboxes], dtype=np.int64) * self.num_classes - label_weights = np.ones([num_bboxes], dtype=np.float32) - bbox_targets = np.zeros_like(priors) - return priors, labels, label_weights, bbox_targets, 0 - - valid_mask, is_in_boxes_and_center = self.get_in_gt_and_in_center_info( - priors, gt_bboxes) - - valid_mask_inds = paddle.nonzero(valid_mask).squeeze(1) - valid_decoded_bbox = decoded_bboxes[valid_mask_inds] - valid_pred_scores = pred_scores[valid_mask_inds] - num_valid = valid_decoded_bbox.shape[0] - - pairwise_ious = batch_bbox_overlaps(valid_decoded_bbox, gt_bboxes) + # No ground truth or boxes + label = np.ones([num_bboxes], dtype=np.int64) * self.num_classes + label_weight = np.ones([num_bboxes], dtype=np.float32) + bbox_target = np.zeros_like(flatten_center_and_stride) + return 0, label, label_weight, bbox_target + + is_in_gts_or_centers_all, is_in_gts_or_centers_all_inds, is_in_boxes_and_center = self.get_in_gt_and_in_center_info( + flatten_center_and_stride, gt_bboxes) + + # bboxes and scores to calculate matrix + valid_flatten_bboxes = flatten_bboxes[is_in_gts_or_centers_all_inds] + valid_cls_pred_scores = flatten_cls_pred_scores[ + is_in_gts_or_centers_all_inds] + num_valid_bboxes = valid_flatten_bboxes.shape[0] + + pairwise_ious = batch_bbox_overlaps(valid_flatten_bboxes, + gt_bboxes) # [num_points,num_gts] if self.use_vfl: gt_vfl_labels = gt_labels.squeeze(-1).unsqueeze(0).tile( - [num_valid, 1]).reshape([-1]) - valid_pred_scores = valid_pred_scores.unsqueeze(1).tile( + [num_valid_bboxes, 1]).reshape([-1]) + valid_pred_scores = valid_cls_pred_scores.unsqueeze(1).tile( [1, num_gt, 1]).reshape([-1, self.num_classes]) vfl_score = np.zeros(valid_pred_scores.shape) vfl_score[np.arange(0, vfl_score.shape[0]), gt_vfl_labels.numpy( @@ -202,64 +198,65 @@ class SimOTAAssigner(object): vfl_score = paddle.to_tensor(vfl_score) losses_vfl = varifocal_loss( valid_pred_scores, vfl_score, - use_sigmoid=False).reshape([num_valid, num_gt]) + use_sigmoid=False).reshape([num_valid_bboxes, num_gt]) losses_giou = batch_bbox_overlaps( - valid_decoded_bbox, gt_bboxes, mode='giou') + valid_flatten_bboxes, gt_bboxes, mode='giou') cost_matrix = ( losses_vfl * self.cls_weight + losses_giou * self.iou_weight + - paddle.logical_not(is_in_boxes_and_center).cast('float32') * INF - ) + paddle.logical_not(is_in_boxes_and_center).cast('float32') * + 100000000) else: iou_cost = -paddle.log(pairwise_ious + eps) gt_onehot_label = (F.one_hot( gt_labels.squeeze(-1).cast(paddle.int64), - pred_scores.shape[-1]).cast('float32').unsqueeze(0).tile( - [num_valid, 1, 1])) + flatten_cls_pred_scores.shape[-1]).cast('float32').unsqueeze(0) + .tile([num_valid_bboxes, 1, 1])) - valid_pred_scores = valid_pred_scores.unsqueeze(1).tile( + valid_pred_scores = valid_cls_pred_scores.unsqueeze(1).tile( [1, num_gt, 1]) cls_cost = F.binary_cross_entropy( valid_pred_scores, gt_onehot_label, reduction='none').sum(-1) cost_matrix = ( cls_cost * self.cls_weight + iou_cost * self.iou_weight + - paddle.logical_not(is_in_boxes_and_center).cast('float32') * INF - ) + paddle.logical_not(is_in_boxes_and_center).cast('float32') * + 100000000) - matched_gt_inds, valid_mask = \ + match_gt_inds_to_fg, match_fg_mask_inmatrix = \ self.dynamic_k_matching( - cost_matrix, pairwise_ious, num_gt, valid_mask.numpy()) + cost_matrix, pairwise_ious, num_gt) - # assign results - gt_labels = gt_labels.numpy() - priors = priors.numpy() - matched_gt_inds = matched_gt_inds.numpy() - gt_bboxes = gt_bboxes.numpy() + # sample and assign results + assigned_gt_inds = np.zeros([num_bboxes], dtype=np.int64) + match_fg_mask_inall = np.zeros_like(assigned_gt_inds) + match_fg_mask_inall[is_in_gts_or_centers_all.numpy( + )] = match_fg_mask_inmatrix - assigned_gt_inds[valid_mask] = matched_gt_inds + 1 + assigned_gt_inds[match_fg_mask_inall.astype( + np.bool)] = match_gt_inds_to_fg + 1 pos_inds, neg_inds, pos_gt_bboxes, pos_assigned_gt_inds \ - = self.get_sample(assigned_gt_inds, gt_bboxes) + = self.get_sample(assigned_gt_inds, gt_bboxes.numpy()) - num_cells = priors.shape[0] - bbox_targets = np.zeros_like(priors) - bbox_weights = np.zeros_like(priors) - labels = np.ones([num_cells], dtype=np.int64) * self.num_classes - label_weights = np.zeros([num_cells], dtype=np.float32) + bbox_target = np.zeros_like(flatten_bboxes) + bbox_weight = np.zeros_like(flatten_bboxes) + label = np.ones([num_bboxes], dtype=np.int64) * self.num_classes + label_weight = np.zeros([num_bboxes], dtype=np.float32) if len(pos_inds) > 0: + gt_labels = gt_labels.numpy() pos_bbox_targets = pos_gt_bboxes - bbox_targets[pos_inds, :] = pos_bbox_targets - bbox_weights[pos_inds, :] = 1.0 + bbox_target[pos_inds, :] = pos_bbox_targets + bbox_weight[pos_inds, :] = 1.0 if not np.any(gt_labels): - labels[pos_inds] = 0 + label[pos_inds] = 0 else: - labels[pos_inds] = gt_labels.squeeze(-1)[pos_assigned_gt_inds] + label[pos_inds] = gt_labels.squeeze(-1)[pos_assigned_gt_inds] - label_weights[pos_inds] = 1.0 + label_weight[pos_inds] = 1.0 if len(neg_inds) > 0: - label_weights[neg_inds] = 1.0 + label_weight[neg_inds] = 1.0 pos_num = max(pos_inds.size, 1) - return priors, labels, label_weights, bbox_targets, pos_num + return pos_num, label, label_weight, bbox_target diff --git a/ppdet/modeling/heads/gfl_head.py b/ppdet/modeling/heads/gfl_head.py index e5b0377a6..17e87a4ef 100644 --- a/ppdet/modeling/heads/gfl_head.py +++ b/ppdet/modeling/heads/gfl_head.py @@ -12,6 +12,9 @@ # See the License for the specific language governing permissions and # limitations under the License. +# The code is based on: +# https://github.com/open-mmlab/mmdetection/blob/master/mmdet/models/dense_heads/gfl_head.py + from __future__ import absolute_import from __future__ import division from __future__ import print_function diff --git a/ppdet/modeling/heads/simota_head.py b/ppdet/modeling/heads/simota_head.py index 2a870dc13..a1485f390 100644 --- a/ppdet/modeling/heads/simota_head.py +++ b/ppdet/modeling/heads/simota_head.py @@ -12,6 +12,9 @@ # See the License for the specific language governing permissions and # limitations under the License. +# The code is based on: +# https://github.com/open-mmlab/mmdetection/blob/master/mmdet/models/dense_heads/yolox_head.py + from __future__ import absolute_import from __future__ import division from __future__ import print_function @@ -30,29 +33,7 @@ from ppdet.core.workspace import register from ppdet.modeling.bbox_utils import distance2bbox, bbox2distance from ppdet.data.transform.atss_assigner import bbox_overlaps -from .gfl_head import GFLHead, ScaleReg, Integral - - -def multi_apply(func, *args, **kwargs): - """Apply function to a list of arguments. - - Note: - This function applies the ``func`` to multiple inputs and - map the multiple outputs of the ``func`` into different - list. Each list contains the same type of outputs corresponding - to different inputs. - - Args: - func (Function): A function that will be applied to a list of - arguments - - Returns: - tuple(list): A tuple containing multiple list, each list contains \ - a kind of returned results by the function - """ - pfunc = partial(func, **kwargs) if kwargs else func - map_results = map(pfunc, *args) - return tuple(map(list, zip(*map_results))) +from .gfl_head import GFLHead @register @@ -123,14 +104,15 @@ class OTAHead(GFLHead): self.assigner = assigner - def _get_target_single(self, cls_preds, centors, decoded_bboxes, gt_bboxes, - gt_labels): + def _get_target_single(self, flatten_cls_pred, flatten_center_and_stride, + flatten_bbox, gt_bboxes, gt_labels): """Compute targets for priors in a single image. """ - centors, labels, label_weights, bbox_targets, pos_num = self.assigner( - F.sigmoid(cls_preds), centors, decoded_bboxes, gt_bboxes, gt_labels) + pos_num, label, label_weight, bbox_target = self.assigner( + F.sigmoid(flatten_cls_pred), flatten_center_and_stride, + flatten_bbox, gt_bboxes, gt_labels) - return (centors, labels, label_weights, bbox_targets, pos_num) + return (pos_num, label, label_weight, bbox_target) def get_loss(self, head_outs, gt_meta): cls_scores, bbox_preds = head_outs @@ -142,26 +124,25 @@ class OTAHead(GFLHead): for featmap in cls_scores] decode_bbox_preds = [] - mlvl_centors = [] - with_stride = True + center_and_strides = [] for featmap_size, stride, bbox_pred in zip(featmap_sizes, self.fpn_stride, bbox_preds): + # center in origin image yy, xx = self.get_single_level_center_point(featmap_size, stride, self.cell_offset) - if with_stride: - stride_w = paddle.full((len(xx), ), stride) - stride_h = paddle.full((len(yy), ), stride) - centers = paddle.stack([xx, yy, stride_w, stride_h], -1).tile( + + center_and_stride = paddle.stack([xx, yy, stride, stride], -1).tile( [num_imgs, 1, 1]) - mlvl_centors.append(centers) - centers_in_feature = centers.reshape([-1, 4])[:, :-2] / stride + center_and_strides.append(center_and_stride) + center_in_feature = center_and_stride.reshape( + [-1, 4])[:, :-2] / stride bbox_pred = bbox_pred.transpose([0, 2, 3, 1]).reshape( [num_imgs, -1, 4 * (self.reg_max + 1)]) - pred_corners = self.distribution_project(bbox_pred) - decode_bbox_pred = distance2bbox( - centers_in_feature, pred_corners).reshape([num_imgs, -1, 4]) - decode_bbox_preds.append(decode_bbox_pred * stride) + pred_distances = self.distribution_project(bbox_pred) + decode_bbox_pred_wo_stride = distance2bbox( + center_in_feature, pred_distances).reshape([num_imgs, -1, 4]) + decode_bbox_preds.append(decode_bbox_pred_wo_stride * stride) flatten_cls_preds = [ cls_pred.transpose([0, 2, 3, 1]).reshape( @@ -170,27 +151,33 @@ class OTAHead(GFLHead): ] flatten_cls_preds = paddle.concat(flatten_cls_preds, axis=1) flatten_bboxes = paddle.concat(decode_bbox_preds, axis=1) - flatten_centors = paddle.concat(mlvl_centors, axis=1) - - gt_box, gt_labels = gt_meta['gt_bbox'], gt_meta['gt_class'] - (centors, labels, label_weights, bbox_targets, pos_num) = multi_apply( - self._get_target_single, - flatten_cls_preds.detach(), - flatten_centors.detach(), - flatten_bboxes.detach(), gt_box, gt_labels) - - centors = paddle.to_tensor(np.stack(centors, axis=0)) - labels = paddle.to_tensor(np.stack(labels, axis=0)) - label_weights = paddle.to_tensor(np.stack(label_weights, axis=0)) - bbox_targets = paddle.to_tensor(np.stack(bbox_targets, axis=0)) - - centors_list = self._images_to_levels(centors, num_level_anchors) + flatten_center_and_strides = paddle.concat(center_and_strides, axis=1) + + gt_boxes, gt_labels = gt_meta['gt_bbox'], gt_meta['gt_class'] + pos_num_l, label_l, label_weight_l, bbox_target_l = [], [], [], [] + for flatten_cls_pred,flatten_center_and_stride,flatten_bbox,gt_box, gt_label \ + in zip(flatten_cls_preds.detach(),flatten_center_and_strides.detach(), \ + flatten_bboxes.detach(),gt_boxes, gt_labels): + pos_num, label, label_weight, bbox_target = self._get_target_single( + flatten_cls_pred, flatten_center_and_stride, flatten_bbox, + gt_box, gt_label) + pos_num_l.append(pos_num) + label_l.append(label) + label_weight_l.append(label_weight) + bbox_target_l.append(bbox_target) + + labels = paddle.to_tensor(np.stack(label_l, axis=0)) + label_weights = paddle.to_tensor(np.stack(label_weight_l, axis=0)) + bbox_targets = paddle.to_tensor(np.stack(bbox_target_l, axis=0)) + + center_and_strides_list = self._images_to_levels( + flatten_center_and_strides, num_level_anchors) labels_list = self._images_to_levels(labels, num_level_anchors) label_weights_list = self._images_to_levels(label_weights, num_level_anchors) bbox_targets_list = self._images_to_levels(bbox_targets, num_level_anchors) - num_total_pos = sum(pos_num) + num_total_pos = sum(pos_num_l) try: num_total_pos = paddle.distributed.all_reduce(num_total_pos.clone( )) / paddle.distributed.get_world_size() @@ -198,10 +185,10 @@ class OTAHead(GFLHead): num_total_pos = max(num_total_pos, 1) loss_bbox_list, loss_dfl_list, loss_qfl_list, avg_factor = [], [], [], [] - for cls_score, bbox_pred, grid_cells, labels, label_weights, bbox_targets, stride in zip( - cls_scores, bbox_preds, centors_list, labels_list, + for cls_score, bbox_pred, center_and_strides, labels, label_weights, bbox_targets, stride in zip( + cls_scores, bbox_preds, center_and_strides_list, labels_list, label_weights_list, bbox_targets_list, self.fpn_stride): - grid_cells = grid_cells.reshape([-1, 4]) + center_and_strides = center_and_strides.reshape([-1, 4]) cls_score = cls_score.transpose([0, 2, 3, 1]).reshape( [-1, self.cls_out_channels]) bbox_pred = bbox_pred.transpose([0, 2, 3, 1]).reshape( @@ -219,14 +206,14 @@ class OTAHead(GFLHead): if len(pos_inds) > 0: pos_bbox_targets = paddle.gather(bbox_targets, pos_inds, axis=0) pos_bbox_pred = paddle.gather(bbox_pred, pos_inds, axis=0) - pos_grid_cell_centers = paddle.gather( - grid_cells[:, :-2], pos_inds, axis=0) / stride + pos_centers = paddle.gather( + center_and_strides[:, :-2], pos_inds, axis=0) / stride weight_targets = F.sigmoid(cls_score.detach()) weight_targets = paddle.gather( weight_targets.max(axis=1, keepdim=True), pos_inds, axis=0) pos_bbox_pred_corners = self.distribution_project(pos_bbox_pred) - pos_decode_bbox_pred = distance2bbox(pos_grid_cell_centers, + pos_decode_bbox_pred = distance2bbox(pos_centers, pos_bbox_pred_corners) pos_decode_bbox_targets = pos_bbox_targets / stride bbox_iou = bbox_overlaps( @@ -236,7 +223,7 @@ class OTAHead(GFLHead): score[pos_inds.numpy()] = bbox_iou pred_corners = pos_bbox_pred.reshape([-1, self.reg_max + 1]) - target_corners = bbox2distance(pos_grid_cell_centers, + target_corners = bbox2distance(pos_centers, pos_decode_bbox_targets, self.reg_max).reshape([-1]) # regression loss @@ -355,26 +342,24 @@ class OTAVFLHead(OTAHead): for featmap in cls_scores] decode_bbox_preds = [] - mlvl_centors = [] - with_stride = True + center_and_strides = [] for featmap_size, stride, bbox_pred in zip(featmap_sizes, self.fpn_stride, bbox_preds): - + # center in origin image yy, xx = self.get_single_level_center_point(featmap_size, stride, self.cell_offset) - if with_stride: - stride_w = paddle.full((len(xx), ), stride) - stride_h = paddle.full((len(yy), ), stride) - centers = paddle.stack([xx, yy, stride_w, stride_h], -1).tile( - [num_imgs, 1, 1]) - mlvl_centors.append(centers) - centers_in_feature = centers.reshape([-1, 4])[:, :-2] / stride + strides = paddle.full((len(xx), ), stride) + center_and_stride = paddle.stack([xx, yy, strides, strides], + -1).tile([num_imgs, 1, 1]) + center_and_strides.append(center_and_stride) + center_in_feature = center_and_stride.reshape( + [-1, 4])[:, :-2] / stride bbox_pred = bbox_pred.transpose([0, 2, 3, 1]).reshape( [num_imgs, -1, 4 * (self.reg_max + 1)]) - pred_corners = self.distribution_project(bbox_pred) - decode_bbox_pred = distance2bbox( - centers_in_feature, pred_corners).reshape([num_imgs, -1, 4]) - decode_bbox_preds.append(decode_bbox_pred * stride) + pred_distances = self.distribution_project(bbox_pred) + decode_bbox_pred_wo_stride = distance2bbox( + center_in_feature, pred_distances).reshape([num_imgs, -1, 4]) + decode_bbox_preds.append(decode_bbox_pred_wo_stride * stride) flatten_cls_preds = [ cls_pred.transpose([0, 2, 3, 1]).reshape( @@ -383,27 +368,33 @@ class OTAVFLHead(OTAHead): ] flatten_cls_preds = paddle.concat(flatten_cls_preds, axis=1) flatten_bboxes = paddle.concat(decode_bbox_preds, axis=1) - flatten_centors = paddle.concat(mlvl_centors, axis=1) - - gt_box, gt_labels = gt_meta['gt_bbox'], gt_meta['gt_class'] - (centors, labels, label_weights, bbox_targets, pos_num) = multi_apply( - self._get_target_single, - flatten_cls_preds.detach(), - flatten_centors.detach(), - flatten_bboxes.detach(), gt_box, gt_labels) - - centors = paddle.to_tensor(np.stack(centors, axis=0)) - labels = paddle.to_tensor(np.stack(labels, axis=0)) - label_weights = paddle.to_tensor(np.stack(label_weights, axis=0)) - bbox_targets = paddle.to_tensor(np.stack(bbox_targets, axis=0)) - - centors_list = self._images_to_levels(centors, num_level_anchors) + flatten_center_and_strides = paddle.concat(center_and_strides, axis=1) + + gt_boxes, gt_labels = gt_meta['gt_bbox'], gt_meta['gt_class'] + pos_num_l, label_l, label_weight_l, bbox_target_l = [], [], [], [] + for flatten_cls_pred, flatten_center_and_stride, flatten_bbox,gt_box,gt_label \ + in zip(flatten_cls_preds.detach(), flatten_center_and_strides.detach(), \ + flatten_bboxes.detach(),gt_boxes,gt_labels): + pos_num, label, label_weight, bbox_target = self._get_target_single( + flatten_cls_pred, flatten_center_and_stride, flatten_bbox, + gt_box, gt_label) + pos_num_l.append(pos_num) + label_l.append(label) + label_weight_l.append(label_weight) + bbox_target_l.append(bbox_target) + + labels = paddle.to_tensor(np.stack(label_l, axis=0)) + label_weights = paddle.to_tensor(np.stack(label_weight_l, axis=0)) + bbox_targets = paddle.to_tensor(np.stack(bbox_target_l, axis=0)) + + center_and_strides_list = self._images_to_levels( + flatten_center_and_strides, num_level_anchors) labels_list = self._images_to_levels(labels, num_level_anchors) label_weights_list = self._images_to_levels(label_weights, num_level_anchors) bbox_targets_list = self._images_to_levels(bbox_targets, num_level_anchors) - num_total_pos = sum(pos_num) + num_total_pos = sum(pos_num_l) try: num_total_pos = paddle.distributed.all_reduce(num_total_pos.clone( )) / paddle.distributed.get_world_size() @@ -411,17 +402,16 @@ class OTAVFLHead(OTAHead): num_total_pos = max(num_total_pos, 1) loss_bbox_list, loss_dfl_list, loss_vfl_list, avg_factor = [], [], [], [] - for cls_score, bbox_pred, grid_cells, labels, label_weights, bbox_targets, stride in zip( - cls_scores, bbox_preds, centors_list, labels_list, + for cls_score, bbox_pred, center_and_strides, labels, label_weights, bbox_targets, stride in zip( + cls_scores, bbox_preds, center_and_strides_list, labels_list, label_weights_list, bbox_targets_list, self.fpn_stride): - grid_cells = grid_cells.reshape([-1, 4]) + center_and_strides = center_and_strides.reshape([-1, 4]) cls_score = cls_score.transpose([0, 2, 3, 1]).reshape( [-1, self.cls_out_channels]) bbox_pred = bbox_pred.transpose([0, 2, 3, 1]).reshape( [-1, 4 * (self.reg_max + 1)]) bbox_targets = bbox_targets.reshape([-1, 4]) labels = labels.reshape([-1]) - label_weights = label_weights.reshape([-1]) bg_class_ind = self.num_classes pos_inds = paddle.nonzero( @@ -433,14 +423,14 @@ class OTAVFLHead(OTAHead): if len(pos_inds) > 0: pos_bbox_targets = paddle.gather(bbox_targets, pos_inds, axis=0) pos_bbox_pred = paddle.gather(bbox_pred, pos_inds, axis=0) - pos_grid_cell_centers = paddle.gather( - grid_cells[:, :-2], pos_inds, axis=0) / stride + pos_centers = paddle.gather( + center_and_strides[:, :-2], pos_inds, axis=0) / stride weight_targets = F.sigmoid(cls_score.detach()) weight_targets = paddle.gather( weight_targets.max(axis=1, keepdim=True), pos_inds, axis=0) pos_bbox_pred_corners = self.distribution_project(pos_bbox_pred) - pos_decode_bbox_pred = distance2bbox(pos_grid_cell_centers, + pos_decode_bbox_pred = distance2bbox(pos_centers, pos_bbox_pred_corners) pos_decode_bbox_targets = pos_bbox_targets / stride bbox_iou = bbox_overlaps( @@ -453,7 +443,7 @@ class OTAVFLHead(OTAHead): vfl_score[pos_inds.numpy(), pos_labels] = bbox_iou pred_corners = pos_bbox_pred.reshape([-1, self.reg_max + 1]) - target_corners = bbox2distance(pos_grid_cell_centers, + target_corners = bbox2distance(pos_centers, pos_decode_bbox_targets, self.reg_max).reshape([-1]) # regression loss diff --git a/ppdet/modeling/heads/solov2_head.py b/ppdet/modeling/heads/solov2_head.py index 355d57e64..6989abb3a 100644 --- a/ppdet/modeling/heads/solov2_head.py +++ b/ppdet/modeling/heads/solov2_head.py @@ -34,7 +34,9 @@ __all__ = ['SOLOv2Head'] @register class SOLOv2MaskHead(nn.Layer): """ - MaskHead of SOLOv2 + MaskHead of SOLOv2. + The code of this function is based on: + https://github.com/WXinlong/SOLO/blob/master/mmdet/models/mask_heads/mask_feat_head.py Args: in_channels (int): The channel number of input Tensor. @@ -452,6 +454,10 @@ class SOLOv2Head(nn.Layer): def get_seg_single(self, cate_preds, seg_preds, kernel_preds, featmap_size, im_shape, scale_factor): + """ + The code of this function is based on: + https://github.com/WXinlong/SOLO/blob/master/mmdet/models/anchor_heads/solov2_head.py#L385 + """ h = paddle.cast(im_shape[0], 'int32')[0] w = paddle.cast(im_shape[1], 'int32')[0] upsampled_size_out = [featmap_size[0] * 4, featmap_size[1] * 4] diff --git a/ppdet/modeling/losses/gfocal_loss.py b/ppdet/modeling/losses/gfocal_loss.py index 149d30bf8..37e27f084 100644 --- a/ppdet/modeling/losses/gfocal_loss.py +++ b/ppdet/modeling/losses/gfocal_loss.py @@ -12,6 +12,9 @@ # See the License for the specific language governing permissions and # limitations under the License. +# The code is based on: +# https://github.com/open-mmlab/mmdetection/blob/master/mmdet/models/losses/gfocal_loss.py + from __future__ import absolute_import from __future__ import division from __future__ import print_function diff --git a/ppdet/modeling/losses/varifocal_loss.py b/ppdet/modeling/losses/varifocal_loss.py index 07716a016..42d18a659 100644 --- a/ppdet/modeling/losses/varifocal_loss.py +++ b/ppdet/modeling/losses/varifocal_loss.py @@ -12,6 +12,9 @@ # See the License for the specific language governing permissions and # limitations under the License. +# The code is based on: +# https://github.com/open-mmlab/mmdetection/blob/master/mmdet/models/losses/varifocal_loss.py + from __future__ import absolute_import from __future__ import division from __future__ import print_function diff --git a/ppdet/modeling/necks/__init__.py b/ppdet/modeling/necks/__init__.py index 79deaf94c..d66697caf 100644 --- a/ppdet/modeling/necks/__init__.py +++ b/ppdet/modeling/necks/__init__.py @@ -17,7 +17,6 @@ from . import yolo_fpn from . import hrfpn from . import ttf_fpn from . import centernet_fpn -from . import pan from . import bifpn from . import csp_pan @@ -27,6 +26,5 @@ from .hrfpn import * from .ttf_fpn import * from .centernet_fpn import * from .blazeface_fpn import * -from .pan import * from .bifpn import * from .csp_pan import * diff --git a/ppdet/modeling/necks/csp_pan.py b/ppdet/modeling/necks/csp_pan.py index 6efbb1546..7417c46ab 100644 --- a/ppdet/modeling/necks/csp_pan.py +++ b/ppdet/modeling/necks/csp_pan.py @@ -11,10 +11,9 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -""" -This code is refer from: -https://github.com/open-mmlab/mmdetection/blob/master/mmdet/models/necks/yolox_pafpn.py -""" + +# The code is based on: +# https://github.com/open-mmlab/mmdetection/blob/master/mmdet/models/necks/yolox_pafpn.py import paddle import paddle.nn as nn diff --git a/ppdet/modeling/necks/pan.py b/ppdet/modeling/necks/pan.py deleted file mode 100644 index 8693f29bd..000000000 --- a/ppdet/modeling/necks/pan.py +++ /dev/null @@ -1,127 +0,0 @@ -# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import numpy as np -import paddle -import paddle.nn as nn -import paddle.nn.functional as F -from paddle import ParamAttr -from paddle.nn.initializer import XavierUniform -from paddle.regularizer import L2Decay -from ppdet.core.workspace import register, serializable -from ppdet.modeling.layers import ConvNormLayer -from ..shape_spec import ShapeSpec - -__all__ = ['PAN'] - - -@register -@serializable -class PAN(nn.Layer): - """ - Path Aggregation Network, see https://arxiv.org/abs/1803.01534 - - Args: - in_channels (list[int]): input channels of each level which can be - derived from the output shape of backbone by from_config - out_channel (list[int]): output channel of each level - spatial_scales (list[float]): the spatial scales between input feature - maps and original input image which can be derived from the output - shape of backbone by from_config - start_level (int): Index of the start input backbone level used to - build the feature pyramid. Default: 0. - end_level (int): Index of the end input backbone level (exclusive) to - build the feature pyramid. Default: -1, which means the last level. - norm_type (string|None): The normalization type in FPN module. If - norm_type is None, norm will not be used after conv and if - norm_type is string, bn, gn, sync_bn are available. default None - """ - - def __init__(self, - in_channels, - out_channel, - spatial_scales=[0.125, 0.0625, 0.03125], - start_level=0, - end_level=-1, - norm_type=None): - super(PAN, self).__init__() - self.out_channel = out_channel - self.num_ins = len(in_channels) - self.spatial_scales = spatial_scales - if end_level == -1: - self.end_level = self.num_ins - else: - # if end_level < inputs, no extra level is allowed - self.end_level = end_level - assert end_level <= len(in_channels) - self.start_level = start_level - self.norm_type = norm_type - self.lateral_convs = [] - - for i in range(self.start_level, self.end_level): - in_c = in_channels[i - self.start_level] - if self.norm_type is not None: - lateral = self.add_sublayer( - 'pan_lateral' + str(i), - ConvNormLayer( - ch_in=in_c, - ch_out=self.out_channel, - filter_size=1, - stride=1, - norm_type=self.norm_type, - norm_decay=self.norm_decay, - freeze_norm=self.freeze_norm, - initializer=XavierUniform(fan_out=in_c))) - else: - lateral = self.add_sublayer( - 'pan_lateral' + str(i), - nn.Conv2D( - in_channels=in_c, - out_channels=self.out_channel, - kernel_size=1, - weight_attr=ParamAttr( - initializer=XavierUniform(fan_out=in_c)))) - self.lateral_convs.append(lateral) - - @classmethod - def from_config(cls, cfg, input_shape): - return {'in_channels': [i.channels for i in input_shape], } - - def forward(self, body_feats): - laterals = [] - for i, lateral_conv in enumerate(self.lateral_convs): - laterals.append(lateral_conv(body_feats[i + self.start_level])) - num_levels = len(laterals) - for i in range(1, num_levels): - lvl = num_levels - i - upsample = F.interpolate( - laterals[lvl], - scale_factor=2., - mode='bilinear', ) - laterals[lvl - 1] += upsample - - outs = [laterals[i] for i in range(num_levels)] - for i in range(0, num_levels - 1): - outs[i + 1] += F.interpolate( - outs[i], scale_factor=0.5, mode='bilinear') - - return outs - - @property - def out_shape(self): - return [ - ShapeSpec( - channels=self.out_channel, stride=1. / s) - for s in self.spatial_scales - ] diff --git a/static/ppdet/modeling/anchor_heads/solov2_head.py b/static/ppdet/modeling/anchor_heads/solov2_head.py index 74a681f01..6e592695c 100644 --- a/static/ppdet/modeling/anchor_heads/solov2_head.py +++ b/static/ppdet/modeling/anchor_heads/solov2_head.py @@ -12,6 +12,9 @@ # See the License for the specific language governing permissions and # limitations under the License. +# The code is based on: +# https://github.com/WXinlong/SOLO/blob/master/mmdet/models/anchor_heads/solov2_head.py + from __future__ import absolute_import from __future__ import division from __future__ import print_function diff --git a/static/ppdet/modeling/mask_head/solo_mask_head.py b/static/ppdet/modeling/mask_head/solo_mask_head.py index 61e8e2175..0dc439ce1 100644 --- a/static/ppdet/modeling/mask_head/solo_mask_head.py +++ b/static/ppdet/modeling/mask_head/solo_mask_head.py @@ -12,6 +12,9 @@ # See the License for the specific language governing permissions and # limitations under the License. +# The code is based on: +# https://github.com/WXinlong/SOLO/blob/master/mmdet/models/mask_heads/mask_feat_head.py + from __future__ import absolute_import from __future__ import division from __future__ import print_function -- GitLab