From 0f4d67a50f3dccc82e788a668e4be75019d62581 Mon Sep 17 00:00:00 2001 From: shangliang Xu Date: Mon, 13 Feb 2023 10:17:58 +0800 Subject: [PATCH] fix dn_match_indices bug in detr_loss.py (#7725) --- ppdet/modeling/losses/detr_loss.py | 24 +++++++++++++----------- 1 file changed, 13 insertions(+), 11 deletions(-) diff --git a/ppdet/modeling/losses/detr_loss.py b/ppdet/modeling/losses/detr_loss.py index 8e9b1294f..f4291a7a7 100644 --- a/ppdet/modeling/losses/detr_loss.py +++ b/ppdet/modeling/losses/detr_loss.py @@ -163,9 +163,9 @@ class DETRLoss(nn.Layer): gt_class, bg_index, num_gts, - match_indices=None, + dn_match_indices=None, postfix=""): - if boxes is None and logits is None: + if boxes is None or logits is None: return { "loss_class_aux" + postfix: paddle.paddle.zeros([1]), "loss_bbox_aux" + postfix: paddle.paddle.zeros([1]), @@ -175,9 +175,11 @@ class DETRLoss(nn.Layer): loss_bbox = [] loss_giou = [] for aux_boxes, aux_logits in zip(boxes, logits): - if match_indices is None: + if dn_match_indices is None: match_indices = self.matcher(aux_boxes, aux_logits, gt_bbox, gt_class) + else: + match_indices = dn_match_indices loss_class.append( self._get_loss_class(aux_logits, gt_class, match_indices, bg_index, num_gts, postfix)['loss_class' + @@ -237,11 +239,13 @@ class DETRLoss(nn.Layer): gt_mask (List(Tensor), optional): list[[n, H, W]] postfix (str): postfix of loss name """ - if "match_indices" in kwargs: - match_indices = kwargs["match_indices"] - else: + dn_match_indices = kwargs.get("dn_match_indices", None) + if dn_match_indices is None and (boxes is not None and + logits is not None): match_indices = self.matcher(boxes[-1].detach(), logits[-1].detach(), gt_bbox, gt_class) + else: + match_indices = dn_match_indices num_gts = sum(len(a) for a in gt_bbox) num_gts = paddle.to_tensor([num_gts], dtype="float32") @@ -264,13 +268,11 @@ class DETRLoss(nn.Layer): gt_mask, match_indices, num_gts, postfix)) if self.aux_loss: - if "match_indices" not in kwargs: - match_indices = None total_loss.update( self._get_loss_aux( boxes[:-1] if boxes is not None else None, logits[:-1] if logits is not None else None, gt_bbox, gt_class, - self.num_classes, num_gts, match_indices, postfix)) + self.num_classes, num_gts, dn_match_indices, postfix)) return total_loss @@ -292,7 +294,6 @@ class DINOLoss(DETRLoss): total_loss = super(DINOLoss, self).forward(boxes, logits, gt_bbox, gt_class) - # denoising training loss if dn_meta is not None: dn_positive_idx, dn_num_group = \ dn_meta["dn_positive_idx"], dn_meta["dn_num_group"] @@ -315,13 +316,14 @@ class DINOLoss(DETRLoss): else: dn_match_indices, dn_num_group = None, 1. + # compute denoising training loss dn_loss = super(DINOLoss, self).forward( dn_out_bboxes, dn_out_logits, gt_bbox, gt_class, postfix="_dn", - match_indices=dn_match_indices, + dn_match_indices=dn_match_indices, dn_num_group=dn_num_group) total_loss.update(dn_loss) -- GitLab