diff --git a/paddle/fluid/API.spec b/paddle/fluid/API.spec index 84b4677777a79b30ba8936025a60e8d6d9186a2c..f50a38842a21c795c979f859e88a9b16c3e54bd8 100644 --- a/paddle/fluid/API.spec +++ b/paddle/fluid/API.spec @@ -325,6 +325,7 @@ paddle.fluid.layers.iou_similarity ArgSpec(args=['x', 'y', 'name'], varargs=None paddle.fluid.layers.box_coder ArgSpec(args=['prior_box', 'prior_box_var', 'target_box', 'code_type', 'box_normalized', 'name', 'axis'], varargs=None, keywords=None, defaults=('encode_center_size', True, None, 0)) paddle.fluid.layers.polygon_box_transform ArgSpec(args=['input', 'name'], varargs=None, keywords=None, defaults=(None,)) paddle.fluid.layers.yolov3_loss ArgSpec(args=['x', 'gtbox', 'gtlabel', 'anchors', 'anchor_mask', 'class_num', 'ignore_thresh', 'downsample_ratio', 'name'], varargs=None, keywords=None, defaults=(None,)) +paddle.fluid.layers.box_clip ArgSpec(args=['input', 'im_info', 'name'], varargs=None, keywords=None, defaults=(None,)) paddle.fluid.layers.multiclass_nms ArgSpec(args=['bboxes', 'scores', 'score_threshold', 'nms_top_k', 'keep_top_k', 'nms_threshold', 'normalized', 'nms_eta', 'background_label', 'name'], varargs=None, keywords=None, defaults=(0.3, True, 1.0, 0, None)) paddle.fluid.layers.accuracy ArgSpec(args=['input', 'label', 'k', 'correct', 'total'], varargs=None, keywords=None, defaults=(1, None, None)) paddle.fluid.layers.auc ArgSpec(args=['input', 'label', 'curve', 'num_thresholds', 'topk', 'slide_steps'], varargs=None, keywords=None, defaults=('ROC', 4095, 1, 1)) diff --git a/paddle/fluid/operators/detection/CMakeLists.txt b/paddle/fluid/operators/detection/CMakeLists.txt index cace42bc1bae93287c330e54d12126efbf9a14bb..f6fbe97565c43c306ea885c765c0a665492fa317 100644 --- a/paddle/fluid/operators/detection/CMakeLists.txt +++ b/paddle/fluid/operators/detection/CMakeLists.txt @@ -31,6 +31,7 @@ detection_library(polygon_box_transform_op SRCS polygon_box_transform_op.cc polygon_box_transform_op.cu) detection_library(rpn_target_assign_op SRCS rpn_target_assign_op.cc) detection_library(generate_proposal_labels_op SRCS generate_proposal_labels_op.cc) +detection_library(box_clip_op SRCS box_clip_op.cc box_clip_op.cu) detection_library(yolov3_loss_op SRCS yolov3_loss_op.cc) if(WITH_GPU) diff --git a/paddle/fluid/operators/detection/bbox_util.h b/paddle/fluid/operators/detection/bbox_util.h index b99edb5bf05f94e762b377a8882e4c3fcdb5afad..a7bc3e027229884e78721d29428a8ab3f08a6ebc 100644 --- a/paddle/fluid/operators/detection/bbox_util.h +++ b/paddle/fluid/operators/detection/bbox_util.h @@ -99,5 +99,29 @@ void BboxOverlaps(const framework::Tensor& r_boxes, } } +template +void ClipTiledBoxes(const platform::DeviceContext& ctx, + const framework::Tensor& im_info, + const framework::Tensor& input_boxes, + framework::Tensor* out) { + T* out_data = out->mutable_data(ctx.GetPlace()); + const T* im_info_data = im_info.data(); + const T* input_boxes_data = input_boxes.data(); + T zero(0); + T im_w = round(im_info_data[1] / im_info_data[2]); + T im_h = round(im_info_data[0] / im_info_data[2]); + for (int64_t i = 0; i < input_boxes.numel(); ++i) { + if (i % 4 == 0) { + out_data[i] = std::max(std::min(input_boxes_data[i], im_w - 1), zero); + } else if (i % 4 == 1) { + out_data[i] = std::max(std::min(input_boxes_data[i], im_h - 1), zero); + } else if (i % 4 == 2) { + out_data[i] = std::max(std::min(input_boxes_data[i], im_w - 1), zero); + } else { + out_data[i] = std::max(std::min(input_boxes_data[i], im_h - 1), zero); + } + } +} + } // namespace operators } // namespace paddle diff --git a/paddle/fluid/operators/detection/box_clip_op.cc b/paddle/fluid/operators/detection/box_clip_op.cc new file mode 100644 index 0000000000000000000000000000000000000000..3aa766559a530bc31fbb277f2bcd474da776e63b --- /dev/null +++ b/paddle/fluid/operators/detection/box_clip_op.cc @@ -0,0 +1,86 @@ +/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + http://www.apache.org/licenses/LICENSE-2.0 +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "paddle/fluid/operators/detection/box_clip_op.h" +#include "paddle/fluid/framework/op_registry.h" + +namespace paddle { +namespace operators { + +class BoxClipOp : public framework::OperatorWithKernel { + public: + using framework::OperatorWithKernel::OperatorWithKernel; + + protected: + void InferShape(framework::InferShapeContext* ctx) const override { + PADDLE_ENFORCE(ctx->HasInput("Input"), + "Input(Input) of BoxClipOp should not be null."); + PADDLE_ENFORCE(ctx->HasInput("ImInfo"), + "Input(ImInfo) of BoxClipOp should not be null."); + + auto input_box_dims = ctx->GetInputDim("Input"); + auto im_info_dims = ctx->GetInputDim("ImInfo"); + + if (ctx->IsRuntime()) { + auto input_box_size = input_box_dims.size(); + PADDLE_ENFORCE_EQ(input_box_dims[input_box_size - 1], 4, + "The last dimension of Input must be 4"); + PADDLE_ENFORCE_EQ(im_info_dims.size(), 2, + "The rank of Input(Input) in BoxClipOp must be 2"); + PADDLE_ENFORCE_EQ(im_info_dims[1], 3, + "The last dimension of ImInfo must be 3"); + } + ctx->ShareDim("Input", /*->*/ "Output"); + ctx->ShareLoD("Input", /*->*/ "Output"); + } +}; + +class BoxClipOpMaker : public framework::OpProtoAndCheckerMaker { + public: + void Make() override { + AddInput("Input", + "(LoDTensor) " + "Input is a LoDTensor with shape [..., 4] holds 4 points" + "in last dimension in format [xmin, ymin, xmax, ymax]"); + AddInput("ImInfo", + "(Tensor) Information for image reshape is in shape (N, 3), " + "in format (height, width, im_scale)"); + AddOutput("Output", + "(LoDTensor) " + "Output is a LoDTensor with the same shape as Input" + "and it is the result after clip"); + AddComment(R"DOC( +This operator clips input boxes to original input images. + +For each input box, The formula is given as follows: + + $$xmin = \max(\min(xmin, im_w - 1), 0)$$ + $$ymin = \max(\min(ymin, im_h - 1), 0)$$ + $$xmax = \max(\min(xmax, im_w - 1), 0)$$ + $$ymax = \max(\min(ymax, im_h - 1), 0)$$ + +where im_w and im_h are computed from ImInfo, the formula is given as follows: + + $$im_w = \round(width / im_scale)$$ + $$im_h = \round(height / im_scale)$$ +)DOC"); + } +}; + +} // namespace operators +} // namespace paddle + +namespace ops = paddle::operators; +REGISTER_OPERATOR(box_clip, ops::BoxClipOp, ops::BoxClipOpMaker, + paddle::framework::EmptyGradOpMaker); +REGISTER_OP_CPU_KERNEL( + box_clip, ops::BoxClipKernel, + ops::BoxClipKernel); diff --git a/paddle/fluid/operators/detection/box_clip_op.cu b/paddle/fluid/operators/detection/box_clip_op.cu new file mode 100644 index 0000000000000000000000000000000000000000..b727da5f7b736b6f22407d1dfbca708ed0cf04d9 --- /dev/null +++ b/paddle/fluid/operators/detection/box_clip_op.cu @@ -0,0 +1,74 @@ +/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include +#include "paddle/fluid/framework/op_registry.h" +#include "paddle/fluid/operators/detection/box_clip_op.h" +#include "paddle/fluid/operators/math/math_function.h" +#include "paddle/fluid/platform/cuda_primitives.h" +#include "paddle/fluid/platform/hostdevice.h" + +namespace paddle { +namespace operators { + +using Tensor = framework::Tensor; +using LoDTenso = framework::LoDTensor; + +static constexpr int ImInfoSize = 3; + +template +static __global__ void GPUBoxClip(const T *input, const size_t *lod, + const size_t width, const T *im_info, + T *output) { + T im_w = round(im_info[blockIdx.x * ImInfoSize + 1] / + im_info[blockIdx.x * ImInfoSize + 2]); + T im_h = round(im_info[blockIdx.x * ImInfoSize] / + im_info[blockIdx.x * ImInfoSize + 2]); + for (int i = threadIdx.x; i < (lod[blockIdx.x + 1] - lod[blockIdx.x]) * width; + i += BlockSize) { + int idx = lod[blockIdx.x] * width + i; + T im_size = (idx % 2 == 0) ? im_w : im_h; + output[idx] = max(min(input[idx], im_size - 1), T(0.)); + } +} + +template +class GPUBoxClipKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext &context) const override { + PADDLE_ENFORCE(platform::is_gpu_place(context.GetPlace()), + "This kernel only runs on GPU device."); + auto *input = context.Input("Input"); + auto *im_info = context.Input("ImInfo"); + auto *output = context.Output("Output"); + const int64_t num = input->dims()[0]; + const int64_t bbox_width = input->numel() / num; + auto lod = input->lod(); + framework::LoD abs_offset_lod = framework::ToAbsOffset(lod); + auto &dev_ctx = context.template device_context(); + auto stream = dev_ctx.stream(); + const size_t batch_size = lod.back().size() - 1; + T *output_data = output->mutable_data(dev_ctx.GetPlace()); + GPUBoxClip<<>>( + input->data(), abs_offset_lod[0].CUDAMutableData(dev_ctx.GetPlace()), + bbox_width, im_info->data(), output_data); + } +}; + +} // namespace operators +} // namespace paddle + +namespace ops = paddle::operators; +REGISTER_OP_CUDA_KERNEL( + box_clip, ops::GPUBoxClipKernel, + ops::GPUBoxClipKernel); diff --git a/paddle/fluid/operators/detection/box_clip_op.h b/paddle/fluid/operators/detection/box_clip_op.h new file mode 100644 index 0000000000000000000000000000000000000000..74e1f88f8d8b28e490d170934760bd9bffc807bc --- /dev/null +++ b/paddle/fluid/operators/detection/box_clip_op.h @@ -0,0 +1,50 @@ +/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + http://www.apache.org/licenses/LICENSE-2.0 +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#pragma once +#include +#include "paddle/fluid/framework/op_registry.h" +#include "paddle/fluid/operators/detection/bbox_util.h" +#include "paddle/fluid/operators/math/math_function.h" + +namespace paddle { +namespace operators { + +using Tensor = framework::Tensor; +using LoDTensor = framework::LoDTensor; + +template +class BoxClipKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext& context) const override { + auto* input_box = context.Input("Input"); + auto* im_info = context.Input("ImInfo"); + auto* output_box = context.Output("Output"); + auto& dev_ctx = + context.template device_context(); + output_box->mutable_data(context.GetPlace()); + if (input_box->lod().size()) { + PADDLE_ENFORCE_EQ(input_box->lod().size(), 1UL, + "Only support 1 level of LoD."); + } + auto box_lod = input_box->lod().back(); + int64_t n = static_cast(box_lod.size() - 1); + for (int i = 0; i < n; ++i) { + Tensor im_info_slice = im_info->Slice(i, i + 1); + Tensor box_slice = input_box->Slice(box_lod[i], box_lod[i + 1]); + Tensor output_slice = output_box->Slice(box_lod[i], box_lod[i + 1]); + ClipTiledBoxes(dev_ctx, im_info_slice, box_slice, &output_slice); + } + } +}; + +} // namespace operators +} // namespace paddle diff --git a/python/paddle/fluid/layers/detection.py b/python/paddle/fluid/layers/detection.py index 0602d7a19481fbf0210a7cb4bd15a1033b0e8900..c983e2a44b25c5943df5e822e2e363b2557a6ac3 100644 --- a/python/paddle/fluid/layers/detection.py +++ b/python/paddle/fluid/layers/detection.py @@ -49,6 +49,7 @@ __all__ = [ 'box_coder', 'polygon_box_transform', 'yolov3_loss', + 'box_clip', 'multiclass_nms', ] @@ -2055,6 +2056,54 @@ def generate_proposals(scores, return rpn_rois, rpn_roi_probs +def box_clip(input, im_info, name=None): + """ + Clip the box into the size given by im_info + For each input box, The formula is given as follows: + + .. code-block:: text + + xmin = max(min(xmin, im_w - 1), 0) + ymin = max(min(ymin, im_h - 1), 0) + xmax = max(min(xmax, im_w - 1), 0) + ymax = max(min(ymax, im_h - 1), 0) + + where im_w and im_h are computed from im_info: + + .. code-block:: text + + im_h = round(height / scale) + im_w = round(weight / scale) + + Args: + input(variable): The input box, the last dimension is 4. + im_info(variable): The information of image with shape [N, 3] with + layout (height, width, scale). height and width + is the input size and scale is the ratio of input + size and original size. + name (str): The name of this layer. It is optional. + + Returns: + Variable: The cliped tensor variable. + + Examples: + .. code-block:: python + + boxes = fluid.layers.data( + name='data', shape=[8, 4], dtype='float32', lod_level=1) + im_info = fluid.layers.data(name='im_info', shape=[3]) + out = fluid.layers.box_clip( + input=boxes, im_info=im_info, inplace=True) + """ + + helper = LayerHelper("box_clip", **locals()) + output = helper.create_variable_for_type_inference(dtype=input.dtype) + inputs = {"Input": input, "ImInfo": im_info} + helper.append_op(type="box_clip", inputs=inputs, outputs={"Output": output}) + + return output + + def multiclass_nms(bboxes, scores, score_threshold, @@ -2132,9 +2181,11 @@ def multiclass_nms(bboxes, (After version 1.3, when no boxes detected, the lod is changed from {0} to {1}) + Examples: .. code-block:: python + boxes = fluid.layers.data(name='bboxes', shape=[81, 4], dtype='float32', lod_level=1) scores = fluid.layers.data(name='scores', shape=[81], diff --git a/python/paddle/fluid/tests/test_detection.py b/python/paddle/fluid/tests/test_detection.py index 77dfa1cb519db3faa9ef8b7b27f7a39b5d31f2a8..0d39a139eed87f900b1f59fd0569b6acaec0962b 100644 --- a/python/paddle/fluid/tests/test_detection.py +++ b/python/paddle/fluid/tests/test_detection.py @@ -482,6 +482,17 @@ class TestYoloDetection(unittest.TestCase): self.assertIsNotNone(loss) +class TestBoxClip(unittest.TestCase): + def test_box_clip(self): + program = Program() + with program_guard(program): + input_box = layers.data( + name='input_box', shape=[7, 4], dtype='float32', lod_level=1) + im_info = layers.data(name='im_info', shape=[3], dtype='float32') + out = layers.box_clip(input_box, im_info) + self.assertIsNotNone(out) + + class TestMulticlassNMS(unittest.TestCase): def test_multiclass_nms(self): program = Program() diff --git a/python/paddle/fluid/tests/unittests/test_box_clip_op.py b/python/paddle/fluid/tests/unittests/test_box_clip_op.py new file mode 100644 index 0000000000000000000000000000000000000000..b2b0598f31dd27e12e5ce329129129b5e0f1caf0 --- /dev/null +++ b/python/paddle/fluid/tests/unittests/test_box_clip_op.py @@ -0,0 +1,70 @@ +# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import print_function + +import unittest +import numpy as np +import sys +import math +from op_test import OpTest +import copy + + +def box_clip(input_box, im_info, output_box): + im_w = round(im_info[1] / im_info[2]) + im_h = round(im_info[0] / im_info[2]) + output_box[:, :, 0] = np.maximum( + np.minimum(input_box[:, :, 0], im_w - 1), 0) + output_box[:, :, 1] = np.maximum( + np.minimum(input_box[:, :, 1], im_h - 1), 0) + output_box[:, :, 2] = np.maximum( + np.minimum(input_box[:, :, 2], im_w - 1), 0) + output_box[:, :, 3] = np.maximum( + np.minimum(input_box[:, :, 3], im_h - 1), 0) + + +def batch_box_clip(input_boxes, im_info, lod): + n = input_boxes.shape[0] + m = input_boxes.shape[1] + output_boxes = np.zeros((n, m, 4), dtype=np.float32) + cur_offset = 0 + for i in range(len(lod)): + box_clip(input_boxes[cur_offset:(cur_offset + lod[i]), :, :], + im_info[i, :], + output_boxes[cur_offset:(cur_offset + lod[i]), :, :]) + cur_offset += lod[i] + return output_boxes + + +class TestBoxClipOp(OpTest): + def test_check_output(self): + self.check_output() + + def setUp(self): + self.op_type = "box_clip" + lod = [[1, 2, 3]] + input_boxes = np.random.random((6, 10, 4)) * 5 + im_info = np.array([[5, 8, 1.], [6, 6, 1.], [7, 5, 1.]]) + output_boxes = batch_box_clip(input_boxes, im_info, lod[0]) + + self.inputs = { + 'Input': (input_boxes.astype('float32'), lod), + 'ImInfo': im_info.astype('float32'), + } + self.outputs = {'Output': output_boxes} + + +if __name__ == '__main__': + unittest.main()