未验证 提交 4630e56b 编写于 作者: W wangxinxin08 提交者: GitHub

polish the code of rbox iou (#4123)

上级 ff96c78d
......@@ -11,22 +11,71 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "rbox_iou_op.h"
#include "paddle/extension.h"
#include <vector>
std::vector<paddle::Tensor> RboxIouCPUForward(const paddle::Tensor& rbox1, const paddle::Tensor& rbox2);
template <typename T>
void rbox_iou_cpu_kernel(
const int rbox1_num,
const int rbox2_num,
const T* rbox1_data_ptr,
const T* rbox2_data_ptr,
T* output_data_ptr) {
int i, j;
for (i = 0; i < rbox1_num; i++) {
for (j = 0; j < rbox2_num; j++) {
int offset = i * rbox2_num + j;
output_data_ptr[offset] = rbox_iou_single<T>(rbox1_data_ptr + i * 5, rbox2_data_ptr + j * 5);
}
}
}
#define CHECK_INPUT_CPU(x) PD_CHECK(x.place() == paddle::PlaceType::kCPU, #x " must be a CPU Tensor.")
std::vector<paddle::Tensor> RboxIouCPUForward(const paddle::Tensor& rbox1, const paddle::Tensor& rbox2) {
CHECK_INPUT_CPU(rbox1);
CHECK_INPUT_CPU(rbox2);
auto rbox1_num = rbox1.shape()[0];
auto rbox2_num = rbox2.shape()[0];
auto output = paddle::Tensor(paddle::PlaceType::kCPU);
output.reshape({rbox1_num, rbox2_num});
PD_DISPATCH_FLOATING_TYPES(
rbox1.type(),
"rbox_iou_cpu_kernel",
([&] {
rbox_iou_cpu_kernel<data_t>(
rbox1_num,
rbox2_num,
rbox1.data<data_t>(),
rbox2.data<data_t>(),
output.mutable_data<data_t>());
}));
return {output};
}
#ifdef PADDLE_WITH_CUDA
std::vector<paddle::Tensor> RboxIouCUDAForward(const paddle::Tensor& rbox1, const paddle::Tensor& rbox2);
#endif
#define CHECK_INPUT_SAME(x1, x2) PD_CHECK(x1.place() == x2.place(), "input must be smae pacle.")
std::vector<paddle::Tensor> RboxIouForward(const paddle::Tensor& rbox1, const paddle::Tensor& rbox2) {
CHECK_INPUT_SAME(rbox1, rbox2);
if (rbox1.place() == paddle::PlaceType::kCPU) {
return RboxIouCPUForward(rbox1, rbox2);
}
else if (rbox1.place() == paddle::PlaceType::kGPU) {
#ifdef PADDLE_WITH_CUDA
} else if (rbox1.place() == paddle::PlaceType::kGPU) {
return RboxIouCUDAForward(rbox1, rbox2);
#endif
}
}
......
......@@ -11,350 +11,9 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include <cassert>
#include <cmath>
#ifdef __CUDACC__
// Designates functions callable from the host (CPU) and the device (GPU)
#define HOST_DEVICE __host__ __device__
#define HOST_DEVICE_INLINE HOST_DEVICE __forceinline__
#else
#include <algorithm>
#define HOST_DEVICE
#define HOST_DEVICE_INLINE HOST_DEVICE inline
#endif
#include "rbox_iou_op.h"
#include "paddle/extension.h"
#include <vector>
namespace {
template <typename T>
struct RotatedBox {
T x_ctr, y_ctr, w, h, a;
};
template <typename T>
struct Point {
T x, y;
HOST_DEVICE_INLINE Point(const T& px = 0, const T& py = 0) : x(px), y(py) {}
HOST_DEVICE_INLINE Point operator+(const Point& p) const {
return Point(x + p.x, y + p.y);
}
HOST_DEVICE_INLINE Point& operator+=(const Point& p) {
x += p.x;
y += p.y;
return *this;
}
HOST_DEVICE_INLINE Point operator-(const Point& p) const {
return Point(x - p.x, y - p.y);
}
HOST_DEVICE_INLINE Point operator*(const T coeff) const {
return Point(x * coeff, y * coeff);
}
};
template <typename T>
HOST_DEVICE_INLINE T dot_2d(const Point<T>& A, const Point<T>& B) {
return A.x * B.x + A.y * B.y;
}
template <typename T>
HOST_DEVICE_INLINE T cross_2d(const Point<T>& A, const Point<T>& B) {
return A.x * B.y - B.x * A.y;
}
template <typename T>
HOST_DEVICE_INLINE void get_rotated_vertices(
const RotatedBox<T>& box,
Point<T> (&pts)[4]) {
// M_PI / 180. == 0.01745329251
//double theta = box.a * 0.01745329251;
//MODIFIED
double theta = box.a;
T cosTheta2 = (T)cos(theta) * 0.5f;
T sinTheta2 = (T)sin(theta) * 0.5f;
// y: top --> down; x: left --> right
pts[0].x = box.x_ctr - sinTheta2 * box.h - cosTheta2 * box.w;
pts[0].y = box.y_ctr + cosTheta2 * box.h - sinTheta2 * box.w;
pts[1].x = box.x_ctr + sinTheta2 * box.h - cosTheta2 * box.w;
pts[1].y = box.y_ctr - cosTheta2 * box.h - sinTheta2 * box.w;
pts[2].x = 2 * box.x_ctr - pts[0].x;
pts[2].y = 2 * box.y_ctr - pts[0].y;
pts[3].x = 2 * box.x_ctr - pts[1].x;
pts[3].y = 2 * box.y_ctr - pts[1].y;
}
template <typename T>
HOST_DEVICE_INLINE int get_intersection_points(
const Point<T> (&pts1)[4],
const Point<T> (&pts2)[4],
Point<T> (&intersections)[24]) {
// Line vector
// A line from p1 to p2 is: p1 + (p2-p1)*t, t=[0,1]
Point<T> vec1[4], vec2[4];
for (int i = 0; i < 4; i++) {
vec1[i] = pts1[(i + 1) % 4] - pts1[i];
vec2[i] = pts2[(i + 1) % 4] - pts2[i];
}
// Line test - test all line combos for intersection
int num = 0; // number of intersections
for (int i = 0; i < 4; i++) {
for (int j = 0; j < 4; j++) {
// Solve for 2x2 Ax=b
T det = cross_2d<T>(vec2[j], vec1[i]);
// This takes care of parallel lines
if (fabs(det) <= 1e-14) {
continue;
}
auto vec12 = pts2[j] - pts1[i];
T t1 = cross_2d<T>(vec2[j], vec12) / det;
T t2 = cross_2d<T>(vec1[i], vec12) / det;
if (t1 >= 0.0f && t1 <= 1.0f && t2 >= 0.0f && t2 <= 1.0f) {
intersections[num++] = pts1[i] + vec1[i] * t1;
}
}
}
// Check for vertices of rect1 inside rect2
{
const auto& AB = vec2[0];
const auto& DA = vec2[3];
auto ABdotAB = dot_2d<T>(AB, AB);
auto ADdotAD = dot_2d<T>(DA, DA);
for (int i = 0; i < 4; i++) {
// assume ABCD is the rectangle, and P is the point to be judged
// P is inside ABCD iff. P's projection on AB lies within AB
// and P's projection on AD lies within AD
auto AP = pts1[i] - pts2[0];
auto APdotAB = dot_2d<T>(AP, AB);
auto APdotAD = -dot_2d<T>(AP, DA);
if ((APdotAB >= 0) && (APdotAD >= 0) && (APdotAB <= ABdotAB) &&
(APdotAD <= ADdotAD)) {
intersections[num++] = pts1[i];
}
}
}
// Reverse the check - check for vertices of rect2 inside rect1
{
const auto& AB = vec1[0];
const auto& DA = vec1[3];
auto ABdotAB = dot_2d<T>(AB, AB);
auto ADdotAD = dot_2d<T>(DA, DA);
for (int i = 0; i < 4; i++) {
auto AP = pts2[i] - pts1[0];
auto APdotAB = dot_2d<T>(AP, AB);
auto APdotAD = -dot_2d<T>(AP, DA);
if ((APdotAB >= 0) && (APdotAD >= 0) && (APdotAB <= ABdotAB) &&
(APdotAD <= ADdotAD)) {
intersections[num++] = pts2[i];
}
}
}
return num;
}
template <typename T>
HOST_DEVICE_INLINE int convex_hull_graham(
const Point<T> (&p)[24],
const int& num_in,
Point<T> (&q)[24],
bool shift_to_zero = false) {
assert(num_in >= 2);
// Step 1:
// Find point with minimum y
// if more than 1 points have the same minimum y,
// pick the one with the minimum x.
int t = 0;
for (int i = 1; i < num_in; i++) {
if (p[i].y < p[t].y || (p[i].y == p[t].y && p[i].x < p[t].x)) {
t = i;
}
}
auto& start = p[t]; // starting point
// Step 2:
// Subtract starting point from every points (for sorting in the next step)
for (int i = 0; i < num_in; i++) {
q[i] = p[i] - start;
}
// Swap the starting point to position 0
auto tmp = q[0];
q[0] = q[t];
q[t] = tmp;
// Step 3:
// Sort point 1 ~ num_in according to their relative cross-product values
// (essentially sorting according to angles)
// If the angles are the same, sort according to their distance to origin
T dist[24];
for (int i = 0; i < num_in; i++) {
dist[i] = dot_2d<T>(q[i], q[i]);
}
#ifdef __CUDACC__
// CUDA version
// In the future, we can potentially use thrust
// for sorting here to improve speed (though not guaranteed)
for (int i = 1; i < num_in - 1; i++) {
for (int j = i + 1; j < num_in; j++) {
T crossProduct = cross_2d<T>(q[i], q[j]);
if ((crossProduct < -1e-6) ||
(fabs(crossProduct) < 1e-6 && dist[i] > dist[j])) {
auto q_tmp = q[i];
q[i] = q[j];
q[j] = q_tmp;
auto dist_tmp = dist[i];
dist[i] = dist[j];
dist[j] = dist_tmp;
}
}
}
#else
// CPU version
std::sort(
q + 1, q + num_in, [](const Point<T>& A, const Point<T>& B) -> bool {
T temp = cross_2d<T>(A, B);
if (fabs(temp) < 1e-6) {
return dot_2d<T>(A, A) < dot_2d<T>(B, B);
} else {
return temp > 0;
}
});
#endif
// Step 4:
// Make sure there are at least 2 points (that don't overlap with each other)
// in the stack
int k; // index of the non-overlapped second point
for (k = 1; k < num_in; k++) {
if (dist[k] > 1e-8) {
break;
}
}
if (k == num_in) {
// We reach the end, which means the convex hull is just one point
q[0] = p[t];
return 1;
}
q[1] = q[k];
int m = 2; // 2 points in the stack
// Step 5:
// Finally we can start the scanning process.
// When a non-convex relationship between the 3 points is found
// (either concave shape or duplicated points),
// we pop the previous point from the stack
// until the 3-point relationship is convex again, or
// until the stack only contains two points
for (int i = k + 1; i < num_in; i++) {
while (m > 1 && cross_2d<T>(q[i] - q[m - 2], q[m - 1] - q[m - 2]) >= 0) {
m--;
}
q[m++] = q[i];
}
// Step 6 (Optional):
// In general sense we need the original coordinates, so we
// need to shift the points back (reverting Step 2)
// But if we're only interested in getting the area/perimeter of the shape
// We can simply return.
if (!shift_to_zero) {
for (int i = 0; i < m; i++) {
q[i] += start;
}
}
return m;
}
template <typename T>
HOST_DEVICE_INLINE T polygon_area(const Point<T> (&q)[24], const int& m) {
if (m <= 2) {
return 0;
}
T area = 0;
for (int i = 1; i < m - 1; i++) {
area += fabs(cross_2d<T>(q[i] - q[0], q[i + 1] - q[0]));
}
return area / 2.0;
}
template <typename T>
HOST_DEVICE_INLINE T rboxes_intersection(
const RotatedBox<T>& box1,
const RotatedBox<T>& box2) {
// There are up to 4 x 4 + 4 + 4 = 24 intersections (including dups) returned
// from rotated_rect_intersection_pts
Point<T> intersectPts[24], orderedPts[24];
Point<T> pts1[4];
Point<T> pts2[4];
get_rotated_vertices<T>(box1, pts1);
get_rotated_vertices<T>(box2, pts2);
int num = get_intersection_points<T>(pts1, pts2, intersectPts);
if (num <= 2) {
return 0.0;
}
// Convex Hull to order the intersection points in clockwise order and find
// the contour area.
int num_convex = convex_hull_graham<T>(intersectPts, num, orderedPts, true);
return polygon_area<T>(orderedPts, num_convex);
}
} // namespace
template <typename T>
HOST_DEVICE_INLINE T
rbox_iou_single(T const* const box1_raw, T const* const box2_raw) {
// shift center to the middle point to achieve higher precision in result
RotatedBox<T> box1, box2;
auto center_shift_x = (box1_raw[0] + box2_raw[0]) / 2.0;
auto center_shift_y = (box1_raw[1] + box2_raw[1]) / 2.0;
box1.x_ctr = box1_raw[0] - center_shift_x;
box1.y_ctr = box1_raw[1] - center_shift_y;
box1.w = box1_raw[2];
box1.h = box1_raw[3];
box1.a = box1_raw[4];
box2.x_ctr = box2_raw[0] - center_shift_x;
box2.y_ctr = box2_raw[1] - center_shift_y;
box2.w = box2_raw[2];
box2.h = box2_raw[3];
box2.a = box2_raw[4];
const T area1 = box1.w * box1.h;
const T area2 = box2.w * box2.h;
if (area1 < 1e-14 || area2 < 1e-14) {
return 0.f;
}
const T intersection = rboxes_intersection<T>(box1, box2);
const T iou = intersection / (area1 + area2 - intersection);
return iou;
}
// 2D block with 32 * 16 = 512 threads per block
const int BLOCK_DIM_X = 32;
const int BLOCK_DIM_Y = 16;
......@@ -362,13 +21,9 @@ const int BLOCK_DIM_Y = 16;
/**
Computes ceil(a / b)
*/
template <typename T>
__host__ __device__ __forceinline__ T CeilDiv0(T a, T b) {
return (a + b - 1) / b;
}
static inline int CeilDiv(const int a, const int b) {
return (a + b -1) / b;
return (a + b - 1) / b;
}
template <typename T>
......@@ -461,47 +116,3 @@ std::vector<paddle::Tensor> RboxIouCUDAForward(const paddle::Tensor& rbox1, cons
}
template <typename T>
void rbox_iou_cpu_kernel(
const int rbox1_num,
const int rbox2_num,
const T* rbox1_data_ptr,
const T* rbox2_data_ptr,
T* output_data_ptr) {
int i, j;
for (i = 0; i < rbox1_num; i++) {
for (j = 0; j < rbox2_num; j++) {
int offset = i * rbox2_num + j;
output_data_ptr[offset] = rbox_iou_single<T>(rbox1_data_ptr + i * 5, rbox2_data_ptr + j * 5);
}
}
}
#define CHECK_INPUT_CPU(x) PD_CHECK(x.place() == paddle::PlaceType::kCPU, #x " must be a CPU Tensor.")
std::vector<paddle::Tensor> RboxIouCPUForward(const paddle::Tensor& rbox1, const paddle::Tensor& rbox2) {
CHECK_INPUT_CPU(rbox1);
CHECK_INPUT_CPU(rbox2);
auto rbox1_num = rbox1.shape()[0];
auto rbox2_num = rbox2.shape()[0];
auto output = paddle::Tensor(paddle::PlaceType::kCPU);
output.reshape({rbox1_num, rbox2_num});
PD_DISPATCH_FLOATING_TYPES(
rbox1.type(),
"rbox_iou_cpu_kernel",
([&] {
rbox_iou_cpu_kernel<data_t>(
rbox1_num,
rbox2_num,
rbox1.data<data_t>(),
rbox2.data<data_t>(),
output.mutable_data<data_t>());
}));
return {output};
}
/* Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include <cassert>
#include <cmath>
#include <vector>
#ifdef __CUDACC__
// Designates functions callable from the host (CPU) and the device (GPU)
#define HOST_DEVICE __host__ __device__
#define HOST_DEVICE_INLINE HOST_DEVICE __forceinline__
#else
#include <algorithm>
#define HOST_DEVICE
#define HOST_DEVICE_INLINE HOST_DEVICE inline
#endif
namespace {
template <typename T>
struct RotatedBox {
T x_ctr, y_ctr, w, h, a;
};
template <typename T>
struct Point {
T x, y;
HOST_DEVICE_INLINE Point(const T& px = 0, const T& py = 0) : x(px), y(py) {}
HOST_DEVICE_INLINE Point operator+(const Point& p) const {
return Point(x + p.x, y + p.y);
}
HOST_DEVICE_INLINE Point& operator+=(const Point& p) {
x += p.x;
y += p.y;
return *this;
}
HOST_DEVICE_INLINE Point operator-(const Point& p) const {
return Point(x - p.x, y - p.y);
}
HOST_DEVICE_INLINE Point operator*(const T coeff) const {
return Point(x * coeff, y * coeff);
}
};
template <typename T>
HOST_DEVICE_INLINE T dot_2d(const Point<T>& A, const Point<T>& B) {
return A.x * B.x + A.y * B.y;
}
template <typename T>
HOST_DEVICE_INLINE T cross_2d(const Point<T>& A, const Point<T>& B) {
return A.x * B.y - B.x * A.y;
}
template <typename T>
HOST_DEVICE_INLINE void get_rotated_vertices(
const RotatedBox<T>& box,
Point<T> (&pts)[4]) {
// M_PI / 180. == 0.01745329251
//double theta = box.a * 0.01745329251;
//MODIFIED
double theta = box.a;
T cosTheta2 = (T)cos(theta) * 0.5f;
T sinTheta2 = (T)sin(theta) * 0.5f;
// y: top --> down; x: left --> right
pts[0].x = box.x_ctr - sinTheta2 * box.h - cosTheta2 * box.w;
pts[0].y = box.y_ctr + cosTheta2 * box.h - sinTheta2 * box.w;
pts[1].x = box.x_ctr + sinTheta2 * box.h - cosTheta2 * box.w;
pts[1].y = box.y_ctr - cosTheta2 * box.h - sinTheta2 * box.w;
pts[2].x = 2 * box.x_ctr - pts[0].x;
pts[2].y = 2 * box.y_ctr - pts[0].y;
pts[3].x = 2 * box.x_ctr - pts[1].x;
pts[3].y = 2 * box.y_ctr - pts[1].y;
}
template <typename T>
HOST_DEVICE_INLINE int get_intersection_points(
const Point<T> (&pts1)[4],
const Point<T> (&pts2)[4],
Point<T> (&intersections)[24]) {
// Line vector
// A line from p1 to p2 is: p1 + (p2-p1)*t, t=[0,1]
Point<T> vec1[4], vec2[4];
for (int i = 0; i < 4; i++) {
vec1[i] = pts1[(i + 1) % 4] - pts1[i];
vec2[i] = pts2[(i + 1) % 4] - pts2[i];
}
// Line test - test all line combos for intersection
int num = 0; // number of intersections
for (int i = 0; i < 4; i++) {
for (int j = 0; j < 4; j++) {
// Solve for 2x2 Ax=b
T det = cross_2d<T>(vec2[j], vec1[i]);
// This takes care of parallel lines
if (fabs(det) <= 1e-14) {
continue;
}
auto vec12 = pts2[j] - pts1[i];
T t1 = cross_2d<T>(vec2[j], vec12) / det;
T t2 = cross_2d<T>(vec1[i], vec12) / det;
if (t1 >= 0.0f && t1 <= 1.0f && t2 >= 0.0f && t2 <= 1.0f) {
intersections[num++] = pts1[i] + vec1[i] * t1;
}
}
}
// Check for vertices of rect1 inside rect2
{
const auto& AB = vec2[0];
const auto& DA = vec2[3];
auto ABdotAB = dot_2d<T>(AB, AB);
auto ADdotAD = dot_2d<T>(DA, DA);
for (int i = 0; i < 4; i++) {
// assume ABCD is the rectangle, and P is the point to be judged
// P is inside ABCD iff. P's projection on AB lies within AB
// and P's projection on AD lies within AD
auto AP = pts1[i] - pts2[0];
auto APdotAB = dot_2d<T>(AP, AB);
auto APdotAD = -dot_2d<T>(AP, DA);
if ((APdotAB >= 0) && (APdotAD >= 0) && (APdotAB <= ABdotAB) &&
(APdotAD <= ADdotAD)) {
intersections[num++] = pts1[i];
}
}
}
// Reverse the check - check for vertices of rect2 inside rect1
{
const auto& AB = vec1[0];
const auto& DA = vec1[3];
auto ABdotAB = dot_2d<T>(AB, AB);
auto ADdotAD = dot_2d<T>(DA, DA);
for (int i = 0; i < 4; i++) {
auto AP = pts2[i] - pts1[0];
auto APdotAB = dot_2d<T>(AP, AB);
auto APdotAD = -dot_2d<T>(AP, DA);
if ((APdotAB >= 0) && (APdotAD >= 0) && (APdotAB <= ABdotAB) &&
(APdotAD <= ADdotAD)) {
intersections[num++] = pts2[i];
}
}
}
return num;
}
template <typename T>
HOST_DEVICE_INLINE int convex_hull_graham(
const Point<T> (&p)[24],
const int& num_in,
Point<T> (&q)[24],
bool shift_to_zero = false) {
assert(num_in >= 2);
// Step 1:
// Find point with minimum y
// if more than 1 points have the same minimum y,
// pick the one with the minimum x.
int t = 0;
for (int i = 1; i < num_in; i++) {
if (p[i].y < p[t].y || (p[i].y == p[t].y && p[i].x < p[t].x)) {
t = i;
}
}
auto& start = p[t]; // starting point
// Step 2:
// Subtract starting point from every points (for sorting in the next step)
for (int i = 0; i < num_in; i++) {
q[i] = p[i] - start;
}
// Swap the starting point to position 0
auto tmp = q[0];
q[0] = q[t];
q[t] = tmp;
// Step 3:
// Sort point 1 ~ num_in according to their relative cross-product values
// (essentially sorting according to angles)
// If the angles are the same, sort according to their distance to origin
T dist[24];
for (int i = 0; i < num_in; i++) {
dist[i] = dot_2d<T>(q[i], q[i]);
}
#ifdef __CUDACC__
// CUDA version
// In the future, we can potentially use thrust
// for sorting here to improve speed (though not guaranteed)
for (int i = 1; i < num_in - 1; i++) {
for (int j = i + 1; j < num_in; j++) {
T crossProduct = cross_2d<T>(q[i], q[j]);
if ((crossProduct < -1e-6) ||
(fabs(crossProduct) < 1e-6 && dist[i] > dist[j])) {
auto q_tmp = q[i];
q[i] = q[j];
q[j] = q_tmp;
auto dist_tmp = dist[i];
dist[i] = dist[j];
dist[j] = dist_tmp;
}
}
}
#else
// CPU version
std::sort(
q + 1, q + num_in, [](const Point<T>& A, const Point<T>& B) -> bool {
T temp = cross_2d<T>(A, B);
if (fabs(temp) < 1e-6) {
return dot_2d<T>(A, A) < dot_2d<T>(B, B);
} else {
return temp > 0;
}
});
#endif
// Step 4:
// Make sure there are at least 2 points (that don't overlap with each other)
// in the stack
int k; // index of the non-overlapped second point
for (k = 1; k < num_in; k++) {
if (dist[k] > 1e-8) {
break;
}
}
if (k == num_in) {
// We reach the end, which means the convex hull is just one point
q[0] = p[t];
return 1;
}
q[1] = q[k];
int m = 2; // 2 points in the stack
// Step 5:
// Finally we can start the scanning process.
// When a non-convex relationship between the 3 points is found
// (either concave shape or duplicated points),
// we pop the previous point from the stack
// until the 3-point relationship is convex again, or
// until the stack only contains two points
for (int i = k + 1; i < num_in; i++) {
while (m > 1 && cross_2d<T>(q[i] - q[m - 2], q[m - 1] - q[m - 2]) >= 0) {
m--;
}
q[m++] = q[i];
}
// Step 6 (Optional):
// In general sense we need the original coordinates, so we
// need to shift the points back (reverting Step 2)
// But if we're only interested in getting the area/perimeter of the shape
// We can simply return.
if (!shift_to_zero) {
for (int i = 0; i < m; i++) {
q[i] += start;
}
}
return m;
}
template <typename T>
HOST_DEVICE_INLINE T polygon_area(const Point<T> (&q)[24], const int& m) {
if (m <= 2) {
return 0;
}
T area = 0;
for (int i = 1; i < m - 1; i++) {
area += fabs(cross_2d<T>(q[i] - q[0], q[i + 1] - q[0]));
}
return area / 2.0;
}
template <typename T>
HOST_DEVICE_INLINE T rboxes_intersection(
const RotatedBox<T>& box1,
const RotatedBox<T>& box2) {
// There are up to 4 x 4 + 4 + 4 = 24 intersections (including dups) returned
// from rotated_rect_intersection_pts
Point<T> intersectPts[24], orderedPts[24];
Point<T> pts1[4];
Point<T> pts2[4];
get_rotated_vertices<T>(box1, pts1);
get_rotated_vertices<T>(box2, pts2);
int num = get_intersection_points<T>(pts1, pts2, intersectPts);
if (num <= 2) {
return 0.0;
}
// Convex Hull to order the intersection points in clockwise order and find
// the contour area.
int num_convex = convex_hull_graham<T>(intersectPts, num, orderedPts, true);
return polygon_area<T>(orderedPts, num_convex);
}
} // namespace
template <typename T>
HOST_DEVICE_INLINE T
rbox_iou_single(T const* const box1_raw, T const* const box2_raw) {
// shift center to the middle point to achieve higher precision in result
RotatedBox<T> box1, box2;
auto center_shift_x = (box1_raw[0] + box2_raw[0]) / 2.0;
auto center_shift_y = (box1_raw[1] + box2_raw[1]) / 2.0;
box1.x_ctr = box1_raw[0] - center_shift_x;
box1.y_ctr = box1_raw[1] - center_shift_y;
box1.w = box1_raw[2];
box1.h = box1_raw[3];
box1.a = box1_raw[4];
box2.x_ctr = box2_raw[0] - center_shift_x;
box2.y_ctr = box2_raw[1] - center_shift_y;
box2.w = box2_raw[2];
box2.h = box2_raw[3];
box2.a = box2_raw[4];
const T area1 = box1.w * box1.h;
const T area2 = box2.w * box2.h;
if (area1 < 1e-14 || area2 < 1e-14) {
return 0.f;
}
const T intersection = rboxes_intersection<T>(box1, box2);
const T iou = intersection / (area1 + area2 - intersection);
return iou;
}
from paddle.utils.cpp_extension import CUDAExtension, setup
import paddle
from paddle.utils.cpp_extension import CppExtension, CUDAExtension, setup
if __name__ == "__main__":
setup(
name='rbox_iou_ops',
ext_modules=CUDAExtension(sources=['rbox_iou_op.cc', 'rbox_iou_op.cu']))
if paddle.device.is_compiled_with_cuda():
setup(
name='rbox_iou_ops',
ext_modules=CUDAExtension(
sources=['rbox_iou_op.cc', 'rbox_iou_op.cu'],
extra_compile_args={'cxx': ['-DPADDLE_WITH_CUDA']}))
else:
setup(
name='rbox_iou_ops',
ext_modules=CppExtension(sources=['rbox_iou_op.cc']))
......@@ -3,41 +3,15 @@ import sys
import time
from shapely.geometry import Polygon
import paddle
paddle.set_device('gpu:0')
paddle.disable_static()
import unittest
try:
from rbox_iou_ops import rbox_iou
except Exception as e:
print('import custom_ops error', e)
print('import rbox_iou_ops error', e)
sys.exit(-1)
# generate random data
rbox1 = np.random.rand(13000, 5)
rbox2 = np.random.rand(7, 5)
# x1 y1 w h [0, 0.5]
rbox1[:, 0:4] = rbox1[:, 0:4] * 0.45 + 0.001
rbox2[:, 0:4] = rbox2[:, 0:4] * 0.45 + 0.001
# generate rbox
rbox1[:, 4] = rbox1[:, 4] - 0.5
rbox2[:, 4] = rbox2[:, 4] - 0.5
print('rbox1', rbox1.shape, 'rbox2', rbox2.shape)
# to paddle tensor
pd_rbox1 = paddle.to_tensor(rbox1)
pd_rbox2 = paddle.to_tensor(rbox2)
iou = rbox_iou(pd_rbox1, pd_rbox2)
start_time = time.time()
print('paddle time:', time.time() - start_time)
print('iou is', iou.cpu().shape)
# get gt
def rbox2poly_single(rrect, get_best_begin_point=False):
"""
rrect:[x_ctr,y_ctr,w,h,angle]
......@@ -54,7 +28,7 @@ def rbox2poly_single(rrect, get_best_begin_point=False):
poly = R.dot(rect)
x0, x1, x2, x3 = poly[0, :4] + x_ctr
y0, y1, y2, y3 = poly[1, :4] + y_ctr
poly = np.array([x0, y0, x1, y1, x2, y2, x3, y3], dtype=np.float32)
poly = np.array([x0, y0, x1, y1, x2, y2, x3, y3], dtype=np.float64)
return poly
......@@ -87,8 +61,6 @@ def intersection(g, p):
g = Polygon(g)
p = Polygon(p)
#g = g.buffer(0)
#p = p.buffer(0)
if not g.is_valid or not p.is_valid:
return 0
......@@ -100,7 +72,6 @@ def intersection(g, p):
return inter / union
# rbox_iou by python
def rbox_overlaps(anchors, gt_bboxes, use_cv2=False):
"""
......@@ -118,7 +89,7 @@ def rbox_overlaps(anchors, gt_bboxes, use_cv2=False):
anchors_ploy = [rbox2poly_single(e) for e in anchors]
num_gt, num_anchors = len(gt_bboxes_ploy), len(anchors_ploy)
iou = np.zeros((num_gt, num_anchors), dtype=np.float32)
iou = np.zeros((num_gt, num_anchors), dtype=np.float64)
start_time = time.time()
for i in range(num_gt):
......@@ -129,23 +100,57 @@ def rbox_overlaps(anchors, gt_bboxes, use_cv2=False):
print('cur gt_bboxes_ploy[i]', gt_bboxes_ploy[i],
'anchors_ploy[j]', anchors_ploy[j], e)
iou = iou.T
print('intersection all sp_time', time.time() - start_time)
return iou
# make coor as int
ploy_rbox1 = rbox1
ploy_rbox2 = rbox2
ploy_rbox1[:, 0:4] = rbox1[:, 0:4] * 1024
ploy_rbox2[:, 0:4] = rbox2[:, 0:4] * 1024
start_time = time.time()
iou_py = rbox_overlaps(ploy_rbox1, ploy_rbox2, use_cv2=False)
print('rbox time', time.time() - start_time)
print(iou_py.shape)
iou_pd = iou.cpu().numpy()
sum_abs_diff = np.sum(np.abs(iou_pd - iou_py))
print('sum of abs diff', sum_abs_diff)
if sum_abs_diff < 0.02:
print("rbox_iou OP compute right!")
def gen_sample(n):
rbox = np.random.rand(n, 5)
rbox[:, 0:4] = rbox[:, 0:4] * 0.45 + 0.001
rbox[:, 4] = rbox[:, 4] - 0.5
return rbox
class RBoxIoUTest(unittest.TestCase):
def setUp(self):
self.initTestCase()
self.rbox1 = gen_sample(self.n)
self.rbox2 = gen_sample(self.m)
def initTestCase(self):
self.n = 13000
self.m = 7
def assertAllClose(self, x, y, msg, atol=5e-1, rtol=1e-2):
self.assertTrue(np.allclose(x, y, atol=atol, rtol=rtol), msg=msg)
def get_places(self):
places = [paddle.CPUPlace()]
if paddle.device.is_compiled_with_cuda():
places.append(paddle.CUDAPlace(0))
return places
def check_output(self, place):
paddle.disable_static()
pd_rbox1 = paddle.to_tensor(self.rbox1, place=place)
pd_rbox2 = paddle.to_tensor(self.rbox2, place=place)
actual_t = rbox_iou(pd_rbox1, pd_rbox2).numpy()
poly_rbox1 = self.rbox1
poly_rbox2 = self.rbox2
poly_rbox1[:, 0:4] = self.rbox1[:, 0:4] * 1024
poly_rbox2[:, 0:4] = self.rbox2[:, 0:4] * 1024
expect_t = rbox_overlaps(poly_rbox1, poly_rbox2, use_cv2=False)
self.assertAllClose(
actual_t,
expect_t,
msg="rbox_iou has diff at {} \nExpect {}\nBut got {}".format(
str(place), str(expect_t), str(actual_t)))
def test_output(self):
places = self.get_places()
for place in places:
self.check_output(place)
if __name__ == "__main__":
unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册