From d8704f283b35e5bf94185e4e8bf6013536024f97 Mon Sep 17 00:00:00 2001 From: sunxl1988 <47514455+sunxl1988@users.noreply.github.com> Date: Fri, 24 Jul 2020 15:09:21 +0800 Subject: [PATCH] test=dygraph split target op into label&sample op (#1093) split target op into label&sample op --- ppdet/py_op/post_process.py | 106 +++---- ppdet/py_op/target.py | 534 ++++++++++++++++++------------------ 2 files changed, 328 insertions(+), 312 deletions(-) diff --git a/ppdet/py_op/post_process.py b/ppdet/py_op/post_process.py index 805389236..a2f972e66 100755 --- a/ppdet/py_op/post_process.py +++ b/ppdet/py_op/post_process.py @@ -6,6 +6,7 @@ from .bbox import delta2bbox, clip_bbox, expand_bbox, nms def bbox_post_process(bboxes, + bbox_nums, bbox_probs, bbox_deltas, im_info, @@ -14,30 +15,32 @@ def bbox_post_process(bboxes, nms_thresh=0.5, class_nums=81, bbox_reg_weights=[0.1, 0.1, 0.2, 0.2]): - bbox_nums = [0, bboxes.shape[0]] - bboxes_v = np.array(bboxes) - bbox_probs_v = np.array(bbox_probs) - bbox_deltas_v = np.array(bbox_deltas) - variance_v = np.array(bbox_reg_weights) - new_bboxes = [[] for _ in range(len(bbox_nums) - 1)] + + new_bboxes = [[] for _ in range(len(bbox_nums))] new_bbox_nums = [0] - for i in range(len(bbox_nums) - 1): - start = bbox_nums[i] - end = bbox_nums[i + 1] - if start == end: - continue - - bbox_deltas_n = bbox_deltas_v[start:end, :] # box delta - rois_n = bboxes_v[start:end, :] # box - rois_n = rois_n / im_info[i][2] # scale - rois_n = delta2bbox(bbox_deltas_n, rois_n, variance_v) - rois_n = clip_bbox(rois_n, im_info[i][:2] / im_info[i][2]) + st_num = 0 + end_num = 0 + for i in range(len(bbox_nums)): + bbox_num = bbox_nums[i] + end_num += bbox_num + + bbox = bboxes[st_num:end_num, :] # bbox + bbox = bbox / im_info[i][2] # scale + bbox_delta = bbox_deltas[st_num:end_num, :] # bbox delta + + # step1: decode + bbox = delta2bbox(bbox_delta, bbox, bbox_reg_weights) + + # step2: clip + bbox = clip_bbox(bbox, im_info[i][:2] / im_info[i][2]) + + # step3: nms cls_boxes = [[] for _ in range(class_nums)] - scores_n = bbox_probs_v[start:end, :] + scores_n = bbox_probs[st_num:end_num, :] for j in range(1, class_nums): inds = np.where(scores_n[:, j] > score_thresh)[0] scores_j = scores_n[inds, j] - rois_j = rois_n[inds, j * 4:(j + 1) * 4] + rois_j = bbox[inds, j * 4:(j + 1) * 4] dets_j = np.hstack((scores_j[:, np.newaxis], rois_j)).astype( np.float32, copy=False) keep = nms(dets_j, nms_thresh) @@ -48,6 +51,8 @@ def bbox_post_process(bboxes, np.float32, copy=False) cls_boxes[j] = nms_dets + st_num += bbox_num + # Limit to max_per_image detections **over all classes** image_scores = np.hstack( [cls_boxes[j][:, 1] for j in range(1, class_nums)]) @@ -58,7 +63,7 @@ def bbox_post_process(bboxes, cls_boxes[j] = cls_boxes[j][keep, :] new_bboxes_n = np.vstack([cls_boxes[j] for j in range(1, class_nums)]) new_bboxes[i] = new_bboxes_n - new_bbox_nums.append(len(new_bboxes_n) + new_bbox_nums[-1]) + new_bbox_nums.append(len(new_bboxes_n)) labels = new_bboxes_n[:, 0] scores = new_bboxes_n[:, 1] boxes = new_bboxes_n[:, 2:] @@ -68,27 +73,29 @@ def bbox_post_process(bboxes, @jit -def mask_post_process(bbox_nums, bboxes, masks, im_info): - bboxes = np.array(bboxes) - M = cfg.resolution - scale = (M + 2.0) / M - masks_v = np.array(masks) +def mask_post_process(bboxes, bbox_nums, masks, im_info, resolution=14): + scale = (resolution + 2.0) / resolution boxes = bboxes[:, 2:] labels = bboxes[:, 0] - segms_results = [[] for _ in range(len(bbox_nums) - 1)] + segms_results = [[] for _ in range(len(bbox_nums))] sum = 0 - for i in range(len(bbox_nums) - 1): - bboxes_n = bboxes[bbox_nums[i]:bbox_nums[i + 1]] + st_num = 0 + end_num = 0 + for i in range(len(bbox_nums)): + bbox_num = bbox_nums[i] + end_num += bbox_num + cls_segms = [] - masks_n = masks_v[bbox_nums[i]:bbox_nums[i + 1]] - boxes_n = boxes[bbox_nums[i]:bbox_nums[i + 1]] - labels_n = labels[bbox_nums[i]:bbox_nums[i + 1]] + boxes_n = boxes[st_num:end_num] + labels_n = labels[st_num:end_num] + masks_n = masks[st_num:end_num] + im_h = int(round(im_info[i][0] / im_info[i][2])) im_w = int(round(im_info[i][1] / im_info[i][2])) boxes_n = expand_boxes(boxes_n, scale) boxes_n = boxes_n.astype(np.int32) padded_mask = np.zeros((M + 2, M + 2), dtype=np.float32) - for j in range(len(bboxes_n)): + for j in range(len(boxes_n)): class_id = int(labels_n[j]) padded_mask[1:-1, 1:-1] = masks_n[j, class_id, :, :] @@ -114,28 +121,24 @@ def mask_post_process(bbox_nums, bboxes, masks, im_info): im_mask[:, :, np.newaxis], order='F'))[0] cls_segms.append(rle) segms_results[i] = np.array(cls_segms)[:, np.newaxis] - segms_results = np.vstack([segms_results[k] for k in range(len(lod) - 1)]) + segms_results = np.vstack([segms_results[k] for k in range(len(bbox_nums))]) bboxes = np.hstack([segms_results, bboxes]) return bboxes[:, :3] @jit -def get_det_res(bbox_nums, bbox, image_id, num_id_to_cat_id_map, batch_size=1): +def get_det_res(bboxes, bbox_nums, image_id, num_id_to_cat_id_map, + batch_size=1): det_res = [] - bbox_v = np.array(bbox) - if bbox_v.shape == ( - 1, - 1, ): - return dts_res - assert (len(bbox_nums) == batch_size + 1), \ - "Error bbox_nums Tensor offset dimension. bbox_nums({}) vs. batch_size({})"\ - .format(len(bbox_nums), batch_size) k = 0 - for i in range(batch_size): - dt_num_this_img = bbox_nums[i + 1] - bbox_nums[i] + for i in range(len(bbox_nums)): image_id = int(image_id[i][0]) - for j in range(dt_num_this_img): - dt = bbox_v[k] + image_width = int(image_shape[i][1]) + image_height = int(image_shape[i][2]) + + det_nums = bbox_nums[i] + for j in range(det_nums): + dt = bboxes[k] k = k + 1 num_id, score, xmin, ymin, xmax, ymax = dt.tolist() category_id = num_id_to_cat_id_map[num_id] @@ -153,15 +156,14 @@ def get_det_res(bbox_nums, bbox, image_id, num_id_to_cat_id_map, batch_size=1): @jit -def get_seg_res(mask_nums, mask, image_id, num_id_to_cat_id_map, batch_size=1): +def get_seg_res(masks, mask_nums, image_id, num_id_to_cat_id_map): seg_res = [] - mask_v = np.array(mask) k = 0 - for i in range(batch_size): + for i in range(len(mask_nums)): image_id = int(image_id[i][0]) - dt_num_this_img = mask_nums[i + 1] - mask_nums[i] - for j in range(dt_num_this_img): - dt = mask_v[k] + det_nums = mask_nums[i] + for j in range(det_nums): + dt = masks[k] k = k + 1 sg, num_id, score = dt.tolist() cat_id = num_id_to_cat_id_map[num_id] diff --git a/ppdet/py_op/target.py b/ppdet/py_op/target.py index 004e2a538..fb949ea39 100755 --- a/ppdet/py_op/target.py +++ b/ppdet/py_op/target.py @@ -7,7 +7,7 @@ from .mask import * @jit -def generate_rpn_anchor_target(anchor_box, +def generate_rpn_anchor_target(anchors, gt_boxes, is_crowd, im_info, @@ -16,85 +16,106 @@ def generate_rpn_anchor_target(anchor_box, rpn_positive_overlap, rpn_negative_overlap, rpn_fg_fraction, - use_random=True): - anchor_num = anchor_box.shape[0] + use_random=True, + anchor_reg_weights=[1., 1., 1., 1.]): + anchor_num = anchors.shape[0] batch_size = gt_boxes.shape[0] + loc_indexes = [] + cls_indexes = [] + tgt_labels = [] + tgt_deltas = [] + anchor_inside_weights = [] + for i in range(batch_size): + + # TODO: move anchor filter into anchor generator im_height = im_info[i][0] im_width = im_info[i][1] im_scale = im_info[i][2] if rpn_straddle_thresh >= 0: - # Only keep anchors inside the image by a margin of straddle_thresh - inds_inside = np.where( - (anchor_box[:, 0] >= -rpn_straddle_thresh - ) & (anchor_box[:, 1] >= -rpn_straddle_thresh) & ( - anchor_box[:, 2] < im_width + rpn_straddle_thresh) & ( - anchor_box[:, 3] < im_height + rpn_straddle_thresh))[0] - # keep only inside anchors - inside_anchors = anchor_box[inds_inside, :] + anchor_inds = np.where((anchors[:, 0] >= -rpn_straddle_thresh) & ( + anchors[:, 1] >= -rpn_straddle_thresh) & ( + anchors[:, 2] < im_width + rpn_straddle_thresh) & ( + anchors[:, 3] < im_height + rpn_straddle_thresh))[0] + anchor = anchors[anchor_inds, :] else: - inds_inside = np.arange(anchor_box.shape[0]) - inside_anchors = anchor_box - gt_boxes_slice = gt_boxes[i] * im_scale - is_crowd_slice = is_crowd[i] + anchor_inds = np.arange(anchors.shape[0]) + anchor = anchors + gt_bbox = gt_boxes[i] * im_scale + is_crowd_slice = is_crowd[i] not_crowd_inds = np.where(is_crowd_slice == 0)[0] - gt_boxes_slice = gt_boxes_slice[not_crowd_inds] - iou = bbox_overlaps(inside_anchors, gt_boxes_slice) - - loc_inds, score_inds, labels, gt_inds, bbox_inside_weight = _sample_anchor( - iou, rpn_batch_size_per_im, rpn_positive_overlap, - rpn_negative_overlap, rpn_fg_fraction, use_random) - # unmap to all anchor - loc_inds = inds_inside[loc_inds] - score_inds = inds_inside[score_inds] - sampled_anchor = anchor_box[loc_inds] - sampled_gt = gt_boxes_slice[gt_inds] - box_deltas = bbox2delta(sampled_anchor, sampled_gt, [1., 1., 1., 1.]) - - if i == 0: - loc_indexes = loc_inds - score_indexes = score_inds - tgt_labels = labels - tgt_bboxes = box_deltas - bbox_inside_weights = bbox_inside_weight - else: - loc_indexes = np.concatenate( - [loc_indexes, loc_inds + i * anchor_num]) - score_indexes = np.concatenate( - [score_indexes, score_inds + i * anchor_num]) - tgt_labels = np.concatenate([tgt_labels, labels]) - tgt_bboxes = np.vstack([tgt_bboxes, box_deltas]) - bbox_inside_weights = np.vstack([bbox_inside_weights, \ - bbox_inside_weight]) - tgt_labels = tgt_labels.astype('float32') - tgt_bboxes = tgt_bboxes.astype('float32') - return loc_indexes, score_indexes, tgt_labels, tgt_bboxes, bbox_inside_weights + gt_bbox = gt_bbox[not_crowd_inds] + + # Step1: match anchor and gt_bbox + anchor_gt_bbox_inds, anchor_gt_bbox_iou, labels = label_anchor(anchor, + gt_bbox) + + # Step2: sample anchor + fg_inds, bg_inds, fg_fake_inds, fake_num = sample_anchor( + anchor_gt_bbox_iou, labels, rpn_positive_overlap, + rpn_negative_overlap, rpn_batch_size_per_im, rpn_fg_fraction, + use_random) + + # Step3: make output + loc_inds = np.hstack([fg_fake_inds, fg_inds]) + cls_inds = np.hstack([fg_inds, bg_inds]) + + sampled_labels = labels[cls_inds] + + sampled_anchors = anchor[loc_inds] + sampled_gt_boxes = gt_bbox[anchor_gt_bbox_inds[loc_inds]] + sampled_deltas = bbox2delta(sampled_anchors, sampled_gt_boxes, + anchor_reg_weights) + + anchor_inside_weight = np.zeros((len(loc_inds), 4), dtype=np.float32) + anchor_inside_weight[fake_num:, :] = 1 + + loc_indexes.append(anchor_inds[loc_inds] + i * anchor_num) + cls_indexes.append(anchor_inds[cls_inds] + i * anchor_num) + tgt_labels.append(sampled_labels) + tgt_deltas.append(sampled_deltas) + anchor_inside_weights.append(anchor_inside_weight) + + loc_indexes = np.concatenate(loc_indexes) + cls_indexes = np.concatenate(cls_indexes) + tgt_labels = np.concatenate(tgt_labels).astype('float32') + tgt_deltas = np.vstack(tgt_deltas).astype('float32') + anchor_inside_weights = np.vstack(anchor_inside_weights) + + return loc_indexes, cls_indexes, tgt_labels, tgt_deltas, anchor_inside_weights @jit -def _sample_anchor(anchor_by_gt_overlap, - rpn_batch_size_per_im, - rpn_positive_overlap, - rpn_negative_overlap, - rpn_fg_fraction, - use_random=True): - - anchor_to_gt_argmax = anchor_by_gt_overlap.argmax(axis=1) - anchor_to_gt_max = anchor_by_gt_overlap[np.arange( - anchor_by_gt_overlap.shape[0]), anchor_to_gt_argmax] - - gt_to_anchor_argmax = anchor_by_gt_overlap.argmax(axis=0) - gt_to_anchor_max = anchor_by_gt_overlap[gt_to_anchor_argmax, np.arange( - anchor_by_gt_overlap.shape[1])] - anchors_with_max_overlap = np.where( - anchor_by_gt_overlap == gt_to_anchor_max)[0] - - labels = np.ones((anchor_by_gt_overlap.shape[0], ), dtype=np.int32) * -1 - labels[anchors_with_max_overlap] = 1 - labels[anchor_to_gt_max >= rpn_positive_overlap] = 1 +def label_anchor(anchors, gt_boxes): + iou = compute_iou(anchors, gt_boxes) + + # every gt's anchor's index + gt_bbox_anchor_inds = iou.argmax(axis=0) + gt_bbox_anchor_iou = iou[gt_bbox_anchor_inds, np.arange(iou.shape[1])] + gt_bbox_anchor_iou_inds = np.where(iou == gt_bbox_anchor_iou)[0] + + # every anchor's gt bbox's index + anchor_gt_bbox_inds = iou.argmax(axis=1) + anchor_gt_bbox_iou = iou[np.arange(iou.shape[0]), anchor_gt_bbox_inds] + + labels = np.ones((iou.shape[0], ), dtype=np.int32) * -1 + labels[gt_bbox_anchor_iou_inds] = 1 + + return anchor_gt_bbox_inds, anchor_gt_bbox_iou, labels + +@jit +def sample_anchor(anchor_gt_bbox_iou, + labels, + rpn_positive_overlap, + rpn_negative_overlap, + rpn_batch_size_per_im, + rpn_fg_fraction, + use_random=True): + + labels[anchor_gt_bbox_iou >= rpn_positive_overlap] = 1 num_fg = int(rpn_fg_fraction * rpn_batch_size_per_im) fg_inds = np.where(labels == 1)[0] if len(fg_inds) > num_fg and use_random: @@ -102,12 +123,11 @@ def _sample_anchor(anchor_by_gt_overlap, fg_inds, size=(len(fg_inds) - num_fg), replace=False) else: disable_inds = fg_inds[num_fg:] - labels[disable_inds] = -1 fg_inds = np.where(labels == 1)[0] num_bg = rpn_batch_size_per_im - np.sum(labels == 1) - bg_inds = np.where(anchor_to_gt_max < rpn_negative_overlap)[0] + bg_inds = np.where(anchor_gt_bbox_iou < rpn_negative_overlap)[0] if len(bg_inds) > num_bg and use_random: enable_inds = bg_inds[np.random.randint(len(bg_inds), size=num_bg)] else: @@ -125,15 +145,7 @@ def _sample_anchor(anchor_by_gt_overlap, fg_inds = np.where(labels == 1)[0] bg_inds = np.where(labels == 0)[0] - loc_index = np.hstack([fg_fake_inds, fg_inds]) - score_index = np.hstack([fg_inds, bg_inds]) - labels = labels[score_index] - - gt_inds = anchor_to_gt_argmax[loc_index] - - bbox_inside_weight = np.zeros((len(loc_index), 4), dtype=np.float32) - bbox_inside_weight[fake_num:, :] = 1 - return loc_index, score_index, labels, gt_inds, bbox_inside_weight + return fg_inds, bg_inds, fg_fake_inds, fake_num @jit @@ -155,148 +167,153 @@ def generate_proposal_target(rpn_rois, is_cascade_rcnn=False): rois = [] - labels_int32 = [] - bbox_targets = [] - bbox_inside_weights = [] - bbox_outside_weights = [] + tgt_labels = [] + tgt_deltas = [] + rois_inside_weights = [] + rois_outside_weights = [] rois_nums = [] - batch_size = gt_boxes.shape[0] - # TODO: modify here - # rpn_rois = rpn_rois.reshape(batch_size, -1, 4) st_num = 0 + end_num = 0 for im_i in range(len(rpn_rois_nums)): rpn_rois_num = rpn_rois_nums[im_i] - frcn_blobs = _sample_rois( - rpn_rois[st_num:rpn_rois_num], gt_classes[im_i], is_crowd[im_i], - gt_boxes[im_i], im_info[im_i], batch_size_per_im, fg_fraction, - fg_thresh, bg_thresh_hi, bg_thresh_lo, bbox_reg_weights, class_nums, + end_num += rpn_rois_num + + rpn_roi = rpn_rois[st_num:end_num] + im_scale = im_info[im_i][2] + rpn_roi = rpn_roi / im_scale + gt_bbox = gt_boxes[im_i] + + if is_cascade_rcnn: + rpn_roi = rpn_roi[gt_bbox.shape[0]:, :] + bbox = np.vstack([gt_bbox, rpn_roi]) + + # Step1: label bbox + roi_gt_bbox_inds, roi_gt_bbox_iou, labels, = label_bbox( + bbox, gt_bbox, gt_classes[im_i], is_crowd[im_i]) + + # Step2: sample bbox + if is_cascade_rcnn: + ws = bbox[:, 2] - bbox[:, 0] + 1 + hs = bbox[:, 3] - bbox[:, 1] + 1 + keep = np.where((ws > 0) & (hs > 0))[0] + bbox = bbox[keep] + + fg_inds, bg_inds, fg_nums = sample_bbox( + roi_gt_bbox_iou, batch_size_per_im, fg_fraction, fg_thresh, + bg_thresh_hi, bg_thresh_lo, bbox_reg_weights, class_nums, use_random, is_cls_agnostic, is_cascade_rcnn) - st_num = rpn_rois_num - rois.append(frcn_blobs['rois']) - labels_int32.append(frcn_blobs['labels_int32']) - bbox_targets.append(frcn_blobs['bbox_targets']) - bbox_inside_weights.append(frcn_blobs['bbox_inside_weights']) - bbox_outside_weights.append(frcn_blobs['bbox_outside_weights']) - rois_nums.append(frcn_blobs['rois'].shape[0]) + # Step3: make output + sampled_inds = np.append(fg_inds, bg_inds) + + sampled_labels = labels[sampled_inds] + sampled_labels[fg_nums:] = 0 + + sampled_boxes = bbox[sampled_inds] + sampled_gt_boxes = gt_bbox[roi_gt_bbox_inds[sampled_inds]] + sampled_gt_boxes[fg_nums:, :] = gt_bbox[0] + sampled_deltas = compute_bbox_targets(sampled_boxes, sampled_gt_boxes, + sampled_labels, bbox_reg_weights) + sampled_deltas, bbox_inside_weights = expand_bbox_targets( + sampled_deltas, class_nums, is_cls_agnostic) + bbox_outside_weights = np.array( + bbox_inside_weights > 0, dtype=bbox_inside_weights.dtype) + + roi = sampled_boxes * im_scale + st_num += rpn_rois_num + + rois.append(roi) + rois_nums.append(roi.shape[0]) + tgt_labels.append(sampled_labels) + tgt_deltas.append(sampled_deltas) + rois_inside_weights.append(bbox_inside_weights) + rois_outside_weights.append(bbox_outside_weights) rois = np.concatenate(rois, axis=0).astype(np.float32) - bbox_labels = np.concatenate( - labels_int32, axis=0).astype(np.int32).reshape(-1, 1) - bbox_gts = np.concatenate(bbox_targets, axis=0).astype(np.float32) - bbox_inside_weights = np.concatenate( - bbox_inside_weights, axis=0).astype(np.float32) - bbox_outside_weights = np.concatenate( - bbox_outside_weights, axis=0).astype(np.float32) + tgt_labels = np.concatenate( + tgt_labels, axis=0).astype(np.int32).reshape(-1, 1) + tgt_deltas = np.concatenate(tgt_deltas, axis=0).astype(np.float32) + rois_inside_weights = np.concatenate( + rois_inside_weights, axis=0).astype(np.float32) + rois_outside_weights = np.concatenate( + rois_outside_weights, axis=0).astype(np.float32) rois_nums = np.asarray(rois_nums, np.int32) - return rois, bbox_labels, bbox_gts, bbox_inside_weights, bbox_outside_weights, rois_nums + return rois, tgt_labels, tgt_deltas, rois_inside_weights, rois_outside_weights, rois_nums @jit -def _sample_rois(rpn_rois, - gt_classes, - is_crowd, - gt_boxes, - im_info, - batch_size_per_im, - fg_fraction, - fg_thresh, - bg_thresh_hi, - bg_thresh_lo, - bbox_reg_weights, - class_nums, - use_random=True, - is_cls_agnostic=False, - is_cascade_rcnn=False): - rois_per_image = int(batch_size_per_im) - fg_rois_per_im = int(np.round(fg_fraction * rois_per_image)) - - # Roidb - im_scale = im_info[2] - inv_im_scale = 1. / im_scale - rpn_rois = rpn_rois * inv_im_scale - if is_cascade_rcnn: - rpn_rois = rpn_rois[gt_boxes.shape[0]:, :] - boxes = np.vstack([gt_boxes, rpn_rois]) - gt_overlaps = np.zeros((boxes.shape[0], class_nums)) - box_to_gt_ind_map = np.zeros((boxes.shape[0]), dtype=np.int32) - if len(gt_boxes) > 0: - proposal_to_gt_overlaps = bbox_overlaps(boxes, gt_boxes) - overlaps_argmax = proposal_to_gt_overlaps.argmax(axis=1) - overlaps_max = proposal_to_gt_overlaps.max(axis=1) - # Boxes which with non-zero overlap with gt boxes - overlapped_boxes_ind = np.where(overlaps_max > 0)[0].astype('int32') - overlapped_boxes_gt_classes = gt_classes[overlaps_argmax[ - overlapped_boxes_ind]].astype('int32') - gt_overlaps[overlapped_boxes_ind, - overlapped_boxes_gt_classes] = overlaps_max[ - overlapped_boxes_ind] - box_to_gt_ind_map[overlapped_boxes_ind] = overlaps_argmax[ - overlapped_boxes_ind] +def label_bbox(boxes, + gt_boxes, + gt_classes, + is_crowd, + class_nums=81, + is_cascade_rcnn=False): + + iou = compute_iou(boxes, gt_boxes) + + # every roi's gt box's index + roi_gt_bbox_inds = np.zeros((boxes.shape[0]), dtype=np.int32) + roi_gt_bbox_iou = np.zeros((boxes.shape[0], class_nums)) + + iou_argmax = iou.argmax(axis=1) + iou_max = iou.max(axis=1) + overlapped_boxes_ind = np.where(iou_max > 0)[0].astype('int32') + roi_gt_bbox_inds[overlapped_boxes_ind] = iou_argmax[overlapped_boxes_ind] + overlapped_boxes_gt_classes = gt_classes[iou_argmax[ + overlapped_boxes_ind]].astype('int32') + roi_gt_bbox_iou[overlapped_boxes_ind, + overlapped_boxes_gt_classes] = iou_max[overlapped_boxes_ind] crowd_ind = np.where(is_crowd)[0] - gt_overlaps[crowd_ind] = -1 + roi_gt_bbox_iou[crowd_ind] = -1 + + labels = roi_gt_bbox_iou.argmax(axis=1) - max_overlaps = gt_overlaps.max(axis=1) - max_classes = gt_overlaps.argmax(axis=1) + return roi_gt_bbox_inds, roi_gt_bbox_iou, labels + + +@jit +def sample_bbox(roi_gt_bbox_iou, + batch_size_per_im, + fg_fraction, + fg_thresh, + bg_thresh_hi, + bg_thresh_lo, + bbox_reg_weights, + class_nums, + use_random=True, + is_cls_agnostic=False, + is_cascade_rcnn=False): + + roi_gt_bbox_iou_max = roi_gt_bbox_iou.max(axis=1) + rois_per_image = int(batch_size_per_im) + fg_rois_per_im = int(np.round(fg_fraction * rois_per_image)) - # Cascade RCNN Decode Filter if is_cascade_rcnn: - ws = boxes[:, 2] - boxes[:, 0] + 1 - hs = boxes[:, 3] - boxes[:, 1] + 1 - keep = np.where((ws > 0) & (hs > 0))[0] - boxes = boxes[keep] - fg_inds = np.where(max_overlaps >= fg_thresh)[0] - bg_inds = np.where((max_overlaps < bg_thresh_hi) & (max_overlaps >= - bg_thresh_lo))[0] - fg_rois_per_this_image = fg_inds.shape[0] - bg_rois_per_this_image = bg_inds.shape[0] + fg_inds = np.where(roi_gt_bbox_iou_max >= fg_thresh)[0] + bg_inds = np.where((roi_gt_bbox_iou_max < bg_thresh_hi) & ( + roi_gt_bbox_iou_max >= bg_thresh_lo))[0] + fg_nums = fg_inds.shape[0] + bg_nums = bg_inds.shape[0] else: - # Foreground - fg_inds = np.where(max_overlaps >= fg_thresh)[0] - fg_rois_per_this_image = np.minimum(fg_rois_per_im, fg_inds.shape[0]) - # Sample foreground if there are too many - if (fg_inds.shape[0] > fg_rois_per_this_image) and use_random: - fg_inds = np.random.choice( - fg_inds, size=fg_rois_per_this_image, replace=False) - fg_inds = fg_inds[:fg_rois_per_this_image] - # Background - bg_inds = np.where((max_overlaps < bg_thresh_hi) & (max_overlaps >= - bg_thresh_lo))[0] - bg_rois_per_this_image = rois_per_image - fg_rois_per_this_image - bg_rois_per_this_image = np.minimum(bg_rois_per_this_image, - bg_inds.shape[0]) - # Sample background if there are too many - if (bg_inds.shape[0] > bg_rois_per_this_image) and use_random: - bg_inds = np.random.choice( - bg_inds, size=bg_rois_per_this_image, replace=False) - bg_inds = bg_inds[:bg_rois_per_this_image] - - keep_inds = np.append(fg_inds, bg_inds) - sampled_labels = max_classes[keep_inds] - sampled_labels[fg_rois_per_this_image:] = 0 - sampled_boxes = boxes[keep_inds] - sampled_gts = gt_boxes[box_to_gt_ind_map[keep_inds]] - sampled_gts[fg_rois_per_this_image:, :] = gt_boxes[0] - bbox_label_targets = compute_bbox_targets(sampled_boxes, sampled_gts, - sampled_labels, bbox_reg_weights) - bbox_targets, bbox_inside_weights = expand_bbox_targets( - bbox_label_targets, class_nums, is_cls_agnostic) - bbox_outside_weights = np.array( - bbox_inside_weights > 0, dtype=bbox_inside_weights.dtype) + # sampe fg + fg_inds = np.where(roi_gt_bbox_iou_max >= fg_thresh)[0] + fg_nums = np.minimum(fg_rois_per_im, fg_inds.shape[0]) + if (fg_inds.shape[0] > fg_nums) and use_random: + fg_inds = np.random.choice(fg_inds, size=fg_nums, replace=False) + fg_inds = fg_inds[:fg_nums] - # Scale rois - sampled_rois = sampled_boxes * im_scale + # sample bg + bg_inds = np.where((roi_gt_bbox_iou_max < bg_thresh_hi) & ( + roi_gt_bbox_iou_max >= bg_thresh_lo))[0] + bg_nums = rois_per_image - fg_nums + bg_nums = np.minimum(bg_nums, bg_inds.shape[0]) + if (bg_inds.shape[0] > bg_nums) and use_random: + bg_inds = np.random.choice(bg_inds, size=bg_nums, replace=False) + bg_inds = bg_inds[:bg_nums] - # Faster RCNN blobs - frcn_blobs = dict( - rois=sampled_rois, - labels_int32=sampled_labels, - bbox_targets=bbox_targets, - bbox_inside_weights=bbox_inside_weights, - bbox_outside_weights=bbox_outside_weights) - return frcn_blobs + return fg_inds, bg_inds, fg_nums @jit @@ -306,16 +323,42 @@ def generate_mask_target(im_info, gt_classes, is_crowd, gt_segms, rois, rois_has_mask_int32 = [] mask_int32 = [] st_num = 0 - for i in range(len(rois_nums)): - rois_num = rois_nums[i] - mask_blob = _sample_mask( - rois[st_num:rois_num], labels_int32[st_num:rois_num], gt_segms[i], - im_info[i], gt_classes[i], is_crowd[i], num_classes, resolution) - - st_num = rois_num - mask_rois.append(mask_blob['mask_rois']) - rois_has_mask_int32.append(mask_blob['roi_has_mask_int32']) - mask_int32.append(mask_blob['mask_int32']) + end_num = 0 + for k in range(len(rois_nums)): + rois_num = rois_nums[k] + end_num += rois_num + + # remove padding + gt_polys = gt_segms[k] + new_gt_polys = [] + for i in range(gt_polys.shape[0]): + gt_segs = [] + for j in range(gt_polys[i].shape[0]): + new_poly = [] + polys = gt_polys[i][j] + for ii in range(polys.shape[0]): + x, y = polys[ii] + if (x == -1 and y == -1): + continue + elif (x >= 0 and y >= 0): + new_poly.append([x, y]) # array, one poly + if len(new_poly) > 0: + gt_segs.append(new_poly) + new_gt_polys.append(gt_segs) + + im_scale = im_info[k][2] + boxes = rois[st_num:end_num] / im_scale + + bbox_fg, bbox_has_mask, masks = sample_mask( + boxes, new_gt_polys, labels_int32[st_num:rois_num], gt_classes[k], + is_crowd[k], num_classes, resolution) + + st_num += rois_num + + mask_rois.append(bbox_fg * im_scale) + rois_has_mask_int32.append(bbox_has_mask) + mask_int32.append(masks) + mask_rois = np.concatenate(mask_rois, axis=0).astype(np.float32) rois_has_mask_int32 = np.concatenate( rois_has_mask_int32, axis=0).astype(np.int32) @@ -325,73 +368,44 @@ def generate_mask_target(im_info, gt_classes, is_crowd, gt_segms, rois, @jit -def _sample_mask( - rois, - label_int32, +def sample_mask( + boxes, gt_polys, - im_info, + label_int32, gt_classes, is_crowd, num_classes, resolution, ): - # remove padding - new_gt_polys = [] - for i in range(gt_polys.shape[0]): - gt_segs = [] - for j in range(gt_polys[i].shape[0]): - new_poly = [] - polys = gt_polys[i][j] - for ii in range(polys.shape[0]): - x, y = polys[ii] - if (x == -1 and y == -1): - continue - elif (x >= 0 and y >= 0): - new_poly.append([x, y]) # array, one poly - if len(new_poly) > 0: - gt_segs.append(new_poly) - new_gt_polys.append(gt_segs) - - im_scale = im_info[2] - sample_boxes = rois / im_scale - - polys_gt_inds = np.where((gt_classes > 0) & (is_crowd == 0))[0] - - polys_gt = [new_gt_polys[i] for i in polys_gt_inds] - boxes_from_polys = polys_to_boxes(polys_gt) + gt_polys_inds = np.where((gt_classes > 0) & (is_crowd == 0))[0] + _gt_polys = [gt_polys[i] for i in gt_polys_inds] + boxes_from_polys = polys_to_boxes(_gt_polys) + fg_inds = np.where(label_int32 > 0)[0] - roi_has_mask = fg_inds.copy() + bbox_has_mask = fg_inds.copy() if fg_inds.shape[0] > 0: - mask_class_labels = label_int32[fg_inds] - masks = np.zeros((fg_inds.shape[0], resolution**2), dtype=np.int32) - rois_fg = sample_boxes[fg_inds] + labels_fg = label_int32[fg_inds] + masks_fg = np.zeros((fg_inds.shape[0], resolution**2), dtype=np.int32) + bbox_fg = boxes[fg_inds] - overlaps_bbfg_bbpolys = bbox_overlaps_mask(rois_fg, boxes_from_polys) - fg_polys_inds = np.argmax(overlaps_bbfg_bbpolys, axis=1) + iou = bbox_overlaps_mask(bbox_fg, boxes_from_polys) + fg_polys_inds = np.argmax(iou, axis=1) - for i in range(rois_fg.shape[0]): - fg_polys_ind = fg_polys_inds[i] - poly_gt = polys_gt[fg_polys_ind] - roi_fg = rois_fg[i] + for i in range(bbox_fg.shape[0]): + poly_gt = _gt_polys[fg_polys_inds[i]] + roi_fg = bbox_fg[i] mask = polys_to_mask_wrt_box(poly_gt, roi_fg, resolution) mask = np.array(mask > 0, dtype=np.int32) - masks[i, :] = np.reshape(mask, resolution**2) + masks_fg[i, :] = np.reshape(mask, resolution**2) else: bg_inds = np.where(label_int32 == 0)[0] - rois_fg = sample_boxes[bg_inds[0]].reshape((1, -1)) - masks = -np.ones((1, resolution**2), dtype=np.int32) - mask_class_labels = np.zeros((1, )) - roi_has_mask = np.append(roi_has_mask, 0) - - masks = expand_mask_targets(masks, mask_class_labels, resolution, - num_classes) + bbox_fg = boxes[bg_inds[0]].reshape((1, -1)) + masks_fg = -np.ones((1, resolution**2), dtype=np.int32) + labels_fg = np.zeros((1, )) + bbox_has_mask = np.append(bbox_has_mask, 0) - rois_fg *= im_scale - mask_blob = dict() - mask_blob['mask_rois'] = rois_fg - mask_blob['roi_has_mask_int32'] = roi_has_mask - mask_blob['mask_int32'] = masks + masks = expand_mask_targets(masks_fg, labels_fg, resolution, num_classes) - return mask_blob + return bbox_fg, bbox_has_mask, masks -- GitLab