target.py 14.8 KB
Newer Older
F
FDInSky 已提交
1 2 3 4 5 6 7 8 9
import six
import math
import numpy as np
from numba import jit
from .bbox import *
from .mask import *


@jit
10
def generate_rpn_anchor_target(anchors,
F
FDInSky 已提交
11 12 13 14 15 16 17 18
                               gt_boxes,
                               is_crowd,
                               im_info,
                               rpn_straddle_thresh,
                               rpn_batch_size_per_im,
                               rpn_positive_overlap,
                               rpn_negative_overlap,
                               rpn_fg_fraction,
19 20 21
                               use_random=True,
                               anchor_reg_weights=[1., 1., 1., 1.]):
    anchor_num = anchors.shape[0]
F
FDInSky 已提交
22 23
    batch_size = gt_boxes.shape[0]

24 25 26 27 28 29
    loc_indexes = []
    cls_indexes = []
    tgt_labels = []
    tgt_deltas = []
    anchor_inside_weights = []

F
FDInSky 已提交
30
    for i in range(batch_size):
31 32

        # TODO: move anchor filter into anchor generator 
F
FDInSky 已提交
33 34 35 36
        im_height = im_info[i][0]
        im_width = im_info[i][1]
        im_scale = im_info[i][2]
        if rpn_straddle_thresh >= 0:
37 38 39 40 41
            anchor_inds = np.where((anchors[:, 0] >= -rpn_straddle_thresh) & (
                anchors[:, 1] >= -rpn_straddle_thresh) & (
                    anchors[:, 2] < im_width + rpn_straddle_thresh) & (
                        anchors[:, 3] < im_height + rpn_straddle_thresh))[0]
            anchor = anchors[anchor_inds, :]
F
FDInSky 已提交
42
        else:
43 44
            anchor_inds = np.arange(anchors.shape[0])
            anchor = anchors
F
FDInSky 已提交
45

46 47
        gt_bbox = gt_boxes[i] * im_scale
        is_crowd_slice = is_crowd[i]
F
FDInSky 已提交
48
        not_crowd_inds = np.where(is_crowd_slice == 0)[0]
49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87
        gt_bbox = gt_bbox[not_crowd_inds]

        # Step1: match anchor and gt_bbox
        anchor_gt_bbox_inds, anchor_gt_bbox_iou, labels = label_anchor(anchor,
                                                                       gt_bbox)

        # Step2: sample anchor 
        fg_inds, bg_inds, fg_fake_inds, fake_num = sample_anchor(
            anchor_gt_bbox_iou, labels, rpn_positive_overlap,
            rpn_negative_overlap, rpn_batch_size_per_im, rpn_fg_fraction,
            use_random)

        # Step3: make output  
        loc_inds = np.hstack([fg_fake_inds, fg_inds])
        cls_inds = np.hstack([fg_inds, bg_inds])

        sampled_labels = labels[cls_inds]

        sampled_anchors = anchor[loc_inds]
        sampled_gt_boxes = gt_bbox[anchor_gt_bbox_inds[loc_inds]]
        sampled_deltas = bbox2delta(sampled_anchors, sampled_gt_boxes,
                                    anchor_reg_weights)

        anchor_inside_weight = np.zeros((len(loc_inds), 4), dtype=np.float32)
        anchor_inside_weight[fake_num:, :] = 1

        loc_indexes.append(anchor_inds[loc_inds] + i * anchor_num)
        cls_indexes.append(anchor_inds[cls_inds] + i * anchor_num)
        tgt_labels.append(sampled_labels)
        tgt_deltas.append(sampled_deltas)
        anchor_inside_weights.append(anchor_inside_weight)

    loc_indexes = np.concatenate(loc_indexes)
    cls_indexes = np.concatenate(cls_indexes)
    tgt_labels = np.concatenate(tgt_labels).astype('float32')
    tgt_deltas = np.vstack(tgt_deltas).astype('float32')
    anchor_inside_weights = np.vstack(anchor_inside_weights)

    return loc_indexes, cls_indexes, tgt_labels, tgt_deltas, anchor_inside_weights
F
FDInSky 已提交
88 89 90


@jit
91
def label_anchor(anchors, gt_boxes):
92
    iou = bbox_overlaps(anchors, gt_boxes)
