target.py 14.3 KB
Newer Older
F
FDInSky 已提交
1 2 3 4 5 6 7 8 9
import six
import math
import numpy as np
from numba import jit
from .bbox import *
from .mask import *


@jit
10
def generate_rpn_anchor_target(anchors,
F
FDInSky 已提交
11 12 13 14 15 16 17 18
                               gt_boxes,
                               is_crowd,
                               im_info,
                               rpn_straddle_thresh,
                               rpn_batch_size_per_im,
                               rpn_positive_overlap,
                               rpn_negative_overlap,
                               rpn_fg_fraction,
19 20 21
                               use_random=True,
                               anchor_reg_weights=[1., 1., 1., 1.]):
    anchor_num = anchors.shape[0]
F
FDInSky 已提交
22 23
    batch_size = gt_boxes.shape[0]

24 25 26 27 28 29
    loc_indexes = []
    cls_indexes = []
    tgt_labels = []
    tgt_deltas = []
    anchor_inside_weights = []

F
FDInSky 已提交
30
    for i in range(batch_size):
31 32

        # TODO: move anchor filter into anchor generator 
F
FDInSky 已提交
33 34 35 36
        im_height = im_info[i][0]
        im_width = im_info[i][1]
        im_scale = im_info[i][2]
        if rpn_straddle_thresh >= 0:
37 38 39 40 41
            anchor_inds = np.where((anchors[:, 0] >= -rpn_straddle_thresh) & (
                anchors[:, 1] >= -rpn_straddle_thresh) & (
                    anchors[:, 2] < im_width + rpn_straddle_thresh) & (
                        anchors[:, 3] < im_height + rpn_straddle_thresh))[0]
            anchor = anchors[anchor_inds, :]
F
FDInSky 已提交
42
        else:
43 44
            anchor_inds = np.arange(anchors.shape[0])
            anchor = anchors
F
FDInSky 已提交
45

46 47
        gt_bbox = gt_boxes[i] * im_scale
        is_crowd_slice = is_crowd[i]
F
FDInSky 已提交
48
        not_crowd_inds = np.where(is_crowd_slice == 0)[0]
49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87
        gt_bbox = gt_bbox[not_crowd_inds]

        # Step1: match anchor and gt_bbox
        anchor_gt_bbox_inds, anchor_gt_bbox_iou, labels = label_anchor(anchor,
                                                                       gt_bbox)

        # Step2: sample anchor 
        fg_inds, bg_inds, fg_fake_inds, fake_num = sample_anchor(
            anchor_gt_bbox_iou, labels, rpn_positive_overlap,
            rpn_negative_overlap, rpn_batch_size_per_im, rpn_fg_fraction,
            use_random)

        # Step3: make output  
        loc_inds = np.hstack([fg_fake_inds, fg_inds])
        cls_inds = np.hstack([fg_inds, bg_inds])

        sampled_labels = labels[cls_inds]

        sampled_anchors = anchor[loc_inds]
        sampled_gt_boxes = gt_bbox[anchor_gt_bbox_inds[loc_inds]]
        sampled_deltas = bbox2delta(sampled_anchors, sampled_gt_boxes,
                                    anchor_reg_weights)

        anchor_inside_weight = np.zeros((len(loc_inds), 4), dtype=np.float32)
        anchor_inside_weight[fake_num:, :] = 1

        loc_indexes.append(anchor_inds[loc_inds] + i * anchor_num)
        cls_indexes.append(anchor_inds[cls_inds] + i * anchor_num)
        tgt_labels.append(sampled_labels)
        tgt_deltas.append(sampled_deltas)
        anchor_inside_weights.append(anchor_inside_weight)

    loc_indexes = np.concatenate(loc_indexes)
    cls_indexes = np.concatenate(cls_indexes)
    tgt_labels = np.concatenate(tgt_labels).astype('float32')
    tgt_deltas = np.vstack(tgt_deltas).astype('float32')
    anchor_inside_weights = np.vstack(anchor_inside_weights)

    return loc_indexes, cls_indexes, tgt_labels, tgt_deltas, anchor_inside_weights
