From 481d8bce2fa10c5c729b146c6925e46d434d22d6 Mon Sep 17 00:00:00 2001 From: jerrywgz Date: Wed, 16 Jan 2019 06:42:31 +0000 Subject: [PATCH] add box clip op --- paddle/fluid/API.spec | 2 + .../fluid/operators/detection/CMakeLists.txt | 1 + paddle/fluid/operators/detection/bbox_util.h | 24 ++++++ .../fluid/operators/detection/box_clip_op.cc | 74 +++++++++++++++++++ .../fluid/operators/detection/box_clip_op.h | 50 +++++++++++++ python/paddle/fluid/layers/detection.py | 66 ++++++++++++----- python/paddle/fluid/tests/test_detection.py | 14 +++- .../fluid/tests/unittests/test_box_clip_op.py | 70 ++++++++++++++++++ 8 files changed, 282 insertions(+), 19 deletions(-) create mode 100644 paddle/fluid/operators/detection/box_clip_op.cc create mode 100644 paddle/fluid/operators/detection/box_clip_op.h create mode 100644 python/paddle/fluid/tests/unittests/test_box_clip_op.py diff --git a/paddle/fluid/API.spec b/paddle/fluid/API.spec index 50ffef72baa..cfde0fdf0c8 100644 --- a/paddle/fluid/API.spec +++ b/paddle/fluid/API.spec @@ -318,6 +318,7 @@ paddle.fluid.layers.iou_similarity ArgSpec(args=['x', 'y', 'name'], varargs=None paddle.fluid.layers.box_coder ArgSpec(args=['prior_box', 'prior_box_var', 'target_box', 'code_type', 'box_normalized', 'name'], varargs=None, keywords=None, defaults=('encode_center_size', True, None)) paddle.fluid.layers.polygon_box_transform ArgSpec(args=['input', 'name'], varargs=None, keywords=None, defaults=(None,)) paddle.fluid.layers.yolov3_loss ArgSpec(args=['x', 'gtbox', 'gtlabel', 'anchors', 'class_num', 'ignore_thresh', 'loss_weight_xy', 'loss_weight_wh', 'loss_weight_conf_target', 'loss_weight_conf_notarget', 'loss_weight_class', 'name'], varargs=None, keywords=None, defaults=(None, None, None, None, None, None)) +paddle.fluid.layers.box_clip ArgSpec(args=['input_box', 'im_info', 'inplace', 'name'], varargs=None, keywords=None, defaults=(False, None)) paddle.fluid.layers.accuracy ArgSpec(args=['input', 'label', 'k', 'correct', 'total'], varargs=None, keywords=None, defaults=(1, None, None)) paddle.fluid.layers.auc ArgSpec(args=['input', 'label', 'curve', 'num_thresholds', 'topk', 'slide_steps'], varargs=None, keywords=None, defaults=('ROC', 4095, 1, 1)) paddle.fluid.layers.exponential_decay ArgSpec(args=['learning_rate', 'decay_steps', 'decay_rate', 'staircase'], varargs=None, keywords=None, defaults=(False,)) @@ -494,6 +495,7 @@ paddle.reader.buffered ArgSpec(args=['reader', 'size'], varargs=None, keywords=N paddle.reader.compose ArgSpec(args=[], varargs='readers', keywords='kwargs', defaults=None) paddle.reader.chain ArgSpec(args=[], varargs='readers', keywords=None, defaults=None) paddle.reader.shuffle ArgSpec(args=['reader', 'buf_size'], varargs=None, keywords=None, defaults=None) +paddle.reader.ComposeNotAligned.__init__ paddle.reader.firstn ArgSpec(args=['reader', 'n'], varargs=None, keywords=None, defaults=None) paddle.reader.xmap_readers ArgSpec(args=['mapper', 'reader', 'process_num', 'buffer_size', 'order'], varargs=None, keywords=None, defaults=(False,)) paddle.reader.PipeReader.__init__ ArgSpec(args=['self', 'command', 'bufsize', 'file_type'], varargs=None, keywords=None, defaults=(8192, 'plain')) diff --git a/paddle/fluid/operators/detection/CMakeLists.txt b/paddle/fluid/operators/detection/CMakeLists.txt index 6c85f1577e0..b0f023935d2 100644 --- a/paddle/fluid/operators/detection/CMakeLists.txt +++ b/paddle/fluid/operators/detection/CMakeLists.txt @@ -31,6 +31,7 @@ detection_library(polygon_box_transform_op SRCS polygon_box_transform_op.cc polygon_box_transform_op.cu) detection_library(rpn_target_assign_op SRCS rpn_target_assign_op.cc) detection_library(generate_proposal_labels_op SRCS generate_proposal_labels_op.cc) +detection_library(box_clip_op SRCS box_clip_op.cc) if(WITH_GPU) detection_library(generate_proposals_op SRCS generate_proposals_op.cc generate_proposals_op.cu DEPS memory cub) diff --git a/paddle/fluid/operators/detection/bbox_util.h b/paddle/fluid/operators/detection/bbox_util.h index 6abeca1da44..ba16c9565f3 100644 --- a/paddle/fluid/operators/detection/bbox_util.h +++ b/paddle/fluid/operators/detection/bbox_util.h @@ -93,5 +93,29 @@ void BboxOverlaps(const framework::Tensor& r_boxes, } } +template +void ClipTiledBoxes(const platform::DeviceContext& ctx, + const framework::Tensor& im_info, + const framework::Tensor& input_boxes, + framework::Tensor* out) { + T* out_data = out->mutable_data(ctx.GetPlace()); + const T* im_info_data = im_info.data(); + const T* input_boxes_data = input_boxes.data(); + T zero(0); + T im_w = round(im_info_data[1] / im_info_data[2]); + T im_h = round(im_info_data[0] / im_info_data[2]); + for (int64_t i = 0; i < input_boxes.numel(); ++i) { + if (i % 4 == 0) { + out_data[i] = std::max(std::min(input_boxes_data[i], im_w - 1), zero); + } else if (i % 4 == 1) { + out_data[i] = std::max(std::min(input_boxes_data[i], im_h - 1), zero); + } else if (i % 4 == 2) { + out_data[i] = std::max(std::min(input_boxes_data[i], im_w - 1), zero); + } else { + out_data[i] = std::max(std::min(input_boxes_data[i], im_h - 1), zero); + } + } +} + } // namespace operators } // namespace paddle diff --git a/paddle/fluid/operators/detection/box_clip_op.cc b/paddle/fluid/operators/detection/box_clip_op.cc new file mode 100644 index 00000000000..b185f127961 --- /dev/null +++ b/paddle/fluid/operators/detection/box_clip_op.cc @@ -0,0 +1,74 @@ +/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + http://www.apache.org/licenses/LICENSE-2.0 +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "paddle/fluid/operators/detection/box_clip_op.h" +#include "paddle/fluid/framework/op_registry.h" + +namespace paddle { +namespace operators { + +class BoxClipOp : public framework::OperatorWithKernel { + public: + using framework::OperatorWithKernel::OperatorWithKernel; + + protected: + void InferShape(framework::InferShapeContext *ctx) const override { + PADDLE_ENFORCE(ctx->HasInput("InputBox"), + "Input(InputBox) of BoxClipOp should not be null."); + PADDLE_ENFORCE(ctx->HasInput("ImInfo"), + "Input(ImInfo) of BoxClipOp should not be null."); + + auto input_box_dims = ctx->GetInputDim("InputBox"); + auto im_info_dims = ctx->GetInputDim("ImInfo"); + + if (ctx->IsRuntime()) { + auto input_box_size = input_box_dims.size(); + PADDLE_ENFORCE_EQ(input_box_dims[input_box_size - 1], 4, + "The last dimension of InputBox must be 4"); + PADDLE_ENFORCE_EQ(im_info_dims.size(), 2, + "The rank of Input(InputBox) in BoxClipOp must be 2"); + PADDLE_ENFORCE_EQ(im_info_dims[1], 2, + "The last dimension of ImInfo must be 2"); + } + ctx->ShareDim("InputBox", /*->*/ "OutputBox"); + ctx->ShareLoD("InputBox", /*->*/ "OutputBox"); + } +}; + +class BoxClipOpMaker : public framework::OpProtoAndCheckerMaker { + public: + void Make() override { + AddInput("InputBox", + "(LoDTensor) " + "InputBox is a LoDTensor with shape [..., 4] holds 4 points" + "in last dimension in format [xmin, ymin, xmax, ymax]"); + AddInput("ImInfo", + "(Tensor) Information for image reshape is in shape (N, 2), " + "in format (height, width)"); + AddOutput("OutputBox", + "(LoDTensor) " + "OutputBox is a LoDTensor with the same shape as InputBox" + "and it is the result after clip"); + AddComment(R"DOC( + This operator clips input boxes to original input images. +)DOC"); + } +}; + +} // namespace operators +} // namespace paddle + +namespace ops = paddle::operators; +REGISTER_OPERATOR(box_clip, ops::BoxClipOp, ops::BoxClipOpMaker, + paddle::framework::EmptyGradOpMaker); +REGISTER_OP_CPU_KERNEL( + box_clip, ops::BoxClipKernel, + ops::BoxClipKernel); diff --git a/paddle/fluid/operators/detection/box_clip_op.h b/paddle/fluid/operators/detection/box_clip_op.h new file mode 100644 index 00000000000..88d35d2a88c --- /dev/null +++ b/paddle/fluid/operators/detection/box_clip_op.h @@ -0,0 +1,50 @@ +/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + http://www.apache.org/licenses/LICENSE-2.0 +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#pragma once +#include +#include "paddle/fluid/framework/op_registry.h" +#include "paddle/fluid/operators/detection/bbox_util.h" +#include "paddle/fluid/operators/math/math_function.h" + +namespace paddle { +namespace operators { + +using Tensor = framework::Tensor; +using LoDTensor = framework::LoDTensor; + +template +class BoxClipKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext& context) const override { + auto* input_box = context.Input("InputBox"); + auto* im_info = context.Input("ImInfo"); + auto* output_box = context.Output("OutputBox"); + auto& dev_ctx = + context.template device_context(); + output_box->mutable_data(context.GetPlace()); + if (input_box->lod().size()) { + PADDLE_ENFORCE_EQ(input_box->lod().size(), 1UL, + "Only support 1 level of LoD."); + } + auto box_lod = input_box->lod().back(); + int64_t n = static_cast(box_lod.size() - 1); + for (int i = 0; i < n; ++i) { + Tensor im_info_slice = im_info->Slice(i, i + 1); + Tensor box_slice = input_box->Slice(box_lod[i], box_lod[i + 1]); + Tensor output_slice = output_box->Slice(box_lod[i], box_lod[i + 1]); + ClipTiledBoxes(dev_ctx, im_info_slice, box_slice, &output_slice); + } + } +}; + +} // namespace operators +} // namespace paddle diff --git a/python/paddle/fluid/layers/detection.py b/python/paddle/fluid/layers/detection.py index 8aed97dc59b..daeb10c1d69 100644 --- a/python/paddle/fluid/layers/detection.py +++ b/python/paddle/fluid/layers/detection.py @@ -31,23 +31,11 @@ import numpy from functools import reduce __all__ = [ - 'prior_box', - 'density_prior_box', - 'multi_box_head', - 'bipartite_match', - 'target_assign', - 'detection_output', - 'ssd_loss', - 'detection_map', - 'rpn_target_assign', - 'anchor_generator', - 'roi_perspective_transform', - 'generate_proposal_labels', - 'generate_proposals', - 'iou_similarity', - 'box_coder', - 'polygon_box_transform', - 'yolov3_loss', + 'prior_box', 'density_prior_box', 'multi_box_head', 'bipartite_match', + 'target_assign', 'detection_output', 'ssd_loss', 'detection_map', + 'rpn_target_assign', 'anchor_generator', 'roi_perspective_transform', + 'generate_proposal_labels', 'generate_proposals', 'iou_similarity', + 'box_coder', 'polygon_box_transform', 'yolov3_loss', 'box_clip' ] @@ -1810,3 +1798,47 @@ def generate_proposals(scores, rpn_roi_probs.stop_gradient = True return rpn_rois, rpn_roi_probs + + +def box_clip(input_box, im_info, inplace=False, name=None): + """ + Clip the box into the size given by im_info + + Args: + input_box(variable): The input box, the last dimension is 4. + im_info(variable): The information of image with shape [N, 3]. + inplace(bool): Must use :attr:`False` if :attr:`input_box` is used in + multiple operators. If this flag is set :attr:`True`, + reuse input :attr:`input_box` to clip, which will + change the value of tensor variable :attr:`input_box` + and might cause errors when :attr:`input_box` is used + in multiple operators. If :attr:`False`, preserve the + value pf :attr:`input_box` and create a new output + tensor variable whose data is copied from input x but + cliped. + name (str): The name of this layer. It is optional. + + Returns: + Variable: The cliped tensor variable. + + Examples: + .. code-block:: python + + boxes = fluid.layers.data( + name='data', shape=[8, 4], dtype='float32', lod_level=1) + im_info = fluid.layers.data(name='im_info', shape=[3]) + out = fluid.layers.box_clip( + input_box=boxes, im_info=im_info, inplace=True) + """ + + inputs = {"InputBox": input_box, "ImInfo": im_info} + + helper = LayerHelper("box_clip", **locals()) + output = helper.create_variable_for_type_inference(dtype=input_box.dtype) + helper.append_op( + type="box_clip", + inputs=inputs, + attrs={"inplace:": inplace}, + outputs={"OutputBox": output}) + + return output diff --git a/python/paddle/fluid/tests/test_detection.py b/python/paddle/fluid/tests/test_detection.py index d99eaa0634f..bbc372da1a8 100644 --- a/python/paddle/fluid/tests/test_detection.py +++ b/python/paddle/fluid/tests/test_detection.py @@ -354,8 +354,7 @@ class TestGenerateProposals(unittest.TestCase): data_shape = [20, 64, 64] images = fluid.layers.data( name='images', shape=data_shape, dtype='float32') - im_info = fluid.layers.data( - name='im_info', shape=[1, 3], dtype='float32') + im_info = fluid.layers.data(name='im_info', shape=[3], dtype='float32') anchors, variances = fluid.layers.anchor_generator( name='anchor_generator', input=images, @@ -401,5 +400,16 @@ class TestYoloDetection(unittest.TestCase): self.assertIsNotNone(loss) +class TestBoxClip(unittest.TestCase): + def test_box_clip(self): + program = Program() + with program_guard(program): + input_box = layers.data( + name='input_box', shape=[7, 4], dtype='float32', lod_level=1) + im_info = layers.data(name='im_info', shape=[3], dtype='float32') + out = layers.box_clip(input_box, im_info) + self.assertIsNotNone(out) + + if __name__ == '__main__': unittest.main() diff --git a/python/paddle/fluid/tests/unittests/test_box_clip_op.py b/python/paddle/fluid/tests/unittests/test_box_clip_op.py new file mode 100644 index 00000000000..6cd3f21a6e2 --- /dev/null +++ b/python/paddle/fluid/tests/unittests/test_box_clip_op.py @@ -0,0 +1,70 @@ +# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import print_function + +import unittest +import numpy as np +import sys +import math +from op_test import OpTest +import copy + + +def box_clip(input_box, im_info, output_box): + im_w = round(im_info[1] / im_info[2]) + im_h = round(im_info[0] / im_info[2]) + output_box[:, :, 0] = np.maximum( + np.minimum(input_box[:, :, 0], im_w - 1), 0) + output_box[:, :, 1] = np.maximum( + np.minimum(input_box[:, :, 1], im_h - 1), 0) + output_box[:, :, 2] = np.maximum( + np.minimum(input_box[:, :, 2], im_w - 1), 0) + output_box[:, :, 3] = np.maximum( + np.minimum(input_box[:, :, 3], im_h - 1), 0) + + +def batch_box_clip(input_boxes, im_info, lod): + n = input_boxes.shape[0] + m = input_boxes.shape[1] + output_boxes = np.zeros((n, m, 4), dtype=np.float32) + cur_offset = 0 + for i in range(len(lod)): + box_clip(input_boxes[cur_offset:(cur_offset + lod[i]), :, :], + im_info[i, :], + output_boxes[cur_offset:(cur_offset + lod[i]), :, :]) + cur_offset += lod[i] + return output_boxes + + +class TestBoxClipOp(OpTest): + def test_check_output(self): + self.check_output() + + def setUp(self): + self.op_type = "box_clip" + lod = [[1, 2, 3]] + input_boxes = np.random.random((6, 10, 4)) * 5 + im_info = np.array([[5, 8, 1.], [6, 6, 1.], [7, 5, 1.]]) + output_boxes = batch_box_clip(input_boxes, im_info, lod[0]) + + self.inputs = { + 'InputBox': (input_boxes.astype('float32'), lod), + 'ImInfo': im_info.astype('float32'), + } + self.outputs = {'OutputBox': output_boxes} + + +if __name__ == '__main__': + unittest.main() -- GitLab