提交 1c558ad3 编写于 作者: J jerrywgz

add gpu kernel for box clip, test=develop

上级 5246285e
......@@ -31,7 +31,7 @@ detection_library(polygon_box_transform_op SRCS polygon_box_transform_op.cc
polygon_box_transform_op.cu)
detection_library(rpn_target_assign_op SRCS rpn_target_assign_op.cc)
detection_library(generate_proposal_labels_op SRCS generate_proposal_labels_op.cc)
detection_library(box_clip_op SRCS box_clip_op.cc)
detection_library(box_clip_op SRCS box_clip_op.cc box_clip_op.cu)
if(WITH_GPU)
detection_library(generate_proposals_op SRCS generate_proposals_op.cc generate_proposals_op.cu DEPS memory cub)
......
......@@ -21,51 +21,58 @@ class BoxClipOp : public framework::OperatorWithKernel {
protected:
void InferShape(framework::InferShapeContext* ctx) const override {
PADDLE_ENFORCE(ctx->HasInput("InputBox"),
"Input(InputBox) of BoxClipOp should not be null.");
PADDLE_ENFORCE(ctx->HasInput("Input"),
"Input(Input) of BoxClipOp should not be null.");
PADDLE_ENFORCE(ctx->HasInput("ImInfo"),
"Input(ImInfo) of BoxClipOp should not be null.");
auto input_box_dims = ctx->GetInputDim("InputBox");
auto input_box_dims = ctx->GetInputDim("Input");
auto im_info_dims = ctx->GetInputDim("ImInfo");
if (ctx->IsRuntime()) {
auto input_box_size = input_box_dims.size();
PADDLE_ENFORCE_EQ(input_box_dims[input_box_size - 1], 4,
"The last dimension of InputBox must be 4");
"The last dimension of Input must be 4");
PADDLE_ENFORCE_EQ(im_info_dims.size(), 2,
"The rank of Input(InputBox) in BoxClipOp must be 2");
"The rank of Input(Input) in BoxClipOp must be 2");
PADDLE_ENFORCE_EQ(im_info_dims[1], 3,
"The last dimension of ImInfo must be 3");
}
ctx->ShareDim("InputBox", /*->*/ "OutputBox");
ctx->ShareLoD("InputBox", /*->*/ "OutputBox");
ctx->ShareDim("Input", /*->*/ "Output");
ctx->ShareLoD("Input", /*->*/ "Output");
}
/*
protected:
framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override {
auto data_type = framework::GetDataTypeOfVar(ctx.InputVar("InputBox"));
auto data_type = framework::GetDataTypeOfVar(ctx.InputVar("Input"));
return framework::OpKernelType(data_type, platform::CPUPlace());
}
*/
};
class BoxClipOpMaker : public framework::OpProtoAndCheckerMaker {
public:
void Make() override {
AddInput("InputBox",
AddInput("Input",
"(LoDTensor) "
"InputBox is a LoDTensor with shape [..., 4] holds 4 points"
"Input is a LoDTensor with shape [..., 4] holds 4 points"
"in last dimension in format [xmin, ymin, xmax, ymax]");
AddInput("ImInfo",
"(Tensor) Information for image reshape is in shape (N, 3), "
"in format (height, width, im_scale)");
AddOutput("OutputBox",
AddOutput("Output",
"(LoDTensor) "
"OutputBox is a LoDTensor with the same shape as InputBox"
"Output is a LoDTensor with the same shape as Input"
"and it is the result after clip");
AddComment(R"DOC(
This operator clips input boxes to original input images.
This operator clips input boxes to original input images.
The formula is given as follows:
$$height_out = \max(\min(height_loc, im_h), 0)$$
$$width_out = \max(\min(width_loc, im_w), 0)$$
)DOC");
}
};
......
......@@ -25,9 +25,9 @@ template <typename DeviceContext, typename T>
class BoxClipKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& context) const override {
auto* input_box = context.Input<LoDTensor>("InputBox");
auto* input_box = context.Input<LoDTensor>("Input");
auto* im_info = context.Input<LoDTensor>("ImInfo");
auto* output_box = context.Output<LoDTensor>("OutputBox");
auto* output_box = context.Output<LoDTensor>("Output");
auto& dev_ctx =
context.template device_context<platform::CPUDeviceContext>();
output_box->mutable_data<T>(context.GetPlace());
......
......@@ -31,11 +31,24 @@ import numpy
from functools import reduce
__all__ = [
'prior_box', 'density_prior_box', 'multi_box_head', 'bipartite_match',
'target_assign', 'detection_output', 'ssd_loss', 'detection_map',
'rpn_target_assign', 'anchor_generator', 'roi_perspective_transform',
'generate_proposal_labels', 'generate_proposals', 'iou_similarity',
'box_coder', 'polygon_box_transform', 'yolov3_loss', 'box_clip'
'prior_box',
'density_prior_box',
'multi_box_head',
'bipartite_match',
'target_assign',
'detection_output',
'ssd_loss',
'detection_map',
'rpn_target_assign',
'anchor_generator',
'roi_perspective_transform',
'generate_proposal_labels',
'generate_proposals',
'iou_similarity',
'box_coder',
'polygon_box_transform',
'yolov3_loss',
'box_clip',
]
......@@ -1800,13 +1813,22 @@ def generate_proposals(scores,
return rpn_rois, rpn_roi_probs
def box_clip(input_box, im_info, inplace=False, name=None):
def box_clip(input, im_info, inplace=False, name=None):
"""
Clip the box into the size given by im_info
The formula is given as follows:
.. code-block:: text
height_out = max(min(height_loc, im_h), 0)
width_out = max(min(width_loc, im_w), 0)
Args:
input_box(variable): The input box, the last dimension is 4.
im_info(variable): The information of image with shape [N, 3].
im_info(variable): The information of image with shape [N, 3] with
layout (height, width, scale). height and width
is the input size and scale is the ratio of input
size and original size.
inplace(bool): Must use :attr:`False` if :attr:`input_box` is used in
multiple operators. If this flag is set :attr:`True`,
reuse input :attr:`input_box` to clip, which will
......@@ -1832,12 +1854,12 @@ def box_clip(input_box, im_info, inplace=False, name=None):
"""
helper = LayerHelper("box_clip", **locals())
output = helper.create_variable_for_type_inference(dtype=input_box.dtype)
inputs = {"InputBox": input_box, "ImInfo": im_info}
output = helper.create_variable_for_type_inference(dtype=input.dtype)
inputs = {"Input": input, "ImInfo": im_info}
helper.append_op(
type="box_clip",
inputs=inputs,
attrs={"inplace:": inplace},
outputs={"OutputBox": output})
outputs={"Output": output})
return output
......@@ -60,10 +60,10 @@ class TestBoxClipOp(OpTest):
output_boxes = batch_box_clip(input_boxes, im_info, lod[0])
self.inputs = {
'InputBox': (input_boxes.astype('float32'), lod),
'Input': (input_boxes.astype('float32'), lod),
'ImInfo': im_info.astype('float32'),
}
self.outputs = {'OutputBox': output_boxes}
self.outputs = {'Output': output_boxes}
if __name__ == '__main__':
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册