未验证 提交 26ae0bac 编写于 作者: S shangliang Xu 提交者: GitHub

[dev] fix tood negative training (#4371)

上级 1fa998be
......@@ -114,6 +114,14 @@ class ATSSAssigner(nn.Layer):
num_anchors, _ = anchor_bboxes.shape
batch_size, num_max_boxes, _ = gt_bboxes.shape
# negative batch
if num_max_boxes == 0:
assigned_labels = paddle.full([batch_size, num_anchors], bg_index)
assigned_bboxes = paddle.zeros([batch_size, num_anchors, 4])
assigned_scores = paddle.zeros(
[batch_size, num_anchors, self.num_classes])
return assigned_labels, assigned_bboxes, assigned_scores
# 1. compute iou between gt and anchor bbox, [B, n, L]
ious = iou_similarity(gt_bboxes.reshape([-1, 4]), anchor_bboxes)
ious = ious.reshape([batch_size, -1, num_anchors])
......
......@@ -78,6 +78,14 @@ class TaskAlignedAssigner(nn.Layer):
batch_size, num_anchors, num_classes = pred_scores.shape
_, num_max_boxes, _ = gt_bboxes.shape
# negative batch
if num_max_boxes == 0:
assigned_labels = paddle.full([batch_size, num_anchors], bg_index)
assigned_bboxes = paddle.zeros([batch_size, num_anchors, 4])
assigned_scores = paddle.zeros(
[batch_size, num_anchors, num_classes])
return assigned_labels, assigned_bboxes, assigned_scores
# compute iou between gt and pred bbox, [B, n, L]
ious = iou_similarity(gt_bboxes, pred_bboxes)
# gather pred bboxes class score
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册