93 94 95 96 97 98 99 100 101 102 103 104 105
    # every gt's anchor's index
    gt_bbox_anchor_inds = iou.argmax(axis=0)
    gt_bbox_anchor_iou = iou[gt_bbox_anchor_inds, np.arange(iou.shape[1])]
    gt_bbox_anchor_iou_inds = np.where(iou == gt_bbox_anchor_iou)[0]

    # every anchor's gt bbox's index 
    anchor_gt_bbox_inds = iou.argmax(axis=1)
    anchor_gt_bbox_iou = iou[np.arange(iou.shape[0]), anchor_gt_bbox_inds]

    labels = np.ones((iou.shape[0], ), dtype=np.int32) * -1
    labels[gt_bbox_anchor_iou_inds] = 1

    return anchor_gt_bbox_inds, anchor_gt_bbox_iou, labels
F
FDInSky 已提交
106

107 108 109 110 111 112 113 114 115 116 117

@jit
def sample_anchor(anchor_gt_bbox_iou,
                  labels,
                  rpn_positive_overlap,
                  rpn_negative_overlap,
                  rpn_batch_size_per_im,
                  rpn_fg_fraction,
                  use_random=True):

    labels[anchor_gt_bbox_iou >= rpn_positive_overlap] = 1
F
FDInSky 已提交
118 119 120 121 122 123 124 125 126 127 128
    num_fg = int(rpn_fg_fraction * rpn_batch_size_per_im)
    fg_inds = np.where(labels == 1)[0]
    if len(fg_inds) > num_fg and use_random:
        disable_inds = np.random.choice(
            fg_inds, size=(len(fg_inds) - num_fg), replace=False)
    else:
        disable_inds = fg_inds[num_fg:]
    labels[disable_inds] = -1
    fg_inds = np.where(labels == 1)[0]

    num_bg = rpn_batch_size_per_im - np.sum(labels == 1)
129
    bg_inds = np.where(anchor_gt_bbox_iou < rpn_negative_overlap)[0]
F
FDInSky 已提交
130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146
    if len(bg_inds) > num_bg and use_random:
        enable_inds = bg_inds[np.random.randint(len(bg_inds), size=num_bg)]
    else:
        enable_inds = bg_inds[:num_bg]

    fg_fake_inds = np.array([], np.int32)
    fg_value = np.array([fg_inds[0]], np.int32)
    fake_num = 0
    for bg_id in enable_inds:
        if bg_id in fg_inds:
            fake_num += 1
            fg_fake_inds = np.hstack([fg_fake_inds, fg_value])
    labels[enable_inds] = 0

    fg_inds = np.where(labels == 1)[0]
    bg_inds = np.where(labels == 0)[0]

147
    return fg_inds, bg_inds, fg_fake_inds, fake_num
F
FDInSky 已提交
148 149


W
wangguanzhong 已提交
150 151 152 153 154 155 156 157 158 159
@jit
def filter_roi(rois, max_overlap):
    ws = rois[:, 2] - rois[:, 0] + 1
    hs = rois[:, 3] - rois[:, 1] + 1
    keep = np.where((ws > 0) & (hs > 0) & (max_overlap < 1))[0]
    if len(keep) > 0:
        return rois[keep, :]
    return np.zeros((1, 4)).astype('float32')


F
FDInSky 已提交
160 161
@jit
def generate_proposal_target(rpn_rois,
162
                             rpn_rois_num,
F
FDInSky 已提交
163 164 165 166 167 168 169 170 171 172 173 174 175
                             gt_classes,
                             is_crowd,
                             gt_boxes,
                             im_info,
                             batch_size_per_im,
                             fg_fraction,
                             fg_thresh,
                             bg_thresh_hi,
                             bg_thresh_lo,
                             bbox_reg_weights,
                             class_nums=81,
                             use_random=True,
                             is_cls_agnostic=False,
W
wangguanzhong 已提交
176 177
                             is_cascade_rcnn=False,
                             max_overlaps=None):
F
FDInSky 已提交
178 179

    rois = []
180 181 182 183
    tgt_labels = []
    tgt_deltas = []
    rois_inside_weights = []
    rois_outside_weights = []
