未验证 提交 2cd40bdd 编写于 作者: S shangliang Xu 提交者: GitHub

fix one_hot error (#5532)

上级 eb8e1541
......@@ -199,7 +199,11 @@ class ATSSAssigner(nn.Layer):
gt_bboxes.reshape([-1, 4]), assigned_gt_index.flatten(), axis=0)
assigned_bboxes = assigned_bboxes.reshape([batch_size, num_anchors, 4])
assigned_scores = F.one_hot(assigned_labels, self.num_classes)
assigned_scores = F.one_hot(assigned_labels, self.num_classes + 1)
ind = list(range(self.num_classes + 1))
ind.remove(bg_index)
assigned_scores = paddle.index_select(
assigned_scores, paddle.to_tensor(ind), axis=-1)
if pred_bboxes is not None:
# assigned iou
ious = batch_iou_similarity(gt_bboxes, pred_bboxes) * mask_positive
......
......@@ -143,7 +143,11 @@ class TaskAlignedAssigner(nn.Layer):
gt_bboxes.reshape([-1, 4]), assigned_gt_index.flatten(), axis=0)
assigned_bboxes = assigned_bboxes.reshape([batch_size, num_anchors, 4])
assigned_scores = F.one_hot(assigned_labels, num_classes)
assigned_scores = F.one_hot(assigned_labels, num_classes + 1)
ind = list(range(num_classes + 1))
ind.remove(bg_index)
assigned_scores = paddle.index_select(
assigned_scores, paddle.to_tensor(ind), axis=-1)
# rescale alignment metrics
alignment_metrics *= mask_positive
max_metrics_per_instance = alignment_metrics.max(axis=-1, keepdim=True)
......
......@@ -331,7 +331,8 @@ class PPYOLOEHead(nn.Layer):
assigned_bboxes /= stride_tensor
# cls loss
if self.use_varifocal_loss:
one_hot_label = F.one_hot(assigned_labels, self.num_classes)
one_hot_label = F.one_hot(assigned_labels,
self.num_classes + 1)[..., :-1]
loss_cls = self._varifocal_loss(pred_scores, assigned_scores,
one_hot_label)
else:
......
......@@ -80,7 +80,7 @@ class DETRLoss(nn.Layer):
target_label = target_label.reshape([bs, num_query_objects])
if self.use_focal_loss:
target_label = F.one_hot(target_label,
self.num_classes + 1)[:, :, :-1]
self.num_classes + 1)[..., :-1]
return {
'loss_class': self.loss_coeff['class'] * sigmoid_focal_loss(
logits, target_label, num_gts / num_query_objects)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册