target.py 26.7 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27
#   Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import numpy as np
import paddle
from ..bbox_utils import bbox2delta, bbox_overlaps


def rpn_anchor_target(anchors,
                      gt_boxes,
                      rpn_batch_size_per_im,
                      rpn_positive_overlap,
                      rpn_negative_overlap,
                      rpn_fg_fraction,
                      use_random=True,
                      batch_size=1,
28 29
                      ignore_thresh=-1,
                      is_crowd=None,
W
wangguanzhong 已提交
30 31
                      weights=[1., 1., 1., 1.],
                      assign_on_cpu=False):
32 33 34 35 36
    tgt_labels = []
    tgt_bboxes = []
    tgt_deltas = []
    for i in range(batch_size):
        gt_bbox = gt_boxes[i]
37
        is_crowd_i = is_crowd[i] if is_crowd else None
38
        # Step1: match anchor and gt_bbox
W
wangguanzhong 已提交
39
        matches, match_labels = label_box(
40
            anchors, gt_bbox, rpn_positive_overlap, rpn_negative_overlap, True,
W
wangguanzhong 已提交
41
            ignore_thresh, is_crowd_i, assign_on_cpu)
42 43 44 45 46
        # Step2: sample anchor 
        fg_inds, bg_inds = subsample_labels(match_labels, rpn_batch_size_per_im,
                                            rpn_fg_fraction, 0, use_random)
        # Fill with the ignore label (-1), then set positive and negative labels
        labels = paddle.full(match_labels.shape, -1, dtype='int32')
47 48 49 50
        if bg_inds.shape[0] > 0:
            labels = paddle.scatter(labels, bg_inds, paddle.zeros_like(bg_inds))
        if fg_inds.shape[0] > 0:
            labels = paddle.scatter(labels, fg_inds, paddle.ones_like(fg_inds))
51
        # Step3: make output  
52 53 54 55 56 57 58 59
        if gt_bbox.shape[0] == 0:
            matched_gt_boxes = paddle.zeros([0, 4])
            tgt_delta = paddle.zeros([0, 4])
        else:
            matched_gt_boxes = paddle.gather(gt_bbox, matches)
            tgt_delta = bbox2delta(anchors, matched_gt_boxes, weights)
            matched_gt_boxes.stop_gradient = True
            tgt_delta.stop_gradient = True
60 61 62 63 64 65 66 67
        labels.stop_gradient = True
        tgt_labels.append(labels)
        tgt_bboxes.append(matched_gt_boxes)
        tgt_deltas.append(tgt_delta)

    return tgt_labels, tgt_bboxes, tgt_deltas


68 69 70 71 72 73
def label_box(anchors,
              gt_boxes,
              positive_overlap,
              negative_overlap,
              allow_low_quality,
              ignore_thresh,
W
wangguanzhong 已提交
74 75 76
              is_crowd=None,
              assign_on_cpu=False):
    if assign_on_cpu:
W
wangguanzhong 已提交
77 78 79
        paddle.set_device("cpu")
        iou = bbox_overlaps(gt_boxes, anchors)
        paddle.set_device("gpu")
W
wangguanzhong 已提交
80 81
    else:
        iou = bbox_overlaps(gt_boxes, anchors)
82 83 84 85 86 87 88
    n_gt = gt_boxes.shape[0]
    if n_gt == 0 or is_crowd is None:
        n_gt_crowd = 0
    else:
        n_gt_crowd = paddle.nonzero(is_crowd).shape[0]
    if iou.shape[0] == 0 or n_gt_crowd == n_gt:
        # No truth, assign everything to background
89
        default_matches = paddle.full((iou.shape[1], ), 0, dtype='int64')
90
        default_match_labels = paddle.full((iou.shape[1], ), 0, dtype='int32')
91
        return default_matches, default_match_labels
