提交 1c558ad3 编写于 作者: J jerrywgz

add gpu kernel for box clip, test=develop

上级 5246285e
...@@ -31,7 +31,7 @@ detection_library(polygon_box_transform_op SRCS polygon_box_transform_op.cc ...@@ -31,7 +31,7 @@ detection_library(polygon_box_transform_op SRCS polygon_box_transform_op.cc
polygon_box_transform_op.cu) polygon_box_transform_op.cu)
detection_library(rpn_target_assign_op SRCS rpn_target_assign_op.cc) detection_library(rpn_target_assign_op SRCS rpn_target_assign_op.cc)
detection_library(generate_proposal_labels_op SRCS generate_proposal_labels_op.cc) detection_library(generate_proposal_labels_op SRCS generate_proposal_labels_op.cc)
detection_library(box_clip_op SRCS box_clip_op.cc) detection_library(box_clip_op SRCS box_clip_op.cc box_clip_op.cu)
if(WITH_GPU) if(WITH_GPU)
detection_library(generate_proposals_op SRCS generate_proposals_op.cc generate_proposals_op.cu DEPS memory cub) detection_library(generate_proposals_op SRCS generate_proposals_op.cc generate_proposals_op.cu DEPS memory cub)
......
...@@ -21,51 +21,58 @@ class BoxClipOp : public framework::OperatorWithKernel { ...@@ -21,51 +21,58 @@ class BoxClipOp : public framework::OperatorWithKernel {
protected: protected:
void InferShape(framework::InferShapeContext* ctx) const override { void InferShape(framework::InferShapeContext* ctx) const override {
PADDLE_ENFORCE(ctx->HasInput("InputBox"), PADDLE_ENFORCE(ctx->HasInput("Input"),
"Input(InputBox) of BoxClipOp should not be null."); "Input(Input) of BoxClipOp should not be null.");
PADDLE_ENFORCE(ctx->HasInput("ImInfo"), PADDLE_ENFORCE(ctx->HasInput("ImInfo"),
"Input(ImInfo) of BoxClipOp should not be null."); "Input(ImInfo) of BoxClipOp should not be null.");
auto input_box_dims = ctx->GetInputDim("InputBox"); auto input_box_dims = ctx->GetInputDim("Input");
auto im_info_dims = ctx->GetInputDim("ImInfo"); auto im_info_dims = ctx->GetInputDim("ImInfo");
if (ctx->IsRuntime()) { if (ctx->IsRuntime()) {
auto input_box_size = input_box_dims.size(); auto input_box_size = input_box_dims.size();
PADDLE_ENFORCE_EQ(input_box_dims[input_box_size - 1], 4, PADDLE_ENFORCE_EQ(input_box_dims[input_box_size - 1], 4,
"The last dimension of InputBox must be 4"); "The last dimension of Input must be 4");
PADDLE_ENFORCE_EQ(im_info_dims.size(), 2, PADDLE_ENFORCE_EQ(im_info_dims.size(), 2,
"The rank of Input(InputBox) in BoxClipOp must be 2"); "The rank of Input(Input) in BoxClipOp must be 2");
PADDLE_ENFORCE_EQ(im_info_dims[1], 3, PADDLE_ENFORCE_EQ(im_info_dims[1], 3,
"The last dimension of ImInfo must be 3"); "The last dimension of ImInfo must be 3");
} }
ctx->ShareDim("InputBox", /*->*/ "OutputBox"); ctx->ShareDim("Input", /*->*/ "Output");
ctx->ShareLoD("InputBox", /*->*/ "OutputBox"); ctx->ShareLoD("Input", /*->*/ "Output");
} }
/*
protected: protected:
framework::OpKernelType GetExpectedKernelType( framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override { const framework::ExecutionContext& ctx) const override {
auto data_type = framework::GetDataTypeOfVar(ctx.InputVar("InputBox")); auto data_type = framework::GetDataTypeOfVar(ctx.InputVar("Input"));
return framework::OpKernelType(data_type, platform::CPUPlace()); return framework::OpKernelType(data_type, platform::CPUPlace());
} }
*/
}; };
class BoxClipOpMaker : public framework::OpProtoAndCheckerMaker { class BoxClipOpMaker : public framework::OpProtoAndCheckerMaker {
public: public:
void Make() override { void Make() override {
AddInput("InputBox", AddInput("Input",
"(LoDTensor) " "(LoDTensor) "
"InputBox is a LoDTensor with shape [..., 4] holds 4 points" "Input is a LoDTensor with shape [..., 4] holds 4 points"
"in last dimension in format [xmin, ymin, xmax, ymax]"); "in last dimension in format [xmin, ymin, xmax, ymax]");
AddInput("ImInfo", AddInput("ImInfo",
"(Tensor) Information for image reshape is in shape (N, 3), " "(Tensor) Information for image reshape is in shape (N, 3), "
"in format (height, width, im_scale)"); "in format (height, width, im_scale)");
AddOutput("OutputBox", AddOutput("Output",
"(LoDTensor) " "(LoDTensor) "
"OutputBox is a LoDTensor with the same shape as InputBox" "Output is a LoDTensor with the same shape as Input"
"and it is the result after clip"); "and it is the result after clip");
AddComment(R"DOC( AddComment(R"DOC(
This operator clips input boxes to original input images. This operator clips input boxes to original input images.
The formula is given as follows:
$$height_out = \max(\min(height_loc, im_h), 0)$$
$$width_out = \max(\min(width_loc, im_w), 0)$$
)DOC"); )DOC");
} }
}; };
......
...@@ -25,9 +25,9 @@ template <typename DeviceContext, typename T> ...@@ -25,9 +25,9 @@ template <typename DeviceContext, typename T>
class BoxClipKernel : public framework::OpKernel<T> { class BoxClipKernel : public framework::OpKernel<T> {
public: public:
void Compute(const framework::ExecutionContext& context) const override { void Compute(const framework::ExecutionContext& context) const override {
auto* input_box = context.Input<LoDTensor>("InputBox"); auto* input_box = context.Input<LoDTensor>("Input");
auto* im_info = context.Input<LoDTensor>("ImInfo"); auto* im_info = context.Input<LoDTensor>("ImInfo");
auto* output_box = context.Output<LoDTensor>("OutputBox"); auto* output_box = context.Output<LoDTensor>("Output");
auto& dev_ctx = auto& dev_ctx =
context.template device_context<platform::CPUDeviceContext>(); context.template device_context<platform::CPUDeviceContext>();
output_box->mutable_data<T>(context.GetPlace()); output_box->mutable_data<T>(context.GetPlace());
......
...@@ -31,11 +31,24 @@ import numpy ...@@ -31,11 +31,24 @@ import numpy
from functools import reduce from functools import reduce
__all__ = [ __all__ = [
'prior_box', 'density_prior_box', 'multi_box_head', 'bipartite_match', 'prior_box',
'target_assign', 'detection_output', 'ssd_loss', 'detection_map', 'density_prior_box',
'rpn_target_assign', 'anchor_generator', 'roi_perspective_transform', 'multi_box_head',
'generate_proposal_labels', 'generate_proposals', 'iou_similarity', 'bipartite_match',
'box_coder', 'polygon_box_transform', 'yolov3_loss', 'box_clip' 'target_assign',
'detection_output',
'ssd_loss',
'detection_map',
'rpn_target_assign',
'anchor_generator',
'roi_perspective_transform',
'generate_proposal_labels',
'generate_proposals',
'iou_similarity',
'box_coder',
'polygon_box_transform',
'yolov3_loss',
'box_clip',
] ]
...@@ -1800,13 +1813,22 @@ def generate_proposals(scores, ...@@ -1800,13 +1813,22 @@ def generate_proposals(scores,
return rpn_rois, rpn_roi_probs return rpn_rois, rpn_roi_probs
def box_clip(input_box, im_info, inplace=False, name=None): def box_clip(input, im_info, inplace=False, name=None):
""" """
Clip the box into the size given by im_info Clip the box into the size given by im_info
The formula is given as follows:
.. code-block:: text
height_out = max(min(height_loc, im_h), 0)
width_out = max(min(width_loc, im_w), 0)
Args: Args:
input_box(variable): The input box, the last dimension is 4. input_box(variable): The input box, the last dimension is 4.
im_info(variable): The information of image with shape [N, 3]. im_info(variable): The information of image with shape [N, 3] with
layout (height, width, scale). height and width
is the input size and scale is the ratio of input
size and original size.
inplace(bool): Must use :attr:`False` if :attr:`input_box` is used in inplace(bool): Must use :attr:`False` if :attr:`input_box` is used in
multiple operators. If this flag is set :attr:`True`, multiple operators. If this flag is set :attr:`True`,
reuse input :attr:`input_box` to clip, which will reuse input :attr:`input_box` to clip, which will
...@@ -1832,12 +1854,12 @@ def box_clip(input_box, im_info, inplace=False, name=None): ...@@ -1832,12 +1854,12 @@ def box_clip(input_box, im_info, inplace=False, name=None):
""" """
helper = LayerHelper("box_clip", **locals()) helper = LayerHelper("box_clip", **locals())
output = helper.create_variable_for_type_inference(dtype=input_box.dtype) output = helper.create_variable_for_type_inference(dtype=input.dtype)
inputs = {"InputBox": input_box, "ImInfo": im_info} inputs = {"Input": input, "ImInfo": im_info}
helper.append_op( helper.append_op(
type="box_clip", type="box_clip",
inputs=inputs, inputs=inputs,
attrs={"inplace:": inplace}, attrs={"inplace:": inplace},
outputs={"OutputBox": output}) outputs={"Output": output})
return output return output
...@@ -60,10 +60,10 @@ class TestBoxClipOp(OpTest): ...@@ -60,10 +60,10 @@ class TestBoxClipOp(OpTest):
output_boxes = batch_box_clip(input_boxes, im_info, lod[0]) output_boxes = batch_box_clip(input_boxes, im_info, lod[0])
self.inputs = { self.inputs = {
'InputBox': (input_boxes.astype('float32'), lod), 'Input': (input_boxes.astype('float32'), lod),
'ImInfo': im_info.astype('float32'), 'ImInfo': im_info.astype('float32'),
} }
self.outputs = {'OutputBox': output_boxes} self.outputs = {'Output': output_boxes}
if __name__ == '__main__': if __name__ == '__main__':
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册