未验证 提交 b40f00cb 编写于 作者: S shangliang Xu 提交者: GitHub

fix deprecated api in ssd loss (#3592)

上级 befec463
...@@ -67,18 +67,15 @@ class SSDLoss(nn.Layer): ...@@ -67,18 +67,15 @@ class SSDLoss(nn.Layer):
ious = iou_similarity(gt_bbox.reshape((-1, 4)), prior_boxes).reshape( ious = iou_similarity(gt_bbox.reshape((-1, 4)), prior_boxes).reshape(
(batch_size, -1, num_priors)) (batch_size, -1, num_priors))
# Calculate the number of object per sample.
num_object = (ious.sum(axis=-1) > 0).astype('int64').sum(axis=-1)
# For each prior box, get the max IoU of all GTs. # For each prior box, get the max IoU of all GTs.
prior_max_iou, prior_argmax_iou = ious.max(axis=1), ious.argmax(axis=1) prior_max_iou, prior_argmax_iou = ious.max(axis=1), ious.argmax(axis=1)
# For each GT, get the max IoU of all prior boxes. # For each GT, get the max IoU of all prior boxes.
gt_max_iou, gt_argmax_iou = ious.max(axis=2), ious.argmax(axis=2) gt_max_iou, gt_argmax_iou = ious.max(axis=2), ious.argmax(axis=2)
# Gather target bbox and label according to 'prior_argmax_iou' index. # Gather target bbox and label according to 'prior_argmax_iou' index.
batch_ind = paddle.arange( batch_ind = paddle.arange(end=batch_size, dtype='int64').unsqueeze(-1)
0, batch_size, dtype='int64').unsqueeze(-1).tile([1, num_priors]) prior_argmax_iou = paddle.stack(
prior_argmax_iou = paddle.stack([batch_ind, prior_argmax_iou], axis=-1) [batch_ind.tile([1, num_priors]), prior_argmax_iou], axis=-1)
targets_bbox = paddle.gather_nd(gt_bbox, prior_argmax_iou) targets_bbox = paddle.gather_nd(gt_bbox, prior_argmax_iou)
targets_label = paddle.gather_nd(gt_label, prior_argmax_iou) targets_label = paddle.gather_nd(gt_label, prior_argmax_iou)
# Assign negative # Assign negative
...@@ -89,14 +86,14 @@ class SSDLoss(nn.Layer): ...@@ -89,14 +86,14 @@ class SSDLoss(nn.Layer):
bg_index_tensor, targets_label) bg_index_tensor, targets_label)
# Ensure each GT can match the max IoU prior box. # Ensure each GT can match the max IoU prior box.
for i in range(batch_size): batch_ind = (batch_ind * num_priors + gt_argmax_iou).flatten()
if num_object[i] > 0: targets_bbox = paddle.scatter(
targets_bbox[i] = paddle.scatter( targets_bbox.reshape([-1, 4]), batch_ind,
targets_bbox[i], gt_argmax_iou[i, :int(num_object[i])], gt_bbox.reshape([-1, 4])).reshape([batch_size, -1, 4])
gt_bbox[i, :int(num_object[i])]) targets_label = paddle.scatter(
targets_label[i] = paddle.scatter( targets_label.reshape([-1, 1]), batch_ind,
targets_label[i], gt_argmax_iou[i, :int(num_object[i])], gt_label.reshape([-1, 1])).reshape([batch_size, -1, 1])
gt_label[i, :int(num_object[i])]) targets_label[:, :1] = bg_index
# Encode box # Encode box
prior_boxes = prior_boxes.unsqueeze(0).tile([batch_size, 1, 1]) prior_boxes = prior_boxes.unsqueeze(0).tile([batch_size, 1, 1])
...@@ -107,12 +104,16 @@ class SSDLoss(nn.Layer): ...@@ -107,12 +104,16 @@ class SSDLoss(nn.Layer):
return targets_bbox, targets_label return targets_bbox, targets_label
def _mine_hard_example(self, conf_loss, targets_label, bg_index): def _mine_hard_example(self,
conf_loss,
targets_label,
bg_index,
mine_neg_ratio=0.01):
pos = (targets_label != bg_index).astype(conf_loss.dtype) pos = (targets_label != bg_index).astype(conf_loss.dtype)
num_pos = pos.sum(axis=1, keepdim=True) num_pos = pos.sum(axis=1, keepdim=True)
neg = (targets_label == bg_index).astype(conf_loss.dtype) neg = (targets_label == bg_index).astype(conf_loss.dtype)
conf_loss = conf_loss.clone() * neg conf_loss = conf_loss.detach() * neg
loss_idx = conf_loss.argsort(axis=1, descending=True) loss_idx = conf_loss.argsort(axis=1, descending=True)
idx_rank = loss_idx.argsort(axis=1) idx_rank = loss_idx.argsort(axis=1)
num_negs = [] num_negs = []
...@@ -120,9 +121,11 @@ class SSDLoss(nn.Layer): ...@@ -120,9 +121,11 @@ class SSDLoss(nn.Layer):
cur_num_pos = num_pos[i] cur_num_pos = num_pos[i]
num_neg = paddle.clip( num_neg = paddle.clip(
cur_num_pos * self.neg_pos_ratio, max=pos.shape[1]) cur_num_pos * self.neg_pos_ratio, max=pos.shape[1])
num_neg = num_neg if num_neg > 0 else paddle.to_tensor(
[pos.shape[1] * mine_neg_ratio])
num_negs.append(num_neg) num_negs.append(num_neg)
num_neg = paddle.stack(num_negs).expand_as(idx_rank) num_negs = paddle.stack(num_negs).expand_as(idx_rank)
neg_mask = (idx_rank < num_neg).astype(conf_loss.dtype) neg_mask = (idx_rank < num_negs).astype(conf_loss.dtype)
return (neg_mask + pos).astype('bool') return (neg_mask + pos).astype('bool')
...@@ -141,22 +144,26 @@ class SSDLoss(nn.Layer): ...@@ -141,22 +144,26 @@ class SSDLoss(nn.Layer):
# Compute regression loss. # Compute regression loss.
# Select positive samples. # Select positive samples.
bbox_mask = (targets_label != bg_index).astype(boxes.dtype) bbox_mask = paddle.tile(targets_label != bg_index, [1, 1, 4])
loc_loss = bbox_mask * F.smooth_l1_loss( if bbox_mask.astype(boxes.dtype).sum() > 0:
boxes, targets_bbox, reduction='none') location = paddle.masked_select(boxes, bbox_mask)
loc_loss = loc_loss.sum() * self.loc_loss_weight targets_bbox = paddle.masked_select(targets_bbox, bbox_mask)
loc_loss = F.smooth_l1_loss(location, targets_bbox, reduction='sum')
loc_loss = loc_loss * self.loc_loss_weight
else:
loc_loss = paddle.zeros([1])
# Compute confidence loss. # Compute confidence loss.
conf_loss = F.softmax_with_cross_entropy(scores, targets_label) conf_loss = F.cross_entropy(scores, targets_label, reduction="none")
# Mining hard examples. # Mining hard examples.
label_mask = self._mine_hard_example( label_mask = self._mine_hard_example(
conf_loss.squeeze(-1), targets_label.squeeze(-1), bg_index) conf_loss.squeeze(-1), targets_label.squeeze(-1), bg_index)
conf_loss = conf_loss * label_mask.unsqueeze(-1).astype(conf_loss.dtype) conf_loss = paddle.masked_select(conf_loss, label_mask.unsqueeze(-1))
conf_loss = conf_loss.sum() * self.conf_loss_weight conf_loss = conf_loss.sum() * self.conf_loss_weight
# Compute overall weighted loss. # Compute overall weighted loss.
normalizer = (targets_label != bg_index).astype('float32').sum().clip( normalizer = (targets_label != bg_index).astype('float32').sum().clip(
min=1) min=1)
loss = (conf_loss + loc_loss) / (normalizer + 1e-9) loss = (conf_loss + loc_loss) / normalizer
return loss return loss
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册