92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107
    # if ignore_thresh > 0, remove anchor if it is closed to 
    # one of the crowded ground-truth
    if n_gt_crowd > 0:
        N_a = anchors.shape[0]
        ones = paddle.ones([N_a])
        mask = is_crowd * ones

        if ignore_thresh > 0:
            crowd_iou = iou * mask
            valid = (paddle.sum((crowd_iou > ignore_thresh).cast('int32'),
                                axis=0) > 0).cast('float32')
            iou = iou * (1 - valid) - valid

        # ignore the iou between anchor and crowded ground-truth
        iou = iou * (1 - mask) - mask

108 109
    matched_vals, matches = paddle.topk(iou, k=1, axis=0)
    match_labels = paddle.full(matches.shape, -1, dtype='int32')
110 111 112 113
    # set ignored anchor with iou = -1
    neg_cond = paddle.logical_and(matched_vals > -1,
                                  matched_vals < negative_overlap)
    match_labels = paddle.where(neg_cond,
114 115 116 117 118
                                paddle.zeros_like(match_labels), match_labels)
    match_labels = paddle.where(matched_vals >= positive_overlap,
                                paddle.ones_like(match_labels), match_labels)
    if allow_low_quality:
        highest_quality_foreach_gt = iou.max(axis=1, keepdim=True)
119 120 121
        pred_inds_with_highest_quality = paddle.logical_and(
            iou > 0, iou == highest_quality_foreach_gt).cast('int32').sum(
                0, keepdim=True)
122 123 124 125 126 127
        match_labels = paddle.where(pred_inds_with_highest_quality > 0,
                                    paddle.ones_like(match_labels),
                                    match_labels)

    matches = matches.flatten()
    match_labels = match_labels.flatten()
128

W
wangguanzhong 已提交
129
    return matches, match_labels
130 131 132 133 134 135 136 137 138 139 140 141 142 143 144


def subsample_labels(labels,
                     num_samples,
                     fg_fraction,
                     bg_label=0,
                     use_random=True):
    positive = paddle.nonzero(
        paddle.logical_and(labels != -1, labels != bg_label))
    negative = paddle.nonzero(labels == bg_label)

    fg_num = int(num_samples * fg_fraction)
    fg_num = min(positive.numel(), fg_num)
    bg_num = num_samples - fg_num
    bg_num = min(negative.numel(), bg_num)
145 146 147 148 149
    if fg_num == 0 and bg_num == 0:
        fg_inds = paddle.zeros([0], dtype='int32')
        bg_inds = paddle.zeros([0], dtype='int32')
        return fg_inds, bg_inds

150
    # randomly select positive and negative examples
151 152

    negative = negative.cast('int32').flatten()
153 154 155 156 157 158
    bg_perm = paddle.randperm(negative.numel(), dtype='int32')
    bg_perm = paddle.slice(bg_perm, axes=[0], starts=[0], ends=[bg_num])
    if use_random:
        bg_inds = paddle.gather(negative, bg_perm)
    else:
        bg_inds = paddle.slice(negative, axes=[0], starts=[0], ends=[bg_num])
159 160 161 162 163 164 165 166 167 168 169 170
    if fg_num == 0:
        fg_inds = paddle.zeros([0], dtype='int32')
        return fg_inds, bg_inds

    positive = positive.cast('int32').flatten()
    fg_perm = paddle.randperm(positive.numel(), dtype='int32')
    fg_perm = paddle.slice(fg_perm, axes=[0], starts=[0], ends=[fg_num])
    if use_random:
        fg_inds = paddle.gather(positive, fg_perm)
    else:
        fg_inds = paddle.slice(positive, axes=[0], starts=[0], ends=[fg_num])

171 172 173 174 175 176 177 178 179 180 181
    return fg_inds, bg_inds


def generate_proposal_target(rpn_rois,
                             gt_classes,
                             gt_boxes,
                             batch_size_per_im,
                             fg_fraction,
                             fg_thresh,
                             bg_thresh,
                             num_classes,
182 183
                             ignore_thresh=-1.,
                             is_crowd=None,
184
                             use_random=True,
W
wangguanzhong 已提交
185
                             is_cascade=False,
W
wangguanzhong 已提交
186 187
                             cascade_iou=0.5,
                             assign_on_cpu=False):
