From d3e9e73a8ad636c2743fb7edf66468c85a6e36c1 Mon Sep 17 00:00:00 2001 From: zqw_1997 <118182234+zhengqiwen1997@users.noreply.github.com> Date: Wed, 7 Dec 2022 21:03:05 +0800 Subject: [PATCH] [fluid remove]: remove paddle.fluid.layers.yolo_box and paddle.fluid.layers.yolov3_loss (#48722) * remove paddle.fluid.layers.nn.temporal_shift * code check * rm unittest * remove fluid.yolo_box * remove fluid.yolov3_loss * change the comments of yolov3_loss to yolo_loss --- python/paddle/fluid/layers/detection.py | 258 ------------------ python/paddle/fluid/tests/test_detection.py | 69 ----- .../unittests/dygraph_to_static/yolov3.py | 4 +- .../unittests/ipu/test_yolo_box_op_ipu.py | 2 +- .../ir/inference/test_trt_yolo_box_op.py | 14 +- .../tests/unittests/test_device_guard.py | 4 +- .../unittests/xpu/test_device_guard_xpu.py | 4 +- 7 files changed, 14 insertions(+), 341 deletions(-) diff --git a/python/paddle/fluid/layers/detection.py b/python/paddle/fluid/layers/detection.py index 3d277705aa9..27491919782 100644 --- a/python/paddle/fluid/layers/detection.py +++ b/python/paddle/fluid/layers/detection.py @@ -52,8 +52,6 @@ __all__ = [ 'iou_similarity', 'box_coder', 'polygon_box_transform', - 'yolov3_loss', - 'yolo_box', 'box_clip', 'multiclass_nms', 'locality_aware_nms', @@ -435,262 +433,6 @@ def polygon_box_transform(input, name=None): return output -@deprecated(since="2.0.0", update_to="paddle.vision.ops.yolo_loss") -@templatedoc(op_type="yolov3_loss") -def yolov3_loss( - x, - gt_box, - gt_label, - anchors, - anchor_mask, - class_num, - ignore_thresh, - downsample_ratio, - gt_score=None, - use_label_smooth=True, - name=None, - scale_x_y=1.0, -): - """ - - ${comment} - - Args: - x (Variable): ${x_comment}The data type is float32 or float64. - gt_box (Variable): groud truth boxes, should be in shape of [N, B, 4], - in the third dimension, x, y, w, h should be stored. - x,y is the center coordinate of boxes, w, h are the - width and height, x, y, w, h should be divided by - input image height to scale to [0, 1]. - N is the batch number and B is the max box number in - an image.The data type is float32 or float64. - gt_label (Variable): class id of ground truth boxes, should be in shape - of [N, B].The data type is int32. - anchors (list|tuple): ${anchors_comment} - anchor_mask (list|tuple): ${anchor_mask_comment} - class_num (int): ${class_num_comment} - ignore_thresh (float): ${ignore_thresh_comment} - downsample_ratio (int): ${downsample_ratio_comment} - name (string): The default value is None. Normally there is no need - for user to set this property. For more information, - please refer to :ref:`api_guide_Name` - gt_score (Variable): mixup score of ground truth boxes, should be in shape - of [N, B]. Default None. - use_label_smooth (bool): ${use_label_smooth_comment} - scale_x_y (float): ${scale_x_y_comment} - - Returns: - Variable: A 1-D tensor with shape [N], the value of yolov3 loss - - Raises: - TypeError: Input x of yolov3_loss must be Variable - TypeError: Input gtbox of yolov3_loss must be Variable - TypeError: Input gtlabel of yolov3_loss must be Variable - TypeError: Input gtscore of yolov3_loss must be None or Variable - TypeError: Attr anchors of yolov3_loss must be list or tuple - TypeError: Attr class_num of yolov3_loss must be an integer - TypeError: Attr ignore_thresh of yolov3_loss must be a float number - TypeError: Attr use_label_smooth of yolov3_loss must be a bool value - - Examples: - .. code-block:: python - - import paddle.fluid as fluid - import paddle - paddle.enable_static() - x = fluid.data(name='x', shape=[None, 255, 13, 13], dtype='float32') - gt_box = fluid.data(name='gt_box', shape=[None, 6, 4], dtype='float32') - gt_label = fluid.data(name='gt_label', shape=[None, 6], dtype='int32') - gt_score = fluid.data(name='gt_score', shape=[None, 6], dtype='float32') - anchors = [10, 13, 16, 30, 33, 23, 30, 61, 62, 45, 59, 119, 116, 90, 156, 198, 373, 326] - anchor_mask = [0, 1, 2] - loss = fluid.layers.yolov3_loss(x=x, gt_box=gt_box, gt_label=gt_label, - gt_score=gt_score, anchors=anchors, - anchor_mask=anchor_mask, class_num=80, - ignore_thresh=0.7, downsample_ratio=32) - """ - - if not isinstance(x, Variable): - raise TypeError("Input x of yolov3_loss must be Variable") - if not isinstance(gt_box, Variable): - raise TypeError("Input gtbox of yolov3_loss must be Variable") - if not isinstance(gt_label, Variable): - raise TypeError("Input gtlabel of yolov3_loss must be Variable") - if gt_score is not None and not isinstance(gt_score, Variable): - raise TypeError("Input gtscore of yolov3_loss must be Variable") - if not isinstance(anchors, list) and not isinstance(anchors, tuple): - raise TypeError("Attr anchors of yolov3_loss must be list or tuple") - if not isinstance(anchor_mask, list) and not isinstance(anchor_mask, tuple): - raise TypeError("Attr anchor_mask of yolov3_loss must be list or tuple") - if not isinstance(class_num, int): - raise TypeError("Attr class_num of yolov3_loss must be an integer") - if not isinstance(ignore_thresh, float): - raise TypeError( - "Attr ignore_thresh of yolov3_loss must be a float number" - ) - if not isinstance(use_label_smooth, bool): - raise TypeError( - "Attr use_label_smooth of yolov3_loss must be a bool value" - ) - - if _non_static_mode(): - attrs = ( - "anchors", - anchors, - "anchor_mask", - anchor_mask, - "class_num", - class_num, - "ignore_thresh", - ignore_thresh, - "downsample_ratio", - downsample_ratio, - "use_label_smooth", - use_label_smooth, - "scale_x_y", - scale_x_y, - ) - loss, _, _ = _legacy_C_ops.yolov3_loss( - x, gt_box, gt_label, gt_score, *attrs - ) - return loss - - helper = LayerHelper('yolov3_loss', **locals()) - loss = helper.create_variable_for_type_inference(dtype=x.dtype) - objectness_mask = helper.create_variable_for_type_inference(dtype='int32') - gt_match_mask = helper.create_variable_for_type_inference(dtype='int32') - - inputs = { - "X": x, - "GTBox": gt_box, - "GTLabel": gt_label, - } - if gt_score is not None: - inputs["GTScore"] = gt_score - - attrs = { - "anchors": anchors, - "anchor_mask": anchor_mask, - "class_num": class_num, - "ignore_thresh": ignore_thresh, - "downsample_ratio": downsample_ratio, - "use_label_smooth": use_label_smooth, - "scale_x_y": scale_x_y, - } - - helper.append_op( - type='yolov3_loss', - inputs=inputs, - outputs={ - 'Loss': loss, - 'ObjectnessMask': objectness_mask, - 'GTMatchMask': gt_match_mask, - }, - attrs=attrs, - ) - return loss - - -@deprecated(since="2.0.0", update_to="paddle.vision.ops.yolo_box") -@templatedoc(op_type="yolo_box") -def yolo_box( - x, - img_size, - anchors, - class_num, - conf_thresh, - downsample_ratio, - clip_bbox=True, - name=None, - scale_x_y=1.0, - iou_aware=False, - iou_aware_factor=0.5, -): - """ - - ${comment} - - Args: - x (Variable): ${x_comment} The data type is float32 or float64. - img_size (Variable): ${img_size_comment} The data type is int32. - anchors (list|tuple): ${anchors_comment} - class_num (int): ${class_num_comment} - conf_thresh (float): ${conf_thresh_comment} - downsample_ratio (int): ${downsample_ratio_comment} - clip_bbox (bool): ${clip_bbox_comment} - scale_x_y (float): ${scale_x_y_comment} - name (string): The default value is None. Normally there is no need - for user to set this property. For more information, - please refer to :ref:`api_guide_Name` - iou_aware (bool): ${iou_aware_comment} - iou_aware_factor (float): ${iou_aware_factor_comment} - - Returns: - Variable: A 3-D tensor with shape [N, M, 4], the coordinates of boxes, - and a 3-D tensor with shape [N, M, :attr:`class_num`], the classification - scores of boxes. - - Raises: - TypeError: Input x of yolov_box must be Variable - TypeError: Attr anchors of yolo box must be list or tuple - TypeError: Attr class_num of yolo box must be an integer - TypeError: Attr conf_thresh of yolo box must be a float number - - Examples: - - .. code-block:: python - - import paddle.fluid as fluid - import paddle - paddle.enable_static() - x = fluid.data(name='x', shape=[None, 255, 13, 13], dtype='float32') - img_size = fluid.data(name='img_size',shape=[None, 2],dtype='int64') - anchors = [10, 13, 16, 30, 33, 23] - boxes,scores = fluid.layers.yolo_box(x=x, img_size=img_size, class_num=80, anchors=anchors, - conf_thresh=0.01, downsample_ratio=32) - """ - helper = LayerHelper('yolo_box', **locals()) - - if not isinstance(x, Variable): - raise TypeError("Input x of yolo_box must be Variable") - if not isinstance(img_size, Variable): - raise TypeError("Input img_size of yolo_box must be Variable") - if not isinstance(anchors, list) and not isinstance(anchors, tuple): - raise TypeError("Attr anchors of yolo_box must be list or tuple") - if not isinstance(class_num, int): - raise TypeError("Attr class_num of yolo_box must be an integer") - if not isinstance(conf_thresh, float): - raise TypeError("Attr ignore_thresh of yolo_box must be a float number") - - boxes = helper.create_variable_for_type_inference(dtype=x.dtype) - scores = helper.create_variable_for_type_inference(dtype=x.dtype) - - attrs = { - "anchors": anchors, - "class_num": class_num, - "conf_thresh": conf_thresh, - "downsample_ratio": downsample_ratio, - "clip_bbox": clip_bbox, - "scale_x_y": scale_x_y, - "iou_aware": iou_aware, - "iou_aware_factor": iou_aware_factor, - } - - helper.append_op( - type='yolo_box', - inputs={ - "X": x, - "ImgSize": img_size, - }, - outputs={ - 'Boxes': boxes, - 'Scores': scores, - }, - attrs=attrs, - ) - return boxes, scores - - @templatedoc() def detection_map( detect_res, diff --git a/python/paddle/fluid/tests/test_detection.py b/python/paddle/fluid/tests/test_detection.py index cf2523947f0..a2745bbca8e 100644 --- a/python/paddle/fluid/tests/test_detection.py +++ b/python/paddle/fluid/tests/test_detection.py @@ -528,75 +528,6 @@ class TestGenerateProposals(LayerTest): np.testing.assert_array_equal(np.array(rois_num_stat), rois_num_dy) -class TestYoloDetection(unittest.TestCase): - def test_yolov3_loss(self): - program = Program() - with program_guard(program): - x = layers.data(name='x', shape=[30, 7, 7], dtype='float32') - gt_box = layers.data(name='gt_box', shape=[10, 4], dtype='float32') - gt_label = layers.data(name='gt_label', shape=[10], dtype='int32') - gt_score = layers.data(name='gt_score', shape=[10], dtype='float32') - loss = layers.yolov3_loss( - x, - gt_box, - gt_label, - [10, 13, 30, 13], - [0, 1], - 10, - 0.7, - 32, - gt_score=gt_score, - use_label_smooth=False, - ) - - self.assertIsNotNone(loss) - - def test_yolo_box(self): - program = Program() - with program_guard(program): - x = layers.data(name='x', shape=[30, 7, 7], dtype='float32') - img_size = layers.data(name='img_size', shape=[2], dtype='int32') - boxes, scores = layers.yolo_box( - x, img_size, [10, 13, 30, 13], 10, 0.01, 32 - ) - self.assertIsNotNone(boxes) - self.assertIsNotNone(scores) - - def test_yolov3_loss_with_scale(self): - program = Program() - with program_guard(program): - x = layers.data(name='x', shape=[30, 7, 7], dtype='float32') - gt_box = layers.data(name='gt_box', shape=[10, 4], dtype='float32') - gt_label = layers.data(name='gt_label', shape=[10], dtype='int32') - gt_score = layers.data(name='gt_score', shape=[10], dtype='float32') - loss = layers.yolov3_loss( - x, - gt_box, - gt_label, - [10, 13, 30, 13], - [0, 1], - 10, - 0.7, - 32, - gt_score=gt_score, - use_label_smooth=False, - scale_x_y=1.2, - ) - - self.assertIsNotNone(loss) - - def test_yolo_box_with_scale(self): - program = Program() - with program_guard(program): - x = layers.data(name='x', shape=[30, 7, 7], dtype='float32') - img_size = layers.data(name='img_size', shape=[2], dtype='int32') - boxes, scores = layers.yolo_box( - x, img_size, [10, 13, 30, 13], 10, 0.01, 32, scale_x_y=1.2 - ) - self.assertIsNotNone(boxes) - self.assertIsNotNone(scores) - - class TestBoxClip(unittest.TestCase): def test_box_clip(self): program = Program() diff --git a/python/paddle/fluid/tests/unittests/dygraph_to_static/yolov3.py b/python/paddle/fluid/tests/unittests/dygraph_to_static/yolov3.py index 1c1877681c4..2fe1f652cce 100644 --- a/python/paddle/fluid/tests/unittests/dygraph_to_static/yolov3.py +++ b/python/paddle/fluid/tests/unittests/dygraph_to_static/yolov3.py @@ -314,7 +314,7 @@ class YOLOv3(fluid.dygraph.Layer): for i, out in enumerate(self.outputs): anchor_mask = cfg.anchor_masks[i] if self.is_train: - loss = fluid.layers.yolov3_loss( + loss = paddle.vision.ops.yolo_loss( x=out, gt_box=self.gtbox, gt_label=self.gtlabel, @@ -333,7 +333,7 @@ class YOLOv3(fluid.dygraph.Layer): for m in anchor_mask: mask_anchors.append(cfg.anchors[2 * m]) mask_anchors.append(cfg.anchors[2 * m + 1]) - boxes, scores = fluid.layers.yolo_box( + boxes, scores = paddle.vision.ops.yolo_box( x=out, img_size=self.im_shape, anchors=mask_anchors, diff --git a/python/paddle/fluid/tests/unittests/ipu/test_yolo_box_op_ipu.py b/python/paddle/fluid/tests/unittests/ipu/test_yolo_box_op_ipu.py index 1248eb10921..40c56af9228 100644 --- a/python/paddle/fluid/tests/unittests/ipu/test_yolo_box_op_ipu.py +++ b/python/paddle/fluid/tests/unittests/ipu/test_yolo_box_op_ipu.py @@ -65,7 +65,7 @@ class TestBase(IPUOpTest): 'value': 6, } img_size = paddle.fluid.layers.fill_constant(**attrs) - out = paddle.fluid.layers.yolo_box(x=x, img_size=img_size, **self.attrs) + out = paddle.vision.ops.yolo_box(x=x, img_size=img_size, **self.attrs) self.fetch_list = [x.name for x in out] def run_model(self, exec_mode): diff --git a/python/paddle/fluid/tests/unittests/ir/inference/test_trt_yolo_box_op.py b/python/paddle/fluid/tests/unittests/ir/inference/test_trt_yolo_box_op.py index 42a65f7f79f..a578c5216f3 100644 --- a/python/paddle/fluid/tests/unittests/ir/inference/test_trt_yolo_box_op.py +++ b/python/paddle/fluid/tests/unittests/ir/inference/test_trt_yolo_box_op.py @@ -17,8 +17,8 @@ import unittest import numpy as np from inference_pass_test import InferencePassTest +import paddle import paddle.fluid as fluid -import paddle.fluid.core as core from paddle.fluid.core import AnalysisConfig, PassVersionChecker @@ -56,7 +56,7 @@ class TRTYoloBoxTest(InferencePassTest): self.downsample_ratio = 32 def append_yolobox(self, image, image_size): - return fluid.layers.yolo_box( + return paddle.vision.ops.yolo_box( x=image, img_size=image_size, class_num=self.class_num, @@ -66,7 +66,7 @@ class TRTYoloBoxTest(InferencePassTest): ) def test_check_output(self): - if core.is_compiled_with_cuda(): + if paddle.is_compiled_with_cuda(): use_gpu = True self.check_output_with_option(use_gpu, flatten=True) self.assertTrue( @@ -106,7 +106,7 @@ class TRTYoloBoxFP16Test(InferencePassTest): self.downsample_ratio = 32 def append_yolobox(self, image, image_size): - return fluid.layers.yolo_box( + return paddle.vision.ops.yolo_box( x=image, img_size=image_size, class_num=self.class_num, @@ -116,7 +116,7 @@ class TRTYoloBoxFP16Test(InferencePassTest): ) def test_check_output(self): - if core.is_compiled_with_cuda(): + if paddle.is_compiled_with_cuda(): use_gpu = True self.check_output_with_option(use_gpu, flatten=True, rtol=1e-1) self.assertTrue( @@ -160,7 +160,7 @@ class TRTYoloBoxIoUAwareTest(InferencePassTest): self.iou_aware_factor = 0.5 def append_yolobox(self, image, image_size): - return fluid.layers.yolo_box( + return paddle.vision.ops.yolo_box( x=image, img_size=image_size, class_num=self.class_num, @@ -172,7 +172,7 @@ class TRTYoloBoxIoUAwareTest(InferencePassTest): ) def test_check_output(self): - if core.is_compiled_with_cuda(): + if paddle.is_compiled_with_cuda(): use_gpu = True self.check_output_with_option(use_gpu, flatten=True) self.assertTrue( diff --git a/python/paddle/fluid/tests/unittests/test_device_guard.py b/python/paddle/fluid/tests/unittests/test_device_guard.py index d62893de97c..eff076c6a78 100644 --- a/python/paddle/fluid/tests/unittests/test_device_guard.py +++ b/python/paddle/fluid/tests/unittests/test_device_guard.py @@ -127,8 +127,8 @@ class TestDeviceGuard(unittest.TestCase): ] anchor_mask = [0, 1, 2] with paddle.static.device_guard("gpu"): - # yolov3_loss only has cpu kernel, so its cpu kernel will be executed - loss = fluid.layers.yolov3_loss( + # yolo_loss only has cpu kernel, so its cpu kernel will be executed + loss = paddle.vision.ops.yolo_loss( x=x, gt_box=gt_box, gt_label=gt_label, diff --git a/python/paddle/fluid/tests/unittests/xpu/test_device_guard_xpu.py b/python/paddle/fluid/tests/unittests/xpu/test_device_guard_xpu.py index 3e126318df2..6de4b3f07b2 100644 --- a/python/paddle/fluid/tests/unittests/xpu/test_device_guard_xpu.py +++ b/python/paddle/fluid/tests/unittests/xpu/test_device_guard_xpu.py @@ -133,8 +133,8 @@ class TestDeviceGuard(unittest.TestCase): ] anchor_mask = [0, 1, 2] with paddle.static.device_guard("xpu"): - # yolov3_loss only has cpu kernel, so its cpu kernel will be executed - loss = fluid.layers.yolov3_loss( + # yolo_loss has cpu kernel, so its cpu kernel will be executed + loss = paddle.vision.ops.yolo_loss( x=x, gt_box=gt_box, gt_label=gt_label, -- GitLab