未验证 提交 04825be6 编写于 作者: W wangxinxin08 提交者: GitHub

refactor ext op and add matched rbox iou (#6530)

* refactor ext op and add matched rbox iou

* replace rbox_iou_ops with ext_op
上级 3e2330fb
# 自定义OP编译 # 自定义OP编译
旋转框IOU计算OP是参考[自定义外部算子](https://www.paddlepaddle.org.cn/documentation/docs/zh/guides/07_new_op/new_custom_op.html) 旋转框IOU计算OP是参考[自定义外部算子](https://www.paddlepaddle.org.cn/documentation/docs/zh/guides/custom_op/new_cpp_op_cn.html)
## 1. 环境依赖 ## 1. 环境依赖
- Paddle >= 2.0.1 - Paddle >= 2.0.1
...@@ -7,13 +7,13 @@ ...@@ -7,13 +7,13 @@
## 2. 安装 ## 2. 安装
``` ```
python3.7 setup.py install python setup.py install
``` ```
按照如下方式使用 编译完成后即可使用,以下为`rbox_iou`的使用示例
``` ```
# 引入自定义op # 引入自定义op
from rbox_iou_ops import rbox_iou from ext_op import rbox_iou
paddle.set_device('gpu:0') paddle.set_device('gpu:0')
paddle.disable_static() paddle.disable_static()
...@@ -29,10 +29,7 @@ print('iou', iou) ...@@ -29,10 +29,7 @@ print('iou', iou)
``` ```
## 3. 单元测试 ## 3. 单元测试
单元测试`test.py`文件中,通过对比python实现的结果和测试自定义op结果。 可以通过执行单元测试来确认自定义算子功能的正确性,执行单元测试的示例如下所示:
由于python计算细节与cpp计算细节略有区别,误差区间设置为0.02。
``` ```
python3.7 test.py python unittest/test_matched_rbox_iou.py
``` ```
提示`rbox_iou OP compute right!`说明OP测试通过。
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//
// The code is based on
// https://github.com/csuhan/s2anet/blob/master/mmdet/ops/box_iou_rotated
#include "paddle/extension.h"
#include "rbox_iou_op.h"
template <typename T>
void matched_rbox_iou_cpu_kernel(const int rbox_num, const T *rbox1_data_ptr,
const T *rbox2_data_ptr, T *output_data_ptr) {
int i;
for (i = 0; i < rbox_num; i++) {
output_data_ptr[i] =
rbox_iou_single<T>(rbox1_data_ptr + i * 5, rbox2_data_ptr + i * 5);
}
}
#define CHECK_INPUT_CPU(x) \
PD_CHECK(x.place() == paddle::PlaceType::kCPU, #x " must be a CPU Tensor.")
std::vector<paddle::Tensor> MatchedRboxIouCPUForward(const paddle::Tensor &rbox1,
const paddle::Tensor &rbox2) {
CHECK_INPUT_CPU(rbox1);
CHECK_INPUT_CPU(rbox2);
PD_CHECK(rbox1.shape()[0] == rbox2.shape()[0], "inputs must be same dim");
auto rbox_num = rbox1.shape()[0];
auto output = paddle::Tensor(paddle::PlaceType::kCPU, {rbox_num});
PD_DISPATCH_FLOATING_TYPES(rbox1.type(), "rotated_iou_cpu_kernel", ([&] {
matched_rbox_iou_cpu_kernel<data_t>(
rbox_num, rbox1.data<data_t>(),
rbox2.data<data_t>(),
output.mutable_data<data_t>());
}));
return {output};
}
#ifdef PADDLE_WITH_CUDA
std::vector<paddle::Tensor> MatchedRboxIouCUDAForward(const paddle::Tensor &rbox1,
const paddle::Tensor &rbox2);
#endif
#define CHECK_INPUT_SAME(x1, x2) \
PD_CHECK(x1.place() == x2.place(), "input must be smae pacle.")
std::vector<paddle::Tensor> MatchedRboxIouForward(const paddle::Tensor &rbox1,
const paddle::Tensor &rbox2) {
CHECK_INPUT_SAME(rbox1, rbox2);
if (rbox1.place() == paddle::PlaceType::kCPU) {
return MatchedRboxIouCPUForward(rbox1, rbox2);
#ifdef PADDLE_WITH_CUDA
} else if (rbox1.place() == paddle::PlaceType::kGPU) {
return MatchedRboxIouCUDAForward(rbox1, rbox2);
#endif
}
}
std::vector<std::vector<int64_t>>
MatchedRboxIouInferShape(std::vector<int64_t> rbox1_shape,
std::vector<int64_t> rbox2_shape) {
return {{rbox1_shape[0]}};
}
std::vector<paddle::DataType> MatchedRboxIouInferDtype(paddle::DataType t1,
paddle::DataType t2) {
return {t1};
}
PD_BUILD_OP(matched_rbox_iou)
.Inputs({"RBOX1", "RBOX2"})
.Outputs({"Output"})
.SetKernelFn(PD_KERNEL(MatchedRboxIouForward))
.SetInferShapeFn(PD_INFER_SHAPE(MatchedRboxIouInferShape))
.SetInferDtypeFn(PD_INFER_DTYPE(MatchedRboxIouInferDtype));
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//
// The code is based on
// https://github.com/csuhan/s2anet/blob/master/mmdet/ops/box_iou_rotated
#include "paddle/extension.h"
#include "rbox_iou_op.h"
/**
Computes ceil(a / b)
*/
static inline int CeilDiv(const int a, const int b) { return (a + b - 1) / b; }
template <typename T>
__global__ void
matched_rbox_iou_cuda_kernel(const int rbox_num, const T *rbox1_data_ptr,
const T *rbox2_data_ptr, T *output_data_ptr) {
for (int tid = blockIdx.x * blockDim.x + threadIdx.x; tid < rbox_num;
tid += blockDim.x * gridDim.x) {
output_data_ptr[tid] =
rbox_iou_single<T>(rbox1_data_ptr + tid * 5, rbox2_data_ptr + tid * 5);
}
}
#define CHECK_INPUT_GPU(x) \
PD_CHECK(x.place() == paddle::PlaceType::kGPU, #x " must be a GPU Tensor.")
std::vector<paddle::Tensor> MatchedRboxIouCUDAForward(const paddle::Tensor &rbox1,
const paddle::Tensor &rbox2) {
CHECK_INPUT_GPU(rbox1);
CHECK_INPUT_GPU(rbox2);
PD_CHECK(rbox1.shape()[0] == rbox2.shape()[0], "inputs must be same dim");
auto rbox_num = rbox1.shape()[0];
auto output = paddle::Tensor(paddle::PlaceType::kGPU, {rbox_num});
const int thread_per_block = 512;
const int block_per_grid = CeilDiv(rbox_num, thread_per_block);
PD_DISPATCH_FLOATING_TYPES(
rbox1.type(), "matched_rbox_iou_cuda_kernel", ([&] {
matched_rbox_iou_cuda_kernel<
data_t><<<block_per_grid, thread_per_block, 0, rbox1.stream()>>>(
rbox_num, rbox1.data<data_t>(), rbox2.data<data_t>(),
output.mutable_data<data_t>());
}));
return {output};
}
...@@ -12,10 +12,11 @@ ...@@ -12,10 +12,11 @@
// See the License for the specific language governing permissions and // See the License for the specific language governing permissions and
// limitations under the License. // limitations under the License.
// //
// The code is based on https://github.com/csuhan/s2anet/blob/master/mmdet/ops/box_iou_rotated // The code is based on
// https://github.com/csuhan/s2anet/blob/master/mmdet/ops/box_iou_rotated
#include "rbox_iou_op.h"
#include "paddle/extension.h" #include "paddle/extension.h"
#include "rbox_iou_op.h"
// 2D block with 32 * 16 = 512 threads per block // 2D block with 32 * 16 = 512 threads per block
const int BLOCK_DIM_X = 32; const int BLOCK_DIM_X = 32;
...@@ -25,17 +26,13 @@ const int BLOCK_DIM_Y = 16; ...@@ -25,17 +26,13 @@ const int BLOCK_DIM_Y = 16;
Computes ceil(a / b) Computes ceil(a / b)
*/ */
static inline int CeilDiv(const int a, const int b) { static inline int CeilDiv(const int a, const int b) { return (a + b - 1) / b; }
return (a + b - 1) / b;
}
template <typename T> template <typename T>
__global__ void rbox_iou_cuda_kernel( __global__ void rbox_iou_cuda_kernel(const int rbox1_num, const int rbox2_num,
const int rbox1_num, const T *rbox1_data_ptr,
const int rbox2_num, const T *rbox2_data_ptr,
const T* rbox1_data_ptr, T *output_data_ptr) {
const T* rbox2_data_ptr,
T* output_data_ptr) {
// get row_start and col_start // get row_start and col_start
const int rbox1_block_idx = blockIdx.x * blockDim.x; const int rbox1_block_idx = blockIdx.x * blockDim.x;
...@@ -47,7 +44,6 @@ __global__ void rbox_iou_cuda_kernel( ...@@ -47,7 +44,6 @@ __global__ void rbox_iou_cuda_kernel(
__shared__ T block_boxes1[BLOCK_DIM_X * 5]; __shared__ T block_boxes1[BLOCK_DIM_X * 5];
__shared__ T block_boxes2[BLOCK_DIM_Y * 5]; __shared__ T block_boxes2[BLOCK_DIM_Y * 5];
// It's safe to copy using threadIdx.x since BLOCK_DIM_X >= BLOCK_DIM_Y // It's safe to copy using threadIdx.x since BLOCK_DIM_X >= BLOCK_DIM_Y
if (threadIdx.x < rbox1_thread_num && threadIdx.y == 0) { if (threadIdx.x < rbox1_thread_num && threadIdx.y == 0) {
block_boxes1[threadIdx.x * 5 + 0] = block_boxes1[threadIdx.x * 5 + 0] =
...@@ -62,7 +58,8 @@ __global__ void rbox_iou_cuda_kernel( ...@@ -62,7 +58,8 @@ __global__ void rbox_iou_cuda_kernel(
rbox1_data_ptr[(rbox1_block_idx + threadIdx.x) * 5 + 4]; rbox1_data_ptr[(rbox1_block_idx + threadIdx.x) * 5 + 4];
} }
// threadIdx.x < BLOCK_DIM_Y=rbox2_thread_num, just use same condition as above: threadIdx.y == 0 // threadIdx.x < BLOCK_DIM_Y=rbox2_thread_num, just use same condition as
// above: threadIdx.y == 0
if (threadIdx.x < rbox2_thread_num && threadIdx.y == 0) { if (threadIdx.x < rbox2_thread_num && threadIdx.y == 0) {
block_boxes2[threadIdx.x * 5 + 0] = block_boxes2[threadIdx.x * 5 + 0] =
rbox2_data_ptr[(rbox2_block_idx + threadIdx.x) * 5 + 0]; rbox2_data_ptr[(rbox2_block_idx + threadIdx.x) * 5 + 0];
...@@ -80,41 +77,38 @@ __global__ void rbox_iou_cuda_kernel( ...@@ -80,41 +77,38 @@ __global__ void rbox_iou_cuda_kernel(
__syncthreads(); __syncthreads();
if (threadIdx.x < rbox1_thread_num && threadIdx.y < rbox2_thread_num) { if (threadIdx.x < rbox1_thread_num && threadIdx.y < rbox2_thread_num) {
int offset = (rbox1_block_idx + threadIdx.x) * rbox2_num + rbox2_block_idx + threadIdx.y; int offset = (rbox1_block_idx + threadIdx.x) * rbox2_num + rbox2_block_idx +
output_data_ptr[offset] = rbox_iou_single<T>(block_boxes1 + threadIdx.x * 5, block_boxes2 + threadIdx.y * 5); threadIdx.y;
output_data_ptr[offset] = rbox_iou_single<T>(
block_boxes1 + threadIdx.x * 5, block_boxes2 + threadIdx.y * 5);
} }
} }
#define CHECK_INPUT_GPU(x) PD_CHECK(x.place() == paddle::PlaceType::kGPU, #x " must be a GPU Tensor.") #define CHECK_INPUT_GPU(x) \
PD_CHECK(x.place() == paddle::PlaceType::kGPU, #x " must be a GPU Tensor.")
std::vector<paddle::Tensor> RboxIouCUDAForward(const paddle::Tensor& rbox1, const paddle::Tensor& rbox2) { std::vector<paddle::Tensor> RboxIouCUDAForward(const paddle::Tensor &rbox1,
CHECK_INPUT_GPU(rbox1); const paddle::Tensor &rbox2) {
CHECK_INPUT_GPU(rbox2); CHECK_INPUT_GPU(rbox1);
CHECK_INPUT_GPU(rbox2);
auto rbox1_num = rbox1.shape()[0]; auto rbox1_num = rbox1.shape()[0];
auto rbox2_num = rbox2.shape()[0]; auto rbox2_num = rbox2.shape()[0];
auto output = paddle::Tensor(paddle::PlaceType::kGPU, {rbox1_num, rbox2_num}); auto output = paddle::Tensor(paddle::PlaceType::kGPU, {rbox1_num, rbox2_num});
const int blocks_x = CeilDiv(rbox1_num, BLOCK_DIM_X); const int blocks_x = CeilDiv(rbox1_num, BLOCK_DIM_X);
const int blocks_y = CeilDiv(rbox2_num, BLOCK_DIM_Y); const int blocks_y = CeilDiv(rbox2_num, BLOCK_DIM_Y);
dim3 blocks(blocks_x, blocks_y); dim3 blocks(blocks_x, blocks_y);
dim3 threads(BLOCK_DIM_X, BLOCK_DIM_Y); dim3 threads(BLOCK_DIM_X, BLOCK_DIM_Y);
PD_DISPATCH_FLOATING_TYPES( PD_DISPATCH_FLOATING_TYPES(
rbox1.type(), rbox1.type(), "rbox_iou_cuda_kernel", ([&] {
"rbox_iou_cuda_kernel", rbox_iou_cuda_kernel<data_t><<<blocks, threads, 0, rbox1.stream()>>>(
([&] { rbox1_num, rbox2_num, rbox1.data<data_t>(), rbox2.data<data_t>(),
rbox_iou_cuda_kernel<data_t><<<blocks, threads, 0, rbox1.stream()>>>( output.mutable_data<data_t>());
rbox1_num, }));
rbox2_num,
rbox1.data<data_t>(),
rbox2.data<data_t>(),
output.mutable_data<data_t>());
}));
return {output}; return {output};
} }
...@@ -12,7 +12,8 @@ ...@@ -12,7 +12,8 @@
// See the License for the specific language governing permissions and // See the License for the specific language governing permissions and
// limitations under the License. // limitations under the License.
// //
// The code is based on https://github.com/csuhan/s2anet/blob/master/mmdet/ops/box_iou_rotated // The code is based on
// https://github.com/csuhan/s2anet/blob/master/mmdet/ops/box_iou_rotated
#pragma once #pragma once
...@@ -32,24 +33,20 @@ ...@@ -32,24 +33,20 @@
namespace { namespace {
template <typename T> template <typename T> struct RotatedBox { T x_ctr, y_ctr, w, h, a; };
struct RotatedBox {
T x_ctr, y_ctr, w, h, a;
};
template <typename T> template <typename T> struct Point {
struct Point {
T x, y; T x, y;
HOST_DEVICE_INLINE Point(const T& px = 0, const T& py = 0) : x(px), y(py) {} HOST_DEVICE_INLINE Point(const T &px = 0, const T &py = 0) : x(px), y(py) {}
HOST_DEVICE_INLINE Point operator+(const Point& p) const { HOST_DEVICE_INLINE Point operator+(const Point &p) const {
return Point(x + p.x, y + p.y); return Point(x + p.x, y + p.y);
} }
HOST_DEVICE_INLINE Point& operator+=(const Point& p) { HOST_DEVICE_INLINE Point &operator+=(const Point &p) {
x += p.x; x += p.x;
y += p.y; y += p.y;
return *this; return *this;
} }
HOST_DEVICE_INLINE Point operator-(const Point& p) const { HOST_DEVICE_INLINE Point operator-(const Point &p) const {
return Point(x - p.x, y - p.y); return Point(x - p.x, y - p.y);
} }
HOST_DEVICE_INLINE Point operator*(const T coeff) const { HOST_DEVICE_INLINE Point operator*(const T coeff) const {
...@@ -58,22 +55,21 @@ struct Point { ...@@ -58,22 +55,21 @@ struct Point {
}; };
template <typename T> template <typename T>
HOST_DEVICE_INLINE T dot_2d(const Point<T>& A, const Point<T>& B) { HOST_DEVICE_INLINE T dot_2d(const Point<T> &A, const Point<T> &B) {
return A.x * B.x + A.y * B.y; return A.x * B.x + A.y * B.y;
} }
template <typename T> template <typename T>
HOST_DEVICE_INLINE T cross_2d(const Point<T>& A, const Point<T>& B) { HOST_DEVICE_INLINE T cross_2d(const Point<T> &A, const Point<T> &B) {
return A.x * B.y - B.x * A.y; return A.x * B.y - B.x * A.y;
} }
template <typename T> template <typename T>
HOST_DEVICE_INLINE void get_rotated_vertices( HOST_DEVICE_INLINE void get_rotated_vertices(const RotatedBox<T> &box,
const RotatedBox<T>& box, Point<T> (&pts)[4]) {
Point<T> (&pts)[4]) {
// M_PI / 180. == 0.01745329251 // M_PI / 180. == 0.01745329251
//double theta = box.a * 0.01745329251; // double theta = box.a * 0.01745329251;
//MODIFIED // MODIFIED
double theta = box.a; double theta = box.a;
T cosTheta2 = (T)cos(theta) * 0.5f; T cosTheta2 = (T)cos(theta) * 0.5f;
T sinTheta2 = (T)sin(theta) * 0.5f; T sinTheta2 = (T)sin(theta) * 0.5f;
...@@ -90,10 +86,9 @@ HOST_DEVICE_INLINE void get_rotated_vertices( ...@@ -90,10 +86,9 @@ HOST_DEVICE_INLINE void get_rotated_vertices(
} }
template <typename T> template <typename T>
HOST_DEVICE_INLINE int get_intersection_points( HOST_DEVICE_INLINE int get_intersection_points(const Point<T> (&pts1)[4],
const Point<T> (&pts1)[4], const Point<T> (&pts2)[4],
const Point<T> (&pts2)[4], Point<T> (&intersections)[24]) {
Point<T> (&intersections)[24]) {
// Line vector // Line vector
// A line from p1 to p2 is: p1 + (p2-p1)*t, t=[0,1] // A line from p1 to p2 is: p1 + (p2-p1)*t, t=[0,1]
Point<T> vec1[4], vec2[4]; Point<T> vec1[4], vec2[4];
...@@ -127,8 +122,8 @@ HOST_DEVICE_INLINE int get_intersection_points( ...@@ -127,8 +122,8 @@ HOST_DEVICE_INLINE int get_intersection_points(
// Check for vertices of rect1 inside rect2 // Check for vertices of rect1 inside rect2
{ {
const auto& AB = vec2[0]; const auto &AB = vec2[0];
const auto& DA = vec2[3]; const auto &DA = vec2[3];
auto ABdotAB = dot_2d<T>(AB, AB); auto ABdotAB = dot_2d<T>(AB, AB);
auto ADdotAD = dot_2d<T>(DA, DA); auto ADdotAD = dot_2d<T>(DA, DA);
for (int i = 0; i < 4; i++) { for (int i = 0; i < 4; i++) {
...@@ -150,8 +145,8 @@ HOST_DEVICE_INLINE int get_intersection_points( ...@@ -150,8 +145,8 @@ HOST_DEVICE_INLINE int get_intersection_points(
// Reverse the check - check for vertices of rect2 inside rect1 // Reverse the check - check for vertices of rect2 inside rect1
{ {
const auto& AB = vec1[0]; const auto &AB = vec1[0];
const auto& DA = vec1[3]; const auto &DA = vec1[3];
auto ABdotAB = dot_2d<T>(AB, AB); auto ABdotAB = dot_2d<T>(AB, AB);
auto ADdotAD = dot_2d<T>(DA, DA); auto ADdotAD = dot_2d<T>(DA, DA);
for (int i = 0; i < 4; i++) { for (int i = 0; i < 4; i++) {
...@@ -171,11 +166,9 @@ HOST_DEVICE_INLINE int get_intersection_points( ...@@ -171,11 +166,9 @@ HOST_DEVICE_INLINE int get_intersection_points(
} }
template <typename T> template <typename T>
HOST_DEVICE_INLINE int convex_hull_graham( HOST_DEVICE_INLINE int convex_hull_graham(const Point<T> (&p)[24],
const Point<T> (&p)[24], const int &num_in, Point<T> (&q)[24],
const int& num_in, bool shift_to_zero = false) {
Point<T> (&q)[24],
bool shift_to_zero = false) {
assert(num_in >= 2); assert(num_in >= 2);
// Step 1: // Step 1:
...@@ -188,7 +181,7 @@ HOST_DEVICE_INLINE int convex_hull_graham( ...@@ -188,7 +181,7 @@ HOST_DEVICE_INLINE int convex_hull_graham(
t = i; t = i;
} }
} }
auto& start = p[t]; // starting point auto &start = p[t]; // starting point
// Step 2: // Step 2:
// Subtract starting point from every points (for sorting in the next step) // Subtract starting point from every points (for sorting in the next step)
...@@ -230,15 +223,15 @@ HOST_DEVICE_INLINE int convex_hull_graham( ...@@ -230,15 +223,15 @@ HOST_DEVICE_INLINE int convex_hull_graham(
} }
#else #else
// CPU version // CPU version
std::sort( std::sort(q + 1, q + num_in,
q + 1, q + num_in, [](const Point<T>& A, const Point<T>& B) -> bool { [](const Point<T> &A, const Point<T> &B) -> bool {
T temp = cross_2d<T>(A, B); T temp = cross_2d<T>(A, B);
if (fabs(temp) < 1e-6) { if (fabs(temp) < 1e-6) {
return dot_2d<T>(A, A) < dot_2d<T>(B, B); return dot_2d<T>(A, A) < dot_2d<T>(B, B);
} else { } else {
return temp > 0; return temp > 0;
} }
}); });
#endif #endif
// Step 4: // Step 4:
...@@ -286,7 +279,7 @@ HOST_DEVICE_INLINE int convex_hull_graham( ...@@ -286,7 +279,7 @@ HOST_DEVICE_INLINE int convex_hull_graham(
} }
template <typename T> template <typename T>
HOST_DEVICE_INLINE T polygon_area(const Point<T> (&q)[24], const int& m) { HOST_DEVICE_INLINE T polygon_area(const Point<T> (&q)[24], const int &m) {
if (m <= 2) { if (m <= 2) {
return 0; return 0;
} }
...@@ -300,9 +293,8 @@ HOST_DEVICE_INLINE T polygon_area(const Point<T> (&q)[24], const int& m) { ...@@ -300,9 +293,8 @@ HOST_DEVICE_INLINE T polygon_area(const Point<T> (&q)[24], const int& m) {
} }
template <typename T> template <typename T>
HOST_DEVICE_INLINE T rboxes_intersection( HOST_DEVICE_INLINE T rboxes_intersection(const RotatedBox<T> &box1,
const RotatedBox<T>& box1, const RotatedBox<T> &box2) {
const RotatedBox<T>& box2) {
// There are up to 4 x 4 + 4 + 4 = 24 intersections (including dups) returned // There are up to 4 x 4 + 4 + 4 = 24 intersections (including dups) returned
// from rotated_rect_intersection_pts // from rotated_rect_intersection_pts
Point<T> intersectPts[24], orderedPts[24]; Point<T> intersectPts[24], orderedPts[24];
...@@ -327,8 +319,8 @@ HOST_DEVICE_INLINE T rboxes_intersection( ...@@ -327,8 +319,8 @@ HOST_DEVICE_INLINE T rboxes_intersection(
} // namespace } // namespace
template <typename T> template <typename T>
HOST_DEVICE_INLINE T HOST_DEVICE_INLINE T rbox_iou_single(T const *const box1_raw,
rbox_iou_single(T const* const box1_raw, T const* const box2_raw) { T const *const box2_raw) {
// shift center to the middle point to achieve higher precision in result // shift center to the middle point to achieve higher precision in result
RotatedBox<T> box1, box2; RotatedBox<T> box1, box2;
auto center_shift_x = (box1_raw[0] + box2_raw[0]) / 2.0; auto center_shift_x = (box1_raw[0] + box2_raw[0]) / 2.0;
......
import os
import glob
import paddle import paddle
from paddle.utils.cpp_extension import CppExtension, CUDAExtension, setup from paddle.utils.cpp_extension import CppExtension, CUDAExtension, setup
if __name__ == "__main__":
def get_extensions():
root_dir = os.path.dirname(os.path.abspath(__file__))
ext_root_dir = os.path.join(root_dir, 'csrc')
sources = []
for ext_name in os.listdir(ext_root_dir):
ext_dir = os.path.join(ext_root_dir, ext_name)
source = glob.glob(os.path.join(ext_dir, '*.cc'))
kwargs = dict()
if paddle.device.is_compiled_with_cuda():
source += glob.glob(os.path.join(ext_dir, '*.cu'))
if not source:
continue
sources += source
if paddle.device.is_compiled_with_cuda(): if paddle.device.is_compiled_with_cuda():
setup( extension = CUDAExtension(
name='rbox_iou_ops', sources, extra_compile_args={'cxx': ['-DPADDLE_WITH_CUDA']})
ext_modules=CUDAExtension(
sources=['rbox_iou_op.cc', 'rbox_iou_op.cu'],
extra_compile_args={'cxx': ['-DPADDLE_WITH_CUDA']}))
else: else:
setup( extension = CppExtension(sources)
name='rbox_iou_ops',
ext_modules=CppExtension(sources=['rbox_iou_op.cc'])) return extension
if __name__ == "__main__":
setup(name='ext_op', ext_modules=get_extensions())
import numpy as np
import sys
import time
from shapely.geometry import Polygon
import paddle
import unittest
from ext_op import matched_rbox_iou
def rbox2poly_single(rrect, get_best_begin_point=False):
"""
rrect:[x_ctr,y_ctr,w,h,angle]
to
poly:[x0,y0,x1,y1,x2,y2,x3,y3]
"""
x_ctr, y_ctr, width, height, angle = rrect[:5]
tl_x, tl_y, br_x, br_y = -width / 2, -height / 2, width / 2, height / 2
# rect 2x4
rect = np.array([[tl_x, br_x, br_x, tl_x], [tl_y, tl_y, br_y, br_y]])
R = np.array([[np.cos(angle), -np.sin(angle)],
[np.sin(angle), np.cos(angle)]])
# poly
poly = R.dot(rect)
x0, x1, x2, x3 = poly[0, :4] + x_ctr
y0, y1, y2, y3 = poly[1, :4] + y_ctr
poly = np.array([x0, y0, x1, y1, x2, y2, x3, y3], dtype=np.float64)
return poly
def intersection(g, p):
"""
Intersection.
"""
g = g[:8].reshape((4, 2))
p = p[:8].reshape((4, 2))
a = g
b = p
use_filter = True
if use_filter:
# step1:
inter_x1 = np.maximum(np.min(a[:, 0]), np.min(b[:, 0]))
inter_x2 = np.minimum(np.max(a[:, 0]), np.max(b[:, 0]))
inter_y1 = np.maximum(np.min(a[:, 1]), np.min(b[:, 1]))
inter_y2 = np.minimum(np.max(a[:, 1]), np.max(b[:, 1]))
if inter_x1 >= inter_x2 or inter_y1 >= inter_y2:
return 0.
x1 = np.minimum(np.min(a[:, 0]), np.min(b[:, 0]))
x2 = np.maximum(np.max(a[:, 0]), np.max(b[:, 0]))
y1 = np.minimum(np.min(a[:, 1]), np.min(b[:, 1]))
y2 = np.maximum(np.max(a[:, 1]), np.max(b[:, 1]))
if x1 >= x2 or y1 >= y2 or (x2 - x1) < 2 or (y2 - y1) < 2:
return 0.
g = Polygon(g)
p = Polygon(p)
if not g.is_valid or not p.is_valid:
return 0
inter = Polygon(g).intersection(Polygon(p)).area
union = g.area + p.area - inter
if union == 0:
return 0
else:
return inter / union
def matched_rbox_overlaps(anchors, gt_bboxes, use_cv2=False):
"""
Args:
anchors: [M, 5] x1,y1,x2,y2,angle
gt_bboxes: [M, 5] x1,y1,x2,y2,angle
Returns:
macthed_iou: [M]
"""
assert anchors.shape[1] == 5
assert gt_bboxes.shape[1] == 5
gt_bboxes_ploy = [rbox2poly_single(e) for e in gt_bboxes]
anchors_ploy = [rbox2poly_single(e) for e in anchors]
num = len(anchors_ploy)
iou = np.zeros((num, ), dtype=np.float64)
start_time = time.time()
for i in range(num):
try:
iou[i] = intersection(gt_bboxes_ploy[i], anchors_ploy[i])
except Exception as e:
print('cur gt_bboxes_ploy[i]', gt_bboxes_ploy[i],
'anchors_ploy[j]', anchors_ploy[i], e)
return iou
def gen_sample(n):
rbox = np.random.rand(n, 5)
rbox[:, 0:4] = rbox[:, 0:4] * 0.45 + 0.001
rbox[:, 4] = rbox[:, 4] - 0.5
return rbox
class MatchedRBoxIoUTest(unittest.TestCase):
def setUp(self):
self.initTestCase()
self.rbox1 = gen_sample(self.n)
self.rbox2 = gen_sample(self.n)
def initTestCase(self):
self.n = 1000
def assertAllClose(self, x, y, msg, atol=5e-1, rtol=1e-2):
self.assertTrue(np.allclose(x, y, atol=atol, rtol=rtol), msg=msg)
def get_places(self):
places = [paddle.CPUPlace()]
if paddle.device.is_compiled_with_cuda():
places.append(paddle.CUDAPlace(0))
return places
def check_output(self, place):
paddle.disable_static()
pd_rbox1 = paddle.to_tensor(self.rbox1, place=place)
pd_rbox2 = paddle.to_tensor(self.rbox2, place=place)
actual_t = matched_rbox_iou(pd_rbox1, pd_rbox2).numpy()
poly_rbox1 = self.rbox1
poly_rbox2 = self.rbox2
poly_rbox1[:, 0:4] = self.rbox1[:, 0:4] * 1024
poly_rbox2[:, 0:4] = self.rbox2[:, 0:4] * 1024
expect_t = matched_rbox_overlaps(poly_rbox1, poly_rbox2, use_cv2=False)
self.assertAllClose(
actual_t,
expect_t,
msg="rbox_iou has diff at {} \nExpect {}\nBut got {}".format(
str(place), str(expect_t), str(actual_t)))
def test_output(self):
places = self.get_places()
for place in places:
self.check_output(place)
if __name__ == "__main__":
unittest.main()
...@@ -5,11 +5,7 @@ from shapely.geometry import Polygon ...@@ -5,11 +5,7 @@ from shapely.geometry import Polygon
import paddle import paddle
import unittest import unittest
try: from ext_op import rbox_iou
from rbox_iou_ops import rbox_iou
except Exception as e:
print('import rbox_iou_ops error', e)
sys.exit(-1)
def rbox2poly_single(rrect, get_best_begin_point=False): def rbox2poly_single(rrect, get_best_begin_point=False):
...@@ -80,7 +76,7 @@ def rbox_overlaps(anchors, gt_bboxes, use_cv2=False): ...@@ -80,7 +76,7 @@ def rbox_overlaps(anchors, gt_bboxes, use_cv2=False):
gt_bboxes: [M, 5] x1,y1,x2,y2,angle gt_bboxes: [M, 5] x1,y1,x2,y2,angle
Returns: Returns:
iou: [NA, M]
""" """
assert anchors.shape[1] == 5 assert anchors.shape[1] == 5
assert gt_bboxes.shape[1] == 5 assert gt_bboxes.shape[1] == 5
...@@ -89,17 +85,16 @@ def rbox_overlaps(anchors, gt_bboxes, use_cv2=False): ...@@ -89,17 +85,16 @@ def rbox_overlaps(anchors, gt_bboxes, use_cv2=False):
anchors_ploy = [rbox2poly_single(e) for e in anchors] anchors_ploy = [rbox2poly_single(e) for e in anchors]
num_gt, num_anchors = len(gt_bboxes_ploy), len(anchors_ploy) num_gt, num_anchors = len(gt_bboxes_ploy), len(anchors_ploy)
iou = np.zeros((num_gt, num_anchors), dtype=np.float64) iou = np.zeros((num_anchors, num_gt), dtype=np.float64)
start_time = time.time() start_time = time.time()
for i in range(num_gt): for i in range(num_anchors):
for j in range(num_anchors): for j in range(num_gt):
try: try:
iou[i, j] = intersection(gt_bboxes_ploy[i], anchors_ploy[j]) iou[i, j] = intersection(anchors_ploy[i], gt_bboxes_ploy[j])
except Exception as e: except Exception as e:
print('cur gt_bboxes_ploy[i]', gt_bboxes_ploy[i], print('cur anchors_ploy[i]', anchors_ploy[i],
'anchors_ploy[j]', anchors_ploy[j], e) 'gt_bboxes_ploy[j]', gt_bboxes_ploy[j], e)
iou = iou.T
return iou return iou
......
...@@ -121,9 +121,9 @@ def calc_rbox_iou(pred, gt_rbox): ...@@ -121,9 +121,9 @@ def calc_rbox_iou(pred, gt_rbox):
pred_rbox = pred_rbox.reshape(-1, 5) pred_rbox = pred_rbox.reshape(-1, 5)
pred_rbox = pred_rbox.reshape(-1, 5) pred_rbox = pred_rbox.reshape(-1, 5)
try: try:
from rbox_iou_ops import rbox_iou from ext_op import rbox_iou
except Exception as e: except Exception as e:
print("import custom_ops error, try install rbox_iou_ops " \ print("import custom_ops error, try install ext_op " \
"following ppdet/ext_op/README.md", e) "following ppdet/ext_op/README.md", e)
sys.stdout.flush() sys.stdout.flush()
sys.exit(-1) sys.exit(-1)
......
...@@ -601,9 +601,9 @@ class S2ANetHead(nn.Layer): ...@@ -601,9 +601,9 @@ class S2ANetHead(nn.Layer):
fam_bbox = paddle.sum(fam_bbox, axis=-1) fam_bbox = paddle.sum(fam_bbox, axis=-1)
feat_bbox_weights = paddle.sum(feat_bbox_weights, axis=-1) feat_bbox_weights = paddle.sum(feat_bbox_weights, axis=-1)
try: try:
from rbox_iou_ops import rbox_iou from ext_op import rbox_iou
except Exception as e: except Exception as e:
print("import custom_ops error, try install rbox_iou_ops " \ print("import custom_ops error, try install ext_op " \
"following ppdet/ext_op/README.md", e) "following ppdet/ext_op/README.md", e)
sys.stdout.flush() sys.stdout.flush()
sys.exit(-1) sys.exit(-1)
...@@ -716,9 +716,9 @@ class S2ANetHead(nn.Layer): ...@@ -716,9 +716,9 @@ class S2ANetHead(nn.Layer):
odm_bbox = paddle.sum(odm_bbox, axis=-1) odm_bbox = paddle.sum(odm_bbox, axis=-1)
feat_bbox_weights = paddle.sum(feat_bbox_weights, axis=-1) feat_bbox_weights = paddle.sum(feat_bbox_weights, axis=-1)
try: try:
from rbox_iou_ops import rbox_iou from ext_op import rbox_iou
except Exception as e: except Exception as e:
print("import custom_ops error, try install rbox_iou_ops " \ print("import custom_ops error, try install ext_op " \
"following ppdet/ext_op/README.md", e) "following ppdet/ext_op/README.md", e)
sys.stdout.flush() sys.stdout.flush()
sys.exit(-1) sys.exit(-1)
......
...@@ -392,9 +392,9 @@ class RBoxAssigner(object): ...@@ -392,9 +392,9 @@ class RBoxAssigner(object):
gt_bboxes_xc_yc = paddle.to_tensor(gt_bboxes_xc_yc) gt_bboxes_xc_yc = paddle.to_tensor(gt_bboxes_xc_yc)
try: try:
from rbox_iou_ops import rbox_iou from ext_op import rbox_iou
except Exception as e: except Exception as e:
print("import custom_ops error, try install rbox_iou_ops " \ print("import custom_ops error, try install ext_op " \
"following ppdet/ext_op/README.md", e) "following ppdet/ext_op/README.md", e)
sys.stdout.flush() sys.stdout.flush()
sys.exit(-1) sys.exit(-1)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册