188 189 190 191 192 193 194

    rois_with_gt = []
    tgt_labels = []
    tgt_bboxes = []
    tgt_gt_inds = []
    new_rois_num = []

W
wangguanzhong 已提交
195 196
    # In cascade rcnn, the threshold for foreground and background
    # is used from cascade_iou
W
wangguanzhong 已提交
197 198
    fg_thresh = cascade_iou if is_cascade else fg_thresh
    bg_thresh = cascade_iou if is_cascade else bg_thresh
199 200
    for i, rpn_roi in enumerate(rpn_rois):
        gt_bbox = gt_boxes[i]
201
        is_crowd_i = is_crowd[i] if is_crowd else None
202 203
        gt_class = paddle.squeeze(gt_classes[i], axis=-1)

204 205
        # Concat RoIs and gt boxes except cascade rcnn or none gt
        if not is_cascade and gt_bbox.shape[0] > 0:
W
wangguanzhong 已提交
206 207 208 209 210 211
            bbox = paddle.concat([rpn_roi, gt_bbox])
        else:
            bbox = rpn_roi

        # Step1: label bbox
        matches, match_labels = label_box(bbox, gt_bbox, fg_thresh, bg_thresh,
W
wangguanzhong 已提交
212 213
                                          False, ignore_thresh, is_crowd_i,
                                          assign_on_cpu)
214 215
        # Step2: sample bbox 
        sampled_inds, sampled_gt_classes = sample_bbox(
216
            matches, match_labels, gt_class, batch_size_per_im, fg_fraction,
W
wangguanzhong 已提交
217
            num_classes, use_random, is_cascade)
218 219

        # Step3: make output 
W
wangguanzhong 已提交
220 221 222 223
        rois_per_image = bbox if is_cascade else paddle.gather(bbox,
                                                               sampled_inds)
        sampled_gt_ind = matches if is_cascade else paddle.gather(matches,
                                                                  sampled_inds)
224 225 226
        if gt_bbox.shape[0] > 0:
            sampled_bbox = paddle.gather(gt_bbox, sampled_gt_ind)
        else:
227 228
            num = rois_per_image.shape[0]
            sampled_bbox = paddle.zeros([num, 4], dtype='float32')
229 230 231 232 233 234 235 236 237 238

        rois_per_image.stop_gradient = True
        sampled_gt_ind.stop_gradient = True
        sampled_bbox.stop_gradient = True
        tgt_labels.append(sampled_gt_classes)
        tgt_bboxes.append(sampled_bbox)
        rois_with_gt.append(rois_per_image)
        tgt_gt_inds.append(sampled_gt_ind)
        new_rois_num.append(paddle.shape(sampled_inds)[0])
    new_rois_num = paddle.concat(new_rois_num)
W
wangguanzhong 已提交
239
    return rois_with_gt, tgt_labels, tgt_bboxes, tgt_gt_inds, new_rois_num
240 241


W
wangguanzhong 已提交
242 243 244 245 246 247 248 249
def sample_bbox(matches,
                match_labels,
                gt_classes,
                batch_size_per_im,
                fg_fraction,
                num_classes,
                use_random=True,
                is_cascade=False):
250 251 252 253 254 255 256 257 258 259 260 261 262

    n_gt = gt_classes.shape[0]
    if n_gt == 0:
        # No truth, assign everything to background
        gt_classes = paddle.ones(matches.shape, dtype='int32') * num_classes
        #return matches, match_labels + num_classes
    else:
        gt_classes = paddle.gather(gt_classes, matches)
        gt_classes = paddle.where(match_labels == 0,
                                  paddle.ones_like(gt_classes) * num_classes,
                                  gt_classes)
        gt_classes = paddle.where(match_labels == -1,
                                  paddle.ones_like(gt_classes) * -1, gt_classes)
