diff --git a/ppdet/modeling/assigners/atss_assigner.py b/ppdet/modeling/assigners/atss_assigner.py index 550962d8493a6e29de1b0601b994e64d4cd99cc2..2a641d4cc41928671d13dd14cddb73295d0f8386 100644 --- a/ppdet/modeling/assigners/atss_assigner.py +++ b/ppdet/modeling/assigners/atss_assigner.py @@ -51,7 +51,6 @@ class ATSSAssigner(nn.Layer): def _gather_topk_pyramid(self, gt2anchor_distances, num_anchors_list, pad_gt_mask): - pad_gt_mask = pad_gt_mask.tile([1, 1, self.topk]).astype(paddle.bool) gt2anchor_distances_list = paddle.split( gt2anchor_distances, num_anchors_list, axis=-1) num_anchors_index = np.cumsum(num_anchors_list).tolist() @@ -61,15 +60,12 @@ class ATSSAssigner(nn.Layer): for distances, anchors_index in zip(gt2anchor_distances_list, num_anchors_index): num_anchors = distances.shape[-1] - topk_metrics, topk_idxs = paddle.topk( + _, topk_idxs = paddle.topk( distances, self.topk, axis=-1, largest=False) topk_idxs_list.append(topk_idxs + anchors_index) - topk_idxs = paddle.where(pad_gt_mask, topk_idxs, - paddle.zeros_like(topk_idxs)) - is_in_topk = F.one_hot(topk_idxs, num_anchors).sum(axis=-2) - is_in_topk = paddle.where(is_in_topk > 1, - paddle.zeros_like(is_in_topk), is_in_topk) - is_in_topk_list.append(is_in_topk.astype(gt2anchor_distances.dtype)) + is_in_topk = F.one_hot(topk_idxs, num_anchors).sum( + axis=-2).astype(gt2anchor_distances.dtype) + is_in_topk_list.append(is_in_topk * pad_gt_mask) is_in_topk_list = paddle.concat(is_in_topk_list, axis=-1) topk_idxs_list = paddle.concat(topk_idxs_list, axis=-1) return is_in_topk_list, topk_idxs_list @@ -155,9 +151,8 @@ class ATSSAssigner(nn.Layer): iou_threshold = iou_threshold.reshape([batch_size, num_max_boxes, -1]) iou_threshold = iou_threshold.mean(axis=-1, keepdim=True) + \ iou_threshold.std(axis=-1, keepdim=True) - is_in_topk = paddle.where( - iou_candidates > iou_threshold.tile([1, 1, num_anchors]), - is_in_topk, paddle.zeros_like(is_in_topk)) + is_in_topk = paddle.where(iou_candidates > iou_threshold, is_in_topk, + paddle.zeros_like(is_in_topk)) # 6. check the positive sample's center in gt, [B, n, L] is_in_gts = check_points_inside_bboxes(anchor_centers, gt_bboxes) diff --git a/ppdet/modeling/assigners/task_aligned_assigner.py b/ppdet/modeling/assigners/task_aligned_assigner.py index b97923f6c7b3bbf87a83e0eab91ce3358c316069..1a82c15237a07d3993460629ccb8317466da87fb 100644 --- a/ppdet/modeling/assigners/task_aligned_assigner.py +++ b/ppdet/modeling/assigners/task_aligned_assigner.py @@ -112,9 +112,7 @@ class TaskAlignedAssigner(nn.Layer): # select topk largest alignment metrics pred bbox as candidates # for each gt, [B, n, L] is_in_topk = gather_topk_anchors( - alignment_metrics * is_in_gts, - self.topk, - topk_mask=pad_gt_mask.tile([1, 1, self.topk]).astype(paddle.bool)) + alignment_metrics * is_in_gts, self.topk, topk_mask=pad_gt_mask) # select positive sample, [B, n, L] mask_positive = is_in_topk * is_in_gts * pad_gt_mask diff --git a/ppdet/modeling/assigners/utils.py b/ppdet/modeling/assigners/utils.py index f2d9be99509b62d585a9f4db999818b6d25a2d94..6a89593a316da8cc9221fee872f34a8452200751 100644 --- a/ppdet/modeling/assigners/utils.py +++ b/ppdet/modeling/assigners/utils.py @@ -88,7 +88,7 @@ def gather_topk_anchors(metrics, topk, largest=True, topk_mask=None, eps=1e-9): largest (bool) : largest is a flag, if set to true, algorithm will sort by descending order, otherwise sort by ascending order. Default: True - topk_mask (Tensor, bool|None): shape[B, n, topk], ignore bbox mask, + topk_mask (Tensor, float32): shape[B, n, 1], ignore bbox mask, Default: None eps (float): Default: 1e-9 Returns: @@ -98,13 +98,11 @@ def gather_topk_anchors(metrics, topk, largest=True, topk_mask=None, eps=1e-9): topk_metrics, topk_idxs = paddle.topk( metrics, topk, axis=-1, largest=largest) if topk_mask is None: - topk_mask = (topk_metrics.max(axis=-1, keepdim=True) > eps).tile( - [1, 1, topk]) - topk_idxs = paddle.where(topk_mask, topk_idxs, paddle.zeros_like(topk_idxs)) - is_in_topk = F.one_hot(topk_idxs, num_anchors).sum(axis=-2) - is_in_topk = paddle.where(is_in_topk > 1, - paddle.zeros_like(is_in_topk), is_in_topk) - return is_in_topk.astype(metrics.dtype) + topk_mask = ( + topk_metrics.max(axis=-1, keepdim=True) > eps).astype(metrics.dtype) + is_in_topk = F.one_hot(topk_idxs, num_anchors).sum( + axis=-2).astype(metrics.dtype) + return is_in_topk * topk_mask def check_points_inside_bboxes(points,