From de2c5fd60cc029d5d9c3d3340d21d4d2e296332e Mon Sep 17 00:00:00 2001 From: Netpunk <69072522+Patrick-Star125@users.noreply.github.com> Date: Thu, 8 Dec 2022 11:07:47 +0800 Subject: [PATCH] [PHI decoupling] remove bbox_util.h from phi dependencies (#48761) * remove bbox_util.h from phi * add file bbox_util.h * reframe bbox_util.h --- .../phi/kernels/funcs/detection/bbox_util.h | 28 +++++++++++++++++++ .../gpu/distribute_fpn_proposals_kernel.cu | 17 +++++++++-- .../kernels/gpu/generate_proposals_kernel.cu | 10 ++----- 3 files changed, 44 insertions(+), 11 deletions(-) create mode 100644 paddle/phi/kernels/funcs/detection/bbox_util.h diff --git a/paddle/phi/kernels/funcs/detection/bbox_util.h b/paddle/phi/kernels/funcs/detection/bbox_util.h new file mode 100644 index 0000000000..4acaa4406b --- /dev/null +++ b/paddle/phi/kernels/funcs/detection/bbox_util.h @@ -0,0 +1,28 @@ +/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#pragma once + +namespace phi { +namespace funcs { + +struct RangeInitFunctor { + int start_; + int delta_; + int *out_; + __device__ void operator()(size_t i) { out_[i] = start_ + i * delta_; } +}; + +} // namespace funcs +} // namespace phi diff --git a/paddle/phi/kernels/gpu/distribute_fpn_proposals_kernel.cu b/paddle/phi/kernels/gpu/distribute_fpn_proposals_kernel.cu index bcce09649a..3d50a75ae2 100644 --- a/paddle/phi/kernels/gpu/distribute_fpn_proposals_kernel.cu +++ b/paddle/phi/kernels/gpu/distribute_fpn_proposals_kernel.cu @@ -24,6 +24,7 @@ namespace cub = hipcub; #include "paddle/phi/backends/gpu/gpu_context.h" #include "paddle/phi/core/kernel_registry.h" +#include "paddle/phi/kernels/funcs/detection/bbox_util.h" #include "paddle/phi/kernels/funcs/distribute_fpn_proposals_functor.h" #include "paddle/phi/kernels/funcs/for_range.h" #include "paddle/phi/kernels/funcs/gather.cu.h" @@ -31,7 +32,6 @@ namespace cub = hipcub; #include "paddle/fluid/memory/allocation/allocator.h" #include "paddle/fluid/memory/memcpy.h" -#include "paddle/fluid/operators/detection/bbox_util.h" #include "paddle/phi/backends/gpu/gpu_primitives.h" namespace phi { @@ -62,7 +62,18 @@ __global__ void GPUDistFpnProposalsHelper(const int nthreads, const T* offset_roi = rois + i * BBoxSize; int roi_batch_ind = roi_batch_id_data[i]; // get the target level of current rois - T roi_area = paddle::operators::RoIArea(offset_roi, pixel_offset); + T roi_area; + if (offset_roi[2] < offset_roi[0] || offset_roi[3] < offset_roi[1]) { + roi_area = static_cast(0.); + } else { + const T w = offset_roi[2] - offset_roi[0]; + const T h = offset_roi[3] - offset_roi[1]; + if (pixel_offset) { + roi_area = (w + 1) * (h + 1); + } else { + roi_area = w * h; + } + } T roi_scale = sqrt(roi_area); int tgt_lvl = floor( log2(roi_scale / static_cast(refer_scale) + (T)1e-8) + refer_level); @@ -155,7 +166,7 @@ void DistributeFpnProposalsKernel( index_in_t.Resize({roi_num}); int* idx_in = dev_ctx.template Alloc(&index_in_t); funcs::ForRange for_range(dev_ctx, roi_num); - for_range(paddle::operators::RangeInitFunctor{0, 1, idx_in}); + for_range(funcs::RangeInitFunctor{0, 1, idx_in}); DenseTensor keys_out_t; keys_out_t.Resize({roi_num}); diff --git a/paddle/phi/kernels/gpu/generate_proposals_kernel.cu b/paddle/phi/kernels/gpu/generate_proposals_kernel.cu index f750bd5fe7..622ef8100a 100644 --- a/paddle/phi/kernels/gpu/generate_proposals_kernel.cu +++ b/paddle/phi/kernels/gpu/generate_proposals_kernel.cu @@ -26,6 +26,7 @@ namespace cub = hipcub; #include "paddle/phi/backends/gpu/gpu_context.h" #include "paddle/phi/core/kernel_registry.h" +#include "paddle/phi/kernels/funcs/detection/bbox_util.h" #include "paddle/phi/kernels/funcs/for_range.h" #include "paddle/phi/kernels/funcs/gather.cu.h" #include "paddle/phi/kernels/funcs/math_function.h" @@ -38,13 +39,6 @@ int const kThreadsPerBlock = sizeof(uint64_t) * 8; static const double kBBoxClipDefault = std::log(1000.0 / 16.0); -struct RangeInitFunctor { - int start_; - int delta_; - int *out_; - __device__ void operator()(size_t i) { out_[i] = start_ + i * delta_; } -}; - template static void SortDescending(const phi::GPUContext &ctx, const DenseTensor &value, @@ -55,7 +49,7 @@ static void SortDescending(const phi::GPUContext &ctx, index_in_t.Resize(phi::make_ddim({num})); int *idx_in = ctx.template Alloc(&index_in_t); phi::funcs::ForRange for_range(ctx, num); - for_range(RangeInitFunctor{0, 1, idx_in}); + for_range(funcs::RangeInitFunctor{0, 1, idx_in}); index_out->Resize(phi::make_ddim({num})); int *idx_out = ctx.template Alloc(index_out); -- GitLab