F
FDInSky 已提交
88 89 90


@jit
91
def label_anchor(anchors, gt_boxes):
92
    iou = bbox_overlaps(anchors, gt_boxes)
93 94 95 96 97 98 99 100 101 102 103 104 105 106

    # every gt's anchor's index
    gt_bbox_anchor_inds = iou.argmax(axis=0)
    gt_bbox_anchor_iou = iou[gt_bbox_anchor_inds, np.arange(iou.shape[1])]
    gt_bbox_anchor_iou_inds = np.where(iou == gt_bbox_anchor_iou)[0]

    # every anchor's gt bbox's index 
    anchor_gt_bbox_inds = iou.argmax(axis=1)
    anchor_gt_bbox_iou = iou[np.arange(iou.shape[0]), anchor_gt_bbox_inds]

    labels = np.ones((iou.shape[0], ), dtype=np.int32) * -1
    labels[gt_bbox_anchor_iou_inds] = 1

    return anchor_gt_bbox_inds, anchor_gt_bbox_iou, labels
F
FDInSky 已提交
107

108 109 110 111 112 113 114 115 116 117 118

@jit
def sample_anchor(anchor_gt_bbox_iou,
                  labels,
                  rpn_positive_overlap,
                  rpn_negative_overlap,
                  rpn_batch_size_per_im,
                  rpn_fg_fraction,
                  use_random=True):

    labels[anchor_gt_bbox_iou >= rpn_positive_overlap] = 1
F
FDInSky 已提交
119 120 121 122 123 124 125 126 127 128 129
    num_fg = int(rpn_fg_fraction * rpn_batch_size_per_im)
    fg_inds = np.where(labels == 1)[0]
    if len(fg_inds) > num_fg and use_random:
        disable_inds = np.random.choice(
            fg_inds, size=(len(fg_inds) - num_fg), replace=False)
    else:
        disable_inds = fg_inds[num_fg:]
    labels[disable_inds] = -1
    fg_inds = np.where(labels == 1)[0]

    num_bg = rpn_batch_size_per_im - np.sum(labels == 1)
130
    bg_inds = np.where(anchor_gt_bbox_iou < rpn_negative_overlap)[0]
F
FDInSky 已提交
131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147
    if len(bg_inds) > num_bg and use_random:
        enable_inds = bg_inds[np.random.randint(len(bg_inds), size=num_bg)]
    else:
        enable_inds = bg_inds[:num_bg]

    fg_fake_inds = np.array([], np.int32)
    fg_value = np.array([fg_inds[0]], np.int32)
    fake_num = 0
    for bg_id in enable_inds:
        if bg_id in fg_inds:
            fake_num += 1
            fg_fake_inds = np.hstack([fg_fake_inds, fg_value])
    labels[enable_inds] = 0

    fg_inds = np.where(labels == 1)[0]
    bg_inds = np.where(labels == 0)[0]

148
    return fg_inds, bg_inds, fg_fake_inds, fake_num
F
FDInSky 已提交
149 150 151 152


@jit
def generate_proposal_target(rpn_rois,
153
                             rpn_rois_num,
F
FDInSky 已提交
154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169
                             gt_classes,
                             is_crowd,
                             gt_boxes,
                             im_info,
                             batch_size_per_im,
                             fg_fraction,
                             fg_thresh,
                             bg_thresh_hi,
                             bg_thresh_lo,
                             bbox_reg_weights,
                             class_nums=81,
                             use_random=True,
                             is_cls_agnostic=False,
                             is_cascade_rcnn=False):

    rois = []
170 171 172 173
    tgt_labels = []
    tgt_deltas = []
    rois_inside_weights = []
    rois_outside_weights = []
174
    new_rois_num = []
F
FDInSky 已提交
175
    st_num = 0
