Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
PaddleDetection
提交
11f1baa4
P
PaddleDetection
项目概览
PaddlePaddle
/
PaddleDetection
大约 1 年 前同步成功
通知
695
Star
11112
Fork
2696
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
184
列表
看板
标记
里程碑
合并请求
40
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
PaddleDetection
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
184
Issue
184
列表
看板
标记
里程碑
合并请求
40
合并请求
40
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
11f1baa4
编写于
1月 23, 2019
作者:
J
jerrywgz
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
refine code, test=develop
上级
57e5f61e
变更
3
隐藏空白更改
内联
并排
Showing
3 changed file
with
37 addition
and
33 deletion
+37
-33
paddle/fluid/operators/detection/box_clip_op.cc
paddle/fluid/operators/detection/box_clip_op.cc
+9
-11
paddle/fluid/operators/detection/box_clip_op.cu
paddle/fluid/operators/detection/box_clip_op.cu
+6
-6
python/paddle/fluid/layers/detection.py
python/paddle/fluid/layers/detection.py
+22
-16
未找到文件。
paddle/fluid/operators/detection/box_clip_op.cc
浏览文件 @
11f1baa4
...
...
@@ -41,14 +41,6 @@ class BoxClipOp : public framework::OperatorWithKernel {
ctx
->
ShareDim
(
"Input"
,
/*->*/
"Output"
);
ctx
->
ShareLoD
(
"Input"
,
/*->*/
"Output"
);
}
/*
protected:
framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override {
auto data_type = framework::GetDataTypeOfVar(ctx.InputVar("Input"));
return framework::OpKernelType(data_type, platform::CPUPlace());
}
*/
};
class
BoxClipOpMaker
:
public
framework
::
OpProtoAndCheckerMaker
{
...
...
@@ -68,11 +60,17 @@ class BoxClipOpMaker : public framework::OpProtoAndCheckerMaker {
AddComment
(
R"DOC(
This operator clips input boxes to original input images.
The formula is given as follows:
For each input box,
The formula is given as follows:
$$height_out = \max(\min(height_loc, im_h), 0)$$
$$width_out = \max(\min(width_loc, im_w), 0)$$
$$xmin = \max(\min(xmin, im_w - 1), 0)$$
$$ymin = \max(\min(ymin, im_h - 1), 0)$$
$$xmax = \max(\min(xmax, im_w - 1), 0)$$
$$ymax = \max(\min(ymax, im_h - 1), 0)$$
where im_w and im_h are computed from ImInfo, the formula is given as follows:
$$im_w = \round(width / im_scale)$$
$$im_h = \round(height / im_scale)$$
)DOC"
);
}
};
...
...
paddle/fluid/operators/detection/box_clip_op.cu
浏览文件 @
11f1baa4
...
...
@@ -30,13 +30,13 @@ template <typename T, int BlockSize>
static
__global__
void
GPUBoxClip
(
const
T
*
input
,
const
size_t
*
lod
,
const
size_t
width
,
const
T
*
im_info
,
T
*
output
)
{
T
im_w
=
round
(
im_info
[
blockIdx
.
x
*
ImInfoSize
+
1
]
/
im_info
[
blockIdx
.
x
*
ImInfoSize
+
2
]);
T
im_h
=
round
(
im_info
[
blockIdx
.
x
*
ImInfoSize
]
/
im_info
[
blockIdx
.
x
*
ImInfoSize
+
2
]);
for
(
int
i
=
threadIdx
.
x
;
i
<
(
lod
[
blockIdx
.
x
+
1
]
-
lod
[
blockIdx
.
x
])
*
width
;
i
+=
BlockSize
)
{
int
idx
=
lod
[
blockIdx
.
x
]
*
width
+
i
;
T
im_w
=
round
(
im_info
[
blockIdx
.
x
*
ImInfoSize
+
1
]
/
im_info
[
blockIdx
.
x
*
ImInfoSize
+
2
]);
T
im_h
=
round
(
im_info
[
blockIdx
.
x
*
ImInfoSize
]
/
im_info
[
blockIdx
.
x
*
ImInfoSize
+
2
]);
T
im_size
=
(
idx
%
2
==
0
)
?
im_w
:
im_h
;
output
[
idx
]
=
max
(
min
(
input
[
idx
],
im_size
-
1
),
T
(
0.
));
}
...
...
@@ -57,9 +57,9 @@ class GPUBoxClipKernel : public framework::OpKernel<T> {
framework
::
LoD
abs_offset_lod
=
framework
::
ToAbsOffset
(
lod
);
auto
&
dev_ctx
=
context
.
template
device_context
<
DeviceContext
>();
auto
stream
=
dev_ctx
.
stream
();
const
size_t
num_lod
=
lod
.
back
().
size
()
-
1
;
const
size_t
batch_size
=
lod
.
back
().
size
()
-
1
;
T
*
output_data
=
output
->
mutable_data
<
T
>
(
dev_ctx
.
GetPlace
());
GPUBoxClip
<
T
,
512
><<<
num_lod
,
512
,
0
,
stream
>>>
(
GPUBoxClip
<
T
,
512
><<<
batch_size
,
512
,
0
,
stream
>>>
(
input
->
data
<
T
>
(),
abs_offset_lod
[
0
].
CUDAMutableData
(
dev_ctx
.
GetPlace
()),
bbox_width
,
im_info
->
data
<
T
>
(),
output_data
);
}
...
...
python/paddle/fluid/layers/detection.py
浏览文件 @
11f1baa4
...
...
@@ -1816,26 +1816,35 @@ def generate_proposals(scores,
def
box_clip
(
input
,
im_info
,
inplace
=
False
,
name
=
None
):
"""
Clip the box into the size given by im_info
The formula is given as follows:
For each input box,
The formula is given as follows:
.. code-block:: text
height_out = max(min(height_loc, im_h), 0)
width_out = max(min(width_loc, im_w), 0)
xmin = max(min(xmin, im_w - 1), 0)
ymin = max(min(ymin, im_h - 1), 0)
xmax = max(min(xmax, im_w - 1), 0)
ymax = max(min(ymax, im_h - 1), 0)
where im_w and im_h are computed from im_info:
.. code-block:: text
im_h = round(height / scale)
im_w = round(weight / scale)
Args:
input
_box
(variable): The input box, the last dimension is 4.
input(variable): The input box, the last dimension is 4.
im_info(variable): The information of image with shape [N, 3] with
layout (height, width, scale). height and width
is the input size and scale is the ratio of input
size and original size.
inplace(bool): Must use :attr:`False` if :attr:`input
_box
` is used in
inplace(bool): Must use :attr:`False` if :attr:`input` is used in
multiple operators. If this flag is set :attr:`True`,
reuse input :attr:`input
_box
` to clip, which will
change the value of tensor variable :attr:`input
_box
`
and might cause errors when :attr:`input
_box
` is used
reuse input :attr:`input` to clip, which will
change the value of tensor variable :attr:`input`
and might cause errors when :attr:`input` is used
in multiple operators. If :attr:`False`, preserve the
value pf :attr:`input
_box
` and create a new output
value pf :attr:`input` and create a new output
tensor variable whose data is copied from input x but
cliped.
name (str): The name of this layer. It is optional.
...
...
@@ -1850,16 +1859,13 @@ def box_clip(input, im_info, inplace=False, name=None):
name='data', shape=[8, 4], dtype='float32', lod_level=1)
im_info = fluid.layers.data(name='im_info', shape=[3])
out = fluid.layers.box_clip(
input
_box
=boxes, im_info=im_info, inplace=True)
input=boxes, im_info=im_info, inplace=True)
"""
helper
=
LayerHelper
(
"box_clip"
,
**
locals
())
output
=
helper
.
create_variable_for_type_inference
(
dtype
=
input
.
dtype
)
output
=
x
if
inplace
else
helper
.
create_variable_for_type_inference
(
\
dtype
=
input
.
dtype
)
inputs
=
{
"Input"
:
input
,
"ImInfo"
:
im_info
}
helper
.
append_op
(
type
=
"box_clip"
,
inputs
=
inputs
,
attrs
=
{
"inplace:"
:
inplace
},
outputs
=
{
"Output"
:
output
})
helper
.
append_op
(
type
=
"box_clip"
,
inputs
=
inputs
,
outputs
=
{
"Output"
:
output
})
return
output
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录