Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
s920243400
PaddleDetection
提交
4630e56b
P
PaddleDetection
项目概览
s920243400
/
PaddleDetection
与 Fork 源项目一致
Fork自
PaddlePaddle / PaddleDetection
通知
2
Star
0
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
PaddleDetection
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
4630e56b
编写于
9月 06, 2021
作者:
W
wangxinxin08
提交者:
GitHub
9月 06, 2021
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
polish the code of rbox iou (#4123)
上级
ff96c78d
变更
5
隐藏空白更改
内联
并排
Showing
5 changed file
with
475 addition
and
449 deletion
+475
-449
ppdet/ext_op/rbox_iou_op.cc
ppdet/ext_op/rbox_iou_op.cc
+53
-4
ppdet/ext_op/rbox_iou_op.cu
ppdet/ext_op/rbox_iou_op.cu
+2
-391
ppdet/ext_op/rbox_iou_op.h
ppdet/ext_op/rbox_iou_op.h
+353
-0
ppdet/ext_op/setup.py
ppdet/ext_op/setup.py
+12
-4
ppdet/ext_op/test.py
ppdet/ext_op/test.py
+55
-50
未找到文件。
ppdet/ext_op/rbox_iou_op.cc
浏览文件 @
4630e56b
...
...
@@ -11,22 +11,71 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "rbox_iou_op.h"
#include "paddle/extension.h"
#include <vector>
std
::
vector
<
paddle
::
Tensor
>
RboxIouCPUForward
(
const
paddle
::
Tensor
&
rbox1
,
const
paddle
::
Tensor
&
rbox2
);
template
<
typename
T
>
void
rbox_iou_cpu_kernel
(
const
int
rbox1_num
,
const
int
rbox2_num
,
const
T
*
rbox1_data_ptr
,
const
T
*
rbox2_data_ptr
,
T
*
output_data_ptr
)
{
int
i
,
j
;
for
(
i
=
0
;
i
<
rbox1_num
;
i
++
)
{
for
(
j
=
0
;
j
<
rbox2_num
;
j
++
)
{
int
offset
=
i
*
rbox2_num
+
j
;
output_data_ptr
[
offset
]
=
rbox_iou_single
<
T
>
(
rbox1_data_ptr
+
i
*
5
,
rbox2_data_ptr
+
j
*
5
);
}
}
}
#define CHECK_INPUT_CPU(x) PD_CHECK(x.place() == paddle::PlaceType::kCPU, #x " must be a CPU Tensor.")
std
::
vector
<
paddle
::
Tensor
>
RboxIouCPUForward
(
const
paddle
::
Tensor
&
rbox1
,
const
paddle
::
Tensor
&
rbox2
)
{
CHECK_INPUT_CPU
(
rbox1
);
CHECK_INPUT_CPU
(
rbox2
);
auto
rbox1_num
=
rbox1
.
shape
()[
0
];
auto
rbox2_num
=
rbox2
.
shape
()[
0
];
auto
output
=
paddle
::
Tensor
(
paddle
::
PlaceType
::
kCPU
);
output
.
reshape
({
rbox1_num
,
rbox2_num
});
PD_DISPATCH_FLOATING_TYPES
(
rbox1
.
type
(),
"rbox_iou_cpu_kernel"
,
([
&
]
{
rbox_iou_cpu_kernel
<
data_t
>
(
rbox1_num
,
rbox2_num
,
rbox1
.
data
<
data_t
>
(),
rbox2
.
data
<
data_t
>
(),
output
.
mutable_data
<
data_t
>
());
}));
return
{
output
};
}
#ifdef PADDLE_WITH_CUDA
std
::
vector
<
paddle
::
Tensor
>
RboxIouCUDAForward
(
const
paddle
::
Tensor
&
rbox1
,
const
paddle
::
Tensor
&
rbox2
);
#endif
#define CHECK_INPUT_SAME(x1, x2) PD_CHECK(x1.place() == x2.place(), "input must be smae pacle.")
std
::
vector
<
paddle
::
Tensor
>
RboxIouForward
(
const
paddle
::
Tensor
&
rbox1
,
const
paddle
::
Tensor
&
rbox2
)
{
CHECK_INPUT_SAME
(
rbox1
,
rbox2
);
if
(
rbox1
.
place
()
==
paddle
::
PlaceType
::
kCPU
)
{
return
RboxIouCPUForward
(
rbox1
,
rbox2
);
}
else
if
(
rbox1
.
place
()
==
paddle
::
PlaceType
::
kGPU
)
{
#ifdef PADDLE_WITH_CUDA
}
else
if
(
rbox1
.
place
()
==
paddle
::
PlaceType
::
kGPU
)
{
return
RboxIouCUDAForward
(
rbox1
,
rbox2
);
#endif
}
}
...
...
ppdet/ext_op/rbox_iou_op.cu
浏览文件 @
4630e56b
...
...
@@ -11,350 +11,9 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include <cassert>
#include <cmath>
#ifdef __CUDACC__
// Designates functions callable from the host (CPU) and the device (GPU)
#define HOST_DEVICE __host__ __device__
#define HOST_DEVICE_INLINE HOST_DEVICE __forceinline__
#else
#include <algorithm>
#define HOST_DEVICE
#define HOST_DEVICE_INLINE HOST_DEVICE inline
#endif
#include "rbox_iou_op.h"
#include "paddle/extension.h"
#include <vector>
namespace
{
template
<
typename
T
>
struct
RotatedBox
{
T
x_ctr
,
y_ctr
,
w
,
h
,
a
;
};
template
<
typename
T
>
struct
Point
{
T
x
,
y
;
HOST_DEVICE_INLINE
Point
(
const
T
&
px
=
0
,
const
T
&
py
=
0
)
:
x
(
px
),
y
(
py
)
{}
HOST_DEVICE_INLINE
Point
operator
+
(
const
Point
&
p
)
const
{
return
Point
(
x
+
p
.
x
,
y
+
p
.
y
);
}
HOST_DEVICE_INLINE
Point
&
operator
+=
(
const
Point
&
p
)
{
x
+=
p
.
x
;
y
+=
p
.
y
;
return
*
this
;
}
HOST_DEVICE_INLINE
Point
operator
-
(
const
Point
&
p
)
const
{
return
Point
(
x
-
p
.
x
,
y
-
p
.
y
);
}
HOST_DEVICE_INLINE
Point
operator
*
(
const
T
coeff
)
const
{
return
Point
(
x
*
coeff
,
y
*
coeff
);
}
};
template
<
typename
T
>
HOST_DEVICE_INLINE
T
dot_2d
(
const
Point
<
T
>&
A
,
const
Point
<
T
>&
B
)
{
return
A
.
x
*
B
.
x
+
A
.
y
*
B
.
y
;
}
template
<
typename
T
>
HOST_DEVICE_INLINE
T
cross_2d
(
const
Point
<
T
>&
A
,
const
Point
<
T
>&
B
)
{
return
A
.
x
*
B
.
y
-
B
.
x
*
A
.
y
;
}
template
<
typename
T
>
HOST_DEVICE_INLINE
void
get_rotated_vertices
(
const
RotatedBox
<
T
>&
box
,
Point
<
T
>
(
&
pts
)[
4
])
{
// M_PI / 180. == 0.01745329251
//double theta = box.a * 0.01745329251;
//MODIFIED
double
theta
=
box
.
a
;
T
cosTheta2
=
(
T
)
cos
(
theta
)
*
0.5
f
;
T
sinTheta2
=
(
T
)
sin
(
theta
)
*
0.5
f
;
// y: top --> down; x: left --> right
pts
[
0
].
x
=
box
.
x_ctr
-
sinTheta2
*
box
.
h
-
cosTheta2
*
box
.
w
;
pts
[
0
].
y
=
box
.
y_ctr
+
cosTheta2
*
box
.
h
-
sinTheta2
*
box
.
w
;
pts
[
1
].
x
=
box
.
x_ctr
+
sinTheta2
*
box
.
h
-
cosTheta2
*
box
.
w
;
pts
[
1
].
y
=
box
.
y_ctr
-
cosTheta2
*
box
.
h
-
sinTheta2
*
box
.
w
;
pts
[
2
].
x
=
2
*
box
.
x_ctr
-
pts
[
0
].
x
;
pts
[
2
].
y
=
2
*
box
.
y_ctr
-
pts
[
0
].
y
;
pts
[
3
].
x
=
2
*
box
.
x_ctr
-
pts
[
1
].
x
;
pts
[
3
].
y
=
2
*
box
.
y_ctr
-
pts
[
1
].
y
;
}
template
<
typename
T
>
HOST_DEVICE_INLINE
int
get_intersection_points
(
const
Point
<
T
>
(
&
pts1
)[
4
],
const
Point
<
T
>
(
&
pts2
)[
4
],
Point
<
T
>
(
&
intersections
)[
24
])
{
// Line vector
// A line from p1 to p2 is: p1 + (p2-p1)*t, t=[0,1]
Point
<
T
>
vec1
[
4
],
vec2
[
4
];
for
(
int
i
=
0
;
i
<
4
;
i
++
)
{
vec1
[
i
]
=
pts1
[(
i
+
1
)
%
4
]
-
pts1
[
i
];
vec2
[
i
]
=
pts2
[(
i
+
1
)
%
4
]
-
pts2
[
i
];
}
// Line test - test all line combos for intersection
int
num
=
0
;
// number of intersections
for
(
int
i
=
0
;
i
<
4
;
i
++
)
{
for
(
int
j
=
0
;
j
<
4
;
j
++
)
{
// Solve for 2x2 Ax=b
T
det
=
cross_2d
<
T
>
(
vec2
[
j
],
vec1
[
i
]);
// This takes care of parallel lines
if
(
fabs
(
det
)
<=
1e-14
)
{
continue
;
}
auto
vec12
=
pts2
[
j
]
-
pts1
[
i
];
T
t1
=
cross_2d
<
T
>
(
vec2
[
j
],
vec12
)
/
det
;
T
t2
=
cross_2d
<
T
>
(
vec1
[
i
],
vec12
)
/
det
;
if
(
t1
>=
0.0
f
&&
t1
<=
1.0
f
&&
t2
>=
0.0
f
&&
t2
<=
1.0
f
)
{
intersections
[
num
++
]
=
pts1
[
i
]
+
vec1
[
i
]
*
t1
;
}
}
}
// Check for vertices of rect1 inside rect2
{
const
auto
&
AB
=
vec2
[
0
];
const
auto
&
DA
=
vec2
[
3
];
auto
ABdotAB
=
dot_2d
<
T
>
(
AB
,
AB
);
auto
ADdotAD
=
dot_2d
<
T
>
(
DA
,
DA
);
for
(
int
i
=
0
;
i
<
4
;
i
++
)
{
// assume ABCD is the rectangle, and P is the point to be judged
// P is inside ABCD iff. P's projection on AB lies within AB
// and P's projection on AD lies within AD
auto
AP
=
pts1
[
i
]
-
pts2
[
0
];
auto
APdotAB
=
dot_2d
<
T
>
(
AP
,
AB
);
auto
APdotAD
=
-
dot_2d
<
T
>
(
AP
,
DA
);
if
((
APdotAB
>=
0
)
&&
(
APdotAD
>=
0
)
&&
(
APdotAB
<=
ABdotAB
)
&&
(
APdotAD
<=
ADdotAD
))
{
intersections
[
num
++
]
=
pts1
[
i
];
}
}
}
// Reverse the check - check for vertices of rect2 inside rect1
{
const
auto
&
AB
=
vec1
[
0
];
const
auto
&
DA
=
vec1
[
3
];
auto
ABdotAB
=
dot_2d
<
T
>
(
AB
,
AB
);
auto
ADdotAD
=
dot_2d
<
T
>
(
DA
,
DA
);
for
(
int
i
=
0
;
i
<
4
;
i
++
)
{
auto
AP
=
pts2
[
i
]
-
pts1
[
0
];
auto
APdotAB
=
dot_2d
<
T
>
(
AP
,
AB
);
auto
APdotAD
=
-
dot_2d
<
T
>
(
AP
,
DA
);
if
((
APdotAB
>=
0
)
&&
(
APdotAD
>=
0
)
&&
(
APdotAB
<=
ABdotAB
)
&&
(
APdotAD
<=
ADdotAD
))
{
intersections
[
num
++
]
=
pts2
[
i
];
}
}
}
return
num
;
}
template
<
typename
T
>
HOST_DEVICE_INLINE
int
convex_hull_graham
(
const
Point
<
T
>
(
&
p
)[
24
],
const
int
&
num_in
,
Point
<
T
>
(
&
q
)[
24
],
bool
shift_to_zero
=
false
)
{
assert
(
num_in
>=
2
);
// Step 1:
// Find point with minimum y
// if more than 1 points have the same minimum y,
// pick the one with the minimum x.
int
t
=
0
;
for
(
int
i
=
1
;
i
<
num_in
;
i
++
)
{
if
(
p
[
i
].
y
<
p
[
t
].
y
||
(
p
[
i
].
y
==
p
[
t
].
y
&&
p
[
i
].
x
<
p
[
t
].
x
))
{
t
=
i
;
}
}
auto
&
start
=
p
[
t
];
// starting point
// Step 2:
// Subtract starting point from every points (for sorting in the next step)
for
(
int
i
=
0
;
i
<
num_in
;
i
++
)
{
q
[
i
]
=
p
[
i
]
-
start
;
}
// Swap the starting point to position 0
auto
tmp
=
q
[
0
];
q
[
0
]
=
q
[
t
];
q
[
t
]
=
tmp
;
// Step 3:
// Sort point 1 ~ num_in according to their relative cross-product values
// (essentially sorting according to angles)
// If the angles are the same, sort according to their distance to origin
T
dist
[
24
];
for
(
int
i
=
0
;
i
<
num_in
;
i
++
)
{
dist
[
i
]
=
dot_2d
<
T
>
(
q
[
i
],
q
[
i
]);
}
#ifdef __CUDACC__
// CUDA version
// In the future, we can potentially use thrust
// for sorting here to improve speed (though not guaranteed)
for
(
int
i
=
1
;
i
<
num_in
-
1
;
i
++
)
{
for
(
int
j
=
i
+
1
;
j
<
num_in
;
j
++
)
{
T
crossProduct
=
cross_2d
<
T
>
(
q
[
i
],
q
[
j
]);
if
((
crossProduct
<
-
1e-6
)
||
(
fabs
(
crossProduct
)
<
1e-6
&&
dist
[
i
]
>
dist
[
j
]))
{
auto
q_tmp
=
q
[
i
];
q
[
i
]
=
q
[
j
];
q
[
j
]
=
q_tmp
;
auto
dist_tmp
=
dist
[
i
];
dist
[
i
]
=
dist
[
j
];
dist
[
j
]
=
dist_tmp
;
}
}
}
#else
// CPU version
std
::
sort
(
q
+
1
,
q
+
num_in
,
[](
const
Point
<
T
>&
A
,
const
Point
<
T
>&
B
)
->
bool
{
T
temp
=
cross_2d
<
T
>
(
A
,
B
);
if
(
fabs
(
temp
)
<
1e-6
)
{
return
dot_2d
<
T
>
(
A
,
A
)
<
dot_2d
<
T
>
(
B
,
B
);
}
else
{
return
temp
>
0
;
}
});
#endif
// Step 4:
// Make sure there are at least 2 points (that don't overlap with each other)
// in the stack
int
k
;
// index of the non-overlapped second point
for
(
k
=
1
;
k
<
num_in
;
k
++
)
{
if
(
dist
[
k
]
>
1e-8
)
{
break
;
}
}
if
(
k
==
num_in
)
{
// We reach the end, which means the convex hull is just one point
q
[
0
]
=
p
[
t
];
return
1
;
}
q
[
1
]
=
q
[
k
];
int
m
=
2
;
// 2 points in the stack
// Step 5:
// Finally we can start the scanning process.
// When a non-convex relationship between the 3 points is found
// (either concave shape or duplicated points),
// we pop the previous point from the stack
// until the 3-point relationship is convex again, or
// until the stack only contains two points
for
(
int
i
=
k
+
1
;
i
<
num_in
;
i
++
)
{
while
(
m
>
1
&&
cross_2d
<
T
>
(
q
[
i
]
-
q
[
m
-
2
],
q
[
m
-
1
]
-
q
[
m
-
2
])
>=
0
)
{
m
--
;
}
q
[
m
++
]
=
q
[
i
];
}
// Step 6 (Optional):
// In general sense we need the original coordinates, so we
// need to shift the points back (reverting Step 2)
// But if we're only interested in getting the area/perimeter of the shape
// We can simply return.
if
(
!
shift_to_zero
)
{
for
(
int
i
=
0
;
i
<
m
;
i
++
)
{
q
[
i
]
+=
start
;
}
}
return
m
;
}
template
<
typename
T
>
HOST_DEVICE_INLINE
T
polygon_area
(
const
Point
<
T
>
(
&
q
)[
24
],
const
int
&
m
)
{
if
(
m
<=
2
)
{
return
0
;
}
T
area
=
0
;
for
(
int
i
=
1
;
i
<
m
-
1
;
i
++
)
{
area
+=
fabs
(
cross_2d
<
T
>
(
q
[
i
]
-
q
[
0
],
q
[
i
+
1
]
-
q
[
0
]));
}
return
area
/
2.0
;
}
template
<
typename
T
>
HOST_DEVICE_INLINE
T
rboxes_intersection
(
const
RotatedBox
<
T
>&
box1
,
const
RotatedBox
<
T
>&
box2
)
{
// There are up to 4 x 4 + 4 + 4 = 24 intersections (including dups) returned
// from rotated_rect_intersection_pts
Point
<
T
>
intersectPts
[
24
],
orderedPts
[
24
];
Point
<
T
>
pts1
[
4
];
Point
<
T
>
pts2
[
4
];
get_rotated_vertices
<
T
>
(
box1
,
pts1
);
get_rotated_vertices
<
T
>
(
box2
,
pts2
);
int
num
=
get_intersection_points
<
T
>
(
pts1
,
pts2
,
intersectPts
);
if
(
num
<=
2
)
{
return
0.0
;
}
// Convex Hull to order the intersection points in clockwise order and find
// the contour area.
int
num_convex
=
convex_hull_graham
<
T
>
(
intersectPts
,
num
,
orderedPts
,
true
);
return
polygon_area
<
T
>
(
orderedPts
,
num_convex
);
}
}
// namespace
template
<
typename
T
>
HOST_DEVICE_INLINE
T
rbox_iou_single
(
T
const
*
const
box1_raw
,
T
const
*
const
box2_raw
)
{
// shift center to the middle point to achieve higher precision in result
RotatedBox
<
T
>
box1
,
box2
;
auto
center_shift_x
=
(
box1_raw
[
0
]
+
box2_raw
[
0
])
/
2.0
;
auto
center_shift_y
=
(
box1_raw
[
1
]
+
box2_raw
[
1
])
/
2.0
;
box1
.
x_ctr
=
box1_raw
[
0
]
-
center_shift_x
;
box1
.
y_ctr
=
box1_raw
[
1
]
-
center_shift_y
;
box1
.
w
=
box1_raw
[
2
];
box1
.
h
=
box1_raw
[
3
];
box1
.
a
=
box1_raw
[
4
];
box2
.
x_ctr
=
box2_raw
[
0
]
-
center_shift_x
;
box2
.
y_ctr
=
box2_raw
[
1
]
-
center_shift_y
;
box2
.
w
=
box2_raw
[
2
];
box2
.
h
=
box2_raw
[
3
];
box2
.
a
=
box2_raw
[
4
];
const
T
area1
=
box1
.
w
*
box1
.
h
;
const
T
area2
=
box2
.
w
*
box2
.
h
;
if
(
area1
<
1e-14
||
area2
<
1e-14
)
{
return
0.
f
;
}
const
T
intersection
=
rboxes_intersection
<
T
>
(
box1
,
box2
);
const
T
iou
=
intersection
/
(
area1
+
area2
-
intersection
);
return
iou
;
}
// 2D block with 32 * 16 = 512 threads per block
const
int
BLOCK_DIM_X
=
32
;
const
int
BLOCK_DIM_Y
=
16
;
...
...
@@ -362,13 +21,9 @@ const int BLOCK_DIM_Y = 16;
/**
Computes ceil(a / b)
*/
template
<
typename
T
>
__host__
__device__
__forceinline__
T
CeilDiv0
(
T
a
,
T
b
)
{
return
(
a
+
b
-
1
)
/
b
;
}
static
inline
int
CeilDiv
(
const
int
a
,
const
int
b
)
{
return
(
a
+
b
-
1
)
/
b
;
return
(
a
+
b
-
1
)
/
b
;
}
template
<
typename
T
>
...
...
@@ -461,47 +116,3 @@ std::vector<paddle::Tensor> RboxIouCUDAForward(const paddle::Tensor& rbox1, cons
}
template
<
typename
T
>
void
rbox_iou_cpu_kernel
(
const
int
rbox1_num
,
const
int
rbox2_num
,
const
T
*
rbox1_data_ptr
,
const
T
*
rbox2_data_ptr
,
T
*
output_data_ptr
)
{
int
i
,
j
;
for
(
i
=
0
;
i
<
rbox1_num
;
i
++
)
{
for
(
j
=
0
;
j
<
rbox2_num
;
j
++
)
{
int
offset
=
i
*
rbox2_num
+
j
;
output_data_ptr
[
offset
]
=
rbox_iou_single
<
T
>
(
rbox1_data_ptr
+
i
*
5
,
rbox2_data_ptr
+
j
*
5
);
}
}
}
#define CHECK_INPUT_CPU(x) PD_CHECK(x.place() == paddle::PlaceType::kCPU, #x " must be a CPU Tensor.")
std
::
vector
<
paddle
::
Tensor
>
RboxIouCPUForward
(
const
paddle
::
Tensor
&
rbox1
,
const
paddle
::
Tensor
&
rbox2
)
{
CHECK_INPUT_CPU
(
rbox1
);
CHECK_INPUT_CPU
(
rbox2
);
auto
rbox1_num
=
rbox1
.
shape
()[
0
];
auto
rbox2_num
=
rbox2
.
shape
()[
0
];
auto
output
=
paddle
::
Tensor
(
paddle
::
PlaceType
::
kCPU
);
output
.
reshape
({
rbox1_num
,
rbox2_num
});
PD_DISPATCH_FLOATING_TYPES
(
rbox1
.
type
(),
"rbox_iou_cpu_kernel"
,
([
&
]
{
rbox_iou_cpu_kernel
<
data_t
>
(
rbox1_num
,
rbox2_num
,
rbox1
.
data
<
data_t
>
(),
rbox2
.
data
<
data_t
>
(),
output
.
mutable_data
<
data_t
>
());
}));
return
{
output
};
}
ppdet/ext_op/rbox_iou_op.h
0 → 100644
浏览文件 @
4630e56b
/* Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include <cassert>
#include <cmath>
#include <vector>
#ifdef __CUDACC__
// Designates functions callable from the host (CPU) and the device (GPU)
#define HOST_DEVICE __host__ __device__
#define HOST_DEVICE_INLINE HOST_DEVICE __forceinline__
#else
#include <algorithm>
#define HOST_DEVICE
#define HOST_DEVICE_INLINE HOST_DEVICE inline
#endif
namespace
{
template
<
typename
T
>
struct
RotatedBox
{
T
x_ctr
,
y_ctr
,
w
,
h
,
a
;
};
template
<
typename
T
>
struct
Point
{
T
x
,
y
;
HOST_DEVICE_INLINE
Point
(
const
T
&
px
=
0
,
const
T
&
py
=
0
)
:
x
(
px
),
y
(
py
)
{}
HOST_DEVICE_INLINE
Point
operator
+
(
const
Point
&
p
)
const
{
return
Point
(
x
+
p
.
x
,
y
+
p
.
y
);
}
HOST_DEVICE_INLINE
Point
&
operator
+=
(
const
Point
&
p
)
{
x
+=
p
.
x
;
y
+=
p
.
y
;
return
*
this
;
}
HOST_DEVICE_INLINE
Point
operator
-
(
const
Point
&
p
)
const
{
return
Point
(
x
-
p
.
x
,
y
-
p
.
y
);
}
HOST_DEVICE_INLINE
Point
operator
*
(
const
T
coeff
)
const
{
return
Point
(
x
*
coeff
,
y
*
coeff
);
}
};
template
<
typename
T
>
HOST_DEVICE_INLINE
T
dot_2d
(
const
Point
<
T
>&
A
,
const
Point
<
T
>&
B
)
{
return
A
.
x
*
B
.
x
+
A
.
y
*
B
.
y
;
}
template
<
typename
T
>
HOST_DEVICE_INLINE
T
cross_2d
(
const
Point
<
T
>&
A
,
const
Point
<
T
>&
B
)
{
return
A
.
x
*
B
.
y
-
B
.
x
*
A
.
y
;
}
template
<
typename
T
>
HOST_DEVICE_INLINE
void
get_rotated_vertices
(
const
RotatedBox
<
T
>&
box
,
Point
<
T
>
(
&
pts
)[
4
])
{
// M_PI / 180. == 0.01745329251
//double theta = box.a * 0.01745329251;
//MODIFIED
double
theta
=
box
.
a
;
T
cosTheta2
=
(
T
)
cos
(
theta
)
*
0.5
f
;
T
sinTheta2
=
(
T
)
sin
(
theta
)
*
0.5
f
;
// y: top --> down; x: left --> right
pts
[
0
].
x
=
box
.
x_ctr
-
sinTheta2
*
box
.
h
-
cosTheta2
*
box
.
w
;
pts
[
0
].
y
=
box
.
y_ctr
+
cosTheta2
*
box
.
h
-
sinTheta2
*
box
.
w
;
pts
[
1
].
x
=
box
.
x_ctr
+
sinTheta2
*
box
.
h
-
cosTheta2
*
box
.
w
;
pts
[
1
].
y
=
box
.
y_ctr
-
cosTheta2
*
box
.
h
-
sinTheta2
*
box
.
w
;
pts
[
2
].
x
=
2
*
box
.
x_ctr
-
pts
[
0
].
x
;
pts
[
2
].
y
=
2
*
box
.
y_ctr
-
pts
[
0
].
y
;
pts
[
3
].
x
=
2
*
box
.
x_ctr
-
pts
[
1
].
x
;
pts
[
3
].
y
=
2
*
box
.
y_ctr
-
pts
[
1
].
y
;
}
template
<
typename
T
>
HOST_DEVICE_INLINE
int
get_intersection_points
(
const
Point
<
T
>
(
&
pts1
)[
4
],
const
Point
<
T
>
(
&
pts2
)[
4
],
Point
<
T
>
(
&
intersections
)[
24
])
{
// Line vector
// A line from p1 to p2 is: p1 + (p2-p1)*t, t=[0,1]
Point
<
T
>
vec1
[
4
],
vec2
[
4
];
for
(
int
i
=
0
;
i
<
4
;
i
++
)
{
vec1
[
i
]
=
pts1
[(
i
+
1
)
%
4
]
-
pts1
[
i
];
vec2
[
i
]
=
pts2
[(
i
+
1
)
%
4
]
-
pts2
[
i
];
}
// Line test - test all line combos for intersection
int
num
=
0
;
// number of intersections
for
(
int
i
=
0
;
i
<
4
;
i
++
)
{
for
(
int
j
=
0
;
j
<
4
;
j
++
)
{
// Solve for 2x2 Ax=b
T
det
=
cross_2d
<
T
>
(
vec2
[
j
],
vec1
[
i
]);
// This takes care of parallel lines
if
(
fabs
(
det
)
<=
1e-14
)
{
continue
;
}
auto
vec12
=
pts2
[
j
]
-
pts1
[
i
];
T
t1
=
cross_2d
<
T
>
(
vec2
[
j
],
vec12
)
/
det
;
T
t2
=
cross_2d
<
T
>
(
vec1
[
i
],
vec12
)
/
det
;
if
(
t1
>=
0.0
f
&&
t1
<=
1.0
f
&&
t2
>=
0.0
f
&&
t2
<=
1.0
f
)
{
intersections
[
num
++
]
=
pts1
[
i
]
+
vec1
[
i
]
*
t1
;
}
}
}
// Check for vertices of rect1 inside rect2
{
const
auto
&
AB
=
vec2
[
0
];
const
auto
&
DA
=
vec2
[
3
];
auto
ABdotAB
=
dot_2d
<
T
>
(
AB
,
AB
);
auto
ADdotAD
=
dot_2d
<
T
>
(
DA
,
DA
);
for
(
int
i
=
0
;
i
<
4
;
i
++
)
{
// assume ABCD is the rectangle, and P is the point to be judged
// P is inside ABCD iff. P's projection on AB lies within AB
// and P's projection on AD lies within AD
auto
AP
=
pts1
[
i
]
-
pts2
[
0
];
auto
APdotAB
=
dot_2d
<
T
>
(
AP
,
AB
);
auto
APdotAD
=
-
dot_2d
<
T
>
(
AP
,
DA
);
if
((
APdotAB
>=
0
)
&&
(
APdotAD
>=
0
)
&&
(
APdotAB
<=
ABdotAB
)
&&
(
APdotAD
<=
ADdotAD
))
{
intersections
[
num
++
]
=
pts1
[
i
];
}
}
}
// Reverse the check - check for vertices of rect2 inside rect1
{
const
auto
&
AB
=
vec1
[
0
];
const
auto
&
DA
=
vec1
[
3
];
auto
ABdotAB
=
dot_2d
<
T
>
(
AB
,
AB
);
auto
ADdotAD
=
dot_2d
<
T
>
(
DA
,
DA
);
for
(
int
i
=
0
;
i
<
4
;
i
++
)
{
auto
AP
=
pts2
[
i
]
-
pts1
[
0
];
auto
APdotAB
=
dot_2d
<
T
>
(
AP
,
AB
);
auto
APdotAD
=
-
dot_2d
<
T
>
(
AP
,
DA
);
if
((
APdotAB
>=
0
)
&&
(
APdotAD
>=
0
)
&&
(
APdotAB
<=
ABdotAB
)
&&
(
APdotAD
<=
ADdotAD
))
{
intersections
[
num
++
]
=
pts2
[
i
];
}
}
}
return
num
;
}
template
<
typename
T
>
HOST_DEVICE_INLINE
int
convex_hull_graham
(
const
Point
<
T
>
(
&
p
)[
24
],
const
int
&
num_in
,
Point
<
T
>
(
&
q
)[
24
],
bool
shift_to_zero
=
false
)
{
assert
(
num_in
>=
2
);
// Step 1:
// Find point with minimum y
// if more than 1 points have the same minimum y,
// pick the one with the minimum x.
int
t
=
0
;
for
(
int
i
=
1
;
i
<
num_in
;
i
++
)
{
if
(
p
[
i
].
y
<
p
[
t
].
y
||
(
p
[
i
].
y
==
p
[
t
].
y
&&
p
[
i
].
x
<
p
[
t
].
x
))
{
t
=
i
;
}
}
auto
&
start
=
p
[
t
];
// starting point
// Step 2:
// Subtract starting point from every points (for sorting in the next step)
for
(
int
i
=
0
;
i
<
num_in
;
i
++
)
{
q
[
i
]
=
p
[
i
]
-
start
;
}
// Swap the starting point to position 0
auto
tmp
=
q
[
0
];
q
[
0
]
=
q
[
t
];
q
[
t
]
=
tmp
;
// Step 3:
// Sort point 1 ~ num_in according to their relative cross-product values
// (essentially sorting according to angles)
// If the angles are the same, sort according to their distance to origin
T
dist
[
24
];
for
(
int
i
=
0
;
i
<
num_in
;
i
++
)
{
dist
[
i
]
=
dot_2d
<
T
>
(
q
[
i
],
q
[
i
]);
}
#ifdef __CUDACC__
// CUDA version
// In the future, we can potentially use thrust
// for sorting here to improve speed (though not guaranteed)
for
(
int
i
=
1
;
i
<
num_in
-
1
;
i
++
)
{
for
(
int
j
=
i
+
1
;
j
<
num_in
;
j
++
)
{
T
crossProduct
=
cross_2d
<
T
>
(
q
[
i
],
q
[
j
]);
if
((
crossProduct
<
-
1e-6
)
||
(
fabs
(
crossProduct
)
<
1e-6
&&
dist
[
i
]
>
dist
[
j
]))
{
auto
q_tmp
=
q
[
i
];
q
[
i
]
=
q
[
j
];
q
[
j
]
=
q_tmp
;
auto
dist_tmp
=
dist
[
i
];
dist
[
i
]
=
dist
[
j
];
dist
[
j
]
=
dist_tmp
;
}
}
}
#else
// CPU version
std
::
sort
(
q
+
1
,
q
+
num_in
,
[](
const
Point
<
T
>&
A
,
const
Point
<
T
>&
B
)
->
bool
{
T
temp
=
cross_2d
<
T
>
(
A
,
B
);
if
(
fabs
(
temp
)
<
1e-6
)
{
return
dot_2d
<
T
>
(
A
,
A
)
<
dot_2d
<
T
>
(
B
,
B
);
}
else
{
return
temp
>
0
;
}
});
#endif
// Step 4:
// Make sure there are at least 2 points (that don't overlap with each other)
// in the stack
int
k
;
// index of the non-overlapped second point
for
(
k
=
1
;
k
<
num_in
;
k
++
)
{
if
(
dist
[
k
]
>
1e-8
)
{
break
;
}
}
if
(
k
==
num_in
)
{
// We reach the end, which means the convex hull is just one point
q
[
0
]
=
p
[
t
];
return
1
;
}
q
[
1
]
=
q
[
k
];
int
m
=
2
;
// 2 points in the stack
// Step 5:
// Finally we can start the scanning process.
// When a non-convex relationship between the 3 points is found
// (either concave shape or duplicated points),
// we pop the previous point from the stack
// until the 3-point relationship is convex again, or
// until the stack only contains two points
for
(
int
i
=
k
+
1
;
i
<
num_in
;
i
++
)
{
while
(
m
>
1
&&
cross_2d
<
T
>
(
q
[
i
]
-
q
[
m
-
2
],
q
[
m
-
1
]
-
q
[
m
-
2
])
>=
0
)
{
m
--
;
}
q
[
m
++
]
=
q
[
i
];
}
// Step 6 (Optional):
// In general sense we need the original coordinates, so we
// need to shift the points back (reverting Step 2)
// But if we're only interested in getting the area/perimeter of the shape
// We can simply return.
if
(
!
shift_to_zero
)
{
for
(
int
i
=
0
;
i
<
m
;
i
++
)
{
q
[
i
]
+=
start
;
}
}
return
m
;
}
template
<
typename
T
>
HOST_DEVICE_INLINE
T
polygon_area
(
const
Point
<
T
>
(
&
q
)[
24
],
const
int
&
m
)
{
if
(
m
<=
2
)
{
return
0
;
}
T
area
=
0
;
for
(
int
i
=
1
;
i
<
m
-
1
;
i
++
)
{
area
+=
fabs
(
cross_2d
<
T
>
(
q
[
i
]
-
q
[
0
],
q
[
i
+
1
]
-
q
[
0
]));
}
return
area
/
2.0
;
}
template
<
typename
T
>
HOST_DEVICE_INLINE
T
rboxes_intersection
(
const
RotatedBox
<
T
>&
box1
,
const
RotatedBox
<
T
>&
box2
)
{
// There are up to 4 x 4 + 4 + 4 = 24 intersections (including dups) returned
// from rotated_rect_intersection_pts
Point
<
T
>
intersectPts
[
24
],
orderedPts
[
24
];
Point
<
T
>
pts1
[
4
];
Point
<
T
>
pts2
[
4
];
get_rotated_vertices
<
T
>
(
box1
,
pts1
);
get_rotated_vertices
<
T
>
(
box2
,
pts2
);
int
num
=
get_intersection_points
<
T
>
(
pts1
,
pts2
,
intersectPts
);
if
(
num
<=
2
)
{
return
0.0
;
}
// Convex Hull to order the intersection points in clockwise order and find
// the contour area.
int
num_convex
=
convex_hull_graham
<
T
>
(
intersectPts
,
num
,
orderedPts
,
true
);
return
polygon_area
<
T
>
(
orderedPts
,
num_convex
);
}
}
// namespace
template
<
typename
T
>
HOST_DEVICE_INLINE
T
rbox_iou_single
(
T
const
*
const
box1_raw
,
T
const
*
const
box2_raw
)
{
// shift center to the middle point to achieve higher precision in result
RotatedBox
<
T
>
box1
,
box2
;
auto
center_shift_x
=
(
box1_raw
[
0
]
+
box2_raw
[
0
])
/
2.0
;
auto
center_shift_y
=
(
box1_raw
[
1
]
+
box2_raw
[
1
])
/
2.0
;
box1
.
x_ctr
=
box1_raw
[
0
]
-
center_shift_x
;
box1
.
y_ctr
=
box1_raw
[
1
]
-
center_shift_y
;
box1
.
w
=
box1_raw
[
2
];
box1
.
h
=
box1_raw
[
3
];
box1
.
a
=
box1_raw
[
4
];
box2
.
x_ctr
=
box2_raw
[
0
]
-
center_shift_x
;
box2
.
y_ctr
=
box2_raw
[
1
]
-
center_shift_y
;
box2
.
w
=
box2_raw
[
2
];
box2
.
h
=
box2_raw
[
3
];
box2
.
a
=
box2_raw
[
4
];
const
T
area1
=
box1
.
w
*
box1
.
h
;
const
T
area2
=
box2
.
w
*
box2
.
h
;
if
(
area1
<
1e-14
||
area2
<
1e-14
)
{
return
0.
f
;
}
const
T
intersection
=
rboxes_intersection
<
T
>
(
box1
,
box2
);
const
T
iou
=
intersection
/
(
area1
+
area2
-
intersection
);
return
iou
;
}
ppdet/ext_op/setup.py
浏览文件 @
4630e56b
from
paddle.utils.cpp_extension
import
CUDAExtension
,
setup
import
paddle
from
paddle.utils.cpp_extension
import
CppExtension
,
CUDAExtension
,
setup
if
__name__
==
"__main__"
:
setup
(
name
=
'rbox_iou_ops'
,
ext_modules
=
CUDAExtension
(
sources
=
[
'rbox_iou_op.cc'
,
'rbox_iou_op.cu'
]))
if
paddle
.
device
.
is_compiled_with_cuda
():
setup
(
name
=
'rbox_iou_ops'
,
ext_modules
=
CUDAExtension
(
sources
=
[
'rbox_iou_op.cc'
,
'rbox_iou_op.cu'
],
extra_compile_args
=
{
'cxx'
:
[
'-DPADDLE_WITH_CUDA'
]}))
else
:
setup
(
name
=
'rbox_iou_ops'
,
ext_modules
=
CppExtension
(
sources
=
[
'rbox_iou_op.cc'
]))
ppdet/ext_op/test.py
浏览文件 @
4630e56b
...
...
@@ -3,41 +3,15 @@ import sys
import
time
from
shapely.geometry
import
Polygon
import
paddle
paddle
.
set_device
(
'gpu:0'
)
paddle
.
disable_static
()
import
unittest
try
:
from
rbox_iou_ops
import
rbox_iou
except
Exception
as
e
:
print
(
'import
custom
_ops error'
,
e
)
print
(
'import
rbox_iou
_ops error'
,
e
)
sys
.
exit
(
-
1
)
# generate random data
rbox1
=
np
.
random
.
rand
(
13000
,
5
)
rbox2
=
np
.
random
.
rand
(
7
,
5
)
# x1 y1 w h [0, 0.5]
rbox1
[:,
0
:
4
]
=
rbox1
[:,
0
:
4
]
*
0.45
+
0.001
rbox2
[:,
0
:
4
]
=
rbox2
[:,
0
:
4
]
*
0.45
+
0.001
# generate rbox
rbox1
[:,
4
]
=
rbox1
[:,
4
]
-
0.5
rbox2
[:,
4
]
=
rbox2
[:,
4
]
-
0.5
print
(
'rbox1'
,
rbox1
.
shape
,
'rbox2'
,
rbox2
.
shape
)
# to paddle tensor
pd_rbox1
=
paddle
.
to_tensor
(
rbox1
)
pd_rbox2
=
paddle
.
to_tensor
(
rbox2
)
iou
=
rbox_iou
(
pd_rbox1
,
pd_rbox2
)
start_time
=
time
.
time
()
print
(
'paddle time:'
,
time
.
time
()
-
start_time
)
print
(
'iou is'
,
iou
.
cpu
().
shape
)
# get gt
def
rbox2poly_single
(
rrect
,
get_best_begin_point
=
False
):
"""
rrect:[x_ctr,y_ctr,w,h,angle]
...
...
@@ -54,7 +28,7 @@ def rbox2poly_single(rrect, get_best_begin_point=False):
poly
=
R
.
dot
(
rect
)
x0
,
x1
,
x2
,
x3
=
poly
[
0
,
:
4
]
+
x_ctr
y0
,
y1
,
y2
,
y3
=
poly
[
1
,
:
4
]
+
y_ctr
poly
=
np
.
array
([
x0
,
y0
,
x1
,
y1
,
x2
,
y2
,
x3
,
y3
],
dtype
=
np
.
float
32
)
poly
=
np
.
array
([
x0
,
y0
,
x1
,
y1
,
x2
,
y2
,
x3
,
y3
],
dtype
=
np
.
float
64
)
return
poly
...
...
@@ -87,8 +61,6 @@ def intersection(g, p):
g
=
Polygon
(
g
)
p
=
Polygon
(
p
)
#g = g.buffer(0)
#p = p.buffer(0)
if
not
g
.
is_valid
or
not
p
.
is_valid
:
return
0
...
...
@@ -100,7 +72,6 @@ def intersection(g, p):
return
inter
/
union
# rbox_iou by python
def
rbox_overlaps
(
anchors
,
gt_bboxes
,
use_cv2
=
False
):
"""
...
...
@@ -118,7 +89,7 @@ def rbox_overlaps(anchors, gt_bboxes, use_cv2=False):
anchors_ploy
=
[
rbox2poly_single
(
e
)
for
e
in
anchors
]
num_gt
,
num_anchors
=
len
(
gt_bboxes_ploy
),
len
(
anchors_ploy
)
iou
=
np
.
zeros
((
num_gt
,
num_anchors
),
dtype
=
np
.
float
32
)
iou
=
np
.
zeros
((
num_gt
,
num_anchors
),
dtype
=
np
.
float
64
)
start_time
=
time
.
time
()
for
i
in
range
(
num_gt
):
...
...
@@ -129,23 +100,57 @@ def rbox_overlaps(anchors, gt_bboxes, use_cv2=False):
print
(
'cur gt_bboxes_ploy[i]'
,
gt_bboxes_ploy
[
i
],
'anchors_ploy[j]'
,
anchors_ploy
[
j
],
e
)
iou
=
iou
.
T
print
(
'intersection all sp_time'
,
time
.
time
()
-
start_time
)
return
iou
# make coor as int
ploy_rbox1
=
rbox1
ploy_rbox2
=
rbox2
ploy_rbox1
[:,
0
:
4
]
=
rbox1
[:,
0
:
4
]
*
1024
ploy_rbox2
[:,
0
:
4
]
=
rbox2
[:,
0
:
4
]
*
1024
start_time
=
time
.
time
()
iou_py
=
rbox_overlaps
(
ploy_rbox1
,
ploy_rbox2
,
use_cv2
=
False
)
print
(
'rbox time'
,
time
.
time
()
-
start_time
)
print
(
iou_py
.
shape
)
iou_pd
=
iou
.
cpu
().
numpy
()
sum_abs_diff
=
np
.
sum
(
np
.
abs
(
iou_pd
-
iou_py
))
print
(
'sum of abs diff'
,
sum_abs_diff
)
if
sum_abs_diff
<
0.02
:
print
(
"rbox_iou OP compute right!"
)
def
gen_sample
(
n
):
rbox
=
np
.
random
.
rand
(
n
,
5
)
rbox
[:,
0
:
4
]
=
rbox
[:,
0
:
4
]
*
0.45
+
0.001
rbox
[:,
4
]
=
rbox
[:,
4
]
-
0.5
return
rbox
class
RBoxIoUTest
(
unittest
.
TestCase
):
def
setUp
(
self
):
self
.
initTestCase
()
self
.
rbox1
=
gen_sample
(
self
.
n
)
self
.
rbox2
=
gen_sample
(
self
.
m
)
def
initTestCase
(
self
):
self
.
n
=
13000
self
.
m
=
7
def
assertAllClose
(
self
,
x
,
y
,
msg
,
atol
=
5e-1
,
rtol
=
1e-2
):
self
.
assertTrue
(
np
.
allclose
(
x
,
y
,
atol
=
atol
,
rtol
=
rtol
),
msg
=
msg
)
def
get_places
(
self
):
places
=
[
paddle
.
CPUPlace
()]
if
paddle
.
device
.
is_compiled_with_cuda
():
places
.
append
(
paddle
.
CUDAPlace
(
0
))
return
places
def
check_output
(
self
,
place
):
paddle
.
disable_static
()
pd_rbox1
=
paddle
.
to_tensor
(
self
.
rbox1
,
place
=
place
)
pd_rbox2
=
paddle
.
to_tensor
(
self
.
rbox2
,
place
=
place
)
actual_t
=
rbox_iou
(
pd_rbox1
,
pd_rbox2
).
numpy
()
poly_rbox1
=
self
.
rbox1
poly_rbox2
=
self
.
rbox2
poly_rbox1
[:,
0
:
4
]
=
self
.
rbox1
[:,
0
:
4
]
*
1024
poly_rbox2
[:,
0
:
4
]
=
self
.
rbox2
[:,
0
:
4
]
*
1024
expect_t
=
rbox_overlaps
(
poly_rbox1
,
poly_rbox2
,
use_cv2
=
False
)
self
.
assertAllClose
(
actual_t
,
expect_t
,
msg
=
"rbox_iou has diff at {}
\n
Expect {}
\n
But got {}"
.
format
(
str
(
place
),
str
(
expect_t
),
str
(
actual_t
)))
def
test_output
(
self
):
places
=
self
.
get_places
()
for
place
in
places
:
self
.
check_output
(
place
)
if
__name__
==
"__main__"
:
unittest
.
main
()
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录