diff --git a/ppdet/ext_op/README.md b/ppdet/ext_op/README.md index 7ada0acf7fd75266fed6c66a9a010debc645bee8..0d67062ade859b0ca025d6ad35d9a630cf4ec523 100644 --- a/ppdet/ext_op/README.md +++ b/ppdet/ext_op/README.md @@ -1,5 +1,5 @@ # 自定义OP编译 -旋转框IOU计算OP是参考[自定义外部算子](https://www.paddlepaddle.org.cn/documentation/docs/zh/guides/07_new_op/new_custom_op.html) 。 +旋转框IOU计算OP是参考[自定义外部算子](https://www.paddlepaddle.org.cn/documentation/docs/zh/guides/custom_op/new_cpp_op_cn.html) 。 ## 1. 环境依赖 - Paddle >= 2.0.1 @@ -7,13 +7,13 @@ ## 2. 安装 ``` -python3.7 setup.py install +python setup.py install ``` -按照如下方式使用 +编译完成后即可使用,以下为`rbox_iou`的使用示例 ``` # 引入自定义op -from rbox_iou_ops import rbox_iou +from ext_op import rbox_iou paddle.set_device('gpu:0') paddle.disable_static() @@ -29,10 +29,7 @@ print('iou', iou) ``` ## 3. 单元测试 -单元测试`test.py`文件中,通过对比python实现的结果和测试自定义op结果。 - -由于python计算细节与cpp计算细节略有区别,误差区间设置为0.02。 +可以通过执行单元测试来确认自定义算子功能的正确性,执行单元测试的示例如下所示: ``` -python3.7 test.py +python unittest/test_matched_rbox_iou.py ``` -提示`rbox_iou OP compute right!`说明OP测试通过。 diff --git a/ppdet/ext_op/csrc/rbox_iou/matched_rbox_iou_op.cc b/ppdet/ext_op/csrc/rbox_iou/matched_rbox_iou_op.cc new file mode 100644 index 0000000000000000000000000000000000000000..2c3c58b606c22607272d6d37877d11399d7542d9 --- /dev/null +++ b/ppdet/ext_op/csrc/rbox_iou/matched_rbox_iou_op.cc @@ -0,0 +1,90 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// +// The code is based on +// https://github.com/csuhan/s2anet/blob/master/mmdet/ops/box_iou_rotated + +#include "paddle/extension.h" +#include "rbox_iou_op.h" + +template +void matched_rbox_iou_cpu_kernel(const int rbox_num, const T *rbox1_data_ptr, + const T *rbox2_data_ptr, T *output_data_ptr) { + + int i; + for (i = 0; i < rbox_num; i++) { + output_data_ptr[i] = + rbox_iou_single(rbox1_data_ptr + i * 5, rbox2_data_ptr + i * 5); + } +} + +#define CHECK_INPUT_CPU(x) \ + PD_CHECK(x.place() == paddle::PlaceType::kCPU, #x " must be a CPU Tensor.") + +std::vector MatchedRboxIouCPUForward(const paddle::Tensor &rbox1, + const paddle::Tensor &rbox2) { + CHECK_INPUT_CPU(rbox1); + CHECK_INPUT_CPU(rbox2); + PD_CHECK(rbox1.shape()[0] == rbox2.shape()[0], "inputs must be same dim"); + + auto rbox_num = rbox1.shape()[0]; + auto output = paddle::Tensor(paddle::PlaceType::kCPU, {rbox_num}); + + PD_DISPATCH_FLOATING_TYPES(rbox1.type(), "rotated_iou_cpu_kernel", ([&] { + matched_rbox_iou_cpu_kernel( + rbox_num, rbox1.data(), + rbox2.data(), + output.mutable_data()); + })); + + return {output}; +} + +#ifdef PADDLE_WITH_CUDA +std::vector MatchedRboxIouCUDAForward(const paddle::Tensor &rbox1, + const paddle::Tensor &rbox2); +#endif + +#define CHECK_INPUT_SAME(x1, x2) \ + PD_CHECK(x1.place() == x2.place(), "input must be smae pacle.") + +std::vector MatchedRboxIouForward(const paddle::Tensor &rbox1, + const paddle::Tensor &rbox2) { + CHECK_INPUT_SAME(rbox1, rbox2); + if (rbox1.place() == paddle::PlaceType::kCPU) { + return MatchedRboxIouCPUForward(rbox1, rbox2); +#ifdef PADDLE_WITH_CUDA + } else if (rbox1.place() == paddle::PlaceType::kGPU) { + return MatchedRboxIouCUDAForward(rbox1, rbox2); +#endif + } +} + +std::vector> +MatchedRboxIouInferShape(std::vector rbox1_shape, + std::vector rbox2_shape) { + return {{rbox1_shape[0]}}; +} + +std::vector MatchedRboxIouInferDtype(paddle::DataType t1, + paddle::DataType t2) { + return {t1}; +} + +PD_BUILD_OP(matched_rbox_iou) + .Inputs({"RBOX1", "RBOX2"}) + .Outputs({"Output"}) + .SetKernelFn(PD_KERNEL(MatchedRboxIouForward)) + .SetInferShapeFn(PD_INFER_SHAPE(MatchedRboxIouInferShape)) + .SetInferDtypeFn(PD_INFER_DTYPE(MatchedRboxIouInferDtype)); diff --git a/ppdet/ext_op/csrc/rbox_iou/matched_rbox_iou_op.cu b/ppdet/ext_op/csrc/rbox_iou/matched_rbox_iou_op.cu new file mode 100644 index 0000000000000000000000000000000000000000..8d03ecce6a775162980746adf727738a6beb102b --- /dev/null +++ b/ppdet/ext_op/csrc/rbox_iou/matched_rbox_iou_op.cu @@ -0,0 +1,63 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// +// The code is based on +// https://github.com/csuhan/s2anet/blob/master/mmdet/ops/box_iou_rotated + +#include "paddle/extension.h" +#include "rbox_iou_op.h" + +/** + Computes ceil(a / b) +*/ + +static inline int CeilDiv(const int a, const int b) { return (a + b - 1) / b; } + +template +__global__ void +matched_rbox_iou_cuda_kernel(const int rbox_num, const T *rbox1_data_ptr, + const T *rbox2_data_ptr, T *output_data_ptr) { + for (int tid = blockIdx.x * blockDim.x + threadIdx.x; tid < rbox_num; + tid += blockDim.x * gridDim.x) { + output_data_ptr[tid] = + rbox_iou_single(rbox1_data_ptr + tid * 5, rbox2_data_ptr + tid * 5); + } +} + +#define CHECK_INPUT_GPU(x) \ + PD_CHECK(x.place() == paddle::PlaceType::kGPU, #x " must be a GPU Tensor.") + +std::vector MatchedRboxIouCUDAForward(const paddle::Tensor &rbox1, + const paddle::Tensor &rbox2) { + CHECK_INPUT_GPU(rbox1); + CHECK_INPUT_GPU(rbox2); + PD_CHECK(rbox1.shape()[0] == rbox2.shape()[0], "inputs must be same dim"); + + auto rbox_num = rbox1.shape()[0]; + + auto output = paddle::Tensor(paddle::PlaceType::kGPU, {rbox_num}); + + const int thread_per_block = 512; + const int block_per_grid = CeilDiv(rbox_num, thread_per_block); + + PD_DISPATCH_FLOATING_TYPES( + rbox1.type(), "matched_rbox_iou_cuda_kernel", ([&] { + matched_rbox_iou_cuda_kernel< + data_t><<>>( + rbox_num, rbox1.data(), rbox2.data(), + output.mutable_data()); + })); + + return {output}; +} diff --git a/ppdet/ext_op/rbox_iou_op.cc b/ppdet/ext_op/csrc/rbox_iou/rbox_iou_op.cc similarity index 100% rename from ppdet/ext_op/rbox_iou_op.cc rename to ppdet/ext_op/csrc/rbox_iou/rbox_iou_op.cc diff --git a/ppdet/ext_op/rbox_iou_op.cu b/ppdet/ext_op/csrc/rbox_iou/rbox_iou_op.cu similarity index 63% rename from ppdet/ext_op/rbox_iou_op.cu rename to ppdet/ext_op/csrc/rbox_iou/rbox_iou_op.cu index 8ec43e54b4a813ef5829ba3120cc4a2be4d5d9b9..16d1d36f1002832d01db826743ce5c57ac557463 100644 --- a/ppdet/ext_op/rbox_iou_op.cu +++ b/ppdet/ext_op/csrc/rbox_iou/rbox_iou_op.cu @@ -12,10 +12,11 @@ // See the License for the specific language governing permissions and // limitations under the License. // -// The code is based on https://github.com/csuhan/s2anet/blob/master/mmdet/ops/box_iou_rotated +// The code is based on +// https://github.com/csuhan/s2anet/blob/master/mmdet/ops/box_iou_rotated -#include "rbox_iou_op.h" #include "paddle/extension.h" +#include "rbox_iou_op.h" // 2D block with 32 * 16 = 512 threads per block const int BLOCK_DIM_X = 32; @@ -25,17 +26,13 @@ const int BLOCK_DIM_Y = 16; Computes ceil(a / b) */ -static inline int CeilDiv(const int a, const int b) { - return (a + b - 1) / b; -} +static inline int CeilDiv(const int a, const int b) { return (a + b - 1) / b; } template -__global__ void rbox_iou_cuda_kernel( - const int rbox1_num, - const int rbox2_num, - const T* rbox1_data_ptr, - const T* rbox2_data_ptr, - T* output_data_ptr) { +__global__ void rbox_iou_cuda_kernel(const int rbox1_num, const int rbox2_num, + const T *rbox1_data_ptr, + const T *rbox2_data_ptr, + T *output_data_ptr) { // get row_start and col_start const int rbox1_block_idx = blockIdx.x * blockDim.x; @@ -47,7 +44,6 @@ __global__ void rbox_iou_cuda_kernel( __shared__ T block_boxes1[BLOCK_DIM_X * 5]; __shared__ T block_boxes2[BLOCK_DIM_Y * 5]; - // It's safe to copy using threadIdx.x since BLOCK_DIM_X >= BLOCK_DIM_Y if (threadIdx.x < rbox1_thread_num && threadIdx.y == 0) { block_boxes1[threadIdx.x * 5 + 0] = @@ -62,7 +58,8 @@ __global__ void rbox_iou_cuda_kernel( rbox1_data_ptr[(rbox1_block_idx + threadIdx.x) * 5 + 4]; } - // threadIdx.x < BLOCK_DIM_Y=rbox2_thread_num, just use same condition as above: threadIdx.y == 0 + // threadIdx.x < BLOCK_DIM_Y=rbox2_thread_num, just use same condition as + // above: threadIdx.y == 0 if (threadIdx.x < rbox2_thread_num && threadIdx.y == 0) { block_boxes2[threadIdx.x * 5 + 0] = rbox2_data_ptr[(rbox2_block_idx + threadIdx.x) * 5 + 0]; @@ -80,41 +77,38 @@ __global__ void rbox_iou_cuda_kernel( __syncthreads(); if (threadIdx.x < rbox1_thread_num && threadIdx.y < rbox2_thread_num) { - int offset = (rbox1_block_idx + threadIdx.x) * rbox2_num + rbox2_block_idx + threadIdx.y; - output_data_ptr[offset] = rbox_iou_single(block_boxes1 + threadIdx.x * 5, block_boxes2 + threadIdx.y * 5); + int offset = (rbox1_block_idx + threadIdx.x) * rbox2_num + rbox2_block_idx + + threadIdx.y; + output_data_ptr[offset] = rbox_iou_single( + block_boxes1 + threadIdx.x * 5, block_boxes2 + threadIdx.y * 5); } } -#define CHECK_INPUT_GPU(x) PD_CHECK(x.place() == paddle::PlaceType::kGPU, #x " must be a GPU Tensor.") +#define CHECK_INPUT_GPU(x) \ + PD_CHECK(x.place() == paddle::PlaceType::kGPU, #x " must be a GPU Tensor.") -std::vector RboxIouCUDAForward(const paddle::Tensor& rbox1, const paddle::Tensor& rbox2) { - CHECK_INPUT_GPU(rbox1); - CHECK_INPUT_GPU(rbox2); +std::vector RboxIouCUDAForward(const paddle::Tensor &rbox1, + const paddle::Tensor &rbox2) { + CHECK_INPUT_GPU(rbox1); + CHECK_INPUT_GPU(rbox2); - auto rbox1_num = rbox1.shape()[0]; - auto rbox2_num = rbox2.shape()[0]; + auto rbox1_num = rbox1.shape()[0]; + auto rbox2_num = rbox2.shape()[0]; - auto output = paddle::Tensor(paddle::PlaceType::kGPU, {rbox1_num, rbox2_num}); + auto output = paddle::Tensor(paddle::PlaceType::kGPU, {rbox1_num, rbox2_num}); - const int blocks_x = CeilDiv(rbox1_num, BLOCK_DIM_X); - const int blocks_y = CeilDiv(rbox2_num, BLOCK_DIM_Y); + const int blocks_x = CeilDiv(rbox1_num, BLOCK_DIM_X); + const int blocks_y = CeilDiv(rbox2_num, BLOCK_DIM_Y); - dim3 blocks(blocks_x, blocks_y); - dim3 threads(BLOCK_DIM_X, BLOCK_DIM_Y); + dim3 blocks(blocks_x, blocks_y); + dim3 threads(BLOCK_DIM_X, BLOCK_DIM_Y); - PD_DISPATCH_FLOATING_TYPES( - rbox1.type(), - "rbox_iou_cuda_kernel", - ([&] { - rbox_iou_cuda_kernel<<>>( - rbox1_num, - rbox2_num, - rbox1.data(), - rbox2.data(), - output.mutable_data()); - })); + PD_DISPATCH_FLOATING_TYPES( + rbox1.type(), "rbox_iou_cuda_kernel", ([&] { + rbox_iou_cuda_kernel<<>>( + rbox1_num, rbox2_num, rbox1.data(), rbox2.data(), + output.mutable_data()); + })); - return {output}; + return {output}; } - - diff --git a/ppdet/ext_op/rbox_iou_op.h b/ppdet/ext_op/csrc/rbox_iou/rbox_iou_op.h similarity index 81% rename from ppdet/ext_op/rbox_iou_op.h rename to ppdet/ext_op/csrc/rbox_iou/rbox_iou_op.h index 77fb62e394a17a2e41379a40b3379c4eacf4e80d..fce66dea00e829215ffdb3a38f8db6182a068609 100644 --- a/ppdet/ext_op/rbox_iou_op.h +++ b/ppdet/ext_op/csrc/rbox_iou/rbox_iou_op.h @@ -12,7 +12,8 @@ // See the License for the specific language governing permissions and // limitations under the License. // -// The code is based on https://github.com/csuhan/s2anet/blob/master/mmdet/ops/box_iou_rotated +// The code is based on +// https://github.com/csuhan/s2anet/blob/master/mmdet/ops/box_iou_rotated #pragma once @@ -32,24 +33,20 @@ namespace { -template -struct RotatedBox { - T x_ctr, y_ctr, w, h, a; -}; +template struct RotatedBox { T x_ctr, y_ctr, w, h, a; }; -template -struct Point { +template struct Point { T x, y; - HOST_DEVICE_INLINE Point(const T& px = 0, const T& py = 0) : x(px), y(py) {} - HOST_DEVICE_INLINE Point operator+(const Point& p) const { + HOST_DEVICE_INLINE Point(const T &px = 0, const T &py = 0) : x(px), y(py) {} + HOST_DEVICE_INLINE Point operator+(const Point &p) const { return Point(x + p.x, y + p.y); } - HOST_DEVICE_INLINE Point& operator+=(const Point& p) { + HOST_DEVICE_INLINE Point &operator+=(const Point &p) { x += p.x; y += p.y; return *this; } - HOST_DEVICE_INLINE Point operator-(const Point& p) const { + HOST_DEVICE_INLINE Point operator-(const Point &p) const { return Point(x - p.x, y - p.y); } HOST_DEVICE_INLINE Point operator*(const T coeff) const { @@ -58,22 +55,21 @@ struct Point { }; template -HOST_DEVICE_INLINE T dot_2d(const Point& A, const Point& B) { +HOST_DEVICE_INLINE T dot_2d(const Point &A, const Point &B) { return A.x * B.x + A.y * B.y; } template -HOST_DEVICE_INLINE T cross_2d(const Point& A, const Point& B) { +HOST_DEVICE_INLINE T cross_2d(const Point &A, const Point &B) { return A.x * B.y - B.x * A.y; } template -HOST_DEVICE_INLINE void get_rotated_vertices( - const RotatedBox& box, - Point (&pts)[4]) { +HOST_DEVICE_INLINE void get_rotated_vertices(const RotatedBox &box, + Point (&pts)[4]) { // M_PI / 180. == 0.01745329251 - //double theta = box.a * 0.01745329251; - //MODIFIED + // double theta = box.a * 0.01745329251; + // MODIFIED double theta = box.a; T cosTheta2 = (T)cos(theta) * 0.5f; T sinTheta2 = (T)sin(theta) * 0.5f; @@ -90,10 +86,9 @@ HOST_DEVICE_INLINE void get_rotated_vertices( } template -HOST_DEVICE_INLINE int get_intersection_points( - const Point (&pts1)[4], - const Point (&pts2)[4], - Point (&intersections)[24]) { +HOST_DEVICE_INLINE int get_intersection_points(const Point (&pts1)[4], + const Point (&pts2)[4], + Point (&intersections)[24]) { // Line vector // A line from p1 to p2 is: p1 + (p2-p1)*t, t=[0,1] Point vec1[4], vec2[4]; @@ -127,8 +122,8 @@ HOST_DEVICE_INLINE int get_intersection_points( // Check for vertices of rect1 inside rect2 { - const auto& AB = vec2[0]; - const auto& DA = vec2[3]; + const auto &AB = vec2[0]; + const auto &DA = vec2[3]; auto ABdotAB = dot_2d(AB, AB); auto ADdotAD = dot_2d(DA, DA); for (int i = 0; i < 4; i++) { @@ -150,8 +145,8 @@ HOST_DEVICE_INLINE int get_intersection_points( // Reverse the check - check for vertices of rect2 inside rect1 { - const auto& AB = vec1[0]; - const auto& DA = vec1[3]; + const auto &AB = vec1[0]; + const auto &DA = vec1[3]; auto ABdotAB = dot_2d(AB, AB); auto ADdotAD = dot_2d(DA, DA); for (int i = 0; i < 4; i++) { @@ -171,11 +166,9 @@ HOST_DEVICE_INLINE int get_intersection_points( } template -HOST_DEVICE_INLINE int convex_hull_graham( - const Point (&p)[24], - const int& num_in, - Point (&q)[24], - bool shift_to_zero = false) { +HOST_DEVICE_INLINE int convex_hull_graham(const Point (&p)[24], + const int &num_in, Point (&q)[24], + bool shift_to_zero = false) { assert(num_in >= 2); // Step 1: @@ -188,7 +181,7 @@ HOST_DEVICE_INLINE int convex_hull_graham( t = i; } } - auto& start = p[t]; // starting point + auto &start = p[t]; // starting point // Step 2: // Subtract starting point from every points (for sorting in the next step) @@ -230,15 +223,15 @@ HOST_DEVICE_INLINE int convex_hull_graham( } #else // CPU version - std::sort( - q + 1, q + num_in, [](const Point& A, const Point& B) -> bool { - T temp = cross_2d(A, B); - if (fabs(temp) < 1e-6) { - return dot_2d(A, A) < dot_2d(B, B); - } else { - return temp > 0; - } - }); + std::sort(q + 1, q + num_in, + [](const Point &A, const Point &B) -> bool { + T temp = cross_2d(A, B); + if (fabs(temp) < 1e-6) { + return dot_2d(A, A) < dot_2d(B, B); + } else { + return temp > 0; + } + }); #endif // Step 4: @@ -286,7 +279,7 @@ HOST_DEVICE_INLINE int convex_hull_graham( } template -HOST_DEVICE_INLINE T polygon_area(const Point (&q)[24], const int& m) { +HOST_DEVICE_INLINE T polygon_area(const Point (&q)[24], const int &m) { if (m <= 2) { return 0; } @@ -300,9 +293,8 @@ HOST_DEVICE_INLINE T polygon_area(const Point (&q)[24], const int& m) { } template -HOST_DEVICE_INLINE T rboxes_intersection( - const RotatedBox& box1, - const RotatedBox& box2) { +HOST_DEVICE_INLINE T rboxes_intersection(const RotatedBox &box1, + const RotatedBox &box2) { // There are up to 4 x 4 + 4 + 4 = 24 intersections (including dups) returned // from rotated_rect_intersection_pts Point intersectPts[24], orderedPts[24]; @@ -327,8 +319,8 @@ HOST_DEVICE_INLINE T rboxes_intersection( } // namespace template -HOST_DEVICE_INLINE T -rbox_iou_single(T const* const box1_raw, T const* const box2_raw) { +HOST_DEVICE_INLINE T rbox_iou_single(T const *const box1_raw, + T const *const box2_raw) { // shift center to the middle point to achieve higher precision in result RotatedBox box1, box2; auto center_shift_x = (box1_raw[0] + box2_raw[0]) / 2.0; diff --git a/ppdet/ext_op/setup.py b/ppdet/ext_op/setup.py index d364db7ed37c68227a5ef7d2f8b2c8d5fcad8123..5892f4625c263b9eac19a434aca10968882bc4bc 100644 --- a/ppdet/ext_op/setup.py +++ b/ppdet/ext_op/setup.py @@ -1,14 +1,33 @@ +import os +import glob import paddle from paddle.utils.cpp_extension import CppExtension, CUDAExtension, setup -if __name__ == "__main__": + +def get_extensions(): + root_dir = os.path.dirname(os.path.abspath(__file__)) + ext_root_dir = os.path.join(root_dir, 'csrc') + sources = [] + for ext_name in os.listdir(ext_root_dir): + ext_dir = os.path.join(ext_root_dir, ext_name) + source = glob.glob(os.path.join(ext_dir, '*.cc')) + kwargs = dict() + if paddle.device.is_compiled_with_cuda(): + source += glob.glob(os.path.join(ext_dir, '*.cu')) + + if not source: + continue + + sources += source + if paddle.device.is_compiled_with_cuda(): - setup( - name='rbox_iou_ops', - ext_modules=CUDAExtension( - sources=['rbox_iou_op.cc', 'rbox_iou_op.cu'], - extra_compile_args={'cxx': ['-DPADDLE_WITH_CUDA']})) + extension = CUDAExtension( + sources, extra_compile_args={'cxx': ['-DPADDLE_WITH_CUDA']}) else: - setup( - name='rbox_iou_ops', - ext_modules=CppExtension(sources=['rbox_iou_op.cc'])) + extension = CppExtension(sources) + + return extension + + +if __name__ == "__main__": + setup(name='ext_op', ext_modules=get_extensions()) diff --git a/ppdet/ext_op/unittest/test_matched_rbox_iou.py b/ppdet/ext_op/unittest/test_matched_rbox_iou.py new file mode 100644 index 0000000000000000000000000000000000000000..af7b076da2435a4f025f608430549f2334c22e08 --- /dev/null +++ b/ppdet/ext_op/unittest/test_matched_rbox_iou.py @@ -0,0 +1,149 @@ +import numpy as np +import sys +import time +from shapely.geometry import Polygon +import paddle +import unittest + +from ext_op import matched_rbox_iou + + +def rbox2poly_single(rrect, get_best_begin_point=False): + """ + rrect:[x_ctr,y_ctr,w,h,angle] + to + poly:[x0,y0,x1,y1,x2,y2,x3,y3] + """ + x_ctr, y_ctr, width, height, angle = rrect[:5] + tl_x, tl_y, br_x, br_y = -width / 2, -height / 2, width / 2, height / 2 + # rect 2x4 + rect = np.array([[tl_x, br_x, br_x, tl_x], [tl_y, tl_y, br_y, br_y]]) + R = np.array([[np.cos(angle), -np.sin(angle)], + [np.sin(angle), np.cos(angle)]]) + # poly + poly = R.dot(rect) + x0, x1, x2, x3 = poly[0, :4] + x_ctr + y0, y1, y2, y3 = poly[1, :4] + y_ctr + poly = np.array([x0, y0, x1, y1, x2, y2, x3, y3], dtype=np.float64) + return poly + + +def intersection(g, p): + """ + Intersection. + """ + + g = g[:8].reshape((4, 2)) + p = p[:8].reshape((4, 2)) + + a = g + b = p + + use_filter = True + if use_filter: + # step1: + inter_x1 = np.maximum(np.min(a[:, 0]), np.min(b[:, 0])) + inter_x2 = np.minimum(np.max(a[:, 0]), np.max(b[:, 0])) + inter_y1 = np.maximum(np.min(a[:, 1]), np.min(b[:, 1])) + inter_y2 = np.minimum(np.max(a[:, 1]), np.max(b[:, 1])) + if inter_x1 >= inter_x2 or inter_y1 >= inter_y2: + return 0. + x1 = np.minimum(np.min(a[:, 0]), np.min(b[:, 0])) + x2 = np.maximum(np.max(a[:, 0]), np.max(b[:, 0])) + y1 = np.minimum(np.min(a[:, 1]), np.min(b[:, 1])) + y2 = np.maximum(np.max(a[:, 1]), np.max(b[:, 1])) + if x1 >= x2 or y1 >= y2 or (x2 - x1) < 2 or (y2 - y1) < 2: + return 0. + + g = Polygon(g) + p = Polygon(p) + if not g.is_valid or not p.is_valid: + return 0 + + inter = Polygon(g).intersection(Polygon(p)).area + union = g.area + p.area - inter + if union == 0: + return 0 + else: + return inter / union + + +def matched_rbox_overlaps(anchors, gt_bboxes, use_cv2=False): + """ + + Args: + anchors: [M, 5] x1,y1,x2,y2,angle + gt_bboxes: [M, 5] x1,y1,x2,y2,angle + + Returns: + macthed_iou: [M] + """ + assert anchors.shape[1] == 5 + assert gt_bboxes.shape[1] == 5 + + gt_bboxes_ploy = [rbox2poly_single(e) for e in gt_bboxes] + anchors_ploy = [rbox2poly_single(e) for e in anchors] + + num = len(anchors_ploy) + iou = np.zeros((num, ), dtype=np.float64) + + start_time = time.time() + for i in range(num): + try: + iou[i] = intersection(gt_bboxes_ploy[i], anchors_ploy[i]) + except Exception as e: + print('cur gt_bboxes_ploy[i]', gt_bboxes_ploy[i], + 'anchors_ploy[j]', anchors_ploy[i], e) + return iou + + +def gen_sample(n): + rbox = np.random.rand(n, 5) + rbox[:, 0:4] = rbox[:, 0:4] * 0.45 + 0.001 + rbox[:, 4] = rbox[:, 4] - 0.5 + return rbox + + +class MatchedRBoxIoUTest(unittest.TestCase): + def setUp(self): + self.initTestCase() + self.rbox1 = gen_sample(self.n) + self.rbox2 = gen_sample(self.n) + + def initTestCase(self): + self.n = 1000 + + def assertAllClose(self, x, y, msg, atol=5e-1, rtol=1e-2): + self.assertTrue(np.allclose(x, y, atol=atol, rtol=rtol), msg=msg) + + def get_places(self): + places = [paddle.CPUPlace()] + if paddle.device.is_compiled_with_cuda(): + places.append(paddle.CUDAPlace(0)) + + return places + + def check_output(self, place): + paddle.disable_static() + pd_rbox1 = paddle.to_tensor(self.rbox1, place=place) + pd_rbox2 = paddle.to_tensor(self.rbox2, place=place) + actual_t = matched_rbox_iou(pd_rbox1, pd_rbox2).numpy() + poly_rbox1 = self.rbox1 + poly_rbox2 = self.rbox2 + poly_rbox1[:, 0:4] = self.rbox1[:, 0:4] * 1024 + poly_rbox2[:, 0:4] = self.rbox2[:, 0:4] * 1024 + expect_t = matched_rbox_overlaps(poly_rbox1, poly_rbox2, use_cv2=False) + self.assertAllClose( + actual_t, + expect_t, + msg="rbox_iou has diff at {} \nExpect {}\nBut got {}".format( + str(place), str(expect_t), str(actual_t))) + + def test_output(self): + places = self.get_places() + for place in places: + self.check_output(place) + + +if __name__ == "__main__": + unittest.main() diff --git a/ppdet/ext_op/test.py b/ppdet/ext_op/unittest/test_rbox_iou.py similarity index 89% rename from ppdet/ext_op/test.py rename to ppdet/ext_op/unittest/test_rbox_iou.py index 85872e484b8ca6d60a62d311c9fdfc4a9e08b6e2..8ef19ae841d5a73c5b90f1b971ed36d1d7f61a7a 100644 --- a/ppdet/ext_op/test.py +++ b/ppdet/ext_op/unittest/test_rbox_iou.py @@ -5,11 +5,7 @@ from shapely.geometry import Polygon import paddle import unittest -try: - from rbox_iou_ops import rbox_iou -except Exception as e: - print('import rbox_iou_ops error', e) - sys.exit(-1) +from ext_op import rbox_iou def rbox2poly_single(rrect, get_best_begin_point=False): @@ -80,7 +76,7 @@ def rbox_overlaps(anchors, gt_bboxes, use_cv2=False): gt_bboxes: [M, 5] x1,y1,x2,y2,angle Returns: - + iou: [NA, M] """ assert anchors.shape[1] == 5 assert gt_bboxes.shape[1] == 5 @@ -89,17 +85,16 @@ def rbox_overlaps(anchors, gt_bboxes, use_cv2=False): anchors_ploy = [rbox2poly_single(e) for e in anchors] num_gt, num_anchors = len(gt_bboxes_ploy), len(anchors_ploy) - iou = np.zeros((num_gt, num_anchors), dtype=np.float64) + iou = np.zeros((num_anchors, num_gt), dtype=np.float64) start_time = time.time() - for i in range(num_gt): - for j in range(num_anchors): + for i in range(num_anchors): + for j in range(num_gt): try: - iou[i, j] = intersection(gt_bboxes_ploy[i], anchors_ploy[j]) + iou[i, j] = intersection(anchors_ploy[i], gt_bboxes_ploy[j]) except Exception as e: - print('cur gt_bboxes_ploy[i]', gt_bboxes_ploy[i], - 'anchors_ploy[j]', anchors_ploy[j], e) - iou = iou.T + print('cur anchors_ploy[i]', anchors_ploy[i], + 'gt_bboxes_ploy[j]', gt_bboxes_ploy[j], e) return iou diff --git a/ppdet/metrics/map_utils.py b/ppdet/metrics/map_utils.py index 9c96b9235f4205279e47ff84006351a012d7bf2d..12fb9ba51242bdd244eb60da8b364ab26ddbecba 100644 --- a/ppdet/metrics/map_utils.py +++ b/ppdet/metrics/map_utils.py @@ -121,9 +121,9 @@ def calc_rbox_iou(pred, gt_rbox): pred_rbox = pred_rbox.reshape(-1, 5) pred_rbox = pred_rbox.reshape(-1, 5) try: - from rbox_iou_ops import rbox_iou + from ext_op import rbox_iou except Exception as e: - print("import custom_ops error, try install rbox_iou_ops " \ + print("import custom_ops error, try install ext_op " \ "following ppdet/ext_op/README.md", e) sys.stdout.flush() sys.exit(-1) diff --git a/ppdet/modeling/heads/s2anet_head.py b/ppdet/modeling/heads/s2anet_head.py index bb26855a7d163f76de9198515b485d6e5cedf27b..e17023d672532fb7aa786a98f95bdc3315906964 100644 --- a/ppdet/modeling/heads/s2anet_head.py +++ b/ppdet/modeling/heads/s2anet_head.py @@ -601,9 +601,9 @@ class S2ANetHead(nn.Layer): fam_bbox = paddle.sum(fam_bbox, axis=-1) feat_bbox_weights = paddle.sum(feat_bbox_weights, axis=-1) try: - from rbox_iou_ops import rbox_iou + from ext_op import rbox_iou except Exception as e: - print("import custom_ops error, try install rbox_iou_ops " \ + print("import custom_ops error, try install ext_op " \ "following ppdet/ext_op/README.md", e) sys.stdout.flush() sys.exit(-1) @@ -716,9 +716,9 @@ class S2ANetHead(nn.Layer): odm_bbox = paddle.sum(odm_bbox, axis=-1) feat_bbox_weights = paddle.sum(feat_bbox_weights, axis=-1) try: - from rbox_iou_ops import rbox_iou + from ext_op import rbox_iou except Exception as e: - print("import custom_ops error, try install rbox_iou_ops " \ + print("import custom_ops error, try install ext_op " \ "following ppdet/ext_op/README.md", e) sys.stdout.flush() sys.exit(-1) diff --git a/ppdet/modeling/proposal_generator/target_layer.py b/ppdet/modeling/proposal_generator/target_layer.py index 3b5a09601682151afcd47a0ea0db4fd0f03440a9..201c8bf86b14ee19f4398d2451dabdc886e9af98 100644 --- a/ppdet/modeling/proposal_generator/target_layer.py +++ b/ppdet/modeling/proposal_generator/target_layer.py @@ -392,9 +392,9 @@ class RBoxAssigner(object): gt_bboxes_xc_yc = paddle.to_tensor(gt_bboxes_xc_yc) try: - from rbox_iou_ops import rbox_iou + from ext_op import rbox_iou except Exception as e: - print("import custom_ops error, try install rbox_iou_ops " \ + print("import custom_ops error, try install ext_op " \ "following ppdet/ext_op/README.md", e) sys.stdout.flush() sys.exit(-1)