diff --git a/cmake/operators.cmake b/cmake/operators.cmake index 303830e928cf0a25dcbbcdcf262275a73a202d84..4c426d6687646832e219fa95ee96aac3a51a0d3a 100644 --- a/cmake/operators.cmake +++ b/cmake/operators.cmake @@ -386,6 +386,7 @@ function(op_library TARGET) list(REMOVE_ITEM hip_srcs "eigh_op.cu") list(REMOVE_ITEM hip_srcs "lstsq_op.cu") list(REMOVE_ITEM hip_srcs "multinomial_op.cu") + list(REMOVE_ITEM hip_srcs "multiclass_nms3_op.cu") hip_library( ${TARGET} SRCS ${cc_srcs} ${hip_cc_srcs} ${miopen_cu_cc_srcs} ${miopen_cu_srcs} diff --git a/paddle/fluid/operators/detection/multiclass_nms_op.cc b/paddle/fluid/operators/detection/multiclass_nms_op.cc index 159027dacadad9e369949d73b0afae1f2d3240d0..7c02f37fa3ef936e291f10c056125e343453bee7 100644 --- a/paddle/fluid/operators/detection/multiclass_nms_op.cc +++ b/paddle/fluid/operators/detection/multiclass_nms_op.cc @@ -614,6 +614,13 @@ class MultiClassNMS3Op : public MultiClassNMS2Op { const framework::VariableNameMap& outputs, const framework::AttributeMap& attrs) : MultiClassNMS2Op(type, inputs, outputs, attrs) {} + + protected: + phi::KernelKey GetExpectedKernelType( + const framework::ExecutionContext& ctx) const override { + return phi::KernelKey( + OperatorWithKernel::IndicateVarDataType(ctx, "Scores"), ctx.GetPlace()); + } }; class MultiClassNMS3OpMaker : public MultiClassNMS2OpMaker { diff --git a/paddle/phi/kernels/gpu/multiclass_nms3_kernel.cu b/paddle/phi/kernels/gpu/multiclass_nms3_kernel.cu new file mode 100644 index 0000000000000000000000000000000000000000..531e30a880a48ba8821d246cc46cdd04e58973c4 --- /dev/null +++ b/paddle/phi/kernels/gpu/multiclass_nms3_kernel.cu @@ -0,0 +1,1148 @@ +/* Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#ifndef PADDLE_WITH_HIP + +#include "paddle/phi/kernels/multiclass_nms3_kernel.h" + +#include +#include "cuda.h" // NOLINT + +#include "paddle/phi/backends/context_pool.h" +#include "paddle/phi/common/place.h" +#include "paddle/phi/core/dense_tensor.h" +#include "paddle/phi/core/kernel_registry.h" +#include "paddle/phi/core/tensor_utils.h" +#include "paddle/phi/kernels/funcs/concat_and_split_functor.h" +#include "paddle/phi/kernels/funcs/gather.cu.h" +#include "paddle/phi/kernels/funcs/math_function.h" +#include "paddle/phi/kernels/nonzero_kernel.h" + +#define CUDA_MEM_ALIGN 256 + +namespace phi { + +template +struct Bbox { + T xmin, ymin, xmax, ymax; + Bbox(T xmin, T ymin, T xmax, T ymax) + : xmin(xmin), ymin(ymin), xmax(xmax), ymax(ymax) {} + Bbox() = default; +}; + +template +size_t CalcCubSortPairsWorkspaceSize(int num_items, int num_segments) { + size_t temp_storage_bytes = 0; + cub::DeviceSegmentedRadixSort::SortPairsDescending( + reinterpret_cast(NULL), + temp_storage_bytes, + reinterpret_cast(NULL), + reinterpret_cast(NULL), + reinterpret_cast(NULL), + reinterpret_cast(NULL), + num_items, // # items + num_segments, // # segments + reinterpret_cast(NULL), + reinterpret_cast(NULL)); + return temp_storage_bytes; +} + +template +size_t CalcDetectionForwardBBoxDataSize(int N, int C1) { + return N * C1 * sizeof(T); +} + +template +size_t CalcDetectionForwardBBoxPermuteSize(bool share_location, int N, int C1) { + return share_location ? 0 : N * C1 * sizeof(T); +} + +template +size_t CalcDetectionForwardPreNMSSize(int N, int C2) { + return N * C2 * sizeof(T); +} + +template +size_t CalcDetectionForwardPostNMSSize(int N, int num_classes, int top_k) { + return N * num_classes * top_k * sizeof(T); +} + +size_t CalcTotalWorkspaceSize(size_t* workspaces, int count) { + size_t total = 0; + for (int i = 0; i < count; i++) { + total += workspaces[i]; + if (workspaces[i] % CUDA_MEM_ALIGN) { + total += CUDA_MEM_ALIGN - (workspaces[i] % CUDA_MEM_ALIGN); + } + } + return total; +} + +template +size_t CalcSortScoresPerClassWorkspaceSize(const int num, + const int num_classes, + const int num_preds_per_class) { + size_t wss[4]; + const int array_len = num * num_classes * num_preds_per_class; + wss[0] = array_len * sizeof(T); // temp scores + wss[1] = array_len * sizeof(int); // temp indices + wss[2] = (num * num_classes + 1) * sizeof(int); // offsets + wss[3] = CalcCubSortPairsWorkspaceSize( + array_len, num * num_classes); // cub workspace + + return CalcTotalWorkspaceSize(wss, 4); +} + +template +size_t CalcSortScoresPerImageWorkspaceSize(const int num_images, + const int num_items_per_image) { + const int array_len = num_images * num_items_per_image; + size_t wss[2]; + wss[0] = (num_images + 1) * sizeof(int); // offsets + wss[1] = CalcCubSortPairsWorkspaceSize(array_len, + num_images); // cub workspace + + return CalcTotalWorkspaceSize(wss, 2); +} + +template +size_t CalcDetectionInferenceWorkspaceSize(bool share_location, + int N, + int C1, + int C2, + int num_classes, + int num_preds_per_class, + int top_k) { + size_t wss[6]; + wss[0] = CalcDetectionForwardBBoxDataSize(N, C1); + wss[1] = CalcDetectionForwardPreNMSSize(N, C2); + wss[2] = CalcDetectionForwardPreNMSSize(N, C2); + wss[3] = CalcDetectionForwardPostNMSSize(N, num_classes, top_k); + wss[4] = CalcDetectionForwardPostNMSSize(N, num_classes, top_k); + wss[5] = + std::max(CalcSortScoresPerClassWorkspaceSize( + N, num_classes, num_preds_per_class), + CalcSortScoresPerImageWorkspaceSize(N, num_classes * top_k)); + return CalcTotalWorkspaceSize(wss, 6); +} + +// ALIGNPTR +int8_t* AlignPtr(int8_t* ptr, uintptr_t to) { + uintptr_t addr = (uintptr_t)ptr; + if (addr % to) { + addr += to - addr % to; + } + return reinterpret_cast(addr); +} + +// GetNEXTWORKSPACEPTR +int8_t* GetNextWorkspacePtr(int8_t* ptr, uintptr_t previous_workspace_size) { + uintptr_t addr = (uintptr_t)ptr; + addr += previous_workspace_size; + return AlignPtr(reinterpret_cast(addr), CUDA_MEM_ALIGN); +} + +/* ================== + * sortScoresPerClass + * ================== */ +template +__launch_bounds__(nthds_per_cta) __global__ + void PrepareSortData(const int num, + const int num_classes, + const int num_preds_per_class, + const int background_label_id, + const float confidence_threshold, + T_SCORE* conf_scores_gpu, + T_SCORE* temp_scores, + T_SCORE score_shift, + int* temp_idx, + int* d_offsets) { + // Prepare scores data for sort + const int cur_idx = blockIdx.x * nthds_per_cta + threadIdx.x; + const int num_preds_per_batch = num_classes * num_preds_per_class; + T_SCORE clip_val = + T_SCORE(static_cast(score_shift) + 1.f - 1.f / 1024.f); + if (cur_idx < num_preds_per_batch) { + const int class_idx = cur_idx / num_preds_per_class; + for (int i = 0; i < num; i++) { + const int target_idx = i * num_preds_per_batch + cur_idx; + const T_SCORE score = conf_scores_gpu[target_idx]; + + // "Clear" background labeled score and index + // Because we do not care about background + if (class_idx == background_label_id) { + // Set scores to 0 + // Set label = -1 + // add shift of 1.0 to normalize the score values + // to the range [1, 2). + // add a constant shift to scores will not change the sort + // result, but will help reduce the computation because + // we only need to sort the mantissa part of the floating-point + // numbers + temp_scores[target_idx] = score_shift; + temp_idx[target_idx] = -1; + conf_scores_gpu[target_idx] = score_shift; + } else { // "Clear" scores lower than threshold + if (static_cast(score) > confidence_threshold) { + // add shift of 1.0 to normalize the score values + // to the range [1, 2). + // add a constant shift to scores will not change the sort + // result, but will help reduce the computation because + // we only need to sort the mantissa part of the floating-point + // numbers + temp_scores[target_idx] = score + score_shift; + if (static_cast(score_shift) > 0.f && + (temp_scores[target_idx] >= clip_val)) + temp_scores[target_idx] = clip_val; + temp_idx[target_idx] = cur_idx + i * num_preds_per_batch; + } else { + // Set scores to 0 + // Set label = -1 + // add shift of 1.0 to normalize the score values + // to the range [1, 2). + // add a constant shift to scores will not change the sort + // result, but will help reduce the computation because + // we only need to sort the mantissa part of the floating-point + // numbers + temp_scores[target_idx] = score_shift; + temp_idx[target_idx] = -1; + conf_scores_gpu[target_idx] = score_shift; + // TODO(tizheng): HERE writing memory too many times + } + } + + if ((cur_idx % num_preds_per_class) == 0) { + const int offset_ct = i * num_classes + cur_idx / num_preds_per_class; + d_offsets[offset_ct] = offset_ct * num_preds_per_class; + // set the last element in d_offset + if (blockIdx.x == 0 && threadIdx.x == 0) + d_offsets[num * num_classes] = num * num_preds_per_batch; + } + } + } +} + +template +void SortScoresPerClassGPU(cudaStream_t stream, + const int num, + const int num_classes, + const int num_preds_per_class, + const int background_label_id, + const float confidence_threshold, + void* conf_scores_gpu, + void* index_array_gpu, + void* workspace, + const int score_bits, + const float score_shift) { + const int num_segments = num * num_classes; + void* temp_scores = workspace; + const int array_len = num * num_classes * num_preds_per_class; + void* temp_idx = GetNextWorkspacePtr(reinterpret_cast(temp_scores), + array_len * sizeof(T_SCORE)); + void* d_offsets = GetNextWorkspacePtr(reinterpret_cast(temp_idx), + array_len * sizeof(int)); + size_t cubOffsetSize = (num_segments + 1) * sizeof(int); + void* cubWorkspace = + GetNextWorkspacePtr(reinterpret_cast(d_offsets), cubOffsetSize); + + const int BS = 512; + const int GS = (num_classes * num_preds_per_class + BS - 1) / BS; + // prepare the score, index, and offsets for CUB radix sort + // also normalize the scores to the range [1, 2) + // so we only need to sort the mantissa of floating-point numbers + // since their sign bit and exponential bits are identical + // we will subtract the 1.0 shift in gatherTopDetections() + PrepareSortData + <<>>(num, + num_classes, + num_preds_per_class, + background_label_id, + confidence_threshold, + reinterpret_cast(conf_scores_gpu), + reinterpret_cast(temp_scores), + T_SCORE(score_shift), + reinterpret_cast(temp_idx), + reinterpret_cast(d_offsets)); + + size_t temp_storage_bytes = + CalcCubSortPairsWorkspaceSize(array_len, num_segments); + size_t begin_bit = 0; + size_t end_bit = sizeof(T_SCORE) * 8; + if (sizeof(T_SCORE) == 2 && score_bits > 0 && score_bits <= 10) { + // only sort score_bits in 10 mantissa bits. + end_bit = 10; + begin_bit = end_bit - score_bits; + } + cub::DeviceSegmentedRadixSort::SortPairsDescending( + cubWorkspace, + temp_storage_bytes, + reinterpret_cast(temp_scores), + reinterpret_cast(conf_scores_gpu), + reinterpret_cast(temp_idx), + reinterpret_cast(index_array_gpu), + array_len, + num_segments, + reinterpret_cast(d_offsets), + reinterpret_cast(d_offsets) + 1, + begin_bit, + end_bit, + stream); + PADDLE_ENFORCE_GPU_SUCCESS(cudaGetLastError()); +} + +/* =========== + * allClassNMS + * =========== */ +template +__device__ float CalcBboxSize(const Bbox& bbox, const bool normalized) { + if (static_cast(bbox.xmax) < static_cast(bbox.xmin) || + static_cast(bbox.ymax) < static_cast(bbox.ymin)) { + // If bbox is invalid (e.g. xmax < xmin or ymax < ymin), return 0. + return 0; + } else { + float width = static_cast(bbox.xmax) - static_cast(bbox.xmin); + float height = + static_cast(bbox.ymax) - static_cast(bbox.ymin); + if (normalized) { + return width * height; + } else { + // If bbox is not within range [0, 1]. + return (width + 1.f) * (height + 1.f); + } + } +} + +template +__device__ void CalcIntersectBbox(const Bbox& bbox1, + const Bbox& bbox2, + Bbox* intersect_bbox) { + if (bbox2.xmin > bbox1.xmax || bbox2.xmax < bbox1.xmin || + bbox2.ymin > bbox1.ymax || bbox2.ymax < bbox1.ymin) { + // Return [0, 0, 0, 0] if there is no intersection. + intersect_bbox->xmin = T_BBOX(0); + intersect_bbox->ymin = T_BBOX(0); + intersect_bbox->xmax = T_BBOX(0); + intersect_bbox->ymax = T_BBOX(0); + } else { + intersect_bbox->xmin = max(bbox1.xmin, bbox2.xmin); + intersect_bbox->ymin = max(bbox1.ymin, bbox2.ymin); + intersect_bbox->xmax = min(bbox1.xmax, bbox2.xmax); + intersect_bbox->ymax = min(bbox1.ymax, bbox2.ymax); + } +} + +template +__device__ Bbox GetDiagonalMinMaxSortedBox(const Bbox& bbox1) { + Bbox result; + result.xmin = min(bbox1.xmin, bbox1.xmax); + result.xmax = max(bbox1.xmin, bbox1.xmax); + + result.ymin = min(bbox1.ymin, bbox1.ymax); + result.ymax = max(bbox1.ymin, bbox1.ymax); + return result; +} + +template +__device__ void GetFlippedBox(const T_BBOX* bbox1, + bool flip_xy, + Bbox* result) { + result->xmin = flip_xy ? bbox1[1] : bbox1[0]; + result->ymin = flip_xy ? bbox1[0] : bbox1[1]; + result->xmax = flip_xy ? bbox1[3] : bbox1[2]; + result->ymax = flip_xy ? bbox1[2] : bbox1[3]; +} + +template +__device__ float CalcJaccardOverlap(const Bbox& bbox1, + const Bbox& bbox2, + const bool normalized, + const bool caffe_semantics) { + Bbox intersect_bbox; + + Bbox localbbox1 = GetDiagonalMinMaxSortedBox(bbox1); + Bbox localbbox2 = GetDiagonalMinMaxSortedBox(bbox2); + + CalcIntersectBbox(localbbox1, localbbox2, &intersect_bbox); + + float intersect_width, intersect_height; + // Only when using Caffe semantics, IOU calculation adds "1" to width and + // height if bbox is not normalized. + // https://github.com/weiliu89/caffe/blob/ssd/src/caffe/util/bbox_util.cpp#L92-L97 + if (normalized || !caffe_semantics) { + intersect_width = static_cast(intersect_bbox.xmax) - + static_cast(intersect_bbox.xmin); + intersect_height = static_cast(intersect_bbox.ymax) - + static_cast(intersect_bbox.ymin); + } else { + intersect_width = static_cast(intersect_bbox.xmax) - + static_cast(intersect_bbox.xmin) + + static_cast(T_BBOX(1)); + intersect_height = static_cast(intersect_bbox.ymax) - + static_cast(intersect_bbox.ymin) + + static_cast(T_BBOX(1)); + } + if (intersect_width > 0 && intersect_height > 0) { + float intersect_size = intersect_width * intersect_height; + float bbox1_size = CalcBboxSize(localbbox1, normalized); + float bbox2_size = CalcBboxSize(localbbox2, normalized); + return intersect_size / (bbox1_size + bbox2_size - intersect_size); + } else { + return 0.; + } +} + +template +__global__ void AllClassNMSKernel( + const int num, + const int num_classes, + const int num_preds_per_class, + const int top_k, + const float nms_threshold, + const bool share_location, + const bool is_normalized, + T_BBOX* bbox_data, // bbox_data should be float to preserve location + // information + T_SCORE* before_nms_scores, + int* before_nms_index_array, + T_SCORE* after_nms_scores, + int* after_nms_index_array, + bool flip_xy, + const float score_shift, + bool caffe_semantics) { + // __shared__ bool kept_bboxinfo_flag[CAFFE_CUDA_NUM_THREADS * TSIZE]; + extern __shared__ bool kept_bboxinfo_flag[]; + + for (int i = 0; i < num; i++) { + int32_t const offset = i * num_classes * num_preds_per_class + + blockIdx.x * num_preds_per_class; + // Should not write data beyond [offset, top_k). + int32_t const max_idx = offset + top_k; + // Should not read beyond [offset, num_preds_per_class). + int32_t const max_read_idx = offset + min(top_k, num_preds_per_class); + int32_t const bbox_idx_offset = + i * num_preds_per_class * (share_location ? 1 : num_classes); + + // local thread data + int loc_bboxIndex[TSIZE]; + Bbox loc_bbox[TSIZE]; + + // initialize Bbox, Bboxinfo, kept_bboxinfo_flag + // Eliminate shared memory RAW hazard + __syncthreads(); +#pragma unroll + for (int t = 0; t < TSIZE; t++) { + const int cur_idx = threadIdx.x + blockDim.x * t; + const int item_idx = offset + cur_idx; + // Init all output data + if (item_idx < max_idx) { + // Do not access data if it exceeds read boundary + if (item_idx < max_read_idx) { + loc_bboxIndex[t] = before_nms_index_array[item_idx]; + } else { + loc_bboxIndex[t] = -1; + } + + if (loc_bboxIndex[t] != -1) { + const int bbox_data_idx = + share_location + ? (loc_bboxIndex[t] % num_preds_per_class + bbox_idx_offset) + : loc_bboxIndex[t]; + GetFlippedBox(&bbox_data[bbox_data_idx * 4], flip_xy, &loc_bbox[t]); + kept_bboxinfo_flag[cur_idx] = true; + } else { + kept_bboxinfo_flag[cur_idx] = false; + } + } else { + kept_bboxinfo_flag[cur_idx] = false; + } + } + + // filter out overlapped boxes with lower scores + int ref_item_idx = offset; + + int32_t ref_bbox_idx = -1; + if (ref_item_idx < max_read_idx) { + ref_bbox_idx = + share_location + ? (before_nms_index_array[ref_item_idx] % num_preds_per_class + + bbox_idx_offset) + : before_nms_index_array[ref_item_idx]; + } + while ((ref_bbox_idx != -1) && ref_item_idx < max_read_idx) { + Bbox ref_bbox; + GetFlippedBox(&bbox_data[ref_bbox_idx * 4], flip_xy, &ref_bbox); + + // Eliminate shared memory RAW hazard + __syncthreads(); + + for (int t = 0; t < TSIZE; t++) { + const int cur_idx = threadIdx.x + blockDim.x * t; + const int item_idx = offset + cur_idx; + + if ((kept_bboxinfo_flag[cur_idx]) && (item_idx > ref_item_idx)) { + if (CalcJaccardOverlap( + ref_bbox, loc_bbox[t], is_normalized, caffe_semantics) > + nms_threshold) { + kept_bboxinfo_flag[cur_idx] = false; + } + } + } + __syncthreads(); + + do { + ref_item_idx++; + } while (ref_item_idx < max_read_idx && + !kept_bboxinfo_flag[ref_item_idx - offset]); + + // Move to next valid point + if (ref_item_idx < max_read_idx) { + ref_bbox_idx = + share_location + ? (before_nms_index_array[ref_item_idx] % num_preds_per_class + + bbox_idx_offset) + : before_nms_index_array[ref_item_idx]; + } + } + + // store data + for (int t = 0; t < TSIZE; t++) { + const int cur_idx = threadIdx.x + blockDim.x * t; + const int read_item_idx = offset + cur_idx; + const int write_item_idx = + (i * num_classes * top_k + blockIdx.x * top_k) + cur_idx; + /* + * If not not keeping the bbox + * Set the score to 0 + * Set the bounding box index to -1 + */ + if (read_item_idx < max_idx) { + after_nms_scores[write_item_idx] = + kept_bboxinfo_flag[cur_idx] + ? T_SCORE(before_nms_scores[read_item_idx]) + : T_SCORE(score_shift); + after_nms_index_array[write_item_idx] = + kept_bboxinfo_flag[cur_idx] ? loc_bboxIndex[t] : -1; + } + } + } +} + +template +void AllClassNMSGPU(cudaStream_t stream, + const int num, + const int num_classes, + const int num_preds_per_class, + const int top_k, + const float nms_threshold, + const bool share_location, + const bool is_normalized, + void* bbox_data, + void* before_nms_scores, + void* before_nms_index_array, + void* after_nms_scores, + void* after_nms_index_array, + bool flip_xy, + const float score_shift, + bool caffe_semantics) { +#define P(tsize) AllClassNMSKernel + + void (*kernel[8])(const int, + const int, + const int, + const int, + const float, + const bool, + const bool, + T_BBOX*, + T_SCORE*, + int*, + T_SCORE*, + int*, + bool, + const float, + bool) = { + P(1), + P(2), + P(3), + P(4), + P(5), + P(6), + P(7), + P(8), + }; + + const int BS = 512; + const int GS = num_classes; + const int t_size = (top_k + BS - 1) / BS; + + kernel[t_size - 1]<<>>( + num, + num_classes, + num_preds_per_class, + top_k, + nms_threshold, + share_location, + is_normalized, + reinterpret_cast(bbox_data), + reinterpret_cast(before_nms_scores), + reinterpret_cast(before_nms_index_array), + reinterpret_cast(after_nms_scores), + reinterpret_cast(after_nms_index_array), + flip_xy, + score_shift, + caffe_semantics); + + PADDLE_ENFORCE_GPU_SUCCESS(cudaGetLastError()); +} + +/* ================== + * sortScoresPerImage + * ================== */ +template +__launch_bounds__(nthds_per_cta) __global__ + void SetUniformOffsetsKernel(const int num_segments, + const int offset, + int* d_offsets) { + const int idx = blockIdx.x * nthds_per_cta + threadIdx.x; + if (idx <= num_segments) d_offsets[idx] = idx * offset; +} + +void SetUniformOffsets(cudaStream_t stream, + const int num_segments, + const int offset, + int* d_offsets) { + const int BS = 32; + const int GS = (num_segments + 1 + BS - 1) / BS; + SetUniformOffsetsKernel + <<>>(num_segments, offset, d_offsets); +} + +/* ================ + * gatherNMSOutputs + * ================ */ +template +__device__ T_BBOX saturate(T_BBOX v) { + return max(min(v, T_BBOX(1)), T_BBOX(0)); +} + +template +__launch_bounds__(nthds_per_cta) __global__ + void GatherNMSOutputsKernel(const bool share_location, + const int num_images, + const int num_preds_per_class, + const int num_classes, + const int top_k, + const int keep_top_k, + const int* indices, + const T_SCORE* scores, + const T_BBOX* bbox_data, + int* num_detections, + T_BBOX* nmsed_boxes, + T_BBOX* nmsed_scores, + T_BBOX* nmsed_classes, + int* nmsed_indices, + int* nmsed_valid_mask, + bool clip_boxes, + const T_SCORE score_shift) { + if (keep_top_k > top_k) return; + for (int i = blockIdx.x * nthds_per_cta + threadIdx.x; + i < num_images * keep_top_k; + i += gridDim.x * nthds_per_cta) { + const int imgId = i / keep_top_k; + const int detId = i % keep_top_k; + const int offset = imgId * num_classes * top_k; + const int index = indices[offset + detId]; + const T_SCORE score = scores[offset + detId]; + if (index == -1) { + nmsed_classes[i] = -1; + nmsed_scores[i] = 0; + nmsed_boxes[i * 4] = 0; + nmsed_boxes[i * 4 + 1] = 0; + nmsed_boxes[i * 4 + 2] = 0; + nmsed_boxes[i * 4 + 3] = 0; + nmsed_indices[i] = -1; + nmsed_valid_mask[i] = 0; + } else { + const int bbox_offset = + imgId * (share_location ? num_preds_per_class + : (num_classes * num_preds_per_class)); + const int bbox_id = + ((share_location ? (index % num_preds_per_class) + : index % (num_classes * num_preds_per_class)) + + bbox_offset) * + 4; + nmsed_classes[i] = (index % (num_classes * num_preds_per_class)) / + num_preds_per_class; // label + nmsed_scores[i] = score; // confidence score + nmsed_scores[i] = nmsed_scores[i] - score_shift; + const T_BBOX xMin = bbox_data[bbox_id]; + const T_BBOX yMin = bbox_data[bbox_id + 1]; + const T_BBOX xMax = bbox_data[bbox_id + 2]; + const T_BBOX yMax = bbox_data[bbox_id + 3]; + // clipped bbox xmin + nmsed_boxes[i * 4] = clip_boxes ? saturate(xMin) : xMin; + // clipped bbox ymin + nmsed_boxes[i * 4 + 1] = clip_boxes ? saturate(yMin) : yMin; + // clipped bbox xmax + nmsed_boxes[i * 4 + 2] = clip_boxes ? saturate(xMax) : xMax; + // clipped bbox ymax + nmsed_boxes[i * 4 + 3] = clip_boxes ? saturate(yMax) : yMax; + nmsed_indices[i] = bbox_id >> 2; + nmsed_valid_mask[i] = 1; + atomicAdd(&num_detections[i / keep_top_k], 1); + } + } +} + +template +void GatherNMSOutputsGPU(cudaStream_t stream, + const bool share_location, + const int num_images, + const int num_preds_per_class, + const int num_classes, + const int top_k, + const int keep_top_k, + const void* indices, + const void* scores, + const void* bbox_data, + void* num_detections, + void* nmsed_boxes, + void* nmsed_scores, + void* nmsed_classes, + void* nmsed_indices, + void* nmsed_valid_mask, + bool clip_boxes, + const float score_shift) { + PADDLE_ENFORCE_GPU_SUCCESS( + cudaMemsetAsync(num_detections, 0, num_images * sizeof(int), stream)); + const int BS = 32; + const int GS = 32; + GatherNMSOutputsKernel + <<>>(share_location, + num_images, + num_preds_per_class, + num_classes, + top_k, + keep_top_k, + reinterpret_cast(indices), + reinterpret_cast(scores), + reinterpret_cast(bbox_data), + reinterpret_cast(num_detections), + reinterpret_cast(nmsed_boxes), + reinterpret_cast(nmsed_scores), + reinterpret_cast(nmsed_classes), + reinterpret_cast(nmsed_indices), + reinterpret_cast(nmsed_valid_mask), + clip_boxes, + T_SCORE(score_shift)); + + PADDLE_ENFORCE_GPU_SUCCESS(cudaGetLastError()); +} + +template +void SortScoresPerImageGPU(cudaStream_t stream, + const int num_images, + const int num_items_per_image, + void* unsorted_scores, + void* unsorted_bbox_indices, + void* sorted_scores, + void* sorted_bbox_indices, + void* workspace, + int score_bits) { + void* d_offsets = workspace; + void* cubWorkspace = GetNextWorkspacePtr(reinterpret_cast(d_offsets), + (num_images + 1) * sizeof(int)); + + SetUniformOffsets(stream, + num_images, + num_items_per_image, + reinterpret_cast(d_offsets)); + + const int array_len = num_images * num_items_per_image; + size_t temp_storage_bytes = + CalcCubSortPairsWorkspaceSize(array_len, num_images); + size_t begin_bit = 0; + size_t end_bit = sizeof(T_SCORE) * 8; + if (sizeof(T_SCORE) == 2 && score_bits > 0 && score_bits <= 10) { + end_bit = 10; + begin_bit = end_bit - score_bits; + } + cub::DeviceSegmentedRadixSort::SortPairsDescending( + cubWorkspace, + temp_storage_bytes, + reinterpret_cast(unsorted_scores), + reinterpret_cast(sorted_scores), + reinterpret_cast(unsorted_bbox_indices), + reinterpret_cast(sorted_bbox_indices), + array_len, + num_images, + reinterpret_cast(d_offsets), + reinterpret_cast(d_offsets) + 1, + begin_bit, + end_bit, + stream); + PADDLE_ENFORCE_GPU_SUCCESS(cudaGetLastError()); +} + +template +void InferNMS(cudaStream_t stream, + const int N, + const int per_batch_boxes_size, + const int per_batch_scores_size, + const bool share_location, + const int background_label_id, + const int num_preds_per_class, + const int num_classes, + const int top_k, + const int keep_top_k, + const float score_threshold, + const float iou_threshold, + const void* loc_data, + const void* conf_data, + void* keep_count, + void* nmsed_boxes, + void* nmsed_scores, + void* nmsed_classes, + void* nmsed_indices, + void* nmsed_valid_mask, + void* workspace, + bool is_normalized, + bool conf_sigmoid, + bool clip_boxes, + int score_bits, + bool caffe_semantics) { + PADDLE_ENFORCE_EQ( + share_location, + true, + phi::errors::Unimplemented("share_location=false is not supported.")); + + // Prepare workspaces + size_t bbox_data_size = + CalcDetectionForwardBBoxDataSize(N, per_batch_boxes_size); + void* bbox_data_raw = workspace; + PADDLE_ENFORCE_GPU_SUCCESS(cudaMemcpyAsync(bbox_data_raw, + loc_data, + bbox_data_size, + cudaMemcpyDeviceToDevice, + stream)); + void* bbox_data = bbox_data_raw; + + const int num_scores = N * per_batch_scores_size; + size_t total_scores_size = + CalcDetectionForwardPreNMSSize(N, per_batch_scores_size); + void* scores = + GetNextWorkspacePtr(reinterpret_cast(bbox_data), bbox_data_size); + PADDLE_ENFORCE_GPU_SUCCESS(cudaMemcpyAsync( + scores, conf_data, total_scores_size, cudaMemcpyDeviceToDevice, stream)); + + size_t indices_size = + CalcDetectionForwardPreNMSSize(N, per_batch_scores_size); + void* indices = + GetNextWorkspacePtr(reinterpret_cast(scores), total_scores_size); + + size_t post_nms_scores_size = + CalcDetectionForwardPostNMSSize(N, num_classes, top_k); + size_t post_nms_indices_size = CalcDetectionForwardPostNMSSize( + N, num_classes, top_k); // indices are full int32 + void* post_nms_scores = + GetNextWorkspacePtr(reinterpret_cast(indices), indices_size); + void* post_nms_indices = GetNextWorkspacePtr( + reinterpret_cast(post_nms_scores), post_nms_scores_size); + + void* sorting_workspace = GetNextWorkspacePtr( + reinterpret_cast(post_nms_indices), post_nms_indices_size); + // Sort the scores so that the following NMS could be applied. + float score_shift = 0.f; + SortScoresPerClassGPU(stream, + N, + num_classes, + num_preds_per_class, + background_label_id, + score_threshold, + scores, + indices, + sorting_workspace, + score_bits, + score_shift); + + // This is set to true as the input bounding boxes are of the format [ymin, + // xmin, ymax, xmax]. The default implementation assumes [xmin, ymin, xmax, + // ymax] + bool flip_xy = true; + // NMS + AllClassNMSGPU(stream, + N, + num_classes, + num_preds_per_class, + top_k, + iou_threshold, + share_location, + is_normalized, + bbox_data, + scores, + indices, + post_nms_scores, + post_nms_indices, + flip_xy, + score_shift, + caffe_semantics); + + // Sort the bounding boxes after NMS using scores + SortScoresPerImageGPU(stream, + N, + num_classes * top_k, + post_nms_scores, + post_nms_indices, + scores, + indices, + sorting_workspace, + score_bits); + + // Gather data from the sorted bounding boxes after NMS + GatherNMSOutputsGPU(stream, + share_location, + N, + num_preds_per_class, + num_classes, + top_k, + keep_top_k, + indices, + scores, + bbox_data, + keep_count, + nmsed_boxes, + nmsed_scores, + nmsed_classes, + nmsed_indices, + nmsed_valid_mask, + clip_boxes, + score_shift); +} + +template +void MultiClassNMSGPUKernel(const Context& ctx, + const DenseTensor& bboxes, + const DenseTensor& scores, + const paddle::optional& rois_num, + float score_threshold, + int nms_top_k, + int keep_top_k, + float nms_threshold, + bool normalized, + float nms_eta, + int background_label, + DenseTensor* out, + DenseTensor* index, + DenseTensor* nms_rois_num) { + bool return_index = index != nullptr; + bool has_roisnum = rois_num.get_ptr() != nullptr; + auto score_dims = scores.dims(); + auto score_size = score_dims.size(); + + bool is_supported = (score_size == 3) && (nms_top_k >= 0) && + (nms_top_k <= 4096) && (keep_top_k >= 0) && + (nms_eta == 1.0) && !has_roisnum; + if (!is_supported) { + VLOG(6) + << "This configuration is not supported by GPU kernel. Falling back to " + "CPU kernel. " + "Expect (score_size == 3) && (nms_top_k >= 0) && (nms_top_k <= 4096)" + "(keep_top_k >= 0) && (nms_eta == 1.0) && !has_roisnum, " + "got score_size=" + << score_size << ", nms_top_k=" << nms_top_k + << ", keep_top_k=" << keep_top_k << ", nms_eta=" << nms_eta + << ", has_roisnum=" << has_roisnum; + + DenseTensor bboxes_cpu, scores_cpu, rois_num_cpu_tenor; + DenseTensor out_cpu, index_cpu, nms_rois_num_cpu; + paddle::optional rois_num_cpu(paddle::none); + auto cpu_place = phi::CPUPlace(); + auto gpu_place = ctx.GetPlace(); + + // copy from GPU to CPU + phi::Copy(ctx, bboxes, cpu_place, false, &bboxes_cpu); + phi::Copy(ctx, scores, cpu_place, false, &scores_cpu); + if (has_roisnum) { + phi::Copy( + ctx, *rois_num.get_ptr(), cpu_place, false, &rois_num_cpu_tenor); + rois_num_cpu = paddle::optional(rois_num_cpu_tenor); + } + ctx.Wait(); + phi::DeviceContextPool& pool = phi::DeviceContextPool::Instance(); + auto* cpu_ctx = static_cast(pool.Get(cpu_place)); + MultiClassNMSKernel(*cpu_ctx, + bboxes_cpu, + scores_cpu, + rois_num_cpu, + score_threshold, + nms_top_k, + keep_top_k, + nms_threshold, + normalized, + nms_eta, + background_label, + &out_cpu, + &index_cpu, + &nms_rois_num_cpu); + // copy back + phi::Copy(ctx, out_cpu, gpu_place, false, out); + phi::Copy(ctx, index_cpu, gpu_place, false, index); + phi::Copy(ctx, nms_rois_num_cpu, gpu_place, false, nms_rois_num); + return; + } + + // Calculate input shapes + int64_t batch_size = score_dims[0]; + const int64_t per_batch_boxes_size = + bboxes.dims()[1] * bboxes.dims()[2]; // M * 4 + const int64_t per_batch_scores_size = + scores.dims()[1] * scores.dims()[2]; // C * M + const int64_t num_priors = bboxes.dims()[1]; // M + const int64_t num_classes = scores.dims()[1]; // C + const bool share_location = true; + auto stream = reinterpret_cast(ctx).stream(); + // Sanity check + PADDLE_ENFORCE_LE( + nms_top_k, + num_priors, + phi::errors::InvalidArgument("Expect nms_top_k (%d)" + " <= num of boxes per batch (%d).", + nms_top_k, + num_priors)); + PADDLE_ENFORCE_LE(keep_top_k, + nms_top_k, + phi::errors::InvalidArgument("Expect keep_top_k (%d)" + " <= nms_top_k (%d).", + keep_top_k, + nms_top_k)); + + // Transform the layout of bboxes and scores + // bboxes: [N,M,4] -> [N,1,M,4] + DenseTensor transformed_bboxes(bboxes.type()); + transformed_bboxes.ShareDataWith(bboxes).Resize( + {bboxes.dims()[0], 1, bboxes.dims()[1], bboxes.dims()[2]}); + // scores: [N, C, M] => [N, C, M, 1] + DenseTensor transformed_scores(scores.type()); + transformed_scores.ShareDataWith(scores).Resize( + {scores.dims()[0], scores.dims()[1], scores.dims()[2], 1}); + + // Prepare intermediate outputs for NMS kernels + DenseTensor keep_count(DataType::INT32); + keep_count.Resize({batch_size}); + if (nms_rois_num != nullptr) { + nms_rois_num->Resize({batch_size}); + ctx.template Alloc(nms_rois_num); + keep_count.ShareDataWith(*nms_rois_num); + } else { + ctx.template Alloc(&keep_count); + } + + DenseTensor nmsed_indices(DataType::INT32); + nmsed_indices.Resize({batch_size * keep_top_k, 1}); + ctx.template Alloc(&nmsed_indices); + + DenseTensor nmsed_valid_mask(DataType::INT32); + nmsed_valid_mask.Resize({batch_size * keep_top_k}); + ctx.template Alloc(&nmsed_valid_mask); + + DenseTensor nmsed_boxes(bboxes.dtype()); + DenseTensor nmsed_scores(scores.dtype()); + DenseTensor nmsed_classes(scores.dtype()); + nmsed_boxes.Resize({batch_size * keep_top_k, 4}); + nmsed_scores.Resize({batch_size * keep_top_k, 1}); + nmsed_classes.Resize({batch_size * keep_top_k, 1}); + ctx.template Alloc(&nmsed_boxes); + ctx.template Alloc(&nmsed_scores); + ctx.template Alloc(&nmsed_classes); + + auto workspace_size = + CalcDetectionInferenceWorkspaceSize(share_location, + batch_size, + per_batch_boxes_size, + per_batch_scores_size, + num_classes, + num_priors, + nms_top_k); + + DenseTensor workspace = DenseTensor(); + workspace.Resize({static_cast(workspace_size)}); + T* workspace_ptr = ctx.template Alloc(&workspace); + + // Launch the NMS kernel + InferNMS(stream, + batch_size, + per_batch_boxes_size, + per_batch_scores_size, + share_location, + background_label, + num_priors, + num_classes, + nms_top_k, + keep_top_k, + score_threshold, + nms_threshold, + transformed_bboxes.data(), + transformed_scores.data(), + keep_count.data(), + nmsed_boxes.data(), + nmsed_scores.data(), + nmsed_classes.data(), + nmsed_indices.data(), + nmsed_valid_mask.data(), + workspace_ptr, + normalized, + false, + false, + 0, + true); + + // Post-processing to get the final outputs + // Concat the individual class, score and boxes outputs + // into a [N * M, 6] tensor. + DenseTensor raw_out; + raw_out.Resize({batch_size * keep_top_k, 6}); + ctx.template Alloc(&raw_out); + phi::funcs::ConcatFunctor concat; + concat(ctx, {nmsed_classes, nmsed_scores, nmsed_boxes}, 1, &raw_out); + + // Output of NMS kernel may include invalid entries, which is + // marked by nmsed_valid_mask. Eliminate the invalid entries + // by gathering the valid ones. + + // 1. Get valid indices + DenseTensor valid_indices; + NonZeroKernel(ctx, nmsed_valid_mask, &valid_indices); + // 2. Perform gathering + const int64_t valid_samples = valid_indices.dims()[0]; + out->Resize({valid_samples, 6}); + ctx.template Alloc(out); + phi::funcs::GPUGatherNd(ctx, raw_out, valid_indices, out); + index->Resize({valid_samples, 1}); + ctx.template Alloc(index); + phi::funcs::GPUGatherNd( + ctx, nmsed_indices, valid_indices, index); +} + +} // namespace phi + +PD_REGISTER_KERNEL(multiclass_nms3, // cuda_only + GPU, + ALL_LAYOUT, + phi::MultiClassNMSGPUKernel, + float) { + kernel->OutputAt(1).SetDataType(phi::DataType::INT32); + kernel->OutputAt(2).SetDataType(phi::DataType::INT32); +} + +#endif diff --git a/python/paddle/fluid/tests/unittests/test_multiclass_nms_op.py b/python/paddle/fluid/tests/unittests/test_multiclass_nms_op.py index 79207495b7764c0bbda3bbdea0192d008e590b06..423e7039dd9cbd270e26d0676be6cee8ac0693a3 100644 --- a/python/paddle/fluid/tests/unittests/test_multiclass_nms_op.py +++ b/python/paddle/fluid/tests/unittests/test_multiclass_nms_op.py @@ -20,7 +20,7 @@ from eager_op_test import OpTest import paddle from paddle import _C_ops, _legacy_C_ops -from paddle.fluid import _non_static_mode, in_dygraph_mode +from paddle.fluid import _non_static_mode, core, in_dygraph_mode from paddle.fluid.layer_helper import LayerHelper @@ -355,6 +355,7 @@ def batched_multiclass_nms( nms_top_k, keep_top_k, normalized=True, + gpu_logic=False, ): batch_size = scores.shape[0] num_boxes = scores.shape[2] @@ -392,9 +393,14 @@ def batched_multiclass_nms( idx + n * num_boxes, ] ) - sorted_det_out = sorted( - tmp_det_out, key=lambda tup: tup[0], reverse=False - ) + if gpu_logic: + sorted_det_out = sorted( + tmp_det_out, key=lambda tup: tup[1], reverse=True + ) + else: + sorted_det_out = sorted( + tmp_det_out, key=lambda tup: tup[0], reverse=False + ) det_outs.extend(sorted_det_out) return det_outs, lod @@ -747,7 +753,7 @@ class TestMulticlassNMS3Op(TestMulticlassNMS2Op): background = 0 nms_threshold = 0.3 nms_top_k = 400 - keep_top_k = 200 + keep_top_k = 200 if not hasattr(self, 'keep_top_k') else self.keep_top_k score_threshold = self.score_threshold scores = np.random.random((N * M, C)).astype('float32') @@ -768,6 +774,7 @@ class TestMulticlassNMS3Op(TestMulticlassNMS2Op): nms_threshold, nms_top_k, keep_top_k, + gpu_logic=self.gpu_logic if hasattr(self, 'gpu_logic') else None, ) det_outs = np.array(det_outs) @@ -797,7 +804,8 @@ class TestMulticlassNMS3Op(TestMulticlassNMS2Op): } def test_check_output(self): - self.check_output() + place = paddle.CPUPlace() + self.check_output_with_place(place) class TestMulticlassNMS3OpNoOutput(TestMulticlassNMS3Op): @@ -807,6 +815,51 @@ class TestMulticlassNMS3OpNoOutput(TestMulticlassNMS3Op): self.score_threshold = 2.0 +@unittest.skipIf( + not core.is_compiled_with_cuda(), "core is not compiled with CUDA" +) +class TestMulticlassNMS3OpGPU(TestMulticlassNMS2Op): + def test_check_output(self): + place = paddle.CUDAPlace(0) + self.check_output_with_place(place) + + def set_argument(self): + self.score_threshold = 0.01 + self.gpu_logic = True + + +@unittest.skipIf( + not core.is_compiled_with_cuda(), "core is not compiled with CUDA" +) +class TestMulticlassNMS3OpGPULessOutput(TestMulticlassNMS3OpGPU): + def set_argument(self): + # Here set 0.08 to make output box size less than keep_top_k + self.score_threshold = 0.08 + self.gpu_logic = True + + +@unittest.skipIf( + not core.is_compiled_with_cuda(), "core is not compiled with CUDA" +) +class TestMulticlassNMS3OpGPUNoOutput(TestMulticlassNMS3OpGPU): + def set_argument(self): + # Here set 2.0 to test the case there is no outputs. + # In practical use, 0.0 < score_threshold < 1.0 + self.score_threshold = 2.0 + self.gpu_logic = True + + +@unittest.skipIf( + not core.is_compiled_with_cuda(), "core is not compiled with CUDA" +) +class TestMulticlassNMS3OpGPUFallback(TestMulticlassNMS3OpGPU): + def set_argument(self): + # Setting keep_top_k < 0 will fall back to CPU kernel + self.score_threshold = 0.01 + self.keep_top_k = -1 + self.gpu_logic = True + + if __name__ == '__main__': paddle.enable_static() unittest.main()