未验证 提交 8f8ed7de 编写于 作者: Y YuanRisheng 提交者: GitHub

[Phi]Move distribute_fpn_proposals to PHI (#45212)

* move distribute_fpn_proposals

* fix some code

* fix yaml bugs

* add set dtype

* move proposal_impl to funcs

* fix compile bugs
上级 da51baf2
......@@ -93,9 +93,8 @@ if(WITH_GPU OR WITH_ROCM)
generate_proposals_op.cu DEPS ${TMPDEPS})
detection_library(generate_proposals_v2_op SRCS generate_proposals_v2_op.cc
DEPS ${TMPDEPS})
detection_library(
distribute_fpn_proposals_op SRCS distribute_fpn_proposals_op.cc
distribute_fpn_proposals_op.cu DEPS ${TMPDEPS})
detection_library(distribute_fpn_proposals_op SRCS
distribute_fpn_proposals_op.cc DEPS ${TMPDEPS})
detection_library(collect_fpn_proposals_op SRCS collect_fpn_proposals_op.cc
collect_fpn_proposals_op.cu DEPS ${TMPDEPS})
else()
......
......@@ -12,7 +12,10 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/operators/detection/distribute_fpn_proposals_op.h"
#include "paddle/fluid/framework/infershape_utils.h"
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/phi/core/infermeta_utils.h"
#include "paddle/phi/infermeta/binary.h"
#include "paddle/fluid/framework/op_version_registry.h"
......@@ -23,52 +26,6 @@ class DistributeFpnProposalsOp : public framework::OperatorWithKernel {
public:
using framework::OperatorWithKernel::OperatorWithKernel;
void InferShape(framework::InferShapeContext* ctx) const override {
PADDLE_ENFORCE_EQ(
ctx->HasInput("FpnRois"),
true,
platform::errors::NotFound("Input(FpnRois) of DistributeFpnProposalsOp"
" is not found"));
PADDLE_ENFORCE_GE(ctx->Outputs("MultiFpnRois").size(),
1UL,
platform::errors::InvalidArgument(
"Outputs(MultiFpnRois) of "
"DistributeFpnProposalsOp should not be empty"));
size_t min_level = static_cast<size_t>(ctx->Attrs().Get<int>("min_level"));
size_t max_level = static_cast<size_t>(ctx->Attrs().Get<int>("max_level"));
PADDLE_ENFORCE_GE(
max_level,
min_level,
platform::errors::InvalidArgument(
"max_level must not lower than "
"min_level. But received max_level = %d, min_level = %d",
max_level,
min_level));
// Set the output shape
size_t num_out_rois = max_level - min_level + 1;
std::vector<framework::DDim> outs_dims;
outs_dims.reserve(num_out_rois);
for (size_t i = 0; i < num_out_rois; ++i) {
framework::DDim out_dim = {-1, 4};
outs_dims.push_back(out_dim);
}
ctx->SetOutputsDim("MultiFpnRois", outs_dims);
ctx->SetOutputDim("RestoreIndex", {-1, 1});
if (ctx->HasOutputs("MultiLevelRoIsNum")) {
std::vector<framework::DDim> outs_num_dims;
for (size_t i = 0; i < num_out_rois; ++i) {
outs_num_dims.push_back({-1});
}
ctx->SetOutputsDim("MultiLevelRoIsNum", outs_num_dims);
}
if (!ctx->IsRuntime()) {
for (size_t i = 0; i < num_out_rois; ++i) {
ctx->SetLoDLevel("MultiFpnRois", ctx->GetLoDLevel("FpnRois"), i);
}
}
}
protected:
framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override {
......@@ -125,15 +82,19 @@ we return an array which indicate the original index of rois in
} // namespace paddle
namespace ops = paddle::operators;
DECLARE_INFER_SHAPE_FUNCTOR(
distribute_fpn_proposals,
DistributeFpnProposalsInferShapeFunctor,
PD_INFER_META(phi::DistributeFpnProposalsInferMeta));
REGISTER_OPERATOR(
distribute_fpn_proposals,
ops::DistributeFpnProposalsOp,
ops::DistributeFpnProposalsOpMaker,
paddle::framework::EmptyGradOpMaker<paddle::framework::OpDesc>,
paddle::framework::EmptyGradOpMaker<paddle::imperative::OpBase>);
REGISTER_OP_CPU_KERNEL(distribute_fpn_proposals,
ops::DistributeFpnProposalsOpKernel<float>,
ops::DistributeFpnProposalsOpKernel<double>);
paddle::framework::EmptyGradOpMaker<paddle::imperative::OpBase>,
DistributeFpnProposalsInferShapeFunctor);
REGISTER_OP_VERSION(distribute_fpn_proposals)
.AddCheckpoint(
R"ROC(
......
/* Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#ifdef __NVCC__
#include "cub/cub.cuh"
#endif
#ifdef __HIPCC__
#include <hipcub/hipcub.hpp>
namespace cub = hipcub;
#endif
#include <paddle/fluid/memory/allocation/allocator.h>
#include "paddle/fluid/memory/memcpy.h"
#include "paddle/fluid/operators/detection/bbox_util.h"
#include "paddle/fluid/operators/detection/distribute_fpn_proposals_op.h"
#include "paddle/fluid/platform/device/gpu/gpu_primitives.h"
#include "paddle/fluid/platform/for_range.h"
#include "paddle/phi/kernels/funcs/gather.cu.h"
#include "paddle/phi/kernels/funcs/math_function.h"
namespace paddle {
namespace operators {
using Tensor = framework::Tensor;
using LoDTensor = framework::LoDTensor;
static constexpr int kNumCUDAThreads = 64;
static constexpr int kNumMaxinumNumBlocks = 4096;
int const BBoxSize = 4;
static inline int NumBlocks(const int N) {
return std::min((N + kNumCUDAThreads - 1) / kNumCUDAThreads,
kNumMaxinumNumBlocks);
}
template <class T>
__global__ void GPUDistFpnProposalsHelper(const int nthreads,
const T* rois,
const int lod_size,
const int refer_level,
const int refer_scale,
const int max_level,
const int min_level,
int* roi_batch_id_data,
int* sub_lod_list,
int* target_lvls,
bool pixel_offset = true) {
CUDA_KERNEL_LOOP(i, nthreads) {
const T* offset_roi = rois + i * BBoxSize;
int roi_batch_ind = roi_batch_id_data[i];
// get the target level of current rois
T roi_area = RoIArea(offset_roi, pixel_offset);
T roi_scale = sqrt(roi_area);
int tgt_lvl = floor(
log2(roi_scale / static_cast<T>(refer_scale) + (T)1e-8) + refer_level);
tgt_lvl = min(max_level, max(tgt_lvl, min_level));
target_lvls[i] = tgt_lvl;
// compute number of rois in the same batch and same target level
platform::CudaAtomicAdd(
sub_lod_list + (tgt_lvl - min_level) * lod_size + roi_batch_ind, 1);
}
}
template <typename DeviceContext, typename T>
class GPUDistributeFpnProposalsOpKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& ctx) const override {
auto* fpn_rois = ctx.Input<paddle::framework::LoDTensor>("FpnRois");
auto multi_fpn_rois = ctx.MultiOutput<LoDTensor>("MultiFpnRois");
auto* restore_index = ctx.Output<Tensor>("RestoreIndex");
const int min_level = ctx.Attr<int>("min_level");
const int max_level = ctx.Attr<int>("max_level");
const int refer_level = ctx.Attr<int>("refer_level");
const int refer_scale = ctx.Attr<int>("refer_scale");
const bool pixel_offset = ctx.Attr<bool>("pixel_offset");
int num_level = max_level - min_level + 1;
// check that the fpn_rois is not empty
if (!ctx.HasInput("RoisNum")) {
PADDLE_ENFORCE_EQ(
fpn_rois->lod().size(),
1UL,
platform::errors::InvalidArgument("DistributeFpnProposalsOp needs LoD"
"with one level"));
}
std::vector<size_t> fpn_rois_lod;
if (ctx.HasInput("RoisNum")) {
auto* rois_num = ctx.Input<Tensor>("RoisNum");
fpn_rois_lod = GetLodFromRoisNum(rois_num);
} else {
fpn_rois_lod = fpn_rois->lod().back();
}
int lod_size = fpn_rois_lod.size() - 1;
int roi_num = fpn_rois_lod[lod_size];
auto& dev_ctx = ctx.template device_context<DeviceContext>();
// get batch id by lod in CPU
Tensor roi_batch_id_list;
roi_batch_id_list.Resize({roi_num});
int* roi_batch_id_data =
roi_batch_id_list.mutable_data<int>(platform::CPUPlace());
for (int n = 0; n < lod_size; ++n) {
for (size_t i = fpn_rois_lod[n]; i < fpn_rois_lod[n + 1]; ++i) {
roi_batch_id_data[i] = n;
}
}
// copy batch id list to GPU
Tensor roi_batch_id_list_gpu;
framework::TensorCopySync(
roi_batch_id_list, dev_ctx.GetPlace(), &roi_batch_id_list_gpu);
Tensor sub_lod_list;
sub_lod_list.Resize({num_level, lod_size});
int* sub_lod_list_data = sub_lod_list.mutable_data<int>(dev_ctx.GetPlace());
phi::funcs::SetConstant<phi::GPUContext, int> set_zero;
set_zero(dev_ctx, &sub_lod_list, static_cast<int>(0));
Tensor target_lvls;
target_lvls.Resize({roi_num});
int* target_lvls_data = target_lvls.mutable_data<int>(dev_ctx.GetPlace());
int dist_blocks = NumBlocks(roi_num);
int threads = kNumCUDAThreads;
// get target levels and sub_lod list
GPUDistFpnProposalsHelper<T><<<dist_blocks, threads, 0, dev_ctx.stream()>>>(
roi_num,
fpn_rois->data<T>(),
lod_size,
refer_level,
refer_scale,
max_level,
min_level,
roi_batch_id_list_gpu.data<int>(),
sub_lod_list_data,
target_lvls_data,
pixel_offset);
auto place = dev_ctx.GetPlace();
Tensor index_in_t;
int* idx_in = index_in_t.mutable_data<int>({roi_num}, dev_ctx.GetPlace());
platform::ForRange<phi::GPUContext> for_range(dev_ctx, roi_num);
for_range(RangeInitFunctor{0, 1, idx_in});
Tensor keys_out_t;
int* keys_out = keys_out_t.mutable_data<int>({roi_num}, dev_ctx.GetPlace());
Tensor index_out_t;
int* idx_out = index_out_t.mutable_data<int>({roi_num}, dev_ctx.GetPlace());
// Determine temporary device storage requirements
size_t temp_storage_bytes = 0;
cub::DeviceRadixSort::SortPairs<int, int>(nullptr,
temp_storage_bytes,
target_lvls_data,
keys_out,
idx_in,
idx_out,
roi_num,
0,
sizeof(int) * 8,
dev_ctx.stream());
// Allocate temporary storage
auto d_temp_storage = memory::Alloc(place, temp_storage_bytes);
// Run sorting operation
// sort target level to get corresponding index
cub::DeviceRadixSort::SortPairs<int, int>(d_temp_storage->ptr(),
temp_storage_bytes,
target_lvls_data,
keys_out,
idx_in,
idx_out,
roi_num,
0,
sizeof(int) * 8,
dev_ctx.stream());
int* restore_idx_data =
restore_index->mutable_data<int>({roi_num, 1}, dev_ctx.GetPlace());
// sort current index to get restore index
cub::DeviceRadixSort::SortPairs<int, int>(d_temp_storage->ptr(),
temp_storage_bytes,
idx_out,
keys_out,
idx_in,
restore_idx_data,
roi_num,
0,
sizeof(int) * 8,
dev_ctx.stream());
int start = 0;
auto multi_rois_num = ctx.MultiOutput<Tensor>("MultiLevelRoIsNum");
std::vector<int> sub_lod_list_cpu(lod_size * num_level);
memory::Copy(platform::CPUPlace(),
sub_lod_list_cpu.data(),
place,
sub_lod_list_data,
sizeof(int) * lod_size * num_level,
dev_ctx.stream());
dev_ctx.Wait();
for (int i = 0; i < num_level; ++i) {
Tensor sub_lod = sub_lod_list.Slice(i, i + 1);
// transfer length-based lod to offset-based lod
std::vector<size_t> offset(1, 0);
for (int j = 0; j < lod_size; ++j) {
offset.emplace_back(offset.back() + sub_lod_list_cpu[i * lod_size + j]);
}
int sub_rois_num = offset.back();
int end = start + sub_rois_num;
if (end > start) {
Tensor sub_idx = index_out_t.Slice(start, end);
start = end;
multi_fpn_rois[i]->mutable_data<T>({sub_rois_num, kBoxDim},
dev_ctx.GetPlace());
phi::funcs::GPUGather<T>(
dev_ctx, *fpn_rois, sub_idx, multi_fpn_rois[i]);
} else {
multi_fpn_rois[i]->mutable_data<T>({sub_rois_num, kBoxDim},
dev_ctx.GetPlace());
}
if (multi_rois_num.size() > 0) {
Tensor* rois_num_t = multi_rois_num[i];
paddle::framework::TensorCopySync(
sub_lod, dev_ctx.GetPlace(), rois_num_t);
rois_num_t->Resize({lod_size});
}
framework::LoD lod;
lod.emplace_back(offset);
multi_fpn_rois[i]->set_lod(lod);
}
}
};
} // namespace operators
} // namespace paddle
namespace ops = paddle::operators;
REGISTER_OP_CUDA_KERNEL(
distribute_fpn_proposals,
ops::GPUDistributeFpnProposalsOpKernel<phi::GPUContext, float>,
ops::GPUDistributeFpnProposalsOpKernel<phi::GPUContext, double>);
/* Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include <algorithm>
#include <cmath>
#include <cstring>
#include <string>
#include <vector>
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/phi/kernels/funcs/math_function.h"
namespace paddle {
namespace operators {
const int kBoxDim = 4;
inline std::vector<size_t> GetLodFromRoisNum(
const framework::Tensor* rois_num) {
std::vector<size_t> rois_lod;
auto* rois_num_data = rois_num->data<int>();
framework::Tensor cpu_tensor;
if (platform::is_gpu_place(rois_num->place())) {
paddle::framework::TensorCopySync(
*rois_num, platform::CPUPlace(), &cpu_tensor);
rois_num_data = cpu_tensor.data<int>();
}
rois_lod.push_back(static_cast<size_t>(0));
for (int i = 0; i < rois_num->numel(); ++i) {
rois_lod.push_back(rois_lod.back() + static_cast<size_t>(rois_num_data[i]));
}
return rois_lod;
}
template <typename T>
static inline T BBoxArea(const T* box, bool pixel_offset) {
if (box[2] < box[0] || box[3] < box[1]) {
// If coordinate values are is invalid
// (e.g. xmax < xmin or ymax < ymin), return 0.
return static_cast<T>(0.);
} else {
const T w = box[2] - box[0];
const T h = box[3] - box[1];
if (pixel_offset) {
// If coordinate values are not within range [0, 1].
return (w + 1) * (h + 1);
} else {
return w * h;
}
}
}
template <typename T>
class DistributeFpnProposalsOpKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& context) const override {
auto* fpn_rois = context.Input<paddle::framework::LoDTensor>("FpnRois");
auto multi_fpn_rois =
context.MultiOutput<paddle::framework::LoDTensor>("MultiFpnRois");
auto* restore_index =
context.Output<paddle::framework::Tensor>("RestoreIndex");
const int min_level = context.Attr<int>("min_level");
const int max_level = context.Attr<int>("max_level");
const int refer_level = context.Attr<int>("refer_level");
const int refer_scale = context.Attr<int>("refer_scale");
const bool pixel_offset = context.Attr<bool>("pixel_offset");
const int num_level = max_level - min_level + 1;
// check that the fpn_rois is not empty
if (!context.HasInput("RoisNum")) {
PADDLE_ENFORCE_EQ(fpn_rois->lod().size(),
1UL,
platform::errors::InvalidArgument(
"DistributeFpnProposalsOp needs LoD "
"with one level. But received level is %d",
fpn_rois->lod().size()));
}
std::vector<size_t> fpn_rois_lod;
int fpn_rois_num;
if (context.HasInput("RoisNum")) {
auto* rois_num = context.Input<framework::Tensor>("RoisNum");
fpn_rois_lod = GetLodFromRoisNum(rois_num);
} else {
fpn_rois_lod = fpn_rois->lod().back();
}
fpn_rois_num = fpn_rois_lod[fpn_rois_lod.size() - 1];
std::vector<int> target_level;
// std::vector<int> target_level(fpn_rois_num, -1);
// record the number of rois in each level
std::vector<int> num_rois_level(num_level, 0);
std::vector<int> num_rois_level_integral(num_level + 1, 0);
for (size_t i = 0; i < fpn_rois_lod.size() - 1; ++i) {
auto fpn_rois_slice =
fpn_rois->Slice(fpn_rois_lod[i], fpn_rois_lod[i + 1]);
const T* rois_data = fpn_rois_slice.data<T>();
for (int j = 0; j < fpn_rois_slice.dims()[0]; ++j) {
// get the target level of current rois
T roi_scale = std::sqrt(BBoxArea(rois_data, pixel_offset));
int tgt_lvl = std::floor(std::log2(roi_scale / refer_scale + (T)1e-6) +
refer_level);
tgt_lvl = std::min(max_level, std::max(tgt_lvl, min_level));
target_level.push_back(tgt_lvl);
num_rois_level[tgt_lvl - min_level]++;
rois_data += kBoxDim;
}
}
// define the output rois
// pointer which point to each level fpn rois
std::vector<T*> multi_fpn_rois_data(num_level);
// lod0 which will record the offset information of each level rois
std::vector<std::vector<size_t>> multi_fpn_rois_lod0;
for (int i = 0; i < num_level; ++i) {
// allocate memory for each level rois
multi_fpn_rois[i]->mutable_data<T>({num_rois_level[i], kBoxDim},
context.GetPlace());
multi_fpn_rois_data[i] = multi_fpn_rois[i]->data<T>();
std::vector<size_t> lod0(1, 0);
multi_fpn_rois_lod0.push_back(lod0);
// statistic start point for each level rois
num_rois_level_integral[i + 1] =
num_rois_level_integral[i] + num_rois_level[i];
}
restore_index->mutable_data<int>({fpn_rois_num, 1}, context.GetPlace());
int* restore_index_data = restore_index->data<int>();
std::vector<int> restore_index_inter(fpn_rois_num, -1);
// distribute the rois into different fpn level by target level
for (size_t i = 0; i < fpn_rois_lod.size() - 1; ++i) {
auto fpn_rois_slice =
fpn_rois->Slice(fpn_rois_lod[i], fpn_rois_lod[i + 1]);
const T* rois_data = fpn_rois_slice.data<T>();
size_t cur_offset = fpn_rois_lod[i];
// std::vector<size_t > lod_offset[num_level];
for (int j = 0; j < num_level; j++) {
multi_fpn_rois_lod0[j].push_back(multi_fpn_rois_lod0[j][i]);
}
for (int j = 0; j < fpn_rois_slice.dims()[0]; ++j) {
int lvl = target_level[cur_offset + j];
memcpy(multi_fpn_rois_data[lvl - min_level],
rois_data,
kBoxDim * sizeof(T));
multi_fpn_rois_data[lvl - min_level] += kBoxDim;
int index_in_shuffle = num_rois_level_integral[lvl - min_level] +
multi_fpn_rois_lod0[lvl - min_level][i + 1];
restore_index_inter[index_in_shuffle] = cur_offset + j;
multi_fpn_rois_lod0[lvl - min_level][i + 1]++;
rois_data += kBoxDim;
}
}
for (int i = 0; i < fpn_rois_num; ++i) {
restore_index_data[restore_index_inter[i]] = i;
}
auto multi_rois_num =
context.MultiOutput<framework::Tensor>("MultiLevelRoIsNum");
if (multi_rois_num.size() > 0) {
int batch_size = fpn_rois_lod.size() - 1;
for (int i = 0; i < num_level; ++i) {
int* rois_num_data = multi_rois_num[i]->mutable_data<int>(
{batch_size}, context.GetPlace());
for (int j = 0; j < batch_size; ++j) {
rois_num_data[j] = static_cast<int>(multi_fpn_rois_lod0[i][j + 1] -
multi_fpn_rois_lod0[i][j]);
}
}
}
// merge lod information into LoDTensor
for (int i = 0; i < num_level; ++i) {
framework::LoD lod;
lod.emplace_back(multi_fpn_rois_lod0[i]);
multi_fpn_rois[i]->set_lod(lod);
}
}
};
} // namespace operators
} // namespace paddle
......@@ -704,6 +704,16 @@
kernel :
func : diag_embed
- api : distribute_fpn_proposals
args : (Tensor fpn_rois, Tensor rois_num, int min_level, int max_level, int refer_level, int refer_scale, bool pixel_offset)
output : Tensor[](multi_fpn_rois){max_level - min_level + 1}, Tensor[](multi_level_rois_num){max_level - min_level + 1}, Tensor(restore_index)
infer_meta :
func : DistributeFpnProposalsInferMeta
kernel :
func : distribute_fpn_proposals
data_type : fpn_rois
optional : rois_num
- api : divide
args : (Tensor x, Tensor y)
output : Tensor
......
......@@ -122,8 +122,15 @@ class KernelContext {
template <typename TensorType>
std::vector<TensorType*> MutableOutputBetween(size_t start, size_t end) {
std::vector<TensorType*> v;
bool is_empty_vector = true;
for (size_t i = start; i < end; ++i) {
v.emplace_back(static_cast<TensorType*>(outputs_.at(i)));
if (outputs_.at(i) != nullptr) {
is_empty_vector = false;
}
}
if (is_empty_vector) {
v.clear();
}
return v;
}
......
......@@ -951,6 +951,57 @@ void DistInferMeta(const MetaTensor& x,
out->set_dtype(x.dtype());
}
void DistributeFpnProposalsInferMeta(
const MetaTensor& fpn_rois,
const MetaTensor& rois_num,
int min_level,
int max_level,
int refer_level,
int refer_scale,
bool pixel_offset,
std::vector<MetaTensor*> multi_fpn_rois,
std::vector<MetaTensor*> multi_level_rois_num,
MetaTensor* restore_index,
MetaConfig config) {
PADDLE_ENFORCE_GE(
multi_fpn_rois.size(),
1UL,
errors::InvalidArgument("Outputs(MultiFpnRois) of "
"DistributeFpnProposalsOp should not be empty"));
PADDLE_ENFORCE_GE(
max_level,
min_level,
errors::InvalidArgument(
"max_level must not lower than "
"min_level. But received max_level = %d, min_level = %d",
max_level,
min_level));
// Set the output shape
for (size_t i = 0; i < multi_fpn_rois.size(); ++i) {
DDim out_dim = {-1, 4};
if (multi_fpn_rois[i] == nullptr) {
continue;
}
multi_fpn_rois[i]->set_dims(out_dim);
multi_fpn_rois[i]->set_dtype(fpn_rois.dtype());
}
restore_index->set_dims({-1, 1});
restore_index->set_dtype(DataType::INT32);
for (size_t i = 0; i < multi_level_rois_num.size(); ++i) {
if (multi_level_rois_num[i] == nullptr) {
continue;
}
multi_level_rois_num[i]->set_dims({-1});
multi_level_rois_num[i]->set_dtype(DataType::INT32);
}
if (!config.is_runtime) {
for (size_t i = 0; i < multi_fpn_rois.size(); ++i) {
multi_fpn_rois[i]->share_lod(fpn_rois);
}
}
}
void DropoutInferMeta(const MetaTensor& x,
const MetaTensor& seed_tensor,
const Scalar& p,
......
......@@ -139,6 +139,19 @@ void DistInferMeta(const MetaTensor& x,
float p,
MetaTensor* out);
void DistributeFpnProposalsInferMeta(
const MetaTensor& fpn_rois,
const MetaTensor& rois_num,
int min_level,
int max_level,
int refer_level,
int refer_scale,
bool pixel_offset,
std::vector<MetaTensor*> multi_fpn_rois,
std::vector<MetaTensor*> multi_level_rois_num,
MetaTensor* restore_index,
MetaConfig config = MetaConfig());
void DotInferMeta(const MetaTensor& x, const MetaTensor& y, MetaTensor* out);
void DropoutInferMeta(const MetaTensor& x,
......
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/phi/kernels/distribute_fpn_proposals_kernel.h"
#include "paddle/phi/backends/cpu/cpu_context.h"
#include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/kernels/funcs/distribute_fpn_proposals_functor.h"
namespace phi {
template <typename T, typename Context>
void DistributeFpnProposalsKernel(
const Context& dev_ctx,
const DenseTensor& fpn_rois,
const paddle::optional<DenseTensor>& rois_num,
int min_level,
int max_level,
int refer_level,
int refer_scale,
bool pixel_offset,
std::vector<DenseTensor*> multi_fpn_rois,
std::vector<DenseTensor*> multi_level_rois_num,
DenseTensor* restore_index) {
const int num_level = max_level - min_level + 1;
// check that the fpn_rois is not empty
if (!rois_num.get_ptr()) {
PADDLE_ENFORCE_EQ(
fpn_rois.lod().size(),
1UL,
errors::InvalidArgument("DistributeFpnProposalsOp needs LoD "
"with one level. But received level is %d",
fpn_rois.lod().size()));
}
std::vector<size_t> fpn_rois_lod;
int fpn_rois_num;
if (rois_num.get_ptr()) {
fpn_rois_lod = funcs::GetLodFromRoisNum(dev_ctx, rois_num.get_ptr());
} else {
fpn_rois_lod = fpn_rois.lod().back();
}
fpn_rois_num = fpn_rois_lod[fpn_rois_lod.size() - 1];
std::vector<int> target_level;
// record the number of rois in each level
std::vector<int> num_rois_level(num_level, 0);
std::vector<int> num_rois_level_integral(num_level + 1, 0);
for (size_t i = 0; i < fpn_rois_lod.size() - 1; ++i) {
auto fpn_rois_slice = fpn_rois.Slice(fpn_rois_lod[i], fpn_rois_lod[i + 1]);
const T* rois_data = fpn_rois_slice.data<T>();
for (int j = 0; j < fpn_rois_slice.dims()[0]; ++j) {
// get the target level of current rois
T roi_scale = std::sqrt(funcs::BBoxArea(rois_data, pixel_offset));
int tgt_lvl = std::floor(std::log2(roi_scale / refer_scale + (T)1e-6) +
refer_level);
tgt_lvl = std::min(max_level, std::max(tgt_lvl, min_level));
target_level.push_back(tgt_lvl);
num_rois_level[tgt_lvl - min_level]++;
rois_data += funcs::kBoxDim;
}
}
// define the output rois
// pointer which point to each level fpn rois
std::vector<T*> multi_fpn_rois_data(num_level);
// lod0 which will record the offset information of each level rois
std::vector<std::vector<size_t>> multi_fpn_rois_lod0;
for (int i = 0; i < num_level; ++i) {
// allocate memory for each level rois
multi_fpn_rois[i]->Resize({num_rois_level[i], funcs::kBoxDim});
multi_fpn_rois_data[i] = dev_ctx.template Alloc<T>(multi_fpn_rois[i]);
std::vector<size_t> lod0(1, 0);
multi_fpn_rois_lod0.push_back(lod0);
// statistic start point for each level rois
num_rois_level_integral[i + 1] =
num_rois_level_integral[i] + num_rois_level[i];
}
restore_index->Resize({fpn_rois_num, 1});
int* restore_index_data = dev_ctx.template Alloc<int>(restore_index);
std::vector<int> restore_index_inter(fpn_rois_num, -1);
// distribute the rois into different fpn level by target level
for (size_t i = 0; i < fpn_rois_lod.size() - 1; ++i) {
auto fpn_rois_slice = fpn_rois.Slice(fpn_rois_lod[i], fpn_rois_lod[i + 1]);
const T* rois_data = fpn_rois_slice.data<T>();
size_t cur_offset = fpn_rois_lod[i];
for (int j = 0; j < num_level; j++) {
multi_fpn_rois_lod0[j].push_back(multi_fpn_rois_lod0[j][i]);
}
for (int j = 0; j < fpn_rois_slice.dims()[0]; ++j) {
int lvl = target_level[cur_offset + j];
memcpy(multi_fpn_rois_data[lvl - min_level],
rois_data,
funcs::kBoxDim * sizeof(T));
multi_fpn_rois_data[lvl - min_level] += funcs::kBoxDim;
int index_in_shuffle = num_rois_level_integral[lvl - min_level] +
multi_fpn_rois_lod0[lvl - min_level][i + 1];
restore_index_inter[index_in_shuffle] = cur_offset + j;
multi_fpn_rois_lod0[lvl - min_level][i + 1]++;
rois_data += funcs::kBoxDim;
}
}
for (int i = 0; i < fpn_rois_num; ++i) {
restore_index_data[restore_index_inter[i]] = i;
}
if (multi_level_rois_num.size() > 0) {
int batch_size = fpn_rois_lod.size() - 1;
for (int i = 0; i < num_level; ++i) {
multi_level_rois_num[i]->Resize({batch_size});
int* rois_num_data = dev_ctx.template Alloc<int>(multi_level_rois_num[i]);
for (int j = 0; j < batch_size; ++j) {
rois_num_data[j] = static_cast<int>(multi_fpn_rois_lod0[i][j + 1] -
multi_fpn_rois_lod0[i][j]);
}
}
}
// merge lod information into LoDTensor
for (int i = 0; i < num_level; ++i) {
LoD lod;
lod.emplace_back(multi_fpn_rois_lod0[i]);
multi_fpn_rois[i]->set_lod(lod);
}
}
} // namespace phi
PD_REGISTER_KERNEL(distribute_fpn_proposals,
CPU,
ALL_LAYOUT,
phi::DistributeFpnProposalsKernel,
float,
double) {}
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include "paddle/phi/core/dense_tensor.h"
namespace phi {
template <typename T, typename Context>
void DistributeFpnProposalsKernel(
const Context& ctx,
const DenseTensor& fpn_rois,
const paddle::optional<DenseTensor>& rois_num,
int min_level,
int max_level,
int refer_level,
int refer_scale,
bool pixel_offset,
std::vector<DenseTensor*> multi_fpn_rois,
std::vector<DenseTensor*> multi_level_rois_num,
DenseTensor* restore_index);
} // namespace phi
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <algorithm>
#include <cmath>
#include <cstring>
#include <string>
#include <vector>
#include "paddle/phi/common/place.h"
#include "paddle/phi/core/tensor_utils.h"
#include "paddle/phi/kernels/funcs/math_function.h"
namespace phi {
namespace funcs {
const int kBoxDim = 4;
template <typename Context>
inline std::vector<size_t> GetLodFromRoisNum(const Context& dev_ctx,
const DenseTensor* rois_num) {
std::vector<size_t> rois_lod;
auto* rois_num_data = rois_num->data<int>();
DenseTensor cpu_tensor;
if (paddle::platform::is_gpu_place(rois_num->place())) {
Copy<Context>(dev_ctx, *rois_num, phi::CPUPlace(), true, &cpu_tensor);
rois_num_data = cpu_tensor.data<int>();
}
rois_lod.push_back(static_cast<size_t>(0));
for (int i = 0; i < rois_num->numel(); ++i) {
rois_lod.push_back(rois_lod.back() + static_cast<size_t>(rois_num_data[i]));
}
return rois_lod;
}
template <typename T>
static inline T BBoxArea(const T* box, bool pixel_offset) {
if (box[2] < box[0] || box[3] < box[1]) {
// If coordinate values are is invalid
// (e.g. xmax < xmin or ymax < ymin), return 0.
return static_cast<T>(0.);
} else {
const T w = box[2] - box[0];
const T h = box[3] - box[1];
if (pixel_offset) {
// If coordinate values are not within range [0, 1].
return (w + 1) * (h + 1);
} else {
return w * h;
}
}
}
} // namespace funcs
} // namespace phi
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#ifdef __NVCC__
#include "cub/cub.cuh"
#endif
#ifdef __HIPCC__
#include <hipcub/hipcub.hpp>
namespace cub = hipcub;
#endif
#include "paddle/phi/kernels/distribute_fpn_proposals_kernel.h"
#include "paddle/phi/backends/gpu/gpu_context.h"
#include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/kernels/funcs/distribute_fpn_proposals_functor.h"
#include "paddle/phi/kernels/funcs/for_range.h"
#include "paddle/phi/kernels/funcs/gather.cu.h"
#include "paddle/phi/kernels/funcs/math_function.h"
#include "paddle/fluid/memory/allocation/allocator.h"
#include "paddle/fluid/memory/memcpy.h"
#include "paddle/fluid/operators/detection/bbox_util.h"
#include "paddle/fluid/platform/device/gpu/gpu_primitives.h"
namespace phi {
static constexpr int kNumCUDAThreads = 64;
static constexpr int kNumMaxinumNumBlocks = 4096;
int const BBoxSize = 4;
static inline int NumBlocks(const int N) {
return std::min((N + kNumCUDAThreads - 1) / kNumCUDAThreads,
kNumMaxinumNumBlocks);
}
template <class T>
__global__ void GPUDistFpnProposalsHelper(const int nthreads,
const T* rois,
const int lod_size,
const int refer_level,
const int refer_scale,
const int max_level,
const int min_level,
int* roi_batch_id_data,
int* sub_lod_list,
int* target_lvls,
bool pixel_offset = true) {
CUDA_KERNEL_LOOP(i, nthreads) {
const T* offset_roi = rois + i * BBoxSize;
int roi_batch_ind = roi_batch_id_data[i];
// get the target level of current rois
T roi_area = paddle::operators::RoIArea(offset_roi, pixel_offset);
T roi_scale = sqrt(roi_area);
int tgt_lvl = floor(
log2(roi_scale / static_cast<T>(refer_scale) + (T)1e-8) + refer_level);
tgt_lvl = min(max_level, max(tgt_lvl, min_level));
target_lvls[i] = tgt_lvl;
// compute number of rois in the same batch and same target level
paddle::platform::CudaAtomicAdd(
sub_lod_list + (tgt_lvl - min_level) * lod_size + roi_batch_ind, 1);
}
}
template <typename T, typename Context>
void DistributeFpnProposalsKernel(
const Context& dev_ctx,
const DenseTensor& fpn_rois,
const paddle::optional<DenseTensor>& rois_num,
int min_level,
int max_level,
int refer_level,
int refer_scale,
bool pixel_offset,
std::vector<DenseTensor*> multi_fpn_rois,
std::vector<DenseTensor*> multi_level_rois_num,
DenseTensor* restore_index) {
int num_level = max_level - min_level + 1;
// check that the fpn_rois is not empty
if (!rois_num.get_ptr()) {
PADDLE_ENFORCE_EQ(
fpn_rois.lod().size(),
1UL,
errors::InvalidArgument("DistributeFpnProposalsOp needs LoD"
"with one level"));
}
std::vector<size_t> fpn_rois_lod;
if (rois_num.get_ptr()) {
fpn_rois_lod = funcs::GetLodFromRoisNum(dev_ctx, rois_num.get_ptr());
} else {
fpn_rois_lod = fpn_rois.lod().back();
}
int lod_size = fpn_rois_lod.size() - 1;
int roi_num = fpn_rois_lod[lod_size];
// get batch id by lod in CPU
DenseTensor roi_batch_id_list;
roi_batch_id_list.Resize({roi_num});
int* roi_batch_id_data = dev_ctx.template HostAlloc<int>(&roi_batch_id_list);
for (int n = 0; n < lod_size; ++n) {
for (size_t i = fpn_rois_lod[n]; i < fpn_rois_lod[n + 1]; ++i) {
roi_batch_id_data[i] = n;
}
}
// copy batch id list to GPU
DenseTensor roi_batch_id_list_gpu;
Copy(dev_ctx,
roi_batch_id_list,
dev_ctx.GetPlace(),
true,
&roi_batch_id_list_gpu);
DenseTensor sub_lod_list;
sub_lod_list.Resize({num_level, lod_size});
int* sub_lod_list_data = dev_ctx.template Alloc<int>(&sub_lod_list);
phi::funcs::SetConstant<phi::GPUContext, int> set_zero;
set_zero(dev_ctx, &sub_lod_list, static_cast<int>(0));
DenseTensor target_lvls;
target_lvls.Resize({roi_num});
int* target_lvls_data = dev_ctx.template Alloc<int>(&target_lvls);
int dist_blocks = NumBlocks(roi_num);
int threads = kNumCUDAThreads;
// get target levels and sub_lod list
GPUDistFpnProposalsHelper<T><<<dist_blocks, threads, 0, dev_ctx.stream()>>>(
roi_num,
fpn_rois.data<T>(),
lod_size,
refer_level,
refer_scale,
max_level,
min_level,
roi_batch_id_list_gpu.data<int>(),
sub_lod_list_data,
target_lvls_data,
pixel_offset);
auto place = dev_ctx.GetPlace();
DenseTensor index_in_t;
index_in_t.Resize({roi_num});
int* idx_in = dev_ctx.template Alloc<int>(&index_in_t);
funcs::ForRange<phi::GPUContext> for_range(dev_ctx, roi_num);
for_range(paddle::operators::RangeInitFunctor{0, 1, idx_in});
DenseTensor keys_out_t;
keys_out_t.Resize({roi_num});
int* keys_out = dev_ctx.template Alloc<int>(&keys_out_t);
DenseTensor index_out_t;
index_out_t.Resize({roi_num});
int* idx_out = dev_ctx.template Alloc<int>(&index_out_t);
// Determine temporary device storage requirements
size_t temp_storage_bytes = 0;
cub::DeviceRadixSort::SortPairs<int, int>(nullptr,
temp_storage_bytes,
target_lvls_data,
keys_out,
idx_in,
idx_out,
roi_num,
0,
sizeof(int) * 8,
dev_ctx.stream());
// Allocate temporary storage
auto d_temp_storage = paddle::memory::Alloc(place, temp_storage_bytes);
// Run sorting operation
// sort target level to get corresponding index
cub::DeviceRadixSort::SortPairs<int, int>(d_temp_storage->ptr(),
temp_storage_bytes,
target_lvls_data,
keys_out,
idx_in,
idx_out,
roi_num,
0,
sizeof(int) * 8,
dev_ctx.stream());
restore_index->Resize({roi_num, 1});
int* restore_idx_data = dev_ctx.template Alloc<int>(restore_index);
// sort current index to get restore index
cub::DeviceRadixSort::SortPairs<int, int>(d_temp_storage->ptr(),
temp_storage_bytes,
idx_out,
keys_out,
idx_in,
restore_idx_data,
roi_num,
0,
sizeof(int) * 8,
dev_ctx.stream());
int start = 0;
std::vector<int> sub_lod_list_cpu(lod_size * num_level);
paddle::memory::Copy(phi::CPUPlace(),
sub_lod_list_cpu.data(),
place,
sub_lod_list_data,
sizeof(int) * lod_size * num_level,
dev_ctx.stream());
dev_ctx.Wait();
for (int i = 0; i < num_level; ++i) {
DenseTensor sub_lod = sub_lod_list.Slice(i, i + 1);
// transfer length-based lod to offset-based lod
std::vector<size_t> offset(1, 0);
for (int j = 0; j < lod_size; ++j) {
offset.emplace_back(offset.back() + sub_lod_list_cpu[i * lod_size + j]);
}
int sub_rois_num = offset.back();
int end = start + sub_rois_num;
if (end > start) {
DenseTensor sub_idx = index_out_t.Slice(start, end);
start = end;
multi_fpn_rois[i]->Resize({sub_rois_num, funcs::kBoxDim});
dev_ctx.template Alloc<T>(multi_fpn_rois[i]);
phi::funcs::GPUGather<T>(dev_ctx, fpn_rois, sub_idx, multi_fpn_rois[i]);
} else {
multi_fpn_rois[i]->Resize({sub_rois_num, funcs::kBoxDim});
dev_ctx.template Alloc<T>(multi_fpn_rois[i]);
}
if (multi_level_rois_num.size() > 0) {
DenseTensor* rois_num_t = multi_level_rois_num[i];
Copy(dev_ctx, sub_lod, dev_ctx.GetPlace(), true, rois_num_t);
rois_num_t->Resize({lod_size});
}
LoD lod;
lod.emplace_back(offset);
multi_fpn_rois[i]->set_lod(lod);
}
}
} // namespace phi
PD_REGISTER_KERNEL(distribute_fpn_proposals,
GPU,
ALL_LAYOUT,
phi::DistributeFpnProposalsKernel,
float,
double) {}
/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/phi/core/compat/op_utils.h"
namespace phi {
KernelSignature DistributeFpnProposalsOpArgumentMapping(
const ArgumentMappingContext& ctx) {
return KernelSignature(
"distribute_fpn_proposals",
{"FpnRois", "RoisNum"},
{"min_level", "max_level", "refer_level", "refer_scale", "pixel_offset"},
{"MultiFpnRois", "MultiLevelRoIsNum", "RestoreIndex"});
}
} // namespace phi
PD_REGISTER_ARG_MAPPING_FN(distribute_fpn_proposals,
phi::DistributeFpnProposalsOpArgumentMapping);
......@@ -23,6 +23,16 @@ import paddle
from op_test import OpTest
def distribute_fpn_proposals_wrapper(fpn_rois, rois_num, min_level, max_level,
refer_level, refer_scale, pixel_offset):
return paddle.vision.ops.distribute_fpn_proposals(fpn_rois=fpn_rois,
min_level=min_level,
max_level=max_level,
refer_level=refer_level,
refer_scale=refer_scale,
rois_num=rois_num)
class TestDistributeFPNProposalsOp(OpTest):
def set_data(self):
......@@ -44,6 +54,8 @@ class TestDistributeFPNProposalsOp(OpTest):
'MultiFpnRois': output,
'RestoreIndex': self.rois_idx_restore.reshape(-1, 1),
}
self.python_api = distribute_fpn_proposals_wrapper
self.python_out_sig = ['MultiFpnRois', 'RestoreIndex']
def init_test_case(self):
self.roi_max_level = 5
......@@ -152,6 +164,10 @@ class TestDistributeFPNProposalsOpWithRoisNum(TestDistributeFPNProposalsOp):
'RestoreIndex': self.rois_idx_restore.reshape(-1, 1),
'MultiLevelRoIsNum': rois_num_per_level
}
self.python_api = distribute_fpn_proposals_wrapper
self.python_out_sig = [
'MultiFpnRois', 'MultiLevelRoIsNum', 'RestoreIndex'
]
class TestDistributeFPNProposalsOpNoOffset(
......
......@@ -899,6 +899,13 @@ def distribute_fpn_proposals(fpn_rois,
"""
num_lvl = max_level - min_level + 1
if in_dygraph_mode():
assert rois_num is not None, "rois_num should not be None in dygraph mode."
multi_rois, rois_num_per_level, restore_ind = _C_ops.final_state_distribute_fpn_proposals(
fpn_rois, rois_num, min_level, max_level, refer_level, refer_scale,
pixel_offset)
return multi_rois, restore_ind, rois_num_per_level
if _non_static_mode():
assert rois_num is not None, "rois_num should not be None in dygraph mode."
attrs = ('min_level', min_level, 'max_level', max_level, 'refer_level',
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册