W
wangguanzhong 已提交
184
    sampled_max_overlaps = []
185
    new_rois_num = []
F
FDInSky 已提交
186
    st_num = 0
187
    end_num = 0
188 189 190
    for im_i in range(len(rpn_rois_num)):
        length = rpn_rois_num[im_i]
        end_num += length
191
        rpn_roi = rpn_rois[st_num:end_num]
W
wangguanzhong 已提交
192
        max_overlap = max_overlaps[st_num:end_num] if is_cascade_rcnn else None
193 194 195 196 197
        im_scale = im_info[im_i][2]
        rpn_roi = rpn_roi / im_scale
        gt_bbox = gt_boxes[im_i]

        if is_cascade_rcnn:
W
wangguanzhong 已提交
198 199
            rpn_roi = filter_roi(rpn_roi, max_overlap)
        bbox = np.vstack([gt_bbox, rpn_roi]).astype('float32')
200 201

        # Step1: label bbox 
W
wangguanzhong 已提交
202
        roi_gt_bbox_inds, labels, max_overlap = label_bbox(
203 204 205 206
            bbox, gt_bbox, gt_classes[im_i], is_crowd[im_i])

        # Step2: sample bbox 
        fg_inds, bg_inds, fg_nums = sample_bbox(
W
wangguanzhong 已提交
207
            max_overlap, batch_size_per_im, fg_fraction, fg_thresh,
208
            bg_thresh_hi, bg_thresh_lo, bbox_reg_weights, class_nums,
F
FDInSky 已提交
209 210
            use_random, is_cls_agnostic, is_cascade_rcnn)

211 212 213 214 215 216 217
        # Step3: make output 
        sampled_inds = np.append(fg_inds, bg_inds)

        sampled_labels = labels[sampled_inds]
        sampled_labels[fg_nums:] = 0

        sampled_boxes = bbox[sampled_inds]
W
wangguanzhong 已提交
218
        sampled_max_overlap = max_overlap[sampled_inds]
219
        sampled_gt_boxes = gt_bbox[roi_gt_bbox_inds[sampled_inds]]
W
wangguanzhong 已提交
220
        sampled_gt_boxes[fg_nums:, :] = 0
221 222
        sampled_deltas = compute_bbox_targets(sampled_boxes, sampled_gt_boxes,
                                              sampled_labels, bbox_reg_weights)
W
wangguanzhong 已提交
223
        sampled_deltas[fg_nums:, :] = 0
224 225 226 227 228 229
        sampled_deltas, bbox_inside_weights = expand_bbox_targets(
            sampled_deltas, class_nums, is_cls_agnostic)
        bbox_outside_weights = np.array(
            bbox_inside_weights > 0, dtype=bbox_inside_weights.dtype)

        roi = sampled_boxes * im_scale
230
        st_num += length
231 232

        rois.append(roi)
233
        new_rois_num.append(roi.shape[0])
234 235 236 237
        tgt_labels.append(sampled_labels)
        tgt_deltas.append(sampled_deltas)
        rois_inside_weights.append(bbox_inside_weights)
        rois_outside_weights.append(bbox_outside_weights)
W
wangguanzhong 已提交
238
        sampled_max_overlaps.append(sampled_max_overlap)
F
FDInSky 已提交
239 240

    rois = np.concatenate(rois, axis=0).astype(np.float32)
241 242 243 244 245 246 247
    tgt_labels = np.concatenate(
        tgt_labels, axis=0).astype(np.int32).reshape(-1, 1)
    tgt_deltas = np.concatenate(tgt_deltas, axis=0).astype(np.float32)
    rois_inside_weights = np.concatenate(
        rois_inside_weights, axis=0).astype(np.float32)
    rois_outside_weights = np.concatenate(
        rois_outside_weights, axis=0).astype(np.float32)
W
wangguanzhong 已提交
248 249
    sampled_max_overlaps = np.concatenate(
        sampled_max_overlaps, axis=0).astype(np.float32)
250
    new_rois_num = np.asarray(new_rois_num, np.int32)
W
wangguanzhong 已提交
251
    return rois, tgt_labels, tgt_deltas, rois_inside_weights, rois_outside_weights, new_rois_num, sampled_max_overlaps
F
FDInSky 已提交
252 253 254


