diff --git a/paddle/fluid/operators/detection/CMakeLists.txt b/paddle/fluid/operators/detection/CMakeLists.txt index 5a058ddbc59c6135bacf7c2dc4b5c8b687f9b2b1..aa8ed502fc94bd0970dfe5dbf00ef090e799ad30 100644 --- a/paddle/fluid/operators/detection/CMakeLists.txt +++ b/paddle/fluid/operators/detection/CMakeLists.txt @@ -30,7 +30,13 @@ detection_library(polygon_box_transform_op SRCS polygon_box_transform_op.cc polygon_box_transform_op.cu) detection_library(rpn_target_assign_op SRCS rpn_target_assign_op.cc) detection_library(generate_proposal_labels_op SRCS generate_proposal_labels_op.cc) -detection_library(generate_proposals_op SRCS generate_proposals_op.cc) + +if(WITH_GPU) + detection_library(generate_proposals_op SRCS generate_proposals_op.cc generate_proposals_op.cu DEPS memory cub) +else() + detection_library(generate_proposals_op SRCS generate_proposals_op.cc) +endif() + detection_library(roi_perspective_transform_op SRCS roi_perspective_transform_op.cc roi_perspective_transform_op.cu) #Export local libraries to parent set(DETECTION_LIBRARY ${LOCAL_DETECTION_LIBS} PARENT_SCOPE) diff --git a/paddle/fluid/operators/detection/generate_proposals_op.cc b/paddle/fluid/operators/detection/generate_proposals_op.cc index c33aa255362bc5234f2813fb93e70c943b03c33f..818d58ea9ee327fd99182ad2f8cbeed07e6aaea2 100644 --- a/paddle/fluid/operators/detection/generate_proposals_op.cc +++ b/paddle/fluid/operators/detection/generate_proposals_op.cc @@ -15,6 +15,7 @@ limitations under the License. */ #include #include #include "paddle/fluid/framework/op_registry.h" +#include "paddle/fluid/framework/var_type.h" #include "paddle/fluid/operators/gather.h" #include "paddle/fluid/operators/math/math_function.h" @@ -69,7 +70,7 @@ class GenerateProposalsOp : public framework::OperatorWithKernel { const framework::ExecutionContext &ctx) const override { return framework::OpKernelType( framework::ToDataType(ctx.Input("Anchors")->type()), - platform::CPUPlace()); + ctx.device_context()); } }; @@ -162,7 +163,7 @@ void FilterBoxes(const platform::DeviceContext &ctx, Tensor *boxes, const T *im_info_data = im_info.data(); T *boxes_data = boxes->mutable_data(ctx.GetPlace()); T im_scale = im_info_data[2]; - keep->Resize({boxes->dims()[0], 1}); + keep->Resize({boxes->dims()[0]}); min_size = std::max(min_size, 1.0f); int *keep_data = keep->mutable_data(ctx.GetPlace()); @@ -463,7 +464,7 @@ class GenerateProposalsOpMaker : public framework::OpProtoAndCheckerMaker { AddAttr("post_nms_topN", "post_nms_topN"); AddAttr("nms_thresh", "nms_thres"); AddAttr("min_size", "min size"); - AddAttr("eta", "eta"); + AddAttr("eta", "The parameter for adaptive NMS."); AddComment(R"DOC( Generate Proposals OP diff --git a/paddle/fluid/operators/detection/generate_proposals_op.cu b/paddle/fluid/operators/detection/generate_proposals_op.cu new file mode 100644 index 0000000000000000000000000000000000000000..6146ff509d768c0317a5c65ed22af1a3075977a2 --- /dev/null +++ b/paddle/fluid/operators/detection/generate_proposals_op.cu @@ -0,0 +1,449 @@ +/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include +#include +#include +#include "cub/cub.cuh" +#include "paddle/fluid/framework/op_registry.h" +#include "paddle/fluid/memory/memory.h" +#include "paddle/fluid/operators/gather.cu.h" +#include "paddle/fluid/operators/math/math_function.h" + +namespace paddle { +namespace operators { + +using Tensor = framework::Tensor; +using LoDTensor = framework::LoDTensor; + +namespace { + +#define DIVUP(m, n) ((m) / (n) + ((m) % (n) > 0)) +#define CUDA_1D_KERNEL_LOOP(i, n) \ + for (int i = blockIdx.x * blockDim.x + threadIdx.x; i < (n); \ + i += blockDim.x * gridDim.x) + +int const kThreadsPerBlock = sizeof(uint64_t) * 8; + +template +__global__ void RangeInitKernel(const T start, const T delta, const int size, + T *out) { + CUDA_1D_KERNEL_LOOP(i, size) { out[i] = start + i * delta; } +} + +template +void SortDescending(const platform::CUDADeviceContext &ctx, const Tensor &value, + Tensor *value_out, Tensor *index_out) { + int num = value.numel(); + Tensor index_in_t; + int *idx_in = index_in_t.mutable_data({num}, ctx.GetPlace()); + int block = 512; + auto stream = ctx.stream(); + RangeInitKernel<<>>(0, 1, num, idx_in); + int *idx_out = index_out->mutable_data({num}, ctx.GetPlace()); + + const T *keys_in = value.data(); + T *keys_out = value_out->mutable_data({num}, ctx.GetPlace()); + + // Determine temporary device storage requirements + void *d_temp_storage = NULL; + size_t temp_storage_bytes = 0; + cub::DeviceRadixSort::SortPairsDescending( + d_temp_storage, temp_storage_bytes, keys_in, keys_out, idx_in, idx_out, + num); + + // Allocate temporary storage + auto place = boost::get(ctx.GetPlace()); + d_temp_storage = memory::Alloc(place, temp_storage_bytes); + + // Run sorting operation + cub::DeviceRadixSort::SortPairsDescending( + d_temp_storage, temp_storage_bytes, keys_in, keys_out, idx_in, idx_out, + num); + + memory::Free(place, d_temp_storage); +} + +template +__device__ __forceinline__ T Min(T x, T y) { + return x < y ? x : y; +} + +template +__device__ __forceinline__ T Max(T x, T y) { + return x > y ? x : y; +} + +template +__global__ void BoxDecodeAndClipKernel(const T *anchor, const T *deltas, + const T *var, const int *index, + const T *im_info, const int num, + T *proposals) { + T kBBoxClipDefault = log(1000.0 / 16.0); + CUDA_1D_KERNEL_LOOP(i, num) { + int k = index[i] * 4; + T axmin = anchor[k]; + T aymin = anchor[k + 1]; + T axmax = anchor[k + 2]; + T aymax = anchor[k + 3]; + + T w = axmax - axmin + 1.0; + T h = aymax - aymin + 1.0; + T cx = axmin + 0.5 * w; + T cy = aymin + 0.5 * h; + + T dxmin = deltas[k]; + T dymin = deltas[k + 1]; + T dxmax = deltas[k + 2]; + T dymax = deltas[k + 3]; + + T d_cx = 0., d_cy = 0., d_w = 0., d_h = 0.; + if (var) { + d_cx = cx + dxmin * w * var[k]; + d_cy = cy + dymin * h * var[k + 1]; + d_w = exp(Min(dxmax * var[k + 2], kBBoxClipDefault)) * w; + d_h = exp(Min(dymax * var[k + 3], kBBoxClipDefault)) * h; + } else { + d_cx = cx + dxmin * w; + d_cy = cy + dymin * h; + d_w = exp(Min(dxmax, kBBoxClipDefault)) * w; + d_h = exp(Min(dymax, kBBoxClipDefault)) * h; + } + + T oxmin = d_cx - d_w * 0.5; + T oymin = d_cy - d_h * 0.5; + T oxmax = d_cx + d_w * 0.5 - 1.; + T oymax = d_cy + d_h * 0.5 - 1.; + + proposals[i * 4] = Max(Min(oxmin, im_info[1] - 1.), 0.); + proposals[i * 4 + 1] = Max(Min(oymin, im_info[0] - 1.), 0.); + proposals[i * 4 + 2] = Max(Min(oxmax, im_info[1] - 1.), 0.); + proposals[i * 4 + 3] = Max(Min(oymax, im_info[0] - 1.), 0.); + } +} + +template +__global__ void FilterBBoxes(const T *bboxes, const T *im_info, + const T min_size, const int num, int *keep_num, + int *keep) { + T im_h = im_info[0]; + T im_w = im_info[1]; + T im_scale = im_info[2]; + + int cnt = 0; + __shared__ int keep_index[BlockSize]; + + CUDA_1D_KERNEL_LOOP(i, num) { + keep_index[threadIdx.x] = -1; + __syncthreads(); + + int k = i * 4; + T xmin = bboxes[k]; + T ymin = bboxes[k + 1]; + T xmax = bboxes[k + 2]; + T ymax = bboxes[k + 3]; + + T w = xmax - xmin + 1.0; + T h = ymax - ymin + 1.0; + T cx = xmin + w / 2.; + T cy = ymin + h / 2.; + + T w_s = (xmax - xmin) / im_scale + 1.; + T h_s = (ymax - ymin) / im_scale + 1.; + + if (w_s >= min_size && h_s >= min_size && cx <= im_w && cy <= im_h) { + keep_index[threadIdx.x] = i; + } + __syncthreads(); + if (threadIdx.x == 0) { + int size = (num - i) < BlockSize ? num - i : BlockSize; + for (int j = 0; j < size; ++j) { + if (keep_index[j] > -1) { + keep[cnt++] = keep_index[j]; + } + } + } + __syncthreads(); + } + if (threadIdx.x == 0) { + keep_num[0] = cnt; + } +} + +__device__ inline float IoU(const float *a, const float *b) { + float left = max(a[0], b[0]), right = min(a[2], b[2]); + float top = max(a[1], b[1]), bottom = min(a[3], b[3]); + float width = max(right - left + 1, 0.f), height = max(bottom - top + 1, 0.f); + float inter_s = width * height; + float s_a = (a[2] - a[0] + 1) * (a[3] - a[1] + 1); + float s_b = (b[2] - b[0] + 1) * (b[3] - b[1] + 1); + return inter_s / (s_a + s_b - inter_s); +} + +__global__ void NMSKernel(const int n_boxes, const float nms_overlap_thresh, + const float *dev_boxes, uint64_t *dev_mask) { + const int row_start = blockIdx.y; + const int col_start = blockIdx.x; + + const int row_size = + min(n_boxes - row_start * kThreadsPerBlock, kThreadsPerBlock); + const int col_size = + min(n_boxes - col_start * kThreadsPerBlock, kThreadsPerBlock); + + __shared__ float block_boxes[kThreadsPerBlock * 4]; + if (threadIdx.x < col_size) { + block_boxes[threadIdx.x * 4 + 0] = + dev_boxes[(kThreadsPerBlock * col_start + threadIdx.x) * 4 + 0]; + block_boxes[threadIdx.x * 4 + 1] = + dev_boxes[(kThreadsPerBlock * col_start + threadIdx.x) * 4 + 1]; + block_boxes[threadIdx.x * 4 + 2] = + dev_boxes[(kThreadsPerBlock * col_start + threadIdx.x) * 4 + 2]; + block_boxes[threadIdx.x * 4 + 3] = + dev_boxes[(kThreadsPerBlock * col_start + threadIdx.x) * 4 + 3]; + } + __syncthreads(); + + if (threadIdx.x < row_size) { + const int cur_box_idx = kThreadsPerBlock * row_start + threadIdx.x; + const float *cur_box = dev_boxes + cur_box_idx * 4; + int i = 0; + uint64_t t = 0; + int start = 0; + if (row_start == col_start) { + start = threadIdx.x + 1; + } + for (i = start; i < col_size; i++) { + if (IoU(cur_box, block_boxes + i * 4) > nms_overlap_thresh) { + t |= 1ULL << i; + } + } + const int col_blocks = DIVUP(n_boxes, kThreadsPerBlock); + dev_mask[cur_box_idx * col_blocks + col_start] = t; + } +} + +template +void NMS(const platform::CUDADeviceContext &ctx, const Tensor &proposals, + const Tensor &sorted_indices, const T nms_threshold, + Tensor *keep_out) { + int boxes_num = proposals.dims()[0]; + PADDLE_ENFORCE_EQ(boxes_num, sorted_indices.dims()[0]); + + const int col_blocks = DIVUP(boxes_num, kThreadsPerBlock); + dim3 blocks(DIVUP(boxes_num, kThreadsPerBlock), + DIVUP(boxes_num, kThreadsPerBlock)); + dim3 threads(kThreadsPerBlock); + + const T *boxes = proposals.data(); + auto place = boost::get(ctx.GetPlace()); + int size_bytes = boxes_num * col_blocks * sizeof(uint64_t); + uint64_t *d_mask = + reinterpret_cast(memory::Alloc(place, size_bytes)); + NMSKernel<<>>(boxes_num, nms_threshold, boxes, d_mask); + uint64_t *h_mask = reinterpret_cast( + memory::Alloc(platform::CPUPlace(), size_bytes)); + memory::Copy(platform::CPUPlace(), h_mask, place, d_mask, size_bytes, 0); + + std::vector remv(col_blocks); + memset(&remv[0], 0, sizeof(uint64_t) * col_blocks); + + std::vector keep_vec; + int num_to_keep = 0; + for (int i = 0; i < boxes_num; i++) { + int nblock = i / kThreadsPerBlock; + int inblock = i % kThreadsPerBlock; + + if (!(remv[nblock] & (1ULL << inblock))) { + ++num_to_keep; + keep_vec.push_back(i); + uint64_t *p = &h_mask[0] + i * col_blocks; + for (int j = nblock; j < col_blocks; j++) { + remv[j] |= p[j]; + } + } + } + int *keep = keep_out->mutable_data({num_to_keep}, ctx.GetPlace()); + memory::Copy(place, keep, platform::CPUPlace(), keep_vec.data(), + sizeof(int) * num_to_keep, 0); + memory::Free(place, d_mask); + memory::Free(platform::CPUPlace(), h_mask); +} + +template +std::pair ProposalForOneImage( + const platform::CUDADeviceContext &ctx, const Tensor &im_info, + const Tensor &anchors, const Tensor &variances, + const Tensor &bbox_deltas, // [M, 4] + const Tensor &scores, // [N, 1] + int pre_nms_top_n, int post_nms_top_n, float nms_thresh, float min_size, + float eta) { + // 1. pre nms + Tensor scores_sort, index_sort; + SortDescending(ctx, scores, &scores_sort, &index_sort); + int num = scores.numel(); + int pre_nms_num = (pre_nms_top_n <= 0 || pre_nms_top_n > num) ? scores.numel() + : pre_nms_top_n; + scores_sort.Resize({pre_nms_num, 1}); + index_sort.Resize({pre_nms_num, 1}); + + // 2. box decode and clipping + Tensor proposals; + proposals.mutable_data({pre_nms_num, 4}, ctx.GetPlace()); + int block = 512; + auto stream = ctx.stream(); + BoxDecodeAndClipKernel<<>>( + anchors.data(), bbox_deltas.data(), variances.data(), + index_sort.data(), im_info.data(), pre_nms_num, + proposals.data()); + + // 3. filter + Tensor keep_index, keep_num_t; + keep_index.mutable_data({pre_nms_num}, ctx.GetPlace()); + keep_num_t.mutable_data({1}, ctx.GetPlace()); + min_size = std::max(min_size, 1.0f); + FilterBBoxes<<<1, 512, 0, stream>>>( + proposals.data(), im_info.data(), min_size, pre_nms_num, + keep_num_t.data(), keep_index.data()); + int keep_num; + const auto gpu_place = boost::get(ctx.GetPlace()); + memory::Copy(platform::CPUPlace(), &keep_num, gpu_place, + keep_num_t.data(), sizeof(int), 0); + keep_index.Resize({keep_num}); + + Tensor scores_filter, proposals_filter; + proposals_filter.mutable_data({keep_num, 4}, ctx.GetPlace()); + scores_filter.mutable_data({keep_num, 1}, ctx.GetPlace()); + GPUGather(ctx, proposals, keep_index, &proposals_filter); + GPUGather(ctx, scores_sort, keep_index, &scores_filter); + + if (nms_thresh <= 0) { + return std::make_pair(proposals_filter, scores_filter); + } + + // 4. nms + Tensor keep_nms; + NMS(ctx, proposals_filter, keep_index, nms_thresh, &keep_nms); + if (post_nms_top_n > 0 && post_nms_top_n < keep_nms.numel()) { + keep_nms.Resize({post_nms_top_n}); + } + + Tensor scores_nms, proposals_nms; + proposals_nms.mutable_data({keep_nms.numel(), 4}, ctx.GetPlace()); + scores_nms.mutable_data({keep_nms.numel(), 1}, ctx.GetPlace()); + GPUGather(ctx, proposals_filter, keep_nms, &proposals_nms); + GPUGather(ctx, scores_filter, keep_nms, &scores_nms); + + return std::make_pair(proposals_nms, scores_nms); +} +} // namespace + +template +class CUDAGenerateProposalsKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext &context) const override { + auto *scores = context.Input("Scores"); + auto *bbox_deltas = context.Input("BboxDeltas"); + auto *im_info = context.Input("ImInfo"); + auto *anchors = context.Input("Anchors"); + auto *variances = context.Input("Variances"); + + auto *rpn_rois = context.Output("RpnRois"); + auto *rpn_roi_probs = context.Output("RpnRoiProbs"); + + int pre_nms_top_n = context.Attr("pre_nms_topN"); + int post_nms_top_n = context.Attr("post_nms_topN"); + float nms_thresh = context.Attr("nms_thresh"); + float min_size = context.Attr("min_size"); + float eta = context.Attr("eta"); + PADDLE_ENFORCE_GE(eta, 1., "Not support adaptive NMS."); + + auto &dev_ctx = context.template device_context(); + + auto scores_dim = scores->dims(); + int64_t num = scores_dim[0]; + int64_t c_score = scores_dim[1]; + int64_t h_score = scores_dim[2]; + int64_t w_score = scores_dim[3]; + + auto bbox_dim = bbox_deltas->dims(); + int64_t c_bbox = bbox_dim[1]; + int64_t h_bbox = bbox_dim[2]; + int64_t w_bbox = bbox_dim[3]; + + Tensor bbox_deltas_swap, scores_swap; + bbox_deltas_swap.mutable_data({num, h_bbox, w_bbox, c_bbox}, + dev_ctx.GetPlace()); + scores_swap.mutable_data({num, h_score, w_score, c_score}, + dev_ctx.GetPlace()); + + math::Transpose trans; + std::vector axis = {0, 2, 3, 1}; + trans(dev_ctx, *bbox_deltas, &bbox_deltas_swap, axis); + trans(dev_ctx, *scores, &scores_swap, axis); + + Tensor *anchor = const_cast(anchors); + anchor->Resize({anchors->numel() / 4, 4}); + Tensor *var = const_cast(variances); + var->Resize({var->numel() / 4, 4}); + + rpn_rois->mutable_data({bbox_deltas->numel() / 4, 4}, + context.GetPlace()); + rpn_roi_probs->mutable_data({scores->numel(), 1}, context.GetPlace()); + + T *rpn_rois_data = rpn_rois->data(); + T *rpn_roi_probs_data = rpn_roi_probs->data(); + + auto place = boost::get(dev_ctx.GetPlace()); + + int64_t num_proposals = 0; + std::vector offset(1, 0); + for (int64_t i = 0; i < num; ++i) { + Tensor im_info_slice = im_info->Slice(i, i + 1); + Tensor bbox_deltas_slice = bbox_deltas_swap.Slice(i, i + 1); + Tensor scores_slice = scores_swap.Slice(i, i + 1); + + bbox_deltas_slice.Resize({h_bbox * w_bbox * c_bbox / 4, 4}); + scores_slice.Resize({h_score * w_score * c_score, 1}); + + std::pair box_score_pair = + ProposalForOneImage(dev_ctx, im_info_slice, *anchor, *var, + bbox_deltas_slice, scores_slice, pre_nms_top_n, + post_nms_top_n, nms_thresh, min_size, eta); + + Tensor proposals = box_score_pair.first; + Tensor scores = box_score_pair.second; + + memory::Copy(place, rpn_rois_data + num_proposals * 4, place, + proposals.data(), sizeof(T) * proposals.numel(), 0); + memory::Copy(place, rpn_roi_probs_data + num_proposals, place, + scores.data(), sizeof(T) * scores.numel(), 0); + num_proposals += proposals.dims()[0]; + offset.emplace_back(num_proposals); + } + framework::LoD lod; + lod.emplace_back(offset); + rpn_rois->set_lod(lod); + rpn_roi_probs->set_lod(lod); + rpn_rois->Resize({num_proposals, 4}); + rpn_roi_probs->Resize({num_proposals, 1}); + } +}; + +} // namespace operators +} // namespace paddle + +namespace ops = paddle::operators; +REGISTER_OP_CUDA_KERNEL(generate_proposals, + ops::CUDAGenerateProposalsKernel< + paddle::platform::CUDADeviceContext, float>); diff --git a/python/paddle/fluid/tests/unittests/test_generate_proposals_op.py b/python/paddle/fluid/tests/unittests/test_generate_proposals_op.py index 86e27fe29ed945ec77fbbcdbd1c7cc6ecfba0fd5..9340d558577b4b3141df9317900ee33bbb683a0e 100644 --- a/python/paddle/fluid/tests/unittests/test_generate_proposals_op.py +++ b/python/paddle/fluid/tests/unittests/test_generate_proposals_op.py @@ -277,7 +277,6 @@ class TestGenerateProposalsOp(OpTest): 'eta': self.eta } - print("lod = ", self.lod) self.outputs = { 'RpnRois': (self.rpn_rois[0], [self.lod]), 'RpnRoiProbs': (self.rpn_roi_probs[0], [self.lod]) @@ -295,7 +294,7 @@ class TestGenerateProposalsOp(OpTest): self.post_nms_topN = 5000 # train 6000, test 1000 self.nms_thresh = 0.7 self.min_size = 3.0 - self.eta = 0.8 + self.eta = 1. def init_test_input(self): batch_size = 1