From 26ae0baca0ddbfd8ee4810b3a89cebe603e4f424 Mon Sep 17 00:00:00 2001 From: shangliang Xu Date: Tue, 26 Oct 2021 16:14:32 +0800 Subject: [PATCH] [dev] fix tood negative training (#4371) --- ppdet/modeling/assigners/atss_assigner.py | 8 ++++++++ ppdet/modeling/assigners/task_aligned_assigner.py | 8 ++++++++ 2 files changed, 16 insertions(+) diff --git a/ppdet/modeling/assigners/atss_assigner.py b/ppdet/modeling/assigners/atss_assigner.py index 8d8e555ae..4835dd1d5 100644 --- a/ppdet/modeling/assigners/atss_assigner.py +++ b/ppdet/modeling/assigners/atss_assigner.py @@ -114,6 +114,14 @@ class ATSSAssigner(nn.Layer): num_anchors, _ = anchor_bboxes.shape batch_size, num_max_boxes, _ = gt_bboxes.shape + # negative batch + if num_max_boxes == 0: + assigned_labels = paddle.full([batch_size, num_anchors], bg_index) + assigned_bboxes = paddle.zeros([batch_size, num_anchors, 4]) + assigned_scores = paddle.zeros( + [batch_size, num_anchors, self.num_classes]) + return assigned_labels, assigned_bboxes, assigned_scores + # 1. compute iou between gt and anchor bbox, [B, n, L] ious = iou_similarity(gt_bboxes.reshape([-1, 4]), anchor_bboxes) ious = ious.reshape([batch_size, -1, num_anchors]) diff --git a/ppdet/modeling/assigners/task_aligned_assigner.py b/ppdet/modeling/assigners/task_aligned_assigner.py index ed43c4936..bae00cc7c 100644 --- a/ppdet/modeling/assigners/task_aligned_assigner.py +++ b/ppdet/modeling/assigners/task_aligned_assigner.py @@ -78,6 +78,14 @@ class TaskAlignedAssigner(nn.Layer): batch_size, num_anchors, num_classes = pred_scores.shape _, num_max_boxes, _ = gt_bboxes.shape + # negative batch + if num_max_boxes == 0: + assigned_labels = paddle.full([batch_size, num_anchors], bg_index) + assigned_bboxes = paddle.zeros([batch_size, num_anchors, 4]) + assigned_scores = paddle.zeros( + [batch_size, num_anchors, num_classes]) + return assigned_labels, assigned_bboxes, assigned_scores + # compute iou between gt and pred bbox, [B, n, L] ious = iou_similarity(gt_bboxes, pred_bboxes) # gather pred bboxes class score -- GitLab