From 8f8ed7de6e5599215eaa087e6f03cb758fb92632 Mon Sep 17 00:00:00 2001 From: YuanRisheng Date: Tue, 23 Aug 2022 10:55:53 +0800 Subject: [PATCH] [Phi]Move distribute_fpn_proposals to PHI (#45212) * move distribute_fpn_proposals * fix some code * fix yaml bugs * add set dtype * move proposal_impl to funcs * fix compile bugs --- .../fluid/operators/detection/CMakeLists.txt | 5 +- .../detection/distribute_fpn_proposals_op.cc | 63 +---- .../detection/distribute_fpn_proposals_op.cu | 262 ------------------ .../detection/distribute_fpn_proposals_op.h | 191 ------------- paddle/phi/api/yaml/legacy_api.yaml | 10 + paddle/phi/core/kernel_context.h | 7 + paddle/phi/infermeta/binary.cc | 51 ++++ paddle/phi/infermeta/binary.h | 13 + .../cpu/distribute_fpn_proposals_kernel.cc | 145 ++++++++++ .../kernels/distribute_fpn_proposals_kernel.h | 35 +++ .../funcs/distribute_fpn_proposals_functor.h | 68 +++++ .../gpu/distribute_fpn_proposals_kernel.cu | 259 +++++++++++++++++ .../compat/distribute_fpn_proposals_sig.cc | 31 +++ .../test_distribute_fpn_proposals_op.py | 16 ++ python/paddle/vision/ops.py | 7 + 15 files changed, 656 insertions(+), 507 deletions(-) delete mode 100644 paddle/fluid/operators/detection/distribute_fpn_proposals_op.cu delete mode 100644 paddle/fluid/operators/detection/distribute_fpn_proposals_op.h create mode 100644 paddle/phi/kernels/cpu/distribute_fpn_proposals_kernel.cc create mode 100644 paddle/phi/kernels/distribute_fpn_proposals_kernel.h create mode 100644 paddle/phi/kernels/funcs/distribute_fpn_proposals_functor.h create mode 100644 paddle/phi/kernels/gpu/distribute_fpn_proposals_kernel.cu create mode 100644 paddle/phi/ops/compat/distribute_fpn_proposals_sig.cc diff --git a/paddle/fluid/operators/detection/CMakeLists.txt b/paddle/fluid/operators/detection/CMakeLists.txt index 578827f56cb..0f53224c13b 100644 --- a/paddle/fluid/operators/detection/CMakeLists.txt +++ b/paddle/fluid/operators/detection/CMakeLists.txt @@ -93,9 +93,8 @@ if(WITH_GPU OR WITH_ROCM) generate_proposals_op.cu DEPS ${TMPDEPS}) detection_library(generate_proposals_v2_op SRCS generate_proposals_v2_op.cc DEPS ${TMPDEPS}) - detection_library( - distribute_fpn_proposals_op SRCS distribute_fpn_proposals_op.cc - distribute_fpn_proposals_op.cu DEPS ${TMPDEPS}) + detection_library(distribute_fpn_proposals_op SRCS + distribute_fpn_proposals_op.cc DEPS ${TMPDEPS}) detection_library(collect_fpn_proposals_op SRCS collect_fpn_proposals_op.cc collect_fpn_proposals_op.cu DEPS ${TMPDEPS}) else() diff --git a/paddle/fluid/operators/detection/distribute_fpn_proposals_op.cc b/paddle/fluid/operators/detection/distribute_fpn_proposals_op.cc index ec8d8a71008..91aeaf3df2f 100644 --- a/paddle/fluid/operators/detection/distribute_fpn_proposals_op.cc +++ b/paddle/fluid/operators/detection/distribute_fpn_proposals_op.cc @@ -12,7 +12,10 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ -#include "paddle/fluid/operators/detection/distribute_fpn_proposals_op.h" +#include "paddle/fluid/framework/infershape_utils.h" +#include "paddle/fluid/framework/op_registry.h" +#include "paddle/phi/core/infermeta_utils.h" +#include "paddle/phi/infermeta/binary.h" #include "paddle/fluid/framework/op_version_registry.h" @@ -23,52 +26,6 @@ class DistributeFpnProposalsOp : public framework::OperatorWithKernel { public: using framework::OperatorWithKernel::OperatorWithKernel; - void InferShape(framework::InferShapeContext* ctx) const override { - PADDLE_ENFORCE_EQ( - ctx->HasInput("FpnRois"), - true, - platform::errors::NotFound("Input(FpnRois) of DistributeFpnProposalsOp" - " is not found")); - PADDLE_ENFORCE_GE(ctx->Outputs("MultiFpnRois").size(), - 1UL, - platform::errors::InvalidArgument( - "Outputs(MultiFpnRois) of " - "DistributeFpnProposalsOp should not be empty")); - size_t min_level = static_cast(ctx->Attrs().Get("min_level")); - size_t max_level = static_cast(ctx->Attrs().Get("max_level")); - PADDLE_ENFORCE_GE( - max_level, - min_level, - platform::errors::InvalidArgument( - "max_level must not lower than " - "min_level. But received max_level = %d, min_level = %d", - max_level, - min_level)); - // Set the output shape - size_t num_out_rois = max_level - min_level + 1; - std::vector outs_dims; - outs_dims.reserve(num_out_rois); - for (size_t i = 0; i < num_out_rois; ++i) { - framework::DDim out_dim = {-1, 4}; - outs_dims.push_back(out_dim); - } - ctx->SetOutputsDim("MultiFpnRois", outs_dims); - ctx->SetOutputDim("RestoreIndex", {-1, 1}); - - if (ctx->HasOutputs("MultiLevelRoIsNum")) { - std::vector outs_num_dims; - for (size_t i = 0; i < num_out_rois; ++i) { - outs_num_dims.push_back({-1}); - } - ctx->SetOutputsDim("MultiLevelRoIsNum", outs_num_dims); - } - if (!ctx->IsRuntime()) { - for (size_t i = 0; i < num_out_rois; ++i) { - ctx->SetLoDLevel("MultiFpnRois", ctx->GetLoDLevel("FpnRois"), i); - } - } - } - protected: framework::OpKernelType GetExpectedKernelType( const framework::ExecutionContext& ctx) const override { @@ -125,15 +82,19 @@ we return an array which indicate the original index of rois in } // namespace paddle namespace ops = paddle::operators; + +DECLARE_INFER_SHAPE_FUNCTOR( + distribute_fpn_proposals, + DistributeFpnProposalsInferShapeFunctor, + PD_INFER_META(phi::DistributeFpnProposalsInferMeta)); + REGISTER_OPERATOR( distribute_fpn_proposals, ops::DistributeFpnProposalsOp, ops::DistributeFpnProposalsOpMaker, paddle::framework::EmptyGradOpMaker, - paddle::framework::EmptyGradOpMaker); -REGISTER_OP_CPU_KERNEL(distribute_fpn_proposals, - ops::DistributeFpnProposalsOpKernel, - ops::DistributeFpnProposalsOpKernel); + paddle::framework::EmptyGradOpMaker, + DistributeFpnProposalsInferShapeFunctor); REGISTER_OP_VERSION(distribute_fpn_proposals) .AddCheckpoint( R"ROC( diff --git a/paddle/fluid/operators/detection/distribute_fpn_proposals_op.cu b/paddle/fluid/operators/detection/distribute_fpn_proposals_op.cu deleted file mode 100644 index 1063382ef33..00000000000 --- a/paddle/fluid/operators/detection/distribute_fpn_proposals_op.cu +++ /dev/null @@ -1,262 +0,0 @@ -/* Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - -http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. */ - -#ifdef __NVCC__ -#include "cub/cub.cuh" -#endif -#ifdef __HIPCC__ -#include -namespace cub = hipcub; -#endif - -#include - -#include "paddle/fluid/memory/memcpy.h" -#include "paddle/fluid/operators/detection/bbox_util.h" -#include "paddle/fluid/operators/detection/distribute_fpn_proposals_op.h" -#include "paddle/fluid/platform/device/gpu/gpu_primitives.h" -#include "paddle/fluid/platform/for_range.h" -#include "paddle/phi/kernels/funcs/gather.cu.h" -#include "paddle/phi/kernels/funcs/math_function.h" - -namespace paddle { -namespace operators { - -using Tensor = framework::Tensor; -using LoDTensor = framework::LoDTensor; - -static constexpr int kNumCUDAThreads = 64; -static constexpr int kNumMaxinumNumBlocks = 4096; - -int const BBoxSize = 4; - -static inline int NumBlocks(const int N) { - return std::min((N + kNumCUDAThreads - 1) / kNumCUDAThreads, - kNumMaxinumNumBlocks); -} - -template -__global__ void GPUDistFpnProposalsHelper(const int nthreads, - const T* rois, - const int lod_size, - const int refer_level, - const int refer_scale, - const int max_level, - const int min_level, - int* roi_batch_id_data, - int* sub_lod_list, - int* target_lvls, - bool pixel_offset = true) { - CUDA_KERNEL_LOOP(i, nthreads) { - const T* offset_roi = rois + i * BBoxSize; - int roi_batch_ind = roi_batch_id_data[i]; - // get the target level of current rois - T roi_area = RoIArea(offset_roi, pixel_offset); - T roi_scale = sqrt(roi_area); - int tgt_lvl = floor( - log2(roi_scale / static_cast(refer_scale) + (T)1e-8) + refer_level); - tgt_lvl = min(max_level, max(tgt_lvl, min_level)); - target_lvls[i] = tgt_lvl; - // compute number of rois in the same batch and same target level - platform::CudaAtomicAdd( - sub_lod_list + (tgt_lvl - min_level) * lod_size + roi_batch_ind, 1); - } -} - -template -class GPUDistributeFpnProposalsOpKernel : public framework::OpKernel { - public: - void Compute(const framework::ExecutionContext& ctx) const override { - auto* fpn_rois = ctx.Input("FpnRois"); - - auto multi_fpn_rois = ctx.MultiOutput("MultiFpnRois"); - auto* restore_index = ctx.Output("RestoreIndex"); - - const int min_level = ctx.Attr("min_level"); - const int max_level = ctx.Attr("max_level"); - const int refer_level = ctx.Attr("refer_level"); - const int refer_scale = ctx.Attr("refer_scale"); - const bool pixel_offset = ctx.Attr("pixel_offset"); - int num_level = max_level - min_level + 1; - - // check that the fpn_rois is not empty - if (!ctx.HasInput("RoisNum")) { - PADDLE_ENFORCE_EQ( - fpn_rois->lod().size(), - 1UL, - platform::errors::InvalidArgument("DistributeFpnProposalsOp needs LoD" - "with one level")); - } - - std::vector fpn_rois_lod; - if (ctx.HasInput("RoisNum")) { - auto* rois_num = ctx.Input("RoisNum"); - fpn_rois_lod = GetLodFromRoisNum(rois_num); - } else { - fpn_rois_lod = fpn_rois->lod().back(); - } - int lod_size = fpn_rois_lod.size() - 1; - int roi_num = fpn_rois_lod[lod_size]; - - auto& dev_ctx = ctx.template device_context(); - - // get batch id by lod in CPU - Tensor roi_batch_id_list; - roi_batch_id_list.Resize({roi_num}); - int* roi_batch_id_data = - roi_batch_id_list.mutable_data(platform::CPUPlace()); - for (int n = 0; n < lod_size; ++n) { - for (size_t i = fpn_rois_lod[n]; i < fpn_rois_lod[n + 1]; ++i) { - roi_batch_id_data[i] = n; - } - } - // copy batch id list to GPU - Tensor roi_batch_id_list_gpu; - framework::TensorCopySync( - roi_batch_id_list, dev_ctx.GetPlace(), &roi_batch_id_list_gpu); - - Tensor sub_lod_list; - sub_lod_list.Resize({num_level, lod_size}); - int* sub_lod_list_data = sub_lod_list.mutable_data(dev_ctx.GetPlace()); - phi::funcs::SetConstant set_zero; - set_zero(dev_ctx, &sub_lod_list, static_cast(0)); - - Tensor target_lvls; - target_lvls.Resize({roi_num}); - int* target_lvls_data = target_lvls.mutable_data(dev_ctx.GetPlace()); - - int dist_blocks = NumBlocks(roi_num); - int threads = kNumCUDAThreads; - // get target levels and sub_lod list - GPUDistFpnProposalsHelper<<>>( - roi_num, - fpn_rois->data(), - lod_size, - refer_level, - refer_scale, - max_level, - min_level, - roi_batch_id_list_gpu.data(), - sub_lod_list_data, - target_lvls_data, - pixel_offset); - auto place = dev_ctx.GetPlace(); - - Tensor index_in_t; - int* idx_in = index_in_t.mutable_data({roi_num}, dev_ctx.GetPlace()); - platform::ForRange for_range(dev_ctx, roi_num); - for_range(RangeInitFunctor{0, 1, idx_in}); - - Tensor keys_out_t; - int* keys_out = keys_out_t.mutable_data({roi_num}, dev_ctx.GetPlace()); - Tensor index_out_t; - int* idx_out = index_out_t.mutable_data({roi_num}, dev_ctx.GetPlace()); - - // Determine temporary device storage requirements - size_t temp_storage_bytes = 0; - cub::DeviceRadixSort::SortPairs(nullptr, - temp_storage_bytes, - target_lvls_data, - keys_out, - idx_in, - idx_out, - roi_num, - 0, - sizeof(int) * 8, - dev_ctx.stream()); - // Allocate temporary storage - auto d_temp_storage = memory::Alloc(place, temp_storage_bytes); - - // Run sorting operation - // sort target level to get corresponding index - cub::DeviceRadixSort::SortPairs(d_temp_storage->ptr(), - temp_storage_bytes, - target_lvls_data, - keys_out, - idx_in, - idx_out, - roi_num, - 0, - sizeof(int) * 8, - dev_ctx.stream()); - - int* restore_idx_data = - restore_index->mutable_data({roi_num, 1}, dev_ctx.GetPlace()); - // sort current index to get restore index - cub::DeviceRadixSort::SortPairs(d_temp_storage->ptr(), - temp_storage_bytes, - idx_out, - keys_out, - idx_in, - restore_idx_data, - roi_num, - 0, - sizeof(int) * 8, - dev_ctx.stream()); - - int start = 0; - auto multi_rois_num = ctx.MultiOutput("MultiLevelRoIsNum"); - - std::vector sub_lod_list_cpu(lod_size * num_level); - memory::Copy(platform::CPUPlace(), - sub_lod_list_cpu.data(), - place, - sub_lod_list_data, - sizeof(int) * lod_size * num_level, - dev_ctx.stream()); - dev_ctx.Wait(); - - for (int i = 0; i < num_level; ++i) { - Tensor sub_lod = sub_lod_list.Slice(i, i + 1); - // transfer length-based lod to offset-based lod - std::vector offset(1, 0); - for (int j = 0; j < lod_size; ++j) { - offset.emplace_back(offset.back() + sub_lod_list_cpu[i * lod_size + j]); - } - - int sub_rois_num = offset.back(); - - int end = start + sub_rois_num; - if (end > start) { - Tensor sub_idx = index_out_t.Slice(start, end); - start = end; - multi_fpn_rois[i]->mutable_data({sub_rois_num, kBoxDim}, - dev_ctx.GetPlace()); - phi::funcs::GPUGather( - dev_ctx, *fpn_rois, sub_idx, multi_fpn_rois[i]); - } else { - multi_fpn_rois[i]->mutable_data({sub_rois_num, kBoxDim}, - dev_ctx.GetPlace()); - } - if (multi_rois_num.size() > 0) { - Tensor* rois_num_t = multi_rois_num[i]; - paddle::framework::TensorCopySync( - sub_lod, dev_ctx.GetPlace(), rois_num_t); - rois_num_t->Resize({lod_size}); - } - framework::LoD lod; - lod.emplace_back(offset); - multi_fpn_rois[i]->set_lod(lod); - } - } -}; - -} // namespace operators -} // namespace paddle - -namespace ops = paddle::operators; -REGISTER_OP_CUDA_KERNEL( - distribute_fpn_proposals, - ops::GPUDistributeFpnProposalsOpKernel, - ops::GPUDistributeFpnProposalsOpKernel); diff --git a/paddle/fluid/operators/detection/distribute_fpn_proposals_op.h b/paddle/fluid/operators/detection/distribute_fpn_proposals_op.h deleted file mode 100644 index afe4a54d6d7..00000000000 --- a/paddle/fluid/operators/detection/distribute_fpn_proposals_op.h +++ /dev/null @@ -1,191 +0,0 @@ -/* Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. */ - -#pragma once - -#include -#include -#include -#include -#include - -#include "paddle/fluid/framework/op_registry.h" -#include "paddle/phi/kernels/funcs/math_function.h" - -namespace paddle { -namespace operators { - -const int kBoxDim = 4; - -inline std::vector GetLodFromRoisNum( - const framework::Tensor* rois_num) { - std::vector rois_lod; - auto* rois_num_data = rois_num->data(); - framework::Tensor cpu_tensor; - if (platform::is_gpu_place(rois_num->place())) { - paddle::framework::TensorCopySync( - *rois_num, platform::CPUPlace(), &cpu_tensor); - rois_num_data = cpu_tensor.data(); - } - rois_lod.push_back(static_cast(0)); - for (int i = 0; i < rois_num->numel(); ++i) { - rois_lod.push_back(rois_lod.back() + static_cast(rois_num_data[i])); - } - return rois_lod; -} - -template -static inline T BBoxArea(const T* box, bool pixel_offset) { - if (box[2] < box[0] || box[3] < box[1]) { - // If coordinate values are is invalid - // (e.g. xmax < xmin or ymax < ymin), return 0. - return static_cast(0.); - } else { - const T w = box[2] - box[0]; - const T h = box[3] - box[1]; - if (pixel_offset) { - // If coordinate values are not within range [0, 1]. - return (w + 1) * (h + 1); - } else { - return w * h; - } - } -} - -template -class DistributeFpnProposalsOpKernel : public framework::OpKernel { - public: - void Compute(const framework::ExecutionContext& context) const override { - auto* fpn_rois = context.Input("FpnRois"); - - auto multi_fpn_rois = - context.MultiOutput("MultiFpnRois"); - - auto* restore_index = - context.Output("RestoreIndex"); - - const int min_level = context.Attr("min_level"); - const int max_level = context.Attr("max_level"); - const int refer_level = context.Attr("refer_level"); - const int refer_scale = context.Attr("refer_scale"); - const bool pixel_offset = context.Attr("pixel_offset"); - const int num_level = max_level - min_level + 1; - - // check that the fpn_rois is not empty - if (!context.HasInput("RoisNum")) { - PADDLE_ENFORCE_EQ(fpn_rois->lod().size(), - 1UL, - platform::errors::InvalidArgument( - "DistributeFpnProposalsOp needs LoD " - "with one level. But received level is %d", - fpn_rois->lod().size())); - } - - std::vector fpn_rois_lod; - int fpn_rois_num; - if (context.HasInput("RoisNum")) { - auto* rois_num = context.Input("RoisNum"); - fpn_rois_lod = GetLodFromRoisNum(rois_num); - } else { - fpn_rois_lod = fpn_rois->lod().back(); - } - fpn_rois_num = fpn_rois_lod[fpn_rois_lod.size() - 1]; - std::vector target_level; - // std::vector target_level(fpn_rois_num, -1); - // record the number of rois in each level - std::vector num_rois_level(num_level, 0); - std::vector num_rois_level_integral(num_level + 1, 0); - for (size_t i = 0; i < fpn_rois_lod.size() - 1; ++i) { - auto fpn_rois_slice = - fpn_rois->Slice(fpn_rois_lod[i], fpn_rois_lod[i + 1]); - const T* rois_data = fpn_rois_slice.data(); - for (int j = 0; j < fpn_rois_slice.dims()[0]; ++j) { - // get the target level of current rois - T roi_scale = std::sqrt(BBoxArea(rois_data, pixel_offset)); - int tgt_lvl = std::floor(std::log2(roi_scale / refer_scale + (T)1e-6) + - refer_level); - tgt_lvl = std::min(max_level, std::max(tgt_lvl, min_level)); - target_level.push_back(tgt_lvl); - num_rois_level[tgt_lvl - min_level]++; - rois_data += kBoxDim; - } - } - // define the output rois - // pointer which point to each level fpn rois - std::vector multi_fpn_rois_data(num_level); - // lod0 which will record the offset information of each level rois - std::vector> multi_fpn_rois_lod0; - for (int i = 0; i < num_level; ++i) { - // allocate memory for each level rois - multi_fpn_rois[i]->mutable_data({num_rois_level[i], kBoxDim}, - context.GetPlace()); - multi_fpn_rois_data[i] = multi_fpn_rois[i]->data(); - std::vector lod0(1, 0); - multi_fpn_rois_lod0.push_back(lod0); - // statistic start point for each level rois - num_rois_level_integral[i + 1] = - num_rois_level_integral[i] + num_rois_level[i]; - } - restore_index->mutable_data({fpn_rois_num, 1}, context.GetPlace()); - int* restore_index_data = restore_index->data(); - std::vector restore_index_inter(fpn_rois_num, -1); - // distribute the rois into different fpn level by target level - for (size_t i = 0; i < fpn_rois_lod.size() - 1; ++i) { - auto fpn_rois_slice = - fpn_rois->Slice(fpn_rois_lod[i], fpn_rois_lod[i + 1]); - const T* rois_data = fpn_rois_slice.data(); - size_t cur_offset = fpn_rois_lod[i]; - // std::vector lod_offset[num_level]; - for (int j = 0; j < num_level; j++) { - multi_fpn_rois_lod0[j].push_back(multi_fpn_rois_lod0[j][i]); - } - for (int j = 0; j < fpn_rois_slice.dims()[0]; ++j) { - int lvl = target_level[cur_offset + j]; - memcpy(multi_fpn_rois_data[lvl - min_level], - rois_data, - kBoxDim * sizeof(T)); - multi_fpn_rois_data[lvl - min_level] += kBoxDim; - int index_in_shuffle = num_rois_level_integral[lvl - min_level] + - multi_fpn_rois_lod0[lvl - min_level][i + 1]; - restore_index_inter[index_in_shuffle] = cur_offset + j; - multi_fpn_rois_lod0[lvl - min_level][i + 1]++; - rois_data += kBoxDim; - } - } - for (int i = 0; i < fpn_rois_num; ++i) { - restore_index_data[restore_index_inter[i]] = i; - } - auto multi_rois_num = - context.MultiOutput("MultiLevelRoIsNum"); - if (multi_rois_num.size() > 0) { - int batch_size = fpn_rois_lod.size() - 1; - for (int i = 0; i < num_level; ++i) { - int* rois_num_data = multi_rois_num[i]->mutable_data( - {batch_size}, context.GetPlace()); - for (int j = 0; j < batch_size; ++j) { - rois_num_data[j] = static_cast(multi_fpn_rois_lod0[i][j + 1] - - multi_fpn_rois_lod0[i][j]); - } - } - } - // merge lod information into LoDTensor - for (int i = 0; i < num_level; ++i) { - framework::LoD lod; - lod.emplace_back(multi_fpn_rois_lod0[i]); - multi_fpn_rois[i]->set_lod(lod); - } - } -}; -} // namespace operators -} // namespace paddle diff --git a/paddle/phi/api/yaml/legacy_api.yaml b/paddle/phi/api/yaml/legacy_api.yaml index 17a0d99a3f9..a2bbb28a34c 100755 --- a/paddle/phi/api/yaml/legacy_api.yaml +++ b/paddle/phi/api/yaml/legacy_api.yaml @@ -704,6 +704,16 @@ kernel : func : diag_embed +- api : distribute_fpn_proposals + args : (Tensor fpn_rois, Tensor rois_num, int min_level, int max_level, int refer_level, int refer_scale, bool pixel_offset) + output : Tensor[](multi_fpn_rois){max_level - min_level + 1}, Tensor[](multi_level_rois_num){max_level - min_level + 1}, Tensor(restore_index) + infer_meta : + func : DistributeFpnProposalsInferMeta + kernel : + func : distribute_fpn_proposals + data_type : fpn_rois + optional : rois_num + - api : divide args : (Tensor x, Tensor y) output : Tensor diff --git a/paddle/phi/core/kernel_context.h b/paddle/phi/core/kernel_context.h index 0f155f445ec..830443fca8f 100644 --- a/paddle/phi/core/kernel_context.h +++ b/paddle/phi/core/kernel_context.h @@ -122,8 +122,15 @@ class KernelContext { template std::vector MutableOutputBetween(size_t start, size_t end) { std::vector v; + bool is_empty_vector = true; for (size_t i = start; i < end; ++i) { v.emplace_back(static_cast(outputs_.at(i))); + if (outputs_.at(i) != nullptr) { + is_empty_vector = false; + } + } + if (is_empty_vector) { + v.clear(); } return v; } diff --git a/paddle/phi/infermeta/binary.cc b/paddle/phi/infermeta/binary.cc index 909eee908b5..7631457dc4a 100644 --- a/paddle/phi/infermeta/binary.cc +++ b/paddle/phi/infermeta/binary.cc @@ -951,6 +951,57 @@ void DistInferMeta(const MetaTensor& x, out->set_dtype(x.dtype()); } +void DistributeFpnProposalsInferMeta( + const MetaTensor& fpn_rois, + const MetaTensor& rois_num, + int min_level, + int max_level, + int refer_level, + int refer_scale, + bool pixel_offset, + std::vector multi_fpn_rois, + std::vector multi_level_rois_num, + MetaTensor* restore_index, + MetaConfig config) { + PADDLE_ENFORCE_GE( + multi_fpn_rois.size(), + 1UL, + errors::InvalidArgument("Outputs(MultiFpnRois) of " + "DistributeFpnProposalsOp should not be empty")); + PADDLE_ENFORCE_GE( + max_level, + min_level, + errors::InvalidArgument( + "max_level must not lower than " + "min_level. But received max_level = %d, min_level = %d", + max_level, + min_level)); + // Set the output shape + for (size_t i = 0; i < multi_fpn_rois.size(); ++i) { + DDim out_dim = {-1, 4}; + if (multi_fpn_rois[i] == nullptr) { + continue; + } + multi_fpn_rois[i]->set_dims(out_dim); + multi_fpn_rois[i]->set_dtype(fpn_rois.dtype()); + } + restore_index->set_dims({-1, 1}); + restore_index->set_dtype(DataType::INT32); + for (size_t i = 0; i < multi_level_rois_num.size(); ++i) { + if (multi_level_rois_num[i] == nullptr) { + continue; + } + multi_level_rois_num[i]->set_dims({-1}); + multi_level_rois_num[i]->set_dtype(DataType::INT32); + } + + if (!config.is_runtime) { + for (size_t i = 0; i < multi_fpn_rois.size(); ++i) { + multi_fpn_rois[i]->share_lod(fpn_rois); + } + } +} + void DropoutInferMeta(const MetaTensor& x, const MetaTensor& seed_tensor, const Scalar& p, diff --git a/paddle/phi/infermeta/binary.h b/paddle/phi/infermeta/binary.h index 7dcbe33e0a9..d28d15f0829 100644 --- a/paddle/phi/infermeta/binary.h +++ b/paddle/phi/infermeta/binary.h @@ -139,6 +139,19 @@ void DistInferMeta(const MetaTensor& x, float p, MetaTensor* out); +void DistributeFpnProposalsInferMeta( + const MetaTensor& fpn_rois, + const MetaTensor& rois_num, + int min_level, + int max_level, + int refer_level, + int refer_scale, + bool pixel_offset, + std::vector multi_fpn_rois, + std::vector multi_level_rois_num, + MetaTensor* restore_index, + MetaConfig config = MetaConfig()); + void DotInferMeta(const MetaTensor& x, const MetaTensor& y, MetaTensor* out); void DropoutInferMeta(const MetaTensor& x, diff --git a/paddle/phi/kernels/cpu/distribute_fpn_proposals_kernel.cc b/paddle/phi/kernels/cpu/distribute_fpn_proposals_kernel.cc new file mode 100644 index 00000000000..8caf8b07a7a --- /dev/null +++ b/paddle/phi/kernels/cpu/distribute_fpn_proposals_kernel.cc @@ -0,0 +1,145 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/phi/kernels/distribute_fpn_proposals_kernel.h" + +#include "paddle/phi/backends/cpu/cpu_context.h" +#include "paddle/phi/core/kernel_registry.h" +#include "paddle/phi/kernels/funcs/distribute_fpn_proposals_functor.h" + +namespace phi { + +template +void DistributeFpnProposalsKernel( + const Context& dev_ctx, + const DenseTensor& fpn_rois, + const paddle::optional& rois_num, + int min_level, + int max_level, + int refer_level, + int refer_scale, + bool pixel_offset, + std::vector multi_fpn_rois, + std::vector multi_level_rois_num, + DenseTensor* restore_index) { + const int num_level = max_level - min_level + 1; + + // check that the fpn_rois is not empty + if (!rois_num.get_ptr()) { + PADDLE_ENFORCE_EQ( + fpn_rois.lod().size(), + 1UL, + errors::InvalidArgument("DistributeFpnProposalsOp needs LoD " + "with one level. But received level is %d", + fpn_rois.lod().size())); + } + + std::vector fpn_rois_lod; + int fpn_rois_num; + if (rois_num.get_ptr()) { + fpn_rois_lod = funcs::GetLodFromRoisNum(dev_ctx, rois_num.get_ptr()); + } else { + fpn_rois_lod = fpn_rois.lod().back(); + } + fpn_rois_num = fpn_rois_lod[fpn_rois_lod.size() - 1]; + std::vector target_level; + + // record the number of rois in each level + std::vector num_rois_level(num_level, 0); + std::vector num_rois_level_integral(num_level + 1, 0); + for (size_t i = 0; i < fpn_rois_lod.size() - 1; ++i) { + auto fpn_rois_slice = fpn_rois.Slice(fpn_rois_lod[i], fpn_rois_lod[i + 1]); + const T* rois_data = fpn_rois_slice.data(); + for (int j = 0; j < fpn_rois_slice.dims()[0]; ++j) { + // get the target level of current rois + T roi_scale = std::sqrt(funcs::BBoxArea(rois_data, pixel_offset)); + int tgt_lvl = std::floor(std::log2(roi_scale / refer_scale + (T)1e-6) + + refer_level); + tgt_lvl = std::min(max_level, std::max(tgt_lvl, min_level)); + target_level.push_back(tgt_lvl); + num_rois_level[tgt_lvl - min_level]++; + rois_data += funcs::kBoxDim; + } + } + // define the output rois + // pointer which point to each level fpn rois + std::vector multi_fpn_rois_data(num_level); + // lod0 which will record the offset information of each level rois + std::vector> multi_fpn_rois_lod0; + for (int i = 0; i < num_level; ++i) { + // allocate memory for each level rois + multi_fpn_rois[i]->Resize({num_rois_level[i], funcs::kBoxDim}); + multi_fpn_rois_data[i] = dev_ctx.template Alloc(multi_fpn_rois[i]); + std::vector lod0(1, 0); + multi_fpn_rois_lod0.push_back(lod0); + // statistic start point for each level rois + num_rois_level_integral[i + 1] = + num_rois_level_integral[i] + num_rois_level[i]; + } + restore_index->Resize({fpn_rois_num, 1}); + int* restore_index_data = dev_ctx.template Alloc(restore_index); + std::vector restore_index_inter(fpn_rois_num, -1); + // distribute the rois into different fpn level by target level + for (size_t i = 0; i < fpn_rois_lod.size() - 1; ++i) { + auto fpn_rois_slice = fpn_rois.Slice(fpn_rois_lod[i], fpn_rois_lod[i + 1]); + const T* rois_data = fpn_rois_slice.data(); + size_t cur_offset = fpn_rois_lod[i]; + + for (int j = 0; j < num_level; j++) { + multi_fpn_rois_lod0[j].push_back(multi_fpn_rois_lod0[j][i]); + } + for (int j = 0; j < fpn_rois_slice.dims()[0]; ++j) { + int lvl = target_level[cur_offset + j]; + memcpy(multi_fpn_rois_data[lvl - min_level], + rois_data, + funcs::kBoxDim * sizeof(T)); + multi_fpn_rois_data[lvl - min_level] += funcs::kBoxDim; + int index_in_shuffle = num_rois_level_integral[lvl - min_level] + + multi_fpn_rois_lod0[lvl - min_level][i + 1]; + restore_index_inter[index_in_shuffle] = cur_offset + j; + multi_fpn_rois_lod0[lvl - min_level][i + 1]++; + rois_data += funcs::kBoxDim; + } + } + for (int i = 0; i < fpn_rois_num; ++i) { + restore_index_data[restore_index_inter[i]] = i; + } + + if (multi_level_rois_num.size() > 0) { + int batch_size = fpn_rois_lod.size() - 1; + for (int i = 0; i < num_level; ++i) { + multi_level_rois_num[i]->Resize({batch_size}); + int* rois_num_data = dev_ctx.template Alloc(multi_level_rois_num[i]); + for (int j = 0; j < batch_size; ++j) { + rois_num_data[j] = static_cast(multi_fpn_rois_lod0[i][j + 1] - + multi_fpn_rois_lod0[i][j]); + } + } + } + // merge lod information into LoDTensor + for (int i = 0; i < num_level; ++i) { + LoD lod; + lod.emplace_back(multi_fpn_rois_lod0[i]); + multi_fpn_rois[i]->set_lod(lod); + } +} + +} // namespace phi + +PD_REGISTER_KERNEL(distribute_fpn_proposals, + CPU, + ALL_LAYOUT, + phi::DistributeFpnProposalsKernel, + float, + double) {} diff --git a/paddle/phi/kernels/distribute_fpn_proposals_kernel.h b/paddle/phi/kernels/distribute_fpn_proposals_kernel.h new file mode 100644 index 00000000000..9a7bb2f6e0e --- /dev/null +++ b/paddle/phi/kernels/distribute_fpn_proposals_kernel.h @@ -0,0 +1,35 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include "paddle/phi/core/dense_tensor.h" + +namespace phi { + +template +void DistributeFpnProposalsKernel( + const Context& ctx, + const DenseTensor& fpn_rois, + const paddle::optional& rois_num, + int min_level, + int max_level, + int refer_level, + int refer_scale, + bool pixel_offset, + std::vector multi_fpn_rois, + std::vector multi_level_rois_num, + DenseTensor* restore_index); + +} // namespace phi diff --git a/paddle/phi/kernels/funcs/distribute_fpn_proposals_functor.h b/paddle/phi/kernels/funcs/distribute_fpn_proposals_functor.h new file mode 100644 index 00000000000..a1024f04c8e --- /dev/null +++ b/paddle/phi/kernels/funcs/distribute_fpn_proposals_functor.h @@ -0,0 +1,68 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include +#include +#include +#include +#include + +#include "paddle/phi/common/place.h" +#include "paddle/phi/core/tensor_utils.h" +#include "paddle/phi/kernels/funcs/math_function.h" + +namespace phi { +namespace funcs { + +const int kBoxDim = 4; + +template +inline std::vector GetLodFromRoisNum(const Context& dev_ctx, + const DenseTensor* rois_num) { + std::vector rois_lod; + auto* rois_num_data = rois_num->data(); + DenseTensor cpu_tensor; + if (paddle::platform::is_gpu_place(rois_num->place())) { + Copy(dev_ctx, *rois_num, phi::CPUPlace(), true, &cpu_tensor); + rois_num_data = cpu_tensor.data(); + } + rois_lod.push_back(static_cast(0)); + for (int i = 0; i < rois_num->numel(); ++i) { + rois_lod.push_back(rois_lod.back() + static_cast(rois_num_data[i])); + } + return rois_lod; +} + +template +static inline T BBoxArea(const T* box, bool pixel_offset) { + if (box[2] < box[0] || box[3] < box[1]) { + // If coordinate values are is invalid + // (e.g. xmax < xmin or ymax < ymin), return 0. + return static_cast(0.); + } else { + const T w = box[2] - box[0]; + const T h = box[3] - box[1]; + if (pixel_offset) { + // If coordinate values are not within range [0, 1]. + return (w + 1) * (h + 1); + } else { + return w * h; + } + } +} + +} // namespace funcs +} // namespace phi diff --git a/paddle/phi/kernels/gpu/distribute_fpn_proposals_kernel.cu b/paddle/phi/kernels/gpu/distribute_fpn_proposals_kernel.cu new file mode 100644 index 00000000000..130dc99ab17 --- /dev/null +++ b/paddle/phi/kernels/gpu/distribute_fpn_proposals_kernel.cu @@ -0,0 +1,259 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifdef __NVCC__ +#include "cub/cub.cuh" +#endif +#ifdef __HIPCC__ +#include +namespace cub = hipcub; +#endif + +#include "paddle/phi/kernels/distribute_fpn_proposals_kernel.h" + +#include "paddle/phi/backends/gpu/gpu_context.h" +#include "paddle/phi/core/kernel_registry.h" +#include "paddle/phi/kernels/funcs/distribute_fpn_proposals_functor.h" +#include "paddle/phi/kernels/funcs/for_range.h" +#include "paddle/phi/kernels/funcs/gather.cu.h" +#include "paddle/phi/kernels/funcs/math_function.h" + +#include "paddle/fluid/memory/allocation/allocator.h" +#include "paddle/fluid/memory/memcpy.h" +#include "paddle/fluid/operators/detection/bbox_util.h" +#include "paddle/fluid/platform/device/gpu/gpu_primitives.h" + +namespace phi { + +static constexpr int kNumCUDAThreads = 64; +static constexpr int kNumMaxinumNumBlocks = 4096; + +int const BBoxSize = 4; + +static inline int NumBlocks(const int N) { + return std::min((N + kNumCUDAThreads - 1) / kNumCUDAThreads, + kNumMaxinumNumBlocks); +} + +template +__global__ void GPUDistFpnProposalsHelper(const int nthreads, + const T* rois, + const int lod_size, + const int refer_level, + const int refer_scale, + const int max_level, + const int min_level, + int* roi_batch_id_data, + int* sub_lod_list, + int* target_lvls, + bool pixel_offset = true) { + CUDA_KERNEL_LOOP(i, nthreads) { + const T* offset_roi = rois + i * BBoxSize; + int roi_batch_ind = roi_batch_id_data[i]; + // get the target level of current rois + T roi_area = paddle::operators::RoIArea(offset_roi, pixel_offset); + T roi_scale = sqrt(roi_area); + int tgt_lvl = floor( + log2(roi_scale / static_cast(refer_scale) + (T)1e-8) + refer_level); + tgt_lvl = min(max_level, max(tgt_lvl, min_level)); + target_lvls[i] = tgt_lvl; + // compute number of rois in the same batch and same target level + paddle::platform::CudaAtomicAdd( + sub_lod_list + (tgt_lvl - min_level) * lod_size + roi_batch_ind, 1); + } +} + +template +void DistributeFpnProposalsKernel( + const Context& dev_ctx, + const DenseTensor& fpn_rois, + const paddle::optional& rois_num, + int min_level, + int max_level, + int refer_level, + int refer_scale, + bool pixel_offset, + std::vector multi_fpn_rois, + std::vector multi_level_rois_num, + DenseTensor* restore_index) { + int num_level = max_level - min_level + 1; + + // check that the fpn_rois is not empty + if (!rois_num.get_ptr()) { + PADDLE_ENFORCE_EQ( + fpn_rois.lod().size(), + 1UL, + errors::InvalidArgument("DistributeFpnProposalsOp needs LoD" + "with one level")); + } + + std::vector fpn_rois_lod; + if (rois_num.get_ptr()) { + fpn_rois_lod = funcs::GetLodFromRoisNum(dev_ctx, rois_num.get_ptr()); + } else { + fpn_rois_lod = fpn_rois.lod().back(); + } + int lod_size = fpn_rois_lod.size() - 1; + int roi_num = fpn_rois_lod[lod_size]; + + // get batch id by lod in CPU + DenseTensor roi_batch_id_list; + roi_batch_id_list.Resize({roi_num}); + int* roi_batch_id_data = dev_ctx.template HostAlloc(&roi_batch_id_list); + for (int n = 0; n < lod_size; ++n) { + for (size_t i = fpn_rois_lod[n]; i < fpn_rois_lod[n + 1]; ++i) { + roi_batch_id_data[i] = n; + } + } + // copy batch id list to GPU + DenseTensor roi_batch_id_list_gpu; + Copy(dev_ctx, + roi_batch_id_list, + dev_ctx.GetPlace(), + true, + &roi_batch_id_list_gpu); + + DenseTensor sub_lod_list; + sub_lod_list.Resize({num_level, lod_size}); + int* sub_lod_list_data = dev_ctx.template Alloc(&sub_lod_list); + phi::funcs::SetConstant set_zero; + set_zero(dev_ctx, &sub_lod_list, static_cast(0)); + + DenseTensor target_lvls; + target_lvls.Resize({roi_num}); + int* target_lvls_data = dev_ctx.template Alloc(&target_lvls); + + int dist_blocks = NumBlocks(roi_num); + int threads = kNumCUDAThreads; + // get target levels and sub_lod list + GPUDistFpnProposalsHelper<<>>( + roi_num, + fpn_rois.data(), + lod_size, + refer_level, + refer_scale, + max_level, + min_level, + roi_batch_id_list_gpu.data(), + sub_lod_list_data, + target_lvls_data, + pixel_offset); + auto place = dev_ctx.GetPlace(); + + DenseTensor index_in_t; + index_in_t.Resize({roi_num}); + int* idx_in = dev_ctx.template Alloc(&index_in_t); + funcs::ForRange for_range(dev_ctx, roi_num); + for_range(paddle::operators::RangeInitFunctor{0, 1, idx_in}); + + DenseTensor keys_out_t; + keys_out_t.Resize({roi_num}); + int* keys_out = dev_ctx.template Alloc(&keys_out_t); + DenseTensor index_out_t; + index_out_t.Resize({roi_num}); + int* idx_out = dev_ctx.template Alloc(&index_out_t); + + // Determine temporary device storage requirements + size_t temp_storage_bytes = 0; + cub::DeviceRadixSort::SortPairs(nullptr, + temp_storage_bytes, + target_lvls_data, + keys_out, + idx_in, + idx_out, + roi_num, + 0, + sizeof(int) * 8, + dev_ctx.stream()); + // Allocate temporary storage + auto d_temp_storage = paddle::memory::Alloc(place, temp_storage_bytes); + + // Run sorting operation + // sort target level to get corresponding index + cub::DeviceRadixSort::SortPairs(d_temp_storage->ptr(), + temp_storage_bytes, + target_lvls_data, + keys_out, + idx_in, + idx_out, + roi_num, + 0, + sizeof(int) * 8, + dev_ctx.stream()); + + restore_index->Resize({roi_num, 1}); + int* restore_idx_data = dev_ctx.template Alloc(restore_index); + // sort current index to get restore index + cub::DeviceRadixSort::SortPairs(d_temp_storage->ptr(), + temp_storage_bytes, + idx_out, + keys_out, + idx_in, + restore_idx_data, + roi_num, + 0, + sizeof(int) * 8, + dev_ctx.stream()); + + int start = 0; + + std::vector sub_lod_list_cpu(lod_size * num_level); + paddle::memory::Copy(phi::CPUPlace(), + sub_lod_list_cpu.data(), + place, + sub_lod_list_data, + sizeof(int) * lod_size * num_level, + dev_ctx.stream()); + dev_ctx.Wait(); + + for (int i = 0; i < num_level; ++i) { + DenseTensor sub_lod = sub_lod_list.Slice(i, i + 1); + // transfer length-based lod to offset-based lod + std::vector offset(1, 0); + for (int j = 0; j < lod_size; ++j) { + offset.emplace_back(offset.back() + sub_lod_list_cpu[i * lod_size + j]); + } + + int sub_rois_num = offset.back(); + + int end = start + sub_rois_num; + if (end > start) { + DenseTensor sub_idx = index_out_t.Slice(start, end); + start = end; + multi_fpn_rois[i]->Resize({sub_rois_num, funcs::kBoxDim}); + dev_ctx.template Alloc(multi_fpn_rois[i]); + phi::funcs::GPUGather(dev_ctx, fpn_rois, sub_idx, multi_fpn_rois[i]); + } else { + multi_fpn_rois[i]->Resize({sub_rois_num, funcs::kBoxDim}); + dev_ctx.template Alloc(multi_fpn_rois[i]); + } + if (multi_level_rois_num.size() > 0) { + DenseTensor* rois_num_t = multi_level_rois_num[i]; + Copy(dev_ctx, sub_lod, dev_ctx.GetPlace(), true, rois_num_t); + rois_num_t->Resize({lod_size}); + } + LoD lod; + lod.emplace_back(offset); + multi_fpn_rois[i]->set_lod(lod); + } +} + +} // namespace phi + +PD_REGISTER_KERNEL(distribute_fpn_proposals, + GPU, + ALL_LAYOUT, + phi::DistributeFpnProposalsKernel, + float, + double) {} diff --git a/paddle/phi/ops/compat/distribute_fpn_proposals_sig.cc b/paddle/phi/ops/compat/distribute_fpn_proposals_sig.cc new file mode 100644 index 00000000000..ad02fb1aa09 --- /dev/null +++ b/paddle/phi/ops/compat/distribute_fpn_proposals_sig.cc @@ -0,0 +1,31 @@ +/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "paddle/phi/core/compat/op_utils.h" + +namespace phi { + +KernelSignature DistributeFpnProposalsOpArgumentMapping( + const ArgumentMappingContext& ctx) { + return KernelSignature( + "distribute_fpn_proposals", + {"FpnRois", "RoisNum"}, + {"min_level", "max_level", "refer_level", "refer_scale", "pixel_offset"}, + {"MultiFpnRois", "MultiLevelRoIsNum", "RestoreIndex"}); +} + +} // namespace phi + +PD_REGISTER_ARG_MAPPING_FN(distribute_fpn_proposals, + phi::DistributeFpnProposalsOpArgumentMapping); diff --git a/python/paddle/fluid/tests/unittests/test_distribute_fpn_proposals_op.py b/python/paddle/fluid/tests/unittests/test_distribute_fpn_proposals_op.py index 24d15769492..e99d04526a3 100644 --- a/python/paddle/fluid/tests/unittests/test_distribute_fpn_proposals_op.py +++ b/python/paddle/fluid/tests/unittests/test_distribute_fpn_proposals_op.py @@ -23,6 +23,16 @@ import paddle from op_test import OpTest +def distribute_fpn_proposals_wrapper(fpn_rois, rois_num, min_level, max_level, + refer_level, refer_scale, pixel_offset): + return paddle.vision.ops.distribute_fpn_proposals(fpn_rois=fpn_rois, + min_level=min_level, + max_level=max_level, + refer_level=refer_level, + refer_scale=refer_scale, + rois_num=rois_num) + + class TestDistributeFPNProposalsOp(OpTest): def set_data(self): @@ -44,6 +54,8 @@ class TestDistributeFPNProposalsOp(OpTest): 'MultiFpnRois': output, 'RestoreIndex': self.rois_idx_restore.reshape(-1, 1), } + self.python_api = distribute_fpn_proposals_wrapper + self.python_out_sig = ['MultiFpnRois', 'RestoreIndex'] def init_test_case(self): self.roi_max_level = 5 @@ -152,6 +164,10 @@ class TestDistributeFPNProposalsOpWithRoisNum(TestDistributeFPNProposalsOp): 'RestoreIndex': self.rois_idx_restore.reshape(-1, 1), 'MultiLevelRoIsNum': rois_num_per_level } + self.python_api = distribute_fpn_proposals_wrapper + self.python_out_sig = [ + 'MultiFpnRois', 'MultiLevelRoIsNum', 'RestoreIndex' + ] class TestDistributeFPNProposalsOpNoOffset( diff --git a/python/paddle/vision/ops.py b/python/paddle/vision/ops.py index cf038d18ae3..acd896e71e8 100755 --- a/python/paddle/vision/ops.py +++ b/python/paddle/vision/ops.py @@ -899,6 +899,13 @@ def distribute_fpn_proposals(fpn_rois, """ num_lvl = max_level - min_level + 1 + if in_dygraph_mode(): + assert rois_num is not None, "rois_num should not be None in dygraph mode." + multi_rois, rois_num_per_level, restore_ind = _C_ops.final_state_distribute_fpn_proposals( + fpn_rois, rois_num, min_level, max_level, refer_level, refer_scale, + pixel_offset) + return multi_rois, restore_ind, rois_num_per_level + if _non_static_mode(): assert rois_num is not None, "rois_num should not be None in dygraph mode." attrs = ('min_level', min_level, 'max_level', max_level, 'refer_level', -- GitLab