Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
Crayon鑫
Paddle
提交
5262b025
P
Paddle
项目概览
Crayon鑫
/
Paddle
与 Fork 源项目一致
Fork自
PaddlePaddle / Paddle
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
1
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
1
Issue
1
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
5262b025
编写于
11月 02, 2020
作者:
W
wangguanzhong
提交者:
GitHub
11月 02, 2020
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
add generate_proposals_v2 op (#28214)
* add generate_proposals_v2 op
上级
b96869bc
变更
11
隐藏空白更改
内联
并排
Showing
11 changed file
with
1271 addition
and
503 deletion
+1271
-503
paddle/fluid/operators/detection/CMakeLists.txt
paddle/fluid/operators/detection/CMakeLists.txt
+2
-0
paddle/fluid/operators/detection/bbox_util.cu.h
paddle/fluid/operators/detection/bbox_util.cu.h
+285
-0
paddle/fluid/operators/detection/bbox_util.h
paddle/fluid/operators/detection/bbox_util.h
+127
-15
paddle/fluid/operators/detection/generate_proposals_op.cc
paddle/fluid/operators/detection/generate_proposals_op.cc
+4
-233
paddle/fluid/operators/detection/generate_proposals_op.cu
paddle/fluid/operators/detection/generate_proposals_op.cu
+1
-255
paddle/fluid/operators/detection/generate_proposals_v2_op.cc
paddle/fluid/operators/detection/generate_proposals_v2_op.cc
+314
-0
paddle/fluid/operators/detection/generate_proposals_v2_op.cu
paddle/fluid/operators/detection/generate_proposals_v2_op.cu
+229
-0
paddle/fluid/operators/detection/nms_util.h
paddle/fluid/operators/detection/nms_util.h
+69
-0
paddle/fluid/pybind/op_function_generator.cc
paddle/fluid/pybind/op_function_generator.cc
+1
-0
python/paddle/fluid/tests/unittests/test_generate_proposals_v2_op.py
...le/fluid/tests/unittests/test_generate_proposals_v2_op.py
+238
-0
tools/static_mode_white_list.py
tools/static_mode_white_list.py
+1
-0
未找到文件。
paddle/fluid/operators/detection/CMakeLists.txt
浏览文件 @
5262b025
...
@@ -46,10 +46,12 @@ if(WITH_GPU)
...
@@ -46,10 +46,12 @@ if(WITH_GPU)
set
(
TMPDEPS memory cub
)
set
(
TMPDEPS memory cub
)
endif
()
endif
()
detection_library
(
generate_proposals_op SRCS generate_proposals_op.cc generate_proposals_op.cu DEPS
${
TMPDEPS
}
)
detection_library
(
generate_proposals_op SRCS generate_proposals_op.cc generate_proposals_op.cu DEPS
${
TMPDEPS
}
)
detection_library
(
generate_proposals_v2_op SRCS generate_proposals_v2_op.cc generate_proposals_v2_op.cu DEPS
${
TMPDEPS
}
)
detection_library
(
distribute_fpn_proposals_op SRCS distribute_fpn_proposals_op.cc distribute_fpn_proposals_op.cu DEPS
${
TMPDEPS
}
)
detection_library
(
distribute_fpn_proposals_op SRCS distribute_fpn_proposals_op.cc distribute_fpn_proposals_op.cu DEPS
${
TMPDEPS
}
)
detection_library
(
collect_fpn_proposals_op SRCS collect_fpn_proposals_op.cc collect_fpn_proposals_op.cu DEPS
${
TMPDEPS
}
)
detection_library
(
collect_fpn_proposals_op SRCS collect_fpn_proposals_op.cc collect_fpn_proposals_op.cu DEPS
${
TMPDEPS
}
)
else
()
else
()
detection_library
(
generate_proposals_op SRCS generate_proposals_op.cc
)
detection_library
(
generate_proposals_op SRCS generate_proposals_op.cc
)
detection_library
(
generate_proposals_v2_op SRCS generate_proposals_v2_op.cc
)
detection_library
(
distribute_fpn_proposals_op SRCS distribute_fpn_proposals_op.cc
)
detection_library
(
distribute_fpn_proposals_op SRCS distribute_fpn_proposals_op.cc
)
detection_library
(
collect_fpn_proposals_op SRCS collect_fpn_proposals_op.cc
)
detection_library
(
collect_fpn_proposals_op SRCS collect_fpn_proposals_op.cc
)
endif
()
endif
()
...
...
paddle/fluid/operators/detection/bbox_util.cu.h
0 → 100644
浏览文件 @
5262b025
/* Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include <algorithm>
#include <cfloat>
#include <string>
#include <vector>
#include "cub/cub.cuh"
#include "paddle/fluid/operators/gather.cu.h"
#include "paddle/fluid/operators/math/math_function.h"
#include "paddle/fluid/platform/cudnn_helper.h"
#include "paddle/fluid/platform/for_range.h"
namespace
paddle
{
namespace
operators
{
using
Tensor
=
framework
::
Tensor
;
using
LoDTensor
=
framework
::
LoDTensor
;
#define DIVUP(m, n) ((m) / (n) + ((m) % (n) > 0))
int
const
kThreadsPerBlock
=
sizeof
(
uint64_t
)
*
8
;
static
const
double
kBBoxClipDefault
=
std
::
log
(
1000.0
/
16.0
);
struct
RangeInitFunctor
{
int
start_
;
int
delta_
;
int
*
out_
;
__device__
void
operator
()(
size_t
i
)
{
out_
[
i
]
=
start_
+
i
*
delta_
;
}
};
template
<
typename
T
>
static
void
SortDescending
(
const
platform
::
CUDADeviceContext
&
ctx
,
const
Tensor
&
value
,
Tensor
*
value_out
,
Tensor
*
index_out
)
{
int
num
=
static_cast
<
int
>
(
value
.
numel
());
Tensor
index_in_t
;
int
*
idx_in
=
index_in_t
.
mutable_data
<
int
>
({
num
},
ctx
.
GetPlace
());
platform
::
ForRange
<
platform
::
CUDADeviceContext
>
for_range
(
ctx
,
num
);
for_range
(
RangeInitFunctor
{
0
,
1
,
idx_in
});
int
*
idx_out
=
index_out
->
mutable_data
<
int
>
({
num
},
ctx
.
GetPlace
());
const
T
*
keys_in
=
value
.
data
<
T
>
();
T
*
keys_out
=
value_out
->
mutable_data
<
T
>
({
num
},
ctx
.
GetPlace
());
// Determine temporary device storage requirements
size_t
temp_storage_bytes
=
0
;
cub
::
DeviceRadixSort
::
SortPairsDescending
<
T
,
int
>
(
nullptr
,
temp_storage_bytes
,
keys_in
,
keys_out
,
idx_in
,
idx_out
,
num
);
// Allocate temporary storage
auto
place
=
BOOST_GET_CONST
(
platform
::
CUDAPlace
,
ctx
.
GetPlace
());
auto
d_temp_storage
=
memory
::
Alloc
(
place
,
temp_storage_bytes
);
// Run sorting operation
cub
::
DeviceRadixSort
::
SortPairsDescending
<
T
,
int
>
(
d_temp_storage
->
ptr
(),
temp_storage_bytes
,
keys_in
,
keys_out
,
idx_in
,
idx_out
,
num
);
}
template
<
typename
T
>
struct
BoxDecodeAndClipFunctor
{
const
T
*
anchor
;
const
T
*
deltas
;
const
T
*
var
;
const
int
*
index
;
const
T
*
im_info
;
T
*
proposals
;
BoxDecodeAndClipFunctor
(
const
T
*
anchor
,
const
T
*
deltas
,
const
T
*
var
,
const
int
*
index
,
const
T
*
im_info
,
T
*
proposals
)
:
anchor
(
anchor
),
deltas
(
deltas
),
var
(
var
),
index
(
index
),
im_info
(
im_info
),
proposals
(
proposals
)
{}
T
bbox_clip_default
{
static_cast
<
T
>
(
kBBoxClipDefault
)};
__device__
void
operator
()(
size_t
i
)
{
int
k
=
index
[
i
]
*
4
;
T
axmin
=
anchor
[
k
];
T
aymin
=
anchor
[
k
+
1
];
T
axmax
=
anchor
[
k
+
2
];
T
aymax
=
anchor
[
k
+
3
];
T
w
=
axmax
-
axmin
+
1.0
;
T
h
=
aymax
-
aymin
+
1.0
;
T
cx
=
axmin
+
0.5
*
w
;
T
cy
=
aymin
+
0.5
*
h
;
T
dxmin
=
deltas
[
k
];
T
dymin
=
deltas
[
k
+
1
];
T
dxmax
=
deltas
[
k
+
2
];
T
dymax
=
deltas
[
k
+
3
];
T
d_cx
,
d_cy
,
d_w
,
d_h
;
if
(
var
)
{
d_cx
=
cx
+
dxmin
*
w
*
var
[
k
];
d_cy
=
cy
+
dymin
*
h
*
var
[
k
+
1
];
d_w
=
exp
(
Min
(
dxmax
*
var
[
k
+
2
],
bbox_clip_default
))
*
w
;
d_h
=
exp
(
Min
(
dymax
*
var
[
k
+
3
],
bbox_clip_default
))
*
h
;
}
else
{
d_cx
=
cx
+
dxmin
*
w
;
d_cy
=
cy
+
dymin
*
h
;
d_w
=
exp
(
Min
(
dxmax
,
bbox_clip_default
))
*
w
;
d_h
=
exp
(
Min
(
dymax
,
bbox_clip_default
))
*
h
;
}
T
oxmin
=
d_cx
-
d_w
*
0.5
;
T
oymin
=
d_cy
-
d_h
*
0.5
;
T
oxmax
=
d_cx
+
d_w
*
0.5
-
1.
;
T
oymax
=
d_cy
+
d_h
*
0.5
-
1.
;
proposals
[
i
*
4
]
=
Max
(
Min
(
oxmin
,
im_info
[
1
]
-
1.
),
0.
);
proposals
[
i
*
4
+
1
]
=
Max
(
Min
(
oymin
,
im_info
[
0
]
-
1.
),
0.
);
proposals
[
i
*
4
+
2
]
=
Max
(
Min
(
oxmax
,
im_info
[
1
]
-
1.
),
0.
);
proposals
[
i
*
4
+
3
]
=
Max
(
Min
(
oymax
,
im_info
[
0
]
-
1.
),
0.
);
}
__device__
__forceinline__
T
Min
(
T
a
,
T
b
)
const
{
return
a
>
b
?
b
:
a
;
}
__device__
__forceinline__
T
Max
(
T
a
,
T
b
)
const
{
return
a
>
b
?
a
:
b
;
}
};
template
<
typename
T
,
int
BlockSize
>
static
__global__
void
FilterBBoxes
(
const
T
*
bboxes
,
const
T
*
im_info
,
const
T
min_size
,
const
int
num
,
int
*
keep_num
,
int
*
keep
,
bool
is_scale
=
true
)
{
T
im_h
=
im_info
[
0
];
T
im_w
=
im_info
[
1
];
int
cnt
=
0
;
__shared__
int
keep_index
[
BlockSize
];
CUDA_KERNEL_LOOP
(
i
,
num
)
{
keep_index
[
threadIdx
.
x
]
=
-
1
;
__syncthreads
();
int
k
=
i
*
4
;
T
xmin
=
bboxes
[
k
];
T
ymin
=
bboxes
[
k
+
1
];
T
xmax
=
bboxes
[
k
+
2
];
T
ymax
=
bboxes
[
k
+
3
];
T
w
=
xmax
-
xmin
+
1.0
;
T
h
=
ymax
-
ymin
+
1.0
;
T
cx
=
xmin
+
w
/
2.
;
T
cy
=
ymin
+
h
/
2.
;
if
(
is_scale
)
{
w
=
(
xmax
-
xmin
)
/
im_info
[
2
]
+
1.
;
h
=
(
ymax
-
ymin
)
/
im_info
[
2
]
+
1.
;
}
if
(
w
>=
min_size
&&
h
>=
min_size
&&
cx
<=
im_w
&&
cy
<=
im_h
)
{
keep_index
[
threadIdx
.
x
]
=
i
;
}
__syncthreads
();
if
(
threadIdx
.
x
==
0
)
{
int
size
=
(
num
-
i
)
<
BlockSize
?
num
-
i
:
BlockSize
;
for
(
int
j
=
0
;
j
<
size
;
++
j
)
{
if
(
keep_index
[
j
]
>
-
1
)
{
keep
[
cnt
++
]
=
keep_index
[
j
];
}
}
}
__syncthreads
();
}
if
(
threadIdx
.
x
==
0
)
{
keep_num
[
0
]
=
cnt
;
}
}
static
__device__
float
IoU
(
const
float
*
a
,
const
float
*
b
)
{
float
left
=
max
(
a
[
0
],
b
[
0
]),
right
=
min
(
a
[
2
],
b
[
2
]);
float
top
=
max
(
a
[
1
],
b
[
1
]),
bottom
=
min
(
a
[
3
],
b
[
3
]);
float
width
=
max
(
right
-
left
+
1
,
0.
f
),
height
=
max
(
bottom
-
top
+
1
,
0.
f
);
float
inter_s
=
width
*
height
;
float
s_a
=
(
a
[
2
]
-
a
[
0
]
+
1
)
*
(
a
[
3
]
-
a
[
1
]
+
1
);
float
s_b
=
(
b
[
2
]
-
b
[
0
]
+
1
)
*
(
b
[
3
]
-
b
[
1
]
+
1
);
return
inter_s
/
(
s_a
+
s_b
-
inter_s
);
}
static
__global__
void
NMSKernel
(
const
int
n_boxes
,
const
float
nms_overlap_thresh
,
const
float
*
dev_boxes
,
uint64_t
*
dev_mask
)
{
const
int
row_start
=
blockIdx
.
y
;
const
int
col_start
=
blockIdx
.
x
;
const
int
row_size
=
min
(
n_boxes
-
row_start
*
kThreadsPerBlock
,
kThreadsPerBlock
);
const
int
col_size
=
min
(
n_boxes
-
col_start
*
kThreadsPerBlock
,
kThreadsPerBlock
);
__shared__
float
block_boxes
[
kThreadsPerBlock
*
4
];
if
(
threadIdx
.
x
<
col_size
)
{
block_boxes
[
threadIdx
.
x
*
4
+
0
]
=
dev_boxes
[(
kThreadsPerBlock
*
col_start
+
threadIdx
.
x
)
*
4
+
0
];
block_boxes
[
threadIdx
.
x
*
4
+
1
]
=
dev_boxes
[(
kThreadsPerBlock
*
col_start
+
threadIdx
.
x
)
*
4
+
1
];
block_boxes
[
threadIdx
.
x
*
4
+
2
]
=
dev_boxes
[(
kThreadsPerBlock
*
col_start
+
threadIdx
.
x
)
*
4
+
2
];
block_boxes
[
threadIdx
.
x
*
4
+
3
]
=
dev_boxes
[(
kThreadsPerBlock
*
col_start
+
threadIdx
.
x
)
*
4
+
3
];
}
__syncthreads
();
if
(
threadIdx
.
x
<
row_size
)
{
const
int
cur_box_idx
=
kThreadsPerBlock
*
row_start
+
threadIdx
.
x
;
const
float
*
cur_box
=
dev_boxes
+
cur_box_idx
*
4
;
int
i
=
0
;
uint64_t
t
=
0
;
int
start
=
0
;
if
(
row_start
==
col_start
)
{
start
=
threadIdx
.
x
+
1
;
}
for
(
i
=
start
;
i
<
col_size
;
i
++
)
{
if
(
IoU
(
cur_box
,
block_boxes
+
i
*
4
)
>
nms_overlap_thresh
)
{
t
|=
1ULL
<<
i
;
}
}
const
int
col_blocks
=
DIVUP
(
n_boxes
,
kThreadsPerBlock
);
dev_mask
[
cur_box_idx
*
col_blocks
+
col_start
]
=
t
;
}
}
template
<
typename
T
>
static
void
NMS
(
const
platform
::
CUDADeviceContext
&
ctx
,
const
Tensor
&
proposals
,
const
Tensor
&
sorted_indices
,
const
T
nms_threshold
,
Tensor
*
keep_out
)
{
int
boxes_num
=
proposals
.
dims
()[
0
];
const
int
col_blocks
=
DIVUP
(
boxes_num
,
kThreadsPerBlock
);
dim3
blocks
(
DIVUP
(
boxes_num
,
kThreadsPerBlock
),
DIVUP
(
boxes_num
,
kThreadsPerBlock
));
dim3
threads
(
kThreadsPerBlock
);
const
T
*
boxes
=
proposals
.
data
<
T
>
();
auto
place
=
BOOST_GET_CONST
(
platform
::
CUDAPlace
,
ctx
.
GetPlace
());
framework
::
Vector
<
uint64_t
>
mask
(
boxes_num
*
col_blocks
);
NMSKernel
<<<
blocks
,
threads
>>>
(
boxes_num
,
nms_threshold
,
boxes
,
mask
.
CUDAMutableData
(
BOOST_GET_CONST
(
platform
::
CUDAPlace
,
ctx
.
GetPlace
())));
std
::
vector
<
uint64_t
>
remv
(
col_blocks
);
memset
(
&
remv
[
0
],
0
,
sizeof
(
uint64_t
)
*
col_blocks
);
std
::
vector
<
int
>
keep_vec
;
int
num_to_keep
=
0
;
for
(
int
i
=
0
;
i
<
boxes_num
;
i
++
)
{
int
nblock
=
i
/
kThreadsPerBlock
;
int
inblock
=
i
%
kThreadsPerBlock
;
if
(
!
(
remv
[
nblock
]
&
(
1ULL
<<
inblock
)))
{
++
num_to_keep
;
keep_vec
.
push_back
(
i
);
uint64_t
*
p
=
&
mask
[
0
]
+
i
*
col_blocks
;
for
(
int
j
=
nblock
;
j
<
col_blocks
;
j
++
)
{
remv
[
j
]
|=
p
[
j
];
}
}
}
int
*
keep
=
keep_out
->
mutable_data
<
int
>
({
num_to_keep
},
ctx
.
GetPlace
());
memory
::
Copy
(
place
,
keep
,
platform
::
CPUPlace
(),
keep_vec
.
data
(),
sizeof
(
int
)
*
num_to_keep
,
ctx
.
stream
());
ctx
.
Wait
();
}
}
// namespace operators
}
// namespace paddle
paddle/fluid/operators/detection/bbox_util.h
浏览文件 @
5262b025
...
@@ -21,6 +21,8 @@ limitations under the License. */
...
@@ -21,6 +21,8 @@ limitations under the License. */
namespace
paddle
{
namespace
paddle
{
namespace
operators
{
namespace
operators
{
static
const
double
kBBoxClipDefault
=
std
::
log
(
1000.0
/
16.0
);
struct
RangeInitFunctor
{
struct
RangeInitFunctor
{
int
start
;
int
start
;
int
delta
;
int
delta
;
...
@@ -125,17 +127,45 @@ void BboxOverlaps(const framework::Tensor& r_boxes,
...
@@ -125,17 +127,45 @@ void BboxOverlaps(const framework::Tensor& r_boxes,
}
}
}
}
// Calculate max IoU between each box and ground-truth and
// each row represents one box
template
<
typename
T
>
void
MaxIoU
(
const
framework
::
Tensor
&
iou
,
framework
::
Tensor
*
max_iou
)
{
const
T
*
iou_data
=
iou
.
data
<
T
>
();
int
row
=
iou
.
dims
()[
0
];
int
col
=
iou
.
dims
()[
1
];
T
*
max_iou_data
=
max_iou
->
data
<
T
>
();
for
(
int
i
=
0
;
i
<
row
;
++
i
)
{
const
T
*
v
=
iou_data
+
i
*
col
;
T
max_v
=
*
std
::
max_element
(
v
,
v
+
col
);
max_iou_data
[
i
]
=
max_v
;
}
}
static
void
AppendProposals
(
framework
::
Tensor
*
dst
,
int64_t
offset
,
const
framework
::
Tensor
&
src
)
{
auto
*
out_data
=
dst
->
data
<
void
>
();
auto
*
to_add_data
=
src
.
data
<
void
>
();
size_t
size_of_t
=
framework
::
SizeOfType
(
src
.
type
());
offset
*=
size_of_t
;
std
::
memcpy
(
reinterpret_cast
<
void
*>
(
reinterpret_cast
<
uintptr_t
>
(
out_data
)
+
offset
),
to_add_data
,
src
.
numel
()
*
size_of_t
);
}
template
<
class
T
>
template
<
class
T
>
void
ClipTiledBoxes
(
const
platform
::
DeviceContext
&
ctx
,
void
ClipTiledBoxes
(
const
platform
::
DeviceContext
&
ctx
,
const
framework
::
Tensor
&
im_info
,
const
framework
::
Tensor
&
im_info
,
const
framework
::
Tensor
&
input_boxes
,
const
framework
::
Tensor
&
input_boxes
,
framework
::
Tensor
*
out
)
{
framework
::
Tensor
*
out
,
bool
is_scale
=
true
)
{
T
*
out_data
=
out
->
mutable_data
<
T
>
(
ctx
.
GetPlace
());
T
*
out_data
=
out
->
mutable_data
<
T
>
(
ctx
.
GetPlace
());
const
T
*
im_info_data
=
im_info
.
data
<
T
>
();
const
T
*
im_info_data
=
im_info
.
data
<
T
>
();
const
T
*
input_boxes_data
=
input_boxes
.
data
<
T
>
();
const
T
*
input_boxes_data
=
input_boxes
.
data
<
T
>
();
T
zero
(
0
);
T
zero
(
0
);
T
im_w
=
round
(
im_info_data
[
1
]
/
im_info_data
[
2
]);
T
im_w
=
T
im_h
=
round
(
im_info_data
[
0
]
/
im_info_data
[
2
]);
is_scale
?
round
(
im_info_data
[
1
]
/
im_info_data
[
2
])
:
im_info_data
[
1
];
T
im_h
=
is_scale
?
round
(
im_info_data
[
0
]
/
im_info_data
[
2
])
:
im_info_data
[
0
];
for
(
int64_t
i
=
0
;
i
<
input_boxes
.
numel
();
++
i
)
{
for
(
int64_t
i
=
0
;
i
<
input_boxes
.
numel
();
++
i
)
{
if
(
i
%
4
==
0
)
{
if
(
i
%
4
==
0
)
{
out_data
[
i
]
=
std
::
max
(
std
::
min
(
input_boxes_data
[
i
],
im_w
-
1
),
zero
);
out_data
[
i
]
=
std
::
max
(
std
::
min
(
input_boxes_data
[
i
],
im_w
-
1
),
zero
);
...
@@ -149,19 +179,101 @@ void ClipTiledBoxes(const platform::DeviceContext& ctx,
...
@@ -149,19 +179,101 @@ void ClipTiledBoxes(const platform::DeviceContext& ctx,
}
}
}
}
// Calculate max IoU between each box and ground-truth and
// Filter the box with small area
// each row represents one box
template
<
class
T
>
template
<
typename
T
>
void
FilterBoxes
(
const
platform
::
DeviceContext
&
ctx
,
void
MaxIoU
(
const
framework
::
Tensor
&
iou
,
framework
::
Tensor
*
max_iou
)
{
const
framework
::
Tensor
*
boxes
,
float
min_size
,
const
T
*
iou_data
=
iou
.
data
<
T
>
();
const
framework
::
Tensor
&
im_info
,
bool
is_scale
,
int
row
=
iou
.
dims
()[
0
];
framework
::
Tensor
*
keep
)
{
int
col
=
iou
.
dims
()[
1
];
const
T
*
im_info_data
=
im_info
.
data
<
T
>
();
T
*
max_iou_data
=
max_iou
->
data
<
T
>
();
const
T
*
boxes_data
=
boxes
->
data
<
T
>
();
for
(
int
i
=
0
;
i
<
row
;
++
i
)
{
keep
->
Resize
({
boxes
->
dims
()[
0
]});
const
T
*
v
=
iou_data
+
i
*
col
;
min_size
=
std
::
max
(
min_size
,
1.0
f
);
T
max_v
=
*
std
::
max_element
(
v
,
v
+
col
);
int
*
keep_data
=
keep
->
mutable_data
<
int
>
(
ctx
.
GetPlace
());
max_iou_data
[
i
]
=
max_v
;
int
keep_len
=
0
;
for
(
int
i
=
0
;
i
<
boxes
->
dims
()[
0
];
++
i
)
{
T
ws
=
boxes_data
[
4
*
i
+
2
]
-
boxes_data
[
4
*
i
]
+
1
;
T
hs
=
boxes_data
[
4
*
i
+
3
]
-
boxes_data
[
4
*
i
+
1
]
+
1
;
T
x_ctr
=
boxes_data
[
4
*
i
]
+
ws
/
2
;
T
y_ctr
=
boxes_data
[
4
*
i
+
1
]
+
hs
/
2
;
if
(
is_scale
)
{
ws
=
(
boxes_data
[
4
*
i
+
2
]
-
boxes_data
[
4
*
i
])
/
im_info_data
[
2
]
+
1
;
hs
=
(
boxes_data
[
4
*
i
+
3
]
-
boxes_data
[
4
*
i
+
1
])
/
im_info_data
[
2
]
+
1
;
}
if
(
ws
>=
min_size
&&
hs
>=
min_size
&&
x_ctr
<=
im_info_data
[
1
]
&&
y_ctr
<=
im_info_data
[
0
])
{
keep_data
[
keep_len
++
]
=
i
;
}
}
keep
->
Resize
({
keep_len
});
}
template
<
class
T
>
static
void
BoxCoder
(
const
platform
::
DeviceContext
&
ctx
,
framework
::
Tensor
*
all_anchors
,
framework
::
Tensor
*
bbox_deltas
,
framework
::
Tensor
*
variances
,
framework
::
Tensor
*
proposals
)
{
T
*
proposals_data
=
proposals
->
mutable_data
<
T
>
(
ctx
.
GetPlace
());
int64_t
row
=
all_anchors
->
dims
()[
0
];
int64_t
len
=
all_anchors
->
dims
()[
1
];
auto
*
bbox_deltas_data
=
bbox_deltas
->
data
<
T
>
();
auto
*
anchor_data
=
all_anchors
->
data
<
T
>
();
const
T
*
variances_data
=
nullptr
;
if
(
variances
)
{
variances_data
=
variances
->
data
<
T
>
();
}
for
(
int64_t
i
=
0
;
i
<
row
;
++
i
)
{
T
anchor_width
=
anchor_data
[
i
*
len
+
2
]
-
anchor_data
[
i
*
len
]
+
1.0
;
T
anchor_height
=
anchor_data
[
i
*
len
+
3
]
-
anchor_data
[
i
*
len
+
1
]
+
1.0
;
T
anchor_center_x
=
anchor_data
[
i
*
len
]
+
0.5
*
anchor_width
;
T
anchor_center_y
=
anchor_data
[
i
*
len
+
1
]
+
0.5
*
anchor_height
;
T
bbox_center_x
=
0
,
bbox_center_y
=
0
;
T
bbox_width
=
0
,
bbox_height
=
0
;
if
(
variances
)
{
bbox_center_x
=
variances_data
[
i
*
len
]
*
bbox_deltas_data
[
i
*
len
]
*
anchor_width
+
anchor_center_x
;
bbox_center_y
=
variances_data
[
i
*
len
+
1
]
*
bbox_deltas_data
[
i
*
len
+
1
]
*
anchor_height
+
anchor_center_y
;
bbox_width
=
std
::
exp
(
std
::
min
<
T
>
(
variances_data
[
i
*
len
+
2
]
*
bbox_deltas_data
[
i
*
len
+
2
],
kBBoxClipDefault
))
*
anchor_width
;
bbox_height
=
std
::
exp
(
std
::
min
<
T
>
(
variances_data
[
i
*
len
+
3
]
*
bbox_deltas_data
[
i
*
len
+
3
],
kBBoxClipDefault
))
*
anchor_height
;
}
else
{
bbox_center_x
=
bbox_deltas_data
[
i
*
len
]
*
anchor_width
+
anchor_center_x
;
bbox_center_y
=
bbox_deltas_data
[
i
*
len
+
1
]
*
anchor_height
+
anchor_center_y
;
bbox_width
=
std
::
exp
(
std
::
min
<
T
>
(
bbox_deltas_data
[
i
*
len
+
2
],
kBBoxClipDefault
))
*
anchor_width
;
bbox_height
=
std
::
exp
(
std
::
min
<
T
>
(
bbox_deltas_data
[
i
*
len
+
3
],
kBBoxClipDefault
))
*
anchor_height
;
}
proposals_data
[
i
*
len
]
=
bbox_center_x
-
bbox_width
/
2
;
proposals_data
[
i
*
len
+
1
]
=
bbox_center_y
-
bbox_height
/
2
;
proposals_data
[
i
*
len
+
2
]
=
bbox_center_x
+
bbox_width
/
2
-
1
;
proposals_data
[
i
*
len
+
3
]
=
bbox_center_y
+
bbox_height
/
2
-
1
;
}
}
// return proposals;
}
}
}
// namespace operators
}
// namespace operators
...
...
paddle/fluid/operators/detection/generate_proposals_op.cc
浏览文件 @
5262b025
...
@@ -18,6 +18,8 @@ limitations under the License. */
...
@@ -18,6 +18,8 @@ limitations under the License. */
#include <vector>
#include <vector>
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/framework/op_version_registry.h"
#include "paddle/fluid/framework/op_version_registry.h"
#include "paddle/fluid/operators/detection/bbox_util.h"
#include "paddle/fluid/operators/detection/nms_util.h"
#include "paddle/fluid/operators/gather.h"
#include "paddle/fluid/operators/gather.h"
#include "paddle/fluid/operators/math/math_function.h"
#include "paddle/fluid/operators/math/math_function.h"
...
@@ -27,18 +29,6 @@ namespace operators {
...
@@ -27,18 +29,6 @@ namespace operators {
using
Tensor
=
framework
::
Tensor
;
using
Tensor
=
framework
::
Tensor
;
using
LoDTensor
=
framework
::
LoDTensor
;
using
LoDTensor
=
framework
::
LoDTensor
;
static
const
double
kBBoxClipDefault
=
std
::
log
(
1000.0
/
16.0
);
static
void
AppendProposals
(
Tensor
*
dst
,
int64_t
offset
,
const
Tensor
&
src
)
{
auto
*
out_data
=
dst
->
data
<
void
>
();
auto
*
to_add_data
=
src
.
data
<
void
>
();
size_t
size_of_t
=
framework
::
SizeOfType
(
src
.
type
());
offset
*=
size_of_t
;
std
::
memcpy
(
reinterpret_cast
<
void
*>
(
reinterpret_cast
<
uintptr_t
>
(
out_data
)
+
offset
),
to_add_data
,
src
.
numel
()
*
size_of_t
);
}
class
GenerateProposalsOp
:
public
framework
::
OperatorWithKernel
{
class
GenerateProposalsOp
:
public
framework
::
OperatorWithKernel
{
public:
public:
using
framework
::
OperatorWithKernel
::
OperatorWithKernel
;
using
framework
::
OperatorWithKernel
::
OperatorWithKernel
;
...
@@ -77,225 +67,6 @@ class GenerateProposalsOp : public framework::OperatorWithKernel {
...
@@ -77,225 +67,6 @@ class GenerateProposalsOp : public framework::OperatorWithKernel {
}
}
};
};
template
<
class
T
>
static
inline
void
BoxCoder
(
const
platform
::
DeviceContext
&
ctx
,
Tensor
*
all_anchors
,
Tensor
*
bbox_deltas
,
Tensor
*
variances
,
Tensor
*
proposals
)
{
T
*
proposals_data
=
proposals
->
mutable_data
<
T
>
(
ctx
.
GetPlace
());
int64_t
row
=
all_anchors
->
dims
()[
0
];
int64_t
len
=
all_anchors
->
dims
()[
1
];
auto
*
bbox_deltas_data
=
bbox_deltas
->
data
<
T
>
();
auto
*
anchor_data
=
all_anchors
->
data
<
T
>
();
const
T
*
variances_data
=
nullptr
;
if
(
variances
)
{
variances_data
=
variances
->
data
<
T
>
();
}
for
(
int64_t
i
=
0
;
i
<
row
;
++
i
)
{
T
anchor_width
=
anchor_data
[
i
*
len
+
2
]
-
anchor_data
[
i
*
len
]
+
1.0
;
T
anchor_height
=
anchor_data
[
i
*
len
+
3
]
-
anchor_data
[
i
*
len
+
1
]
+
1.0
;
T
anchor_center_x
=
anchor_data
[
i
*
len
]
+
0.5
*
anchor_width
;
T
anchor_center_y
=
anchor_data
[
i
*
len
+
1
]
+
0.5
*
anchor_height
;
T
bbox_center_x
=
0
,
bbox_center_y
=
0
;
T
bbox_width
=
0
,
bbox_height
=
0
;
if
(
variances
)
{
bbox_center_x
=
variances_data
[
i
*
len
]
*
bbox_deltas_data
[
i
*
len
]
*
anchor_width
+
anchor_center_x
;
bbox_center_y
=
variances_data
[
i
*
len
+
1
]
*
bbox_deltas_data
[
i
*
len
+
1
]
*
anchor_height
+
anchor_center_y
;
bbox_width
=
std
::
exp
(
std
::
min
<
T
>
(
variances_data
[
i
*
len
+
2
]
*
bbox_deltas_data
[
i
*
len
+
2
],
kBBoxClipDefault
))
*
anchor_width
;
bbox_height
=
std
::
exp
(
std
::
min
<
T
>
(
variances_data
[
i
*
len
+
3
]
*
bbox_deltas_data
[
i
*
len
+
3
],
kBBoxClipDefault
))
*
anchor_height
;
}
else
{
bbox_center_x
=
bbox_deltas_data
[
i
*
len
]
*
anchor_width
+
anchor_center_x
;
bbox_center_y
=
bbox_deltas_data
[
i
*
len
+
1
]
*
anchor_height
+
anchor_center_y
;
bbox_width
=
std
::
exp
(
std
::
min
<
T
>
(
bbox_deltas_data
[
i
*
len
+
2
],
kBBoxClipDefault
))
*
anchor_width
;
bbox_height
=
std
::
exp
(
std
::
min
<
T
>
(
bbox_deltas_data
[
i
*
len
+
3
],
kBBoxClipDefault
))
*
anchor_height
;
}
proposals_data
[
i
*
len
]
=
bbox_center_x
-
bbox_width
/
2
;
proposals_data
[
i
*
len
+
1
]
=
bbox_center_y
-
bbox_height
/
2
;
proposals_data
[
i
*
len
+
2
]
=
bbox_center_x
+
bbox_width
/
2
-
1
;
proposals_data
[
i
*
len
+
3
]
=
bbox_center_y
+
bbox_height
/
2
-
1
;
}
// return proposals;
}
template
<
class
T
>
static
inline
void
ClipTiledBoxes
(
const
platform
::
DeviceContext
&
ctx
,
const
Tensor
&
im_info
,
Tensor
*
boxes
)
{
T
*
boxes_data
=
boxes
->
mutable_data
<
T
>
(
ctx
.
GetPlace
());
const
T
*
im_info_data
=
im_info
.
data
<
T
>
();
T
zero
(
0
);
for
(
int64_t
i
=
0
;
i
<
boxes
->
numel
();
++
i
)
{
if
(
i
%
4
==
0
)
{
boxes_data
[
i
]
=
std
::
max
(
std
::
min
(
boxes_data
[
i
],
im_info_data
[
1
]
-
1
),
zero
);
}
else
if
(
i
%
4
==
1
)
{
boxes_data
[
i
]
=
std
::
max
(
std
::
min
(
boxes_data
[
i
],
im_info_data
[
0
]
-
1
),
zero
);
}
else
if
(
i
%
4
==
2
)
{
boxes_data
[
i
]
=
std
::
max
(
std
::
min
(
boxes_data
[
i
],
im_info_data
[
1
]
-
1
),
zero
);
}
else
{
boxes_data
[
i
]
=
std
::
max
(
std
::
min
(
boxes_data
[
i
],
im_info_data
[
0
]
-
1
),
zero
);
}
}
}
template
<
class
T
>
static
inline
void
FilterBoxes
(
const
platform
::
DeviceContext
&
ctx
,
Tensor
*
boxes
,
float
min_size
,
const
Tensor
&
im_info
,
Tensor
*
keep
)
{
const
T
*
im_info_data
=
im_info
.
data
<
T
>
();
T
*
boxes_data
=
boxes
->
mutable_data
<
T
>
(
ctx
.
GetPlace
());
T
im_scale
=
im_info_data
[
2
];
keep
->
Resize
({
boxes
->
dims
()[
0
]});
min_size
=
std
::
max
(
min_size
,
1.0
f
);
int
*
keep_data
=
keep
->
mutable_data
<
int
>
(
ctx
.
GetPlace
());
int
keep_len
=
0
;
for
(
int
i
=
0
;
i
<
boxes
->
dims
()[
0
];
++
i
)
{
T
ws
=
boxes_data
[
4
*
i
+
2
]
-
boxes_data
[
4
*
i
]
+
1
;
T
hs
=
boxes_data
[
4
*
i
+
3
]
-
boxes_data
[
4
*
i
+
1
]
+
1
;
T
ws_origin_scale
=
(
boxes_data
[
4
*
i
+
2
]
-
boxes_data
[
4
*
i
])
/
im_scale
+
1
;
T
hs_origin_scale
=
(
boxes_data
[
4
*
i
+
3
]
-
boxes_data
[
4
*
i
+
1
])
/
im_scale
+
1
;
T
x_ctr
=
boxes_data
[
4
*
i
]
+
ws
/
2
;
T
y_ctr
=
boxes_data
[
4
*
i
+
1
]
+
hs
/
2
;
if
(
ws_origin_scale
>=
min_size
&&
hs_origin_scale
>=
min_size
&&
x_ctr
<=
im_info_data
[
1
]
&&
y_ctr
<=
im_info_data
[
0
])
{
keep_data
[
keep_len
++
]
=
i
;
}
}
keep
->
Resize
({
keep_len
});
}
template
<
class
T
>
static
inline
std
::
vector
<
std
::
pair
<
T
,
int
>>
GetSortedScoreIndex
(
const
std
::
vector
<
T
>
&
scores
)
{
std
::
vector
<
std
::
pair
<
T
,
int
>>
sorted_indices
;
sorted_indices
.
reserve
(
scores
.
size
());
for
(
size_t
i
=
0
;
i
<
scores
.
size
();
++
i
)
{
sorted_indices
.
emplace_back
(
scores
[
i
],
i
);
}
// Sort the score pair according to the scores in descending order
std
::
stable_sort
(
sorted_indices
.
begin
(),
sorted_indices
.
end
(),
[](
const
std
::
pair
<
T
,
int
>
&
a
,
const
std
::
pair
<
T
,
int
>
&
b
)
{
return
a
.
first
<
b
.
first
;
});
return
sorted_indices
;
}
template
<
class
T
>
static
inline
T
BBoxArea
(
const
T
*
box
,
bool
normalized
)
{
if
(
box
[
2
]
<
box
[
0
]
||
box
[
3
]
<
box
[
1
])
{
// If coordinate values are is invalid
// (e.g. xmax < xmin or ymax < ymin), return 0.
return
static_cast
<
T
>
(
0.
);
}
else
{
const
T
w
=
box
[
2
]
-
box
[
0
];
const
T
h
=
box
[
3
]
-
box
[
1
];
if
(
normalized
)
{
return
w
*
h
;
}
else
{
// If coordinate values are not within range [0, 1].
return
(
w
+
1
)
*
(
h
+
1
);
}
}
}
template
<
class
T
>
static
inline
T
JaccardOverlap
(
const
T
*
box1
,
const
T
*
box2
,
bool
normalized
)
{
if
(
box2
[
0
]
>
box1
[
2
]
||
box2
[
2
]
<
box1
[
0
]
||
box2
[
1
]
>
box1
[
3
]
||
box2
[
3
]
<
box1
[
1
])
{
return
static_cast
<
T
>
(
0.
);
}
else
{
const
T
inter_xmin
=
std
::
max
(
box1
[
0
],
box2
[
0
]);
const
T
inter_ymin
=
std
::
max
(
box1
[
1
],
box2
[
1
]);
const
T
inter_xmax
=
std
::
min
(
box1
[
2
],
box2
[
2
]);
const
T
inter_ymax
=
std
::
min
(
box1
[
3
],
box2
[
3
]);
const
T
inter_w
=
std
::
max
(
T
(
0
),
inter_xmax
-
inter_xmin
+
1
);
const
T
inter_h
=
std
::
max
(
T
(
0
),
inter_ymax
-
inter_ymin
+
1
);
const
T
inter_area
=
inter_w
*
inter_h
;
const
T
bbox1_area
=
BBoxArea
<
T
>
(
box1
,
normalized
);
const
T
bbox2_area
=
BBoxArea
<
T
>
(
box2
,
normalized
);
return
inter_area
/
(
bbox1_area
+
bbox2_area
-
inter_area
);
}
}
template
<
typename
T
>
static
inline
Tensor
VectorToTensor
(
const
std
::
vector
<
T
>
&
selected_indices
,
int
selected_num
)
{
Tensor
keep_nms
;
keep_nms
.
Resize
({
selected_num
});
auto
*
keep_data
=
keep_nms
.
mutable_data
<
T
>
(
platform
::
CPUPlace
());
for
(
int
i
=
0
;
i
<
selected_num
;
++
i
)
{
keep_data
[
i
]
=
selected_indices
[
i
];
}
return
keep_nms
;
}
template
<
class
T
>
static
inline
Tensor
NMS
(
const
platform
::
DeviceContext
&
ctx
,
Tensor
*
bbox
,
Tensor
*
scores
,
T
nms_threshold
,
float
eta
)
{
int64_t
num_boxes
=
bbox
->
dims
()[
0
];
// 4: [xmin ymin xmax ymax]
int64_t
box_size
=
bbox
->
dims
()[
1
];
std
::
vector
<
T
>
scores_data
(
num_boxes
);
std
::
copy_n
(
scores
->
data
<
T
>
(),
num_boxes
,
scores_data
.
begin
());
std
::
vector
<
std
::
pair
<
T
,
int
>>
sorted_indices
=
GetSortedScoreIndex
<
T
>
(
scores_data
);
std
::
vector
<
int
>
selected_indices
;
int
selected_num
=
0
;
T
adaptive_threshold
=
nms_threshold
;
const
T
*
bbox_data
=
bbox
->
data
<
T
>
();
while
(
sorted_indices
.
size
()
!=
0
)
{
int
idx
=
sorted_indices
.
back
().
second
;
bool
flag
=
true
;
for
(
int
kept_idx
:
selected_indices
)
{
if
(
flag
)
{
T
overlap
=
JaccardOverlap
<
T
>
(
bbox_data
+
idx
*
box_size
,
bbox_data
+
kept_idx
*
box_size
,
false
);
flag
=
(
overlap
<=
adaptive_threshold
);
}
else
{
break
;
}
}
if
(
flag
)
{
selected_indices
.
push_back
(
idx
);
++
selected_num
;
}
sorted_indices
.
erase
(
sorted_indices
.
end
()
-
1
);
if
(
flag
&&
eta
<
1
&&
adaptive_threshold
>
0.5
)
{
adaptive_threshold
*=
eta
;
}
}
return
VectorToTensor
(
selected_indices
,
selected_num
);
}
template
<
typename
T
>
template
<
typename
T
>
class
GenerateProposalsKernel
:
public
framework
::
OpKernel
<
T
>
{
class
GenerateProposalsKernel
:
public
framework
::
OpKernel
<
T
>
{
public:
public:
...
@@ -434,10 +205,10 @@ class GenerateProposalsKernel : public framework::OpKernel<T> {
...
@@ -434,10 +205,10 @@ class GenerateProposalsKernel : public framework::OpKernel<T> {
proposals
.
mutable_data
<
T
>
({
index_t
.
numel
(),
4
},
ctx
.
GetPlace
());
proposals
.
mutable_data
<
T
>
({
index_t
.
numel
(),
4
},
ctx
.
GetPlace
());
BoxCoder
<
T
>
(
ctx
,
&
anchor_sel
,
&
bbox_sel
,
&
var_sel
,
&
proposals
);
BoxCoder
<
T
>
(
ctx
,
&
anchor_sel
,
&
bbox_sel
,
&
var_sel
,
&
proposals
);
ClipTiledBoxes
<
T
>
(
ctx
,
im_info_slice
,
&
proposals
);
ClipTiledBoxes
<
T
>
(
ctx
,
im_info_slice
,
proposals
,
&
proposals
,
false
);
Tensor
keep
;
Tensor
keep
;
FilterBoxes
<
T
>
(
ctx
,
&
proposals
,
min_size
,
im_info_slice
,
&
keep
);
FilterBoxes
<
T
>
(
ctx
,
&
proposals
,
min_size
,
im_info_slice
,
true
,
&
keep
);
// Handle the case when there is no keep index left
// Handle the case when there is no keep index left
if
(
keep
.
numel
()
==
0
)
{
if
(
keep
.
numel
()
==
0
)
{
math
::
SetConstant
<
platform
::
CPUDeviceContext
,
T
>
set_zero
;
math
::
SetConstant
<
platform
::
CPUDeviceContext
,
T
>
set_zero
;
...
...
paddle/fluid/operators/detection/generate_proposals_op.cu
浏览文件 @
5262b025
...
@@ -16,13 +16,11 @@ limitations under the License. */
...
@@ -16,13 +16,11 @@ limitations under the License. */
#include <stdio.h>
#include <stdio.h>
#include <string>
#include <string>
#include <vector>
#include <vector>
#include "cub/cub.cuh"
#include "paddle/fluid/framework/mixed_vector.h"
#include "paddle/fluid/framework/mixed_vector.h"
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/memory/memory.h"
#include "paddle/fluid/memory/memory.h"
#include "paddle/fluid/operators/
gather
.cu.h"
#include "paddle/fluid/operators/
detection/bbox_util
.cu.h"
#include "paddle/fluid/operators/math/math_function.h"
#include "paddle/fluid/operators/math/math_function.h"
#include "paddle/fluid/platform/for_range.h"
namespace
paddle
{
namespace
paddle
{
namespace
operators
{
namespace
operators
{
...
@@ -31,258 +29,6 @@ using Tensor = framework::Tensor;
...
@@ -31,258 +29,6 @@ using Tensor = framework::Tensor;
using
LoDTensor
=
framework
::
LoDTensor
;
using
LoDTensor
=
framework
::
LoDTensor
;
namespace
{
namespace
{
#define DIVUP(m, n) ((m) / (n) + ((m) % (n) > 0))
int
const
kThreadsPerBlock
=
sizeof
(
uint64_t
)
*
8
;
static
const
double
kBBoxClipDefault
=
std
::
log
(
1000.0
/
16.0
);
struct
RangeInitFunctor
{
int
start_
;
int
delta_
;
int
*
out_
;
__device__
void
operator
()(
size_t
i
)
{
out_
[
i
]
=
start_
+
i
*
delta_
;
}
};
template
<
typename
T
>
static
void
SortDescending
(
const
platform
::
CUDADeviceContext
&
ctx
,
const
Tensor
&
value
,
Tensor
*
value_out
,
Tensor
*
index_out
)
{
int
num
=
static_cast
<
int
>
(
value
.
numel
());
Tensor
index_in_t
;
int
*
idx_in
=
index_in_t
.
mutable_data
<
int
>
({
num
},
ctx
.
GetPlace
());
platform
::
ForRange
<
platform
::
CUDADeviceContext
>
for_range
(
ctx
,
num
);
for_range
(
RangeInitFunctor
{
0
,
1
,
idx_in
});
int
*
idx_out
=
index_out
->
mutable_data
<
int
>
({
num
},
ctx
.
GetPlace
());
const
T
*
keys_in
=
value
.
data
<
T
>
();
T
*
keys_out
=
value_out
->
mutable_data
<
T
>
({
num
},
ctx
.
GetPlace
());
// Determine temporary device storage requirements
size_t
temp_storage_bytes
=
0
;
cub
::
DeviceRadixSort
::
SortPairsDescending
<
T
,
int
>
(
nullptr
,
temp_storage_bytes
,
keys_in
,
keys_out
,
idx_in
,
idx_out
,
num
);
// Allocate temporary storage
auto
place
=
BOOST_GET_CONST
(
platform
::
CUDAPlace
,
ctx
.
GetPlace
());
auto
d_temp_storage
=
memory
::
Alloc
(
place
,
temp_storage_bytes
);
// Run sorting operation
cub
::
DeviceRadixSort
::
SortPairsDescending
<
T
,
int
>
(
d_temp_storage
->
ptr
(),
temp_storage_bytes
,
keys_in
,
keys_out
,
idx_in
,
idx_out
,
num
);
}
template
<
typename
T
>
struct
BoxDecodeAndClipFunctor
{
const
T
*
anchor
;
const
T
*
deltas
;
const
T
*
var
;
const
int
*
index
;
const
T
*
im_info
;
T
*
proposals
;
BoxDecodeAndClipFunctor
(
const
T
*
anchor
,
const
T
*
deltas
,
const
T
*
var
,
const
int
*
index
,
const
T
*
im_info
,
T
*
proposals
)
:
anchor
(
anchor
),
deltas
(
deltas
),
var
(
var
),
index
(
index
),
im_info
(
im_info
),
proposals
(
proposals
)
{}
T
bbox_clip_default
{
static_cast
<
T
>
(
kBBoxClipDefault
)};
__device__
void
operator
()(
size_t
i
)
{
int
k
=
index
[
i
]
*
4
;
T
axmin
=
anchor
[
k
];
T
aymin
=
anchor
[
k
+
1
];
T
axmax
=
anchor
[
k
+
2
];
T
aymax
=
anchor
[
k
+
3
];
T
w
=
axmax
-
axmin
+
1.0
;
T
h
=
aymax
-
aymin
+
1.0
;
T
cx
=
axmin
+
0.5
*
w
;
T
cy
=
aymin
+
0.5
*
h
;
T
dxmin
=
deltas
[
k
];
T
dymin
=
deltas
[
k
+
1
];
T
dxmax
=
deltas
[
k
+
2
];
T
dymax
=
deltas
[
k
+
3
];
T
d_cx
,
d_cy
,
d_w
,
d_h
;
if
(
var
)
{
d_cx
=
cx
+
dxmin
*
w
*
var
[
k
];
d_cy
=
cy
+
dymin
*
h
*
var
[
k
+
1
];
d_w
=
exp
(
Min
(
dxmax
*
var
[
k
+
2
],
bbox_clip_default
))
*
w
;
d_h
=
exp
(
Min
(
dymax
*
var
[
k
+
3
],
bbox_clip_default
))
*
h
;
}
else
{
d_cx
=
cx
+
dxmin
*
w
;
d_cy
=
cy
+
dymin
*
h
;
d_w
=
exp
(
Min
(
dxmax
,
bbox_clip_default
))
*
w
;
d_h
=
exp
(
Min
(
dymax
,
bbox_clip_default
))
*
h
;
}
T
oxmin
=
d_cx
-
d_w
*
0.5
;
T
oymin
=
d_cy
-
d_h
*
0.5
;
T
oxmax
=
d_cx
+
d_w
*
0.5
-
1.
;
T
oymax
=
d_cy
+
d_h
*
0.5
-
1.
;
proposals
[
i
*
4
]
=
Max
(
Min
(
oxmin
,
im_info
[
1
]
-
1.
),
0.
);
proposals
[
i
*
4
+
1
]
=
Max
(
Min
(
oymin
,
im_info
[
0
]
-
1.
),
0.
);
proposals
[
i
*
4
+
2
]
=
Max
(
Min
(
oxmax
,
im_info
[
1
]
-
1.
),
0.
);
proposals
[
i
*
4
+
3
]
=
Max
(
Min
(
oymax
,
im_info
[
0
]
-
1.
),
0.
);
}
__device__
__forceinline__
T
Min
(
T
a
,
T
b
)
const
{
return
a
>
b
?
b
:
a
;
}
__device__
__forceinline__
T
Max
(
T
a
,
T
b
)
const
{
return
a
>
b
?
a
:
b
;
}
};
template
<
typename
T
,
int
BlockSize
>
static
__global__
void
FilterBBoxes
(
const
T
*
bboxes
,
const
T
*
im_info
,
const
T
min_size
,
const
int
num
,
int
*
keep_num
,
int
*
keep
)
{
T
im_h
=
im_info
[
0
];
T
im_w
=
im_info
[
1
];
T
im_scale
=
im_info
[
2
];
int
cnt
=
0
;
__shared__
int
keep_index
[
BlockSize
];
CUDA_KERNEL_LOOP
(
i
,
num
)
{
keep_index
[
threadIdx
.
x
]
=
-
1
;
__syncthreads
();
int
k
=
i
*
4
;
T
xmin
=
bboxes
[
k
];
T
ymin
=
bboxes
[
k
+
1
];
T
xmax
=
bboxes
[
k
+
2
];
T
ymax
=
bboxes
[
k
+
3
];
T
w
=
xmax
-
xmin
+
1.0
;
T
h
=
ymax
-
ymin
+
1.0
;
T
cx
=
xmin
+
w
/
2.
;
T
cy
=
ymin
+
h
/
2.
;
T
w_s
=
(
xmax
-
xmin
)
/
im_scale
+
1.
;
T
h_s
=
(
ymax
-
ymin
)
/
im_scale
+
1.
;
if
(
w_s
>=
min_size
&&
h_s
>=
min_size
&&
cx
<=
im_w
&&
cy
<=
im_h
)
{
keep_index
[
threadIdx
.
x
]
=
i
;
}
__syncthreads
();
if
(
threadIdx
.
x
==
0
)
{
int
size
=
(
num
-
i
)
<
BlockSize
?
num
-
i
:
BlockSize
;
for
(
int
j
=
0
;
j
<
size
;
++
j
)
{
if
(
keep_index
[
j
]
>
-
1
)
{
keep
[
cnt
++
]
=
keep_index
[
j
];
}
}
}
__syncthreads
();
}
if
(
threadIdx
.
x
==
0
)
{
keep_num
[
0
]
=
cnt
;
}
}
static
__device__
inline
float
IoU
(
const
float
*
a
,
const
float
*
b
)
{
float
left
=
max
(
a
[
0
],
b
[
0
]),
right
=
min
(
a
[
2
],
b
[
2
]);
float
top
=
max
(
a
[
1
],
b
[
1
]),
bottom
=
min
(
a
[
3
],
b
[
3
]);
float
width
=
max
(
right
-
left
+
1
,
0.
f
),
height
=
max
(
bottom
-
top
+
1
,
0.
f
);
float
inter_s
=
width
*
height
;
float
s_a
=
(
a
[
2
]
-
a
[
0
]
+
1
)
*
(
a
[
3
]
-
a
[
1
]
+
1
);
float
s_b
=
(
b
[
2
]
-
b
[
0
]
+
1
)
*
(
b
[
3
]
-
b
[
1
]
+
1
);
return
inter_s
/
(
s_a
+
s_b
-
inter_s
);
}
static
__global__
void
NMSKernel
(
const
int
n_boxes
,
const
float
nms_overlap_thresh
,
const
float
*
dev_boxes
,
uint64_t
*
dev_mask
)
{
const
int
row_start
=
blockIdx
.
y
;
const
int
col_start
=
blockIdx
.
x
;
const
int
row_size
=
min
(
n_boxes
-
row_start
*
kThreadsPerBlock
,
kThreadsPerBlock
);
const
int
col_size
=
min
(
n_boxes
-
col_start
*
kThreadsPerBlock
,
kThreadsPerBlock
);
__shared__
float
block_boxes
[
kThreadsPerBlock
*
4
];
if
(
threadIdx
.
x
<
col_size
)
{
block_boxes
[
threadIdx
.
x
*
4
+
0
]
=
dev_boxes
[(
kThreadsPerBlock
*
col_start
+
threadIdx
.
x
)
*
4
+
0
];
block_boxes
[
threadIdx
.
x
*
4
+
1
]
=
dev_boxes
[(
kThreadsPerBlock
*
col_start
+
threadIdx
.
x
)
*
4
+
1
];
block_boxes
[
threadIdx
.
x
*
4
+
2
]
=
dev_boxes
[(
kThreadsPerBlock
*
col_start
+
threadIdx
.
x
)
*
4
+
2
];
block_boxes
[
threadIdx
.
x
*
4
+
3
]
=
dev_boxes
[(
kThreadsPerBlock
*
col_start
+
threadIdx
.
x
)
*
4
+
3
];
}
__syncthreads
();
if
(
threadIdx
.
x
<
row_size
)
{
const
int
cur_box_idx
=
kThreadsPerBlock
*
row_start
+
threadIdx
.
x
;
const
float
*
cur_box
=
dev_boxes
+
cur_box_idx
*
4
;
int
i
=
0
;
uint64_t
t
=
0
;
int
start
=
0
;
if
(
row_start
==
col_start
)
{
start
=
threadIdx
.
x
+
1
;
}
for
(
i
=
start
;
i
<
col_size
;
i
++
)
{
if
(
IoU
(
cur_box
,
block_boxes
+
i
*
4
)
>
nms_overlap_thresh
)
{
t
|=
1ULL
<<
i
;
}
}
const
int
col_blocks
=
DIVUP
(
n_boxes
,
kThreadsPerBlock
);
dev_mask
[
cur_box_idx
*
col_blocks
+
col_start
]
=
t
;
}
}
template
<
typename
T
>
static
void
NMS
(
const
platform
::
CUDADeviceContext
&
ctx
,
const
Tensor
&
proposals
,
const
Tensor
&
sorted_indices
,
const
T
nms_threshold
,
Tensor
*
keep_out
)
{
int
boxes_num
=
proposals
.
dims
()[
0
];
const
int
col_blocks
=
DIVUP
(
boxes_num
,
kThreadsPerBlock
);
dim3
blocks
(
DIVUP
(
boxes_num
,
kThreadsPerBlock
),
DIVUP
(
boxes_num
,
kThreadsPerBlock
));
dim3
threads
(
kThreadsPerBlock
);
const
T
*
boxes
=
proposals
.
data
<
T
>
();
auto
place
=
BOOST_GET_CONST
(
platform
::
CUDAPlace
,
ctx
.
GetPlace
());
framework
::
Vector
<
uint64_t
>
mask
(
boxes_num
*
col_blocks
);
NMSKernel
<<<
blocks
,
threads
>>>
(
boxes_num
,
nms_threshold
,
boxes
,
mask
.
CUDAMutableData
(
BOOST_GET_CONST
(
platform
::
CUDAPlace
,
ctx
.
GetPlace
())));
std
::
vector
<
uint64_t
>
remv
(
col_blocks
);
memset
(
&
remv
[
0
],
0
,
sizeof
(
uint64_t
)
*
col_blocks
);
std
::
vector
<
int
>
keep_vec
;
int
num_to_keep
=
0
;
for
(
int
i
=
0
;
i
<
boxes_num
;
i
++
)
{
int
nblock
=
i
/
kThreadsPerBlock
;
int
inblock
=
i
%
kThreadsPerBlock
;
if
(
!
(
remv
[
nblock
]
&
(
1ULL
<<
inblock
)))
{
++
num_to_keep
;
keep_vec
.
push_back
(
i
);
uint64_t
*
p
=
&
mask
[
0
]
+
i
*
col_blocks
;
for
(
int
j
=
nblock
;
j
<
col_blocks
;
j
++
)
{
remv
[
j
]
|=
p
[
j
];
}
}
}
int
*
keep
=
keep_out
->
mutable_data
<
int
>
({
num_to_keep
},
ctx
.
GetPlace
());
memory
::
Copy
(
place
,
keep
,
platform
::
CPUPlace
(),
keep_vec
.
data
(),
sizeof
(
int
)
*
num_to_keep
,
ctx
.
stream
());
ctx
.
Wait
();
}
template
<
typename
T
>
template
<
typename
T
>
static
std
::
pair
<
Tensor
,
Tensor
>
ProposalForOneImage
(
static
std
::
pair
<
Tensor
,
Tensor
>
ProposalForOneImage
(
const
platform
::
CUDADeviceContext
&
ctx
,
const
Tensor
&
im_info
,
const
platform
::
CUDADeviceContext
&
ctx
,
const
Tensor
&
im_info
,
...
...
paddle/fluid/operators/detection/generate_proposals_v2_op.cc
0 → 100644
浏览文件 @
5262b025
/* Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include <cmath>
#include <cstring>
#include <string>
#include <vector>
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/framework/op_version_registry.h"
#include "paddle/fluid/operators/detection/bbox_util.h"
#include "paddle/fluid/operators/detection/nms_util.h"
#include "paddle/fluid/operators/gather.h"
#include "paddle/fluid/operators/math/math_function.h"
namespace
paddle
{
namespace
operators
{
using
Tensor
=
framework
::
Tensor
;
using
LoDTensor
=
framework
::
LoDTensor
;
class
GenerateProposalsV2Op
:
public
framework
::
OperatorWithKernel
{
public:
using
framework
::
OperatorWithKernel
::
OperatorWithKernel
;
void
InferShape
(
framework
::
InferShapeContext
*
ctx
)
const
override
{
PADDLE_ENFORCE_EQ
(
ctx
->
HasInput
(
"Scores"
),
true
,
platform
::
errors
::
NotFound
(
"Input(Scores) shouldn't be null."
));
PADDLE_ENFORCE_EQ
(
ctx
->
HasInput
(
"BboxDeltas"
),
true
,
platform
::
errors
::
NotFound
(
"Input(BboxDeltas) shouldn't be null."
));
PADDLE_ENFORCE_EQ
(
ctx
->
HasInput
(
"ImShape"
),
true
,
platform
::
errors
::
NotFound
(
"Input(ImShape) shouldn't be null."
));
PADDLE_ENFORCE_EQ
(
ctx
->
HasInput
(
"Anchors"
),
true
,
platform
::
errors
::
NotFound
(
"Input(Anchors) shouldn't be null."
));
PADDLE_ENFORCE_EQ
(
ctx
->
HasInput
(
"Variances"
),
true
,
platform
::
errors
::
NotFound
(
"Input(Variances) shouldn't be null."
));
ctx
->
SetOutputDim
(
"RpnRois"
,
{
-
1
,
4
});
ctx
->
SetOutputDim
(
"RpnRoiProbs"
,
{
-
1
,
1
});
if
(
!
ctx
->
IsRuntime
())
{
ctx
->
SetLoDLevel
(
"RpnRois"
,
std
::
max
(
ctx
->
GetLoDLevel
(
"Scores"
),
1
));
ctx
->
SetLoDLevel
(
"RpnRoiProbs"
,
std
::
max
(
ctx
->
GetLoDLevel
(
"Scores"
),
1
));
}
}
protected:
framework
::
OpKernelType
GetExpectedKernelType
(
const
framework
::
ExecutionContext
&
ctx
)
const
override
{
return
framework
::
OpKernelType
(
OperatorWithKernel
::
IndicateVarDataType
(
ctx
,
"Anchors"
),
ctx
.
device_context
());
}
};
template
<
typename
T
>
class
GenerateProposalsV2Kernel
:
public
framework
::
OpKernel
<
T
>
{
public:
void
Compute
(
const
framework
::
ExecutionContext
&
context
)
const
override
{
auto
*
scores
=
context
.
Input
<
Tensor
>
(
"Scores"
);
auto
*
bbox_deltas
=
context
.
Input
<
Tensor
>
(
"BboxDeltas"
);
auto
*
im_shape
=
context
.
Input
<
Tensor
>
(
"ImShape"
);
auto
anchors
=
GET_DATA_SAFELY
(
context
.
Input
<
Tensor
>
(
"Anchors"
),
"Input"
,
"Anchors"
,
"GenerateProposals"
);
auto
variances
=
GET_DATA_SAFELY
(
context
.
Input
<
Tensor
>
(
"Variances"
),
"Input"
,
"Variances"
,
"GenerateProposals"
);
auto
*
rpn_rois
=
context
.
Output
<
LoDTensor
>
(
"RpnRois"
);
auto
*
rpn_roi_probs
=
context
.
Output
<
LoDTensor
>
(
"RpnRoiProbs"
);
int
pre_nms_top_n
=
context
.
Attr
<
int
>
(
"pre_nms_topN"
);
int
post_nms_top_n
=
context
.
Attr
<
int
>
(
"post_nms_topN"
);
float
nms_thresh
=
context
.
Attr
<
float
>
(
"nms_thresh"
);
float
min_size
=
context
.
Attr
<
float
>
(
"min_size"
);
float
eta
=
context
.
Attr
<
float
>
(
"eta"
);
auto
&
dev_ctx
=
context
.
template
device_context
<
platform
::
CPUDeviceContext
>();
auto
&
scores_dim
=
scores
->
dims
();
int64_t
num
=
scores_dim
[
0
];
int64_t
c_score
=
scores_dim
[
1
];
int64_t
h_score
=
scores_dim
[
2
];
int64_t
w_score
=
scores_dim
[
3
];
auto
&
bbox_dim
=
bbox_deltas
->
dims
();
int64_t
c_bbox
=
bbox_dim
[
1
];
int64_t
h_bbox
=
bbox_dim
[
2
];
int64_t
w_bbox
=
bbox_dim
[
3
];
rpn_rois
->
mutable_data
<
T
>
({
bbox_deltas
->
numel
()
/
4
,
4
},
context
.
GetPlace
());
rpn_roi_probs
->
mutable_data
<
T
>
({
scores
->
numel
(),
1
},
context
.
GetPlace
());
Tensor
bbox_deltas_swap
,
scores_swap
;
bbox_deltas_swap
.
mutable_data
<
T
>
({
num
,
h_bbox
,
w_bbox
,
c_bbox
},
dev_ctx
.
GetPlace
());
scores_swap
.
mutable_data
<
T
>
({
num
,
h_score
,
w_score
,
c_score
},
dev_ctx
.
GetPlace
());
math
::
Transpose
<
platform
::
CPUDeviceContext
,
T
,
4
>
trans
;
std
::
vector
<
int
>
axis
=
{
0
,
2
,
3
,
1
};
trans
(
dev_ctx
,
*
bbox_deltas
,
&
bbox_deltas_swap
,
axis
);
trans
(
dev_ctx
,
*
scores
,
&
scores_swap
,
axis
);
framework
::
LoD
lod
;
lod
.
resize
(
1
);
auto
&
lod0
=
lod
[
0
];
lod0
.
push_back
(
0
);
anchors
.
Resize
({
anchors
.
numel
()
/
4
,
4
});
variances
.
Resize
({
variances
.
numel
()
/
4
,
4
});
std
::
vector
<
int
>
tmp_num
;
int64_t
num_proposals
=
0
;
for
(
int64_t
i
=
0
;
i
<
num
;
++
i
)
{
Tensor
im_shape_slice
=
im_shape
->
Slice
(
i
,
i
+
1
);
Tensor
bbox_deltas_slice
=
bbox_deltas_swap
.
Slice
(
i
,
i
+
1
);
Tensor
scores_slice
=
scores_swap
.
Slice
(
i
,
i
+
1
);
bbox_deltas_slice
.
Resize
({
h_bbox
*
w_bbox
*
c_bbox
/
4
,
4
});
scores_slice
.
Resize
({
h_score
*
w_score
*
c_score
,
1
});
std
::
pair
<
Tensor
,
Tensor
>
tensor_pair
=
ProposalForOneImage
(
dev_ctx
,
im_shape_slice
,
anchors
,
variances
,
bbox_deltas_slice
,
scores_slice
,
pre_nms_top_n
,
post_nms_top_n
,
nms_thresh
,
min_size
,
eta
);
Tensor
&
proposals
=
tensor_pair
.
first
;
Tensor
&
scores
=
tensor_pair
.
second
;
AppendProposals
(
rpn_rois
,
4
*
num_proposals
,
proposals
);
AppendProposals
(
rpn_roi_probs
,
num_proposals
,
scores
);
num_proposals
+=
proposals
.
dims
()[
0
];
lod0
.
push_back
(
num_proposals
);
tmp_num
.
push_back
(
proposals
.
dims
()[
0
]);
}
if
(
context
.
HasOutput
(
"RpnRoisNum"
))
{
auto
*
rpn_rois_num
=
context
.
Output
<
Tensor
>
(
"RpnRoisNum"
);
rpn_rois_num
->
mutable_data
<
int
>
({
num
},
context
.
GetPlace
());
int
*
num_data
=
rpn_rois_num
->
data
<
int
>
();
for
(
int
i
=
0
;
i
<
num
;
i
++
)
{
num_data
[
i
]
=
tmp_num
[
i
];
}
rpn_rois_num
->
Resize
({
num
});
}
rpn_rois
->
set_lod
(
lod
);
rpn_roi_probs
->
set_lod
(
lod
);
rpn_rois
->
Resize
({
num_proposals
,
4
});
rpn_roi_probs
->
Resize
({
num_proposals
,
1
});
}
std
::
pair
<
Tensor
,
Tensor
>
ProposalForOneImage
(
const
platform
::
CPUDeviceContext
&
ctx
,
const
Tensor
&
im_shape_slice
,
const
Tensor
&
anchors
,
const
Tensor
&
variances
,
const
Tensor
&
bbox_deltas_slice
,
// [M, 4]
const
Tensor
&
scores_slice
,
// [N, 1]
int
pre_nms_top_n
,
int
post_nms_top_n
,
float
nms_thresh
,
float
min_size
,
float
eta
)
const
{
auto
*
scores_data
=
scores_slice
.
data
<
T
>
();
// Sort index
Tensor
index_t
;
index_t
.
Resize
({
scores_slice
.
numel
()});
int
*
index
=
index_t
.
mutable_data
<
int
>
(
ctx
.
GetPlace
());
for
(
int
i
=
0
;
i
<
scores_slice
.
numel
();
++
i
)
{
index
[
i
]
=
i
;
}
auto
compare
=
[
scores_data
](
const
int64_t
&
i
,
const
int64_t
&
j
)
{
return
scores_data
[
i
]
>
scores_data
[
j
];
};
if
(
pre_nms_top_n
<=
0
||
pre_nms_top_n
>=
scores_slice
.
numel
())
{
std
::
sort
(
index
,
index
+
scores_slice
.
numel
(),
compare
);
}
else
{
std
::
nth_element
(
index
,
index
+
pre_nms_top_n
,
index
+
scores_slice
.
numel
(),
compare
);
index_t
.
Resize
({
pre_nms_top_n
});
}
Tensor
scores_sel
,
bbox_sel
,
anchor_sel
,
var_sel
;
scores_sel
.
mutable_data
<
T
>
({
index_t
.
numel
(),
1
},
ctx
.
GetPlace
());
bbox_sel
.
mutable_data
<
T
>
({
index_t
.
numel
(),
4
},
ctx
.
GetPlace
());
anchor_sel
.
mutable_data
<
T
>
({
index_t
.
numel
(),
4
},
ctx
.
GetPlace
());
var_sel
.
mutable_data
<
T
>
({
index_t
.
numel
(),
4
},
ctx
.
GetPlace
());
CPUGather
<
T
>
(
ctx
,
scores_slice
,
index_t
,
&
scores_sel
);
CPUGather
<
T
>
(
ctx
,
bbox_deltas_slice
,
index_t
,
&
bbox_sel
);
CPUGather
<
T
>
(
ctx
,
anchors
,
index_t
,
&
anchor_sel
);
CPUGather
<
T
>
(
ctx
,
variances
,
index_t
,
&
var_sel
);
Tensor
proposals
;
proposals
.
mutable_data
<
T
>
({
index_t
.
numel
(),
4
},
ctx
.
GetPlace
());
BoxCoder
<
T
>
(
ctx
,
&
anchor_sel
,
&
bbox_sel
,
&
var_sel
,
&
proposals
);
ClipTiledBoxes
<
T
>
(
ctx
,
im_shape_slice
,
proposals
,
&
proposals
,
false
);
Tensor
keep
;
FilterBoxes
<
T
>
(
ctx
,
&
proposals
,
min_size
,
im_shape_slice
,
false
,
&
keep
);
// Handle the case when there is no keep index left
if
(
keep
.
numel
()
==
0
)
{
math
::
SetConstant
<
platform
::
CPUDeviceContext
,
T
>
set_zero
;
bbox_sel
.
mutable_data
<
T
>
({
1
,
4
},
ctx
.
GetPlace
());
set_zero
(
ctx
,
&
bbox_sel
,
static_cast
<
T
>
(
0
));
Tensor
scores_filter
;
scores_filter
.
mutable_data
<
T
>
({
1
,
1
},
ctx
.
GetPlace
());
set_zero
(
ctx
,
&
scores_filter
,
static_cast
<
T
>
(
0
));
return
std
::
make_pair
(
bbox_sel
,
scores_filter
);
}
Tensor
scores_filter
;
bbox_sel
.
mutable_data
<
T
>
({
keep
.
numel
(),
4
},
ctx
.
GetPlace
());
scores_filter
.
mutable_data
<
T
>
({
keep
.
numel
(),
1
},
ctx
.
GetPlace
());
CPUGather
<
T
>
(
ctx
,
proposals
,
keep
,
&
bbox_sel
);
CPUGather
<
T
>
(
ctx
,
scores_sel
,
keep
,
&
scores_filter
);
if
(
nms_thresh
<=
0
)
{
return
std
::
make_pair
(
bbox_sel
,
scores_filter
);
}
Tensor
keep_nms
=
NMS
<
T
>
(
ctx
,
&
bbox_sel
,
&
scores_filter
,
nms_thresh
,
eta
);
if
(
post_nms_top_n
>
0
&&
post_nms_top_n
<
keep_nms
.
numel
())
{
keep_nms
.
Resize
({
post_nms_top_n
});
}
proposals
.
mutable_data
<
T
>
({
keep_nms
.
numel
(),
4
},
ctx
.
GetPlace
());
scores_sel
.
mutable_data
<
T
>
({
keep_nms
.
numel
(),
1
},
ctx
.
GetPlace
());
CPUGather
<
T
>
(
ctx
,
bbox_sel
,
keep_nms
,
&
proposals
);
CPUGather
<
T
>
(
ctx
,
scores_filter
,
keep_nms
,
&
scores_sel
);
return
std
::
make_pair
(
proposals
,
scores_sel
);
}
};
class
GenerateProposalsV2OpMaker
:
public
framework
::
OpProtoAndCheckerMaker
{
public:
void
Make
()
override
{
AddInput
(
"Scores"
,
"(Tensor) The scores from conv is in shape (N, A, H, W), "
"N is batch size, A is number of anchors, "
"H and W are height and width of the feature map"
);
AddInput
(
"BboxDeltas"
,
"(Tensor) Bounding box deltas from conv is in "
"shape (N, 4*A, H, W)."
);
AddInput
(
"ImShape"
,
"(Tensor) Image shape in shape (N, 2), "
"in format (height, width)"
);
AddInput
(
"Anchors"
,
"(Tensor) Bounding box anchors from anchor_generator_op "
"is in shape (A, H, W, 4)."
);
AddInput
(
"Variances"
,
"(Tensor) Bounding box variances with same shape as `Anchors`."
);
AddOutput
(
"RpnRois"
,
"(LoDTensor), Output proposals with shape (rois_num, 4)."
);
AddOutput
(
"RpnRoiProbs"
,
"(LoDTensor) Scores of proposals with shape (rois_num, 1)."
);
AddOutput
(
"RpnRoisNum"
,
"(Tensor), The number of Rpn RoIs in each image"
)
.
AsDispensable
();
AddAttr
<
int
>
(
"pre_nms_topN"
,
"Number of top scoring RPN proposals to keep before "
"applying NMS."
);
AddAttr
<
int
>
(
"post_nms_topN"
,
"Number of top scoring RPN proposals to keep after "
"applying NMS"
);
AddAttr
<
float
>
(
"nms_thresh"
,
"NMS threshold used on RPN proposals."
);
AddAttr
<
float
>
(
"min_size"
,
"Proposal height and width both need to be greater "
"than this min_size."
);
AddAttr
<
float
>
(
"eta"
,
"The parameter for adaptive NMS."
);
AddComment
(
R"DOC(
This operator is the second version of generate_proposals op to generate
bounding box proposals for Faster RCNN.
The proposals are generated for a list of images based on image
score 'Scores', bounding box regression result 'BboxDeltas' as
well as predefined bounding box shapes 'anchors'. Greedy
non-maximum suppression is applied to generate the final bounding
boxes.
The difference between this version and the first version is that the image
scale is no long needed now, so the input requires im_shape instead of im_info.
The change aims to unify the input for all kinds of objective detection
such as YOLO-v3 and Faster R-CNN. As a result, the min_size represents the
size on input image instead of original image which is slightly different
to before and will not effect the result.
)DOC"
);
}
};
}
// namespace operators
}
// namespace paddle
namespace
ops
=
paddle
::
operators
;
REGISTER_OPERATOR
(
generate_proposals_v2
,
ops
::
GenerateProposalsV2Op
,
ops
::
GenerateProposalsV2OpMaker
,
paddle
::
framework
::
EmptyGradOpMaker
<
paddle
::
framework
::
OpDesc
>
,
paddle
::
framework
::
EmptyGradOpMaker
<
paddle
::
imperative
::
OpBase
>
);
REGISTER_OP_CPU_KERNEL
(
generate_proposals_v2
,
ops
::
GenerateProposalsV2Kernel
<
float
>
,
ops
::
GenerateProposalsV2Kernel
<
double
>
);
paddle/fluid/operators/detection/generate_proposals_v2_op.cu
0 → 100644
浏览文件 @
5262b025
/* Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include <paddle/fluid/memory/allocation/allocator.h>
#include <stdio.h>
#include <string>
#include <vector>
#include "paddle/fluid/framework/mixed_vector.h"
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/memory/memory.h"
#include "paddle/fluid/operators/detection/bbox_util.cu.h"
#include "paddle/fluid/operators/math/math_function.h"
namespace
paddle
{
namespace
operators
{
using
Tensor
=
framework
::
Tensor
;
using
LoDTensor
=
framework
::
LoDTensor
;
namespace
{
template
<
typename
T
>
static
std
::
pair
<
Tensor
,
Tensor
>
ProposalForOneImage
(
const
platform
::
CUDADeviceContext
&
ctx
,
const
Tensor
&
im_shape
,
const
Tensor
&
anchors
,
const
Tensor
&
variances
,
const
Tensor
&
bbox_deltas
,
// [M, 4]
const
Tensor
&
scores
,
// [N, 1]
int
pre_nms_top_n
,
int
post_nms_top_n
,
float
nms_thresh
,
float
min_size
,
float
eta
)
{
// 1. pre nms
Tensor
scores_sort
,
index_sort
;
SortDescending
<
T
>
(
ctx
,
scores
,
&
scores_sort
,
&
index_sort
);
int
num
=
scores
.
numel
();
int
pre_nms_num
=
(
pre_nms_top_n
<=
0
||
pre_nms_top_n
>
num
)
?
scores
.
numel
()
:
pre_nms_top_n
;
scores_sort
.
Resize
({
pre_nms_num
,
1
});
index_sort
.
Resize
({
pre_nms_num
,
1
});
// 2. box decode and clipping
Tensor
proposals
;
proposals
.
mutable_data
<
T
>
({
pre_nms_num
,
4
},
ctx
.
GetPlace
());
{
platform
::
ForRange
<
platform
::
CUDADeviceContext
>
for_range
(
ctx
,
pre_nms_num
);
for_range
(
BoxDecodeAndClipFunctor
<
T
>
{
anchors
.
data
<
T
>
(),
bbox_deltas
.
data
<
T
>
(),
variances
.
data
<
T
>
(),
index_sort
.
data
<
int
>
(),
im_shape
.
data
<
T
>
(),
proposals
.
data
<
T
>
()});
}
// 3. filter
Tensor
keep_index
,
keep_num_t
;
keep_index
.
mutable_data
<
int
>
({
pre_nms_num
},
ctx
.
GetPlace
());
keep_num_t
.
mutable_data
<
int
>
({
1
},
ctx
.
GetPlace
());
min_size
=
std
::
max
(
min_size
,
1.0
f
);
auto
stream
=
ctx
.
stream
();
FilterBBoxes
<
T
,
512
><<<
1
,
512
,
0
,
stream
>>>
(
proposals
.
data
<
T
>
(),
im_shape
.
data
<
T
>
(),
min_size
,
pre_nms_num
,
keep_num_t
.
data
<
int
>
(),
keep_index
.
data
<
int
>
(),
false
);
int
keep_num
;
const
auto
gpu_place
=
BOOST_GET_CONST
(
platform
::
CUDAPlace
,
ctx
.
GetPlace
());
memory
::
Copy
(
platform
::
CPUPlace
(),
&
keep_num
,
gpu_place
,
keep_num_t
.
data
<
int
>
(),
sizeof
(
int
),
ctx
.
stream
());
ctx
.
Wait
();
keep_index
.
Resize
({
keep_num
});
Tensor
scores_filter
,
proposals_filter
;
// Handle the case when there is no keep index left
if
(
keep_num
==
0
)
{
math
::
SetConstant
<
platform
::
CUDADeviceContext
,
T
>
set_zero
;
proposals_filter
.
mutable_data
<
T
>
({
1
,
4
},
ctx
.
GetPlace
());
scores_filter
.
mutable_data
<
T
>
({
1
,
1
},
ctx
.
GetPlace
());
set_zero
(
ctx
,
&
proposals_filter
,
static_cast
<
T
>
(
0
));
set_zero
(
ctx
,
&
scores_filter
,
static_cast
<
T
>
(
0
));
return
std
::
make_pair
(
proposals_filter
,
scores_filter
);
}
proposals_filter
.
mutable_data
<
T
>
({
keep_num
,
4
},
ctx
.
GetPlace
());
scores_filter
.
mutable_data
<
T
>
({
keep_num
,
1
},
ctx
.
GetPlace
());
GPUGather
<
T
>
(
ctx
,
proposals
,
keep_index
,
&
proposals_filter
);
GPUGather
<
T
>
(
ctx
,
scores_sort
,
keep_index
,
&
scores_filter
);
if
(
nms_thresh
<=
0
)
{
return
std
::
make_pair
(
proposals_filter
,
scores_filter
);
}
// 4. nms
Tensor
keep_nms
;
NMS
<
T
>
(
ctx
,
proposals_filter
,
keep_index
,
nms_thresh
,
&
keep_nms
);
if
(
post_nms_top_n
>
0
&&
post_nms_top_n
<
keep_nms
.
numel
())
{
keep_nms
.
Resize
({
post_nms_top_n
});
}
Tensor
scores_nms
,
proposals_nms
;
proposals_nms
.
mutable_data
<
T
>
({
keep_nms
.
numel
(),
4
},
ctx
.
GetPlace
());
scores_nms
.
mutable_data
<
T
>
({
keep_nms
.
numel
(),
1
},
ctx
.
GetPlace
());
GPUGather
<
T
>
(
ctx
,
proposals_filter
,
keep_nms
,
&
proposals_nms
);
GPUGather
<
T
>
(
ctx
,
scores_filter
,
keep_nms
,
&
scores_nms
);
return
std
::
make_pair
(
proposals_nms
,
scores_nms
);
}
}
// namespace
template
<
typename
DeviceContext
,
typename
T
>
class
CUDAGenerateProposalsV2Kernel
:
public
framework
::
OpKernel
<
T
>
{
public:
void
Compute
(
const
framework
::
ExecutionContext
&
context
)
const
override
{
auto
*
scores
=
context
.
Input
<
Tensor
>
(
"Scores"
);
auto
*
bbox_deltas
=
context
.
Input
<
Tensor
>
(
"BboxDeltas"
);
auto
*
im_shape
=
context
.
Input
<
Tensor
>
(
"ImShape"
);
auto
anchors
=
GET_DATA_SAFELY
(
context
.
Input
<
Tensor
>
(
"Anchors"
),
"Input"
,
"Anchors"
,
"GenerateProposals"
);
auto
variances
=
GET_DATA_SAFELY
(
context
.
Input
<
Tensor
>
(
"Variances"
),
"Input"
,
"Variances"
,
"GenerateProposals"
);
auto
*
rpn_rois
=
context
.
Output
<
LoDTensor
>
(
"RpnRois"
);
auto
*
rpn_roi_probs
=
context
.
Output
<
LoDTensor
>
(
"RpnRoiProbs"
);
int
pre_nms_top_n
=
context
.
Attr
<
int
>
(
"pre_nms_topN"
);
int
post_nms_top_n
=
context
.
Attr
<
int
>
(
"post_nms_topN"
);
float
nms_thresh
=
context
.
Attr
<
float
>
(
"nms_thresh"
);
float
min_size
=
context
.
Attr
<
float
>
(
"min_size"
);
float
eta
=
context
.
Attr
<
float
>
(
"eta"
);
PADDLE_ENFORCE_GE
(
eta
,
1.
,
platform
::
errors
::
InvalidArgument
(
"Not support adaptive NMS. The attribute 'eta' "
"should not less than 1. But received eta=[%d]"
,
eta
));
auto
&
dev_ctx
=
context
.
template
device_context
<
DeviceContext
>();
auto
scores_dim
=
scores
->
dims
();
int64_t
num
=
scores_dim
[
0
];
int64_t
c_score
=
scores_dim
[
1
];
int64_t
h_score
=
scores_dim
[
2
];
int64_t
w_score
=
scores_dim
[
3
];
auto
bbox_dim
=
bbox_deltas
->
dims
();
int64_t
c_bbox
=
bbox_dim
[
1
];
int64_t
h_bbox
=
bbox_dim
[
2
];
int64_t
w_bbox
=
bbox_dim
[
3
];
Tensor
bbox_deltas_swap
,
scores_swap
;
bbox_deltas_swap
.
mutable_data
<
T
>
({
num
,
h_bbox
,
w_bbox
,
c_bbox
},
dev_ctx
.
GetPlace
());
scores_swap
.
mutable_data
<
T
>
({
num
,
h_score
,
w_score
,
c_score
},
dev_ctx
.
GetPlace
());
math
::
Transpose
<
DeviceContext
,
T
,
4
>
trans
;
std
::
vector
<
int
>
axis
=
{
0
,
2
,
3
,
1
};
trans
(
dev_ctx
,
*
bbox_deltas
,
&
bbox_deltas_swap
,
axis
);
trans
(
dev_ctx
,
*
scores
,
&
scores_swap
,
axis
);
anchors
.
Resize
({
anchors
.
numel
()
/
4
,
4
});
variances
.
Resize
({
variances
.
numel
()
/
4
,
4
});
rpn_rois
->
mutable_data
<
T
>
({
bbox_deltas
->
numel
()
/
4
,
4
},
context
.
GetPlace
());
rpn_roi_probs
->
mutable_data
<
T
>
({
scores
->
numel
(),
1
},
context
.
GetPlace
());
T
*
rpn_rois_data
=
rpn_rois
->
data
<
T
>
();
T
*
rpn_roi_probs_data
=
rpn_roi_probs
->
data
<
T
>
();
auto
place
=
BOOST_GET_CONST
(
platform
::
CUDAPlace
,
dev_ctx
.
GetPlace
());
auto
cpu_place
=
platform
::
CPUPlace
();
int64_t
num_proposals
=
0
;
std
::
vector
<
size_t
>
offset
(
1
,
0
);
std
::
vector
<
int
>
tmp_num
;
for
(
int64_t
i
=
0
;
i
<
num
;
++
i
)
{
Tensor
im_shape_slice
=
im_shape
->
Slice
(
i
,
i
+
1
);
Tensor
bbox_deltas_slice
=
bbox_deltas_swap
.
Slice
(
i
,
i
+
1
);
Tensor
scores_slice
=
scores_swap
.
Slice
(
i
,
i
+
1
);
bbox_deltas_slice
.
Resize
({
h_bbox
*
w_bbox
*
c_bbox
/
4
,
4
});
scores_slice
.
Resize
({
h_score
*
w_score
*
c_score
,
1
});
std
::
pair
<
Tensor
,
Tensor
>
box_score_pair
=
ProposalForOneImage
<
T
>
(
dev_ctx
,
im_shape_slice
,
anchors
,
variances
,
bbox_deltas_slice
,
scores_slice
,
pre_nms_top_n
,
post_nms_top_n
,
nms_thresh
,
min_size
,
eta
);
Tensor
&
proposals
=
box_score_pair
.
first
;
Tensor
&
scores
=
box_score_pair
.
second
;
memory
::
Copy
(
place
,
rpn_rois_data
+
num_proposals
*
4
,
place
,
proposals
.
data
<
T
>
(),
sizeof
(
T
)
*
proposals
.
numel
(),
dev_ctx
.
stream
());
memory
::
Copy
(
place
,
rpn_roi_probs_data
+
num_proposals
,
place
,
scores
.
data
<
T
>
(),
sizeof
(
T
)
*
scores
.
numel
(),
dev_ctx
.
stream
());
dev_ctx
.
Wait
();
num_proposals
+=
proposals
.
dims
()[
0
];
offset
.
emplace_back
(
num_proposals
);
tmp_num
.
push_back
(
proposals
.
dims
()[
0
]);
}
if
(
context
.
HasOutput
(
"RpnRoisNum"
))
{
auto
*
rpn_rois_num
=
context
.
Output
<
Tensor
>
(
"RpnRoisNum"
);
rpn_rois_num
->
mutable_data
<
int
>
({
num
},
context
.
GetPlace
());
int
*
num_data
=
rpn_rois_num
->
data
<
int
>
();
memory
::
Copy
(
place
,
num_data
,
cpu_place
,
&
tmp_num
[
0
],
sizeof
(
int
)
*
num
,
dev_ctx
.
stream
());
rpn_rois_num
->
Resize
({
num
});
}
framework
::
LoD
lod
;
lod
.
emplace_back
(
offset
);
rpn_rois
->
set_lod
(
lod
);
rpn_roi_probs
->
set_lod
(
lod
);
rpn_rois
->
Resize
({
num_proposals
,
4
});
rpn_roi_probs
->
Resize
({
num_proposals
,
1
});
}
};
}
// namespace operators
}
// namespace paddle
namespace
ops
=
paddle
::
operators
;
REGISTER_OP_CUDA_KERNEL
(
generate_proposals_v2
,
ops
::
CUDAGenerateProposalsV2Kernel
<
paddle
::
platform
::
CUDADeviceContext
,
float
>
);
paddle/fluid/operators/detection/nms_util.h
浏览文件 @
5262b025
...
@@ -99,5 +99,74 @@ T PolyIoU(const T* box1, const T* box2, const size_t box_size,
...
@@ -99,5 +99,74 @@ T PolyIoU(const T* box1, const T* box2, const size_t box_size,
}
}
}
}
template
<
class
T
>
static
inline
std
::
vector
<
std
::
pair
<
T
,
int
>>
GetSortedScoreIndex
(
const
std
::
vector
<
T
>&
scores
)
{
std
::
vector
<
std
::
pair
<
T
,
int
>>
sorted_indices
;
sorted_indices
.
reserve
(
scores
.
size
());
for
(
size_t
i
=
0
;
i
<
scores
.
size
();
++
i
)
{
sorted_indices
.
emplace_back
(
scores
[
i
],
i
);
}
// Sort the score pair according to the scores in descending order
std
::
stable_sort
(
sorted_indices
.
begin
(),
sorted_indices
.
end
(),
[](
const
std
::
pair
<
T
,
int
>&
a
,
const
std
::
pair
<
T
,
int
>&
b
)
{
return
a
.
first
<
b
.
first
;
});
return
sorted_indices
;
}
template
<
typename
T
>
static
inline
framework
::
Tensor
VectorToTensor
(
const
std
::
vector
<
T
>&
selected_indices
,
int
selected_num
)
{
framework
::
Tensor
keep_nms
;
keep_nms
.
Resize
({
selected_num
});
auto
*
keep_data
=
keep_nms
.
mutable_data
<
T
>
(
platform
::
CPUPlace
());
for
(
int
i
=
0
;
i
<
selected_num
;
++
i
)
{
keep_data
[
i
]
=
selected_indices
[
i
];
}
return
keep_nms
;
}
template
<
class
T
>
framework
::
Tensor
NMS
(
const
platform
::
DeviceContext
&
ctx
,
framework
::
Tensor
*
bbox
,
framework
::
Tensor
*
scores
,
T
nms_threshold
,
float
eta
)
{
int64_t
num_boxes
=
bbox
->
dims
()[
0
];
// 4: [xmin ymin xmax ymax]
int64_t
box_size
=
bbox
->
dims
()[
1
];
std
::
vector
<
T
>
scores_data
(
num_boxes
);
std
::
copy_n
(
scores
->
data
<
T
>
(),
num_boxes
,
scores_data
.
begin
());
std
::
vector
<
std
::
pair
<
T
,
int
>>
sorted_indices
=
GetSortedScoreIndex
<
T
>
(
scores_data
);
std
::
vector
<
int
>
selected_indices
;
int
selected_num
=
0
;
T
adaptive_threshold
=
nms_threshold
;
const
T
*
bbox_data
=
bbox
->
data
<
T
>
();
while
(
sorted_indices
.
size
()
!=
0
)
{
int
idx
=
sorted_indices
.
back
().
second
;
bool
flag
=
true
;
for
(
int
kept_idx
:
selected_indices
)
{
if
(
flag
)
{
T
overlap
=
JaccardOverlap
<
T
>
(
bbox_data
+
idx
*
box_size
,
bbox_data
+
kept_idx
*
box_size
,
false
);
flag
=
(
overlap
<=
adaptive_threshold
);
}
else
{
break
;
}
}
if
(
flag
)
{
selected_indices
.
push_back
(
idx
);
++
selected_num
;
}
sorted_indices
.
erase
(
sorted_indices
.
end
()
-
1
);
if
(
flag
&&
eta
<
1
&&
adaptive_threshold
>
0.5
)
{
adaptive_threshold
*=
eta
;
}
}
return
VectorToTensor
(
selected_indices
,
selected_num
);
}
}
// namespace operators
}
// namespace operators
}
// namespace paddle
}
// namespace paddle
paddle/fluid/pybind/op_function_generator.cc
浏览文件 @
5262b025
...
@@ -81,6 +81,7 @@ std::map<std::string, std::set<std::string>> op_outs_map = {
...
@@ -81,6 +81,7 @@ std::map<std::string, std::set<std::string>> op_outs_map = {
{
"MultiFpnRois"
,
"RestoreIndex"
,
"MultiLevelRoIsNum"
}},
{
"MultiFpnRois"
,
"RestoreIndex"
,
"MultiLevelRoIsNum"
}},
{
"moving_average_abs_max_scale"
,
{
"OutScale"
,
"OutAccum"
,
"OutState"
}},
{
"moving_average_abs_max_scale"
,
{
"OutScale"
,
"OutAccum"
,
"OutState"
}},
{
"multiclass_nms3"
,
{
"Out"
,
"NmsRoisNum"
}},
{
"multiclass_nms3"
,
{
"Out"
,
"NmsRoisNum"
}},
{
"generate_proposals_v2"
,
{
"RpnRois"
,
"RpnRoiProbs"
,
"RpnRoisNum"
}},
};
};
// NOTE(zhiqiu): Commonly, the outputs in auto-generated OP function are
// NOTE(zhiqiu): Commonly, the outputs in auto-generated OP function are
...
...
python/paddle/fluid/tests/unittests/test_generate_proposals_v2_op.py
0 → 100644
浏览文件 @
5262b025
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from
__future__
import
print_function
import
unittest
import
numpy
as
np
import
sys
import
math
import
paddle
import
paddle.fluid
as
fluid
from
op_test
import
OpTest
from
test_multiclass_nms_op
import
nms
from
test_anchor_generator_op
import
anchor_generator_in_python
import
copy
from
test_generate_proposals_op
import
clip_tiled_boxes
,
box_coder
,
nms
def
generate_proposals_v2_in_python
(
scores
,
bbox_deltas
,
im_shape
,
anchors
,
variances
,
pre_nms_topN
,
post_nms_topN
,
nms_thresh
,
min_size
,
eta
):
all_anchors
=
anchors
.
reshape
(
-
1
,
4
)
rois
=
np
.
empty
((
0
,
5
),
dtype
=
np
.
float32
)
roi_probs
=
np
.
empty
((
0
,
1
),
dtype
=
np
.
float32
)
rpn_rois
=
[]
rpn_roi_probs
=
[]
rois_num
=
[]
num_images
=
scores
.
shape
[
0
]
for
img_idx
in
range
(
num_images
):
img_i_boxes
,
img_i_probs
=
proposal_for_one_image
(
im_shape
[
img_idx
,
:],
all_anchors
,
variances
,
bbox_deltas
[
img_idx
,
:,
:,
:],
scores
[
img_idx
,
:,
:,
:],
pre_nms_topN
,
post_nms_topN
,
nms_thresh
,
min_size
,
eta
)
rois_num
.
append
(
img_i_probs
.
shape
[
0
])
rpn_rois
.
append
(
img_i_boxes
)
rpn_roi_probs
.
append
(
img_i_probs
)
return
rpn_rois
,
rpn_roi_probs
,
rois_num
def
proposal_for_one_image
(
im_shape
,
all_anchors
,
variances
,
bbox_deltas
,
scores
,
pre_nms_topN
,
post_nms_topN
,
nms_thresh
,
min_size
,
eta
):
# Transpose and reshape predicted bbox transformations to get them
# into the same order as the anchors:
# - bbox deltas will be (4 * A, H, W) format from conv output
# - transpose to (H, W, 4 * A)
# - reshape to (H * W * A, 4) where rows are ordered by (H, W, A)
# in slowest to fastest order to match the enumerated anchors
bbox_deltas
=
bbox_deltas
.
transpose
((
1
,
2
,
0
)).
reshape
(
-
1
,
4
)
all_anchors
=
all_anchors
.
reshape
(
-
1
,
4
)
variances
=
variances
.
reshape
(
-
1
,
4
)
# Same story for the scores:
# - scores are (A, H, W) format from conv output
# - transpose to (H, W, A)
# - reshape to (H * W * A, 1) where rows are ordered by (H, W, A)
# to match the order of anchors and bbox_deltas
scores
=
scores
.
transpose
((
1
,
2
,
0
)).
reshape
(
-
1
,
1
)
# sort all (proposal, score) pairs by score from highest to lowest
# take top pre_nms_topN (e.g. 6000)
if
pre_nms_topN
<=
0
or
pre_nms_topN
>=
len
(
scores
):
order
=
np
.
argsort
(
-
scores
.
squeeze
())
else
:
# Avoid sorting possibly large arrays;
# First partition to get top K unsorted
# and then sort just those
inds
=
np
.
argpartition
(
-
scores
.
squeeze
(),
pre_nms_topN
)[:
pre_nms_topN
]
order
=
np
.
argsort
(
-
scores
[
inds
].
squeeze
())
order
=
inds
[
order
]
scores
=
scores
[
order
,
:]
bbox_deltas
=
bbox_deltas
[
order
,
:]
all_anchors
=
all_anchors
[
order
,
:]
proposals
=
box_coder
(
all_anchors
,
bbox_deltas
,
variances
)
# clip proposals to image (may result in proposals with zero area
# that will be removed in the next step)
proposals
=
clip_tiled_boxes
(
proposals
,
im_shape
)
# remove predicted boxes with height or width < min_size
keep
=
filter_boxes
(
proposals
,
min_size
,
im_shape
)
if
len
(
keep
)
==
0
:
proposals
=
np
.
zeros
((
1
,
4
)).
astype
(
'float32'
)
scores
=
np
.
zeros
((
1
,
1
)).
astype
(
'float32'
)
return
proposals
,
scores
proposals
=
proposals
[
keep
,
:]
scores
=
scores
[
keep
,
:]
# apply loose nms (e.g. threshold = 0.7)
# take post_nms_topN (e.g. 1000)
# return the top proposals
if
nms_thresh
>
0
:
keep
=
nms
(
boxes
=
proposals
,
scores
=
scores
,
nms_threshold
=
nms_thresh
,
eta
=
eta
)
if
post_nms_topN
>
0
and
post_nms_topN
<
len
(
keep
):
keep
=
keep
[:
post_nms_topN
]
proposals
=
proposals
[
keep
,
:]
scores
=
scores
[
keep
,
:]
return
proposals
,
scores
def
filter_boxes
(
boxes
,
min_size
,
im_shape
):
"""Only keep boxes with both sides >= min_size and center within the image.
"""
# Scale min_size to match image scale
min_size
=
max
(
min_size
,
1.0
)
ws
=
boxes
[:,
2
]
-
boxes
[:,
0
]
+
1
hs
=
boxes
[:,
3
]
-
boxes
[:,
1
]
+
1
x_ctr
=
boxes
[:,
0
]
+
ws
/
2.
y_ctr
=
boxes
[:,
1
]
+
hs
/
2.
keep
=
np
.
where
((
ws
>=
min_size
)
&
(
hs
>=
min_size
)
&
(
x_ctr
<
im_shape
[
1
])
&
(
y_ctr
<
im_shape
[
0
]))[
0
]
return
keep
class
TestGenerateProposalsV2Op
(
OpTest
):
def
set_data
(
self
):
self
.
init_test_params
()
self
.
init_test_input
()
self
.
init_test_output
()
self
.
inputs
=
{
'Scores'
:
self
.
scores
,
'BboxDeltas'
:
self
.
bbox_deltas
,
'ImShape'
:
self
.
im_shape
.
astype
(
np
.
float32
),
'Anchors'
:
self
.
anchors
,
'Variances'
:
self
.
variances
}
self
.
attrs
=
{
'pre_nms_topN'
:
self
.
pre_nms_topN
,
'post_nms_topN'
:
self
.
post_nms_topN
,
'nms_thresh'
:
self
.
nms_thresh
,
'min_size'
:
self
.
min_size
,
'eta'
:
self
.
eta
}
self
.
outputs
=
{
'RpnRois'
:
(
self
.
rpn_rois
[
0
],
[
self
.
rois_num
]),
'RpnRoiProbs'
:
(
self
.
rpn_roi_probs
[
0
],
[
self
.
rois_num
]),
}
def
test_check_output
(
self
):
self
.
check_output
()
def
setUp
(
self
):
self
.
op_type
=
"generate_proposals_v2"
self
.
set_data
()
def
init_test_params
(
self
):
self
.
pre_nms_topN
=
12000
# train 12000, test 2000
self
.
post_nms_topN
=
5000
# train 6000, test 1000
self
.
nms_thresh
=
0.7
self
.
min_size
=
3.0
self
.
eta
=
1.
def
init_test_input
(
self
):
batch_size
=
1
input_channels
=
20
layer_h
=
16
layer_w
=
16
input_feat
=
np
.
random
.
random
(
(
batch_size
,
input_channels
,
layer_h
,
layer_w
)).
astype
(
'float32'
)
self
.
anchors
,
self
.
variances
=
anchor_generator_in_python
(
input_feat
=
input_feat
,
anchor_sizes
=
[
16.
,
32.
],
aspect_ratios
=
[
0.5
,
1.0
],
variances
=
[
1.0
,
1.0
,
1.0
,
1.0
],
stride
=
[
16.0
,
16.0
],
offset
=
0.5
)
self
.
im_shape
=
np
.
array
([[
64
,
64
]]).
astype
(
'float32'
)
num_anchors
=
self
.
anchors
.
shape
[
2
]
self
.
scores
=
np
.
random
.
random
(
(
batch_size
,
num_anchors
,
layer_h
,
layer_w
)).
astype
(
'float32'
)
self
.
bbox_deltas
=
np
.
random
.
random
(
(
batch_size
,
num_anchors
*
4
,
layer_h
,
layer_w
)).
astype
(
'float32'
)
def
init_test_output
(
self
):
self
.
rpn_rois
,
self
.
rpn_roi_probs
,
self
.
rois_num
=
generate_proposals_v2_in_python
(
self
.
scores
,
self
.
bbox_deltas
,
self
.
im_shape
,
self
.
anchors
,
self
.
variances
,
self
.
pre_nms_topN
,
self
.
post_nms_topN
,
self
.
nms_thresh
,
self
.
min_size
,
self
.
eta
)
class
TestGenerateProposalsV2OutLodOp
(
TestGenerateProposalsV2Op
):
def
set_data
(
self
):
self
.
init_test_params
()
self
.
init_test_input
()
self
.
init_test_output
()
self
.
inputs
=
{
'Scores'
:
self
.
scores
,
'BboxDeltas'
:
self
.
bbox_deltas
,
'ImShape'
:
self
.
im_shape
.
astype
(
np
.
float32
),
'Anchors'
:
self
.
anchors
,
'Variances'
:
self
.
variances
}
self
.
attrs
=
{
'pre_nms_topN'
:
self
.
pre_nms_topN
,
'post_nms_topN'
:
self
.
post_nms_topN
,
'nms_thresh'
:
self
.
nms_thresh
,
'min_size'
:
self
.
min_size
,
'eta'
:
self
.
eta
,
'return_rois_num'
:
True
}
self
.
outputs
=
{
'RpnRois'
:
(
self
.
rpn_rois
[
0
],
[
self
.
rois_num
]),
'RpnRoiProbs'
:
(
self
.
rpn_roi_probs
[
0
],
[
self
.
rois_num
]),
'RpnRoisNum'
:
(
np
.
asarray
(
self
.
rois_num
,
dtype
=
np
.
int32
))
}
class
TestGenerateProposalsV2OpNoBoxLeft
(
TestGenerateProposalsV2Op
):
def
init_test_params
(
self
):
self
.
pre_nms_topN
=
12000
# train 12000, test 2000
self
.
post_nms_topN
=
5000
# train 6000, test 1000
self
.
nms_thresh
=
0.7
self
.
min_size
=
1000.0
self
.
eta
=
1.
if
__name__
==
'__main__'
:
paddle
.
enable_static
()
unittest
.
main
()
tools/static_mode_white_list.py
浏览文件 @
5262b025
...
@@ -673,4 +673,5 @@ STATIC_MODE_TESTING_LIST = [
...
@@ -673,4 +673,5 @@ STATIC_MODE_TESTING_LIST = [
'test_sgd_op_xpu'
,
'test_sgd_op_xpu'
,
'test_shape_op_xpu'
,
'test_shape_op_xpu'
,
'test_slice_op_xpu'
,
'test_slice_op_xpu'
,
'test_generate_proposals_v2_op'
,
]
]
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录