// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. // // The code is based on // https://github.com/csuhan/s2anet/blob/master/mmdet/ops/box_iou_rotated #include "paddle/extension.h" #include "rbox_iou_op.h" // 2D block with 32 * 16 = 512 threads per block const int BLOCK_DIM_X = 32; const int BLOCK_DIM_Y = 16; /** Computes ceil(a / b) */ static inline int CeilDiv(const int a, const int b) { return (a + b - 1) / b; } template __global__ void rbox_iou_cuda_kernel(const int rbox1_num, const int rbox2_num, const T *rbox1_data_ptr, const T *rbox2_data_ptr, T *output_data_ptr) { // get row_start and col_start const int rbox1_block_idx = blockIdx.x * blockDim.x; const int rbox2_block_idx = blockIdx.y * blockDim.y; const int rbox1_thread_num = min(rbox1_num - rbox1_block_idx, blockDim.x); const int rbox2_thread_num = min(rbox2_num - rbox2_block_idx, blockDim.y); __shared__ T block_boxes1[BLOCK_DIM_X * 5]; __shared__ T block_boxes2[BLOCK_DIM_Y * 5]; // It's safe to copy using threadIdx.x since BLOCK_DIM_X >= BLOCK_DIM_Y if (threadIdx.x < rbox1_thread_num && threadIdx.y == 0) { block_boxes1[threadIdx.x * 5 + 0] = rbox1_data_ptr[(rbox1_block_idx + threadIdx.x) * 5 + 0]; block_boxes1[threadIdx.x * 5 + 1] = rbox1_data_ptr[(rbox1_block_idx + threadIdx.x) * 5 + 1]; block_boxes1[threadIdx.x * 5 + 2] = rbox1_data_ptr[(rbox1_block_idx + threadIdx.x) * 5 + 2]; block_boxes1[threadIdx.x * 5 + 3] = rbox1_data_ptr[(rbox1_block_idx + threadIdx.x) * 5 + 3]; block_boxes1[threadIdx.x * 5 + 4] = rbox1_data_ptr[(rbox1_block_idx + threadIdx.x) * 5 + 4]; } // threadIdx.x < BLOCK_DIM_Y=rbox2_thread_num, just use same condition as // above: threadIdx.y == 0 if (threadIdx.x < rbox2_thread_num && threadIdx.y == 0) { block_boxes2[threadIdx.x * 5 + 0] = rbox2_data_ptr[(rbox2_block_idx + threadIdx.x) * 5 + 0]; block_boxes2[threadIdx.x * 5 + 1] = rbox2_data_ptr[(rbox2_block_idx + threadIdx.x) * 5 + 1]; block_boxes2[threadIdx.x * 5 + 2] = rbox2_data_ptr[(rbox2_block_idx + threadIdx.x) * 5 + 2]; block_boxes2[threadIdx.x * 5 + 3] = rbox2_data_ptr[(rbox2_block_idx + threadIdx.x) * 5 + 3]; block_boxes2[threadIdx.x * 5 + 4] = rbox2_data_ptr[(rbox2_block_idx + threadIdx.x) * 5 + 4]; } // sync __syncthreads(); if (threadIdx.x < rbox1_thread_num && threadIdx.y < rbox2_thread_num) { int offset = (rbox1_block_idx + threadIdx.x) * rbox2_num + rbox2_block_idx + threadIdx.y; output_data_ptr[offset] = rbox_iou_single( block_boxes1 + threadIdx.x * 5, block_boxes2 + threadIdx.y * 5); } } #define CHECK_INPUT_GPU(x) \ PD_CHECK(x.place() == paddle::PlaceType::kGPU, #x " must be a GPU Tensor.") std::vector RboxIouCUDAForward(const paddle::Tensor &rbox1, const paddle::Tensor &rbox2) { CHECK_INPUT_GPU(rbox1); CHECK_INPUT_GPU(rbox2); auto rbox1_num = rbox1.shape()[0]; auto rbox2_num = rbox2.shape()[0]; auto output = paddle::Tensor(paddle::PlaceType::kGPU, {rbox1_num, rbox2_num}); const int blocks_x = CeilDiv(rbox1_num, BLOCK_DIM_X); const int blocks_y = CeilDiv(rbox2_num, BLOCK_DIM_Y); dim3 blocks(blocks_x, blocks_y); dim3 threads(BLOCK_DIM_X, BLOCK_DIM_Y); PD_DISPATCH_FLOATING_TYPES( rbox1.type(), "rbox_iou_cuda_kernel", ([&] { rbox_iou_cuda_kernel<<>>( rbox1_num, rbox2_num, rbox1.data(), rbox2.data(), output.mutable_data()); })); return {output}; }