@jit
W
wangguanzhong 已提交
255
def label_bbox(boxes, gt_boxes, gt_classes, is_crowd, class_nums=81):
256

257
    iou = bbox_overlaps(boxes, gt_boxes)
258 259 260

    # every roi's gt box's index  
    roi_gt_bbox_inds = np.zeros((boxes.shape[0]), dtype=np.int32)
W
wangguanzhong 已提交
261
    roi_gt_bbox_iou = np.zeros((boxes.shape[0], class_nums), dtype=np.float32)
262 263 264 265 266 267 268 269 270

    iou_argmax = iou.argmax(axis=1)
    iou_max = iou.max(axis=1)
    overlapped_boxes_ind = np.where(iou_max > 0)[0].astype('int32')
    roi_gt_bbox_inds[overlapped_boxes_ind] = iou_argmax[overlapped_boxes_ind]
    overlapped_boxes_gt_classes = gt_classes[iou_argmax[
        overlapped_boxes_ind]].astype('int32')
    roi_gt_bbox_iou[overlapped_boxes_ind,
                    overlapped_boxes_gt_classes] = iou_max[overlapped_boxes_ind]
F
FDInSky 已提交
271 272

    crowd_ind = np.where(is_crowd)[0]
273 274
    roi_gt_bbox_iou[crowd_ind] = -1

W
wangguanzhong 已提交
275
    max_overlap = roi_gt_bbox_iou.max(axis=1)
276
    labels = roi_gt_bbox_iou.argmax(axis=1)
F
FDInSky 已提交
277

W
wangguanzhong 已提交
278
    return roi_gt_bbox_inds, labels, max_overlap
279 280 281


@jit
W
wangguanzhong 已提交
282
def sample_bbox(max_overlap,
283 284 285 286 287 288 289 290 291 292 293 294 295
                batch_size_per_im,
                fg_fraction,
                fg_thresh,
                bg_thresh_hi,
                bg_thresh_lo,
                bbox_reg_weights,
                class_nums,
                use_random=True,
                is_cls_agnostic=False,
                is_cascade_rcnn=False):

    rois_per_image = int(batch_size_per_im)
    fg_rois_per_im = int(np.round(fg_fraction * rois_per_image))
F
FDInSky 已提交
296 297

    if is_cascade_rcnn:
W
wangguanzhong 已提交
298 299 300
        fg_inds = np.where(max_overlap >= fg_thresh)[0]
        bg_inds = np.where((max_overlap < bg_thresh_hi) & (max_overlap >=
                                                           bg_thresh_lo))[0]
301 302
        fg_nums = fg_inds.shape[0]
        bg_nums = bg_inds.shape[0]
F
FDInSky 已提交
303
    else:
304
        # sampe fg 
W
wangguanzhong 已提交
305
        fg_inds = np.where(max_overlap >= fg_thresh)[0]
306 307 308 309
        fg_nums = np.minimum(fg_rois_per_im, fg_inds.shape[0])
        if (fg_inds.shape[0] > fg_nums) and use_random:
            fg_inds = np.random.choice(fg_inds, size=fg_nums, replace=False)
        fg_inds = fg_inds[:fg_nums]
F
FDInSky 已提交
310

311
        # sample bg 
W
wangguanzhong 已提交
312 313
        bg_inds = np.where((max_overlap < bg_thresh_hi) & (max_overlap >=
                                                           bg_thresh_lo))[0]
314 315 316 317 318
        bg_nums = rois_per_image - fg_nums
        bg_nums = np.minimum(bg_nums, bg_inds.shape[0])
        if (bg_inds.shape[0] > bg_nums) and use_random:
            bg_inds = np.random.choice(bg_inds, size=bg_nums, replace=False)
        bg_inds = bg_inds[:bg_nums]
F
FDInSky 已提交
319

320
    return fg_inds, bg_inds, fg_nums
F
FDInSky 已提交
321 322 323 324


@jit
def generate_mask_target(im_info, gt_classes, is_crowd, gt_segms, rois,
325
                         rois_num, labels_int32, num_classes, resolution):
F
FDInSky 已提交
326
    mask_rois = []
327
    mask_rois_num = []