176
    end_num = 0
177 178 179
    for im_i in range(len(rpn_rois_num)):
        length = rpn_rois_num[im_i]
        end_num += length
180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203

        rpn_roi = rpn_rois[st_num:end_num]
        im_scale = im_info[im_i][2]
        rpn_roi = rpn_roi / im_scale
        gt_bbox = gt_boxes[im_i]

        if is_cascade_rcnn:
            rpn_roi = rpn_roi[gt_bbox.shape[0]:, :]
        bbox = np.vstack([gt_bbox, rpn_roi])

        # Step1: label bbox 
        roi_gt_bbox_inds, roi_gt_bbox_iou, labels, = label_bbox(
            bbox, gt_bbox, gt_classes[im_i], is_crowd[im_i])

        # Step2: sample bbox 
        if is_cascade_rcnn:
            ws = bbox[:, 2] - bbox[:, 0] + 1
            hs = bbox[:, 3] - bbox[:, 1] + 1
            keep = np.where((ws > 0) & (hs > 0))[0]
            bbox = bbox[keep]

        fg_inds, bg_inds, fg_nums = sample_bbox(
            roi_gt_bbox_iou, batch_size_per_im, fg_fraction, fg_thresh,
            bg_thresh_hi, bg_thresh_lo, bbox_reg_weights, class_nums,
F
FDInSky 已提交
204 205
            use_random, is_cls_agnostic, is_cascade_rcnn)

206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222
        # Step3: make output 
        sampled_inds = np.append(fg_inds, bg_inds)

        sampled_labels = labels[sampled_inds]
        sampled_labels[fg_nums:] = 0

        sampled_boxes = bbox[sampled_inds]
        sampled_gt_boxes = gt_bbox[roi_gt_bbox_inds[sampled_inds]]
        sampled_gt_boxes[fg_nums:, :] = gt_bbox[0]
        sampled_deltas = compute_bbox_targets(sampled_boxes, sampled_gt_boxes,
                                              sampled_labels, bbox_reg_weights)
        sampled_deltas, bbox_inside_weights = expand_bbox_targets(
            sampled_deltas, class_nums, is_cls_agnostic)
        bbox_outside_weights = np.array(
            bbox_inside_weights > 0, dtype=bbox_inside_weights.dtype)

        roi = sampled_boxes * im_scale
223
        st_num += length
224 225

        rois.append(roi)
226
        new_rois_num.append(roi.shape[0])
227 228 229 230
        tgt_labels.append(sampled_labels)
        tgt_deltas.append(sampled_deltas)
        rois_inside_weights.append(bbox_inside_weights)
        rois_outside_weights.append(bbox_outside_weights)
F
FDInSky 已提交
231 232

    rois = np.concatenate(rois, axis=0).astype(np.float32)
233 234 235 236 237 238 239
    tgt_labels = np.concatenate(
        tgt_labels, axis=0).astype(np.int32).reshape(-1, 1)
    tgt_deltas = np.concatenate(tgt_deltas, axis=0).astype(np.float32)
    rois_inside_weights = np.concatenate(
        rois_inside_weights, axis=0).astype(np.float32)
    rois_outside_weights = np.concatenate(
        rois_outside_weights, axis=0).astype(np.float32)
240 241
    new_rois_num = np.asarray(new_rois_num, np.int32)
    return rois, tgt_labels, tgt_deltas, rois_inside_weights, rois_outside_weights, new_rois_num
F
FDInSky 已提交
242 243 244


@jit
245 246 247 248 249 250 251
def label_bbox(boxes,
               gt_boxes,
               gt_classes,
               is_crowd,
               class_nums=81,
               is_cascade_rcnn=False):

252
    iou = bbox_overlaps(boxes, gt_boxes)
