未验证 提交 2768e1ae 编写于 作者: S shangliang Xu 提交者: GitHub

[PPYOLOE] fix assigner bug (#6066)

上级 d2f86f6e
......@@ -51,7 +51,6 @@ class ATSSAssigner(nn.Layer):
def _gather_topk_pyramid(self, gt2anchor_distances, num_anchors_list,
pad_gt_mask):
pad_gt_mask = pad_gt_mask.tile([1, 1, self.topk]).astype(paddle.bool)
gt2anchor_distances_list = paddle.split(
gt2anchor_distances, num_anchors_list, axis=-1)
num_anchors_index = np.cumsum(num_anchors_list).tolist()
......@@ -61,15 +60,12 @@ class ATSSAssigner(nn.Layer):
for distances, anchors_index in zip(gt2anchor_distances_list,
num_anchors_index):
num_anchors = distances.shape[-1]
topk_metrics, topk_idxs = paddle.topk(
_, topk_idxs = paddle.topk(
distances, self.topk, axis=-1, largest=False)
topk_idxs_list.append(topk_idxs + anchors_index)
topk_idxs = paddle.where(pad_gt_mask, topk_idxs,
paddle.zeros_like(topk_idxs))
is_in_topk = F.one_hot(topk_idxs, num_anchors).sum(axis=-2)
is_in_topk = paddle.where(is_in_topk > 1,
paddle.zeros_like(is_in_topk), is_in_topk)
is_in_topk_list.append(is_in_topk.astype(gt2anchor_distances.dtype))
is_in_topk = F.one_hot(topk_idxs, num_anchors).sum(
axis=-2).astype(gt2anchor_distances.dtype)
is_in_topk_list.append(is_in_topk * pad_gt_mask)
is_in_topk_list = paddle.concat(is_in_topk_list, axis=-1)
topk_idxs_list = paddle.concat(topk_idxs_list, axis=-1)
return is_in_topk_list, topk_idxs_list
......@@ -155,9 +151,8 @@ class ATSSAssigner(nn.Layer):
iou_threshold = iou_threshold.reshape([batch_size, num_max_boxes, -1])
iou_threshold = iou_threshold.mean(axis=-1, keepdim=True) + \
iou_threshold.std(axis=-1, keepdim=True)
is_in_topk = paddle.where(
iou_candidates > iou_threshold.tile([1, 1, num_anchors]),
is_in_topk, paddle.zeros_like(is_in_topk))
is_in_topk = paddle.where(iou_candidates > iou_threshold, is_in_topk,
paddle.zeros_like(is_in_topk))
# 6. check the positive sample's center in gt, [B, n, L]
is_in_gts = check_points_inside_bboxes(anchor_centers, gt_bboxes)
......
......@@ -112,9 +112,7 @@ class TaskAlignedAssigner(nn.Layer):
# select topk largest alignment metrics pred bbox as candidates
# for each gt, [B, n, L]
is_in_topk = gather_topk_anchors(
alignment_metrics * is_in_gts,
self.topk,
topk_mask=pad_gt_mask.tile([1, 1, self.topk]).astype(paddle.bool))
alignment_metrics * is_in_gts, self.topk, topk_mask=pad_gt_mask)
# select positive sample, [B, n, L]
mask_positive = is_in_topk * is_in_gts * pad_gt_mask
......
......@@ -88,7 +88,7 @@ def gather_topk_anchors(metrics, topk, largest=True, topk_mask=None, eps=1e-9):
largest (bool) : largest is a flag, if set to true,
algorithm will sort by descending order, otherwise sort by
ascending order. Default: True
topk_mask (Tensor, bool|None): shape[B, n, topk], ignore bbox mask,
topk_mask (Tensor, float32): shape[B, n, 1], ignore bbox mask,
Default: None
eps (float): Default: 1e-9
Returns:
......@@ -98,13 +98,11 @@ def gather_topk_anchors(metrics, topk, largest=True, topk_mask=None, eps=1e-9):
topk_metrics, topk_idxs = paddle.topk(
metrics, topk, axis=-1, largest=largest)
if topk_mask is None:
topk_mask = (topk_metrics.max(axis=-1, keepdim=True) > eps).tile(
[1, 1, topk])
topk_idxs = paddle.where(topk_mask, topk_idxs, paddle.zeros_like(topk_idxs))
is_in_topk = F.one_hot(topk_idxs, num_anchors).sum(axis=-2)
is_in_topk = paddle.where(is_in_topk > 1,
paddle.zeros_like(is_in_topk), is_in_topk)
return is_in_topk.astype(metrics.dtype)
topk_mask = (
topk_metrics.max(axis=-1, keepdim=True) > eps).astype(metrics.dtype)
is_in_topk = F.one_hot(topk_idxs, num_anchors).sum(
axis=-2).astype(metrics.dtype)
return is_in_topk * topk_mask
def check_points_inside_bboxes(points,
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册