diff --git a/paddle/fluid/API.spec b/paddle/fluid/API.spec index fdd23681af5401ca9d1f64b818514924ccf96fed..d28eb0a8088b41c7dfdd24c39580c8dd2f655bda 100644 --- a/paddle/fluid/API.spec +++ b/paddle/fluid/API.spec @@ -331,6 +331,7 @@ paddle.fluid.layers.iou_similarity (ArgSpec(args=['x', 'y', 'name'], varargs=Non paddle.fluid.layers.box_coder (ArgSpec(args=['prior_box', 'prior_box_var', 'target_box', 'code_type', 'box_normalized', 'name', 'axis'], varargs=None, keywords=None, defaults=('encode_center_size', True, None, 0)), ('document', '032d0f4b7d8f6235ee5d91e473344f0e')) paddle.fluid.layers.polygon_box_transform (ArgSpec(args=['input', 'name'], varargs=None, keywords=None, defaults=(None,)), ('document', '0e5ac2507723a0b5adec473f9556799b')) paddle.fluid.layers.yolov3_loss (ArgSpec(args=['x', 'gtbox', 'gtlabel', 'anchors', 'anchor_mask', 'class_num', 'ignore_thresh', 'downsample_ratio', 'gtscore', 'use_label_smooth', 'name'], varargs=None, keywords=None, defaults=(None, True, None)), ('document', '57fa96922e42db8f064c3fb77f2255e8')) +paddle.fluid.layers.yolo_box (ArgSpec(args=['x', 'img_size', 'anchors', 'class_num', 'conf_thresh', 'downsample_ratio', 'name'], varargs=None, keywords=None, defaults=(None,)), ('document', '5566169a5ab993d177792c023c7fb340')) paddle.fluid.layers.box_clip (ArgSpec(args=['input', 'im_info', 'name'], varargs=None, keywords=None, defaults=(None,)), ('document', '397e9e02b451d99c56e20f268fa03f2e')) paddle.fluid.layers.multiclass_nms (ArgSpec(args=['bboxes', 'scores', 'score_threshold', 'nms_top_k', 'keep_top_k', 'nms_threshold', 'normalized', 'nms_eta', 'background_label', 'name'], varargs=None, keywords=None, defaults=(0.3, True, 1.0, 0, None)), ('document', 'ca7d1107b6c5d2d6d8221039a220fde0')) paddle.fluid.layers.distribute_fpn_proposals (ArgSpec(args=['fpn_rois', 'min_level', 'max_level', 'refer_level', 'refer_scale', 'name'], varargs=None, keywords=None, defaults=(None,)), ('document', '7bb011ec26bace2bc23235aa4a17647d')) diff --git a/paddle/fluid/operators/detection/CMakeLists.txt b/paddle/fluid/operators/detection/CMakeLists.txt index c87837e69424335ac926bf05664e5f79940390b5..94a2016aa53212c3ae5af6d86cccb117855cc3b4 100644 --- a/paddle/fluid/operators/detection/CMakeLists.txt +++ b/paddle/fluid/operators/detection/CMakeLists.txt @@ -33,6 +33,7 @@ detection_library(rpn_target_assign_op SRCS rpn_target_assign_op.cc) detection_library(generate_proposal_labels_op SRCS generate_proposal_labels_op.cc) detection_library(box_clip_op SRCS box_clip_op.cc box_clip_op.cu) detection_library(yolov3_loss_op SRCS yolov3_loss_op.cc) +detection_library(yolo_box_op SRCS yolo_box_op.cc yolo_box_op.cu) detection_library(box_decoder_and_assign_op SRCS box_decoder_and_assign_op.cc box_decoder_and_assign_op.cu) if(WITH_GPU) diff --git a/paddle/fluid/operators/detection/yolo_box_op.cc b/paddle/fluid/operators/detection/yolo_box_op.cc new file mode 100644 index 0000000000000000000000000000000000000000..e0d7e25d944cf2321799da4c73de9f74d9fd287d --- /dev/null +++ b/paddle/fluid/operators/detection/yolo_box_op.cc @@ -0,0 +1,167 @@ +/* Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve. + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + http://www.apache.org/licenses/LICENSE-2.0 + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. */ + +#include "paddle/fluid/operators/detection/yolo_box_op.h" +#include "paddle/fluid/framework/op_registry.h" + +namespace paddle { +namespace operators { + +using framework::Tensor; + +class YoloBoxOp : public framework::OperatorWithKernel { + public: + using framework::OperatorWithKernel::OperatorWithKernel; + void InferShape(framework::InferShapeContext* ctx) const override { + PADDLE_ENFORCE(ctx->HasInput("X"), + "Input(X) of YoloBoxOp should not be null."); + PADDLE_ENFORCE(ctx->HasInput("ImgSize"), + "Input(ImgSize) of YoloBoxOp should not be null."); + PADDLE_ENFORCE(ctx->HasOutput("Boxes"), + "Output(Boxes) of YoloBoxOp should not be null."); + PADDLE_ENFORCE(ctx->HasOutput("Scores"), + "Output(Scores) of YoloBoxOp should not be null."); + + auto dim_x = ctx->GetInputDim("X"); + auto dim_imgsize = ctx->GetInputDim("ImgSize"); + auto anchors = ctx->Attrs().Get>("anchors"); + int anchor_num = anchors.size() / 2; + auto class_num = ctx->Attrs().Get("class_num"); + + PADDLE_ENFORCE_EQ(dim_x.size(), 4, "Input(X) should be a 4-D tensor."); + PADDLE_ENFORCE_EQ( + dim_x[1], anchor_num * (5 + class_num), + "Input(X) dim[1] should be equal to (anchor_mask_number * (5 " + "+ class_num))."); + PADDLE_ENFORCE_EQ(dim_imgsize.size(), 2, + "Input(ImgSize) should be a 2-D tensor."); + PADDLE_ENFORCE_EQ( + dim_imgsize[0], dim_x[0], + "Input(ImgSize) dim[0] and Input(X) dim[0] should be same."); + PADDLE_ENFORCE_EQ(dim_imgsize[1], 2, "Input(ImgSize) dim[1] should be 2."); + PADDLE_ENFORCE_GT(anchors.size(), 0, + "Attr(anchors) length should be greater than 0."); + PADDLE_ENFORCE_EQ(anchors.size() % 2, 0, + "Attr(anchors) length should be even integer."); + PADDLE_ENFORCE_GT(class_num, 0, + "Attr(class_num) should be an integer greater than 0."); + + int box_num = dim_x[2] * dim_x[3] * anchor_num; + std::vector dim_boxes({dim_x[0], box_num, 4}); + ctx->SetOutputDim("Boxes", framework::make_ddim(dim_boxes)); + + std::vector dim_scores({dim_x[0], box_num, class_num}); + ctx->SetOutputDim("Scores", framework::make_ddim(dim_scores)); + } + + protected: + framework::OpKernelType GetExpectedKernelType( + const framework::ExecutionContext& ctx) const override { + return framework::OpKernelType(ctx.Input("X")->type(), + ctx.GetPlace()); + } +}; + +class YoloBoxOpMaker : public framework::OpProtoAndCheckerMaker { + public: + void Make() override { + AddInput("X", + "The input tensor of YoloBox operator is a 4-D tensor with " + "shape of [N, C, H, W]. The second dimension(C) stores " + "box locations, confidence score and classification one-hot " + "keys of each anchor box. Generally, X should be the output " + "of YOLOv3 network."); + AddInput("ImgSize", + "The image size tensor of YoloBox operator, " + "This is a 2-D tensor with shape of [N, 2]. This tensor holds " + "height and width of each input image used for resizing output " + "box in input image scale."); + AddOutput("Boxes", + "The output tensor of detection boxes of YoloBox operator, " + "This is a 3-D tensor with shape of [N, M, 4], N is the " + "batch num, M is output box number, and the 3rd dimension " + "stores [xmin, ymin, xmax, ymax] coordinates of boxes."); + AddOutput("Scores", + "The output tensor of detection boxes scores of YoloBox " + "operator, This is a 3-D tensor with shape of " + "[N, M, :attr:`class_num`], N is the batch num, M is " + "output box number."); + + AddAttr("class_num", "The number of classes to predict."); + AddAttr>("anchors", + "The anchor width and height, " + "it will be parsed pair by pair.") + .SetDefault(std::vector{}); + AddAttr("downsample_ratio", + "The downsample ratio from network input to YoloBox operator " + "input, so 32, 16, 8 should be set for the first, second, " + "and thrid YoloBox operators.") + .SetDefault(32); + AddAttr("conf_thresh", + "The confidence scores threshold of detection boxes. " + "Boxes with confidence scores under threshold should " + "be ignored.") + .SetDefault(0.01); + AddComment(R"DOC( + This operator generates YOLO detection boxes from output of YOLOv3 network. + + The output of previous network is in shape [N, C, H, W], while H and W + should be the same, H and W specify the grid size, each grid point predict + given number boxes, this given number, which following will be represented as S, + is specified by the number of anchors. In the second dimension(the channel + dimension), C should be equal to S * (5 + class_num), class_num is the object + category number of source dataset(such as 80 in coco dataset), so the + second(channel) dimension, apart from 4 box location coordinates x, y, w, h, + also includes confidence score of the box and class one-hot key of each anchor + box. + + Assume the 4 location coordinates are :math:`t_x, t_y, t_w, t_h`, the box + predictions should be as follows: + + $$ + b_x = \\sigma(t_x) + c_x + $$ + $$ + b_y = \\sigma(t_y) + c_y + $$ + $$ + b_w = p_w e^{t_w} + $$ + $$ + b_h = p_h e^{t_h} + $$ + + in the equation above, :math:`c_x, c_y` is the left top corner of current grid + and :math:`p_w, p_h` is specified by anchors. + + The logistic regression value of the 5th channel of each anchor prediction boxes + represents the confidence score of each prediction box, and the logistic + regression value of the last :attr:`class_num` channels of each anchor prediction + boxes represents the classifcation scores. Boxes with confidence scores less than + :attr:`conf_thresh` should be ignored, and box final scores is the product of + confidence scores and classification scores. + + $$ + score_{pred} = score_{conf} * score_{class} + $$ + + )DOC"); + } +}; + +} // namespace operators +} // namespace paddle + +namespace ops = paddle::operators; +REGISTER_OPERATOR(yolo_box, ops::YoloBoxOp, ops::YoloBoxOpMaker, + paddle::framework::EmptyGradOpMaker); +REGISTER_OP_CPU_KERNEL(yolo_box, ops::YoloBoxKernel, + ops::YoloBoxKernel); diff --git a/paddle/fluid/operators/detection/yolo_box_op.cu b/paddle/fluid/operators/detection/yolo_box_op.cu new file mode 100644 index 0000000000000000000000000000000000000000..5a882958e66a79507e053a96b15be8cbbcc83164 --- /dev/null +++ b/paddle/fluid/operators/detection/yolo_box_op.cu @@ -0,0 +1,120 @@ +/* Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "paddle/fluid/operators/detection/yolo_box_op.h" +#include "paddle/fluid/operators/math/math_function.h" + +namespace paddle { +namespace operators { + +using Tensor = framework::Tensor; + +template +__global__ void KeYoloBoxFw(const T* input, const int* imgsize, T* boxes, + T* scores, const float conf_thresh, + const int* anchors, const int n, const int h, + const int w, const int an_num, const int class_num, + const int box_num, int input_size) { + int tid = blockIdx.x * blockDim.x + threadIdx.x; + int stride = blockDim.x * gridDim.x; + T box[4]; + for (; tid < n * box_num; tid += stride) { + int grid_num = h * w; + int i = tid / box_num; + int j = (tid % box_num) / grid_num; + int k = (tid % grid_num) / w; + int l = tid % w; + + int an_stride = (5 + class_num) * grid_num; + int img_height = imgsize[2 * i]; + int img_width = imgsize[2 * i + 1]; + + int obj_idx = + GetEntryIndex(i, j, k * w + l, an_num, an_stride, grid_num, 4); + T conf = sigmoid(input[obj_idx]); + if (conf < conf_thresh) { + continue; + } + + int box_idx = + GetEntryIndex(i, j, k * w + l, an_num, an_stride, grid_num, 0); + GetYoloBox(box, input, anchors, l, k, j, h, input_size, box_idx, + grid_num, img_height, img_width); + box_idx = (i * box_num + j * grid_num + k * w + l) * 4; + CalcDetectionBox(boxes, box, box_idx, img_height, img_width); + + int label_idx = + GetEntryIndex(i, j, k * w + l, an_num, an_stride, grid_num, 5); + int score_idx = (i * box_num + j * grid_num + k * w + l) * class_num; + CalcLabelScore(scores, input, label_idx, score_idx, class_num, conf, + grid_num); + } +} + +template +class YoloBoxOpCUDAKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext& ctx) const override { + auto* input = ctx.Input("X"); + auto* img_size = ctx.Input("ImgSize"); + auto* boxes = ctx.Output("Boxes"); + auto* scores = ctx.Output("Scores"); + + auto anchors = ctx.Attr>("anchors"); + int class_num = ctx.Attr("class_num"); + float conf_thresh = ctx.Attr("conf_thresh"); + int downsample_ratio = ctx.Attr("downsample_ratio"); + + const int n = input->dims()[0]; + const int h = input->dims()[2]; + const int w = input->dims()[3]; + const int box_num = boxes->dims()[1]; + const int an_num = anchors.size() / 2; + int input_size = downsample_ratio * h; + + auto& dev_ctx = ctx.cuda_device_context(); + auto& allocator = + platform::DeviceTemporaryAllocator::Instance().Get(dev_ctx); + int bytes = sizeof(int) * anchors.size(); + auto anchors_ptr = allocator.Allocate(sizeof(int) * anchors.size()); + int* anchors_data = reinterpret_cast(anchors_ptr->ptr()); + const auto gplace = boost::get(ctx.GetPlace()); + const auto cplace = platform::CPUPlace(); + memory::Copy(gplace, anchors_data, cplace, anchors.data(), bytes, + dev_ctx.stream()); + + const T* input_data = input->data(); + const int* imgsize_data = img_size->data(); + T* boxes_data = boxes->mutable_data({n, box_num, 4}, ctx.GetPlace()); + T* scores_data = + scores->mutable_data({n, box_num, class_num}, ctx.GetPlace()); + math::SetConstant set_zero; + set_zero(dev_ctx, boxes, static_cast(0)); + set_zero(dev_ctx, scores, static_cast(0)); + + int grid_dim = (n * box_num + 512 - 1) / 512; + grid_dim = grid_dim > 8 ? 8 : grid_dim; + + KeYoloBoxFw<<>>( + input_data, imgsize_data, boxes_data, scores_data, conf_thresh, + anchors_data, n, h, w, an_num, class_num, box_num, input_size); + } +}; + +} // namespace operators +} // namespace paddle + +namespace ops = paddle::operators; +REGISTER_OP_CUDA_KERNEL(yolo_box, ops::YoloBoxOpCUDAKernel, + ops::YoloBoxOpCUDAKernel); diff --git a/paddle/fluid/operators/detection/yolo_box_op.h b/paddle/fluid/operators/detection/yolo_box_op.h new file mode 100644 index 0000000000000000000000000000000000000000..8b7c7df0f3cf754f59c994dbe5b1cc2ac5fb773b --- /dev/null +++ b/paddle/fluid/operators/detection/yolo_box_op.h @@ -0,0 +1,149 @@ +/* Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve. + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + http://www.apache.org/licenses/LICENSE-2.0 + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. */ + +#pragma once +#include +#include +#include "paddle/fluid/framework/op_registry.h" +#include "paddle/fluid/platform/hostdevice.h" + +namespace paddle { +namespace operators { + +using Tensor = framework::Tensor; + +template +HOSTDEVICE inline T sigmoid(T x) { + return 1.0 / (1.0 + std::exp(-x)); +} + +template +HOSTDEVICE inline void GetYoloBox(T* box, const T* x, const int* anchors, int i, + int j, int an_idx, int grid_size, + int input_size, int index, int stride, + int img_height, int img_width) { + box[0] = (i + sigmoid(x[index])) * img_width / grid_size; + box[1] = (j + sigmoid(x[index + stride])) * img_height / grid_size; + box[2] = std::exp(x[index + 2 * stride]) * anchors[2 * an_idx] * img_width / + input_size; + box[3] = std::exp(x[index + 3 * stride]) * anchors[2 * an_idx + 1] * + img_height / input_size; +} + +HOSTDEVICE inline int GetEntryIndex(int batch, int an_idx, int hw_idx, + int an_num, int an_stride, int stride, + int entry) { + return (batch * an_num + an_idx) * an_stride + entry * stride + hw_idx; +} + +template +HOSTDEVICE inline void CalcDetectionBox(T* boxes, T* box, const int box_idx, + const int img_height, + const int img_width) { + boxes[box_idx] = box[0] - box[2] / 2; + boxes[box_idx + 1] = box[1] - box[3] / 2; + boxes[box_idx + 2] = box[0] + box[2] / 2; + boxes[box_idx + 3] = box[1] + box[3] / 2; + + boxes[box_idx] = boxes[box_idx] > 0 ? boxes[box_idx] : static_cast(0); + boxes[box_idx + 1] = + boxes[box_idx + 1] > 0 ? boxes[box_idx + 1] : static_cast(0); + boxes[box_idx + 2] = boxes[box_idx + 2] < img_width - 1 + ? boxes[box_idx + 2] + : static_cast(img_width - 1); + boxes[box_idx + 3] = boxes[box_idx + 3] < img_height - 1 + ? boxes[box_idx + 3] + : static_cast(img_height - 1); +} + +template +HOSTDEVICE inline void CalcLabelScore(T* scores, const T* input, + const int label_idx, const int score_idx, + const int class_num, const T conf, + const int stride) { + for (int i = 0; i < class_num; i++) { + scores[score_idx + i] = conf * sigmoid(input[label_idx + i * stride]); + } +} + +template +class YoloBoxKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext& ctx) const override { + auto* input = ctx.Input("X"); + auto* imgsize = ctx.Input("ImgSize"); + auto* boxes = ctx.Output("Boxes"); + auto* scores = ctx.Output("Scores"); + auto anchors = ctx.Attr>("anchors"); + int class_num = ctx.Attr("class_num"); + float conf_thresh = ctx.Attr("conf_thresh"); + int downsample_ratio = ctx.Attr("downsample_ratio"); + + const int n = input->dims()[0]; + const int h = input->dims()[2]; + const int w = input->dims()[3]; + const int box_num = boxes->dims()[1]; + const int an_num = anchors.size() / 2; + int input_size = downsample_ratio * h; + + const int stride = h * w; + const int an_stride = (class_num + 5) * stride; + + Tensor anchors_; + auto anchors_data = + anchors_.mutable_data({an_num * 2}, ctx.GetPlace()); + std::copy(anchors.begin(), anchors.end(), anchors_data); + + const T* input_data = input->data(); + const int* imgsize_data = imgsize->data(); + T* boxes_data = boxes->mutable_data({n, box_num, 4}, ctx.GetPlace()); + memset(boxes_data, 0, boxes->numel() * sizeof(T)); + T* scores_data = + scores->mutable_data({n, box_num, class_num}, ctx.GetPlace()); + memset(scores_data, 0, scores->numel() * sizeof(T)); + + T box[4]; + for (int i = 0; i < n; i++) { + int img_height = imgsize_data[2 * i]; + int img_width = imgsize_data[2 * i + 1]; + + for (int j = 0; j < an_num; j++) { + for (int k = 0; k < h; k++) { + for (int l = 0; l < w; l++) { + int obj_idx = + GetEntryIndex(i, j, k * w + l, an_num, an_stride, stride, 4); + T conf = sigmoid(input_data[obj_idx]); + if (conf < conf_thresh) { + continue; + } + + int box_idx = + GetEntryIndex(i, j, k * w + l, an_num, an_stride, stride, 0); + GetYoloBox(box, input_data, anchors_data, l, k, j, h, input_size, + box_idx, stride, img_height, img_width); + box_idx = (i * box_num + j * stride + k * w + l) * 4; + CalcDetectionBox(boxes_data, box, box_idx, img_height, + img_width); + + int label_idx = + GetEntryIndex(i, j, k * w + l, an_num, an_stride, stride, 5); + int score_idx = (i * box_num + j * stride + k * w + l) * class_num; + CalcLabelScore(scores_data, input_data, label_idx, score_idx, + class_num, conf, stride); + } + } + } + } + } +}; + +} // namespace operators +} // namespace paddle diff --git a/python/paddle/fluid/layers/detection.py b/python/paddle/fluid/layers/detection.py index 9183bfd43b1d6f903d56fc275355b787a53ef511..0a1ddbc1dba51692e75fa76856dd689b77ab9f35 100644 --- a/python/paddle/fluid/layers/detection.py +++ b/python/paddle/fluid/layers/detection.py @@ -49,6 +49,7 @@ __all__ = [ 'box_coder', 'polygon_box_transform', 'yolov3_loss', + 'yolo_box', 'box_clip', 'multiclass_nms', 'distribute_fpn_proposals', @@ -628,6 +629,83 @@ def yolov3_loss(x, return loss +@templatedoc(op_type="yolo_box") +def yolo_box(x, + img_size, + anchors, + class_num, + conf_thresh, + downsample_ratio, + name=None): + """ + ${comment} + + Args: + x (Variable): ${x_comment} + img_size (Variable): ${img_size_comment} + anchors (list|tuple): ${anchors_comment} + class_num (int): ${class_num_comment} + conf_thresh (float): ${conf_thresh_comment} + downsample_ratio (int): ${downsample_ratio_comment} + name (string): the name of yolo box layer. Default None. + + Returns: + Variable: A 3-D tensor with shape [N, M, 4], the coordinates of boxes, + and a 3-D tensor with shape [N, M, :attr:`class_num`], the classification + scores of boxes. + + Raises: + TypeError: Input x of yolov_box must be Variable + TypeError: Attr anchors of yolo box must be list or tuple + TypeError: Attr class_num of yolo box must be an integer + TypeError: Attr conf_thresh of yolo box must be a float number + + Examples: + + .. code-block:: python + + x = fluid.layers.data(name='x', shape=[255, 13, 13], dtype='float32') + anchors = [10, 13, 16, 30, 33, 23] + loss = fluid.layers.yolo_box(x=x, class_num=80, anchors=anchors, + conf_thresh=0.01, downsample_ratio=32) + """ + helper = LayerHelper('yolo_box', **locals()) + + if not isinstance(x, Variable): + raise TypeError("Input x of yolo_box must be Variable") + if not isinstance(img_size, Variable): + raise TypeError("Input img_size of yolo_box must be Variable") + if not isinstance(anchors, list) and not isinstance(anchors, tuple): + raise TypeError("Attr anchors of yolo_box must be list or tuple") + if not isinstance(class_num, int): + raise TypeError("Attr class_num of yolo_box must be an integer") + if not isinstance(conf_thresh, float): + raise TypeError("Attr ignore_thresh of yolo_box must be a float number") + + boxes = helper.create_variable_for_type_inference(dtype=x.dtype) + scores = helper.create_variable_for_type_inference(dtype=x.dtype) + + attrs = { + "anchors": anchors, + "class_num": class_num, + "conf_thresh": conf_thresh, + "downsample_ratio": downsample_ratio, + } + + helper.append_op( + type='yolo_box', + inputs={ + "X": x, + "ImgSize": img_size, + }, + outputs={ + 'Boxes': boxes, + 'Scores': scores, + }, + attrs=attrs) + return boxes, scores + + @templatedoc() def detection_map(detect_res, label, diff --git a/python/paddle/fluid/tests/test_detection.py b/python/paddle/fluid/tests/test_detection.py index b756c532cad00b57c39349c5a6959b38f70f09fa..7d1b869cf5991dc5ef960ff4d72289979aae158a 100644 --- a/python/paddle/fluid/tests/test_detection.py +++ b/python/paddle/fluid/tests/test_detection.py @@ -489,6 +489,16 @@ class TestYoloDetection(unittest.TestCase): self.assertIsNotNone(loss) + def test_yolo_box(self): + program = Program() + with program_guard(program): + x = layers.data(name='x', shape=[30, 7, 7], dtype='float32') + img_size = layers.data(name='img_size', shape=[2], dtype='int32') + boxes, scores = layers.yolo_box(x, img_size, [10, 13, 30, 13], 10, + 0.01, 32) + self.assertIsNotNone(boxes) + self.assertIsNotNone(scores) + class TestBoxClip(unittest.TestCase): def test_box_clip(self): diff --git a/python/paddle/fluid/tests/unittests/test_yolo_box_op.py b/python/paddle/fluid/tests/unittests/test_yolo_box_op.py new file mode 100644 index 0000000000000000000000000000000000000000..416e6ea9f412d86db877fc36175e8b910b0613fe --- /dev/null +++ b/python/paddle/fluid/tests/unittests/test_yolo_box_op.py @@ -0,0 +1,117 @@ +# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import division + +import unittest +import numpy as np +from op_test import OpTest + +from paddle.fluid import core + + +def sigmoid(x): + return 1.0 / (1.0 + np.exp(-1.0 * x)) + + +def YoloBox(x, img_size, attrs): + n, c, h, w = x.shape + anchors = attrs['anchors'] + an_num = int(len(anchors) // 2) + class_num = attrs['class_num'] + conf_thresh = attrs['conf_thresh'] + downsample = attrs['downsample'] + input_size = downsample * h + + x = x.reshape((n, an_num, 5 + class_num, h, w)).transpose((0, 1, 3, 4, 2)) + + pred_box = x[:, :, :, :, :4].copy() + grid_x = np.tile(np.arange(w).reshape((1, w)), (h, 1)) + grid_y = np.tile(np.arange(h).reshape((h, 1)), (1, w)) + pred_box[:, :, :, :, 0] = (grid_x + sigmoid(pred_box[:, :, :, :, 0])) / w + pred_box[:, :, :, :, 1] = (grid_y + sigmoid(pred_box[:, :, :, :, 1])) / h + + anchors = [(anchors[i], anchors[i + 1]) for i in range(0, len(anchors), 2)] + anchors_s = np.array( + [(an_w / input_size, an_h / input_size) for an_w, an_h in anchors]) + anchor_w = anchors_s[:, 0:1].reshape((1, an_num, 1, 1)) + anchor_h = anchors_s[:, 1:2].reshape((1, an_num, 1, 1)) + pred_box[:, :, :, :, 2] = np.exp(pred_box[:, :, :, :, 2]) * anchor_w + pred_box[:, :, :, :, 3] = np.exp(pred_box[:, :, :, :, 3]) * anchor_h + + pred_conf = sigmoid(x[:, :, :, :, 4:5]) + pred_conf[pred_conf < conf_thresh] = 0. + pred_score = sigmoid(x[:, :, :, :, 5:]) * pred_conf + pred_box = pred_box * (pred_conf > 0.).astype('float32') + + pred_box = pred_box.reshape((n, -1, 4)) + pred_box[:, :, :2], pred_box[:, :, 2:4] = \ + pred_box[:, :, :2] - pred_box[:, :, 2:4] / 2., \ + pred_box[:, :, :2] + pred_box[:, :, 2:4] / 2.0 + pred_box[:, :, 0] = pred_box[:, :, 0] * img_size[:, 1][:, np.newaxis] + pred_box[:, :, 1] = pred_box[:, :, 1] * img_size[:, 0][:, np.newaxis] + pred_box[:, :, 2] = pred_box[:, :, 2] * img_size[:, 1][:, np.newaxis] + pred_box[:, :, 3] = pred_box[:, :, 3] * img_size[:, 0][:, np.newaxis] + + for i in range(len(pred_box)): + pred_box[i, :, 0] = np.clip(pred_box[i, :, 0], 0, np.inf) + pred_box[i, :, 1] = np.clip(pred_box[i, :, 1], 0, np.inf) + pred_box[i, :, 2] = np.clip(pred_box[i, :, 2], -np.inf, + img_size[i, 1] - 1) + pred_box[i, :, 3] = np.clip(pred_box[i, :, 3], -np.inf, + img_size[i, 0] - 1) + + return pred_box, pred_score.reshape((n, -1, class_num)) + + +class TestYoloBoxOp(OpTest): + def setUp(self): + self.initTestCase() + self.op_type = 'yolo_box' + x = np.random.random(self.x_shape).astype('float32') + img_size = np.random.randint(10, 20, self.imgsize_shape).astype('int32') + + self.attrs = { + "anchors": self.anchors, + "class_num": self.class_num, + "conf_thresh": self.conf_thresh, + "downsample": self.downsample, + } + + self.inputs = { + 'X': x, + 'ImgSize': img_size, + } + boxes, scores = YoloBox(x, img_size, self.attrs) + self.outputs = { + "Boxes": boxes, + "Scores": scores, + } + + def test_check_output(self): + self.check_output() + + def initTestCase(self): + self.anchors = [10, 13, 16, 30, 33, 23] + an_num = int(len(self.anchors) // 2) + self.batch_size = 32 + self.class_num = 2 + self.conf_thresh = 0.5 + self.downsample = 32 + self.x_shape = (self.batch_size, an_num * (5 + self.class_num), 13, 13) + self.imgsize_shape = (self.batch_size, 2) + + +if __name__ == "__main__": + unittest.main()