diff --git a/paddle/fluid/operators/detection/CMakeLists.txt b/paddle/fluid/operators/detection/CMakeLists.txt index d965e1ace5fc3182f79e5e92906f0ee448bce24d..35c41e0dc93da0a367b8e98b4e4a4882bcea0822 100644 --- a/paddle/fluid/operators/detection/CMakeLists.txt +++ b/paddle/fluid/operators/detection/CMakeLists.txt @@ -47,6 +47,7 @@ elseif(WITH_MLU) detection_library(iou_similarity_op SRCS iou_similarity_op.cc iou_similarity_op_mlu.cc) detection_library(prior_box_op SRCS prior_box_op.cc) + detection_library(yolo_box_op SRCS yolo_box_op.cc yolo_box_op_mlu.cc) elseif(WITH_ASCEND_CL) detection_library(iou_similarity_op SRCS iou_similarity_op.cc iou_similarity_op_npu.cc) @@ -55,6 +56,7 @@ else() detection_library(iou_similarity_op SRCS iou_similarity_op.cc iou_similarity_op.cu) detection_library(prior_box_op SRCS prior_box_op.cc) + detection_library(yolo_box_op SRCS yolo_box_op.cc) # detection_library(generate_proposals_v2_op SRCS generate_proposals_v2_op.cc) endif() @@ -73,7 +75,6 @@ detection_library(locality_aware_nms_op SRCS locality_aware_nms_op.cc DEPS gpc) detection_library(matrix_nms_op SRCS matrix_nms_op.cc DEPS gpc) detection_library(box_clip_op SRCS box_clip_op.cc box_clip_op.cu) detection_library(yolov3_loss_op SRCS yolov3_loss_op.cc) -detection_library(yolo_box_op SRCS yolo_box_op.cc) detection_library(box_decoder_and_assign_op SRCS box_decoder_and_assign_op.cc box_decoder_and_assign_op.cu) detection_library(sigmoid_focal_loss_op SRCS sigmoid_focal_loss_op.cc diff --git a/paddle/fluid/operators/detection/yolo_box_op_mlu.cc b/paddle/fluid/operators/detection/yolo_box_op_mlu.cc new file mode 100644 index 0000000000000000000000000000000000000000..739c05805d68a2efb9edd9c284018383ebc4dcea --- /dev/null +++ b/paddle/fluid/operators/detection/yolo_box_op_mlu.cc @@ -0,0 +1,137 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/fluid/framework/op_registry.h" +#include "paddle/fluid/operators/mlu/mlu_baseop.h" + +namespace paddle { +namespace operators { +template +class YoloBoxMLUKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext& ctx) const override { + auto* x = ctx.Input("X"); + auto* img_size = ctx.Input("ImgSize"); + auto* boxes = ctx.Output("Boxes"); + auto* scores = ctx.Output("Scores"); + const std::vector anchors = ctx.Attr>("anchors"); + auto class_num = ctx.Attr("class_num"); + auto conf_thresh = ctx.Attr("conf_thresh"); + auto downsample_ratio = ctx.Attr("downsample_ratio"); + auto clip_bbox = ctx.Attr("clip_bbox"); + auto scale = ctx.Attr("scale_x_y"); + auto iou_aware = ctx.Attr("iou_aware"); + auto iou_aware_factor = ctx.Attr("iou_aware_factor"); + + int anchor_num = anchors.size() / 2; + int64_t size = anchors.size(); + auto dim_x = x->dims(); + int n = dim_x[0]; + int s = anchor_num; + int h = dim_x[2]; + int w = dim_x[3]; + + // The output of mluOpYoloBox: A 4-D tensor with shape [N, anchor_num, 4, + // H*W], the coordinates of boxes, and a 4-D tensor with shape [N, + // anchor_num, :attr:`class_num`, H*W], the classification scores of boxes. + std::vector boxes_dim_mluops({n, s, 4, h * w}); + std::vector scores_dim_mluops({n, s, class_num, h * w}); + + // In Paddle framework: A 3-D tensor with shape [N, M, 4], the coordinates + // of boxes, and a 3-D tensor with shape [N, M, :attr:`class_num`], the + // classification scores of boxes. + std::vector boxes_out_dim({n, s, h * w, 4}); + std::vector scores_out_dim({n, s, h * w, class_num}); + + auto& dev_ctx = ctx.template device_context(); + phi::DenseTensor boxes_tensor_mluops = + ctx.AllocateTmpTensor({n, s, 4, h * w}, dev_ctx); + phi::DenseTensor scores_tensor_mluops = + ctx.AllocateTmpTensor({n, s, class_num, h * w}, + dev_ctx); + MLUOpTensorDesc boxes_trans_desc_mluops( + 4, boxes_dim_mluops.data(), ToMluOpDataType()); + MLUCnnlTensorDesc boxes_trans_desc_cnnl( + 4, boxes_dim_mluops.data(), ToCnnlDataType()); + MLUOpTensorDesc scores_trans_desc_mluops( + 4, scores_dim_mluops.data(), ToMluOpDataType()); + MLUCnnlTensorDesc scores_trans_desc_cnnl( + 4, scores_dim_mluops.data(), ToCnnlDataType()); + + boxes->mutable_data(ctx.GetPlace()); + scores->mutable_data(ctx.GetPlace()); + FillMLUTensorWithHostValue(ctx, static_cast(0), boxes); + FillMLUTensorWithHostValue(ctx, static_cast(0), scores); + + MLUOpTensorDesc x_desc(*x, MLUOP_LAYOUT_ARRAY, ToMluOpDataType()); + MLUOpTensorDesc img_size_desc( + *img_size, MLUOP_LAYOUT_ARRAY, ToMluOpDataType()); + Tensor anchors_temp(framework::TransToPhiDataType(VT::INT32)); + anchors_temp.Resize({size}); + paddle::framework::TensorFromVector( + anchors, ctx.device_context(), &anchors_temp); + MLUOpTensorDesc anchors_desc(anchors_temp); + MLUCnnlTensorDesc boxes_desc_cnnl( + 4, boxes_out_dim.data(), ToCnnlDataType()); + MLUCnnlTensorDesc scores_desc_cnnl( + 4, scores_out_dim.data(), ToCnnlDataType()); + + MLUOP::OpYoloBox(ctx, + x_desc.get(), + GetBasePtr(x), + img_size_desc.get(), + GetBasePtr(img_size), + anchors_desc.get(), + GetBasePtr(&anchors_temp), + class_num, + conf_thresh, + downsample_ratio, + clip_bbox, + scale, + iou_aware, + iou_aware_factor, + boxes_trans_desc_mluops.get(), + GetBasePtr(&boxes_tensor_mluops), + scores_trans_desc_mluops.get(), + GetBasePtr(&scores_tensor_mluops)); + const std::vector perm = {0, 1, 3, 2}; + + // transpose the boxes from [N, S, 4, H*W] to [N, S, H*W, 4] + MLUCnnl::Transpose(ctx, + perm, + 4, + boxes_trans_desc_cnnl.get(), + GetBasePtr(&boxes_tensor_mluops), + boxes_desc_cnnl.get(), + GetBasePtr(boxes)); + + // transpose the scores from [N, S, class_num, H*W] to [N, S, H*W, + // class_num] + MLUCnnl::Transpose(ctx, + perm, + 4, + scores_trans_desc_cnnl.get(), + GetBasePtr(&scores_tensor_mluops), + scores_desc_cnnl.get(), + GetBasePtr(scores)); + } +}; + +} // namespace operators +} // namespace paddle + +namespace ops = paddle::operators; +namespace plat = paddle::platform; + +REGISTER_OP_MLU_KERNEL(yolo_box, ops::YoloBoxMLUKernel); diff --git a/paddle/fluid/operators/mlu/mlu_baseop.cc b/paddle/fluid/operators/mlu/mlu_baseop.cc index c2f55811ee57bcc9e39c8879d62b88ad6b7c2363..03e5a49e28ac3d12c912a118a082524110f67e1d 100644 --- a/paddle/fluid/operators/mlu/mlu_baseop.cc +++ b/paddle/fluid/operators/mlu/mlu_baseop.cc @@ -5418,5 +5418,45 @@ MLURNNDesc::~MLURNNDesc() { diff_x)); } +/* static */ void MLUOP::OpYoloBox(const ExecutionContext& ctx, + const mluOpTensorDescriptor_t x_desc, + const void* x, + const mluOpTensorDescriptor_t img_size_desc, + const void* img_size, + const mluOpTensorDescriptor_t anchors_desc, + const void* anchors, + const int class_num, + const float conf_thresh, + const int downsample_ratio, + const bool clip_bbox, + const float scale, + const bool iou_aware, + const float iou_aware_factor, + const mluOpTensorDescriptor_t boxes_desc, + void* boxes, + const mluOpTensorDescriptor_t scores_desc, + void* scores) { + mluOpHandle_t handle = GetMLUOpHandleFromCTX(ctx); + + PADDLE_ENFORCE_MLU_SUCCESS(mluOpYoloBox(handle, + x_desc, + x, + img_size_desc, + img_size, + anchors_desc, + anchors, + class_num, + conf_thresh, + downsample_ratio, + clip_bbox, + scale, + iou_aware, + iou_aware_factor, + boxes_desc, + boxes, + scores_desc, + scores)); +} + } // namespace operators } // namespace paddle diff --git a/paddle/fluid/operators/mlu/mlu_baseop.h b/paddle/fluid/operators/mlu/mlu_baseop.h index 3d6cc2255ec5e5abe6599b697a6704c06d1dc8a2..354c5fe3f3d8e5871722d1aa7456f8e40f6fc08f 100644 --- a/paddle/fluid/operators/mlu/mlu_baseop.h +++ b/paddle/fluid/operators/mlu/mlu_baseop.h @@ -2292,6 +2292,27 @@ class MLUCnnl { void* diff_x); }; +class MLUOP { + public: + static void OpYoloBox(const ExecutionContext& ctx, + const mluOpTensorDescriptor_t x_desc, + const void* x, + const mluOpTensorDescriptor_t img_size_desc, + const void* img_size, + const mluOpTensorDescriptor_t anchors_desc, + const void* anchors, + const int class_num, + const float conf_thresh, + const int downsample_ratio, + const bool clip_bbox, + const float scale, + const bool iou_aware, + const float iou_aware_factor, + const mluOpTensorDescriptor_t boxes_desc, + void* boxes, + const mluOpTensorDescriptor_t scores_desc, + void* scores); +}; const std::map, std::vector>> TransPermMap = { // trans_mode, (forward_perm, backward_perm) diff --git a/python/paddle/fluid/tests/unittests/mlu/test_yolo_box_op_mlu.py b/python/paddle/fluid/tests/unittests/mlu/test_yolo_box_op_mlu.py new file mode 100644 index 0000000000000000000000000000000000000000..443ad4b22365c582e61e7dface516f7f0f2e96d9 --- /dev/null +++ b/python/paddle/fluid/tests/unittests/mlu/test_yolo_box_op_mlu.py @@ -0,0 +1,276 @@ +# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import division +import sys + +sys.path.append("..") +import unittest +import numpy as np +from op_test import OpTest +import paddle +from paddle.fluid import core +import paddle.fluid as fluid +from paddle.fluid.op import Operator +from paddle.fluid.executor import Executor +from paddle.fluid.framework import _test_eager_guard + +paddle.enable_static() + + +def sigmoid(x): + return (1.0 / (1.0 + np.exp(((-1.0) * x)))) + + +def YoloBox(x, img_size, attrs): + (n, c, h, w) = x.shape + anchors = attrs['anchors'] + an_num = int((len(anchors) // 2)) + class_num = attrs['class_num'] + conf_thresh = attrs['conf_thresh'] + downsample = attrs['downsample_ratio'] + clip_bbox = attrs['clip_bbox'] + scale_x_y = attrs['scale_x_y'] + iou_aware = attrs['iou_aware'] + iou_aware_factor = attrs['iou_aware_factor'] + bias_x_y = ((-0.5) * (scale_x_y - 1.0)) + input_h = (downsample * h) + input_w = (downsample * w) + if iou_aware: + ioup = x[:, :an_num, :, :] + ioup = np.expand_dims(ioup, axis=(-1)) + x = x[:, an_num:, :, :] + x = x.reshape((n, an_num, (5 + class_num), h, w)).transpose((0, 1, 3, 4, 2)) + pred_box = x[:, :, :, :, :4].copy() + grid_x = np.tile(np.arange(w).reshape((1, w)), (h, 1)) + grid_y = np.tile(np.arange(h).reshape((h, 1)), (1, w)) + pred_box[:, :, :, :, 0] = (( + (grid_x + (sigmoid(pred_box[:, :, :, :, 0]) * scale_x_y)) + bias_x_y) / + w) + pred_box[:, :, :, :, 1] = (( + (grid_y + (sigmoid(pred_box[:, :, :, :, 1]) * scale_x_y)) + bias_x_y) / + h) + anchors = [(anchors[i], anchors[(i + 1)]) + for i in range(0, len(anchors), 2)] + anchors_s = np.array([((an_w / input_w), (an_h / input_h)) + for (an_w, an_h) in anchors]) + anchor_w = anchors_s[:, 0:1].reshape((1, an_num, 1, 1)) + anchor_h = anchors_s[:, 1:2].reshape((1, an_num, 1, 1)) + pred_box[:, :, :, :, 2] = (np.exp(pred_box[:, :, :, :, 2]) * anchor_w) + pred_box[:, :, :, :, 3] = (np.exp(pred_box[:, :, :, :, 3]) * anchor_h) + if iou_aware: + pred_conf = ((sigmoid(x[:, :, :, :, 4:5])**(1 - iou_aware_factor)) * + (sigmoid(ioup)**iou_aware_factor)) + else: + pred_conf = sigmoid(x[:, :, :, :, 4:5]) + pred_conf[(pred_conf < conf_thresh)] = 0.0 + pred_score = (sigmoid(x[:, :, :, :, 5:]) * pred_conf) + pred_box = (pred_box * (pred_conf > 0.0).astype('float32')) + pred_box = pred_box.reshape((n, (-1), 4)) + (pred_box[:, :, :2], + pred_box[:, :, 2:4]) = ((pred_box[:, :, :2] - (pred_box[:, :, 2:4] / 2.0)), + (pred_box[:, :, :2] + (pred_box[:, :, 2:4] / 2.0))) + pred_box[:, :, 0] = (pred_box[:, :, 0] * img_size[:, 1][:, np.newaxis]) + pred_box[:, :, 1] = (pred_box[:, :, 1] * img_size[:, 0][:, np.newaxis]) + pred_box[:, :, 2] = (pred_box[:, :, 2] * img_size[:, 1][:, np.newaxis]) + pred_box[:, :, 3] = (pred_box[:, :, 3] * img_size[:, 0][:, np.newaxis]) + if clip_bbox: + for i in range(len(pred_box)): + pred_box[i, :, 0] = np.clip(pred_box[i, :, 0], 0, np.inf) + pred_box[i, :, 1] = np.clip(pred_box[i, :, 1], 0, np.inf) + pred_box[i, :, 2] = np.clip(pred_box[i, :, 2], (-np.inf), + (img_size[(i, 1)] - 1)) + pred_box[i, :, 3] = np.clip(pred_box[i, :, 3], (-np.inf), + (img_size[(i, 0)] - 1)) + return (pred_box, pred_score.reshape((n, (-1), class_num))) + + +class TestYoloBoxOp(OpTest): + + def setUp(self): + self.initTestCase() + self.op_type = 'yolo_box' + self.place = paddle.device.MLUPlace(0) + self.__class__.use_mlu = True + self.__class__.no_need_check_grad = True + self.python_api = paddle.vision.ops.yolo_box + x = np.random.random(self.x_shape).astype('float32') + img_size = np.random.randint(10, 20, self.imgsize_shape).astype('int32') + self.attrs = { + 'anchors': self.anchors, + 'class_num': self.class_num, + 'conf_thresh': self.conf_thresh, + 'downsample_ratio': self.downsample, + 'clip_bbox': self.clip_bbox, + 'scale_x_y': self.scale_x_y, + 'iou_aware': self.iou_aware, + 'iou_aware_factor': self.iou_aware_factor + } + self.inputs = {'X': x, 'ImgSize': img_size} + (boxes, scores) = YoloBox(x, img_size, self.attrs) + self.outputs = {'Boxes': boxes, 'Scores': scores} + + def test_check_output(self): + self.check_output_with_place(self.place, check_eager=False, atol=1e-5) + + def initTestCase(self): + self.anchors = [10, 13, 16, 30, 33, 23] + an_num = int((len(self.anchors) // 2)) + self.batch_size = 32 + self.class_num = 2 + self.conf_thresh = 0.5 + self.downsample = 32 + self.clip_bbox = True + self.x_shape = (self.batch_size, (an_num * (5 + self.class_num)), 13, + 13) + self.imgsize_shape = (self.batch_size, 2) + self.scale_x_y = 1.0 + self.iou_aware = False + self.iou_aware_factor = 0.5 + + +class TestYoloBoxOpNoClipBbox(TestYoloBoxOp): + + def initTestCase(self): + self.anchors = [10, 13, 16, 30, 33, 23] + an_num = int((len(self.anchors) // 2)) + self.batch_size = 32 + self.class_num = 2 + self.conf_thresh = 0.5 + self.downsample = 32 + self.clip_bbox = False + self.x_shape = (self.batch_size, (an_num * (5 + self.class_num)), 13, + 13) + self.imgsize_shape = (self.batch_size, 2) + self.scale_x_y = 1.0 + self.iou_aware = False + self.iou_aware_factor = 0.5 + + +class TestYoloBoxOpScaleXY(TestYoloBoxOp): + + def initTestCase(self): + self.anchors = [10, 13, 16, 30, 33, 23] + an_num = int((len(self.anchors) // 2)) + self.batch_size = 32 + self.class_num = 2 + self.conf_thresh = 0.5 + self.downsample = 32 + self.clip_bbox = True + self.x_shape = (self.batch_size, (an_num * (5 + self.class_num)), 13, + 13) + self.imgsize_shape = (self.batch_size, 2) + self.scale_x_y = 1.2 + self.iou_aware = False + self.iou_aware_factor = 0.5 + + +class TestYoloBoxOpIoUAware(TestYoloBoxOp): + + def initTestCase(self): + self.anchors = [10, 13, 16, 30, 33, 23] + an_num = int((len(self.anchors) // 2)) + self.batch_size = 32 + self.class_num = 2 + self.conf_thresh = 0.5 + self.downsample = 32 + self.clip_bbox = True + self.x_shape = (self.batch_size, (an_num * (6 + self.class_num)), 13, + 13) + self.imgsize_shape = (self.batch_size, 2) + self.scale_x_y = 1.0 + self.iou_aware = True + self.iou_aware_factor = 0.5 + + +class TestYoloBoxDygraph(unittest.TestCase): + + def test_dygraph(self): + paddle.disable_static() + img_size = np.ones((2, 2)).astype('int32') + img_size = paddle.to_tensor(img_size) + x1 = np.random.random([2, 14, 8, 8]).astype('float32') + x1 = paddle.to_tensor(x1) + (boxes, scores) = paddle.vision.ops.yolo_box(x1, + img_size=img_size, + anchors=[10, 13, 16, 30], + class_num=2, + conf_thresh=0.01, + downsample_ratio=8, + clip_bbox=True, + scale_x_y=1.0) + assert ((boxes is not None) and (scores is not None)) + x2 = np.random.random([2, 16, 8, 8]).astype('float32') + x2 = paddle.to_tensor(x2) + (boxes, scores) = paddle.vision.ops.yolo_box(x2, + img_size=img_size, + anchors=[10, 13, 16, 30], + class_num=2, + conf_thresh=0.01, + downsample_ratio=8, + clip_bbox=True, + scale_x_y=1.0, + iou_aware=True, + iou_aware_factor=0.5) + paddle.enable_static() + + +class TestYoloBoxStatic(unittest.TestCase): + + def test_static(self): + x1 = paddle.static.data('x1', [2, 14, 8, 8], 'float32') + img_size = paddle.static.data('img_size', [2, 2], 'int32') + (boxes, scores) = paddle.vision.ops.yolo_box(x1, + img_size=img_size, + anchors=[10, 13, 16, 30], + class_num=2, + conf_thresh=0.01, + downsample_ratio=8, + clip_bbox=True, + scale_x_y=1.0) + assert ((boxes is not None) and (scores is not None)) + x2 = paddle.static.data('x2', [2, 16, 8, 8], 'float32') + (boxes, scores) = paddle.vision.ops.yolo_box(x2, + img_size=img_size, + anchors=[10, 13, 16, 30], + class_num=2, + conf_thresh=0.01, + downsample_ratio=8, + clip_bbox=True, + scale_x_y=1.0, + iou_aware=True, + iou_aware_factor=0.5) + assert ((boxes is not None) and (scores is not None)) + + +class TestYoloBoxOpHW(TestYoloBoxOp): + + def initTestCase(self): + self.anchors = [10, 13, 16, 30, 33, 23] + an_num = int((len(self.anchors) // 2)) + self.batch_size = 32 + self.class_num = 2 + self.conf_thresh = 0.5 + self.downsample = 32 + self.clip_bbox = False + self.x_shape = (self.batch_size, (an_num * (5 + self.class_num)), 13, 9) + self.imgsize_shape = (self.batch_size, 2) + self.scale_x_y = 1.0 + self.iou_aware = False + self.iou_aware_factor = 0.5 + + +if __name__ == '__main__': + paddle.enable_static() + unittest.main()