# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://w_idxw.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import unittest import numpy as np import sys import math import paddle.fluid as fluid from op_test import OpTest def generate_proposal_labels_in_python( rpn_rois, gt_classes, gt_boxes, im_scales, batch_size_per_im, fg_fraction, fg_thresh, bg_thresh_hi, bg_thresh_lo, bbox_reg_weights, class_nums): rois = [] labels_int32 = [] bbox_targets = [] bbox_inside_weights = [] bbox_outside_weights = [] lod = [] assert len(rpn_rois) == len( im_scales), 'batch size of rpn_rois and ground_truth is not matched' for im_i in range(len(im_scales)): frcn_blobs = _sample_rois( rpn_rois[im_i], gt_classes[im_i], gt_boxes[im_i], im_scales[im_i], batch_size_per_im, fg_fraction, fg_thresh, bg_thresh_hi, bg_thresh_lo, bbox_reg_weights, class_nums) lod.append(frcn_blobs['rois'].shape[0]) rois.append(frcn_blobs['rois']) labels_int32.append(frcn_blobs['labels_int32']) bbox_targets.append(frcn_blobs['bbox_targets']) bbox_inside_weights.append(frcn_blobs['bbox_inside_weights']) bbox_outside_weights.append(frcn_blobs['bbox_outside_weights']) return rois, labels_int32, bbox_targets, bbox_inside_weights, bbox_outside_weights, lod def _sample_rois(rpn_rois, gt_classes, gt_boxes, im_scale, batch_size_per_im, fg_fraction, fg_thresh, bg_thresh_hi, bg_thresh_lo, bbox_reg_weights, class_nums): rois_per_image = int(batch_size_per_im) fg_rois_per_im = int(np.round(fg_fraction * rois_per_image)) # Roidb inv_im_scale = 1. / im_scale rpn_rois = rpn_rois * inv_im_scale boxes = np.vstack([gt_boxes, rpn_rois]) gt_overlaps = np.zeros((boxes.shape[0], class_nums)) box_to_gt_ind_map = np.zeros((boxes.shape[0]), dtype=np.int32) if len(gt_boxes) > 0: proposal_to_gt_overlaps = _bbox_overlaps(boxes, gt_boxes) overlaps_argmax = proposal_to_gt_overlaps.argmax(axis=1) overlaps_max = proposal_to_gt_overlaps.max(axis=1) # Boxes which with non-zero overlap with gt boxes overlapped_boxes_ind = np.where(overlaps_max > 0)[0] overlapped_boxes_gt_classes = gt_classes[overlaps_argmax[ overlapped_boxes_ind]] gt_overlaps[overlapped_boxes_ind, overlapped_boxes_gt_classes] = overlaps_max[ overlapped_boxes_ind] box_to_gt_ind_map[overlapped_boxes_ind] = overlaps_argmax[ overlapped_boxes_ind] max_overlaps = gt_overlaps.max(axis=1) max_classes = gt_overlaps.argmax(axis=1) # Foreground fg_inds = np.where(max_overlaps >= fg_thresh)[0] fg_rois_per_this_image = np.minimum(fg_rois_per_im, fg_inds.shape[0]) # Sample foreground if there are too many if fg_inds.shape[0] > fg_rois_per_this_image: fg_inds = np.random.choice( fg_inds, size=fg_rois_per_this_image, replace=False) # Background bg_inds = np.where((max_overlaps < bg_thresh_hi) & (max_overlaps >= bg_thresh_lo))[0] bg_rois_per_this_image = rois_per_image - fg_rois_per_this_image bg_rois_per_this_image = np.minimum(bg_rois_per_this_image, bg_inds.shape[0]) # Sample background if there are too many if bg_inds.shape[0] > bg_rois_per_this_image: bg_inds = np.random.choice( bg_inds, size=bg_rois_per_this_image, replace=False) keep_inds = np.append(fg_inds, bg_inds) sampled_labels = max_classes[keep_inds] sampled_labels[fg_rois_per_this_image:] = 0 sampled_boxes = boxes[keep_inds] sampled_gts = gt_boxes[box_to_gt_ind_map[keep_inds]] sampled_gts[fg_rois_per_this_image:, :] = gt_boxes[0] bbox_label_targets = _compute_targets(sampled_boxes, sampled_gts, sampled_labels, bbox_reg_weights) bbox_targets, bbox_inside_weights = _expand_bbox_targets(bbox_label_targets, class_nums) bbox_outside_weights = np.array( bbox_inside_weights > 0, dtype=bbox_inside_weights.dtype) # Scale rois sampled_rois = sampled_boxes * im_scale # Faster RCNN blobs frcn_blobs = dict( rois=sampled_rois, labels_int32=sampled_labels, bbox_targets=bbox_targets, bbox_inside_weights=bbox_inside_weights, bbox_outside_weights=bbox_outside_weights) return frcn_blobs def _bbox_overlaps(roi_boxes, gt_boxes): w1 = np.maximum(roi_boxes[:, 2] - roi_boxes[:, 0] + 1, 0) h1 = np.maximum(roi_boxes[:, 3] - roi_boxes[:, 1] + 1, 0) w2 = np.maximum(gt_boxes[:, 2] - gt_boxes[:, 0] + 1, 0) h2 = np.maximum(gt_boxes[:, 3] - gt_boxes[:, 1] + 1, 0) area1 = w1 * h1 area2 = w2 * h2 overlaps = np.zeros((roi_boxes.shape[0], gt_boxes.shape[0])) for ind1 in range(roi_boxes.shape[0]): for ind2 in range(gt_boxes.shape[0]): inter_x1 = np.maximum(roi_boxes[ind1, 0], gt_boxes[ind2, 0]) inter_y1 = np.maximum(roi_boxes[ind1, 1], gt_boxes[ind2, 1]) inter_x2 = np.minimum(roi_boxes[ind1, 2], gt_boxes[ind2, 2]) inter_y2 = np.minimum(roi_boxes[ind1, 3], gt_boxes[ind2, 3]) inter_w = np.maximum(inter_x2 - inter_x1 + 1, 0) inter_h = np.maximum(inter_y2 - inter_y1 + 1, 0) inter_area = inter_w * inter_h iou = inter_area / (area1[ind1] + area2[ind2] - inter_area) overlaps[ind1, ind2] = iou return overlaps def _compute_targets(roi_boxes, gt_boxes, labels, bbox_reg_weights): assert roi_boxes.shape[0] == gt_boxes.shape[0] assert roi_boxes.shape[1] == 4 assert gt_boxes.shape[1] == 4 targets = np.zeros(roi_boxes.shape) bbox_reg_weights = np.asarray(bbox_reg_weights) targets = _box_to_delta( ex_boxes=roi_boxes, gt_boxes=gt_boxes, weights=bbox_reg_weights) return np.hstack([labels[:, np.newaxis], targets]).astype( np.float32, copy=False) def _box_to_delta(ex_boxes, gt_boxes, weights): ex_w = ex_boxes[:, 2] - ex_boxes[:, 0] + 1 ex_h = ex_boxes[:, 3] - ex_boxes[:, 1] + 1 ex_ctr_x = ex_boxes[:, 0] + 0.5 * ex_w ex_ctr_y = ex_boxes[:, 1] + 0.5 * ex_h gt_w = gt_boxes[:, 2] - gt_boxes[:, 0] + 1 gt_h = gt_boxes[:, 3] - gt_boxes[:, 1] + 1 gt_ctr_x = gt_boxes[:, 0] + 0.5 * gt_w gt_ctr_y = gt_boxes[:, 1] + 0.5 * gt_h dx = (gt_ctr_x - ex_ctr_x) / ex_w / weights[0] dy = (gt_ctr_y - ex_ctr_y) / ex_h / weights[1] dw = (np.log(gt_w / ex_w)) / weights[2] dh = (np.log(gt_h / ex_h)) / weights[3] targets = np.vstack([dx, dy, dw, dh]).transpose() return targets def _expand_bbox_targets(bbox_targets_input, class_nums): class_labels = bbox_targets_input[:, 0] fg_inds = np.where(class_labels > 0)[0] bbox_targets = np.zeros((class_labels.shape[0], 4 * class_nums)) bbox_inside_weights = np.zeros(bbox_targets.shape) for ind in fg_inds: class_label = int(class_labels[ind]) start_ind = class_label * 4 end_ind = class_label * 4 + 4 bbox_targets[ind, start_ind:end_ind] = bbox_targets_input[ind, 1:] bbox_inside_weights[ind, start_ind:end_ind] = (1.0, 1.0, 1.0, 1.0) return bbox_targets, bbox_inside_weights class TestGenerateProposalLabelsOp(OpTest): def set_data(self): self.init_test_params() self.init_test_input() self.init_test_output() self.inputs = { 'RpnRois': (self.rpn_rois[0], self.rpn_rois_lod), 'GtClasses': (self.gt_classes[0], self.gts_lod), 'GtBoxes': (self.gt_boxes[0], self.gts_lod), 'ImScales': self.im_scales[0] } self.attrs = { 'batch_size_per_im': self.batch_size_per_im, 'fg_fraction': self.fg_fraction, 'fg_thresh': self.fg_thresh, 'bg_thresh_hi': self.bg_thresh_hi, 'bg_thresh_lo': self.bg_thresh_lo, 'bbox_reg_weights': self.bbox_reg_weights, 'class_nums': self.class_nums } self.outputs = { 'Rois': (self.rois[0], [self.lod]), 'LabelsInt32': (self.labels_int32[0], [self.lod]), 'BboxTargets': (self.bbox_targets[0], [self.lod]), 'BboxInsideWeights': (self.bbox_inside_weights[0], [self.lod]), 'BboxOutsideWeights': (self.bbox_outside_weights[0], [self.lod]), } def test_check_output(self): self.check_output() def setUp(self): self.op_type = 'generate_proposal_labels' self.set_data() def init_test_params(self): self.batch_size_per_im = 10 self.fg_fraction = 1.0 self.fg_thresh = 0.5 self.bg_thresh_hi = 0.5 self.bg_thresh_lo = 0.0 self.bbox_reg_weights = [0.1, 0.1, 0.2, 0.2] self.class_nums = 81 def init_test_input(self): np.random.seed(0) image_nums = 1 gt_nums = 6 # Keep same with batch_size_per_im for unittest proposal_nums = self.batch_size_per_im - gt_nums images_shape = [] self.im_scales = [] for i in range(image_nums): images_shape.append(np.random.randint(200, size=2)) self.im_scales.append(np.ones((1)).astype(np.float32)) self.rpn_rois, self.rpn_rois_lod = _generate_proposals(images_shape, proposal_nums) ground_truth, self.gts_lod = _generate_groundtruth( images_shape, self.class_nums, gt_nums) self.gt_classes = [gt['gt_classes'] for gt in ground_truth] self.gt_boxes = [gt['boxes'] for gt in ground_truth] def init_test_output(self): self.rois, self.labels_int32, self.bbox_targets, \ self.bbox_inside_weights, self.bbox_outside_weights, \ self.lod = generate_proposal_labels_in_python( self.rpn_rois, self.gt_classes, self.gt_boxes, self.im_scales, self.batch_size_per_im, self.fg_fraction, self.fg_thresh, self.bg_thresh_hi, self.bg_thresh_lo, self.bbox_reg_weights, self.class_nums ) def _generate_proposals(images_shape, proposal_nums): rpn_rois = [] rpn_rois_lod = [] num_proposals = 0 for i, image_shape in enumerate(images_shape): proposals = _generate_boxes(image_shape, proposal_nums) rpn_rois.append(proposals) num_proposals += len(proposals) rpn_rois_lod.append(num_proposals) return rpn_rois, [rpn_rois_lod] def _generate_groundtruth(images_shape, class_nums, gt_nums): ground_truth = [] gts_lod = [] num_gts = 0 for i, image_shape in enumerate(images_shape): # Avoid background gt_classes = np.random.randint( low=1, high=class_nums, size=gt_nums).astype(np.int32) gt_boxes = _generate_boxes(image_shape, gt_nums) ground_truth.append(dict(gt_classes=gt_classes, boxes=gt_boxes)) num_gts += len(gt_classes) gts_lod.append(num_gts) return ground_truth, [gts_lod] def _generate_boxes(image_size, box_nums): width = image_size[0] height = image_size[1] xywh = np.random.rand(box_nums, 4) xy1 = xywh[:, [0, 1]] * image_size wh = xywh[:, [2, 3]] * (image_size - xy1) xy2 = xy1 + wh boxes = np.hstack([xy1, xy2]) boxes[:, [0, 2]] = np.minimum(width - 1., np.maximum(0., boxes[:, [0, 2]])) boxes[:, [1, 3]] = np.minimum(height - 1., np.maximum(0., boxes[:, [1, 3]])) return boxes.astype(np.float32) if __name__ == '__main__': unittest.main()