未验证 提交 2768e1ae 编写于 作者: S shangliang Xu 提交者: GitHub

[PPYOLOE] fix assigner bug (#6066)

上级 d2f86f6e
...@@ -51,7 +51,6 @@ class ATSSAssigner(nn.Layer): ...@@ -51,7 +51,6 @@ class ATSSAssigner(nn.Layer):
def _gather_topk_pyramid(self, gt2anchor_distances, num_anchors_list, def _gather_topk_pyramid(self, gt2anchor_distances, num_anchors_list,
pad_gt_mask): pad_gt_mask):
pad_gt_mask = pad_gt_mask.tile([1, 1, self.topk]).astype(paddle.bool)
gt2anchor_distances_list = paddle.split( gt2anchor_distances_list = paddle.split(
gt2anchor_distances, num_anchors_list, axis=-1) gt2anchor_distances, num_anchors_list, axis=-1)
num_anchors_index = np.cumsum(num_anchors_list).tolist() num_anchors_index = np.cumsum(num_anchors_list).tolist()
...@@ -61,15 +60,12 @@ class ATSSAssigner(nn.Layer): ...@@ -61,15 +60,12 @@ class ATSSAssigner(nn.Layer):
for distances, anchors_index in zip(gt2anchor_distances_list, for distances, anchors_index in zip(gt2anchor_distances_list,
num_anchors_index): num_anchors_index):
num_anchors = distances.shape[-1] num_anchors = distances.shape[-1]
topk_metrics, topk_idxs = paddle.topk( _, topk_idxs = paddle.topk(
distances, self.topk, axis=-1, largest=False) distances, self.topk, axis=-1, largest=False)
topk_idxs_list.append(topk_idxs + anchors_index) topk_idxs_list.append(topk_idxs + anchors_index)
topk_idxs = paddle.where(pad_gt_mask, topk_idxs, is_in_topk = F.one_hot(topk_idxs, num_anchors).sum(
paddle.zeros_like(topk_idxs)) axis=-2).astype(gt2anchor_distances.dtype)
is_in_topk = F.one_hot(topk_idxs, num_anchors).sum(axis=-2) is_in_topk_list.append(is_in_topk * pad_gt_mask)
is_in_topk = paddle.where(is_in_topk > 1,
paddle.zeros_like(is_in_topk), is_in_topk)
is_in_topk_list.append(is_in_topk.astype(gt2anchor_distances.dtype))
is_in_topk_list = paddle.concat(is_in_topk_list, axis=-1) is_in_topk_list = paddle.concat(is_in_topk_list, axis=-1)
topk_idxs_list = paddle.concat(topk_idxs_list, axis=-1) topk_idxs_list = paddle.concat(topk_idxs_list, axis=-1)
return is_in_topk_list, topk_idxs_list return is_in_topk_list, topk_idxs_list
...@@ -155,9 +151,8 @@ class ATSSAssigner(nn.Layer): ...@@ -155,9 +151,8 @@ class ATSSAssigner(nn.Layer):
iou_threshold = iou_threshold.reshape([batch_size, num_max_boxes, -1]) iou_threshold = iou_threshold.reshape([batch_size, num_max_boxes, -1])
iou_threshold = iou_threshold.mean(axis=-1, keepdim=True) + \ iou_threshold = iou_threshold.mean(axis=-1, keepdim=True) + \
iou_threshold.std(axis=-1, keepdim=True) iou_threshold.std(axis=-1, keepdim=True)
is_in_topk = paddle.where( is_in_topk = paddle.where(iou_candidates > iou_threshold, is_in_topk,
iou_candidates > iou_threshold.tile([1, 1, num_anchors]), paddle.zeros_like(is_in_topk))
is_in_topk, paddle.zeros_like(is_in_topk))
# 6. check the positive sample's center in gt, [B, n, L] # 6. check the positive sample's center in gt, [B, n, L]
is_in_gts = check_points_inside_bboxes(anchor_centers, gt_bboxes) is_in_gts = check_points_inside_bboxes(anchor_centers, gt_bboxes)
......
...@@ -112,9 +112,7 @@ class TaskAlignedAssigner(nn.Layer): ...@@ -112,9 +112,7 @@ class TaskAlignedAssigner(nn.Layer):
# select topk largest alignment metrics pred bbox as candidates # select topk largest alignment metrics pred bbox as candidates
# for each gt, [B, n, L] # for each gt, [B, n, L]
is_in_topk = gather_topk_anchors( is_in_topk = gather_topk_anchors(
alignment_metrics * is_in_gts, alignment_metrics * is_in_gts, self.topk, topk_mask=pad_gt_mask)
self.topk,
topk_mask=pad_gt_mask.tile([1, 1, self.topk]).astype(paddle.bool))
# select positive sample, [B, n, L] # select positive sample, [B, n, L]
mask_positive = is_in_topk * is_in_gts * pad_gt_mask mask_positive = is_in_topk * is_in_gts * pad_gt_mask
......
...@@ -88,7 +88,7 @@ def gather_topk_anchors(metrics, topk, largest=True, topk_mask=None, eps=1e-9): ...@@ -88,7 +88,7 @@ def gather_topk_anchors(metrics, topk, largest=True, topk_mask=None, eps=1e-9):
largest (bool) : largest is a flag, if set to true, largest (bool) : largest is a flag, if set to true,
algorithm will sort by descending order, otherwise sort by algorithm will sort by descending order, otherwise sort by
ascending order. Default: True ascending order. Default: True
topk_mask (Tensor, bool|None): shape[B, n, topk], ignore bbox mask, topk_mask (Tensor, float32): shape[B, n, 1], ignore bbox mask,
Default: None Default: None
eps (float): Default: 1e-9 eps (float): Default: 1e-9
Returns: Returns:
...@@ -98,13 +98,11 @@ def gather_topk_anchors(metrics, topk, largest=True, topk_mask=None, eps=1e-9): ...@@ -98,13 +98,11 @@ def gather_topk_anchors(metrics, topk, largest=True, topk_mask=None, eps=1e-9):
topk_metrics, topk_idxs = paddle.topk( topk_metrics, topk_idxs = paddle.topk(
metrics, topk, axis=-1, largest=largest) metrics, topk, axis=-1, largest=largest)
if topk_mask is None: if topk_mask is None:
topk_mask = (topk_metrics.max(axis=-1, keepdim=True) > eps).tile( topk_mask = (
[1, 1, topk]) topk_metrics.max(axis=-1, keepdim=True) > eps).astype(metrics.dtype)
topk_idxs = paddle.where(topk_mask, topk_idxs, paddle.zeros_like(topk_idxs)) is_in_topk = F.one_hot(topk_idxs, num_anchors).sum(
is_in_topk = F.one_hot(topk_idxs, num_anchors).sum(axis=-2) axis=-2).astype(metrics.dtype)
is_in_topk = paddle.where(is_in_topk > 1, return is_in_topk * topk_mask
paddle.zeros_like(is_in_topk), is_in_topk)
return is_in_topk.astype(metrics.dtype)
def check_points_inside_bboxes(points, def check_points_inside_bboxes(points,
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册