提交 3a806a51 编写于 作者: Z Zhi Tian

add center sampling

上级 91790b86
......@@ -29,11 +29,44 @@ class FCOSLossComputation(object):
cfg.MODEL.FCOS.LOSS_GAMMA,
cfg.MODEL.FCOS.LOSS_ALPHA
)
self.strides = cfg.MODEL.FCOS.FPN_STRIDES
# we make use of IOU Loss for bounding boxes regression,
# but we found that L1 in log scale can yield a similar performance
self.box_reg_loss_func = IOULoss()
self.centerness_loss_func = nn.BCEWithLogitsLoss()
def get_sample_region(self, gt, strides, num_points_per, gt_xs, gt_ys, radius=1):
num_gts = gt.shape[0]
K = len(gt_xs)
gt = gt[None].expand(K, num_gts, 4)
center_x = (gt[..., 0] + gt[..., 2]) / 2
center_y = (gt[..., 1] + gt[..., 3]) / 2
center_gt = gt.new_zeros(gt.shape)
# no gt
if center_x[..., 0].sum() == 0:
return gt_xs.new_zeros(gt_xs.shape, dtype=torch.uint8)
beg = 0
for level, n_p in enumerate(num_points_per):
end = beg + n_p
stride = strides[level] * radius
xmin = center_x[beg:end] - stride
ymin = center_y[beg:end] - stride
xmax = center_x[beg:end] + stride
ymax = center_y[beg:end] + stride
# limit sample region in gt
center_gt[beg:end, :, 0] = torch.where(xmin > gt[beg:end, :, 0], xmin, gt[beg:end, :, 0])
center_gt[beg:end, :, 1] = torch.where(ymin > gt[beg:end, :, 1], ymin, gt[beg:end, :, 1])
center_gt[beg:end, :, 2] = torch.where(xmax > gt[beg:end, :, 2], gt[beg:end, :, 2], xmax)
center_gt[beg:end, :, 3] = torch.where(ymax > gt[beg:end, :, 3], gt[beg:end, :, 3], ymax)
beg = end
left = gt_xs[:, None] - center_gt[..., 0]
right = center_gt[..., 2] - gt_xs[:, None]
top = gt_ys[:, None] - center_gt[..., 1]
bottom = center_gt[..., 3] - gt_ys[:, None]
center_bbox = torch.stack((left, top, right, bottom), -1)
inside_gt_bbox_mask = center_bbox.min(-1)[0] > 0
return inside_gt_bbox_mask
def prepare_targets(self, points, targets):
object_sizes_of_interest = [
[-1, 64],
......@@ -52,6 +85,7 @@ class FCOSLossComputation(object):
expanded_object_sizes_of_interest = torch.cat(expanded_object_sizes_of_interest, dim=0)
num_points_per_level = [len(points_per_level) for points_per_level in points]
self.num_points_per_level = num_points_per_level
points_all_level = torch.cat(points, dim=0)
labels, reg_targets = self.compute_targets_for_locations(
points_all_level, targets, expanded_object_sizes_of_interest
......@@ -91,7 +125,14 @@ class FCOSLossComputation(object):
b = bboxes[:, 3][None] - ys[:, None]
reg_targets_per_im = torch.stack([l, t, r, b], dim=2)
is_in_boxes = reg_targets_per_im.min(dim=2)[0] > 0
is_in_boxes = self.get_sample_region(
bboxes,
self.strides,
self.num_points_per_level,
xs,
ys,
radius=1.5
)
max_reg_targets_per_im = reg_targets_per_im.max(dim=2)[0]
# limit the regression range for each location
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册