253 254 255 256 257 258 259 260 261 262 263 264 265

    # every roi's gt box's index  
    roi_gt_bbox_inds = np.zeros((boxes.shape[0]), dtype=np.int32)
    roi_gt_bbox_iou = np.zeros((boxes.shape[0], class_nums))

    iou_argmax = iou.argmax(axis=1)
    iou_max = iou.max(axis=1)
    overlapped_boxes_ind = np.where(iou_max > 0)[0].astype('int32')
    roi_gt_bbox_inds[overlapped_boxes_ind] = iou_argmax[overlapped_boxes_ind]
    overlapped_boxes_gt_classes = gt_classes[iou_argmax[
        overlapped_boxes_ind]].astype('int32')
    roi_gt_bbox_iou[overlapped_boxes_ind,
                    overlapped_boxes_gt_classes] = iou_max[overlapped_boxes_ind]
F
FDInSky 已提交
266 267

    crowd_ind = np.where(is_crowd)[0]
268 269 270
    roi_gt_bbox_iou[crowd_ind] = -1

    labels = roi_gt_bbox_iou.argmax(axis=1)
F
FDInSky 已提交
271

272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290
    return roi_gt_bbox_inds, roi_gt_bbox_iou, labels


@jit
def sample_bbox(roi_gt_bbox_iou,
                batch_size_per_im,
                fg_fraction,
                fg_thresh,
                bg_thresh_hi,
                bg_thresh_lo,
                bbox_reg_weights,
                class_nums,
                use_random=True,
                is_cls_agnostic=False,
                is_cascade_rcnn=False):

    roi_gt_bbox_iou_max = roi_gt_bbox_iou.max(axis=1)
    rois_per_image = int(batch_size_per_im)
    fg_rois_per_im = int(np.round(fg_fraction * rois_per_image))
F
FDInSky 已提交
291 292

    if is_cascade_rcnn:
293 294 295 296 297
        fg_inds = np.where(roi_gt_bbox_iou_max >= fg_thresh)[0]
        bg_inds = np.where((roi_gt_bbox_iou_max < bg_thresh_hi) & (
            roi_gt_bbox_iou_max >= bg_thresh_lo))[0]
        fg_nums = fg_inds.shape[0]
        bg_nums = bg_inds.shape[0]
F
FDInSky 已提交
298
    else:
299 300 301 302 303 304
        # sampe fg 
        fg_inds = np.where(roi_gt_bbox_iou_max >= fg_thresh)[0]
        fg_nums = np.minimum(fg_rois_per_im, fg_inds.shape[0])
        if (fg_inds.shape[0] > fg_nums) and use_random:
            fg_inds = np.random.choice(fg_inds, size=fg_nums, replace=False)
        fg_inds = fg_inds[:fg_nums]
F
FDInSky 已提交
305

306 307 308 309 310 311 312 313
        # sample bg 
        bg_inds = np.where((roi_gt_bbox_iou_max < bg_thresh_hi) & (
            roi_gt_bbox_iou_max >= bg_thresh_lo))[0]
        bg_nums = rois_per_image - fg_nums
        bg_nums = np.minimum(bg_nums, bg_inds.shape[0])
        if (bg_inds.shape[0] > bg_nums) and use_random:
            bg_inds = np.random.choice(bg_inds, size=bg_nums, replace=False)
        bg_inds = bg_inds[:bg_nums]
F
FDInSky 已提交
314

315
    return fg_inds, bg_inds, fg_nums
F
FDInSky 已提交
316 317 318 319


@jit
def generate_mask_target(im_info, gt_classes, is_crowd, gt_segms, rois,
320
                         rois_num, labels_int32, num_classes, resolution):
F
FDInSky 已提交
321
    mask_rois = []
322
    mask_rois_num = []
F
FDInSky 已提交
323 324 325
    rois_has_mask_int32 = []
    mask_int32 = []
    st_num = 0
326
    end_num = 0
327 328 329
    for k in range(len(rois_num)):
        length = rois_num[k]
        end_num += length