F
FDInSky 已提交
328 329 330
    rois_has_mask_int32 = []
    mask_int32 = []
    st_num = 0
331
    end_num = 0
332 333 334
    for k in range(len(rois_num)):
        length = rois_num[k]
        end_num += length
335 336 337 338 339 340 341 342 343 344 345 346 347

        # remove padding
        gt_polys = gt_segms[k]
        new_gt_polys = []
        for i in range(gt_polys.shape[0]):
            gt_segs = []
            for j in range(gt_polys[i].shape[0]):
                new_poly = []
                polys = gt_polys[i][j]
                for ii in range(polys.shape[0]):
                    x, y = polys[ii]
                    if (x == -1 and y == -1):
                        continue
W
wangguanzhong 已提交
348
                    elif (x >= 0 or y >= 0):
349 350 351 352 353 354 355 356
                        new_poly.append([x, y])  # array, one poly 
                if len(new_poly) > 0:
                    gt_segs.append(new_poly)
            new_gt_polys.append(gt_segs)
        im_scale = im_info[k][2]
        boxes = rois[st_num:end_num] / im_scale

        bbox_fg, bbox_has_mask, masks = sample_mask(
357
            boxes, new_gt_polys, labels_int32[st_num:end_num], gt_classes[k],
358 359
            is_crowd[k], num_classes, resolution)

360
        st_num += length
361 362

        mask_rois.append(bbox_fg * im_scale)
363
        mask_rois_num.append(len(bbox_fg))
364 365 366
        rois_has_mask_int32.append(bbox_has_mask)
        mask_int32.append(masks)

F
FDInSky 已提交
367
    mask_rois = np.concatenate(mask_rois, axis=0).astype(np.float32)
368
    mask_rois_num = np.array(mask_rois_num).astype(np.int32)
F
FDInSky 已提交
369 370 371 372
    rois_has_mask_int32 = np.concatenate(
        rois_has_mask_int32, axis=0).astype(np.int32)
    mask_int32 = np.concatenate(mask_int32, axis=0).astype(np.int32)

373
    return mask_rois, mask_rois_num, rois_has_mask_int32, mask_int32
F
FDInSky 已提交
374 375 376


@jit
377 378
def sample_mask(boxes, gt_polys, label_int32, gt_classes, is_crowd, num_classes,
                resolution):
F
FDInSky 已提交
379

380 381 382 383
    gt_polys_inds = np.where((gt_classes > 0) & (is_crowd == 0))[0]
    _gt_polys = [gt_polys[i] for i in gt_polys_inds]
    boxes_from_polys = polys_to_boxes(_gt_polys)

F
FDInSky 已提交
384
    fg_inds = np.where(label_int32 > 0)[0]
385
    bbox_has_mask = fg_inds.copy()
F
FDInSky 已提交
386 387

    if fg_inds.shape[0] > 0:
388 389 390
        labels_fg = label_int32[fg_inds]
        masks_fg = np.zeros((fg_inds.shape[0], resolution**2), dtype=np.int32)
        bbox_fg = boxes[fg_inds]
F
FDInSky 已提交
391

392 393
        iou = bbox_overlaps_mask(bbox_fg, boxes_from_polys)
        fg_polys_inds = np.argmax(iou, axis=1)
F
FDInSky 已提交
394

395 396 397
        for i in range(bbox_fg.shape[0]):
            poly_gt = _gt_polys[fg_polys_inds[i]]
            roi_fg = bbox_fg[i]
F
FDInSky 已提交
398 399 400

            mask = polys_to_mask_wrt_box(poly_gt, roi_fg, resolution)
            mask = np.array(mask > 0, dtype=np.int32)
401
            masks_fg[i, :] = np.reshape(mask, resolution**2)
F
FDInSky 已提交
402 403
    else:
        bg_inds = np.where(label_int32 == 0)[0]
404 405 406 407 408 409
        bbox_fg = boxes[bg_inds[0]].reshape((1, -1))
        masks_fg = -np.ones((1, resolution**2), dtype=np.int32)
        labels_fg = np.zeros((1, ))
        bbox_has_mask = np.append(bbox_has_mask, 0)
    masks = expand_mask_targets(masks_fg, labels_fg, resolution, num_classes)
    return bbox_fg, bbox_has_mask, masks