提交 481d8bce 编写于 作者: J jerrywgz

add box clip op

上级 3f815e07
......@@ -318,6 +318,7 @@ paddle.fluid.layers.iou_similarity ArgSpec(args=['x', 'y', 'name'], varargs=None
paddle.fluid.layers.box_coder ArgSpec(args=['prior_box', 'prior_box_var', 'target_box', 'code_type', 'box_normalized', 'name'], varargs=None, keywords=None, defaults=('encode_center_size', True, None))
paddle.fluid.layers.polygon_box_transform ArgSpec(args=['input', 'name'], varargs=None, keywords=None, defaults=(None,))
paddle.fluid.layers.yolov3_loss ArgSpec(args=['x', 'gtbox', 'gtlabel', 'anchors', 'class_num', 'ignore_thresh', 'loss_weight_xy', 'loss_weight_wh', 'loss_weight_conf_target', 'loss_weight_conf_notarget', 'loss_weight_class', 'name'], varargs=None, keywords=None, defaults=(None, None, None, None, None, None))
paddle.fluid.layers.box_clip ArgSpec(args=['input_box', 'im_info', 'inplace', 'name'], varargs=None, keywords=None, defaults=(False, None))
paddle.fluid.layers.accuracy ArgSpec(args=['input', 'label', 'k', 'correct', 'total'], varargs=None, keywords=None, defaults=(1, None, None))
paddle.fluid.layers.auc ArgSpec(args=['input', 'label', 'curve', 'num_thresholds', 'topk', 'slide_steps'], varargs=None, keywords=None, defaults=('ROC', 4095, 1, 1))
paddle.fluid.layers.exponential_decay ArgSpec(args=['learning_rate', 'decay_steps', 'decay_rate', 'staircase'], varargs=None, keywords=None, defaults=(False,))
......@@ -494,6 +495,7 @@ paddle.reader.buffered ArgSpec(args=['reader', 'size'], varargs=None, keywords=N
paddle.reader.compose ArgSpec(args=[], varargs='readers', keywords='kwargs', defaults=None)
paddle.reader.chain ArgSpec(args=[], varargs='readers', keywords=None, defaults=None)
paddle.reader.shuffle ArgSpec(args=['reader', 'buf_size'], varargs=None, keywords=None, defaults=None)
paddle.reader.ComposeNotAligned.__init__
paddle.reader.firstn ArgSpec(args=['reader', 'n'], varargs=None, keywords=None, defaults=None)
paddle.reader.xmap_readers ArgSpec(args=['mapper', 'reader', 'process_num', 'buffer_size', 'order'], varargs=None, keywords=None, defaults=(False,))
paddle.reader.PipeReader.__init__ ArgSpec(args=['self', 'command', 'bufsize', 'file_type'], varargs=None, keywords=None, defaults=(8192, 'plain'))
......
......@@ -31,6 +31,7 @@ detection_library(polygon_box_transform_op SRCS polygon_box_transform_op.cc
polygon_box_transform_op.cu)
detection_library(rpn_target_assign_op SRCS rpn_target_assign_op.cc)
detection_library(generate_proposal_labels_op SRCS generate_proposal_labels_op.cc)
detection_library(box_clip_op SRCS box_clip_op.cc)
if(WITH_GPU)
detection_library(generate_proposals_op SRCS generate_proposals_op.cc generate_proposals_op.cu DEPS memory cub)
......
......@@ -93,5 +93,29 @@ void BboxOverlaps(const framework::Tensor& r_boxes,
}
}
template <class T>
void ClipTiledBoxes(const platform::DeviceContext& ctx,
const framework::Tensor& im_info,
const framework::Tensor& input_boxes,
framework::Tensor* out) {
T* out_data = out->mutable_data<T>(ctx.GetPlace());
const T* im_info_data = im_info.data<T>();
const T* input_boxes_data = input_boxes.data<T>();
T zero(0);
T im_w = round(im_info_data[1] / im_info_data[2]);
T im_h = round(im_info_data[0] / im_info_data[2]);
for (int64_t i = 0; i < input_boxes.numel(); ++i) {
if (i % 4 == 0) {
out_data[i] = std::max(std::min(input_boxes_data[i], im_w - 1), zero);
} else if (i % 4 == 1) {
out_data[i] = std::max(std::min(input_boxes_data[i], im_h - 1), zero);
} else if (i % 4 == 2) {
out_data[i] = std::max(std::min(input_boxes_data[i], im_w - 1), zero);
} else {
out_data[i] = std::max(std::min(input_boxes_data[i], im_h - 1), zero);
}
}
}
} // namespace operators
} // namespace paddle
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/operators/detection/box_clip_op.h"
#include "paddle/fluid/framework/op_registry.h"
namespace paddle {
namespace operators {
class BoxClipOp : public framework::OperatorWithKernel {
public:
using framework::OperatorWithKernel::OperatorWithKernel;
protected:
void InferShape(framework::InferShapeContext *ctx) const override {
PADDLE_ENFORCE(ctx->HasInput("InputBox"),
"Input(InputBox) of BoxClipOp should not be null.");
PADDLE_ENFORCE(ctx->HasInput("ImInfo"),
"Input(ImInfo) of BoxClipOp should not be null.");
auto input_box_dims = ctx->GetInputDim("InputBox");
auto im_info_dims = ctx->GetInputDim("ImInfo");
if (ctx->IsRuntime()) {
auto input_box_size = input_box_dims.size();
PADDLE_ENFORCE_EQ(input_box_dims[input_box_size - 1], 4,
"The last dimension of InputBox must be 4");
PADDLE_ENFORCE_EQ(im_info_dims.size(), 2,
"The rank of Input(InputBox) in BoxClipOp must be 2");
PADDLE_ENFORCE_EQ(im_info_dims[1], 2,
"The last dimension of ImInfo must be 2");
}
ctx->ShareDim("InputBox", /*->*/ "OutputBox");
ctx->ShareLoD("InputBox", /*->*/ "OutputBox");
}
};
class BoxClipOpMaker : public framework::OpProtoAndCheckerMaker {
public:
void Make() override {
AddInput("InputBox",
"(LoDTensor) "
"InputBox is a LoDTensor with shape [..., 4] holds 4 points"
"in last dimension in format [xmin, ymin, xmax, ymax]");
AddInput("ImInfo",
"(Tensor) Information for image reshape is in shape (N, 2), "
"in format (height, width)");
AddOutput("OutputBox",
"(LoDTensor) "
"OutputBox is a LoDTensor with the same shape as InputBox"
"and it is the result after clip");
AddComment(R"DOC(
This operator clips input boxes to original input images.
)DOC");
}
};
} // namespace operators
} // namespace paddle
namespace ops = paddle::operators;
REGISTER_OPERATOR(box_clip, ops::BoxClipOp, ops::BoxClipOpMaker,
paddle::framework::EmptyGradOpMaker);
REGISTER_OP_CPU_KERNEL(
box_clip, ops::BoxClipKernel<paddle::platform::CPUDeviceContext, float>,
ops::BoxClipKernel<paddle::platform::CPUDeviceContext, double>);
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include <string>
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/operators/detection/bbox_util.h"
#include "paddle/fluid/operators/math/math_function.h"
namespace paddle {
namespace operators {
using Tensor = framework::Tensor;
using LoDTensor = framework::LoDTensor;
template <typename DeviceContext, typename T>
class BoxClipKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& context) const override {
auto* input_box = context.Input<LoDTensor>("InputBox");
auto* im_info = context.Input<LoDTensor>("ImInfo");
auto* output_box = context.Output<LoDTensor>("OutputBox");
auto& dev_ctx =
context.template device_context<platform::CPUDeviceContext>();
output_box->mutable_data<T>(context.GetPlace());
if (input_box->lod().size()) {
PADDLE_ENFORCE_EQ(input_box->lod().size(), 1UL,
"Only support 1 level of LoD.");
}
auto box_lod = input_box->lod().back();
int64_t n = static_cast<int64_t>(box_lod.size() - 1);
for (int i = 0; i < n; ++i) {
Tensor im_info_slice = im_info->Slice(i, i + 1);
Tensor box_slice = input_box->Slice(box_lod[i], box_lod[i + 1]);
Tensor output_slice = output_box->Slice(box_lod[i], box_lod[i + 1]);
ClipTiledBoxes<T>(dev_ctx, im_info_slice, box_slice, &output_slice);
}
}
};
} // namespace operators
} // namespace paddle
......@@ -31,23 +31,11 @@ import numpy
from functools import reduce
__all__ = [
'prior_box',
'density_prior_box',
'multi_box_head',
'bipartite_match',
'target_assign',
'detection_output',
'ssd_loss',
'detection_map',
'rpn_target_assign',
'anchor_generator',
'roi_perspective_transform',
'generate_proposal_labels',
'generate_proposals',
'iou_similarity',
'box_coder',
'polygon_box_transform',
'yolov3_loss',
'prior_box', 'density_prior_box', 'multi_box_head', 'bipartite_match',
'target_assign', 'detection_output', 'ssd_loss', 'detection_map',
'rpn_target_assign', 'anchor_generator', 'roi_perspective_transform',
'generate_proposal_labels', 'generate_proposals', 'iou_similarity',
'box_coder', 'polygon_box_transform', 'yolov3_loss', 'box_clip'
]
......@@ -1810,3 +1798,47 @@ def generate_proposals(scores,
rpn_roi_probs.stop_gradient = True
return rpn_rois, rpn_roi_probs
def box_clip(input_box, im_info, inplace=False, name=None):
"""
Clip the box into the size given by im_info
Args:
input_box(variable): The input box, the last dimension is 4.
im_info(variable): The information of image with shape [N, 3].
inplace(bool): Must use :attr:`False` if :attr:`input_box` is used in
multiple operators. If this flag is set :attr:`True`,
reuse input :attr:`input_box` to clip, which will
change the value of tensor variable :attr:`input_box`
and might cause errors when :attr:`input_box` is used
in multiple operators. If :attr:`False`, preserve the
value pf :attr:`input_box` and create a new output
tensor variable whose data is copied from input x but
cliped.
name (str): The name of this layer. It is optional.
Returns:
Variable: The cliped tensor variable.
Examples:
.. code-block:: python
boxes = fluid.layers.data(
name='data', shape=[8, 4], dtype='float32', lod_level=1)
im_info = fluid.layers.data(name='im_info', shape=[3])
out = fluid.layers.box_clip(
input_box=boxes, im_info=im_info, inplace=True)
"""
inputs = {"InputBox": input_box, "ImInfo": im_info}
helper = LayerHelper("box_clip", **locals())
output = helper.create_variable_for_type_inference(dtype=input_box.dtype)
helper.append_op(
type="box_clip",
inputs=inputs,
attrs={"inplace:": inplace},
outputs={"OutputBox": output})
return output
......@@ -354,8 +354,7 @@ class TestGenerateProposals(unittest.TestCase):
data_shape = [20, 64, 64]
images = fluid.layers.data(
name='images', shape=data_shape, dtype='float32')
im_info = fluid.layers.data(
name='im_info', shape=[1, 3], dtype='float32')
im_info = fluid.layers.data(name='im_info', shape=[3], dtype='float32')
anchors, variances = fluid.layers.anchor_generator(
name='anchor_generator',
input=images,
......@@ -401,5 +400,16 @@ class TestYoloDetection(unittest.TestCase):
self.assertIsNotNone(loss)
class TestBoxClip(unittest.TestCase):
def test_box_clip(self):
program = Program()
with program_guard(program):
input_box = layers.data(
name='input_box', shape=[7, 4], dtype='float32', lod_level=1)
im_info = layers.data(name='im_info', shape=[3], dtype='float32')
out = layers.box_clip(input_box, im_info)
self.assertIsNotNone(out)
if __name__ == '__main__':
unittest.main()
# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import print_function
import unittest
import numpy as np
import sys
import math
from op_test import OpTest
import copy
def box_clip(input_box, im_info, output_box):
im_w = round(im_info[1] / im_info[2])
im_h = round(im_info[0] / im_info[2])
output_box[:, :, 0] = np.maximum(
np.minimum(input_box[:, :, 0], im_w - 1), 0)
output_box[:, :, 1] = np.maximum(
np.minimum(input_box[:, :, 1], im_h - 1), 0)
output_box[:, :, 2] = np.maximum(
np.minimum(input_box[:, :, 2], im_w - 1), 0)
output_box[:, :, 3] = np.maximum(
np.minimum(input_box[:, :, 3], im_h - 1), 0)
def batch_box_clip(input_boxes, im_info, lod):
n = input_boxes.shape[0]
m = input_boxes.shape[1]
output_boxes = np.zeros((n, m, 4), dtype=np.float32)
cur_offset = 0
for i in range(len(lod)):
box_clip(input_boxes[cur_offset:(cur_offset + lod[i]), :, :],
im_info[i, :],
output_boxes[cur_offset:(cur_offset + lod[i]), :, :])
cur_offset += lod[i]
return output_boxes
class TestBoxClipOp(OpTest):
def test_check_output(self):
self.check_output()
def setUp(self):
self.op_type = "box_clip"
lod = [[1, 2, 3]]
input_boxes = np.random.random((6, 10, 4)) * 5
im_info = np.array([[5, 8, 1.], [6, 6, 1.], [7, 5, 1.]])
output_boxes = batch_box_clip(input_boxes, im_info, lod[0])
self.inputs = {
'InputBox': (input_boxes.astype('float32'), lod),
'ImInfo': im_info.astype('float32'),
}
self.outputs = {'OutputBox': output_boxes}
if __name__ == '__main__':
unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册