W
wangguanzhong 已提交
263
    if is_cascade:
264 265
        index = paddle.arange(matches.shape[0])
        return index, gt_classes
266 267 268 269
    rois_per_image = int(batch_size_per_im)

    fg_inds, bg_inds = subsample_labels(gt_classes, rois_per_image, fg_fraction,
                                        num_classes, use_random)
270 271 272 273 274 275
    if fg_inds.shape[0] == 0 and bg_inds.shape[0] == 0:
        # fake output labeled with -1 when all boxes are neither
        # foreground nor background
        sampled_inds = paddle.zeros([1], dtype='int32')
    else:
        sampled_inds = paddle.concat([fg_inds, bg_inds])
276 277 278 279 280 281
    sampled_gt_classes = paddle.gather(gt_classes, sampled_inds)
    return sampled_inds, sampled_gt_classes


def polygons_to_mask(polygons, height, width):
    """
W
wangguanzhong 已提交
282 283
    Convert the polygons to mask format

284 285
    Args:
        polygons (list[ndarray]): each array has shape (Nx2,)
W
wangguanzhong 已提交
286 287
        height (int): mask height
        width (int): mask width
288 289 290 291 292 293 294 295 296 297 298 299
    Returns:
        ndarray: a bool mask of shape (height, width)
    """
    import pycocotools.mask as mask_util
    assert len(polygons) > 0, "COCOAPI does not support empty polygons"
    rles = mask_util.frPyObjects(polygons, height, width)
    rle = mask_util.merge(rles)
    return mask_util.decode(rle).astype(np.bool)


def rasterize_polygons_within_box(poly, box, resolution):
    w, h = box[2] - box[0], box[3] - box[1]
300
    polygons = [np.asarray(p, dtype=np.float64) for p in poly]
301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331
    for p in polygons:
        p[0::2] = p[0::2] - box[0]
        p[1::2] = p[1::2] - box[1]

    ratio_h = resolution / max(h, 0.1)
    ratio_w = resolution / max(w, 0.1)

    if ratio_h == ratio_w:
        for p in polygons:
            p *= ratio_h
    else:
        for p in polygons:
            p[0::2] *= ratio_w
            p[1::2] *= ratio_h

    # 3. Rasterize the polygons with coco api
    mask = polygons_to_mask(polygons, resolution, resolution)
    mask = paddle.to_tensor(mask, dtype='int32')
    return mask


def generate_mask_target(gt_segms, rois, labels_int32, sampled_gt_inds,
                         num_classes, resolution):
    mask_rois = []
    mask_rois_num = []
    tgt_masks = []
    tgt_classes = []
    mask_index = []
    tgt_weights = []
    for k in range(len(rois)):
        labels_per_im = labels_int32[k]
W
wangguanzhong 已提交
332
        # select rois labeled with foreground
333 334 335
        fg_inds = paddle.nonzero(
            paddle.logical_and(labels_per_im != -1, labels_per_im !=
                               num_classes))
336
        has_fg = True
W
wangguanzhong 已提交
337
        # generate fake roi if foreground is empty
338 339 340 341 342 343
        if fg_inds.numel() == 0:
            has_fg = False
            fg_inds = paddle.ones([1], dtype='int32')
        inds_per_im = sampled_gt_inds[k]
        inds_per_im = paddle.gather(inds_per_im, fg_inds)

344
        rois_per_im = rois[k]
345
        fg_rois = paddle.gather(rois_per_im, fg_inds)
W
wangguanzhong 已提交
346 347
        # Copy the foreground roi to cpu
        # to generate mask target with ground-truth
348 349
        boxes = fg_rois.numpy()
        gt_segms_per_im = gt_segms[k]
350

351 352
        new_segm = []
        inds_per_im = inds_per_im.numpy()
353 354 355
        if len(gt_segms_per_im) > 0:
            for i in inds_per_im:
                new_segm.append(gt_segms_per_im[i])
