From cc95a7516cb111f08914e712244258dacbbe8f20 Mon Sep 17 00:00:00 2001 From: jerrywgz Date: Mon, 6 May 2019 17:51:16 +0800 Subject: [PATCH] fix distribute fpn proposals, test=develop (#16152) * fix distribute fpn proposals, test=develop --- paddle/fluid/operators/detection/bbox_util.h | 26 +++++ .../detection/distribute_fpn_proposals_op.cc | 4 +- .../detection/distribute_fpn_proposals_op.cu | 103 +++++++----------- .../detection/distribute_fpn_proposals_op.h | 6 +- python/paddle/fluid/layers/detection.py | 2 +- .../test_distribute_fpn_proposals_op.py | 10 +- 6 files changed, 81 insertions(+), 70 deletions(-) diff --git a/paddle/fluid/operators/detection/bbox_util.h b/paddle/fluid/operators/detection/bbox_util.h index a7bc3e0272..d4cf9a326c 100644 --- a/paddle/fluid/operators/detection/bbox_util.h +++ b/paddle/fluid/operators/detection/bbox_util.h @@ -15,11 +15,37 @@ limitations under the License. */ #pragma once #include #include "paddle/fluid/framework/eigen.h" +#include "paddle/fluid/framework/op_registry.h" #include "paddle/fluid/framework/tensor.h" namespace paddle { namespace operators { +struct RangeInitFunctor { + int start_; + int delta_; + int* out_; + HOSTDEVICE void operator()(size_t i) { out_[i] = start_ + i * delta_; } +}; + +template +inline HOSTDEVICE T RoIArea(const T* box, bool normalized) { + if (box[2] < box[0] || box[3] < box[1]) { + // If coordinate values are is invalid + // (e.g. xmax < xmin or ymax < ymin), return 0. + return static_cast(0.); + } else { + const T w = box[2] - box[0]; + const T h = box[3] - box[1]; + if (normalized) { + return w * h; + } else { + // If coordinate values are not within range [0, 1]. + return (w + 1) * (h + 1); + } + } +} + /* * transform that computes target bounding-box regression deltas * given proposal boxes and ground-truth boxes. diff --git a/paddle/fluid/operators/detection/distribute_fpn_proposals_op.cc b/paddle/fluid/operators/detection/distribute_fpn_proposals_op.cc index 6d36876efd..4cc989b632 100644 --- a/paddle/fluid/operators/detection/distribute_fpn_proposals_op.cc +++ b/paddle/fluid/operators/detection/distribute_fpn_proposals_op.cc @@ -40,14 +40,14 @@ class DistributeFpnProposalsOp : public framework::OperatorWithKernel { outs_dims.push_back(out_dim); } ctx->SetOutputsDim("MultiFpnRois", outs_dims); - ctx->SetOutputDim("RestoreIndex", {1, -1}); + ctx->SetOutputDim("RestoreIndex", {-1, 1}); } protected: framework::OpKernelType GetExpectedKernelType( const framework::ExecutionContext& ctx) const override { auto data_type = framework::GetDataTypeOfVar(ctx.InputVar("FpnRois")); - return framework::OpKernelType(data_type, platform::CPUPlace()); + return framework::OpKernelType(data_type, ctx.device_context()); } }; diff --git a/paddle/fluid/operators/detection/distribute_fpn_proposals_op.cu b/paddle/fluid/operators/detection/distribute_fpn_proposals_op.cu index 9cbb969158..598510870a 100644 --- a/paddle/fluid/operators/detection/distribute_fpn_proposals_op.cu +++ b/paddle/fluid/operators/detection/distribute_fpn_proposals_op.cu @@ -15,8 +15,10 @@ limitations under the License. */ #include #include "cub/cub.cuh" #include "paddle/fluid/memory/memcpy.h" +#include "paddle/fluid/operators/detection/bbox_util.h" #include "paddle/fluid/operators/detection/distribute_fpn_proposals_op.h" #include "paddle/fluid/operators/gather.cu.h" +#include "paddle/fluid/operators/math/math_function.h" #include "paddle/fluid/platform/cuda_primitives.h" #include "paddle/fluid/platform/for_range.h" @@ -26,7 +28,7 @@ namespace operators { using Tensor = framework::Tensor; using LoDTensor = framework::LoDTensor; -static constexpr int kNumCUDAThreads = 512; +static constexpr int kNumCUDAThreads = 64; static constexpr int kNumMaxinumNumBlocks = 4096; #define CUDA_1D_KERNEL_LOOP(i, n) \ @@ -35,47 +37,13 @@ static constexpr int kNumMaxinumNumBlocks = 4096; int const BBoxSize = 4; -struct RangeInitFunctor { - int start_; - int delta_; - int* out_; - __device__ void operator()(size_t i) { out_[i] = start_ + i * delta_; } -}; - static inline int NumBlocks(const int N) { return std::min((N + kNumCUDAThreads - 1) / kNumCUDAThreads, kNumMaxinumNumBlocks); } -static inline void TransLoD(const int* length_lod, const int lod_size, - int* offset_lod) { - int offset = 0; - for (int i = 0; i < lod_size; ++i) { - offset_lod[i] = offset; - offset += length_lod[i]; - } -} - -template -static __device__ inline T RoIArea(const T* box, bool normalized) { - if (box[2] < box[0] || box[3] < box[1]) { - // If coordinate values are is invalid - // (e.g. xmax < xmin or ymax < ymin), return 0. - return static_cast(0.); - } else { - const T w = box[2] - box[0]; - const T h = box[3] - box[1]; - if (normalized) { - return w * h; - } else { - // If coordinate values are not within range [0, 1]. - return (w + 1) * (h + 1); - } - } -} - template -static __global__ void GPUDistFpnProposalsHelper( +__global__ void GPUDistFpnProposalsHelper( const int nthreads, const T* rois, const int lod_size, const int refer_level, const int refer_scale, const int max_level, const int min_level, int* roi_batch_id_data, int* sub_lod_list, @@ -86,12 +54,13 @@ static __global__ void GPUDistFpnProposalsHelper( // get the target level of current rois T roi_area = RoIArea(offset_roi, false); T roi_scale = sqrt(roi_area); - int tgt_lvl = floor(log2(roi_scale / refer_scale) + refer_level); + int tgt_lvl = floor( + log2(roi_scale / static_cast(refer_scale) + (T)1e-6) + refer_level); tgt_lvl = min(max_level, max(tgt_lvl, min_level)); target_lvls[i] = tgt_lvl; // compute number of rois in the same batch and same target level - platform::CudaAtomicAdd(sub_lod_list + tgt_lvl * lod_size + roi_batch_ind, - 1); + platform::CudaAtomicAdd( + sub_lod_list + (tgt_lvl - min_level) * lod_size + roi_batch_ind, 1); } } @@ -138,18 +107,22 @@ class GPUDistributeFpnProposalsOpKernel : public framework::OpKernel { Tensor sub_lod_list; sub_lod_list.Resize({num_level, lod_size}); int* sub_lod_list_data = sub_lod_list.mutable_data(dev_ctx.GetPlace()); + math::SetConstant set_zero; + set_zero(dev_ctx, &sub_lod_list, static_cast(0)); + Tensor target_lvls; target_lvls.Resize({roi_num}); int* target_lvls_data = target_lvls.mutable_data(dev_ctx.GetPlace()); - int blocks = NumBlocks(roi_num); + int dist_blocks = NumBlocks(roi_num); int threads = kNumCUDAThreads; - // get target levels and sub_lod list - GPUDistFpnProposalsHelper<<>>( + GPUDistFpnProposalsHelper<<>>( roi_num, fpn_rois->data(), lod_size, refer_level, refer_scale, max_level, min_level, roi_batch_id_list_gpu.data(), sub_lod_list_data, target_lvls_data); + dev_ctx.Wait(); + auto place = boost::get(dev_ctx.GetPlace()); Tensor index_in_t; int* idx_in = index_in_t.mutable_data({roi_num}, dev_ctx.GetPlace()); @@ -163,46 +136,54 @@ class GPUDistributeFpnProposalsOpKernel : public framework::OpKernel { // Determine temporary device storage requirements size_t temp_storage_bytes = 0; - cub::DeviceRadixSort::SortPairsDescending( - nullptr, temp_storage_bytes, target_lvls_data, keys_out, idx_in, - idx_out, roi_num); + cub::DeviceRadixSort::SortPairs(nullptr, temp_storage_bytes, + target_lvls_data, keys_out, + idx_in, idx_out, roi_num); // Allocate temporary storage - auto place = boost::get(dev_ctx.GetPlace()); auto d_temp_storage = memory::Alloc(place, temp_storage_bytes, memory::Allocator::kScratchpad); // Run sorting operation // sort target level to get corresponding index - cub::DeviceRadixSort::SortPairsDescending( + cub::DeviceRadixSort::SortPairs( d_temp_storage->ptr(), temp_storage_bytes, target_lvls_data, keys_out, idx_in, idx_out, roi_num); int* restore_idx_data = restore_index->mutable_data({roi_num, 1}, dev_ctx.GetPlace()); // sort current index to get restore index - cub::DeviceRadixSort::SortPairsDescending( + cub::DeviceRadixSort::SortPairs( d_temp_storage->ptr(), temp_storage_bytes, idx_out, keys_out, idx_in, restore_idx_data, roi_num); - Tensor offset_lod; - int* offset_lod_data = - offset_lod.mutable_data({lod_size + 1}, dev_ctx.GetPlace()); + int start = 0; for (int i = 0; i < num_level; ++i) { Tensor sub_lod = sub_lod_list.Slice(i, i + 1); int* sub_lod_data = sub_lod.data(); // transfer length-based lod to offset-based lod - TransLoD(sub_lod_data, lod_size + 1, offset_lod_data); - int sub_rois_num = offset_lod_data[lod_size]; - Tensor sub_idx = index_out_t.Slice(0, sub_rois_num); - - multi_fpn_rois[i]->mutable_data({sub_rois_num, kBoxDim}, - dev_ctx.GetPlace()); + std::vector offset(1, 0); + std::vector sub_lod_cpu(lod_size); + memory::Copy(platform::CPUPlace(), sub_lod_cpu.data(), place, + sub_lod_data, sizeof(int) * lod_size, dev_ctx.stream()); + dev_ctx.Wait(); + for (int j = 0; j < lod_size; ++j) { + offset.emplace_back(offset.back() + sub_lod_cpu[j]); + } - GPUGather(dev_ctx, *fpn_rois, sub_idx, multi_fpn_rois[i]); + int sub_rois_num = offset.back(); + + int end = start + sub_rois_num; + if (end > start) { + Tensor sub_idx = index_out_t.Slice(start, end); + start = end; + multi_fpn_rois[i]->mutable_data({sub_rois_num, kBoxDim}, + dev_ctx.GetPlace()); + GPUGather(dev_ctx, *fpn_rois, sub_idx, multi_fpn_rois[i]); + } else { + multi_fpn_rois[i]->mutable_data({sub_rois_num, kBoxDim}, + dev_ctx.GetPlace()); + } framework::LoD lod; - std::vector offset; - memory::Copy(platform::CPUPlace(), offset.data(), place, offset_lod_data, - sizeof(int) * (lod_size + 1), 0); lod.emplace_back(offset); multi_fpn_rois[i]->set_lod(lod); } diff --git a/paddle/fluid/operators/detection/distribute_fpn_proposals_op.h b/paddle/fluid/operators/detection/distribute_fpn_proposals_op.h index f63e856626..a3196ea5f6 100644 --- a/paddle/fluid/operators/detection/distribute_fpn_proposals_op.h +++ b/paddle/fluid/operators/detection/distribute_fpn_proposals_op.h @@ -83,8 +83,8 @@ class DistributeFpnProposalsOpKernel : public framework::OpKernel { for (int j = 0; j < fpn_rois_slice.dims()[0]; ++j) { // get the target level of current rois T roi_scale = std::sqrt(BBoxArea(rois_data, false)); - int tgt_lvl = - std::floor(std::log2(roi_scale / refer_scale) + refer_level); + int tgt_lvl = std::floor(std::log2(roi_scale / refer_scale + (T)1e-6) + + refer_level); tgt_lvl = std::min(max_level, std::max(tgt_lvl, min_level)); target_level.push_back(tgt_lvl); num_rois_level[tgt_lvl - min_level]++; @@ -107,7 +107,7 @@ class DistributeFpnProposalsOpKernel : public framework::OpKernel { num_rois_level_integral[i + 1] = num_rois_level_integral[i] + num_rois_level[i]; } - restore_index->mutable_data({1, fpn_rois_num}, context.GetPlace()); + restore_index->mutable_data({fpn_rois_num, 1}, context.GetPlace()); int* restore_index_data = restore_index->data(); std::vector restore_index_inter(fpn_rois_num, -1); // distribute the rois into different fpn level by target level diff --git a/python/paddle/fluid/layers/detection.py b/python/paddle/fluid/layers/detection.py index 8aeaf4e92a..ca0952ca1f 100644 --- a/python/paddle/fluid/layers/detection.py +++ b/python/paddle/fluid/layers/detection.py @@ -2383,7 +2383,7 @@ def distribute_fpn_proposals(fpn_rois, """ helper = LayerHelper('distribute_fpn_proposals', **locals()) - dtype = helper.input_dtype() + dtype = helper.input_dtype('fpn_rois') num_lvl = max_level - min_level + 1 multi_rois = [ helper.create_variable_for_type_inference(dtype) for i in range(num_lvl) diff --git a/python/paddle/fluid/tests/unittests/test_distribute_fpn_proposals_op.py b/python/paddle/fluid/tests/unittests/test_distribute_fpn_proposals_op.py index 1464060f59..55b21f1a72 100644 --- a/python/paddle/fluid/tests/unittests/test_distribute_fpn_proposals_op.py +++ b/python/paddle/fluid/tests/unittests/test_distribute_fpn_proposals_op.py @@ -37,7 +37,7 @@ class TestDistributeFPNProposalsOp(OpTest): for i in range(len(self.rois_fpn))] self.outputs = { 'MultiFpnRois': output, - 'RestoreIndex': self.rois_idx_restore + 'RestoreIndex': self.rois_idx_restore.reshape(-1, 1) } def init_test_case(self): @@ -63,10 +63,10 @@ class TestDistributeFPNProposalsOp(OpTest): return target_lvls def get_sub_lod(self, sub_lvl): - sub_lod = [] + sub_lod = [0, 0] max_batch_id = sub_lvl[-1] for i in range(max_batch_id.astype(np.int32) + 1): - sub_lod.append(np.where(sub_lvl == i)[0].size) + sub_lod[i] = np.where(sub_lvl == i)[0].size return sub_lod def add_multilevel_roi(self, rois, target_lvls, lvl_min, lvl_max): @@ -115,3 +115,7 @@ class TestDistributeFPNProposalsOp(OpTest): def test_check_output(self): self.check_output() + + +if __name__ == '__main__': + unittest.main() -- GitLab