未验证 提交 0f4d67a5 编写于 作者: S shangliang Xu 提交者: GitHub

fix dn_match_indices bug in detr_loss.py (#7725)

上级 dd2d5795
......@@ -163,9 +163,9 @@ class DETRLoss(nn.Layer):
gt_class,
bg_index,
num_gts,
match_indices=None,
dn_match_indices=None,
postfix=""):
if boxes is None and logits is None:
if boxes is None or logits is None:
return {
"loss_class_aux" + postfix: paddle.paddle.zeros([1]),
"loss_bbox_aux" + postfix: paddle.paddle.zeros([1]),
......@@ -175,9 +175,11 @@ class DETRLoss(nn.Layer):
loss_bbox = []
loss_giou = []
for aux_boxes, aux_logits in zip(boxes, logits):
if match_indices is None:
if dn_match_indices is None:
match_indices = self.matcher(aux_boxes, aux_logits, gt_bbox,
gt_class)
else:
match_indices = dn_match_indices
loss_class.append(
self._get_loss_class(aux_logits, gt_class, match_indices,
bg_index, num_gts, postfix)['loss_class' +
......@@ -237,11 +239,13 @@ class DETRLoss(nn.Layer):
gt_mask (List(Tensor), optional): list[[n, H, W]]
postfix (str): postfix of loss name
"""
if "match_indices" in kwargs:
match_indices = kwargs["match_indices"]
else:
dn_match_indices = kwargs.get("dn_match_indices", None)
if dn_match_indices is None and (boxes is not None and
logits is not None):
match_indices = self.matcher(boxes[-1].detach(),
logits[-1].detach(), gt_bbox, gt_class)
else:
match_indices = dn_match_indices
num_gts = sum(len(a) for a in gt_bbox)
num_gts = paddle.to_tensor([num_gts], dtype="float32")
......@@ -264,13 +268,11 @@ class DETRLoss(nn.Layer):
gt_mask, match_indices, num_gts, postfix))
if self.aux_loss:
if "match_indices" not in kwargs:
match_indices = None
total_loss.update(
self._get_loss_aux(
boxes[:-1] if boxes is not None else None, logits[:-1]
if logits is not None else None, gt_bbox, gt_class,
self.num_classes, num_gts, match_indices, postfix))
self.num_classes, num_gts, dn_match_indices, postfix))
return total_loss
......@@ -292,7 +294,6 @@ class DINOLoss(DETRLoss):
total_loss = super(DINOLoss, self).forward(boxes, logits, gt_bbox,
gt_class)
# denoising training loss
if dn_meta is not None:
dn_positive_idx, dn_num_group = \
dn_meta["dn_positive_idx"], dn_meta["dn_num_group"]
......@@ -315,13 +316,14 @@ class DINOLoss(DETRLoss):
else:
dn_match_indices, dn_num_group = None, 1.
# compute denoising training loss
dn_loss = super(DINOLoss, self).forward(
dn_out_bboxes,
dn_out_logits,
gt_bbox,
gt_class,
postfix="_dn",
match_indices=dn_match_indices,
dn_match_indices=dn_match_indices,
dn_num_group=dn_num_group)
total_loss.update(dn_loss)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册