356 357
        fg_inds_new = fg_inds.reshape([-1]).numpy()
        results = []
358 359 360 361 362 363 364
        if len(gt_segms_per_im) > 0:
            for j in fg_inds_new:
                results.append(
                    rasterize_polygons_within_box(new_segm[j], boxes[j],
                                                  resolution))
        else:
            results.append(paddle.ones([resolution, resolution], dtype='int32'))
365

366 367 368
        fg_classes = paddle.gather(labels_per_im, fg_inds)
        weight = paddle.ones([fg_rois.shape[0]], dtype='float32')
        if not has_fg:
369 370 371 372
            # now all sampled classes are background
            # which will cause error in loss calculation,
            # make fake classes with weight of 0.
            fg_classes = paddle.zeros([1], dtype='int32')
373 374 375 376 377 378 379 380 381 382 383 384 385 386 387 388 389 390 391
            weight = weight - 1
        tgt_mask = paddle.stack(results)
        tgt_mask.stop_gradient = True
        fg_rois.stop_gradient = True

        mask_index.append(fg_inds)
        mask_rois.append(fg_rois)
        mask_rois_num.append(paddle.shape(fg_rois)[0])
        tgt_classes.append(fg_classes)
        tgt_masks.append(tgt_mask)
        tgt_weights.append(weight)

    mask_index = paddle.concat(mask_index)
    mask_rois_num = paddle.concat(mask_rois_num)
    tgt_classes = paddle.concat(tgt_classes, axis=0)
    tgt_masks = paddle.concat(tgt_masks, axis=0)
    tgt_weights = paddle.concat(tgt_weights, axis=0)

    return mask_rois, mask_rois_num, tgt_classes, tgt_masks, mask_index, tgt_weights
G
Guanghua Yu 已提交
392 393 394 395 396 397 398 399 400 401 402 403 404 405 406 407 408 409 410 411 412 413 414 415 416 417 418 419 420 421 422 423 424 425 426 427 428 429 430 431 432 433 434 435 436 437 438 439 440 441 442 443 444 445 446 447 448 449 450 451 452 453 454 455 456 457 458 459 460 461 462 463 464 465 466 467 468 469 470 471 472 473 474 475 476 477 478 479 480 481 482 483 484 485 486 487 488 489 490 491 492 493 494 495 496 497 498 499 500 501 502 503 504 505 506 507 508 509 510 511 512 513 514 515 516 517 518 519 520 521 522 523 524 525 526 527 528 529 530 531 532 533 534 535 536 537 538 539 540 541 542 543 544 545 546 547 548 549 550 551 552 553 554 555 556 557 558 559 560 561 562 563 564 565 566 567 568 569 570 571 572 573 574 575 576 577 578 579 580 581 582 583 584 585 586 587 588 589 590 591 592 593 594 595 596 597 598 599 600 601 602 603 604 605 606 607 608 609 610 611 612 613 614 615 616 617 618 619 620 621 622 623 624 625 626 627 628 629 630 631 632 633 634 635 636 637 638 639 640 641 642 643 644 645 646 647 648 649 650 651 652 653 654 655 656 657 658 659 660 661 662 663 664 665 666 667 668 669 670 671 672 673 674 675