330 331 332 333 334 335 336 337 338 339 340 341 342

        # remove padding
        gt_polys = gt_segms[k]
        new_gt_polys = []
        for i in range(gt_polys.shape[0]):
            gt_segs = []
            for j in range(gt_polys[i].shape[0]):
                new_poly = []
                polys = gt_polys[i][j]
                for ii in range(polys.shape[0]):
                    x, y = polys[ii]
                    if (x == -1 and y == -1):
                        continue
W
wangguanzhong 已提交
343
                    elif (x >= 0 or y >= 0):
344 345 346 347 348 349 350 351
                        new_poly.append([x, y])  # array, one poly 
                if len(new_poly) > 0:
                    gt_segs.append(new_poly)
            new_gt_polys.append(gt_segs)
        im_scale = im_info[k][2]
        boxes = rois[st_num:end_num] / im_scale

        bbox_fg, bbox_has_mask, masks = sample_mask(
352
            boxes, new_gt_polys, labels_int32[st_num:end_num], gt_classes[k],
353 354
            is_crowd[k], num_classes, resolution)

355
        st_num += length
356 357

        mask_rois.append(bbox_fg * im_scale)
358
        mask_rois_num.append(len(bbox_fg))
359 360 361
        rois_has_mask_int32.append(bbox_has_mask)
        mask_int32.append(masks)

F
FDInSky 已提交
362
    mask_rois = np.concatenate(mask_rois, axis=0).astype(np.float32)
363
    mask_rois_num = np.array(mask_rois_num).astype(np.int32)
F
FDInSky 已提交
364 365 366 367
    rois_has_mask_int32 = np.concatenate(
        rois_has_mask_int32, axis=0).astype(np.int32)
    mask_int32 = np.concatenate(mask_int32, axis=0).astype(np.int32)

368
    return mask_rois, mask_rois_num, rois_has_mask_int32, mask_int32
F
FDInSky 已提交
369 370 371


@jit
372 373
def sample_mask(boxes, gt_polys, label_int32, gt_classes, is_crowd, num_classes,
                resolution):
F
FDInSky 已提交
374

375 376 377 378
    gt_polys_inds = np.where((gt_classes > 0) & (is_crowd == 0))[0]
    _gt_polys = [gt_polys[i] for i in gt_polys_inds]
    boxes_from_polys = polys_to_boxes(_gt_polys)

F
FDInSky 已提交
379
    fg_inds = np.where(label_int32 > 0)[0]
380
    bbox_has_mask = fg_inds.copy()
F
FDInSky 已提交
381 382

    if fg_inds.shape[0] > 0:
383 384 385
        labels_fg = label_int32[fg_inds]
        masks_fg = np.zeros((fg_inds.shape[0], resolution**2), dtype=np.int32)
        bbox_fg = boxes[fg_inds]
F
FDInSky 已提交
386

387 388
        iou = bbox_overlaps_mask(bbox_fg, boxes_from_polys)
        fg_polys_inds = np.argmax(iou, axis=1)
F
FDInSky 已提交
389

390 391 392
        for i in range(bbox_fg.shape[0]):
            poly_gt = _gt_polys[fg_polys_inds[i]]
            roi_fg = bbox_fg[i]
F
FDInSky 已提交
393 394 395

            mask = polys_to_mask_wrt_box(poly_gt, roi_fg, resolution)
            mask = np.array(mask > 0, dtype=np.int32)
396
            masks_fg[i, :] = np.reshape(mask, resolution**2)
F
FDInSky 已提交
397 398
    else:
        bg_inds = np.where(label_int32 == 0)[0]
399 400 401 402 403 404
        bbox_fg = boxes[bg_inds[0]].reshape((1, -1))
        masks_fg = -np.ones((1, resolution**2), dtype=np.int32)
        labels_fg = np.zeros((1, ))
        bbox_has_mask = np.append(bbox_has_mask, 0)
    masks = expand_mask_targets(masks_fg, labels_fg, resolution, num_classes)
    return bbox_fg, bbox_has_mask, masks