diff --git a/paddle/fluid/API.spec b/paddle/fluid/API.spec index 29bf80270f8cc6b92698677fe7de32601a13488a..2640ed1815c97b8b5ceca32b5e1c4cc69ea225d5 100644 --- a/paddle/fluid/API.spec +++ b/paddle/fluid/API.spec @@ -350,7 +350,7 @@ paddle.fluid.layers.detection_map (ArgSpec(args=['detect_res', 'label', 'class_n paddle.fluid.layers.rpn_target_assign (ArgSpec(args=['bbox_pred', 'cls_logits', 'anchor_box', 'anchor_var', 'gt_boxes', 'is_crowd', 'im_info', 'rpn_batch_size_per_im', 'rpn_straddle_thresh', 'rpn_fg_fraction', 'rpn_positive_overlap', 'rpn_negative_overlap', 'use_random'], varargs=None, keywords=None, defaults=(256, 0.0, 0.5, 0.7, 0.3, True)), ('document', '1e164a56fe9376e18a56d22563d9f801')) paddle.fluid.layers.anchor_generator (ArgSpec(args=['input', 'anchor_sizes', 'aspect_ratios', 'variance', 'stride', 'offset', 'name'], varargs=None, keywords=None, defaults=(None, None, [0.1, 0.1, 0.2, 0.2], None, 0.5, None)), ('document', '82b2aefeeb1b706bc4afec70928a259a')) paddle.fluid.layers.roi_perspective_transform (ArgSpec(args=['input', 'rois', 'transformed_height', 'transformed_width', 'spatial_scale'], varargs=None, keywords=None, defaults=(1.0,)), ('document', 'd1ddc75629fedee46f82e631e22c79dc')) -paddle.fluid.layers.generate_proposal_labels (ArgSpec(args=['rpn_rois', 'gt_classes', 'is_crowd', 'gt_boxes', 'im_info', 'batch_size_per_im', 'fg_fraction', 'fg_thresh', 'bg_thresh_hi', 'bg_thresh_lo', 'bbox_reg_weights', 'class_nums', 'use_random'], varargs=None, keywords=None, defaults=(256, 0.25, 0.25, 0.5, 0.0, [0.1, 0.1, 0.2, 0.2], None, True)), ('document', '9c601df88b251f22e9311c52939948cd')) +paddle.fluid.layers.generate_proposal_labels (ArgSpec(args=['rpn_rois', 'gt_classes', 'is_crowd', 'gt_boxes', 'im_info', 'batch_size_per_im', 'fg_fraction', 'fg_thresh', 'bg_thresh_hi', 'bg_thresh_lo', 'bbox_reg_weights', 'class_nums', 'use_random', 'is_cls_agnostic', 'is_cascade_rcnn'], varargs=None, keywords=None, defaults=(256, 0.25, 0.25, 0.5, 0.0, [0.1, 0.1, 0.2, 0.2], None, True, False, False)), ('document', 'e87c1131e98715d3657a96c44db1b910')) paddle.fluid.layers.generate_proposals (ArgSpec(args=['scores', 'bbox_deltas', 'im_info', 'anchors', 'variances', 'pre_nms_top_n', 'post_nms_top_n', 'nms_thresh', 'min_size', 'eta', 'name'], varargs=None, keywords=None, defaults=(6000, 1000, 0.5, 0.1, 1.0, None)), ('document', 'b7d707822b6af2a586bce608040235b1')) paddle.fluid.layers.generate_mask_labels (ArgSpec(args=['im_info', 'gt_classes', 'is_crowd', 'gt_segms', 'rois', 'labels_int32', 'num_classes', 'resolution'], varargs=None, keywords=None, defaults=None), ('document', 'b319b10ddaf17fb4ddf03518685a17ef')) paddle.fluid.layers.iou_similarity (ArgSpec(args=['x', 'y', 'name'], varargs=None, keywords=None, defaults=(None,)), ('document', '72fca4a39ccf82d5c746ae62d1868a99')) diff --git a/paddle/fluid/operators/detection/generate_proposal_labels_op.cc b/paddle/fluid/operators/detection/generate_proposal_labels_op.cc index b9b8a5a53ae5b865d882407b4985a657cf85eccb..451e0ca85501bccd2588dd58d0c8efe7142559d9 100644 --- a/paddle/fluid/operators/detection/generate_proposal_labels_op.cc +++ b/paddle/fluid/operators/detection/generate_proposal_labels_op.cc @@ -109,17 +109,18 @@ std::vector> SampleFgBgGt( const platform::CPUDeviceContext& context, Tensor* iou, const Tensor& is_crowd, const int batch_size_per_im, const float fg_fraction, const float fg_thresh, const float bg_thresh_hi, - const float bg_thresh_lo, std::minstd_rand engine, const bool use_random) { + const float bg_thresh_lo, std::minstd_rand engine, const bool use_random, + const bool is_cascade_rcnn, const Tensor& rpn_rois) { std::vector fg_inds; std::vector bg_inds; - std::vector gt_inds; + std::vector mapped_gt_inds; int64_t gt_num = is_crowd.numel(); const int* crowd_data = is_crowd.data(); T* proposal_to_gt_overlaps = iou->data(); int64_t row = iou->dims()[0]; int64_t col = iou->dims()[1]; float epsilon = 0.00001; - + const T* rpn_rois_dt = rpn_rois.data(); // Follow the Faster RCNN's implementation for (int64_t i = 0; i < row; ++i) { const T* v = proposal_to_gt_overlaps + i * col; @@ -127,64 +128,82 @@ std::vector> SampleFgBgGt( if ((i < gt_num) && (crowd_data[i])) { max_overlap = -1.0; } - if (max_overlap > fg_thresh) { + if (is_cascade_rcnn && + ((rpn_rois_dt[i * 4 + 2] - rpn_rois_dt[i * 4 + 0] + 1) <= 0 || + (rpn_rois_dt[i * 4 + 3] - rpn_rois_dt[i * 4 + 1] + 1) <= 0)) { + continue; + } + if (max_overlap >= fg_thresh) { + // fg mapped gt label index for (int64_t j = 0; j < col; ++j) { T val = proposal_to_gt_overlaps[i * col + j]; auto diff = std::abs(max_overlap - val); if (diff < epsilon) { fg_inds.emplace_back(i); - gt_inds.emplace_back(j); + mapped_gt_inds.emplace_back(j); break; } } + } else if ((max_overlap >= bg_thresh_lo) && (max_overlap < bg_thresh_hi)) { + bg_inds.emplace_back(i); } else { - if ((max_overlap >= bg_thresh_lo) && (max_overlap < bg_thresh_hi)) { - bg_inds.emplace_back(i); - } + continue; } } - // Reservoir Sampling - std::uniform_real_distribution uniform(0, 1); - int fg_rois_per_im = std::floor(batch_size_per_im * fg_fraction); - int fg_rois_this_image = fg_inds.size(); - int fg_rois_per_this_image = std::min(fg_rois_per_im, fg_rois_this_image); - if (use_random) { - const int64_t fg_size = static_cast(fg_inds.size()); - if (fg_size > fg_rois_per_this_image) { - for (int64_t i = fg_rois_per_this_image; i < fg_size; ++i) { - int rng_ind = std::floor(uniform(engine) * i); - if (rng_ind < fg_rois_per_this_image) { - std::iter_swap(fg_inds.begin() + rng_ind, fg_inds.begin() + i); - std::iter_swap(gt_inds.begin() + rng_ind, gt_inds.begin() + i); + std::vector> res; + if (is_cascade_rcnn) { + res.emplace_back(fg_inds); + res.emplace_back(bg_inds); + res.emplace_back(mapped_gt_inds); + } else { + // Reservoir Sampling + // sampling fg + std::uniform_real_distribution uniform(0, 1); + int fg_rois_per_im = std::floor(batch_size_per_im * fg_fraction); + int fg_rois_this_image = fg_inds.size(); + int fg_rois_per_this_image = std::min(fg_rois_per_im, fg_rois_this_image); + if (use_random) { + const int64_t fg_size = static_cast(fg_inds.size()); + if (fg_size > fg_rois_per_this_image) { + for (int64_t i = fg_rois_per_this_image; i < fg_size; ++i) { + int rng_ind = std::floor(uniform(engine) * i); + if (rng_ind < fg_rois_per_this_image) { + std::iter_swap(fg_inds.begin() + rng_ind, fg_inds.begin() + i); + std::iter_swap(mapped_gt_inds.begin() + rng_ind, + mapped_gt_inds.begin() + i); + } } } } - } - std::vector new_fg_inds(fg_inds.begin(), - fg_inds.begin() + fg_rois_per_this_image); - std::vector new_gt_inds(gt_inds.begin(), - gt_inds.begin() + fg_rois_per_this_image); - - int bg_rois_per_image = batch_size_per_im - fg_rois_per_this_image; - int bg_rois_this_image = bg_inds.size(); - int bg_rois_per_this_image = std::min(bg_rois_per_image, bg_rois_this_image); - if (use_random) { - const int64_t bg_size = static_cast(bg_inds.size()); - if (bg_size > bg_rois_per_this_image) { - for (int64_t i = bg_rois_per_this_image; i < bg_size; ++i) { - int rng_ind = std::floor(uniform(engine) * i); - if (rng_ind < fg_rois_per_this_image) - std::iter_swap(bg_inds.begin() + rng_ind, bg_inds.begin() + i); + std::vector new_fg_inds(fg_inds.begin(), + fg_inds.begin() + fg_rois_per_this_image); + std::vector new_gt_inds( + mapped_gt_inds.begin(), + mapped_gt_inds.begin() + fg_rois_per_this_image); + // sampling bg + int bg_rois_per_image = batch_size_per_im - fg_rois_per_this_image; + int bg_rois_this_image = bg_inds.size(); + int bg_rois_per_this_image = + std::min(bg_rois_per_image, bg_rois_this_image); + if (use_random) { + const int64_t bg_size = static_cast(bg_inds.size()); + if (bg_size > bg_rois_per_this_image) { + for (int64_t i = bg_rois_per_this_image; i < bg_size; ++i) { + int rng_ind = std::floor(uniform(engine) * i); + if (rng_ind < fg_rois_per_this_image) + std::iter_swap(bg_inds.begin() + rng_ind, bg_inds.begin() + i); + } } } + std::vector new_bg_inds(bg_inds.begin(), + bg_inds.begin() + bg_rois_per_this_image); + // + res.emplace_back(new_fg_inds); + res.emplace_back(new_bg_inds); + res.emplace_back(new_gt_inds); } - std::vector new_bg_inds(bg_inds.begin(), - bg_inds.begin() + bg_rois_per_this_image); - std::vector> res; - res.emplace_back(new_fg_inds); - res.emplace_back(new_bg_inds); - res.emplace_back(new_gt_inds); + return res; } @@ -231,35 +250,50 @@ std::vector SampleRoisForOneImage( const Tensor& im_info, const int batch_size_per_im, const float fg_fraction, const float fg_thresh, const float bg_thresh_hi, const float bg_thresh_lo, const std::vector& bbox_reg_weights, const int class_nums, - std::minstd_rand engine, bool use_random) { + std::minstd_rand engine, bool use_random, bool is_cascade_rcnn, + bool is_cls_agnostic) { + // 1.1 map to original image auto im_scale = im_info.data()[2]; - + Tensor rpn_rois_slice; Tensor rpn_rois; - rpn_rois.mutable_data(rpn_rois_in.dims(), context.GetPlace()); - T* rpn_rois_dt = rpn_rois.data(); - const T* rpn_rois_in_dt = rpn_rois_in.data(); - for (int i = 0; i < rpn_rois.numel(); ++i) { - rpn_rois_dt[i] = rpn_rois_in_dt[i] / im_scale; + + if (is_cascade_rcnn) { + // slice rpn_rois from gt_box_num refer to detectron + rpn_rois_slice = + rpn_rois_in.Slice(gt_boxes.dims()[0], rpn_rois_in.dims()[0]); + rpn_rois.mutable_data(rpn_rois_slice.dims(), context.GetPlace()); + const T* rpn_rois_in_dt = rpn_rois_slice.data(); + T* rpn_rois_dt = rpn_rois.data(); + for (int i = 0; i < rpn_rois.numel(); ++i) { + rpn_rois_dt[i] = rpn_rois_in_dt[i] / im_scale; + } + } else { + rpn_rois.mutable_data(rpn_rois_in.dims(), context.GetPlace()); + const T* rpn_rois_in_dt = rpn_rois_in.data(); + T* rpn_rois_dt = rpn_rois.data(); + for (int i = 0; i < rpn_rois.numel(); ++i) { + rpn_rois_dt[i] = rpn_rois_in_dt[i] / im_scale; + } } - Tensor boxes; + // 1.2 compute overlaps int proposals_num = gt_boxes.dims()[0] + rpn_rois.dims()[0]; + Tensor boxes; boxes.mutable_data({proposals_num, kBoxDim}, context.GetPlace()); Concat(context, gt_boxes, rpn_rois, &boxes); - - // Overlaps Tensor proposal_to_gt_overlaps; proposal_to_gt_overlaps.mutable_data({proposals_num, gt_boxes.dims()[0]}, context.GetPlace()); BboxOverlaps(boxes, gt_boxes, &proposal_to_gt_overlaps); // Generate proposal index - std::vector> fg_bg_gt = SampleFgBgGt( - context, &proposal_to_gt_overlaps, is_crowd, batch_size_per_im, - fg_fraction, fg_thresh, bg_thresh_hi, bg_thresh_lo, engine, use_random); + std::vector> fg_bg_gt = + SampleFgBgGt(context, &proposal_to_gt_overlaps, is_crowd, + batch_size_per_im, fg_fraction, fg_thresh, bg_thresh_hi, + bg_thresh_lo, engine, use_random, is_cascade_rcnn, boxes); std::vector fg_inds = fg_bg_gt[0]; std::vector bg_inds = fg_bg_gt[1]; - std::vector gt_inds = fg_bg_gt[2]; + std::vector mapped_gt_inds = fg_bg_gt[2]; // mapped_gt_labels // Gather boxes and labels Tensor sampled_boxes, sampled_labels, sampled_gts; @@ -271,7 +305,8 @@ std::vector SampleRoisForOneImage( sampled_labels.mutable_data({boxes_num}, context.GetPlace()); sampled_gts.mutable_data({fg_num, kBoxDim}, context.GetPlace()); GatherBoxesLabels(context, boxes, gt_boxes, gt_classes, fg_inds, bg_inds, - gt_inds, &sampled_boxes, &sampled_labels, &sampled_gts); + mapped_gt_inds, &sampled_boxes, &sampled_labels, + &sampled_gts); // Compute targets Tensor bbox_targets_single; @@ -305,6 +340,9 @@ std::vector SampleRoisForOneImage( for (int64_t i = 0; i < boxes_num; ++i) { int label = sampled_labels_data[i]; if (label > 0) { + if (is_cls_agnostic) { + label = 1; + } int dst_idx = i * width + kBoxDim * label; int src_idx = kBoxDim * i; bbox_targets_data[dst_idx] = bbox_targets_single_data[src_idx]; @@ -356,7 +394,8 @@ class GenerateProposalLabelsKernel : public framework::OpKernel { context.Attr>("bbox_reg_weights"); int class_nums = context.Attr("class_nums"); bool use_random = context.Attr("use_random"); - + bool is_cascade_rcnn = context.Attr("is_cascade_rcnn"); + bool is_cls_agnostic = context.Attr("is_cls_agnostic"); PADDLE_ENFORCE_EQ(rpn_rois->lod().size(), 1UL, "GenerateProposalLabelsOp rpn_rois needs 1 level of LoD"); PADDLE_ENFORCE_EQ( @@ -411,7 +450,7 @@ class GenerateProposalLabelsKernel : public framework::OpKernel { dev_ctx, rpn_rois_slice, gt_classes_slice, is_crowd_slice, gt_boxes_slice, im_info_slice, batch_size_per_im, fg_fraction, fg_thresh, bg_thresh_hi, bg_thresh_lo, bbox_reg_weights, class_nums, - engine, use_random); + engine, use_random, is_cascade_rcnn, is_cls_agnostic); Tensor sampled_rois = tensor_output[0]; Tensor sampled_labels_int32 = tensor_output[1]; Tensor sampled_bbox_targets = tensor_output[2]; @@ -513,6 +552,13 @@ class GenerateProposalLabelsOpMaker : public framework::OpProtoAndCheckerMaker { "use_random", "Use random sampling to choose foreground and background boxes.") .SetDefault(true); + AddAttr("is_cascade_rcnn", + "cascade rcnn sampling policy changed from stage 2.") + .SetDefault(false); + AddAttr( + "is_cls_agnostic", + "the box regress will only include fg and bg locations if set true ") + .SetDefault(false); AddComment(R"DOC( This operator can be, for given the GenerateProposalOp output bounding boxes and groundtruth, diff --git a/python/paddle/fluid/layers/detection.py b/python/paddle/fluid/layers/detection.py index 4d187120227a5acb4729c0d82311606b4fe97650..fa85350adcd3f8cc181a8a19a2789042a01031ba 100644 --- a/python/paddle/fluid/layers/detection.py +++ b/python/paddle/fluid/layers/detection.py @@ -1916,9 +1916,13 @@ def generate_proposal_labels(rpn_rois, bg_thresh_lo=0.0, bbox_reg_weights=[0.1, 0.1, 0.2, 0.2], class_nums=None, - use_random=True): + use_random=True, + is_cls_agnostic=False, + is_cascade_rcnn=False): """ + ** Generate Proposal Labels of Faster-RCNN ** + This operator can be, for given the GenerateProposalOp output bounding boxes and groundtruth, to sample foreground boxes and background boxes, and compute loss target. @@ -1949,6 +1953,8 @@ def generate_proposal_labels(rpn_rois, bbox_reg_weights(list|tuple): Box regression weights. class_nums(int): Class number. use_random(bool): Use random sampling to choose foreground and background boxes. + is_cls_agnostic(bool): class agnostic bbox regression will only represent fg and bg boxes. + is_cascade_rcnn(bool): cascade rcnn model will change sampling policy when settting True. Examples: .. code-block:: python @@ -2007,7 +2013,9 @@ def generate_proposal_labels(rpn_rois, 'bg_thresh_lo': bg_thresh_lo, 'bbox_reg_weights': bbox_reg_weights, 'class_nums': class_nums, - 'use_random': use_random + 'use_random': use_random, + 'is_cls_agnostic': is_cls_agnostic, + 'is_cascade_rcnn': is_cascade_rcnn }) rois.stop_gradient = True diff --git a/python/paddle/fluid/tests/unittests/test_generate_proposal_labels_op.py b/python/paddle/fluid/tests/unittests/test_generate_proposal_labels_op.py index 5f6328707fd80ec8f11b96cc65e2dcaf44496d58..406c255970a52d50c14efb685f55c89947958339 100644 --- a/python/paddle/fluid/tests/unittests/test_generate_proposal_labels_op.py +++ b/python/paddle/fluid/tests/unittests/test_generate_proposal_labels_op.py @@ -22,10 +22,10 @@ import paddle.fluid as fluid from op_test import OpTest -def generate_proposal_labels_in_python(rpn_rois, gt_classes, is_crowd, gt_boxes, - im_info, batch_size_per_im, fg_fraction, - fg_thresh, bg_thresh_hi, bg_thresh_lo, - bbox_reg_weights, class_nums): +def generate_proposal_labels_in_python( + rpn_rois, gt_classes, is_crowd, gt_boxes, im_info, batch_size_per_im, + fg_fraction, fg_thresh, bg_thresh_hi, bg_thresh_lo, bbox_reg_weights, + class_nums, is_cls_agnostic, is_cascade_rcnn): rois = [] labels_int32 = [] bbox_targets = [] @@ -36,13 +36,12 @@ def generate_proposal_labels_in_python(rpn_rois, gt_classes, is_crowd, gt_boxes, im_info), 'batch size of rpn_rois and ground_truth is not matched' for im_i in range(len(im_info)): - frcn_blobs = _sample_rois( - rpn_rois[im_i], gt_classes[im_i], is_crowd[im_i], gt_boxes[im_i], - im_info[im_i], batch_size_per_im, fg_fraction, fg_thresh, - bg_thresh_hi, bg_thresh_lo, bbox_reg_weights, class_nums) - + frcn_blobs = _sample_rois(rpn_rois[im_i], gt_classes[im_i], + is_crowd[im_i], gt_boxes[im_i], im_info[im_i], + batch_size_per_im, fg_fraction, fg_thresh, + bg_thresh_hi, bg_thresh_lo, bbox_reg_weights, + class_nums, is_cls_agnostic, is_cascade_rcnn) lod.append(frcn_blobs['rois'].shape[0]) - rois.append(frcn_blobs['rois']) labels_int32.append(frcn_blobs['labels_int32']) bbox_targets.append(frcn_blobs['bbox_targets']) @@ -54,7 +53,8 @@ def generate_proposal_labels_in_python(rpn_rois, gt_classes, is_crowd, gt_boxes, def _sample_rois(rpn_rois, gt_classes, is_crowd, gt_boxes, im_info, batch_size_per_im, fg_fraction, fg_thresh, bg_thresh_hi, - bg_thresh_lo, bbox_reg_weights, class_nums): + bg_thresh_lo, bbox_reg_weights, class_nums, is_cls_agnostic, + is_cascade_rcnn): rois_per_image = int(batch_size_per_im) fg_rois_per_im = int(np.round(fg_fraction * rois_per_image)) @@ -62,7 +62,8 @@ def _sample_rois(rpn_rois, gt_classes, is_crowd, gt_boxes, im_info, im_scale = im_info[2] inv_im_scale = 1. / im_scale rpn_rois = rpn_rois * inv_im_scale - + if is_cascade_rcnn: + rpn_rois = rpn_rois[gt_boxes.shape[0]:, :] boxes = np.vstack([gt_boxes, rpn_rois]) gt_overlaps = np.zeros((boxes.shape[0], class_nums)) box_to_gt_ind_map = np.zeros((boxes.shape[0]), dtype=np.int32) @@ -87,26 +88,37 @@ def _sample_rois(rpn_rois, gt_classes, is_crowd, gt_boxes, im_info, max_overlaps = gt_overlaps.max(axis=1) max_classes = gt_overlaps.argmax(axis=1) - # Foreground - fg_inds = np.where(max_overlaps >= fg_thresh)[0] - fg_rois_per_this_image = np.minimum(fg_rois_per_im, fg_inds.shape[0]) - # Sample foreground if there are too many - # if fg_inds.shape[0] > fg_rois_per_this_image: - # fg_inds = np.random.choice( - # fg_inds, size=fg_rois_per_this_image, replace=False) - fg_inds = fg_inds[:fg_rois_per_this_image] - - # Background - bg_inds = np.where((max_overlaps < bg_thresh_hi) & (max_overlaps >= - bg_thresh_lo))[0] - bg_rois_per_this_image = rois_per_image - fg_rois_per_this_image - bg_rois_per_this_image = np.minimum(bg_rois_per_this_image, - bg_inds.shape[0]) - # Sample background if there are too many - # if bg_inds.shape[0] > bg_rois_per_this_image: - # bg_inds = np.random.choice( - # bg_inds, size=bg_rois_per_this_image, replace=False) - bg_inds = bg_inds[:bg_rois_per_this_image] + # Cascade RCNN Decode Filter + if is_cascade_rcnn: + ws = boxes[:, 2] - boxes[:, 0] + 1 + hs = boxes[:, 3] - boxes[:, 1] + 1 + keep = np.where((ws > 0) & (hs > 0))[0] + boxes = boxes[keep] + fg_inds = np.where(max_overlaps >= fg_thresh)[0] + bg_inds = np.where((max_overlaps < bg_thresh_hi) & (max_overlaps >= + bg_thresh_lo))[0] + fg_rois_per_this_image = fg_inds.shape[0] + bg_rois_per_this_image = bg_inds.shape[0] + else: + # Foreground + fg_inds = np.where(max_overlaps >= fg_thresh)[0] + fg_rois_per_this_image = np.minimum(fg_rois_per_im, fg_inds.shape[0]) + # Sample foreground if there are too many + if fg_inds.shape[0] > fg_rois_per_this_image: + fg_inds = np.random.choice( + fg_inds, size=fg_rois_per_this_image, replace=False) + fg_inds = fg_inds[:fg_rois_per_this_image] + # Background + bg_inds = np.where((max_overlaps < bg_thresh_hi) & (max_overlaps >= + bg_thresh_lo))[0] + bg_rois_per_this_image = rois_per_image - fg_rois_per_this_image + bg_rois_per_this_image = np.minimum(bg_rois_per_this_image, + bg_inds.shape[0]) + # Sample background if there are too many + if bg_inds.shape[0] > bg_rois_per_this_image: + bg_inds = np.random.choice( + bg_inds, size=bg_rois_per_this_image, replace=False) + bg_inds = bg_inds[:bg_rois_per_this_image] keep_inds = np.append(fg_inds, bg_inds) sampled_labels = max_classes[keep_inds] @@ -114,14 +126,12 @@ def _sample_rois(rpn_rois, gt_classes, is_crowd, gt_boxes, im_info, sampled_boxes = boxes[keep_inds] sampled_gts = gt_boxes[box_to_gt_ind_map[keep_inds]] sampled_gts[fg_rois_per_this_image:, :] = gt_boxes[0] - bbox_label_targets = _compute_targets(sampled_boxes, sampled_gts, sampled_labels, bbox_reg_weights) - bbox_targets, bbox_inside_weights = _expand_bbox_targets(bbox_label_targets, - class_nums) + bbox_targets, bbox_inside_weights = _expand_bbox_targets( + bbox_label_targets, class_nums, is_cls_agnostic) bbox_outside_weights = np.array( bbox_inside_weights > 0, dtype=bbox_inside_weights.dtype) - # Scale rois sampled_rois = sampled_boxes * im_scale @@ -192,19 +202,22 @@ def _box_to_delta(ex_boxes, gt_boxes, weights): return targets -def _expand_bbox_targets(bbox_targets_input, class_nums): +def _expand_bbox_targets(bbox_targets_input, class_nums, is_cls_agnostic): class_labels = bbox_targets_input[:, 0] fg_inds = np.where(class_labels > 0)[0] - - bbox_targets = np.zeros((class_labels.shape[0], 4 * class_nums)) + #if is_cls_agnostic: + # class_labels = [1 if ll > 0 else 0 for ll in class_labels] + # class_labels = np.array(class_labels, dtype=np.int32) + # class_nums = 2 + bbox_targets = np.zeros((class_labels.shape[0], 4 * class_nums + if not is_cls_agnostic else 4 * 2)) bbox_inside_weights = np.zeros(bbox_targets.shape) for ind in fg_inds: - class_label = int(class_labels[ind]) + class_label = int(class_labels[ind]) if not is_cls_agnostic else 1 start_ind = class_label * 4 end_ind = class_label * 4 + 4 bbox_targets[ind, start_ind:end_ind] = bbox_targets_input[ind, 1:] bbox_inside_weights[ind, start_ind:end_ind] = (1.0, 1.0, 1.0, 1.0) - return bbox_targets, bbox_inside_weights @@ -228,7 +241,9 @@ class TestGenerateProposalLabelsOp(OpTest): 'bg_thresh_lo': self.bg_thresh_lo, 'bbox_reg_weights': self.bbox_reg_weights, 'class_nums': self.class_nums, - 'use_random': False + 'use_random': False, + 'is_cls_agnostic': self.is_cls_agnostic, + 'is_cascade_rcnn': self.is_cascade_rcnn } self.outputs = { 'Rois': (self.rois, [self.lod]), @@ -252,12 +267,15 @@ class TestGenerateProposalLabelsOp(OpTest): self.bg_thresh_hi = 0.5 self.bg_thresh_lo = 0.0 self.bbox_reg_weights = [0.1, 0.1, 0.2, 0.2] - self.class_nums = 81 + #self.class_nums = 81 + self.is_cls_agnostic = False #True + self.is_cascade_rcnn = True + self.class_nums = 2 if self.is_cls_agnostic else 81 def init_test_input(self): np.random.seed(0) gt_nums = 6 # Keep same with batch_size_per_im for unittest - proposal_nums = 2000 #self.batch_size_per_im - gt_nums + proposal_nums = 2000 if not self.is_cascade_rcnn else 512 #self.batch_size_per_im - gt_nums images_shape = [[64, 64]] self.im_info = np.ones((len(images_shape), 3)).astype(np.float32) for i in range(len(images_shape)): @@ -280,7 +298,8 @@ class TestGenerateProposalLabelsOp(OpTest): self.rpn_rois, self.gt_classes, self.is_crowd, self.gt_boxes, self.im_info, self.batch_size_per_im, self.fg_fraction, self.fg_thresh, self.bg_thresh_hi, self.bg_thresh_lo, - self.bbox_reg_weights, self.class_nums + self.bbox_reg_weights, self.class_nums, + self.is_cls_agnostic, self.is_cascade_rcnn ) self.rois = np.vstack(self.rois) self.labels_int32 = np.hstack(self.labels_int32)