diff --git a/ppdet/modeling/heads/roi_extractor.py b/ppdet/modeling/heads/roi_extractor.py index 5d2b1528f07003193b03a02bc1320bfb2d304a6d..cab81d5ea74dc2fced1eadaaad2f6da3190dba0c 100644 --- a/ppdet/modeling/heads/roi_extractor.py +++ b/ppdet/modeling/heads/roi_extractor.py @@ -87,7 +87,7 @@ class RoIAlign(object): offset = 2 k_min = self.start_level + offset k_max = self.end_level + offset - rois_dist, restore_index, rois_num_dist = ops.distribute_fpn_proposals( + rois_dist, restore_index, rois_num_dist = paddle.vision.ops.distribute_fpn_proposals( roi, k_min, k_max, diff --git a/ppdet/modeling/ops.py b/ppdet/modeling/ops.py index 567c26d7233e561118e22ca3a8e7d74f7b7cf686..564261c89d31e9c3c785debba4ac03571757a15e 100644 --- a/ppdet/modeling/ops.py +++ b/ppdet/modeling/ops.py @@ -17,17 +17,15 @@ import paddle.nn.functional as F import paddle.nn as nn from paddle import ParamAttr from paddle.regularizer import L2Decay -from paddle import _C_ops +from paddle import _C_ops, _legacy_C_ops from paddle import in_dynamic_mode from paddle.common_ops_import import Variable, LayerHelper, check_variable_and_dtype, check_type, check_dtype __all__ = [ 'prior_box', - 'generate_proposals', 'box_coder', 'multiclass_nms', - 'distribute_fpn_proposals', 'matrix_nms', 'batch_norm', 'mish', @@ -115,135 +113,6 @@ def batch_norm(ch, return norm_layer -@paddle.jit.not_to_static -def distribute_fpn_proposals(fpn_rois, - min_level, - max_level, - refer_level, - refer_scale, - pixel_offset=False, - rois_num=None, - name=None): - r""" - - **This op only takes LoDTensor as input.** In Feature Pyramid Networks - (FPN) models, it is needed to distribute all proposals into different FPN - level, with respect to scale of the proposals, the referring scale and the - referring level. Besides, to restore the order of proposals, we return an - array which indicates the original index of rois in current proposals. - To compute FPN level for each roi, the formula is given as follows: - - .. math:: - - roi\_scale &= \sqrt{BBoxArea(fpn\_roi)} - - level = floor(&\log(\\frac{roi\_scale}{refer\_scale}) + refer\_level) - - where BBoxArea is a function to compute the area of each roi. - - Args: - - fpn_rois(Variable): 2-D Tensor with shape [N, 4] and data type is - float32 or float64. The input fpn_rois. - min_level(int32): The lowest level of FPN layer where the proposals come - from. - max_level(int32): The highest level of FPN layer where the proposals - come from. - refer_level(int32): The referring level of FPN layer with specified scale. - refer_scale(int32): The referring scale of FPN layer with specified level. - rois_num(Tensor): 1-D Tensor contains the number of RoIs in each image. - The shape is [B] and data type is int32. B is the number of images. - If it is not None then return a list of 1-D Tensor. Each element - is the output RoIs' number of each image on the corresponding level - and the shape is [B]. None by default. - name(str, optional): For detailed information, please refer - to :ref:`api_guide_Name`. Usually name is no need to set and - None by default. - - Returns: - Tuple: - - multi_rois(List) : A list of 2-D LoDTensor with shape [M, 4] - and data type of float32 and float64. The length is - max_level-min_level+1. The proposals in each FPN level. - - restore_ind(Variable): A 2-D Tensor with shape [N, 1], N is - the number of total rois. The data type is int32. It is - used to restore the order of fpn_rois. - - rois_num_per_level(List): A list of 1-D Tensor and each Tensor is - the RoIs' number in each image on the corresponding level. The shape - is [B] and data type of int32. B is the number of images - - - Examples: - .. code-block:: python - - import paddle - from ppdet.modeling import ops - paddle.enable_static() - fpn_rois = paddle.static.data( - name='data', shape=[None, 4], dtype='float32', lod_level=1) - multi_rois, restore_ind = ops.distribute_fpn_proposals( - fpn_rois=fpn_rois, - min_level=2, - max_level=5, - refer_level=4, - refer_scale=224) - """ - num_lvl = max_level - min_level + 1 - - if in_dynamic_mode(): - assert rois_num is not None, "rois_num should not be None in dygraph mode." - attrs = ('min_level', min_level, 'max_level', max_level, 'refer_level', - refer_level, 'refer_scale', refer_scale, 'pixel_offset', - pixel_offset) - multi_rois, restore_ind, rois_num_per_level = _C_ops.distribute_fpn_proposals( - fpn_rois, rois_num, num_lvl, num_lvl, *attrs) - return multi_rois, restore_ind, rois_num_per_level - - else: - check_variable_and_dtype(fpn_rois, 'fpn_rois', ['float32', 'float64'], - 'distribute_fpn_proposals') - helper = LayerHelper('distribute_fpn_proposals', **locals()) - dtype = helper.input_dtype('fpn_rois') - multi_rois = [ - helper.create_variable_for_type_inference(dtype) - for i in range(num_lvl) - ] - - restore_ind = helper.create_variable_for_type_inference(dtype='int32') - - inputs = {'FpnRois': fpn_rois} - outputs = { - 'MultiFpnRois': multi_rois, - 'RestoreIndex': restore_ind, - } - - if rois_num is not None: - inputs['RoisNum'] = rois_num - rois_num_per_level = [ - helper.create_variable_for_type_inference(dtype='int32') - for i in range(num_lvl) - ] - outputs['MultiLevelRoIsNum'] = rois_num_per_level - else: - rois_num_per_level = None - - helper.append_op( - type='distribute_fpn_proposals', - inputs=inputs, - outputs=outputs, - attrs={ - 'min_level': min_level, - 'max_level': max_level, - 'refer_level': refer_level, - 'refer_scale': refer_scale, - 'pixel_offset': pixel_offset - }) - return multi_rois, restore_ind, rois_num_per_level - - @paddle.jit.not_to_static def prior_box(input, image, @@ -353,7 +222,7 @@ def prior_box(input, 'min_max_aspect_ratios_order', min_max_aspect_ratios_order) if cur_max_sizes is not None: attrs += ('max_sizes', cur_max_sizes) - box, var = _C_ops.prior_box(input, image, *attrs) + box, var = _legacy_C_ops.prior_box(input, image, *attrs) return box, var else: attrs = { @@ -496,8 +365,8 @@ def multiclass_nms(bboxes, score_threshold, 'nms_top_k', nms_top_k, 'nms_threshold', nms_threshold, 'keep_top_k', keep_top_k, 'nms_eta', nms_eta, 'normalized', normalized) - output, index, nms_rois_num = _C_ops.multiclass_nms3(bboxes, scores, - rois_num, *attrs) + output, index, nms_rois_num = _legacy_C_ops.multiclass_nms3( + bboxes, scores, rois_num, *attrs) if not return_index: index = None return output, nms_rois_num, index @@ -638,7 +507,7 @@ def matrix_nms(bboxes, nms_top_k, 'gaussian_sigma', gaussian_sigma, 'use_gaussian', use_gaussian, 'keep_top_k', keep_top_k, 'normalized', normalized) - out, index, rois_num = _C_ops.matrix_nms(bboxes, scores, *attrs) + out, index, rois_num = _legacy_C_ops.matrix_nms(bboxes, scores, *attrs) if not return_index: index = None if not return_rois_num: @@ -791,12 +660,12 @@ def box_coder(prior_box, if in_dynamic_mode(): if isinstance(prior_box_var, Variable): - output_box = _C_ops.box_coder( + output_box = _legacy_C_ops.box_coder( prior_box, prior_box_var, target_box, "code_type", code_type, "box_normalized", box_normalized, "axis", axis) elif isinstance(prior_box_var, list): - output_box = _C_ops.box_coder( + output_box = _legacy_C_ops.box_coder( prior_box, None, target_box, "code_type", code_type, "box_normalized", box_normalized, "axis", axis, "variance", prior_box_var) @@ -831,154 +700,6 @@ def box_coder(prior_box, return output_box -@paddle.jit.not_to_static -def generate_proposals(scores, - bbox_deltas, - im_shape, - anchors, - variances, - pre_nms_top_n=6000, - post_nms_top_n=1000, - nms_thresh=0.5, - min_size=0.1, - eta=1.0, - pixel_offset=False, - return_rois_num=False, - name=None): - """ - **Generate proposal Faster-RCNN** - This operation proposes RoIs according to each box with their - probability to be a foreground object and - the box can be calculated by anchors. Bbox_deltais and scores - to be an object are the output of RPN. Final proposals - could be used to train detection net. - For generating proposals, this operation performs following steps: - 1. Transposes and resizes scores and bbox_deltas in size of - (H*W*A, 1) and (H*W*A, 4) - 2. Calculate box locations as proposals candidates. - 3. Clip boxes to image - 4. Remove predicted boxes with small area. - 5. Apply NMS to get final proposals as output. - Args: - scores(Tensor): A 4-D Tensor with shape [N, A, H, W] represents - the probability for each box to be an object. - N is batch size, A is number of anchors, H and W are height and - width of the feature map. The data type must be float32. - bbox_deltas(Tensor): A 4-D Tensor with shape [N, 4*A, H, W] - represents the difference between predicted box location and - anchor location. The data type must be float32. - im_shape(Tensor): A 2-D Tensor with shape [N, 2] represents H, W, the - origin image size or input size. The data type can be float32 or - float64. - anchors(Tensor): A 4-D Tensor represents the anchors with a layout - of [H, W, A, 4]. H and W are height and width of the feature map, - num_anchors is the box count of each position. Each anchor is - in (xmin, ymin, xmax, ymax) format an unnormalized. The data type must be float32. - variances(Tensor): A 4-D Tensor. The expanded variances of anchors with a layout of - [H, W, num_priors, 4]. Each variance is in - (xcenter, ycenter, w, h) format. The data type must be float32. - pre_nms_top_n(float): Number of total bboxes to be kept per - image before NMS. The data type must be float32. `6000` by default. - post_nms_top_n(float): Number of total bboxes to be kept per - image after NMS. The data type must be float32. `1000` by default. - nms_thresh(float): Threshold in NMS. The data type must be float32. `0.5` by default. - min_size(float): Remove predicted boxes with either height or - width < min_size. The data type must be float32. `0.1` by default. - eta(float): Apply in adaptive NMS, if adaptive `threshold > 0.5`, - `adaptive_threshold = adaptive_threshold * eta` in each iteration. - return_rois_num(bool): When setting True, it will return a 1D Tensor with shape [N, ] that includes Rois's - num of each image in one batch. The N is the image's num. For example, the tensor has values [4,5] that represents - the first image has 4 Rois, the second image has 5 Rois. It only used in rcnn model. - 'False' by default. - name(str, optional): For detailed information, please refer - to :ref:`api_guide_Name`. Usually name is no need to set and - None by default. - - Returns: - tuple: - A tuple with format ``(rpn_rois, rpn_roi_probs)``. - - **rpn_rois**: The generated RoIs. 2-D Tensor with shape ``[N, 4]`` while ``N`` is the number of RoIs. The data type is the same as ``scores``. - - **rpn_roi_probs**: The scores of generated RoIs. 2-D Tensor with shape ``[N, 1]`` while ``N`` is the number of RoIs. The data type is the same as ``scores``. - - Examples: - .. code-block:: python - - import paddle - from ppdet.modeling import ops - paddle.enable_static() - scores = paddle.static.data(name='scores', shape=[None, 4, 5, 5], dtype='float32') - bbox_deltas = paddle.static.data(name='bbox_deltas', shape=[None, 16, 5, 5], dtype='float32') - im_shape = paddle.static.data(name='im_shape', shape=[None, 2], dtype='float32') - anchors = paddle.static.data(name='anchors', shape=[None, 5, 4, 4], dtype='float32') - variances = paddle.static.data(name='variances', shape=[None, 5, 10, 4], dtype='float32') - rois, roi_probs = ops.generate_proposals(scores, bbox_deltas, - im_shape, anchors, variances) - """ - if in_dynamic_mode(): - assert return_rois_num, "return_rois_num should be True in dygraph mode." - attrs = ('pre_nms_topN', pre_nms_top_n, 'post_nms_topN', post_nms_top_n, - 'nms_thresh', nms_thresh, 'min_size', min_size, 'eta', eta, - 'pixel_offset', pixel_offset) - rpn_rois, rpn_roi_probs, rpn_rois_num = _C_ops.generate_proposals_v2( - scores, bbox_deltas, im_shape, anchors, variances, *attrs) - if not return_rois_num: - rpn_rois_num = None - return rpn_rois, rpn_roi_probs, rpn_rois_num - - else: - helper = LayerHelper('generate_proposals_v2', **locals()) - - check_variable_and_dtype(scores, 'scores', ['float32'], - 'generate_proposals_v2') - check_variable_and_dtype(bbox_deltas, 'bbox_deltas', ['float32'], - 'generate_proposals_v2') - check_variable_and_dtype(im_shape, 'im_shape', ['float32', 'float64'], - 'generate_proposals_v2') - check_variable_and_dtype(anchors, 'anchors', ['float32'], - 'generate_proposals_v2') - check_variable_and_dtype(variances, 'variances', ['float32'], - 'generate_proposals_v2') - - rpn_rois = helper.create_variable_for_type_inference( - dtype=bbox_deltas.dtype) - rpn_roi_probs = helper.create_variable_for_type_inference( - dtype=scores.dtype) - outputs = { - 'RpnRois': rpn_rois, - 'RpnRoiProbs': rpn_roi_probs, - } - if return_rois_num: - rpn_rois_num = helper.create_variable_for_type_inference( - dtype='int32') - rpn_rois_num.stop_gradient = True - outputs['RpnRoisNum'] = rpn_rois_num - - helper.append_op( - type="generate_proposals_v2", - inputs={ - 'Scores': scores, - 'BboxDeltas': bbox_deltas, - 'ImShape': im_shape, - 'Anchors': anchors, - 'Variances': variances - }, - attrs={ - 'pre_nms_topN': pre_nms_top_n, - 'post_nms_topN': post_nms_top_n, - 'nms_thresh': nms_thresh, - 'min_size': min_size, - 'eta': eta, - 'pixel_offset': pixel_offset - }, - outputs=outputs) - rpn_rois.stop_gradient = True - rpn_roi_probs.stop_gradient = True - if not return_rois_num: - rpn_rois_num = None - - return rpn_rois, rpn_roi_probs, rpn_rois_num - - def sigmoid_cross_entropy_with_logits(input, label, ignore_index=-100, diff --git a/ppdet/modeling/proposal_generator/proposal_generator.py b/ppdet/modeling/proposal_generator/proposal_generator.py index 1fcb8b1e2de963e9867249bbb6bc2f524e07d1d2..e911909fee98ac9377c4d2b76a61595ce2f4041b 100644 --- a/ppdet/modeling/proposal_generator/proposal_generator.py +++ b/ppdet/modeling/proposal_generator/proposal_generator.py @@ -62,7 +62,7 @@ class ProposalGenerator(object): top_n = self.pre_nms_top_n if self.topk_after_collect else self.post_nms_top_n variances = paddle.ones_like(anchors) - rpn_rois, rpn_rois_prob, rpn_rois_num = ops.generate_proposals( + rpn_rois, rpn_rois_prob, rpn_rois_num = paddle.vision.ops.generate_proposals( scores, bbox_deltas, im_shape, diff --git a/ppdet/modeling/tests/test_ops.py b/ppdet/modeling/tests/test_ops.py index 4ef9cbc28c0f1dcc4268571c7206d70b306682cd..3614d5b9c232ae9c1ce37b158a502b3349473cb4 100644 --- a/ppdet/modeling/tests/test_ops.py +++ b/ppdet/modeling/tests/test_ops.py @@ -48,70 +48,6 @@ def softmax(x): return exps / np.sum(exps) -class TestDistributeFpnProposals(LayerTest): - def test_distribute_fpn_proposals(self): - rois_np = np.random.rand(10, 4).astype('float32') - rois_num_np = np.array([4, 6]).astype('int32') - with self.static_graph(): - rois = paddle.static.data( - name='rois', shape=[10, 4], dtype='float32') - rois_num = paddle.static.data( - name='rois_num', shape=[None], dtype='int32') - multi_rois, restore_ind, rois_num_per_level = ops.distribute_fpn_proposals( - fpn_rois=rois, - min_level=2, - max_level=5, - refer_level=4, - refer_scale=224, - rois_num=rois_num) - fetch_list = multi_rois + [restore_ind] + rois_num_per_level - output_stat = self.get_static_graph_result( - feed={'rois': rois_np, - 'rois_num': rois_num_np}, - fetch_list=fetch_list, - with_lod=True) - output_stat_np = [] - for output in output_stat: - output_np = np.array(output) - if len(output_np) > 0: - output_stat_np.append(output_np) - - with self.dynamic_graph(): - rois_dy = paddle.to_tensor(rois_np) - rois_num_dy = paddle.to_tensor(rois_num_np) - multi_rois_dy, restore_ind_dy, rois_num_per_level_dy = ops.distribute_fpn_proposals( - fpn_rois=rois_dy, - min_level=2, - max_level=5, - refer_level=4, - refer_scale=224, - rois_num=rois_num_dy) - output_dy = multi_rois_dy + [restore_ind_dy] + rois_num_per_level_dy - output_dy_np = [] - for output in output_dy: - output_np = output.numpy() - if len(output_np) > 0: - output_dy_np.append(output_np) - - for res_stat, res_dy in zip(output_stat_np, output_dy_np): - self.assertTrue(np.array_equal(res_stat, res_dy)) - - def test_distribute_fpn_proposals_error(self): - with self.static_graph(): - fpn_rois = paddle.static.data( - name='data_error', shape=[10, 4], dtype='int32', lod_level=1) - self.assertRaises( - TypeError, - ops.distribute_fpn_proposals, - fpn_rois=fpn_rois, - min_level=2, - max_level=5, - refer_level=4, - refer_scale=224) - - paddle.disable_static() - - class TestROIAlign(LayerTest): def test_roi_align(self): b, c, h, w = 2, 12, 20, 20 @@ -516,69 +452,5 @@ class TestBoxCoder(LayerTest): paddle.disable_static() -class TestGenerateProposals(LayerTest): - def test_generate_proposals(self): - scores_np = np.random.rand(2, 3, 4, 4).astype('float32') - bbox_deltas_np = np.random.rand(2, 12, 4, 4).astype('float32') - im_shape_np = np.array([[8, 8], [6, 6]]).astype('float32') - anchors_np = np.reshape(np.arange(4 * 4 * 3 * 4), - [4, 4, 3, 4]).astype('float32') - variances_np = np.ones((4, 4, 3, 4)).astype('float32') - - with self.static_graph(): - scores = paddle.static.data( - name='scores', shape=[2, 3, 4, 4], dtype='float32') - bbox_deltas = paddle.static.data( - name='bbox_deltas', shape=[2, 12, 4, 4], dtype='float32') - im_shape = paddle.static.data( - name='im_shape', shape=[2, 2], dtype='float32') - anchors = paddle.static.data( - name='anchors', shape=[4, 4, 3, 4], dtype='float32') - variances = paddle.static.data( - name='var', shape=[4, 4, 3, 4], dtype='float32') - rois, roi_probs, rois_num = ops.generate_proposals( - scores, - bbox_deltas, - im_shape, - anchors, - variances, - pre_nms_top_n=10, - post_nms_top_n=5, - return_rois_num=True) - rois_stat, roi_probs_stat, rois_num_stat = self.get_static_graph_result( - feed={ - 'scores': scores_np, - 'bbox_deltas': bbox_deltas_np, - 'im_shape': im_shape_np, - 'anchors': anchors_np, - 'var': variances_np - }, - fetch_list=[rois, roi_probs, rois_num], - with_lod=True) - - with self.dynamic_graph(): - scores_dy = paddle.to_tensor(scores_np) - bbox_deltas_dy = paddle.to_tensor(bbox_deltas_np) - im_shape_dy = paddle.to_tensor(im_shape_np) - anchors_dy = paddle.to_tensor(anchors_np) - variances_dy = paddle.to_tensor(variances_np) - rois, roi_probs, rois_num = ops.generate_proposals( - scores_dy, - bbox_deltas_dy, - im_shape_dy, - anchors_dy, - variances_dy, - pre_nms_top_n=10, - post_nms_top_n=5, - return_rois_num=True) - rois_dy = rois.numpy() - roi_probs_dy = roi_probs.numpy() - rois_num_dy = rois_num.numpy() - - self.assertTrue(np.array_equal(np.array(rois_stat), rois_dy)) - self.assertTrue(np.array_equal(np.array(roi_probs_stat), roi_probs_dy)) - self.assertTrue(np.array_equal(np.array(rois_num_stat), rois_num_dy)) - - if __name__ == '__main__': unittest.main()