Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
PaddleDetection
提交
be2d9dc2
P
PaddleDetection
项目概览
PaddlePaddle
/
PaddleDetection
大约 1 年 前同步成功
通知
695
Star
11112
Fork
2696
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
184
列表
看板
标记
里程碑
合并请求
40
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
PaddleDetection
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
184
Issue
184
列表
看板
标记
里程碑
合并请求
40
合并请求
40
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
be2d9dc2
编写于
7月 11, 2018
作者:
B
baiyf
提交者:
qingqing01
7月 11, 2018
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
Add prior_box output order control (#12032)
* Add flag to set prior_box output order.
上级
8e4b225f
变更
5
显示空白变更内容
内联
并排
Showing
5 changed file
with
161 addition
and
46 deletion
+161
-46
paddle/fluid/operators/detection/prior_box_op.cc
paddle/fluid/operators/detection/prior_box_op.cc
+7
-0
paddle/fluid/operators/detection/prior_box_op.cu
paddle/fluid/operators/detection/prior_box_op.cu
+26
-10
paddle/fluid/operators/detection/prior_box_op.h
paddle/fluid/operators/detection/prior_box_op.h
+50
-15
python/paddle/fluid/layers/detection.py
python/paddle/fluid/layers/detection.py
+18
-4
python/paddle/fluid/tests/unittests/test_prior_box_op.py
python/paddle/fluid/tests/unittests/test_prior_box_op.py
+60
-17
未找到文件。
paddle/fluid/operators/detection/prior_box_op.cc
浏览文件 @
be2d9dc2
...
...
@@ -149,6 +149,13 @@ class PriorBoxOpMaker : public framework::OpProtoAndCheckerMaker {
"(float) "
"Prior boxes center offset."
)
.
SetDefault
(
0.5
);
AddAttr
<
bool
>
(
"min_max_aspect_ratios_order"
,
"(bool) If set True, the output prior box is in order of"
"[min, max, aspect_ratios], which is consistent with Caffe."
"Please note, this order affects the weights order of convolution layer"
"followed by and does not affect the final detection results."
)
.
SetDefault
(
false
);
AddComment
(
R"DOC(
Prior box operator
Generate prior boxes for SSD(Single Shot MultiBox Detector) algorithm.
...
...
paddle/fluid/operators/detection/prior_box_op.cu
浏览文件 @
be2d9dc2
...
...
@@ -28,8 +28,8 @@ __global__ void GenPriorBox(T* out, const T* aspect_ratios, const int height,
const
int
im_width
,
const
int
as_num
,
const
T
offset
,
const
T
step_width
,
const
T
step_height
,
const
T
*
min_sizes
,
const
T
*
max_sizes
,
const
int
min_num
,
bool
is_clip
)
{
const
T
*
max_sizes
,
const
int
min_num
,
bool
is_clip
,
bool
min_max_aspect_ratios_order
)
{
int
num_priors
=
max_sizes
?
as_num
*
min_num
+
min_num
:
as_num
*
min_num
;
int
box_num
=
height
*
width
*
num_priors
;
for
(
int
i
=
blockIdx
.
x
*
blockDim
.
x
+
threadIdx
.
x
;
i
<
box_num
;
...
...
@@ -44,6 +44,7 @@ __global__ void GenPriorBox(T* out, const T* aspect_ratios, const int height,
T
min_size
=
min_sizes
[
m
];
if
(
max_sizes
)
{
int
s
=
p
%
(
as_num
+
1
);
if
(
!
min_max_aspect_ratios_order
)
{
if
(
s
<
as_num
)
{
T
ar
=
aspect_ratios
[
s
];
bw
=
min_size
*
sqrt
(
ar
)
/
2.
;
...
...
@@ -53,6 +54,19 @@ __global__ void GenPriorBox(T* out, const T* aspect_ratios, const int height,
bw
=
sqrt
(
min_size
*
max_size
)
/
2.
;
bh
=
bw
;
}
}
else
{
if
(
s
==
0
)
{
bw
=
bh
=
min_size
/
2.
;
}
else
if
(
s
==
1
)
{
T
max_size
=
max_sizes
[
m
];
bw
=
sqrt
(
min_size
*
max_size
)
/
2.
;
bh
=
bw
;
}
else
{
T
ar
=
aspect_ratios
[
s
-
1
];
bw
=
min_size
*
sqrt
(
ar
)
/
2.
;
bh
=
min_size
/
sqrt
(
ar
)
/
2.
;
}
}
}
else
{
int
s
=
p
%
as_num
;
T
ar
=
aspect_ratios
[
s
];
...
...
@@ -94,6 +108,8 @@ class PriorBoxOpCUDAKernel : public framework::OpKernel<T> {
auto
variances
=
ctx
.
Attr
<
std
::
vector
<
float
>>
(
"variances"
);
auto
flip
=
ctx
.
Attr
<
bool
>
(
"flip"
);
auto
clip
=
ctx
.
Attr
<
bool
>
(
"clip"
);
auto
min_max_aspect_ratios_order
=
ctx
.
Attr
<
bool
>
(
"min_max_aspect_ratios_order"
);
std
::
vector
<
float
>
aspect_ratios
;
ExpandAspectRatios
(
input_aspect_ratio
,
flip
,
&
aspect_ratios
);
...
...
@@ -149,7 +165,7 @@ class PriorBoxOpCUDAKernel : public framework::OpKernel<T> {
GenPriorBox
<
T
><<<
grid
,
block
,
0
,
stream
>>>
(
boxes
->
data
<
T
>
(),
r
.
data
<
T
>
(),
height
,
width
,
im_height
,
im_width
,
aspect_ratios
.
size
(),
offset
,
step_width
,
step_height
,
min
.
data
<
T
>
(),
max_data
,
min_num
,
clip
);
max_data
,
min_num
,
clip
,
min_max_aspect_ratios_order
);
framework
::
Tensor
v
;
framework
::
TensorFromVector
(
variances
,
ctx
.
device_context
(),
&
v
);
...
...
paddle/fluid/operators/detection/prior_box_op.h
浏览文件 @
be2d9dc2
...
...
@@ -68,6 +68,8 @@ class PriorBoxOpKernel : public framework::OpKernel<T> {
auto
variances
=
ctx
.
Attr
<
std
::
vector
<
float
>>
(
"variances"
);
auto
flip
=
ctx
.
Attr
<
bool
>
(
"flip"
);
auto
clip
=
ctx
.
Attr
<
bool
>
(
"clip"
);
auto
min_max_aspect_ratios_order
=
ctx
.
Attr
<
bool
>
(
"min_max_aspect_ratios_order"
);
std
::
vector
<
float
>
aspect_ratios
;
ExpandAspectRatios
(
input_aspect_ratio
,
flip
,
&
aspect_ratios
);
...
...
@@ -108,6 +110,38 @@ class PriorBoxOpKernel : public framework::OpKernel<T> {
int
idx
=
0
;
for
(
size_t
s
=
0
;
s
<
min_sizes
.
size
();
++
s
)
{
auto
min_size
=
min_sizes
[
s
];
if
(
min_max_aspect_ratios_order
)
{
box_width
=
box_height
=
min_size
/
2.
;
e_boxes
(
h
,
w
,
idx
,
0
)
=
(
center_x
-
box_width
)
/
img_width
;
e_boxes
(
h
,
w
,
idx
,
1
)
=
(
center_y
-
box_height
)
/
img_height
;
e_boxes
(
h
,
w
,
idx
,
2
)
=
(
center_x
+
box_width
)
/
img_width
;
e_boxes
(
h
,
w
,
idx
,
3
)
=
(
center_y
+
box_height
)
/
img_height
;
idx
++
;
if
(
max_sizes
.
size
()
>
0
)
{
auto
max_size
=
max_sizes
[
s
];
// square prior with size sqrt(minSize * maxSize)
box_width
=
box_height
=
sqrt
(
min_size
*
max_size
)
/
2.
;
e_boxes
(
h
,
w
,
idx
,
0
)
=
(
center_x
-
box_width
)
/
img_width
;
e_boxes
(
h
,
w
,
idx
,
1
)
=
(
center_y
-
box_height
)
/
img_height
;
e_boxes
(
h
,
w
,
idx
,
2
)
=
(
center_x
+
box_width
)
/
img_width
;
e_boxes
(
h
,
w
,
idx
,
3
)
=
(
center_y
+
box_height
)
/
img_height
;
idx
++
;
}
// priors with different aspect ratios
for
(
size_t
r
=
0
;
r
<
aspect_ratios
.
size
();
++
r
)
{
float
ar
=
aspect_ratios
[
r
];
if
(
fabs
(
ar
-
1.
)
<
1e-6
)
{
continue
;
}
box_width
=
min_size
*
sqrt
(
ar
)
/
2.
;
box_height
=
min_size
/
sqrt
(
ar
)
/
2.
;
e_boxes
(
h
,
w
,
idx
,
0
)
=
(
center_x
-
box_width
)
/
img_width
;
e_boxes
(
h
,
w
,
idx
,
1
)
=
(
center_y
-
box_height
)
/
img_height
;
e_boxes
(
h
,
w
,
idx
,
2
)
=
(
center_x
+
box_width
)
/
img_width
;
e_boxes
(
h
,
w
,
idx
,
3
)
=
(
center_y
+
box_height
)
/
img_height
;
idx
++
;
}
}
else
{
// priors with different aspect ratios
for
(
size_t
r
=
0
;
r
<
aspect_ratios
.
size
();
++
r
)
{
float
ar
=
aspect_ratios
[
r
];
...
...
@@ -132,6 +166,7 @@ class PriorBoxOpKernel : public framework::OpKernel<T> {
}
}
}
}
if
(
clip
)
{
platform
::
Transform
<
platform
::
CPUDeviceContext
>
trans
;
...
...
python/paddle/fluid/layers/detection.py
浏览文件 @
be2d9dc2
...
...
@@ -789,7 +789,8 @@ def prior_box(input,
clip
=
False
,
steps
=
[
0.0
,
0.0
],
offset
=
0.5
,
name
=
None
):
name
=
None
,
min_max_aspect_ratios_order
=
False
):
"""
**Prior Box Operator**
...
...
@@ -818,6 +819,11 @@ def prior_box(input,
Default: [0., 0.]
offset(float): Prior boxes center offset. Default: 0.5
name(str): Name of the prior box op. Default: None.
min_max_aspect_ratios_order(bool): If set True, the output prior box is
in order of [min, max, aspect_ratios], which is consistent with
Caffe. Please note, this order affects the weights order of
convolution layer followed by and does not affect the final
detection results. Default: False.
Returns:
tuple: A tuple with two Variable (boxes, variances)
...
...
@@ -871,7 +877,8 @@ def prior_box(input,
'clip'
:
clip
,
'step_w'
:
steps
[
0
],
'step_h'
:
steps
[
1
],
'offset'
:
offset
'offset'
:
offset
,
'min_max_aspect_ratios_order'
:
min_max_aspect_ratios_order
}
if
max_sizes
is
not
None
and
len
(
max_sizes
)
>
0
and
max_sizes
[
0
]
>
0
:
if
not
_is_list_or_tuple_
(
max_sizes
):
...
...
@@ -911,7 +918,8 @@ def multi_box_head(inputs,
kernel_size
=
1
,
pad
=
0
,
stride
=
1
,
name
=
None
):
name
=
None
,
min_max_aspect_ratios_order
=
False
):
"""
Generate prior boxes for SSD(Single Shot MultiBox Detector)
algorithm. The details of this algorithm, please refer the
...
...
@@ -954,6 +962,11 @@ def multi_box_head(inputs,
pad(int|list|tuple): The padding of conv2d. Default:0.
stride(int|list|tuple): The stride of conv2d. Default:1,
name(str): Name of the prior box layer. Default: None.
min_max_aspect_ratios_order(bool): If set True, the output prior box is
in order of [min, max, aspect_ratios], which is consistent with
Caffe. Please note, this order affects the weights order of
convolution layer followed by and does not affect the fininal
detection results. Default: False.
Returns:
tuple: A tuple with four Variables. (mbox_loc, mbox_conf, boxes, variances)
...
...
@@ -1068,7 +1081,8 @@ def multi_box_head(inputs,
step
=
[
step_w
[
i
]
if
step_w
else
0.0
,
step_h
[
i
]
if
step_w
else
0.0
]
box
,
var
=
prior_box
(
input
,
image
,
min_size
,
max_size
,
aspect_ratio
,
variance
,
flip
,
clip
,
step
,
offset
)
variance
,
flip
,
clip
,
step
,
offset
,
None
,
min_max_aspect_ratios_order
)
box_results
.
append
(
box
)
var_results
.
append
(
var
)
...
...
python/paddle/fluid/tests/unittests/test_prior_box_op.py
浏览文件 @
be2d9dc2
...
...
@@ -32,6 +32,7 @@ class TestPriorBoxOp(OpTest):
'variances'
:
self
.
variances
,
'flip'
:
self
.
flip
,
'clip'
:
self
.
clip
,
'min_max_aspect_ratios_order'
:
self
.
min_max_aspect_ratios_order
,
'step_w'
:
self
.
step_w
,
'step_h'
:
self
.
step_h
,
'offset'
:
self
.
offset
...
...
@@ -52,6 +53,9 @@ class TestPriorBoxOp(OpTest):
max_sizes
=
[
5
,
10
]
self
.
max_sizes
=
np
.
array
(
max_sizes
).
astype
(
'float32'
).
tolist
()
def
set_min_max_aspect_ratios_order
(
self
):
self
.
min_max_aspect_ratios_order
=
False
def
init_test_params
(
self
):
self
.
layer_w
=
32
self
.
layer_h
=
32
...
...
@@ -71,6 +75,7 @@ class TestPriorBoxOp(OpTest):
self
.
set_max_sizes
()
self
.
aspect_ratios
=
[
2.0
,
3.0
]
self
.
flip
=
True
self
.
set_min_max_aspect_ratios_order
()
self
.
real_aspect_ratios
=
[
1
,
2.0
,
1.0
/
2.0
,
3.0
,
1.0
/
3.0
]
self
.
aspect_ratios
=
np
.
array
(
self
.
aspect_ratios
,
dtype
=
np
.
float
).
flatten
()
...
...
@@ -78,7 +83,6 @@ class TestPriorBoxOp(OpTest):
self
.
variances
=
np
.
array
(
self
.
variances
,
dtype
=
np
.
float
).
flatten
()
self
.
clip
=
True
self
.
num_priors
=
len
(
self
.
real_aspect_ratios
)
*
len
(
self
.
min_sizes
)
if
len
(
self
.
max_sizes
)
>
0
:
self
.
num_priors
+=
len
(
self
.
max_sizes
)
...
...
@@ -106,26 +110,60 @@ class TestPriorBoxOp(OpTest):
idx
=
0
for
s
in
range
(
len
(
self
.
min_sizes
)):
min_size
=
self
.
min_sizes
[
s
]
if
not
self
.
min_max_aspect_ratios_order
:
# rest of priors
for
r
in
range
(
len
(
self
.
real_aspect_ratios
)):
ar
=
self
.
real_aspect_ratios
[
r
]
c_w
=
min_size
*
math
.
sqrt
(
ar
)
/
2
c_h
=
(
min_size
/
math
.
sqrt
(
ar
))
/
2
out_boxes
[
h
,
w
,
idx
,
:]
=
[(
c_x
-
c_w
)
/
self
.
image_w
,
(
c_y
-
c_h
)
/
self
.
image_h
,
(
c_x
+
c_w
)
/
self
.
image_w
,
(
c_y
+
c_h
)
/
self
.
image_h
]
out_boxes
[
h
,
w
,
idx
,
:]
=
[
(
c_x
-
c_w
)
/
self
.
image_w
,
(
c_y
-
c_h
)
/
self
.
image_h
,
(
c_x
+
c_w
)
/
self
.
image_w
,
(
c_y
+
c_h
)
/
self
.
image_h
]
idx
+=
1
if
len
(
self
.
max_sizes
)
>
0
:
max_size
=
self
.
max_sizes
[
s
]
# second prior: aspect_ratio = 1,
c_w
=
c_h
=
math
.
sqrt
(
min_size
*
max_size
)
/
2
out_boxes
[
h
,
w
,
idx
,
:]
=
[
(
c_x
-
c_w
)
/
self
.
image_w
,
(
c_y
-
c_h
)
/
self
.
image_h
,
(
c_x
+
c_w
)
/
self
.
image_w
,
(
c_y
+
c_h
)
/
self
.
image_h
]
idx
+=
1
else
:
c_w
=
c_h
=
min_size
/
2.
out_boxes
[
h
,
w
,
idx
,
:]
=
[(
c_x
-
c_w
)
/
self
.
image_w
,
(
c_y
-
c_h
)
/
self
.
image_h
,
(
c_x
+
c_w
)
/
self
.
image_w
,
(
c_y
+
c_h
)
/
self
.
image_h
]
idx
+=
1
if
len
(
self
.
max_sizes
)
>
0
:
max_size
=
self
.
max_sizes
[
s
]
# second prior: aspect_ratio = 1,
c_w
=
c_h
=
math
.
sqrt
(
min_size
*
max_size
)
/
2
out_boxes
[
h
,
w
,
idx
,
:]
=
[
(
c_x
-
c_w
)
/
self
.
image_w
,
(
c_y
-
c_h
)
/
self
.
image_h
,
(
c_x
+
c_w
)
/
self
.
image_w
,
(
c_y
+
c_h
)
/
self
.
image_h
]
idx
+=
1
# rest of priors
for
r
in
range
(
len
(
self
.
real_aspect_ratios
)):
ar
=
self
.
real_aspect_ratios
[
r
]
if
abs
(
ar
-
1.
)
<
1e-6
:
continue
c_w
=
min_size
*
math
.
sqrt
(
ar
)
/
2
c_h
=
(
min_size
/
math
.
sqrt
(
ar
))
/
2
out_boxes
[
h
,
w
,
idx
,
:]
=
[
(
c_x
-
c_w
)
/
self
.
image_w
,
(
c_y
-
c_h
)
/
self
.
image_h
,
(
c_x
+
c_w
)
/
self
.
image_w
,
(
c_y
+
c_h
)
/
self
.
image_h
]
idx
+=
1
# clip the prior's coordidate such that it is within[0, 1]
if
self
.
clip
:
...
...
@@ -137,10 +175,15 @@ class TestPriorBoxOp(OpTest):
self
.
out_var
=
out_var
.
astype
(
'float32'
)
class
TestPriorBoxOpWithMaxSize
(
TestPriorBoxOp
):
class
TestPriorBoxOpWith
out
MaxSize
(
TestPriorBoxOp
):
def
set_max_sizes
(
self
):
self
.
max_sizes
=
[]
class
TestPriorBoxOpWithSpecifiedOutOrder
(
TestPriorBoxOp
):
def
set_min_max_aspect_ratios_order
(
self
):
self
.
min_max_aspect_ratios_order
=
True
if
__name__
==
'__main__'
:
unittest
.
main
()
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录