Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
PaddleDetection
提交
04825be6
P
PaddleDetection
项目概览
PaddlePaddle
/
PaddleDetection
大约 1 年 前同步成功
通知
695
Star
11112
Fork
2696
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
184
列表
看板
标记
里程碑
合并请求
40
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
PaddleDetection
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
184
Issue
184
列表
看板
标记
里程碑
合并请求
40
合并请求
40
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
04825be6
编写于
8月 01, 2022
作者:
W
wangxinxin08
提交者:
GitHub
8月 01, 2022
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
refactor ext op and add matched rbox iou (#6530)
* refactor ext op and add matched rbox iou * replace rbox_iou_ops with ext_op
上级
3e2330fb
变更
12
隐藏空白更改
内联
并排
Showing
12 changed file
with
425 addition
and
126 deletion
+425
-126
ppdet/ext_op/README.md
ppdet/ext_op/README.md
+6
-9
ppdet/ext_op/csrc/rbox_iou/matched_rbox_iou_op.cc
ppdet/ext_op/csrc/rbox_iou/matched_rbox_iou_op.cc
+90
-0
ppdet/ext_op/csrc/rbox_iou/matched_rbox_iou_op.cu
ppdet/ext_op/csrc/rbox_iou/matched_rbox_iou_op.cu
+63
-0
ppdet/ext_op/csrc/rbox_iou/rbox_iou_op.cc
ppdet/ext_op/csrc/rbox_iou/rbox_iou_op.cc
+0
-0
ppdet/ext_op/csrc/rbox_iou/rbox_iou_op.cu
ppdet/ext_op/csrc/rbox_iou/rbox_iou_op.cu
+34
-40
ppdet/ext_op/csrc/rbox_iou/rbox_iou_op.h
ppdet/ext_op/csrc/rbox_iou/rbox_iou_op.h
+39
-47
ppdet/ext_op/setup.py
ppdet/ext_op/setup.py
+28
-9
ppdet/ext_op/unittest/test_matched_rbox_iou.py
ppdet/ext_op/unittest/test_matched_rbox_iou.py
+149
-0
ppdet/ext_op/unittest/test_rbox_iou.py
ppdet/ext_op/unittest/test_rbox_iou.py
+8
-13
ppdet/metrics/map_utils.py
ppdet/metrics/map_utils.py
+2
-2
ppdet/modeling/heads/s2anet_head.py
ppdet/modeling/heads/s2anet_head.py
+4
-4
ppdet/modeling/proposal_generator/target_layer.py
ppdet/modeling/proposal_generator/target_layer.py
+2
-2
未找到文件。
ppdet/ext_op/README.md
浏览文件 @
04825be6
# 自定义OP编译
旋转框IOU计算OP是参考
[
自定义外部算子
](
https://www.paddlepaddle.org.cn/documentation/docs/zh/guides/
07_new_op/new_custom_op
.html
)
。
旋转框IOU计算OP是参考
[
自定义外部算子
](
https://www.paddlepaddle.org.cn/documentation/docs/zh/guides/
custom_op/new_cpp_op_cn
.html
)
。
## 1. 环境依赖
-
Paddle >= 2.0.1
...
...
@@ -7,13 +7,13 @@
## 2. 安装
```
python
3.7
setup.py install
python setup.py install
```
按照如下方式使用
编译完成后即可使用,以下为
`rbox_iou`
的使用示例
```
# 引入自定义op
from
rbox_iou_ops
import rbox_iou
from
ext_op
import rbox_iou
paddle.set_device('gpu:0')
paddle.disable_static()
...
...
@@ -29,10 +29,7 @@ print('iou', iou)
```
## 3. 单元测试
单元测试
`test.py`
文件中,通过对比python实现的结果和测试自定义op结果。
由于python计算细节与cpp计算细节略有区别,误差区间设置为0.02。
可以通过执行单元测试来确认自定义算子功能的正确性,执行单元测试的示例如下所示:
```
python
3.7 test
.py
python
unittest/test_matched_rbox_iou
.py
```
提示
`rbox_iou OP compute right!`
说明OP测试通过。
ppdet/ext_op/csrc/rbox_iou/matched_rbox_iou_op.cc
0 → 100644
浏览文件 @
04825be6
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//
// The code is based on
// https://github.com/csuhan/s2anet/blob/master/mmdet/ops/box_iou_rotated
#include "paddle/extension.h"
#include "rbox_iou_op.h"
template
<
typename
T
>
void
matched_rbox_iou_cpu_kernel
(
const
int
rbox_num
,
const
T
*
rbox1_data_ptr
,
const
T
*
rbox2_data_ptr
,
T
*
output_data_ptr
)
{
int
i
;
for
(
i
=
0
;
i
<
rbox_num
;
i
++
)
{
output_data_ptr
[
i
]
=
rbox_iou_single
<
T
>
(
rbox1_data_ptr
+
i
*
5
,
rbox2_data_ptr
+
i
*
5
);
}
}
#define CHECK_INPUT_CPU(x) \
PD_CHECK(x.place() == paddle::PlaceType::kCPU, #x " must be a CPU Tensor.")
std
::
vector
<
paddle
::
Tensor
>
MatchedRboxIouCPUForward
(
const
paddle
::
Tensor
&
rbox1
,
const
paddle
::
Tensor
&
rbox2
)
{
CHECK_INPUT_CPU
(
rbox1
);
CHECK_INPUT_CPU
(
rbox2
);
PD_CHECK
(
rbox1
.
shape
()[
0
]
==
rbox2
.
shape
()[
0
],
"inputs must be same dim"
);
auto
rbox_num
=
rbox1
.
shape
()[
0
];
auto
output
=
paddle
::
Tensor
(
paddle
::
PlaceType
::
kCPU
,
{
rbox_num
});
PD_DISPATCH_FLOATING_TYPES
(
rbox1
.
type
(),
"rotated_iou_cpu_kernel"
,
([
&
]
{
matched_rbox_iou_cpu_kernel
<
data_t
>
(
rbox_num
,
rbox1
.
data
<
data_t
>
(),
rbox2
.
data
<
data_t
>
(),
output
.
mutable_data
<
data_t
>
());
}));
return
{
output
};
}
#ifdef PADDLE_WITH_CUDA
std
::
vector
<
paddle
::
Tensor
>
MatchedRboxIouCUDAForward
(
const
paddle
::
Tensor
&
rbox1
,
const
paddle
::
Tensor
&
rbox2
);
#endif
#define CHECK_INPUT_SAME(x1, x2) \
PD_CHECK(x1.place() == x2.place(), "input must be smae pacle.")
std
::
vector
<
paddle
::
Tensor
>
MatchedRboxIouForward
(
const
paddle
::
Tensor
&
rbox1
,
const
paddle
::
Tensor
&
rbox2
)
{
CHECK_INPUT_SAME
(
rbox1
,
rbox2
);
if
(
rbox1
.
place
()
==
paddle
::
PlaceType
::
kCPU
)
{
return
MatchedRboxIouCPUForward
(
rbox1
,
rbox2
);
#ifdef PADDLE_WITH_CUDA
}
else
if
(
rbox1
.
place
()
==
paddle
::
PlaceType
::
kGPU
)
{
return
MatchedRboxIouCUDAForward
(
rbox1
,
rbox2
);
#endif
}
}
std
::
vector
<
std
::
vector
<
int64_t
>>
MatchedRboxIouInferShape
(
std
::
vector
<
int64_t
>
rbox1_shape
,
std
::
vector
<
int64_t
>
rbox2_shape
)
{
return
{{
rbox1_shape
[
0
]}};
}
std
::
vector
<
paddle
::
DataType
>
MatchedRboxIouInferDtype
(
paddle
::
DataType
t1
,
paddle
::
DataType
t2
)
{
return
{
t1
};
}
PD_BUILD_OP
(
matched_rbox_iou
)
.
Inputs
({
"RBOX1"
,
"RBOX2"
})
.
Outputs
({
"Output"
})
.
SetKernelFn
(
PD_KERNEL
(
MatchedRboxIouForward
))
.
SetInferShapeFn
(
PD_INFER_SHAPE
(
MatchedRboxIouInferShape
))
.
SetInferDtypeFn
(
PD_INFER_DTYPE
(
MatchedRboxIouInferDtype
));
ppdet/ext_op/csrc/rbox_iou/matched_rbox_iou_op.cu
0 → 100644
浏览文件 @
04825be6
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//
// The code is based on
// https://github.com/csuhan/s2anet/blob/master/mmdet/ops/box_iou_rotated
#include "paddle/extension.h"
#include "rbox_iou_op.h"
/**
Computes ceil(a / b)
*/
static
inline
int
CeilDiv
(
const
int
a
,
const
int
b
)
{
return
(
a
+
b
-
1
)
/
b
;
}
template
<
typename
T
>
__global__
void
matched_rbox_iou_cuda_kernel
(
const
int
rbox_num
,
const
T
*
rbox1_data_ptr
,
const
T
*
rbox2_data_ptr
,
T
*
output_data_ptr
)
{
for
(
int
tid
=
blockIdx
.
x
*
blockDim
.
x
+
threadIdx
.
x
;
tid
<
rbox_num
;
tid
+=
blockDim
.
x
*
gridDim
.
x
)
{
output_data_ptr
[
tid
]
=
rbox_iou_single
<
T
>
(
rbox1_data_ptr
+
tid
*
5
,
rbox2_data_ptr
+
tid
*
5
);
}
}
#define CHECK_INPUT_GPU(x) \
PD_CHECK(x.place() == paddle::PlaceType::kGPU, #x " must be a GPU Tensor.")
std
::
vector
<
paddle
::
Tensor
>
MatchedRboxIouCUDAForward
(
const
paddle
::
Tensor
&
rbox1
,
const
paddle
::
Tensor
&
rbox2
)
{
CHECK_INPUT_GPU
(
rbox1
);
CHECK_INPUT_GPU
(
rbox2
);
PD_CHECK
(
rbox1
.
shape
()[
0
]
==
rbox2
.
shape
()[
0
],
"inputs must be same dim"
);
auto
rbox_num
=
rbox1
.
shape
()[
0
];
auto
output
=
paddle
::
Tensor
(
paddle
::
PlaceType
::
kGPU
,
{
rbox_num
});
const
int
thread_per_block
=
512
;
const
int
block_per_grid
=
CeilDiv
(
rbox_num
,
thread_per_block
);
PD_DISPATCH_FLOATING_TYPES
(
rbox1
.
type
(),
"matched_rbox_iou_cuda_kernel"
,
([
&
]
{
matched_rbox_iou_cuda_kernel
<
data_t
><<<
block_per_grid
,
thread_per_block
,
0
,
rbox1
.
stream
()
>>>
(
rbox_num
,
rbox1
.
data
<
data_t
>
(),
rbox2
.
data
<
data_t
>
(),
output
.
mutable_data
<
data_t
>
());
}));
return
{
output
};
}
ppdet/ext_op/rbox_iou_op.cc
→
ppdet/ext_op/
csrc/rbox_iou/
rbox_iou_op.cc
浏览文件 @
04825be6
文件已移动
ppdet/ext_op/rbox_iou_op.cu
→
ppdet/ext_op/
csrc/rbox_iou/
rbox_iou_op.cu
浏览文件 @
04825be6
...
...
@@ -12,10 +12,11 @@
// See the License for the specific language governing permissions and
// limitations under the License.
//
// The code is based on https://github.com/csuhan/s2anet/blob/master/mmdet/ops/box_iou_rotated
// The code is based on
// https://github.com/csuhan/s2anet/blob/master/mmdet/ops/box_iou_rotated
#include "rbox_iou_op.h"
#include "paddle/extension.h"
#include "rbox_iou_op.h"
// 2D block with 32 * 16 = 512 threads per block
const
int
BLOCK_DIM_X
=
32
;
...
...
@@ -25,17 +26,13 @@ const int BLOCK_DIM_Y = 16;
Computes ceil(a / b)
*/
static
inline
int
CeilDiv
(
const
int
a
,
const
int
b
)
{
return
(
a
+
b
-
1
)
/
b
;
}
static
inline
int
CeilDiv
(
const
int
a
,
const
int
b
)
{
return
(
a
+
b
-
1
)
/
b
;
}
template
<
typename
T
>
__global__
void
rbox_iou_cuda_kernel
(
const
int
rbox1_num
,
const
int
rbox2_num
,
const
T
*
rbox1_data_ptr
,
const
T
*
rbox2_data_ptr
,
T
*
output_data_ptr
)
{
__global__
void
rbox_iou_cuda_kernel
(
const
int
rbox1_num
,
const
int
rbox2_num
,
const
T
*
rbox1_data_ptr
,
const
T
*
rbox2_data_ptr
,
T
*
output_data_ptr
)
{
// get row_start and col_start
const
int
rbox1_block_idx
=
blockIdx
.
x
*
blockDim
.
x
;
...
...
@@ -47,7 +44,6 @@ __global__ void rbox_iou_cuda_kernel(
__shared__
T
block_boxes1
[
BLOCK_DIM_X
*
5
];
__shared__
T
block_boxes2
[
BLOCK_DIM_Y
*
5
];
// It's safe to copy using threadIdx.x since BLOCK_DIM_X >= BLOCK_DIM_Y
if
(
threadIdx
.
x
<
rbox1_thread_num
&&
threadIdx
.
y
==
0
)
{
block_boxes1
[
threadIdx
.
x
*
5
+
0
]
=
...
...
@@ -62,7 +58,8 @@ __global__ void rbox_iou_cuda_kernel(
rbox1_data_ptr
[(
rbox1_block_idx
+
threadIdx
.
x
)
*
5
+
4
];
}
// threadIdx.x < BLOCK_DIM_Y=rbox2_thread_num, just use same condition as above: threadIdx.y == 0
// threadIdx.x < BLOCK_DIM_Y=rbox2_thread_num, just use same condition as
// above: threadIdx.y == 0
if
(
threadIdx
.
x
<
rbox2_thread_num
&&
threadIdx
.
y
==
0
)
{
block_boxes2
[
threadIdx
.
x
*
5
+
0
]
=
rbox2_data_ptr
[(
rbox2_block_idx
+
threadIdx
.
x
)
*
5
+
0
];
...
...
@@ -80,41 +77,38 @@ __global__ void rbox_iou_cuda_kernel(
__syncthreads
();
if
(
threadIdx
.
x
<
rbox1_thread_num
&&
threadIdx
.
y
<
rbox2_thread_num
)
{
int
offset
=
(
rbox1_block_idx
+
threadIdx
.
x
)
*
rbox2_num
+
rbox2_block_idx
+
threadIdx
.
y
;
output_data_ptr
[
offset
]
=
rbox_iou_single
<
T
>
(
block_boxes1
+
threadIdx
.
x
*
5
,
block_boxes2
+
threadIdx
.
y
*
5
);
int
offset
=
(
rbox1_block_idx
+
threadIdx
.
x
)
*
rbox2_num
+
rbox2_block_idx
+
threadIdx
.
y
;
output_data_ptr
[
offset
]
=
rbox_iou_single
<
T
>
(
block_boxes1
+
threadIdx
.
x
*
5
,
block_boxes2
+
threadIdx
.
y
*
5
);
}
}
#define CHECK_INPUT_GPU(x) PD_CHECK(x.place() == paddle::PlaceType::kGPU, #x " must be a GPU Tensor.")
#define CHECK_INPUT_GPU(x) \
PD_CHECK(x.place() == paddle::PlaceType::kGPU, #x " must be a GPU Tensor.")
std
::
vector
<
paddle
::
Tensor
>
RboxIouCUDAForward
(
const
paddle
::
Tensor
&
rbox1
,
const
paddle
::
Tensor
&
rbox2
)
{
CHECK_INPUT_GPU
(
rbox1
);
CHECK_INPUT_GPU
(
rbox2
);
std
::
vector
<
paddle
::
Tensor
>
RboxIouCUDAForward
(
const
paddle
::
Tensor
&
rbox1
,
const
paddle
::
Tensor
&
rbox2
)
{
CHECK_INPUT_GPU
(
rbox1
);
CHECK_INPUT_GPU
(
rbox2
);
auto
rbox1_num
=
rbox1
.
shape
()[
0
];
auto
rbox2_num
=
rbox2
.
shape
()[
0
];
auto
rbox1_num
=
rbox1
.
shape
()[
0
];
auto
rbox2_num
=
rbox2
.
shape
()[
0
];
auto
output
=
paddle
::
Tensor
(
paddle
::
PlaceType
::
kGPU
,
{
rbox1_num
,
rbox2_num
});
auto
output
=
paddle
::
Tensor
(
paddle
::
PlaceType
::
kGPU
,
{
rbox1_num
,
rbox2_num
});
const
int
blocks_x
=
CeilDiv
(
rbox1_num
,
BLOCK_DIM_X
);
const
int
blocks_y
=
CeilDiv
(
rbox2_num
,
BLOCK_DIM_Y
);
const
int
blocks_x
=
CeilDiv
(
rbox1_num
,
BLOCK_DIM_X
);
const
int
blocks_y
=
CeilDiv
(
rbox2_num
,
BLOCK_DIM_Y
);
dim3
blocks
(
blocks_x
,
blocks_y
);
dim3
threads
(
BLOCK_DIM_X
,
BLOCK_DIM_Y
);
dim3
blocks
(
blocks_x
,
blocks_y
);
dim3
threads
(
BLOCK_DIM_X
,
BLOCK_DIM_Y
);
PD_DISPATCH_FLOATING_TYPES
(
rbox1
.
type
(),
"rbox_iou_cuda_kernel"
,
([
&
]
{
rbox_iou_cuda_kernel
<
data_t
><<<
blocks
,
threads
,
0
,
rbox1
.
stream
()
>>>
(
rbox1_num
,
rbox2_num
,
rbox1
.
data
<
data_t
>
(),
rbox2
.
data
<
data_t
>
(),
output
.
mutable_data
<
data_t
>
());
}));
PD_DISPATCH_FLOATING_TYPES
(
rbox1
.
type
(),
"rbox_iou_cuda_kernel"
,
([
&
]
{
rbox_iou_cuda_kernel
<
data_t
><<<
blocks
,
threads
,
0
,
rbox1
.
stream
()
>>>
(
rbox1_num
,
rbox2_num
,
rbox1
.
data
<
data_t
>
(),
rbox2
.
data
<
data_t
>
(),
output
.
mutable_data
<
data_t
>
());
}));
return
{
output
};
return
{
output
};
}
ppdet/ext_op/rbox_iou_op.h
→
ppdet/ext_op/
csrc/rbox_iou/
rbox_iou_op.h
浏览文件 @
04825be6
...
...
@@ -12,7 +12,8 @@
// See the License for the specific language governing permissions and
// limitations under the License.
//
// The code is based on https://github.com/csuhan/s2anet/blob/master/mmdet/ops/box_iou_rotated
// The code is based on
// https://github.com/csuhan/s2anet/blob/master/mmdet/ops/box_iou_rotated
#pragma once
...
...
@@ -32,24 +33,20 @@
namespace
{
template
<
typename
T
>
struct
RotatedBox
{
T
x_ctr
,
y_ctr
,
w
,
h
,
a
;
};
template
<
typename
T
>
struct
RotatedBox
{
T
x_ctr
,
y_ctr
,
w
,
h
,
a
;
};
template
<
typename
T
>
struct
Point
{
template
<
typename
T
>
struct
Point
{
T
x
,
y
;
HOST_DEVICE_INLINE
Point
(
const
T
&
px
=
0
,
const
T
&
py
=
0
)
:
x
(
px
),
y
(
py
)
{}
HOST_DEVICE_INLINE
Point
operator
+
(
const
Point
&
p
)
const
{
HOST_DEVICE_INLINE
Point
(
const
T
&
px
=
0
,
const
T
&
py
=
0
)
:
x
(
px
),
y
(
py
)
{}
HOST_DEVICE_INLINE
Point
operator
+
(
const
Point
&
p
)
const
{
return
Point
(
x
+
p
.
x
,
y
+
p
.
y
);
}
HOST_DEVICE_INLINE
Point
&
operator
+=
(
const
Point
&
p
)
{
HOST_DEVICE_INLINE
Point
&
operator
+=
(
const
Point
&
p
)
{
x
+=
p
.
x
;
y
+=
p
.
y
;
return
*
this
;
}
HOST_DEVICE_INLINE
Point
operator
-
(
const
Point
&
p
)
const
{
HOST_DEVICE_INLINE
Point
operator
-
(
const
Point
&
p
)
const
{
return
Point
(
x
-
p
.
x
,
y
-
p
.
y
);
}
HOST_DEVICE_INLINE
Point
operator
*
(
const
T
coeff
)
const
{
...
...
@@ -58,22 +55,21 @@ struct Point {
};
template
<
typename
T
>
HOST_DEVICE_INLINE
T
dot_2d
(
const
Point
<
T
>
&
A
,
const
Point
<
T
>&
B
)
{
HOST_DEVICE_INLINE
T
dot_2d
(
const
Point
<
T
>
&
A
,
const
Point
<
T
>
&
B
)
{
return
A
.
x
*
B
.
x
+
A
.
y
*
B
.
y
;
}
template
<
typename
T
>
HOST_DEVICE_INLINE
T
cross_2d
(
const
Point
<
T
>
&
A
,
const
Point
<
T
>&
B
)
{
HOST_DEVICE_INLINE
T
cross_2d
(
const
Point
<
T
>
&
A
,
const
Point
<
T
>
&
B
)
{
return
A
.
x
*
B
.
y
-
B
.
x
*
A
.
y
;
}
template
<
typename
T
>
HOST_DEVICE_INLINE
void
get_rotated_vertices
(
const
RotatedBox
<
T
>&
box
,
Point
<
T
>
(
&
pts
)[
4
])
{
HOST_DEVICE_INLINE
void
get_rotated_vertices
(
const
RotatedBox
<
T
>
&
box
,
Point
<
T
>
(
&
pts
)[
4
])
{
// M_PI / 180. == 0.01745329251
//double theta = box.a * 0.01745329251;
//MODIFIED
//
double theta = box.a * 0.01745329251;
//
MODIFIED
double
theta
=
box
.
a
;
T
cosTheta2
=
(
T
)
cos
(
theta
)
*
0.5
f
;
T
sinTheta2
=
(
T
)
sin
(
theta
)
*
0.5
f
;
...
...
@@ -90,10 +86,9 @@ HOST_DEVICE_INLINE void get_rotated_vertices(
}
template
<
typename
T
>
HOST_DEVICE_INLINE
int
get_intersection_points
(
const
Point
<
T
>
(
&
pts1
)[
4
],
const
Point
<
T
>
(
&
pts2
)[
4
],
Point
<
T
>
(
&
intersections
)[
24
])
{
HOST_DEVICE_INLINE
int
get_intersection_points
(
const
Point
<
T
>
(
&
pts1
)[
4
],
const
Point
<
T
>
(
&
pts2
)[
4
],
Point
<
T
>
(
&
intersections
)[
24
])
{
// Line vector
// A line from p1 to p2 is: p1 + (p2-p1)*t, t=[0,1]
Point
<
T
>
vec1
[
4
],
vec2
[
4
];
...
...
@@ -127,8 +122,8 @@ HOST_DEVICE_INLINE int get_intersection_points(
// Check for vertices of rect1 inside rect2
{
const
auto
&
AB
=
vec2
[
0
];
const
auto
&
DA
=
vec2
[
3
];
const
auto
&
AB
=
vec2
[
0
];
const
auto
&
DA
=
vec2
[
3
];
auto
ABdotAB
=
dot_2d
<
T
>
(
AB
,
AB
);
auto
ADdotAD
=
dot_2d
<
T
>
(
DA
,
DA
);
for
(
int
i
=
0
;
i
<
4
;
i
++
)
{
...
...
@@ -150,8 +145,8 @@ HOST_DEVICE_INLINE int get_intersection_points(
// Reverse the check - check for vertices of rect2 inside rect1
{
const
auto
&
AB
=
vec1
[
0
];
const
auto
&
DA
=
vec1
[
3
];
const
auto
&
AB
=
vec1
[
0
];
const
auto
&
DA
=
vec1
[
3
];
auto
ABdotAB
=
dot_2d
<
T
>
(
AB
,
AB
);
auto
ADdotAD
=
dot_2d
<
T
>
(
DA
,
DA
);
for
(
int
i
=
0
;
i
<
4
;
i
++
)
{
...
...
@@ -171,11 +166,9 @@ HOST_DEVICE_INLINE int get_intersection_points(
}
template
<
typename
T
>
HOST_DEVICE_INLINE
int
convex_hull_graham
(
const
Point
<
T
>
(
&
p
)[
24
],
const
int
&
num_in
,
Point
<
T
>
(
&
q
)[
24
],
bool
shift_to_zero
=
false
)
{
HOST_DEVICE_INLINE
int
convex_hull_graham
(
const
Point
<
T
>
(
&
p
)[
24
],
const
int
&
num_in
,
Point
<
T
>
(
&
q
)[
24
],
bool
shift_to_zero
=
false
)
{
assert
(
num_in
>=
2
);
// Step 1:
...
...
@@ -188,7 +181,7 @@ HOST_DEVICE_INLINE int convex_hull_graham(
t
=
i
;
}
}
auto
&
start
=
p
[
t
];
// starting point
auto
&
start
=
p
[
t
];
// starting point
// Step 2:
// Subtract starting point from every points (for sorting in the next step)
...
...
@@ -230,15 +223,15 @@ HOST_DEVICE_INLINE int convex_hull_graham(
}
#else
// CPU version
std
::
sort
(
q
+
1
,
q
+
num_in
,
[](
const
Point
<
T
>&
A
,
const
Point
<
T
>&
B
)
->
bool
{
T
temp
=
cross_2d
<
T
>
(
A
,
B
);
if
(
fabs
(
temp
)
<
1e-6
)
{
return
dot_2d
<
T
>
(
A
,
A
)
<
dot_2d
<
T
>
(
B
,
B
);
}
else
{
return
temp
>
0
;
}
});
std
::
sort
(
q
+
1
,
q
+
num_in
,
[](
const
Point
<
T
>
&
A
,
const
Point
<
T
>
&
B
)
->
bool
{
T
temp
=
cross_2d
<
T
>
(
A
,
B
);
if
(
fabs
(
temp
)
<
1e-6
)
{
return
dot_2d
<
T
>
(
A
,
A
)
<
dot_2d
<
T
>
(
B
,
B
);
}
else
{
return
temp
>
0
;
}
});
#endif
// Step 4:
...
...
@@ -286,7 +279,7 @@ HOST_DEVICE_INLINE int convex_hull_graham(
}
template
<
typename
T
>
HOST_DEVICE_INLINE
T
polygon_area
(
const
Point
<
T
>
(
&
q
)[
24
],
const
int
&
m
)
{
HOST_DEVICE_INLINE
T
polygon_area
(
const
Point
<
T
>
(
&
q
)[
24
],
const
int
&
m
)
{
if
(
m
<=
2
)
{
return
0
;
}
...
...
@@ -300,9 +293,8 @@ HOST_DEVICE_INLINE T polygon_area(const Point<T> (&q)[24], const int& m) {
}
template
<
typename
T
>
HOST_DEVICE_INLINE
T
rboxes_intersection
(
const
RotatedBox
<
T
>&
box1
,
const
RotatedBox
<
T
>&
box2
)
{
HOST_DEVICE_INLINE
T
rboxes_intersection
(
const
RotatedBox
<
T
>
&
box1
,
const
RotatedBox
<
T
>
&
box2
)
{
// There are up to 4 x 4 + 4 + 4 = 24 intersections (including dups) returned
// from rotated_rect_intersection_pts
Point
<
T
>
intersectPts
[
24
],
orderedPts
[
24
];
...
...
@@ -327,8 +319,8 @@ HOST_DEVICE_INLINE T rboxes_intersection(
}
// namespace
template
<
typename
T
>
HOST_DEVICE_INLINE
T
rbox_iou_single
(
T
const
*
const
box1_raw
,
T
const
*
const
box2_raw
)
{
HOST_DEVICE_INLINE
T
rbox_iou_single
(
T
const
*
const
box1_raw
,
T
const
*
const
box2_raw
)
{
// shift center to the middle point to achieve higher precision in result
RotatedBox
<
T
>
box1
,
box2
;
auto
center_shift_x
=
(
box1_raw
[
0
]
+
box2_raw
[
0
])
/
2.0
;
...
...
ppdet/ext_op/setup.py
浏览文件 @
04825be6
import
os
import
glob
import
paddle
from
paddle.utils.cpp_extension
import
CppExtension
,
CUDAExtension
,
setup
if
__name__
==
"__main__"
:
def
get_extensions
():
root_dir
=
os
.
path
.
dirname
(
os
.
path
.
abspath
(
__file__
))
ext_root_dir
=
os
.
path
.
join
(
root_dir
,
'csrc'
)
sources
=
[]
for
ext_name
in
os
.
listdir
(
ext_root_dir
):
ext_dir
=
os
.
path
.
join
(
ext_root_dir
,
ext_name
)
source
=
glob
.
glob
(
os
.
path
.
join
(
ext_dir
,
'*.cc'
))
kwargs
=
dict
()
if
paddle
.
device
.
is_compiled_with_cuda
():
source
+=
glob
.
glob
(
os
.
path
.
join
(
ext_dir
,
'*.cu'
))
if
not
source
:
continue
sources
+=
source
if
paddle
.
device
.
is_compiled_with_cuda
():
setup
(
name
=
'rbox_iou_ops'
,
ext_modules
=
CUDAExtension
(
sources
=
[
'rbox_iou_op.cc'
,
'rbox_iou_op.cu'
],
extra_compile_args
=
{
'cxx'
:
[
'-DPADDLE_WITH_CUDA'
]}))
extension
=
CUDAExtension
(
sources
,
extra_compile_args
=
{
'cxx'
:
[
'-DPADDLE_WITH_CUDA'
]})
else
:
setup
(
name
=
'rbox_iou_ops'
,
ext_modules
=
CppExtension
(
sources
=
[
'rbox_iou_op.cc'
]))
extension
=
CppExtension
(
sources
)
return
extension
if
__name__
==
"__main__"
:
setup
(
name
=
'ext_op'
,
ext_modules
=
get_extensions
())
ppdet/ext_op/unittest/test_matched_rbox_iou.py
0 → 100644
浏览文件 @
04825be6
import
numpy
as
np
import
sys
import
time
from
shapely.geometry
import
Polygon
import
paddle
import
unittest
from
ext_op
import
matched_rbox_iou
def
rbox2poly_single
(
rrect
,
get_best_begin_point
=
False
):
"""
rrect:[x_ctr,y_ctr,w,h,angle]
to
poly:[x0,y0,x1,y1,x2,y2,x3,y3]
"""
x_ctr
,
y_ctr
,
width
,
height
,
angle
=
rrect
[:
5
]
tl_x
,
tl_y
,
br_x
,
br_y
=
-
width
/
2
,
-
height
/
2
,
width
/
2
,
height
/
2
# rect 2x4
rect
=
np
.
array
([[
tl_x
,
br_x
,
br_x
,
tl_x
],
[
tl_y
,
tl_y
,
br_y
,
br_y
]])
R
=
np
.
array
([[
np
.
cos
(
angle
),
-
np
.
sin
(
angle
)],
[
np
.
sin
(
angle
),
np
.
cos
(
angle
)]])
# poly
poly
=
R
.
dot
(
rect
)
x0
,
x1
,
x2
,
x3
=
poly
[
0
,
:
4
]
+
x_ctr
y0
,
y1
,
y2
,
y3
=
poly
[
1
,
:
4
]
+
y_ctr
poly
=
np
.
array
([
x0
,
y0
,
x1
,
y1
,
x2
,
y2
,
x3
,
y3
],
dtype
=
np
.
float64
)
return
poly
def
intersection
(
g
,
p
):
"""
Intersection.
"""
g
=
g
[:
8
].
reshape
((
4
,
2
))
p
=
p
[:
8
].
reshape
((
4
,
2
))
a
=
g
b
=
p
use_filter
=
True
if
use_filter
:
# step1:
inter_x1
=
np
.
maximum
(
np
.
min
(
a
[:,
0
]),
np
.
min
(
b
[:,
0
]))
inter_x2
=
np
.
minimum
(
np
.
max
(
a
[:,
0
]),
np
.
max
(
b
[:,
0
]))
inter_y1
=
np
.
maximum
(
np
.
min
(
a
[:,
1
]),
np
.
min
(
b
[:,
1
]))
inter_y2
=
np
.
minimum
(
np
.
max
(
a
[:,
1
]),
np
.
max
(
b
[:,
1
]))
if
inter_x1
>=
inter_x2
or
inter_y1
>=
inter_y2
:
return
0.
x1
=
np
.
minimum
(
np
.
min
(
a
[:,
0
]),
np
.
min
(
b
[:,
0
]))
x2
=
np
.
maximum
(
np
.
max
(
a
[:,
0
]),
np
.
max
(
b
[:,
0
]))
y1
=
np
.
minimum
(
np
.
min
(
a
[:,
1
]),
np
.
min
(
b
[:,
1
]))
y2
=
np
.
maximum
(
np
.
max
(
a
[:,
1
]),
np
.
max
(
b
[:,
1
]))
if
x1
>=
x2
or
y1
>=
y2
or
(
x2
-
x1
)
<
2
or
(
y2
-
y1
)
<
2
:
return
0.
g
=
Polygon
(
g
)
p
=
Polygon
(
p
)
if
not
g
.
is_valid
or
not
p
.
is_valid
:
return
0
inter
=
Polygon
(
g
).
intersection
(
Polygon
(
p
)).
area
union
=
g
.
area
+
p
.
area
-
inter
if
union
==
0
:
return
0
else
:
return
inter
/
union
def
matched_rbox_overlaps
(
anchors
,
gt_bboxes
,
use_cv2
=
False
):
"""
Args:
anchors: [M, 5] x1,y1,x2,y2,angle
gt_bboxes: [M, 5] x1,y1,x2,y2,angle
Returns:
macthed_iou: [M]
"""
assert
anchors
.
shape
[
1
]
==
5
assert
gt_bboxes
.
shape
[
1
]
==
5
gt_bboxes_ploy
=
[
rbox2poly_single
(
e
)
for
e
in
gt_bboxes
]
anchors_ploy
=
[
rbox2poly_single
(
e
)
for
e
in
anchors
]
num
=
len
(
anchors_ploy
)
iou
=
np
.
zeros
((
num
,
),
dtype
=
np
.
float64
)
start_time
=
time
.
time
()
for
i
in
range
(
num
):
try
:
iou
[
i
]
=
intersection
(
gt_bboxes_ploy
[
i
],
anchors_ploy
[
i
])
except
Exception
as
e
:
print
(
'cur gt_bboxes_ploy[i]'
,
gt_bboxes_ploy
[
i
],
'anchors_ploy[j]'
,
anchors_ploy
[
i
],
e
)
return
iou
def
gen_sample
(
n
):
rbox
=
np
.
random
.
rand
(
n
,
5
)
rbox
[:,
0
:
4
]
=
rbox
[:,
0
:
4
]
*
0.45
+
0.001
rbox
[:,
4
]
=
rbox
[:,
4
]
-
0.5
return
rbox
class
MatchedRBoxIoUTest
(
unittest
.
TestCase
):
def
setUp
(
self
):
self
.
initTestCase
()
self
.
rbox1
=
gen_sample
(
self
.
n
)
self
.
rbox2
=
gen_sample
(
self
.
n
)
def
initTestCase
(
self
):
self
.
n
=
1000
def
assertAllClose
(
self
,
x
,
y
,
msg
,
atol
=
5e-1
,
rtol
=
1e-2
):
self
.
assertTrue
(
np
.
allclose
(
x
,
y
,
atol
=
atol
,
rtol
=
rtol
),
msg
=
msg
)
def
get_places
(
self
):
places
=
[
paddle
.
CPUPlace
()]
if
paddle
.
device
.
is_compiled_with_cuda
():
places
.
append
(
paddle
.
CUDAPlace
(
0
))
return
places
def
check_output
(
self
,
place
):
paddle
.
disable_static
()
pd_rbox1
=
paddle
.
to_tensor
(
self
.
rbox1
,
place
=
place
)
pd_rbox2
=
paddle
.
to_tensor
(
self
.
rbox2
,
place
=
place
)
actual_t
=
matched_rbox_iou
(
pd_rbox1
,
pd_rbox2
).
numpy
()
poly_rbox1
=
self
.
rbox1
poly_rbox2
=
self
.
rbox2
poly_rbox1
[:,
0
:
4
]
=
self
.
rbox1
[:,
0
:
4
]
*
1024
poly_rbox2
[:,
0
:
4
]
=
self
.
rbox2
[:,
0
:
4
]
*
1024
expect_t
=
matched_rbox_overlaps
(
poly_rbox1
,
poly_rbox2
,
use_cv2
=
False
)
self
.
assertAllClose
(
actual_t
,
expect_t
,
msg
=
"rbox_iou has diff at {}
\n
Expect {}
\n
But got {}"
.
format
(
str
(
place
),
str
(
expect_t
),
str
(
actual_t
)))
def
test_output
(
self
):
places
=
self
.
get_places
()
for
place
in
places
:
self
.
check_output
(
place
)
if
__name__
==
"__main__"
:
unittest
.
main
()
ppdet/ext_op/
test
.py
→
ppdet/ext_op/
unittest/test_rbox_iou
.py
浏览文件 @
04825be6
...
...
@@ -5,11 +5,7 @@ from shapely.geometry import Polygon
import
paddle
import
unittest
try
:
from
rbox_iou_ops
import
rbox_iou
except
Exception
as
e
:
print
(
'import rbox_iou_ops error'
,
e
)
sys
.
exit
(
-
1
)
from
ext_op
import
rbox_iou
def
rbox2poly_single
(
rrect
,
get_best_begin_point
=
False
):
...
...
@@ -80,7 +76,7 @@ def rbox_overlaps(anchors, gt_bboxes, use_cv2=False):
gt_bboxes: [M, 5] x1,y1,x2,y2,angle
Returns:
iou: [NA, M]
"""
assert
anchors
.
shape
[
1
]
==
5
assert
gt_bboxes
.
shape
[
1
]
==
5
...
...
@@ -89,17 +85,16 @@ def rbox_overlaps(anchors, gt_bboxes, use_cv2=False):
anchors_ploy
=
[
rbox2poly_single
(
e
)
for
e
in
anchors
]
num_gt
,
num_anchors
=
len
(
gt_bboxes_ploy
),
len
(
anchors_ploy
)
iou
=
np
.
zeros
((
num_
gt
,
num_anchors
),
dtype
=
np
.
float64
)
iou
=
np
.
zeros
((
num_
anchors
,
num_gt
),
dtype
=
np
.
float64
)
start_time
=
time
.
time
()
for
i
in
range
(
num_
gt
):
for
j
in
range
(
num_
anchors
):
for
i
in
range
(
num_
anchors
):
for
j
in
range
(
num_
gt
):
try
:
iou
[
i
,
j
]
=
intersection
(
gt_bboxes_ploy
[
i
],
anchor
s_ploy
[
j
])
iou
[
i
,
j
]
=
intersection
(
anchors_ploy
[
i
],
gt_bboxe
s_ploy
[
j
])
except
Exception
as
e
:
print
(
'cur gt_bboxes_ploy[i]'
,
gt_bboxes_ploy
[
i
],
'anchors_ploy[j]'
,
anchors_ploy
[
j
],
e
)
iou
=
iou
.
T
print
(
'cur anchors_ploy[i]'
,
anchors_ploy
[
i
],
'gt_bboxes_ploy[j]'
,
gt_bboxes_ploy
[
j
],
e
)
return
iou
...
...
ppdet/metrics/map_utils.py
浏览文件 @
04825be6
...
...
@@ -121,9 +121,9 @@ def calc_rbox_iou(pred, gt_rbox):
pred_rbox
=
pred_rbox
.
reshape
(
-
1
,
5
)
pred_rbox
=
pred_rbox
.
reshape
(
-
1
,
5
)
try
:
from
rbox_iou_ops
import
rbox_iou
from
ext_op
import
rbox_iou
except
Exception
as
e
:
print
(
"import custom_ops error, try install
rbox_iou_ops
"
\
print
(
"import custom_ops error, try install
ext_op
"
\
"following ppdet/ext_op/README.md"
,
e
)
sys
.
stdout
.
flush
()
sys
.
exit
(
-
1
)
...
...
ppdet/modeling/heads/s2anet_head.py
浏览文件 @
04825be6
...
...
@@ -601,9 +601,9 @@ class S2ANetHead(nn.Layer):
fam_bbox
=
paddle
.
sum
(
fam_bbox
,
axis
=-
1
)
feat_bbox_weights
=
paddle
.
sum
(
feat_bbox_weights
,
axis
=-
1
)
try
:
from
rbox_iou_ops
import
rbox_iou
from
ext_op
import
rbox_iou
except
Exception
as
e
:
print
(
"import custom_ops error, try install
rbox_iou_ops
"
\
print
(
"import custom_ops error, try install
ext_op
"
\
"following ppdet/ext_op/README.md"
,
e
)
sys
.
stdout
.
flush
()
sys
.
exit
(
-
1
)
...
...
@@ -716,9 +716,9 @@ class S2ANetHead(nn.Layer):
odm_bbox
=
paddle
.
sum
(
odm_bbox
,
axis
=-
1
)
feat_bbox_weights
=
paddle
.
sum
(
feat_bbox_weights
,
axis
=-
1
)
try
:
from
rbox_iou_ops
import
rbox_iou
from
ext_op
import
rbox_iou
except
Exception
as
e
:
print
(
"import custom_ops error, try install
rbox_iou_ops
"
\
print
(
"import custom_ops error, try install
ext_op
"
\
"following ppdet/ext_op/README.md"
,
e
)
sys
.
stdout
.
flush
()
sys
.
exit
(
-
1
)
...
...
ppdet/modeling/proposal_generator/target_layer.py
浏览文件 @
04825be6
...
...
@@ -392,9 +392,9 @@ class RBoxAssigner(object):
gt_bboxes_xc_yc
=
paddle
.
to_tensor
(
gt_bboxes_xc_yc
)
try
:
from
rbox_iou_ops
import
rbox_iou
from
ext_op
import
rbox_iou
except
Exception
as
e
:
print
(
"import custom_ops error, try install
rbox_iou_ops
"
\
print
(
"import custom_ops error, try install
ext_op
"
\
"following ppdet/ext_op/README.md"
,
e
)
sys
.
stdout
.
flush
()
sys
.
exit
(
-
1
)
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录