Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
BaiXuePrincess
Paddle
提交
1743d1a5
P
Paddle
项目概览
BaiXuePrincess
/
Paddle
与 Fork 源项目一致
Fork自
PaddlePaddle / Paddle
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
1743d1a5
编写于
1月 31, 2019
作者:
J
jerrywgz
提交者:
GitHub
1月 31, 2019
浏览文件
操作
浏览文件
下载
差异文件
Merge pull request #15356 from jerrywgz/add_clip_op
Add box clip op
上级
43a67a26
4f18a9b8
变更
9
显示空白变更内容
内联
并排
Showing
9 changed file
with
368 addition
and
0 deletion
+368
-0
paddle/fluid/API.spec
paddle/fluid/API.spec
+1
-0
paddle/fluid/operators/detection/CMakeLists.txt
paddle/fluid/operators/detection/CMakeLists.txt
+1
-0
paddle/fluid/operators/detection/bbox_util.h
paddle/fluid/operators/detection/bbox_util.h
+24
-0
paddle/fluid/operators/detection/box_clip_op.cc
paddle/fluid/operators/detection/box_clip_op.cc
+86
-0
paddle/fluid/operators/detection/box_clip_op.cu
paddle/fluid/operators/detection/box_clip_op.cu
+74
-0
paddle/fluid/operators/detection/box_clip_op.h
paddle/fluid/operators/detection/box_clip_op.h
+50
-0
python/paddle/fluid/layers/detection.py
python/paddle/fluid/layers/detection.py
+51
-0
python/paddle/fluid/tests/test_detection.py
python/paddle/fluid/tests/test_detection.py
+11
-0
python/paddle/fluid/tests/unittests/test_box_clip_op.py
python/paddle/fluid/tests/unittests/test_box_clip_op.py
+70
-0
未找到文件。
paddle/fluid/API.spec
浏览文件 @
1743d1a5
...
...
@@ -325,6 +325,7 @@ paddle.fluid.layers.iou_similarity ArgSpec(args=['x', 'y', 'name'], varargs=None
paddle.fluid.layers.box_coder ArgSpec(args=['prior_box', 'prior_box_var', 'target_box', 'code_type', 'box_normalized', 'name', 'axis'], varargs=None, keywords=None, defaults=('encode_center_size', True, None, 0))
paddle.fluid.layers.polygon_box_transform ArgSpec(args=['input', 'name'], varargs=None, keywords=None, defaults=(None,))
paddle.fluid.layers.yolov3_loss ArgSpec(args=['x', 'gtbox', 'gtlabel', 'anchors', 'anchor_mask', 'class_num', 'ignore_thresh', 'downsample_ratio', 'name'], varargs=None, keywords=None, defaults=(None,))
paddle.fluid.layers.box_clip ArgSpec(args=['input', 'im_info', 'name'], varargs=None, keywords=None, defaults=(None,))
paddle.fluid.layers.multiclass_nms ArgSpec(args=['bboxes', 'scores', 'score_threshold', 'nms_top_k', 'keep_top_k', 'nms_threshold', 'normalized', 'nms_eta', 'background_label', 'name'], varargs=None, keywords=None, defaults=(0.3, True, 1.0, 0, None))
paddle.fluid.layers.accuracy ArgSpec(args=['input', 'label', 'k', 'correct', 'total'], varargs=None, keywords=None, defaults=(1, None, None))
paddle.fluid.layers.auc ArgSpec(args=['input', 'label', 'curve', 'num_thresholds', 'topk', 'slide_steps'], varargs=None, keywords=None, defaults=('ROC', 4095, 1, 1))
...
...
paddle/fluid/operators/detection/CMakeLists.txt
浏览文件 @
1743d1a5
...
...
@@ -31,6 +31,7 @@ detection_library(polygon_box_transform_op SRCS polygon_box_transform_op.cc
polygon_box_transform_op.cu
)
detection_library
(
rpn_target_assign_op SRCS rpn_target_assign_op.cc
)
detection_library
(
generate_proposal_labels_op SRCS generate_proposal_labels_op.cc
)
detection_library
(
box_clip_op SRCS box_clip_op.cc box_clip_op.cu
)
detection_library
(
yolov3_loss_op SRCS yolov3_loss_op.cc
)
if
(
WITH_GPU
)
...
...
paddle/fluid/operators/detection/bbox_util.h
浏览文件 @
1743d1a5
...
...
@@ -99,5 +99,29 @@ void BboxOverlaps(const framework::Tensor& r_boxes,
}
}
template
<
class
T
>
void
ClipTiledBoxes
(
const
platform
::
DeviceContext
&
ctx
,
const
framework
::
Tensor
&
im_info
,
const
framework
::
Tensor
&
input_boxes
,
framework
::
Tensor
*
out
)
{
T
*
out_data
=
out
->
mutable_data
<
T
>
(
ctx
.
GetPlace
());
const
T
*
im_info_data
=
im_info
.
data
<
T
>
();
const
T
*
input_boxes_data
=
input_boxes
.
data
<
T
>
();
T
zero
(
0
);
T
im_w
=
round
(
im_info_data
[
1
]
/
im_info_data
[
2
]);
T
im_h
=
round
(
im_info_data
[
0
]
/
im_info_data
[
2
]);
for
(
int64_t
i
=
0
;
i
<
input_boxes
.
numel
();
++
i
)
{
if
(
i
%
4
==
0
)
{
out_data
[
i
]
=
std
::
max
(
std
::
min
(
input_boxes_data
[
i
],
im_w
-
1
),
zero
);
}
else
if
(
i
%
4
==
1
)
{
out_data
[
i
]
=
std
::
max
(
std
::
min
(
input_boxes_data
[
i
],
im_h
-
1
),
zero
);
}
else
if
(
i
%
4
==
2
)
{
out_data
[
i
]
=
std
::
max
(
std
::
min
(
input_boxes_data
[
i
],
im_w
-
1
),
zero
);
}
else
{
out_data
[
i
]
=
std
::
max
(
std
::
min
(
input_boxes_data
[
i
],
im_h
-
1
),
zero
);
}
}
}
}
// namespace operators
}
// namespace paddle
paddle/fluid/operators/detection/box_clip_op.cc
0 → 100644
浏览文件 @
1743d1a5
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/operators/detection/box_clip_op.h"
#include "paddle/fluid/framework/op_registry.h"
namespace
paddle
{
namespace
operators
{
class
BoxClipOp
:
public
framework
::
OperatorWithKernel
{
public:
using
framework
::
OperatorWithKernel
::
OperatorWithKernel
;
protected:
void
InferShape
(
framework
::
InferShapeContext
*
ctx
)
const
override
{
PADDLE_ENFORCE
(
ctx
->
HasInput
(
"Input"
),
"Input(Input) of BoxClipOp should not be null."
);
PADDLE_ENFORCE
(
ctx
->
HasInput
(
"ImInfo"
),
"Input(ImInfo) of BoxClipOp should not be null."
);
auto
input_box_dims
=
ctx
->
GetInputDim
(
"Input"
);
auto
im_info_dims
=
ctx
->
GetInputDim
(
"ImInfo"
);
if
(
ctx
->
IsRuntime
())
{
auto
input_box_size
=
input_box_dims
.
size
();
PADDLE_ENFORCE_EQ
(
input_box_dims
[
input_box_size
-
1
],
4
,
"The last dimension of Input must be 4"
);
PADDLE_ENFORCE_EQ
(
im_info_dims
.
size
(),
2
,
"The rank of Input(Input) in BoxClipOp must be 2"
);
PADDLE_ENFORCE_EQ
(
im_info_dims
[
1
],
3
,
"The last dimension of ImInfo must be 3"
);
}
ctx
->
ShareDim
(
"Input"
,
/*->*/
"Output"
);
ctx
->
ShareLoD
(
"Input"
,
/*->*/
"Output"
);
}
};
class
BoxClipOpMaker
:
public
framework
::
OpProtoAndCheckerMaker
{
public:
void
Make
()
override
{
AddInput
(
"Input"
,
"(LoDTensor) "
"Input is a LoDTensor with shape [..., 4] holds 4 points"
"in last dimension in format [xmin, ymin, xmax, ymax]"
);
AddInput
(
"ImInfo"
,
"(Tensor) Information for image reshape is in shape (N, 3), "
"in format (height, width, im_scale)"
);
AddOutput
(
"Output"
,
"(LoDTensor) "
"Output is a LoDTensor with the same shape as Input"
"and it is the result after clip"
);
AddComment
(
R"DOC(
This operator clips input boxes to original input images.
For each input box, The formula is given as follows:
$$xmin = \max(\min(xmin, im_w - 1), 0)$$
$$ymin = \max(\min(ymin, im_h - 1), 0)$$
$$xmax = \max(\min(xmax, im_w - 1), 0)$$
$$ymax = \max(\min(ymax, im_h - 1), 0)$$
where im_w and im_h are computed from ImInfo, the formula is given as follows:
$$im_w = \round(width / im_scale)$$
$$im_h = \round(height / im_scale)$$
)DOC"
);
}
};
}
// namespace operators
}
// namespace paddle
namespace
ops
=
paddle
::
operators
;
REGISTER_OPERATOR
(
box_clip
,
ops
::
BoxClipOp
,
ops
::
BoxClipOpMaker
,
paddle
::
framework
::
EmptyGradOpMaker
);
REGISTER_OP_CPU_KERNEL
(
box_clip
,
ops
::
BoxClipKernel
<
paddle
::
platform
::
CPUDeviceContext
,
float
>
,
ops
::
BoxClipKernel
<
paddle
::
platform
::
CPUDeviceContext
,
double
>
);
paddle/fluid/operators/detection/box_clip_op.cu
0 → 100644
浏览文件 @
1743d1a5
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include <algorithm>
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/operators/detection/box_clip_op.h"
#include "paddle/fluid/operators/math/math_function.h"
#include "paddle/fluid/platform/cuda_primitives.h"
#include "paddle/fluid/platform/hostdevice.h"
namespace
paddle
{
namespace
operators
{
using
Tensor
=
framework
::
Tensor
;
using
LoDTenso
=
framework
::
LoDTensor
;
static
constexpr
int
ImInfoSize
=
3
;
template
<
typename
T
,
int
BlockSize
>
static
__global__
void
GPUBoxClip
(
const
T
*
input
,
const
size_t
*
lod
,
const
size_t
width
,
const
T
*
im_info
,
T
*
output
)
{
T
im_w
=
round
(
im_info
[
blockIdx
.
x
*
ImInfoSize
+
1
]
/
im_info
[
blockIdx
.
x
*
ImInfoSize
+
2
]);
T
im_h
=
round
(
im_info
[
blockIdx
.
x
*
ImInfoSize
]
/
im_info
[
blockIdx
.
x
*
ImInfoSize
+
2
]);
for
(
int
i
=
threadIdx
.
x
;
i
<
(
lod
[
blockIdx
.
x
+
1
]
-
lod
[
blockIdx
.
x
])
*
width
;
i
+=
BlockSize
)
{
int
idx
=
lod
[
blockIdx
.
x
]
*
width
+
i
;
T
im_size
=
(
idx
%
2
==
0
)
?
im_w
:
im_h
;
output
[
idx
]
=
max
(
min
(
input
[
idx
],
im_size
-
1
),
T
(
0.
));
}
}
template
<
typename
DeviceContext
,
typename
T
>
class
GPUBoxClipKernel
:
public
framework
::
OpKernel
<
T
>
{
public:
void
Compute
(
const
framework
::
ExecutionContext
&
context
)
const
override
{
PADDLE_ENFORCE
(
platform
::
is_gpu_place
(
context
.
GetPlace
()),
"This kernel only runs on GPU device."
);
auto
*
input
=
context
.
Input
<
LoDTensor
>
(
"Input"
);
auto
*
im_info
=
context
.
Input
<
Tensor
>
(
"ImInfo"
);
auto
*
output
=
context
.
Output
<
LoDTensor
>
(
"Output"
);
const
int64_t
num
=
input
->
dims
()[
0
];
const
int64_t
bbox_width
=
input
->
numel
()
/
num
;
auto
lod
=
input
->
lod
();
framework
::
LoD
abs_offset_lod
=
framework
::
ToAbsOffset
(
lod
);
auto
&
dev_ctx
=
context
.
template
device_context
<
DeviceContext
>();
auto
stream
=
dev_ctx
.
stream
();
const
size_t
batch_size
=
lod
.
back
().
size
()
-
1
;
T
*
output_data
=
output
->
mutable_data
<
T
>
(
dev_ctx
.
GetPlace
());
GPUBoxClip
<
T
,
512
><<<
batch_size
,
512
,
0
,
stream
>>>
(
input
->
data
<
T
>
(),
abs_offset_lod
[
0
].
CUDAMutableData
(
dev_ctx
.
GetPlace
()),
bbox_width
,
im_info
->
data
<
T
>
(),
output_data
);
}
};
}
// namespace operators
}
// namespace paddle
namespace
ops
=
paddle
::
operators
;
REGISTER_OP_CUDA_KERNEL
(
box_clip
,
ops
::
GPUBoxClipKernel
<
paddle
::
platform
::
CUDADeviceContext
,
float
>
,
ops
::
GPUBoxClipKernel
<
paddle
::
platform
::
CUDADeviceContext
,
double
>
);
paddle/fluid/operators/detection/box_clip_op.h
0 → 100644
浏览文件 @
1743d1a5
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include <string>
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/operators/detection/bbox_util.h"
#include "paddle/fluid/operators/math/math_function.h"
namespace
paddle
{
namespace
operators
{
using
Tensor
=
framework
::
Tensor
;
using
LoDTensor
=
framework
::
LoDTensor
;
template
<
typename
DeviceContext
,
typename
T
>
class
BoxClipKernel
:
public
framework
::
OpKernel
<
T
>
{
public:
void
Compute
(
const
framework
::
ExecutionContext
&
context
)
const
override
{
auto
*
input_box
=
context
.
Input
<
LoDTensor
>
(
"Input"
);
auto
*
im_info
=
context
.
Input
<
LoDTensor
>
(
"ImInfo"
);
auto
*
output_box
=
context
.
Output
<
LoDTensor
>
(
"Output"
);
auto
&
dev_ctx
=
context
.
template
device_context
<
platform
::
CPUDeviceContext
>();
output_box
->
mutable_data
<
T
>
(
context
.
GetPlace
());
if
(
input_box
->
lod
().
size
())
{
PADDLE_ENFORCE_EQ
(
input_box
->
lod
().
size
(),
1UL
,
"Only support 1 level of LoD."
);
}
auto
box_lod
=
input_box
->
lod
().
back
();
int64_t
n
=
static_cast
<
int64_t
>
(
box_lod
.
size
()
-
1
);
for
(
int
i
=
0
;
i
<
n
;
++
i
)
{
Tensor
im_info_slice
=
im_info
->
Slice
(
i
,
i
+
1
);
Tensor
box_slice
=
input_box
->
Slice
(
box_lod
[
i
],
box_lod
[
i
+
1
]);
Tensor
output_slice
=
output_box
->
Slice
(
box_lod
[
i
],
box_lod
[
i
+
1
]);
ClipTiledBoxes
<
T
>
(
dev_ctx
,
im_info_slice
,
box_slice
,
&
output_slice
);
}
}
};
}
// namespace operators
}
// namespace paddle
python/paddle/fluid/layers/detection.py
浏览文件 @
1743d1a5
...
...
@@ -49,6 +49,7 @@ __all__ = [
'box_coder'
,
'polygon_box_transform'
,
'yolov3_loss'
,
'box_clip'
,
'multiclass_nms'
,
]
...
...
@@ -2055,6 +2056,54 @@ def generate_proposals(scores,
return
rpn_rois
,
rpn_roi_probs
def
box_clip
(
input
,
im_info
,
name
=
None
):
"""
Clip the box into the size given by im_info
For each input box, The formula is given as follows:
.. code-block:: text
xmin = max(min(xmin, im_w - 1), 0)
ymin = max(min(ymin, im_h - 1), 0)
xmax = max(min(xmax, im_w - 1), 0)
ymax = max(min(ymax, im_h - 1), 0)
where im_w and im_h are computed from im_info:
.. code-block:: text
im_h = round(height / scale)
im_w = round(weight / scale)
Args:
input(variable): The input box, the last dimension is 4.
im_info(variable): The information of image with shape [N, 3] with
layout (height, width, scale). height and width
is the input size and scale is the ratio of input
size and original size.
name (str): The name of this layer. It is optional.
Returns:
Variable: The cliped tensor variable.
Examples:
.. code-block:: python
boxes = fluid.layers.data(
name='data', shape=[8, 4], dtype='float32', lod_level=1)
im_info = fluid.layers.data(name='im_info', shape=[3])
out = fluid.layers.box_clip(
input=boxes, im_info=im_info, inplace=True)
"""
helper
=
LayerHelper
(
"box_clip"
,
**
locals
())
output
=
helper
.
create_variable_for_type_inference
(
dtype
=
input
.
dtype
)
inputs
=
{
"Input"
:
input
,
"ImInfo"
:
im_info
}
helper
.
append_op
(
type
=
"box_clip"
,
inputs
=
inputs
,
outputs
=
{
"Output"
:
output
})
return
output
def
multiclass_nms
(
bboxes
,
scores
,
score_threshold
,
...
...
@@ -2132,9 +2181,11 @@ def multiclass_nms(bboxes,
(After version 1.3, when no boxes detected, the lod is changed
from {0} to {1})
Examples:
.. code-block:: python
boxes = fluid.layers.data(name='bboxes', shape=[81, 4],
dtype='float32', lod_level=1)
scores = fluid.layers.data(name='scores', shape=[81],
...
...
python/paddle/fluid/tests/test_detection.py
浏览文件 @
1743d1a5
...
...
@@ -482,6 +482,17 @@ class TestYoloDetection(unittest.TestCase):
self
.
assertIsNotNone
(
loss
)
class
TestBoxClip
(
unittest
.
TestCase
):
def
test_box_clip
(
self
):
program
=
Program
()
with
program_guard
(
program
):
input_box
=
layers
.
data
(
name
=
'input_box'
,
shape
=
[
7
,
4
],
dtype
=
'float32'
,
lod_level
=
1
)
im_info
=
layers
.
data
(
name
=
'im_info'
,
shape
=
[
3
],
dtype
=
'float32'
)
out
=
layers
.
box_clip
(
input_box
,
im_info
)
self
.
assertIsNotNone
(
out
)
class
TestMulticlassNMS
(
unittest
.
TestCase
):
def
test_multiclass_nms
(
self
):
program
=
Program
()
...
...
python/paddle/fluid/tests/unittests/test_box_clip_op.py
0 → 100644
浏览文件 @
1743d1a5
# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from
__future__
import
print_function
import
unittest
import
numpy
as
np
import
sys
import
math
from
op_test
import
OpTest
import
copy
def
box_clip
(
input_box
,
im_info
,
output_box
):
im_w
=
round
(
im_info
[
1
]
/
im_info
[
2
])
im_h
=
round
(
im_info
[
0
]
/
im_info
[
2
])
output_box
[:,
:,
0
]
=
np
.
maximum
(
np
.
minimum
(
input_box
[:,
:,
0
],
im_w
-
1
),
0
)
output_box
[:,
:,
1
]
=
np
.
maximum
(
np
.
minimum
(
input_box
[:,
:,
1
],
im_h
-
1
),
0
)
output_box
[:,
:,
2
]
=
np
.
maximum
(
np
.
minimum
(
input_box
[:,
:,
2
],
im_w
-
1
),
0
)
output_box
[:,
:,
3
]
=
np
.
maximum
(
np
.
minimum
(
input_box
[:,
:,
3
],
im_h
-
1
),
0
)
def
batch_box_clip
(
input_boxes
,
im_info
,
lod
):
n
=
input_boxes
.
shape
[
0
]
m
=
input_boxes
.
shape
[
1
]
output_boxes
=
np
.
zeros
((
n
,
m
,
4
),
dtype
=
np
.
float32
)
cur_offset
=
0
for
i
in
range
(
len
(
lod
)):
box_clip
(
input_boxes
[
cur_offset
:(
cur_offset
+
lod
[
i
]),
:,
:],
im_info
[
i
,
:],
output_boxes
[
cur_offset
:(
cur_offset
+
lod
[
i
]),
:,
:])
cur_offset
+=
lod
[
i
]
return
output_boxes
class
TestBoxClipOp
(
OpTest
):
def
test_check_output
(
self
):
self
.
check_output
()
def
setUp
(
self
):
self
.
op_type
=
"box_clip"
lod
=
[[
1
,
2
,
3
]]
input_boxes
=
np
.
random
.
random
((
6
,
10
,
4
))
*
5
im_info
=
np
.
array
([[
5
,
8
,
1.
],
[
6
,
6
,
1.
],
[
7
,
5
,
1.
]])
output_boxes
=
batch_box_clip
(
input_boxes
,
im_info
,
lod
[
0
])
self
.
inputs
=
{
'Input'
:
(
input_boxes
.
astype
(
'float32'
),
lod
),
'ImInfo'
:
im_info
.
astype
(
'float32'
),
}
self
.
outputs
=
{
'Output'
:
output_boxes
}
if
__name__
==
'__main__'
:
unittest
.
main
()
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录