未验证 提交 0f4d67a5 编写于 作者: S shangliang Xu 提交者: GitHub

fix dn_match_indices bug in detr_loss.py (#7725)

上级 dd2d5795
...@@ -163,9 +163,9 @@ class DETRLoss(nn.Layer): ...@@ -163,9 +163,9 @@ class DETRLoss(nn.Layer):
gt_class, gt_class,
bg_index, bg_index,
num_gts, num_gts,
match_indices=None, dn_match_indices=None,
postfix=""): postfix=""):
if boxes is None and logits is None: if boxes is None or logits is None:
return { return {
"loss_class_aux" + postfix: paddle.paddle.zeros([1]), "loss_class_aux" + postfix: paddle.paddle.zeros([1]),
"loss_bbox_aux" + postfix: paddle.paddle.zeros([1]), "loss_bbox_aux" + postfix: paddle.paddle.zeros([1]),
...@@ -175,9 +175,11 @@ class DETRLoss(nn.Layer): ...@@ -175,9 +175,11 @@ class DETRLoss(nn.Layer):
loss_bbox = [] loss_bbox = []
loss_giou = [] loss_giou = []
for aux_boxes, aux_logits in zip(boxes, logits): for aux_boxes, aux_logits in zip(boxes, logits):
if match_indices is None: if dn_match_indices is None:
match_indices = self.matcher(aux_boxes, aux_logits, gt_bbox, match_indices = self.matcher(aux_boxes, aux_logits, gt_bbox,
gt_class) gt_class)
else:
match_indices = dn_match_indices
loss_class.append( loss_class.append(
self._get_loss_class(aux_logits, gt_class, match_indices, self._get_loss_class(aux_logits, gt_class, match_indices,
bg_index, num_gts, postfix)['loss_class' + bg_index, num_gts, postfix)['loss_class' +
...@@ -237,11 +239,13 @@ class DETRLoss(nn.Layer): ...@@ -237,11 +239,13 @@ class DETRLoss(nn.Layer):
gt_mask (List(Tensor), optional): list[[n, H, W]] gt_mask (List(Tensor), optional): list[[n, H, W]]
postfix (str): postfix of loss name postfix (str): postfix of loss name
""" """
if "match_indices" in kwargs: dn_match_indices = kwargs.get("dn_match_indices", None)
match_indices = kwargs["match_indices"] if dn_match_indices is None and (boxes is not None and
else: logits is not None):
match_indices = self.matcher(boxes[-1].detach(), match_indices = self.matcher(boxes[-1].detach(),
logits[-1].detach(), gt_bbox, gt_class) logits[-1].detach(), gt_bbox, gt_class)
else:
match_indices = dn_match_indices
num_gts = sum(len(a) for a in gt_bbox) num_gts = sum(len(a) for a in gt_bbox)
num_gts = paddle.to_tensor([num_gts], dtype="float32") num_gts = paddle.to_tensor([num_gts], dtype="float32")
...@@ -264,13 +268,11 @@ class DETRLoss(nn.Layer): ...@@ -264,13 +268,11 @@ class DETRLoss(nn.Layer):
gt_mask, match_indices, num_gts, postfix)) gt_mask, match_indices, num_gts, postfix))
if self.aux_loss: if self.aux_loss:
if "match_indices" not in kwargs:
match_indices = None
total_loss.update( total_loss.update(
self._get_loss_aux( self._get_loss_aux(
boxes[:-1] if boxes is not None else None, logits[:-1] boxes[:-1] if boxes is not None else None, logits[:-1]
if logits is not None else None, gt_bbox, gt_class, if logits is not None else None, gt_bbox, gt_class,
self.num_classes, num_gts, match_indices, postfix)) self.num_classes, num_gts, dn_match_indices, postfix))
return total_loss return total_loss
...@@ -292,7 +294,6 @@ class DINOLoss(DETRLoss): ...@@ -292,7 +294,6 @@ class DINOLoss(DETRLoss):
total_loss = super(DINOLoss, self).forward(boxes, logits, gt_bbox, total_loss = super(DINOLoss, self).forward(boxes, logits, gt_bbox,
gt_class) gt_class)
# denoising training loss
if dn_meta is not None: if dn_meta is not None:
dn_positive_idx, dn_num_group = \ dn_positive_idx, dn_num_group = \
dn_meta["dn_positive_idx"], dn_meta["dn_num_group"] dn_meta["dn_positive_idx"], dn_meta["dn_num_group"]
...@@ -315,13 +316,14 @@ class DINOLoss(DETRLoss): ...@@ -315,13 +316,14 @@ class DINOLoss(DETRLoss):
else: else:
dn_match_indices, dn_num_group = None, 1. dn_match_indices, dn_num_group = None, 1.
# compute denoising training loss
dn_loss = super(DINOLoss, self).forward( dn_loss = super(DINOLoss, self).forward(
dn_out_bboxes, dn_out_bboxes,
dn_out_logits, dn_out_logits,
gt_bbox, gt_bbox,
gt_class, gt_class,
postfix="_dn", postfix="_dn",
match_indices=dn_match_indices, dn_match_indices=dn_match_indices,
dn_num_group=dn_num_group) dn_num_group=dn_num_group)
total_loss.update(dn_loss) total_loss.update(dn_loss)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册