def libra_sample_pos(max_overlaps, max_classes, pos_inds, num_expected):
    if len(pos_inds) <= num_expected:
        return pos_inds
    else:
        unique_gt_inds = np.unique(max_classes[pos_inds])
        num_gts = len(unique_gt_inds)
        num_per_gt = int(round(num_expected / float(num_gts)) + 1)

        sampled_inds = []
        for i in unique_gt_inds:
            inds = np.nonzero(max_classes == i)[0]
            before_len = len(inds)
            inds = list(set(inds) & set(pos_inds))
            after_len = len(inds)
            if len(inds) > num_per_gt:
                inds = np.random.choice(inds, size=num_per_gt, replace=False)
            sampled_inds.extend(list(inds))  # combine as a new sampler
        if len(sampled_inds) < num_expected:
            num_extra = num_expected - len(sampled_inds)
            extra_inds = np.array(list(set(pos_inds) - set(sampled_inds)))
            assert len(sampled_inds) + len(extra_inds) == len(pos_inds), \
                "sum of sampled_inds({}) and extra_inds({}) length must be equal with pos_inds({})!".format(
                    len(sampled_inds), len(extra_inds), len(pos_inds))
            if len(extra_inds) > num_extra:
                extra_inds = np.random.choice(
                    extra_inds, size=num_extra, replace=False)
            sampled_inds.extend(extra_inds.tolist())
        elif len(sampled_inds) > num_expected:
            sampled_inds = np.random.choice(
                sampled_inds, size=num_expected, replace=False)
        return paddle.to_tensor(sampled_inds)


def libra_sample_via_interval(max_overlaps, full_set, num_expected, floor_thr,
                              num_bins, bg_thresh):
    max_iou = max_overlaps.max()
    iou_interval = (max_iou - floor_thr) / num_bins
    per_num_expected = int(num_expected / num_bins)

    sampled_inds = []
    for i in range(num_bins):
        start_iou = floor_thr + i * iou_interval
        end_iou = floor_thr + (i + 1) * iou_interval

        tmp_set = set(
            np.where(
                np.logical_and(max_overlaps >= start_iou, max_overlaps <
                               end_iou))[0])
        tmp_inds = list(tmp_set & full_set)

        if len(tmp_inds) > per_num_expected:
            tmp_sampled_set = np.random.choice(
                tmp_inds, size=per_num_expected, replace=False)
        else:
            tmp_sampled_set = np.array(tmp_inds, dtype=np.int)
        sampled_inds.append(tmp_sampled_set)

    sampled_inds = np.concatenate(sampled_inds)
    if len(sampled_inds) < num_expected:
        num_extra = num_expected - len(sampled_inds)
        extra_inds = np.array(list(full_set - set(sampled_inds)))
        assert len(sampled_inds) + len(extra_inds) == len(full_set), \
            "sum of sampled_inds({}) and extra_inds({}) length must be equal with full_set({})!".format(
                len(sampled_inds), len(extra_inds), len(full_set))

        if len(extra_inds) > num_extra:
            extra_inds = np.random.choice(extra_inds, num_extra, replace=False)
        sampled_inds = np.concatenate([sampled_inds, extra_inds])

    return sampled_inds


def libra_sample_neg(max_overlaps,
                     max_classes,
                     neg_inds,
                     num_expected,
                     floor_thr=-1,
                     floor_fraction=0,
                     num_bins=3,
                     bg_thresh=0.5):
    if len(neg_inds) <= num_expected:
        return neg_inds
    else:
        # balance sampling for negative samples
        neg_set = set(neg_inds.tolist())
        if floor_thr > 0:
            floor_set = set(
                np.where(
                    np.logical_and(max_overlaps >= 0, max_overlaps < floor_thr))
                [0])
            iou_sampling_set = set(np.where(max_overlaps >= floor_thr)[0])
        elif floor_thr == 0:
            floor_set = set(np.where(max_overlaps == 0)[0])
            iou_sampling_set = set(np.where(max_overlaps > floor_thr)[0])
        else:
            floor_set = set()
            iou_sampling_set = set(np.where(max_overlaps > floor_thr)[0])
            floor_thr = 0

        floor_neg_inds = list(floor_set & neg_set)
        iou_sampling_neg_inds = list(iou_sampling_set & neg_set)

        num_expected_iou_sampling = int(num_expected * (1 - floor_fraction))
        if len(iou_sampling_neg_inds) > num_expected_iou_sampling:
            if num_bins >= 2:
                iou_sampled_inds = libra_sample_via_interval(
                    max_overlaps,
                    set(iou_sampling_neg_inds), num_expected_iou_sampling,
                    floor_thr, num_bins, bg_thresh)
            else:
                iou_sampled_inds = np.random.choice(
                    iou_sampling_neg_inds,
                    size=num_expected_iou_sampling,
                    replace=False)
        else:
            iou_sampled_inds = np.array(iou_sampling_neg_inds, dtype=np.int)
        num_expected_floor = num_expected - len(iou_sampled_inds)
        if len(floor_neg_inds) > num_expected_floor:
            sampled_floor_inds = np.random.choice(
                floor_neg_inds, size=num_expected_floor, replace=False)
        else:
            sampled_floor_inds = np.array(floor_neg_inds, dtype=np.int)
        sampled_inds = np.concatenate((sampled_floor_inds, iou_sampled_inds))
        if len(sampled_inds) < num_expected:
            num_extra = num_expected - len(sampled_inds)
            extra_inds = np.array(list(neg_set - set(sampled_inds)))
            if len(extra_inds) > num_extra:
                extra_inds = np.random.choice(
                    extra_inds, size=num_extra, replace=False)
            sampled_inds = np.concatenate((sampled_inds, extra_inds))
        return paddle.to_tensor(sampled_inds)


