diff --git a/ppdet/modeling/losses/ssd_loss.py b/ppdet/modeling/losses/ssd_loss.py index 961a8c4e8d899c76c4baa50b82f133756488a0b9..0b68f317f15e736f1c741535363652c3c5e8a5e7 100644 --- a/ppdet/modeling/losses/ssd_loss.py +++ b/ppdet/modeling/losses/ssd_loss.py @@ -19,9 +19,9 @@ from __future__ import print_function import paddle import paddle.nn as nn import paddle.nn.functional as F -import numpy as np from ppdet.core.workspace import register -from ..ops import bipartite_match, box_coder, iou_similarity +from ..ops import iou_similarity +from ..bbox_utils import bbox2delta __all__ = ['SSDLoss'] @@ -32,191 +32,131 @@ class SSDLoss(nn.Layer): SSDLoss Args: - match_type (str): The type of matching method, should be - 'bipartite' or 'per_prediction'. None ('bipartite') by default. - overlap_threshold (float32, optional): If `match_type` is 'per_prediction', - this threshold is to determine the extra matching bboxes based - on the maximum distance, 0.5 by default. + overlap_threshold (float32, optional): IoU threshold for negative bboxes + and positive bboxes, 0.5 by default. neg_pos_ratio (float): The ratio of negative samples / positive samples. - neg_overlap (float): The overlap threshold of negative samples. loc_loss_weight (float): The weight of loc_loss. conf_loss_weight (float): The weight of conf_loss. + prior_box_var (list): Variances corresponding to prior box coord, [0.1, + 0.1, 0.2, 0.2] by default. """ def __init__(self, - match_type='per_prediction', overlap_threshold=0.5, neg_pos_ratio=3.0, - neg_overlap=0.5, loc_loss_weight=1.0, - conf_loss_weight=1.0): + conf_loss_weight=1.0, + prior_box_var=[0.1, 0.1, 0.2, 0.2]): super(SSDLoss, self).__init__() - self.match_type = match_type self.overlap_threshold = overlap_threshold self.neg_pos_ratio = neg_pos_ratio - self.neg_overlap = neg_overlap self.loc_loss_weight = loc_loss_weight self.conf_loss_weight = conf_loss_weight - - def _label_target_assign(self, - gt_label, - matched_indices, - neg_mask=None, - mismatch_value=0): - gt_label = gt_label.numpy() - matched_indices = matched_indices.numpy() - if neg_mask is not None: - neg_mask = neg_mask.numpy() - - batch_size, num_priors = matched_indices.shape - trg_lbl = np.ones((batch_size, num_priors, 1)).astype('int32') - trg_lbl *= mismatch_value - trg_lbl_wt = np.zeros((batch_size, num_priors, 1)).astype('float32') - + self.prior_box_var = [1. / a for a in prior_box_var] + + def _bipartite_match_for_batch(self, gt_bbox, gt_label, prior_boxes, + bg_index): + """ + Args: + gt_bbox (Tensor): [B, N, 4] + gt_label (Tensor): [B, N, 1] + prior_boxes (Tensor): [A, 4] + bg_index (int): Background class index + """ + batch_size, num_priors = gt_bbox.shape[0], prior_boxes.shape[0] + ious = iou_similarity(gt_bbox.reshape((-1, 4)), prior_boxes).reshape( + (batch_size, -1, num_priors)) + + # Calculate the number of object per sample. + num_object = (ious.sum(axis=-1) > 0).astype('int64').sum(axis=-1) + + # For each prior box, get the max IoU of all GTs. + prior_max_iou, prior_argmax_iou = ious.max(axis=1), ious.argmax(axis=1) + # For each GT, get the max IoU of all prior boxes. + gt_max_iou, gt_argmax_iou = ious.max(axis=2), ious.argmax(axis=2) + + # Gather target bbox and label according to 'prior_argmax_iou' index. + batch_ind = paddle.arange( + 0, batch_size, dtype='int64').unsqueeze(-1).tile([1, num_priors]) + prior_argmax_iou = paddle.stack([batch_ind, prior_argmax_iou], axis=-1) + targets_bbox = paddle.gather_nd(gt_bbox, prior_argmax_iou) + targets_label = paddle.gather_nd(gt_label, prior_argmax_iou) + # Assign negative + bg_index_tensor = paddle.full([batch_size, num_priors, 1], bg_index, + 'int64') + targets_label = paddle.where( + prior_max_iou.unsqueeze(-1) < self.overlap_threshold, + bg_index_tensor, targets_label) + + # Ensure each GT can match the max IoU prior box. for i in range(batch_size): - col_ids = np.where(matched_indices[i] > -1) - col_val = matched_indices[i][col_ids] - trg_lbl[i][col_ids] = gt_label[i][col_val] - trg_lbl_wt[i][col_ids] = 1.0 - - if neg_mask is not None: - trg_lbl_wt += neg_mask[:, :, np.newaxis] - - return paddle.to_tensor(trg_lbl), paddle.to_tensor(trg_lbl_wt) - - def _bbox_target_assign(self, encoded_box, matched_indices): - encoded_box = encoded_box.numpy() - matched_indices = matched_indices.numpy() - - batch_size, num_priors = matched_indices.shape - trg_bbox = np.zeros((batch_size, num_priors, 4)).astype('float32') - trg_bbox_wt = np.zeros((batch_size, num_priors, 1)).astype('float32') - - for i in range(batch_size): - col_ids = np.where(matched_indices[i] > -1) - col_val = matched_indices[i][col_ids] - for v, c in zip(col_val.tolist(), col_ids[0]): - trg_bbox[i][c] = encoded_box[i][v][c] - trg_bbox_wt[i][col_ids] = 1.0 - - return paddle.to_tensor(trg_bbox), paddle.to_tensor(trg_bbox_wt) - - def _mine_hard_example(self, - conf_loss, - matched_indices, - matched_dist, - neg_pos_ratio=3.0, - neg_overlap=0.5): - pos = (matched_indices > -1).astype(conf_loss.dtype) + if num_object[i] > 0: + targets_bbox[i] = paddle.scatter( + targets_bbox[i], gt_argmax_iou[i, :int(num_object[i])], + gt_bbox[i, :int(num_object[i])]) + targets_label[i] = paddle.scatter( + targets_label[i], gt_argmax_iou[i, :int(num_object[i])], + gt_label[i, :int(num_object[i])]) + + # Encode box + prior_boxes = prior_boxes.unsqueeze(0).tile([batch_size, 1, 1]) + targets_bbox = bbox2delta( + prior_boxes.reshape([-1, 4]), + targets_bbox.reshape([-1, 4]), self.prior_box_var) + targets_bbox = targets_bbox.reshape([batch_size, -1, 4]) + + return targets_bbox, targets_label + + def _mine_hard_example(self, conf_loss, targets_label, bg_index): + pos = (targets_label != bg_index).astype(conf_loss.dtype) num_pos = pos.sum(axis=1, keepdim=True) - neg = (matched_dist < neg_overlap).astype(conf_loss.dtype) + neg = (targets_label == bg_index).astype(conf_loss.dtype) - conf_loss = conf_loss * (1.0 - pos) * neg + conf_loss = conf_loss.clone() * neg loss_idx = conf_loss.argsort(axis=1, descending=True) idx_rank = loss_idx.argsort(axis=1) num_negs = [] - for i in range(matched_indices.shape[0]): - cur_idx = loss_idx[i] + for i in range(conf_loss.shape[0]): cur_num_pos = num_pos[i] - num_neg = paddle.clip(cur_num_pos * neg_pos_ratio, max=pos.shape[1]) + num_neg = paddle.clip( + cur_num_pos * self.neg_pos_ratio, max=pos.shape[1]) num_negs.append(num_neg) - num_neg = paddle.stack(num_negs, axis=0).expand_as(idx_rank) + num_neg = paddle.stack(num_negs).expand_as(idx_rank) neg_mask = (idx_rank < num_neg).astype(conf_loss.dtype) - return neg_mask - def forward(self, boxes, scores, gt_box, gt_class, anchors): + return (neg_mask + pos).astype('bool') + + def forward(self, boxes, scores, gt_bbox, gt_label, prior_boxes): boxes = paddle.concat(boxes, axis=1) scores = paddle.concat(scores, axis=1) - prior_boxes = paddle.concat(anchors, axis=0) - gt_label = gt_class.unsqueeze(-1) - batch_size, num_priors = scores.shape[:2] - num_classes = scores.shape[-1] - 1 - - def _reshape_to_2d(x): - return paddle.flatten(x, start_axis=2) - - # 1. Find matched bounding box by prior box. - # 1.1 Compute IOU similarity between ground-truth boxes and prior boxes. - # 1.2 Compute matched bounding box by bipartite matching algorithm. - matched_indices = [] - matched_dist = [] - for i in range(gt_box.shape[0]): - iou = iou_similarity(gt_box[i], prior_boxes) - matched_indice, matched_d = bipartite_match(iou, self.match_type, - self.overlap_threshold) - matched_indices.append(matched_indice) - matched_dist.append(matched_d) - matched_indices = paddle.concat(matched_indices, axis=0) - matched_indices.stop_gradient = True - matched_dist = paddle.concat(matched_dist, axis=0) - matched_dist.stop_gradient = True - - # 2. Compute confidence for mining hard examples - # 2.1. Get the target label based on matched indices - target_label, _ = self._label_target_assign( - gt_label, matched_indices, mismatch_value=num_classes) - confidence = _reshape_to_2d(scores) - # 2.2. Compute confidence loss. - # Reshape confidence to 2D tensor. - target_label = _reshape_to_2d(target_label).astype('int64') - conf_loss = F.softmax_with_cross_entropy(confidence, target_label) - conf_loss = paddle.reshape(conf_loss, [batch_size, num_priors]) - - # 3. Mining hard examples - neg_mask = self._mine_hard_example( - conf_loss, - matched_indices, - matched_dist, - neg_pos_ratio=self.neg_pos_ratio, - neg_overlap=self.neg_overlap) - - # 4. Assign classification and regression targets - # 4.1. Encoded bbox according to the prior boxes. - prior_box_var = paddle.to_tensor( - np.array( - [0.1, 0.1, 0.2, 0.2], dtype='float32')).reshape( - [1, 4]).expand_as(prior_boxes) - encoded_bbox = [] - for i in range(gt_box.shape[0]): - encoded_bbox.append( - box_coder( - prior_box=prior_boxes, - prior_box_var=prior_box_var, - target_box=gt_box[i], - code_type='encode_center_size')) - encoded_bbox = paddle.stack(encoded_bbox, axis=0) - # 4.2. Assign regression targets - target_bbox, target_loc_weight = self._bbox_target_assign( - encoded_bbox, matched_indices) - # 4.3. Assign classification targets - target_label, target_conf_weight = self._label_target_assign( - gt_label, - matched_indices, - neg_mask=neg_mask, - mismatch_value=num_classes) - - # 5. Compute loss. - # 5.1 Compute confidence loss. - target_label = _reshape_to_2d(target_label).astype('int64') - conf_loss = F.softmax_with_cross_entropy(confidence, target_label) - - target_conf_weight = _reshape_to_2d(target_conf_weight) - conf_loss = conf_loss * target_conf_weight * self.conf_loss_weight - - # 5.2 Compute regression loss. - location = _reshape_to_2d(boxes) - target_bbox = _reshape_to_2d(target_bbox) - - loc_loss = F.smooth_l1_loss(location, target_bbox, reduction='none') - loc_loss = paddle.sum(loc_loss, axis=-1, keepdim=True) - target_loc_weight = _reshape_to_2d(target_loc_weight) - loc_loss = loc_loss * target_loc_weight * self.loc_loss_weight - - # 5.3 Compute overall weighted loss. - loss = conf_loss + loc_loss - loss = paddle.reshape(loss, [batch_size, num_priors]) - loss = paddle.sum(loss, axis=1, keepdim=True) - normalizer = paddle.sum(target_loc_weight) - loss = paddle.sum(loss / normalizer) + gt_label = gt_label.unsqueeze(-1).astype('int64') + prior_boxes = paddle.concat(prior_boxes, axis=0) + bg_index = scores.shape[-1] - 1 + + # Match bbox and get targets. + targets_bbox, targets_label = \ + self._bipartite_match_for_batch(gt_bbox, gt_label, prior_boxes, bg_index) + targets_bbox.stop_gradient = True + targets_label.stop_gradient = True + + # Compute regression loss. + # Select positive samples. + bbox_mask = (targets_label != bg_index).astype(boxes.dtype) + loc_loss = bbox_mask * F.smooth_l1_loss( + boxes, targets_bbox, reduction='none') + loc_loss = loc_loss.sum() * self.loc_loss_weight + + # Compute confidence loss. + conf_loss = F.softmax_with_cross_entropy(scores, targets_label) + # Mining hard examples. + label_mask = self._mine_hard_example( + conf_loss.squeeze(-1), targets_label.squeeze(-1), bg_index) + conf_loss = conf_loss * label_mask.unsqueeze(-1).astype(conf_loss.dtype) + conf_loss = conf_loss.sum() * self.conf_loss_weight + + # Compute overall weighted loss. + normalizer = (targets_label != bg_index).astype('float32').sum().clip( + min=1) + loss = (conf_loss + loc_loss) / (normalizer + 1e-9) return loss