未验证 提交 cc95a751 编写于 作者: J jerrywgz 提交者: GitHub

fix distribute fpn proposals, test=develop (#16152)

* fix distribute fpn proposals, test=develop
上级 9ec4615d
......@@ -15,11 +15,37 @@ limitations under the License. */
#pragma once
#include <algorithm>
#include "paddle/fluid/framework/eigen.h"
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/framework/tensor.h"
namespace paddle {
namespace operators {
struct RangeInitFunctor {
int start_;
int delta_;
int* out_;
HOSTDEVICE void operator()(size_t i) { out_[i] = start_ + i * delta_; }
};
template <typename T>
inline HOSTDEVICE T RoIArea(const T* box, bool normalized) {
if (box[2] < box[0] || box[3] < box[1]) {
// If coordinate values are is invalid
// (e.g. xmax < xmin or ymax < ymin), return 0.
return static_cast<T>(0.);
} else {
const T w = box[2] - box[0];
const T h = box[3] - box[1];
if (normalized) {
return w * h;
} else {
// If coordinate values are not within range [0, 1].
return (w + 1) * (h + 1);
}
}
}
/*
* transform that computes target bounding-box regression deltas
* given proposal boxes and ground-truth boxes.
......
......@@ -40,14 +40,14 @@ class DistributeFpnProposalsOp : public framework::OperatorWithKernel {
outs_dims.push_back(out_dim);
}
ctx->SetOutputsDim("MultiFpnRois", outs_dims);
ctx->SetOutputDim("RestoreIndex", {1, -1});
ctx->SetOutputDim("RestoreIndex", {-1, 1});
}
protected:
framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override {
auto data_type = framework::GetDataTypeOfVar(ctx.InputVar("FpnRois"));
return framework::OpKernelType(data_type, platform::CPUPlace());
return framework::OpKernelType(data_type, ctx.device_context());
}
};
......
......@@ -15,8 +15,10 @@ limitations under the License. */
#include <paddle/fluid/memory/allocation/allocator.h>
#include "cub/cub.cuh"
#include "paddle/fluid/memory/memcpy.h"
#include "paddle/fluid/operators/detection/bbox_util.h"
#include "paddle/fluid/operators/detection/distribute_fpn_proposals_op.h"
#include "paddle/fluid/operators/gather.cu.h"
#include "paddle/fluid/operators/math/math_function.h"
#include "paddle/fluid/platform/cuda_primitives.h"
#include "paddle/fluid/platform/for_range.h"
......@@ -26,7 +28,7 @@ namespace operators {
using Tensor = framework::Tensor;
using LoDTensor = framework::LoDTensor;
static constexpr int kNumCUDAThreads = 512;
static constexpr int kNumCUDAThreads = 64;
static constexpr int kNumMaxinumNumBlocks = 4096;
#define CUDA_1D_KERNEL_LOOP(i, n) \
......@@ -35,47 +37,13 @@ static constexpr int kNumMaxinumNumBlocks = 4096;
int const BBoxSize = 4;
struct RangeInitFunctor {
int start_;
int delta_;
int* out_;
__device__ void operator()(size_t i) { out_[i] = start_ + i * delta_; }
};
static inline int NumBlocks(const int N) {
return std::min((N + kNumCUDAThreads - 1) / kNumCUDAThreads,
kNumMaxinumNumBlocks);
}
static inline void TransLoD(const int* length_lod, const int lod_size,
int* offset_lod) {
int offset = 0;
for (int i = 0; i < lod_size; ++i) {
offset_lod[i] = offset;
offset += length_lod[i];
}
}
template <typename T>
static __device__ inline T RoIArea(const T* box, bool normalized) {
if (box[2] < box[0] || box[3] < box[1]) {
// If coordinate values are is invalid
// (e.g. xmax < xmin or ymax < ymin), return 0.
return static_cast<T>(0.);
} else {
const T w = box[2] - box[0];
const T h = box[3] - box[1];
if (normalized) {
return w * h;
} else {
// If coordinate values are not within range [0, 1].
return (w + 1) * (h + 1);
}
}
}
template <class T>
static __global__ void GPUDistFpnProposalsHelper(
__global__ void GPUDistFpnProposalsHelper(
const int nthreads, const T* rois, const int lod_size,
const int refer_level, const int refer_scale, const int max_level,
const int min_level, int* roi_batch_id_data, int* sub_lod_list,
......@@ -86,12 +54,13 @@ static __global__ void GPUDistFpnProposalsHelper(
// get the target level of current rois
T roi_area = RoIArea(offset_roi, false);
T roi_scale = sqrt(roi_area);
int tgt_lvl = floor(log2(roi_scale / refer_scale) + refer_level);
int tgt_lvl = floor(
log2(roi_scale / static_cast<T>(refer_scale) + (T)1e-6) + refer_level);
tgt_lvl = min(max_level, max(tgt_lvl, min_level));
target_lvls[i] = tgt_lvl;
// compute number of rois in the same batch and same target level
platform::CudaAtomicAdd(sub_lod_list + tgt_lvl * lod_size + roi_batch_ind,
1);
platform::CudaAtomicAdd(
sub_lod_list + (tgt_lvl - min_level) * lod_size + roi_batch_ind, 1);
}
}
......@@ -138,18 +107,22 @@ class GPUDistributeFpnProposalsOpKernel : public framework::OpKernel<T> {
Tensor sub_lod_list;
sub_lod_list.Resize({num_level, lod_size});
int* sub_lod_list_data = sub_lod_list.mutable_data<int>(dev_ctx.GetPlace());
math::SetConstant<platform::CUDADeviceContext, int> set_zero;
set_zero(dev_ctx, &sub_lod_list, static_cast<int>(0));
Tensor target_lvls;
target_lvls.Resize({roi_num});
int* target_lvls_data = target_lvls.mutable_data<int>(dev_ctx.GetPlace());
int blocks = NumBlocks(roi_num);
int dist_blocks = NumBlocks(roi_num);
int threads = kNumCUDAThreads;
// get target levels and sub_lod list
GPUDistFpnProposalsHelper<T><<<blocks, threads>>>(
GPUDistFpnProposalsHelper<T><<<dist_blocks, threads>>>(
roi_num, fpn_rois->data<T>(), lod_size, refer_level, refer_scale,
max_level, min_level, roi_batch_id_list_gpu.data<int>(),
sub_lod_list_data, target_lvls_data);
dev_ctx.Wait();
auto place = boost::get<platform::CUDAPlace>(dev_ctx.GetPlace());
Tensor index_in_t;
int* idx_in = index_in_t.mutable_data<int>({roi_num}, dev_ctx.GetPlace());
......@@ -163,46 +136,54 @@ class GPUDistributeFpnProposalsOpKernel : public framework::OpKernel<T> {
// Determine temporary device storage requirements
size_t temp_storage_bytes = 0;
cub::DeviceRadixSort::SortPairsDescending<int, int>(
nullptr, temp_storage_bytes, target_lvls_data, keys_out, idx_in,
idx_out, roi_num);
cub::DeviceRadixSort::SortPairs<int, int>(nullptr, temp_storage_bytes,
target_lvls_data, keys_out,
idx_in, idx_out, roi_num);
// Allocate temporary storage
auto place = boost::get<platform::CUDAPlace>(dev_ctx.GetPlace());
auto d_temp_storage = memory::Alloc(place, temp_storage_bytes,
memory::Allocator::kScratchpad);
// Run sorting operation
// sort target level to get corresponding index
cub::DeviceRadixSort::SortPairsDescending<int, int>(
cub::DeviceRadixSort::SortPairs<int, int>(
d_temp_storage->ptr(), temp_storage_bytes, target_lvls_data, keys_out,
idx_in, idx_out, roi_num);
int* restore_idx_data =
restore_index->mutable_data<int>({roi_num, 1}, dev_ctx.GetPlace());
// sort current index to get restore index
cub::DeviceRadixSort::SortPairsDescending<int, int>(
cub::DeviceRadixSort::SortPairs<int, int>(
d_temp_storage->ptr(), temp_storage_bytes, idx_out, keys_out, idx_in,
restore_idx_data, roi_num);
Tensor offset_lod;
int* offset_lod_data =
offset_lod.mutable_data<int>({lod_size + 1}, dev_ctx.GetPlace());
int start = 0;
for (int i = 0; i < num_level; ++i) {
Tensor sub_lod = sub_lod_list.Slice(i, i + 1);
int* sub_lod_data = sub_lod.data<int>();
// transfer length-based lod to offset-based lod
TransLoD(sub_lod_data, lod_size + 1, offset_lod_data);
int sub_rois_num = offset_lod_data[lod_size];
Tensor sub_idx = index_out_t.Slice(0, sub_rois_num);
multi_fpn_rois[i]->mutable_data<T>({sub_rois_num, kBoxDim},
dev_ctx.GetPlace());
std::vector<size_t> offset(1, 0);
std::vector<int> sub_lod_cpu(lod_size);
memory::Copy(platform::CPUPlace(), sub_lod_cpu.data(), place,
sub_lod_data, sizeof(int) * lod_size, dev_ctx.stream());
dev_ctx.Wait();
for (int j = 0; j < lod_size; ++j) {
offset.emplace_back(offset.back() + sub_lod_cpu[j]);
}
GPUGather<T>(dev_ctx, *fpn_rois, sub_idx, multi_fpn_rois[i]);
int sub_rois_num = offset.back();
int end = start + sub_rois_num;
if (end > start) {
Tensor sub_idx = index_out_t.Slice(start, end);
start = end;
multi_fpn_rois[i]->mutable_data<T>({sub_rois_num, kBoxDim},
dev_ctx.GetPlace());
GPUGather<T>(dev_ctx, *fpn_rois, sub_idx, multi_fpn_rois[i]);
} else {
multi_fpn_rois[i]->mutable_data<T>({sub_rois_num, kBoxDim},
dev_ctx.GetPlace());
}
framework::LoD lod;
std::vector<size_t> offset;
memory::Copy(platform::CPUPlace(), offset.data(), place, offset_lod_data,
sizeof(int) * (lod_size + 1), 0);
lod.emplace_back(offset);
multi_fpn_rois[i]->set_lod(lod);
}
......
......@@ -83,8 +83,8 @@ class DistributeFpnProposalsOpKernel : public framework::OpKernel<T> {
for (int j = 0; j < fpn_rois_slice.dims()[0]; ++j) {
// get the target level of current rois
T roi_scale = std::sqrt(BBoxArea(rois_data, false));
int tgt_lvl =
std::floor(std::log2(roi_scale / refer_scale) + refer_level);
int tgt_lvl = std::floor(std::log2(roi_scale / refer_scale + (T)1e-6) +
refer_level);
tgt_lvl = std::min(max_level, std::max(tgt_lvl, min_level));
target_level.push_back(tgt_lvl);
num_rois_level[tgt_lvl - min_level]++;
......@@ -107,7 +107,7 @@ class DistributeFpnProposalsOpKernel : public framework::OpKernel<T> {
num_rois_level_integral[i + 1] =
num_rois_level_integral[i] + num_rois_level[i];
}
restore_index->mutable_data<int>({1, fpn_rois_num}, context.GetPlace());
restore_index->mutable_data<int>({fpn_rois_num, 1}, context.GetPlace());
int* restore_index_data = restore_index->data<int>();
std::vector<int> restore_index_inter(fpn_rois_num, -1);
// distribute the rois into different fpn level by target level
......
......@@ -2383,7 +2383,7 @@ def distribute_fpn_proposals(fpn_rois,
"""
helper = LayerHelper('distribute_fpn_proposals', **locals())
dtype = helper.input_dtype()
dtype = helper.input_dtype('fpn_rois')
num_lvl = max_level - min_level + 1
multi_rois = [
helper.create_variable_for_type_inference(dtype) for i in range(num_lvl)
......
......@@ -37,7 +37,7 @@ class TestDistributeFPNProposalsOp(OpTest):
for i in range(len(self.rois_fpn))]
self.outputs = {
'MultiFpnRois': output,
'RestoreIndex': self.rois_idx_restore
'RestoreIndex': self.rois_idx_restore.reshape(-1, 1)
}
def init_test_case(self):
......@@ -63,10 +63,10 @@ class TestDistributeFPNProposalsOp(OpTest):
return target_lvls
def get_sub_lod(self, sub_lvl):
sub_lod = []
sub_lod = [0, 0]
max_batch_id = sub_lvl[-1]
for i in range(max_batch_id.astype(np.int32) + 1):
sub_lod.append(np.where(sub_lvl == i)[0].size)
sub_lod[i] = np.where(sub_lvl == i)[0].size
return sub_lod
def add_multilevel_roi(self, rois, target_lvls, lvl_min, lvl_max):
......@@ -115,3 +115,7 @@ class TestDistributeFPNProposalsOp(OpTest):
def test_check_output(self):
self.check_output()
if __name__ == '__main__':
unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册