def libra_label_box(anchors, gt_boxes, gt_classes, positive_overlap,
                    negative_overlap, num_classes):
    # TODO: use paddle API to speed up
    gt_classes = gt_classes.numpy()
    gt_overlaps = np.zeros((anchors.shape[0], num_classes))
    matches = np.zeros((anchors.shape[0]), dtype=np.int32)
    if len(gt_boxes) > 0:
        proposal_to_gt_overlaps = bbox_overlaps(anchors, gt_boxes).numpy()
        overlaps_argmax = proposal_to_gt_overlaps.argmax(axis=1)
        overlaps_max = proposal_to_gt_overlaps.max(axis=1)
        # Boxes which with non-zero overlap with gt boxes
        overlapped_boxes_ind = np.where(overlaps_max > 0)[0]
        overlapped_boxes_gt_classes = gt_classes[overlaps_argmax[
            overlapped_boxes_ind]]

        for idx in range(len(overlapped_boxes_ind)):
            gt_overlaps[overlapped_boxes_ind[idx], overlapped_boxes_gt_classes[
                idx]] = overlaps_max[overlapped_boxes_ind[idx]]
            matches[overlapped_boxes_ind[idx]] = overlaps_argmax[
                overlapped_boxes_ind[idx]]

    gt_overlaps = paddle.to_tensor(gt_overlaps)
    matches = paddle.to_tensor(matches)

    matched_vals = paddle.max(gt_overlaps, axis=1)
    match_labels = paddle.full(matches.shape, -1, dtype='int32')
    match_labels = paddle.where(matched_vals < negative_overlap,
                                paddle.zeros_like(match_labels), match_labels)
    match_labels = paddle.where(matched_vals >= positive_overlap,
                                paddle.ones_like(match_labels), match_labels)

    return matches, match_labels, matched_vals


