From 3a806a51e77660ca5be8029eafb9149773e879a5 Mon Sep 17 00:00:00 2001 From: Zhi Tian Date: Wed, 14 Aug 2019 17:25:31 +0930 Subject: [PATCH] add center sampling --- fcos_core/modeling/rpn/fcos/loss.py | 43 ++++++++++++++++++++++++++++- 1 file changed, 42 insertions(+), 1 deletion(-) diff --git a/fcos_core/modeling/rpn/fcos/loss.py b/fcos_core/modeling/rpn/fcos/loss.py index 2af735c..bdbffd1 100644 --- a/fcos_core/modeling/rpn/fcos/loss.py +++ b/fcos_core/modeling/rpn/fcos/loss.py @@ -29,11 +29,44 @@ class FCOSLossComputation(object): cfg.MODEL.FCOS.LOSS_GAMMA, cfg.MODEL.FCOS.LOSS_ALPHA ) + self.strides = cfg.MODEL.FCOS.FPN_STRIDES # we make use of IOU Loss for bounding boxes regression, # but we found that L1 in log scale can yield a similar performance self.box_reg_loss_func = IOULoss() self.centerness_loss_func = nn.BCEWithLogitsLoss() + def get_sample_region(self, gt, strides, num_points_per, gt_xs, gt_ys, radius=1): + num_gts = gt.shape[0] + K = len(gt_xs) + gt = gt[None].expand(K, num_gts, 4) + center_x = (gt[..., 0] + gt[..., 2]) / 2 + center_y = (gt[..., 1] + gt[..., 3]) / 2 + center_gt = gt.new_zeros(gt.shape) + # no gt + if center_x[..., 0].sum() == 0: + return gt_xs.new_zeros(gt_xs.shape, dtype=torch.uint8) + beg = 0 + for level, n_p in enumerate(num_points_per): + end = beg + n_p + stride = strides[level] * radius + xmin = center_x[beg:end] - stride + ymin = center_y[beg:end] - stride + xmax = center_x[beg:end] + stride + ymax = center_y[beg:end] + stride + # limit sample region in gt + center_gt[beg:end, :, 0] = torch.where(xmin > gt[beg:end, :, 0], xmin, gt[beg:end, :, 0]) + center_gt[beg:end, :, 1] = torch.where(ymin > gt[beg:end, :, 1], ymin, gt[beg:end, :, 1]) + center_gt[beg:end, :, 2] = torch.where(xmax > gt[beg:end, :, 2], gt[beg:end, :, 2], xmax) + center_gt[beg:end, :, 3] = torch.where(ymax > gt[beg:end, :, 3], gt[beg:end, :, 3], ymax) + beg = end + left = gt_xs[:, None] - center_gt[..., 0] + right = center_gt[..., 2] - gt_xs[:, None] + top = gt_ys[:, None] - center_gt[..., 1] + bottom = center_gt[..., 3] - gt_ys[:, None] + center_bbox = torch.stack((left, top, right, bottom), -1) + inside_gt_bbox_mask = center_bbox.min(-1)[0] > 0 + return inside_gt_bbox_mask + def prepare_targets(self, points, targets): object_sizes_of_interest = [ [-1, 64], @@ -52,6 +85,7 @@ class FCOSLossComputation(object): expanded_object_sizes_of_interest = torch.cat(expanded_object_sizes_of_interest, dim=0) num_points_per_level = [len(points_per_level) for points_per_level in points] + self.num_points_per_level = num_points_per_level points_all_level = torch.cat(points, dim=0) labels, reg_targets = self.compute_targets_for_locations( points_all_level, targets, expanded_object_sizes_of_interest @@ -91,7 +125,14 @@ class FCOSLossComputation(object): b = bboxes[:, 3][None] - ys[:, None] reg_targets_per_im = torch.stack([l, t, r, b], dim=2) - is_in_boxes = reg_targets_per_im.min(dim=2)[0] > 0 + is_in_boxes = self.get_sample_region( + bboxes, + self.strides, + self.num_points_per_level, + xs, + ys, + radius=1.5 + ) max_reg_targets_per_im = reg_targets_per_im.max(dim=2)[0] # limit the regression range for each location -- GitLab