diff --git a/ppdet/modeling/losses/ssd_loss.py b/ppdet/modeling/losses/ssd_loss.py index 0b68f317f15e736f1c741535363652c3c5e8a5e7..62aecc1f33a104531edc2a77015e27847bb92506 100644 --- a/ppdet/modeling/losses/ssd_loss.py +++ b/ppdet/modeling/losses/ssd_loss.py @@ -67,18 +67,15 @@ class SSDLoss(nn.Layer): ious = iou_similarity(gt_bbox.reshape((-1, 4)), prior_boxes).reshape( (batch_size, -1, num_priors)) - # Calculate the number of object per sample. - num_object = (ious.sum(axis=-1) > 0).astype('int64').sum(axis=-1) - # For each prior box, get the max IoU of all GTs. prior_max_iou, prior_argmax_iou = ious.max(axis=1), ious.argmax(axis=1) # For each GT, get the max IoU of all prior boxes. gt_max_iou, gt_argmax_iou = ious.max(axis=2), ious.argmax(axis=2) # Gather target bbox and label according to 'prior_argmax_iou' index. - batch_ind = paddle.arange( - 0, batch_size, dtype='int64').unsqueeze(-1).tile([1, num_priors]) - prior_argmax_iou = paddle.stack([batch_ind, prior_argmax_iou], axis=-1) + batch_ind = paddle.arange(end=batch_size, dtype='int64').unsqueeze(-1) + prior_argmax_iou = paddle.stack( + [batch_ind.tile([1, num_priors]), prior_argmax_iou], axis=-1) targets_bbox = paddle.gather_nd(gt_bbox, prior_argmax_iou) targets_label = paddle.gather_nd(gt_label, prior_argmax_iou) # Assign negative @@ -89,14 +86,14 @@ class SSDLoss(nn.Layer): bg_index_tensor, targets_label) # Ensure each GT can match the max IoU prior box. - for i in range(batch_size): - if num_object[i] > 0: - targets_bbox[i] = paddle.scatter( - targets_bbox[i], gt_argmax_iou[i, :int(num_object[i])], - gt_bbox[i, :int(num_object[i])]) - targets_label[i] = paddle.scatter( - targets_label[i], gt_argmax_iou[i, :int(num_object[i])], - gt_label[i, :int(num_object[i])]) + batch_ind = (batch_ind * num_priors + gt_argmax_iou).flatten() + targets_bbox = paddle.scatter( + targets_bbox.reshape([-1, 4]), batch_ind, + gt_bbox.reshape([-1, 4])).reshape([batch_size, -1, 4]) + targets_label = paddle.scatter( + targets_label.reshape([-1, 1]), batch_ind, + gt_label.reshape([-1, 1])).reshape([batch_size, -1, 1]) + targets_label[:, :1] = bg_index # Encode box prior_boxes = prior_boxes.unsqueeze(0).tile([batch_size, 1, 1]) @@ -107,12 +104,16 @@ class SSDLoss(nn.Layer): return targets_bbox, targets_label - def _mine_hard_example(self, conf_loss, targets_label, bg_index): + def _mine_hard_example(self, + conf_loss, + targets_label, + bg_index, + mine_neg_ratio=0.01): pos = (targets_label != bg_index).astype(conf_loss.dtype) num_pos = pos.sum(axis=1, keepdim=True) neg = (targets_label == bg_index).astype(conf_loss.dtype) - conf_loss = conf_loss.clone() * neg + conf_loss = conf_loss.detach() * neg loss_idx = conf_loss.argsort(axis=1, descending=True) idx_rank = loss_idx.argsort(axis=1) num_negs = [] @@ -120,9 +121,11 @@ class SSDLoss(nn.Layer): cur_num_pos = num_pos[i] num_neg = paddle.clip( cur_num_pos * self.neg_pos_ratio, max=pos.shape[1]) + num_neg = num_neg if num_neg > 0 else paddle.to_tensor( + [pos.shape[1] * mine_neg_ratio]) num_negs.append(num_neg) - num_neg = paddle.stack(num_negs).expand_as(idx_rank) - neg_mask = (idx_rank < num_neg).astype(conf_loss.dtype) + num_negs = paddle.stack(num_negs).expand_as(idx_rank) + neg_mask = (idx_rank < num_negs).astype(conf_loss.dtype) return (neg_mask + pos).astype('bool') @@ -141,22 +144,26 @@ class SSDLoss(nn.Layer): # Compute regression loss. # Select positive samples. - bbox_mask = (targets_label != bg_index).astype(boxes.dtype) - loc_loss = bbox_mask * F.smooth_l1_loss( - boxes, targets_bbox, reduction='none') - loc_loss = loc_loss.sum() * self.loc_loss_weight + bbox_mask = paddle.tile(targets_label != bg_index, [1, 1, 4]) + if bbox_mask.astype(boxes.dtype).sum() > 0: + location = paddle.masked_select(boxes, bbox_mask) + targets_bbox = paddle.masked_select(targets_bbox, bbox_mask) + loc_loss = F.smooth_l1_loss(location, targets_bbox, reduction='sum') + loc_loss = loc_loss * self.loc_loss_weight + else: + loc_loss = paddle.zeros([1]) # Compute confidence loss. - conf_loss = F.softmax_with_cross_entropy(scores, targets_label) + conf_loss = F.cross_entropy(scores, targets_label, reduction="none") # Mining hard examples. label_mask = self._mine_hard_example( conf_loss.squeeze(-1), targets_label.squeeze(-1), bg_index) - conf_loss = conf_loss * label_mask.unsqueeze(-1).astype(conf_loss.dtype) + conf_loss = paddle.masked_select(conf_loss, label_mask.unsqueeze(-1)) conf_loss = conf_loss.sum() * self.conf_loss_weight # Compute overall weighted loss. normalizer = (targets_label != bg_index).astype('float32').sum().clip( min=1) - loss = (conf_loss + loc_loss) / (normalizer + 1e-9) + loss = (conf_loss + loc_loss) / normalizer return loss