def libra_sample_bbox(matches,
                      match_labels,
                      matched_vals,
                      gt_classes,
                      batch_size_per_im,
                      num_classes,
                      fg_fraction,
                      fg_thresh,
                      bg_thresh,
                      num_bins,
                      use_random=True,
                      is_cascade_rcnn=False):
    rois_per_image = int(batch_size_per_im)
    fg_rois_per_im = int(np.round(fg_fraction * rois_per_image))
    bg_rois_per_im = rois_per_image - fg_rois_per_im

    if is_cascade_rcnn:
        fg_inds = paddle.nonzero(matched_vals >= fg_thresh)
        bg_inds = paddle.nonzero(matched_vals < bg_thresh)
    else:
        matched_vals_np = matched_vals.numpy()
        match_labels_np = match_labels.numpy()

        # sample fg
        fg_inds = paddle.nonzero(matched_vals >= fg_thresh).flatten()
        fg_nums = int(np.minimum(fg_rois_per_im, fg_inds.shape[0]))
        if (fg_inds.shape[0] > fg_nums) and use_random:
            fg_inds = libra_sample_pos(matched_vals_np, match_labels_np,
                                       fg_inds.numpy(), fg_rois_per_im)
        fg_inds = fg_inds[:fg_nums]

        # sample bg
        bg_inds = paddle.nonzero(matched_vals < bg_thresh).flatten()
        bg_nums = int(np.minimum(rois_per_image - fg_nums, bg_inds.shape[0]))
        if (bg_inds.shape[0] > bg_nums) and use_random:
            bg_inds = libra_sample_neg(
                matched_vals_np,
                match_labels_np,
                bg_inds.numpy(),
                bg_rois_per_im,
                num_bins=num_bins,
                bg_thresh=bg_thresh)
        bg_inds = bg_inds[:bg_nums]

        sampled_inds = paddle.concat([fg_inds, bg_inds])

        gt_classes = paddle.gather(gt_classes, matches)
        gt_classes = paddle.where(match_labels == 0,
                                  paddle.ones_like(gt_classes) * num_classes,
                                  gt_classes)
        gt_classes = paddle.where(match_labels == -1,
                                  paddle.ones_like(gt_classes) * -1, gt_classes)
        sampled_gt_classes = paddle.gather(gt_classes, sampled_inds)

        return sampled_inds, sampled_gt_classes


def libra_generate_proposal_target(rpn_rois,
                                   gt_classes,
                                   gt_boxes,
                                   batch_size_per_im,
                                   fg_fraction,
                                   fg_thresh,
                                   bg_thresh,
                                   num_classes,
                                   use_random=True,
                                   is_cascade_rcnn=False,
                                   max_overlaps=None,
                                   num_bins=3):

    rois_with_gt = []
    tgt_labels = []
    tgt_bboxes = []
    sampled_max_overlaps = []
    tgt_gt_inds = []
    new_rois_num = []

    for i, rpn_roi in enumerate(rpn_rois):
        max_overlap = max_overlaps[i] if is_cascade_rcnn else None
        gt_bbox = gt_boxes[i]
        gt_class = paddle.squeeze(gt_classes[i], axis=-1)
        if is_cascade_rcnn:
            rpn_roi = filter_roi(rpn_roi, max_overlap)
        bbox = paddle.concat([rpn_roi, gt_bbox])

        # Step1: label bbox
        matches, match_labels, matched_vals = libra_label_box(
            bbox, gt_bbox, gt_class, fg_thresh, bg_thresh, num_classes)

        # Step2: sample bbox
        sampled_inds, sampled_gt_classes = libra_sample_bbox(
            matches, match_labels, matched_vals, gt_class, batch_size_per_im,
            num_classes, fg_fraction, fg_thresh, bg_thresh, num_bins,
            use_random, is_cascade_rcnn)

        # Step3: make output
        rois_per_image = paddle.gather(bbox, sampled_inds)
        sampled_gt_ind = paddle.gather(matches, sampled_inds)
        sampled_bbox = paddle.gather(gt_bbox, sampled_gt_ind)
        sampled_overlap = paddle.gather(matched_vals, sampled_inds)

        rois_per_image.stop_gradient = True
        sampled_gt_ind.stop_gradient = True
        sampled_bbox.stop_gradient = True
        sampled_overlap.stop_gradient = True

        tgt_labels.append(sampled_gt_classes)
        tgt_bboxes.append(sampled_bbox)
        rois_with_gt.append(rois_per_image)
        sampled_max_overlaps.append(sampled_overlap)
        tgt_gt_inds.append(sampled_gt_ind)
        new_rois_num.append(paddle.shape(sampled_inds)[0])
    new_rois_num = paddle.concat(new_rois_num)
    # rois_with_gt, tgt_labels, tgt_bboxes, tgt_gt_inds, new_rois_num
    return rois_with_gt, tgt_labels, tgt_bboxes, tgt_gt_inds, new_rois_num