未验证 提交 5262b025 编写于 作者: W wangguanzhong 提交者: GitHub

add generate_proposals_v2 op (#28214)

* add generate_proposals_v2 op
上级 b96869bc
......@@ -46,10 +46,12 @@ if(WITH_GPU)
set(TMPDEPS memory cub)
endif()
detection_library(generate_proposals_op SRCS generate_proposals_op.cc generate_proposals_op.cu DEPS ${TMPDEPS})
detection_library(generate_proposals_v2_op SRCS generate_proposals_v2_op.cc generate_proposals_v2_op.cu DEPS ${TMPDEPS})
detection_library(distribute_fpn_proposals_op SRCS distribute_fpn_proposals_op.cc distribute_fpn_proposals_op.cu DEPS ${TMPDEPS})
detection_library(collect_fpn_proposals_op SRCS collect_fpn_proposals_op.cc collect_fpn_proposals_op.cu DEPS ${TMPDEPS})
else()
detection_library(generate_proposals_op SRCS generate_proposals_op.cc)
detection_library(generate_proposals_v2_op SRCS generate_proposals_v2_op.cc)
detection_library(distribute_fpn_proposals_op SRCS distribute_fpn_proposals_op.cc)
detection_library(collect_fpn_proposals_op SRCS collect_fpn_proposals_op.cc)
endif()
......
/* Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include <algorithm>
#include <cfloat>
#include <string>
#include <vector>
#include "cub/cub.cuh"
#include "paddle/fluid/operators/gather.cu.h"
#include "paddle/fluid/operators/math/math_function.h"
#include "paddle/fluid/platform/cudnn_helper.h"
#include "paddle/fluid/platform/for_range.h"
namespace paddle {
namespace operators {
using Tensor = framework::Tensor;
using LoDTensor = framework::LoDTensor;
#define DIVUP(m, n) ((m) / (n) + ((m) % (n) > 0))
int const kThreadsPerBlock = sizeof(uint64_t) * 8;
static const double kBBoxClipDefault = std::log(1000.0 / 16.0);
struct RangeInitFunctor {
int start_;
int delta_;
int *out_;
__device__ void operator()(size_t i) { out_[i] = start_ + i * delta_; }
};
template <typename T>
static void SortDescending(const platform::CUDADeviceContext &ctx,
const Tensor &value, Tensor *value_out,
Tensor *index_out) {
int num = static_cast<int>(value.numel());
Tensor index_in_t;
int *idx_in = index_in_t.mutable_data<int>({num}, ctx.GetPlace());
platform::ForRange<platform::CUDADeviceContext> for_range(ctx, num);
for_range(RangeInitFunctor{0, 1, idx_in});
int *idx_out = index_out->mutable_data<int>({num}, ctx.GetPlace());
const T *keys_in = value.data<T>();
T *keys_out = value_out->mutable_data<T>({num}, ctx.GetPlace());
// Determine temporary device storage requirements
size_t temp_storage_bytes = 0;
cub::DeviceRadixSort::SortPairsDescending<T, int>(
nullptr, temp_storage_bytes, keys_in, keys_out, idx_in, idx_out, num);
// Allocate temporary storage
auto place = BOOST_GET_CONST(platform::CUDAPlace, ctx.GetPlace());
auto d_temp_storage = memory::Alloc(place, temp_storage_bytes);
// Run sorting operation
cub::DeviceRadixSort::SortPairsDescending<T, int>(
d_temp_storage->ptr(), temp_storage_bytes, keys_in, keys_out, idx_in,
idx_out, num);
}
template <typename T>
struct BoxDecodeAndClipFunctor {
const T *anchor;
const T *deltas;
const T *var;
const int *index;
const T *im_info;
T *proposals;
BoxDecodeAndClipFunctor(const T *anchor, const T *deltas, const T *var,
const int *index, const T *im_info, T *proposals)
: anchor(anchor),
deltas(deltas),
var(var),
index(index),
im_info(im_info),
proposals(proposals) {}
T bbox_clip_default{static_cast<T>(kBBoxClipDefault)};
__device__ void operator()(size_t i) {
int k = index[i] * 4;
T axmin = anchor[k];
T aymin = anchor[k + 1];
T axmax = anchor[k + 2];
T aymax = anchor[k + 3];
T w = axmax - axmin + 1.0;
T h = aymax - aymin + 1.0;
T cx = axmin + 0.5 * w;
T cy = aymin + 0.5 * h;
T dxmin = deltas[k];
T dymin = deltas[k + 1];
T dxmax = deltas[k + 2];
T dymax = deltas[k + 3];
T d_cx, d_cy, d_w, d_h;
if (var) {
d_cx = cx + dxmin * w * var[k];
d_cy = cy + dymin * h * var[k + 1];
d_w = exp(Min(dxmax * var[k + 2], bbox_clip_default)) * w;
d_h = exp(Min(dymax * var[k + 3], bbox_clip_default)) * h;
} else {
d_cx = cx + dxmin * w;
d_cy = cy + dymin * h;
d_w = exp(Min(dxmax, bbox_clip_default)) * w;
d_h = exp(Min(dymax, bbox_clip_default)) * h;
}
T oxmin = d_cx - d_w * 0.5;
T oymin = d_cy - d_h * 0.5;
T oxmax = d_cx + d_w * 0.5 - 1.;
T oymax = d_cy + d_h * 0.5 - 1.;
proposals[i * 4] = Max(Min(oxmin, im_info[1] - 1.), 0.);
proposals[i * 4 + 1] = Max(Min(oymin, im_info[0] - 1.), 0.);
proposals[i * 4 + 2] = Max(Min(oxmax, im_info[1] - 1.), 0.);
proposals[i * 4 + 3] = Max(Min(oymax, im_info[0] - 1.), 0.);
}
__device__ __forceinline__ T Min(T a, T b) const { return a > b ? b : a; }
__device__ __forceinline__ T Max(T a, T b) const { return a > b ? a : b; }
};
template <typename T, int BlockSize>
static __global__ void FilterBBoxes(const T *bboxes, const T *im_info,
const T min_size, const int num,
int *keep_num, int *keep,
bool is_scale = true) {
T im_h = im_info[0];
T im_w = im_info[1];
int cnt = 0;
__shared__ int keep_index[BlockSize];
CUDA_KERNEL_LOOP(i, num) {
keep_index[threadIdx.x] = -1;
__syncthreads();
int k = i * 4;
T xmin = bboxes[k];
T ymin = bboxes[k + 1];
T xmax = bboxes[k + 2];
T ymax = bboxes[k + 3];
T w = xmax - xmin + 1.0;
T h = ymax - ymin + 1.0;
T cx = xmin + w / 2.;
T cy = ymin + h / 2.;
if (is_scale) {
w = (xmax - xmin) / im_info[2] + 1.;
h = (ymax - ymin) / im_info[2] + 1.;
}
if (w >= min_size && h >= min_size && cx <= im_w && cy <= im_h) {
keep_index[threadIdx.x] = i;
}
__syncthreads();
if (threadIdx.x == 0) {
int size = (num - i) < BlockSize ? num - i : BlockSize;
for (int j = 0; j < size; ++j) {
if (keep_index[j] > -1) {
keep[cnt++] = keep_index[j];
}
}
}
__syncthreads();
}
if (threadIdx.x == 0) {
keep_num[0] = cnt;
}
}
static __device__ float IoU(const float *a, const float *b) {
float left = max(a[0], b[0]), right = min(a[2], b[2]);
float top = max(a[1], b[1]), bottom = min(a[3], b[3]);
float width = max(right - left + 1, 0.f), height = max(bottom - top + 1, 0.f);
float inter_s = width * height;
float s_a = (a[2] - a[0] + 1) * (a[3] - a[1] + 1);
float s_b = (b[2] - b[0] + 1) * (b[3] - b[1] + 1);
return inter_s / (s_a + s_b - inter_s);
}
static __global__ void NMSKernel(const int n_boxes,
const float nms_overlap_thresh,
const float *dev_boxes, uint64_t *dev_mask) {
const int row_start = blockIdx.y;
const int col_start = blockIdx.x;
const int row_size =
min(n_boxes - row_start * kThreadsPerBlock, kThreadsPerBlock);
const int col_size =
min(n_boxes - col_start * kThreadsPerBlock, kThreadsPerBlock);
__shared__ float block_boxes[kThreadsPerBlock * 4];
if (threadIdx.x < col_size) {
block_boxes[threadIdx.x * 4 + 0] =
dev_boxes[(kThreadsPerBlock * col_start + threadIdx.x) * 4 + 0];
block_boxes[threadIdx.x * 4 + 1] =
dev_boxes[(kThreadsPerBlock * col_start + threadIdx.x) * 4 + 1];
block_boxes[threadIdx.x * 4 + 2] =
dev_boxes[(kThreadsPerBlock * col_start + threadIdx.x) * 4 + 2];
block_boxes[threadIdx.x * 4 + 3] =
dev_boxes[(kThreadsPerBlock * col_start + threadIdx.x) * 4 + 3];
}
__syncthreads();
if (threadIdx.x < row_size) {
const int cur_box_idx = kThreadsPerBlock * row_start + threadIdx.x;
const float *cur_box = dev_boxes + cur_box_idx * 4;
int i = 0;
uint64_t t = 0;
int start = 0;
if (row_start == col_start) {
start = threadIdx.x + 1;
}
for (i = start; i < col_size; i++) {
if (IoU(cur_box, block_boxes + i * 4) > nms_overlap_thresh) {
t |= 1ULL << i;
}
}
const int col_blocks = DIVUP(n_boxes, kThreadsPerBlock);
dev_mask[cur_box_idx * col_blocks + col_start] = t;
}
}
template <typename T>
static void NMS(const platform::CUDADeviceContext &ctx, const Tensor &proposals,
const Tensor &sorted_indices, const T nms_threshold,
Tensor *keep_out) {
int boxes_num = proposals.dims()[0];
const int col_blocks = DIVUP(boxes_num, kThreadsPerBlock);
dim3 blocks(DIVUP(boxes_num, kThreadsPerBlock),
DIVUP(boxes_num, kThreadsPerBlock));
dim3 threads(kThreadsPerBlock);
const T *boxes = proposals.data<T>();
auto place = BOOST_GET_CONST(platform::CUDAPlace, ctx.GetPlace());
framework::Vector<uint64_t> mask(boxes_num * col_blocks);
NMSKernel<<<blocks, threads>>>(boxes_num, nms_threshold, boxes,
mask.CUDAMutableData(BOOST_GET_CONST(
platform::CUDAPlace, ctx.GetPlace())));
std::vector<uint64_t> remv(col_blocks);
memset(&remv[0], 0, sizeof(uint64_t) * col_blocks);
std::vector<int> keep_vec;
int num_to_keep = 0;
for (int i = 0; i < boxes_num; i++) {
int nblock = i / kThreadsPerBlock;
int inblock = i % kThreadsPerBlock;
if (!(remv[nblock] & (1ULL << inblock))) {
++num_to_keep;
keep_vec.push_back(i);
uint64_t *p = &mask[0] + i * col_blocks;
for (int j = nblock; j < col_blocks; j++) {
remv[j] |= p[j];
}
}
}
int *keep = keep_out->mutable_data<int>({num_to_keep}, ctx.GetPlace());
memory::Copy(place, keep, platform::CPUPlace(), keep_vec.data(),
sizeof(int) * num_to_keep, ctx.stream());
ctx.Wait();
}
} // namespace operators
} // namespace paddle
......@@ -21,6 +21,8 @@ limitations under the License. */
namespace paddle {
namespace operators {
static const double kBBoxClipDefault = std::log(1000.0 / 16.0);
struct RangeInitFunctor {
int start;
int delta;
......@@ -125,17 +127,45 @@ void BboxOverlaps(const framework::Tensor& r_boxes,
}
}
// Calculate max IoU between each box and ground-truth and
// each row represents one box
template <typename T>
void MaxIoU(const framework::Tensor& iou, framework::Tensor* max_iou) {
const T* iou_data = iou.data<T>();
int row = iou.dims()[0];
int col = iou.dims()[1];
T* max_iou_data = max_iou->data<T>();
for (int i = 0; i < row; ++i) {
const T* v = iou_data + i * col;
T max_v = *std::max_element(v, v + col);
max_iou_data[i] = max_v;
}
}
static void AppendProposals(framework::Tensor* dst, int64_t offset,
const framework::Tensor& src) {
auto* out_data = dst->data<void>();
auto* to_add_data = src.data<void>();
size_t size_of_t = framework::SizeOfType(src.type());
offset *= size_of_t;
std::memcpy(
reinterpret_cast<void*>(reinterpret_cast<uintptr_t>(out_data) + offset),
to_add_data, src.numel() * size_of_t);
}
template <class T>
void ClipTiledBoxes(const platform::DeviceContext& ctx,
const framework::Tensor& im_info,
const framework::Tensor& input_boxes,
framework::Tensor* out) {
framework::Tensor* out, bool is_scale = true) {
T* out_data = out->mutable_data<T>(ctx.GetPlace());
const T* im_info_data = im_info.data<T>();
const T* input_boxes_data = input_boxes.data<T>();
T zero(0);
T im_w = round(im_info_data[1] / im_info_data[2]);
T im_h = round(im_info_data[0] / im_info_data[2]);
T im_w =
is_scale ? round(im_info_data[1] / im_info_data[2]) : im_info_data[1];
T im_h =
is_scale ? round(im_info_data[0] / im_info_data[2]) : im_info_data[0];
for (int64_t i = 0; i < input_boxes.numel(); ++i) {
if (i % 4 == 0) {
out_data[i] = std::max(std::min(input_boxes_data[i], im_w - 1), zero);
......@@ -149,19 +179,101 @@ void ClipTiledBoxes(const platform::DeviceContext& ctx,
}
}
// Calculate max IoU between each box and ground-truth and
// each row represents one box
template <typename T>
void MaxIoU(const framework::Tensor& iou, framework::Tensor* max_iou) {
const T* iou_data = iou.data<T>();
int row = iou.dims()[0];
int col = iou.dims()[1];
T* max_iou_data = max_iou->data<T>();
for (int i = 0; i < row; ++i) {
const T* v = iou_data + i * col;
T max_v = *std::max_element(v, v + col);
max_iou_data[i] = max_v;
// Filter the box with small area
template <class T>
void FilterBoxes(const platform::DeviceContext& ctx,
const framework::Tensor* boxes, float min_size,
const framework::Tensor& im_info, bool is_scale,
framework::Tensor* keep) {
const T* im_info_data = im_info.data<T>();
const T* boxes_data = boxes->data<T>();
keep->Resize({boxes->dims()[0]});
min_size = std::max(min_size, 1.0f);
int* keep_data = keep->mutable_data<int>(ctx.GetPlace());
int keep_len = 0;
for (int i = 0; i < boxes->dims()[0]; ++i) {
T ws = boxes_data[4 * i + 2] - boxes_data[4 * i] + 1;
T hs = boxes_data[4 * i + 3] - boxes_data[4 * i + 1] + 1;
T x_ctr = boxes_data[4 * i] + ws / 2;
T y_ctr = boxes_data[4 * i + 1] + hs / 2;
if (is_scale) {
ws = (boxes_data[4 * i + 2] - boxes_data[4 * i]) / im_info_data[2] + 1;
hs =
(boxes_data[4 * i + 3] - boxes_data[4 * i + 1]) / im_info_data[2] + 1;
}
if (ws >= min_size && hs >= min_size && x_ctr <= im_info_data[1] &&
y_ctr <= im_info_data[0]) {
keep_data[keep_len++] = i;
}
}
keep->Resize({keep_len});
}
template <class T>
static void BoxCoder(const platform::DeviceContext& ctx,
framework::Tensor* all_anchors,
framework::Tensor* bbox_deltas,
framework::Tensor* variances,
framework::Tensor* proposals) {
T* proposals_data = proposals->mutable_data<T>(ctx.GetPlace());
int64_t row = all_anchors->dims()[0];
int64_t len = all_anchors->dims()[1];
auto* bbox_deltas_data = bbox_deltas->data<T>();
auto* anchor_data = all_anchors->data<T>();
const T* variances_data = nullptr;
if (variances) {
variances_data = variances->data<T>();
}
for (int64_t i = 0; i < row; ++i) {
T anchor_width = anchor_data[i * len + 2] - anchor_data[i * len] + 1.0;
T anchor_height = anchor_data[i * len + 3] - anchor_data[i * len + 1] + 1.0;
T anchor_center_x = anchor_data[i * len] + 0.5 * anchor_width;
T anchor_center_y = anchor_data[i * len + 1] + 0.5 * anchor_height;
T bbox_center_x = 0, bbox_center_y = 0;
T bbox_width = 0, bbox_height = 0;
if (variances) {
bbox_center_x =
variances_data[i * len] * bbox_deltas_data[i * len] * anchor_width +
anchor_center_x;
bbox_center_y = variances_data[i * len + 1] *
bbox_deltas_data[i * len + 1] * anchor_height +
anchor_center_y;
bbox_width = std::exp(std::min<T>(variances_data[i * len + 2] *
bbox_deltas_data[i * len + 2],
kBBoxClipDefault)) *
anchor_width;
bbox_height = std::exp(std::min<T>(variances_data[i * len + 3] *
bbox_deltas_data[i * len + 3],
kBBoxClipDefault)) *
anchor_height;
} else {
bbox_center_x =
bbox_deltas_data[i * len] * anchor_width + anchor_center_x;
bbox_center_y =
bbox_deltas_data[i * len + 1] * anchor_height + anchor_center_y;
bbox_width = std::exp(std::min<T>(bbox_deltas_data[i * len + 2],
kBBoxClipDefault)) *
anchor_width;
bbox_height = std::exp(std::min<T>(bbox_deltas_data[i * len + 3],
kBBoxClipDefault)) *
anchor_height;
}
proposals_data[i * len] = bbox_center_x - bbox_width / 2;
proposals_data[i * len + 1] = bbox_center_y - bbox_height / 2;
proposals_data[i * len + 2] = bbox_center_x + bbox_width / 2 - 1;
proposals_data[i * len + 3] = bbox_center_y + bbox_height / 2 - 1;
}
// return proposals;
}
} // namespace operators
......
......@@ -18,6 +18,8 @@ limitations under the License. */
#include <vector>
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/framework/op_version_registry.h"
#include "paddle/fluid/operators/detection/bbox_util.h"
#include "paddle/fluid/operators/detection/nms_util.h"
#include "paddle/fluid/operators/gather.h"
#include "paddle/fluid/operators/math/math_function.h"
......@@ -27,18 +29,6 @@ namespace operators {
using Tensor = framework::Tensor;
using LoDTensor = framework::LoDTensor;
static const double kBBoxClipDefault = std::log(1000.0 / 16.0);
static void AppendProposals(Tensor *dst, int64_t offset, const Tensor &src) {
auto *out_data = dst->data<void>();
auto *to_add_data = src.data<void>();
size_t size_of_t = framework::SizeOfType(src.type());
offset *= size_of_t;
std::memcpy(
reinterpret_cast<void *>(reinterpret_cast<uintptr_t>(out_data) + offset),
to_add_data, src.numel() * size_of_t);
}
class GenerateProposalsOp : public framework::OperatorWithKernel {
public:
using framework::OperatorWithKernel::OperatorWithKernel;
......@@ -77,225 +67,6 @@ class GenerateProposalsOp : public framework::OperatorWithKernel {
}
};
template <class T>
static inline void BoxCoder(const platform::DeviceContext &ctx,
Tensor *all_anchors, Tensor *bbox_deltas,
Tensor *variances, Tensor *proposals) {
T *proposals_data = proposals->mutable_data<T>(ctx.GetPlace());
int64_t row = all_anchors->dims()[0];
int64_t len = all_anchors->dims()[1];
auto *bbox_deltas_data = bbox_deltas->data<T>();
auto *anchor_data = all_anchors->data<T>();
const T *variances_data = nullptr;
if (variances) {
variances_data = variances->data<T>();
}
for (int64_t i = 0; i < row; ++i) {
T anchor_width = anchor_data[i * len + 2] - anchor_data[i * len] + 1.0;
T anchor_height = anchor_data[i * len + 3] - anchor_data[i * len + 1] + 1.0;
T anchor_center_x = anchor_data[i * len] + 0.5 * anchor_width;
T anchor_center_y = anchor_data[i * len + 1] + 0.5 * anchor_height;
T bbox_center_x = 0, bbox_center_y = 0;
T bbox_width = 0, bbox_height = 0;
if (variances) {
bbox_center_x =
variances_data[i * len] * bbox_deltas_data[i * len] * anchor_width +
anchor_center_x;
bbox_center_y = variances_data[i * len + 1] *
bbox_deltas_data[i * len + 1] * anchor_height +
anchor_center_y;
bbox_width = std::exp(std::min<T>(variances_data[i * len + 2] *
bbox_deltas_data[i * len + 2],
kBBoxClipDefault)) *
anchor_width;
bbox_height = std::exp(std::min<T>(variances_data[i * len + 3] *
bbox_deltas_data[i * len + 3],
kBBoxClipDefault)) *
anchor_height;
} else {
bbox_center_x =
bbox_deltas_data[i * len] * anchor_width + anchor_center_x;
bbox_center_y =
bbox_deltas_data[i * len + 1] * anchor_height + anchor_center_y;
bbox_width = std::exp(std::min<T>(bbox_deltas_data[i * len + 2],
kBBoxClipDefault)) *
anchor_width;
bbox_height = std::exp(std::min<T>(bbox_deltas_data[i * len + 3],
kBBoxClipDefault)) *
anchor_height;
}
proposals_data[i * len] = bbox_center_x - bbox_width / 2;
proposals_data[i * len + 1] = bbox_center_y - bbox_height / 2;
proposals_data[i * len + 2] = bbox_center_x + bbox_width / 2 - 1;
proposals_data[i * len + 3] = bbox_center_y + bbox_height / 2 - 1;
}
// return proposals;
}
template <class T>
static inline void ClipTiledBoxes(const platform::DeviceContext &ctx,
const Tensor &im_info, Tensor *boxes) {
T *boxes_data = boxes->mutable_data<T>(ctx.GetPlace());
const T *im_info_data = im_info.data<T>();
T zero(0);
for (int64_t i = 0; i < boxes->numel(); ++i) {
if (i % 4 == 0) {
boxes_data[i] =
std::max(std::min(boxes_data[i], im_info_data[1] - 1), zero);
} else if (i % 4 == 1) {
boxes_data[i] =
std::max(std::min(boxes_data[i], im_info_data[0] - 1), zero);
} else if (i % 4 == 2) {
boxes_data[i] =
std::max(std::min(boxes_data[i], im_info_data[1] - 1), zero);
} else {
boxes_data[i] =
std::max(std::min(boxes_data[i], im_info_data[0] - 1), zero);
}
}
}
template <class T>
static inline void FilterBoxes(const platform::DeviceContext &ctx,
Tensor *boxes, float min_size,
const Tensor &im_info, Tensor *keep) {
const T *im_info_data = im_info.data<T>();
T *boxes_data = boxes->mutable_data<T>(ctx.GetPlace());
T im_scale = im_info_data[2];
keep->Resize({boxes->dims()[0]});
min_size = std::max(min_size, 1.0f);
int *keep_data = keep->mutable_data<int>(ctx.GetPlace());
int keep_len = 0;
for (int i = 0; i < boxes->dims()[0]; ++i) {
T ws = boxes_data[4 * i + 2] - boxes_data[4 * i] + 1;
T hs = boxes_data[4 * i + 3] - boxes_data[4 * i + 1] + 1;
T ws_origin_scale =
(boxes_data[4 * i + 2] - boxes_data[4 * i]) / im_scale + 1;
T hs_origin_scale =
(boxes_data[4 * i + 3] - boxes_data[4 * i + 1]) / im_scale + 1;
T x_ctr = boxes_data[4 * i] + ws / 2;
T y_ctr = boxes_data[4 * i + 1] + hs / 2;
if (ws_origin_scale >= min_size && hs_origin_scale >= min_size &&
x_ctr <= im_info_data[1] && y_ctr <= im_info_data[0]) {
keep_data[keep_len++] = i;
}
}
keep->Resize({keep_len});
}
template <class T>
static inline std::vector<std::pair<T, int>> GetSortedScoreIndex(
const std::vector<T> &scores) {
std::vector<std::pair<T, int>> sorted_indices;
sorted_indices.reserve(scores.size());
for (size_t i = 0; i < scores.size(); ++i) {
sorted_indices.emplace_back(scores[i], i);
}
// Sort the score pair according to the scores in descending order
std::stable_sort(sorted_indices.begin(), sorted_indices.end(),
[](const std::pair<T, int> &a, const std::pair<T, int> &b) {
return a.first < b.first;
});
return sorted_indices;
}
template <class T>
static inline T BBoxArea(const T *box, bool normalized) {
if (box[2] < box[0] || box[3] < box[1]) {
// If coordinate values are is invalid
// (e.g. xmax < xmin or ymax < ymin), return 0.
return static_cast<T>(0.);
} else {
const T w = box[2] - box[0];
const T h = box[3] - box[1];
if (normalized) {
return w * h;
} else {
// If coordinate values are not within range [0, 1].
return (w + 1) * (h + 1);
}
}
}
template <class T>
static inline T JaccardOverlap(const T *box1, const T *box2, bool normalized) {
if (box2[0] > box1[2] || box2[2] < box1[0] || box2[1] > box1[3] ||
box2[3] < box1[1]) {
return static_cast<T>(0.);
} else {
const T inter_xmin = std::max(box1[0], box2[0]);
const T inter_ymin = std::max(box1[1], box2[1]);
const T inter_xmax = std::min(box1[2], box2[2]);
const T inter_ymax = std::min(box1[3], box2[3]);
const T inter_w = std::max(T(0), inter_xmax - inter_xmin + 1);
const T inter_h = std::max(T(0), inter_ymax - inter_ymin + 1);
const T inter_area = inter_w * inter_h;
const T bbox1_area = BBoxArea<T>(box1, normalized);
const T bbox2_area = BBoxArea<T>(box2, normalized);
return inter_area / (bbox1_area + bbox2_area - inter_area);
}
}
template <typename T>
static inline Tensor VectorToTensor(const std::vector<T> &selected_indices,
int selected_num) {
Tensor keep_nms;
keep_nms.Resize({selected_num});
auto *keep_data = keep_nms.mutable_data<T>(platform::CPUPlace());
for (int i = 0; i < selected_num; ++i) {
keep_data[i] = selected_indices[i];
}
return keep_nms;
}
template <class T>
static inline Tensor NMS(const platform::DeviceContext &ctx, Tensor *bbox,
Tensor *scores, T nms_threshold, float eta) {
int64_t num_boxes = bbox->dims()[0];
// 4: [xmin ymin xmax ymax]
int64_t box_size = bbox->dims()[1];
std::vector<T> scores_data(num_boxes);
std::copy_n(scores->data<T>(), num_boxes, scores_data.begin());
std::vector<std::pair<T, int>> sorted_indices =
GetSortedScoreIndex<T>(scores_data);
std::vector<int> selected_indices;
int selected_num = 0;
T adaptive_threshold = nms_threshold;
const T *bbox_data = bbox->data<T>();
while (sorted_indices.size() != 0) {
int idx = sorted_indices.back().second;
bool flag = true;
for (int kept_idx : selected_indices) {
if (flag) {
T overlap = JaccardOverlap<T>(bbox_data + idx * box_size,
bbox_data + kept_idx * box_size, false);
flag = (overlap <= adaptive_threshold);
} else {
break;
}
}
if (flag) {
selected_indices.push_back(idx);
++selected_num;
}
sorted_indices.erase(sorted_indices.end() - 1);
if (flag && eta < 1 && adaptive_threshold > 0.5) {
adaptive_threshold *= eta;
}
}
return VectorToTensor(selected_indices, selected_num);
}
template <typename T>
class GenerateProposalsKernel : public framework::OpKernel<T> {
public:
......@@ -434,10 +205,10 @@ class GenerateProposalsKernel : public framework::OpKernel<T> {
proposals.mutable_data<T>({index_t.numel(), 4}, ctx.GetPlace());
BoxCoder<T>(ctx, &anchor_sel, &bbox_sel, &var_sel, &proposals);
ClipTiledBoxes<T>(ctx, im_info_slice, &proposals);
ClipTiledBoxes<T>(ctx, im_info_slice, proposals, &proposals, false);
Tensor keep;
FilterBoxes<T>(ctx, &proposals, min_size, im_info_slice, &keep);
FilterBoxes<T>(ctx, &proposals, min_size, im_info_slice, true, &keep);
// Handle the case when there is no keep index left
if (keep.numel() == 0) {
math::SetConstant<platform::CPUDeviceContext, T> set_zero;
......
......@@ -16,13 +16,11 @@ limitations under the License. */
#include <stdio.h>
#include <string>
#include <vector>
#include "cub/cub.cuh"
#include "paddle/fluid/framework/mixed_vector.h"
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/memory/memory.h"
#include "paddle/fluid/operators/gather.cu.h"
#include "paddle/fluid/operators/detection/bbox_util.cu.h"
#include "paddle/fluid/operators/math/math_function.h"
#include "paddle/fluid/platform/for_range.h"
namespace paddle {
namespace operators {
......@@ -31,258 +29,6 @@ using Tensor = framework::Tensor;
using LoDTensor = framework::LoDTensor;
namespace {
#define DIVUP(m, n) ((m) / (n) + ((m) % (n) > 0))
int const kThreadsPerBlock = sizeof(uint64_t) * 8;
static const double kBBoxClipDefault = std::log(1000.0 / 16.0);
struct RangeInitFunctor {
int start_;
int delta_;
int *out_;
__device__ void operator()(size_t i) { out_[i] = start_ + i * delta_; }
};
template <typename T>
static void SortDescending(const platform::CUDADeviceContext &ctx,
const Tensor &value, Tensor *value_out,
Tensor *index_out) {
int num = static_cast<int>(value.numel());
Tensor index_in_t;
int *idx_in = index_in_t.mutable_data<int>({num}, ctx.GetPlace());
platform::ForRange<platform::CUDADeviceContext> for_range(ctx, num);
for_range(RangeInitFunctor{0, 1, idx_in});
int *idx_out = index_out->mutable_data<int>({num}, ctx.GetPlace());
const T *keys_in = value.data<T>();
T *keys_out = value_out->mutable_data<T>({num}, ctx.GetPlace());
// Determine temporary device storage requirements
size_t temp_storage_bytes = 0;
cub::DeviceRadixSort::SortPairsDescending<T, int>(
nullptr, temp_storage_bytes, keys_in, keys_out, idx_in, idx_out, num);
// Allocate temporary storage
auto place = BOOST_GET_CONST(platform::CUDAPlace, ctx.GetPlace());
auto d_temp_storage = memory::Alloc(place, temp_storage_bytes);
// Run sorting operation
cub::DeviceRadixSort::SortPairsDescending<T, int>(
d_temp_storage->ptr(), temp_storage_bytes, keys_in, keys_out, idx_in,
idx_out, num);
}
template <typename T>
struct BoxDecodeAndClipFunctor {
const T *anchor;
const T *deltas;
const T *var;
const int *index;
const T *im_info;
T *proposals;
BoxDecodeAndClipFunctor(const T *anchor, const T *deltas, const T *var,
const int *index, const T *im_info, T *proposals)
: anchor(anchor),
deltas(deltas),
var(var),
index(index),
im_info(im_info),
proposals(proposals) {}
T bbox_clip_default{static_cast<T>(kBBoxClipDefault)};
__device__ void operator()(size_t i) {
int k = index[i] * 4;
T axmin = anchor[k];
T aymin = anchor[k + 1];
T axmax = anchor[k + 2];
T aymax = anchor[k + 3];
T w = axmax - axmin + 1.0;
T h = aymax - aymin + 1.0;
T cx = axmin + 0.5 * w;
T cy = aymin + 0.5 * h;
T dxmin = deltas[k];
T dymin = deltas[k + 1];
T dxmax = deltas[k + 2];
T dymax = deltas[k + 3];
T d_cx, d_cy, d_w, d_h;
if (var) {
d_cx = cx + dxmin * w * var[k];
d_cy = cy + dymin * h * var[k + 1];
d_w = exp(Min(dxmax * var[k + 2], bbox_clip_default)) * w;
d_h = exp(Min(dymax * var[k + 3], bbox_clip_default)) * h;
} else {
d_cx = cx + dxmin * w;
d_cy = cy + dymin * h;
d_w = exp(Min(dxmax, bbox_clip_default)) * w;
d_h = exp(Min(dymax, bbox_clip_default)) * h;
}
T oxmin = d_cx - d_w * 0.5;
T oymin = d_cy - d_h * 0.5;
T oxmax = d_cx + d_w * 0.5 - 1.;
T oymax = d_cy + d_h * 0.5 - 1.;
proposals[i * 4] = Max(Min(oxmin, im_info[1] - 1.), 0.);
proposals[i * 4 + 1] = Max(Min(oymin, im_info[0] - 1.), 0.);
proposals[i * 4 + 2] = Max(Min(oxmax, im_info[1] - 1.), 0.);
proposals[i * 4 + 3] = Max(Min(oymax, im_info[0] - 1.), 0.);
}
__device__ __forceinline__ T Min(T a, T b) const { return a > b ? b : a; }
__device__ __forceinline__ T Max(T a, T b) const { return a > b ? a : b; }
};
template <typename T, int BlockSize>
static __global__ void FilterBBoxes(const T *bboxes, const T *im_info,
const T min_size, const int num,
int *keep_num, int *keep) {
T im_h = im_info[0];
T im_w = im_info[1];
T im_scale = im_info[2];
int cnt = 0;
__shared__ int keep_index[BlockSize];
CUDA_KERNEL_LOOP(i, num) {
keep_index[threadIdx.x] = -1;
__syncthreads();
int k = i * 4;
T xmin = bboxes[k];
T ymin = bboxes[k + 1];
T xmax = bboxes[k + 2];
T ymax = bboxes[k + 3];
T w = xmax - xmin + 1.0;
T h = ymax - ymin + 1.0;
T cx = xmin + w / 2.;
T cy = ymin + h / 2.;
T w_s = (xmax - xmin) / im_scale + 1.;
T h_s = (ymax - ymin) / im_scale + 1.;
if (w_s >= min_size && h_s >= min_size && cx <= im_w && cy <= im_h) {
keep_index[threadIdx.x] = i;
}
__syncthreads();
if (threadIdx.x == 0) {
int size = (num - i) < BlockSize ? num - i : BlockSize;
for (int j = 0; j < size; ++j) {
if (keep_index[j] > -1) {
keep[cnt++] = keep_index[j];
}
}
}
__syncthreads();
}
if (threadIdx.x == 0) {
keep_num[0] = cnt;
}
}
static __device__ inline float IoU(const float *a, const float *b) {
float left = max(a[0], b[0]), right = min(a[2], b[2]);
float top = max(a[1], b[1]), bottom = min(a[3], b[3]);
float width = max(right - left + 1, 0.f), height = max(bottom - top + 1, 0.f);
float inter_s = width * height;
float s_a = (a[2] - a[0] + 1) * (a[3] - a[1] + 1);
float s_b = (b[2] - b[0] + 1) * (b[3] - b[1] + 1);
return inter_s / (s_a + s_b - inter_s);
}
static __global__ void NMSKernel(const int n_boxes,
const float nms_overlap_thresh,
const float *dev_boxes, uint64_t *dev_mask) {
const int row_start = blockIdx.y;
const int col_start = blockIdx.x;
const int row_size =
min(n_boxes - row_start * kThreadsPerBlock, kThreadsPerBlock);
const int col_size =
min(n_boxes - col_start * kThreadsPerBlock, kThreadsPerBlock);
__shared__ float block_boxes[kThreadsPerBlock * 4];
if (threadIdx.x < col_size) {
block_boxes[threadIdx.x * 4 + 0] =
dev_boxes[(kThreadsPerBlock * col_start + threadIdx.x) * 4 + 0];
block_boxes[threadIdx.x * 4 + 1] =
dev_boxes[(kThreadsPerBlock * col_start + threadIdx.x) * 4 + 1];
block_boxes[threadIdx.x * 4 + 2] =
dev_boxes[(kThreadsPerBlock * col_start + threadIdx.x) * 4 + 2];
block_boxes[threadIdx.x * 4 + 3] =
dev_boxes[(kThreadsPerBlock * col_start + threadIdx.x) * 4 + 3];
}
__syncthreads();
if (threadIdx.x < row_size) {
const int cur_box_idx = kThreadsPerBlock * row_start + threadIdx.x;
const float *cur_box = dev_boxes + cur_box_idx * 4;
int i = 0;
uint64_t t = 0;
int start = 0;
if (row_start == col_start) {
start = threadIdx.x + 1;
}
for (i = start; i < col_size; i++) {
if (IoU(cur_box, block_boxes + i * 4) > nms_overlap_thresh) {
t |= 1ULL << i;
}
}
const int col_blocks = DIVUP(n_boxes, kThreadsPerBlock);
dev_mask[cur_box_idx * col_blocks + col_start] = t;
}
}
template <typename T>
static void NMS(const platform::CUDADeviceContext &ctx, const Tensor &proposals,
const Tensor &sorted_indices, const T nms_threshold,
Tensor *keep_out) {
int boxes_num = proposals.dims()[0];
const int col_blocks = DIVUP(boxes_num, kThreadsPerBlock);
dim3 blocks(DIVUP(boxes_num, kThreadsPerBlock),
DIVUP(boxes_num, kThreadsPerBlock));
dim3 threads(kThreadsPerBlock);
const T *boxes = proposals.data<T>();
auto place = BOOST_GET_CONST(platform::CUDAPlace, ctx.GetPlace());
framework::Vector<uint64_t> mask(boxes_num * col_blocks);
NMSKernel<<<blocks, threads>>>(boxes_num, nms_threshold, boxes,
mask.CUDAMutableData(BOOST_GET_CONST(
platform::CUDAPlace, ctx.GetPlace())));
std::vector<uint64_t> remv(col_blocks);
memset(&remv[0], 0, sizeof(uint64_t) * col_blocks);
std::vector<int> keep_vec;
int num_to_keep = 0;
for (int i = 0; i < boxes_num; i++) {
int nblock = i / kThreadsPerBlock;
int inblock = i % kThreadsPerBlock;
if (!(remv[nblock] & (1ULL << inblock))) {
++num_to_keep;
keep_vec.push_back(i);
uint64_t *p = &mask[0] + i * col_blocks;
for (int j = nblock; j < col_blocks; j++) {
remv[j] |= p[j];
}
}
}
int *keep = keep_out->mutable_data<int>({num_to_keep}, ctx.GetPlace());
memory::Copy(place, keep, platform::CPUPlace(), keep_vec.data(),
sizeof(int) * num_to_keep, ctx.stream());
ctx.Wait();
}
template <typename T>
static std::pair<Tensor, Tensor> ProposalForOneImage(
const platform::CUDADeviceContext &ctx, const Tensor &im_info,
......
/* Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include <cmath>
#include <cstring>
#include <string>
#include <vector>
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/framework/op_version_registry.h"
#include "paddle/fluid/operators/detection/bbox_util.h"
#include "paddle/fluid/operators/detection/nms_util.h"
#include "paddle/fluid/operators/gather.h"
#include "paddle/fluid/operators/math/math_function.h"
namespace paddle {
namespace operators {
using Tensor = framework::Tensor;
using LoDTensor = framework::LoDTensor;
class GenerateProposalsV2Op : public framework::OperatorWithKernel {
public:
using framework::OperatorWithKernel::OperatorWithKernel;
void InferShape(framework::InferShapeContext *ctx) const override {
PADDLE_ENFORCE_EQ(
ctx->HasInput("Scores"), true,
platform::errors::NotFound("Input(Scores) shouldn't be null."));
PADDLE_ENFORCE_EQ(
ctx->HasInput("BboxDeltas"), true,
platform::errors::NotFound("Input(BboxDeltas) shouldn't be null."));
PADDLE_ENFORCE_EQ(
ctx->HasInput("ImShape"), true,
platform::errors::NotFound("Input(ImShape) shouldn't be null."));
PADDLE_ENFORCE_EQ(
ctx->HasInput("Anchors"), true,
platform::errors::NotFound("Input(Anchors) shouldn't be null."));
PADDLE_ENFORCE_EQ(
ctx->HasInput("Variances"), true,
platform::errors::NotFound("Input(Variances) shouldn't be null."));
ctx->SetOutputDim("RpnRois", {-1, 4});
ctx->SetOutputDim("RpnRoiProbs", {-1, 1});
if (!ctx->IsRuntime()) {
ctx->SetLoDLevel("RpnRois", std::max(ctx->GetLoDLevel("Scores"), 1));
ctx->SetLoDLevel("RpnRoiProbs", std::max(ctx->GetLoDLevel("Scores"), 1));
}
}
protected:
framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext &ctx) const override {
return framework::OpKernelType(
OperatorWithKernel::IndicateVarDataType(ctx, "Anchors"),
ctx.device_context());
}
};
template <typename T>
class GenerateProposalsV2Kernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext &context) const override {
auto *scores = context.Input<Tensor>("Scores");
auto *bbox_deltas = context.Input<Tensor>("BboxDeltas");
auto *im_shape = context.Input<Tensor>("ImShape");
auto anchors = GET_DATA_SAFELY(context.Input<Tensor>("Anchors"), "Input",
"Anchors", "GenerateProposals");
auto variances = GET_DATA_SAFELY(context.Input<Tensor>("Variances"),
"Input", "Variances", "GenerateProposals");
auto *rpn_rois = context.Output<LoDTensor>("RpnRois");
auto *rpn_roi_probs = context.Output<LoDTensor>("RpnRoiProbs");
int pre_nms_top_n = context.Attr<int>("pre_nms_topN");
int post_nms_top_n = context.Attr<int>("post_nms_topN");
float nms_thresh = context.Attr<float>("nms_thresh");
float min_size = context.Attr<float>("min_size");
float eta = context.Attr<float>("eta");
auto &dev_ctx =
context.template device_context<platform::CPUDeviceContext>();
auto &scores_dim = scores->dims();
int64_t num = scores_dim[0];
int64_t c_score = scores_dim[1];
int64_t h_score = scores_dim[2];
int64_t w_score = scores_dim[3];
auto &bbox_dim = bbox_deltas->dims();
int64_t c_bbox = bbox_dim[1];
int64_t h_bbox = bbox_dim[2];
int64_t w_bbox = bbox_dim[3];
rpn_rois->mutable_data<T>({bbox_deltas->numel() / 4, 4},
context.GetPlace());
rpn_roi_probs->mutable_data<T>({scores->numel(), 1}, context.GetPlace());
Tensor bbox_deltas_swap, scores_swap;
bbox_deltas_swap.mutable_data<T>({num, h_bbox, w_bbox, c_bbox},
dev_ctx.GetPlace());
scores_swap.mutable_data<T>({num, h_score, w_score, c_score},
dev_ctx.GetPlace());
math::Transpose<platform::CPUDeviceContext, T, 4> trans;
std::vector<int> axis = {0, 2, 3, 1};
trans(dev_ctx, *bbox_deltas, &bbox_deltas_swap, axis);
trans(dev_ctx, *scores, &scores_swap, axis);
framework::LoD lod;
lod.resize(1);
auto &lod0 = lod[0];
lod0.push_back(0);
anchors.Resize({anchors.numel() / 4, 4});
variances.Resize({variances.numel() / 4, 4});
std::vector<int> tmp_num;
int64_t num_proposals = 0;
for (int64_t i = 0; i < num; ++i) {
Tensor im_shape_slice = im_shape->Slice(i, i + 1);
Tensor bbox_deltas_slice = bbox_deltas_swap.Slice(i, i + 1);
Tensor scores_slice = scores_swap.Slice(i, i + 1);
bbox_deltas_slice.Resize({h_bbox * w_bbox * c_bbox / 4, 4});
scores_slice.Resize({h_score * w_score * c_score, 1});
std::pair<Tensor, Tensor> tensor_pair =
ProposalForOneImage(dev_ctx, im_shape_slice, anchors, variances,
bbox_deltas_slice, scores_slice, pre_nms_top_n,
post_nms_top_n, nms_thresh, min_size, eta);
Tensor &proposals = tensor_pair.first;
Tensor &scores = tensor_pair.second;
AppendProposals(rpn_rois, 4 * num_proposals, proposals);
AppendProposals(rpn_roi_probs, num_proposals, scores);
num_proposals += proposals.dims()[0];
lod0.push_back(num_proposals);
tmp_num.push_back(proposals.dims()[0]);
}
if (context.HasOutput("RpnRoisNum")) {
auto *rpn_rois_num = context.Output<Tensor>("RpnRoisNum");
rpn_rois_num->mutable_data<int>({num}, context.GetPlace());
int *num_data = rpn_rois_num->data<int>();
for (int i = 0; i < num; i++) {
num_data[i] = tmp_num[i];
}
rpn_rois_num->Resize({num});
}
rpn_rois->set_lod(lod);
rpn_roi_probs->set_lod(lod);
rpn_rois->Resize({num_proposals, 4});
rpn_roi_probs->Resize({num_proposals, 1});
}
std::pair<Tensor, Tensor> ProposalForOneImage(
const platform::CPUDeviceContext &ctx, const Tensor &im_shape_slice,
const Tensor &anchors, const Tensor &variances,
const Tensor &bbox_deltas_slice, // [M, 4]
const Tensor &scores_slice, // [N, 1]
int pre_nms_top_n, int post_nms_top_n, float nms_thresh, float min_size,
float eta) const {
auto *scores_data = scores_slice.data<T>();
// Sort index
Tensor index_t;
index_t.Resize({scores_slice.numel()});
int *index = index_t.mutable_data<int>(ctx.GetPlace());
for (int i = 0; i < scores_slice.numel(); ++i) {
index[i] = i;
}
auto compare = [scores_data](const int64_t &i, const int64_t &j) {
return scores_data[i] > scores_data[j];
};
if (pre_nms_top_n <= 0 || pre_nms_top_n >= scores_slice.numel()) {
std::sort(index, index + scores_slice.numel(), compare);
} else {
std::nth_element(index, index + pre_nms_top_n,
index + scores_slice.numel(), compare);
index_t.Resize({pre_nms_top_n});
}
Tensor scores_sel, bbox_sel, anchor_sel, var_sel;
scores_sel.mutable_data<T>({index_t.numel(), 1}, ctx.GetPlace());
bbox_sel.mutable_data<T>({index_t.numel(), 4}, ctx.GetPlace());
anchor_sel.mutable_data<T>({index_t.numel(), 4}, ctx.GetPlace());
var_sel.mutable_data<T>({index_t.numel(), 4}, ctx.GetPlace());
CPUGather<T>(ctx, scores_slice, index_t, &scores_sel);
CPUGather<T>(ctx, bbox_deltas_slice, index_t, &bbox_sel);
CPUGather<T>(ctx, anchors, index_t, &anchor_sel);
CPUGather<T>(ctx, variances, index_t, &var_sel);
Tensor proposals;
proposals.mutable_data<T>({index_t.numel(), 4}, ctx.GetPlace());
BoxCoder<T>(ctx, &anchor_sel, &bbox_sel, &var_sel, &proposals);
ClipTiledBoxes<T>(ctx, im_shape_slice, proposals, &proposals, false);
Tensor keep;
FilterBoxes<T>(ctx, &proposals, min_size, im_shape_slice, false, &keep);
// Handle the case when there is no keep index left
if (keep.numel() == 0) {
math::SetConstant<platform::CPUDeviceContext, T> set_zero;
bbox_sel.mutable_data<T>({1, 4}, ctx.GetPlace());
set_zero(ctx, &bbox_sel, static_cast<T>(0));
Tensor scores_filter;
scores_filter.mutable_data<T>({1, 1}, ctx.GetPlace());
set_zero(ctx, &scores_filter, static_cast<T>(0));
return std::make_pair(bbox_sel, scores_filter);
}
Tensor scores_filter;
bbox_sel.mutable_data<T>({keep.numel(), 4}, ctx.GetPlace());
scores_filter.mutable_data<T>({keep.numel(), 1}, ctx.GetPlace());
CPUGather<T>(ctx, proposals, keep, &bbox_sel);
CPUGather<T>(ctx, scores_sel, keep, &scores_filter);
if (nms_thresh <= 0) {
return std::make_pair(bbox_sel, scores_filter);
}
Tensor keep_nms = NMS<T>(ctx, &bbox_sel, &scores_filter, nms_thresh, eta);
if (post_nms_top_n > 0 && post_nms_top_n < keep_nms.numel()) {
keep_nms.Resize({post_nms_top_n});
}
proposals.mutable_data<T>({keep_nms.numel(), 4}, ctx.GetPlace());
scores_sel.mutable_data<T>({keep_nms.numel(), 1}, ctx.GetPlace());
CPUGather<T>(ctx, bbox_sel, keep_nms, &proposals);
CPUGather<T>(ctx, scores_filter, keep_nms, &scores_sel);
return std::make_pair(proposals, scores_sel);
}
};
class GenerateProposalsV2OpMaker : public framework::OpProtoAndCheckerMaker {
public:
void Make() override {
AddInput("Scores",
"(Tensor) The scores from conv is in shape (N, A, H, W), "
"N is batch size, A is number of anchors, "
"H and W are height and width of the feature map");
AddInput("BboxDeltas",
"(Tensor) Bounding box deltas from conv is in "
"shape (N, 4*A, H, W).");
AddInput("ImShape",
"(Tensor) Image shape in shape (N, 2), "
"in format (height, width)");
AddInput("Anchors",
"(Tensor) Bounding box anchors from anchor_generator_op "
"is in shape (A, H, W, 4).");
AddInput("Variances",
"(Tensor) Bounding box variances with same shape as `Anchors`.");
AddOutput("RpnRois",
"(LoDTensor), Output proposals with shape (rois_num, 4).");
AddOutput("RpnRoiProbs",
"(LoDTensor) Scores of proposals with shape (rois_num, 1).");
AddOutput("RpnRoisNum", "(Tensor), The number of Rpn RoIs in each image")
.AsDispensable();
AddAttr<int>("pre_nms_topN",
"Number of top scoring RPN proposals to keep before "
"applying NMS.");
AddAttr<int>("post_nms_topN",
"Number of top scoring RPN proposals to keep after "
"applying NMS");
AddAttr<float>("nms_thresh", "NMS threshold used on RPN proposals.");
AddAttr<float>("min_size",
"Proposal height and width both need to be greater "
"than this min_size.");
AddAttr<float>("eta", "The parameter for adaptive NMS.");
AddComment(R"DOC(
This operator is the second version of generate_proposals op to generate
bounding box proposals for Faster RCNN.
The proposals are generated for a list of images based on image
score 'Scores', bounding box regression result 'BboxDeltas' as
well as predefined bounding box shapes 'anchors'. Greedy
non-maximum suppression is applied to generate the final bounding
boxes.
The difference between this version and the first version is that the image
scale is no long needed now, so the input requires im_shape instead of im_info.
The change aims to unify the input for all kinds of objective detection
such as YOLO-v3 and Faster R-CNN. As a result, the min_size represents the
size on input image instead of original image which is slightly different
to before and will not effect the result.
)DOC");
}
};
} // namespace operators
} // namespace paddle
namespace ops = paddle::operators;
REGISTER_OPERATOR(
generate_proposals_v2, ops::GenerateProposalsV2Op,
ops::GenerateProposalsV2OpMaker,
paddle::framework::EmptyGradOpMaker<paddle::framework::OpDesc>,
paddle::framework::EmptyGradOpMaker<paddle::imperative::OpBase>);
REGISTER_OP_CPU_KERNEL(generate_proposals_v2,
ops::GenerateProposalsV2Kernel<float>,
ops::GenerateProposalsV2Kernel<double>);
/* Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include <paddle/fluid/memory/allocation/allocator.h>
#include <stdio.h>
#include <string>
#include <vector>
#include "paddle/fluid/framework/mixed_vector.h"
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/memory/memory.h"
#include "paddle/fluid/operators/detection/bbox_util.cu.h"
#include "paddle/fluid/operators/math/math_function.h"
namespace paddle {
namespace operators {
using Tensor = framework::Tensor;
using LoDTensor = framework::LoDTensor;
namespace {
template <typename T>
static std::pair<Tensor, Tensor> ProposalForOneImage(
const platform::CUDADeviceContext &ctx, const Tensor &im_shape,
const Tensor &anchors, const Tensor &variances,
const Tensor &bbox_deltas, // [M, 4]
const Tensor &scores, // [N, 1]
int pre_nms_top_n, int post_nms_top_n, float nms_thresh, float min_size,
float eta) {
// 1. pre nms
Tensor scores_sort, index_sort;
SortDescending<T>(ctx, scores, &scores_sort, &index_sort);
int num = scores.numel();
int pre_nms_num = (pre_nms_top_n <= 0 || pre_nms_top_n > num) ? scores.numel()
: pre_nms_top_n;
scores_sort.Resize({pre_nms_num, 1});
index_sort.Resize({pre_nms_num, 1});
// 2. box decode and clipping
Tensor proposals;
proposals.mutable_data<T>({pre_nms_num, 4}, ctx.GetPlace());
{
platform::ForRange<platform::CUDADeviceContext> for_range(ctx, pre_nms_num);
for_range(BoxDecodeAndClipFunctor<T>{
anchors.data<T>(), bbox_deltas.data<T>(), variances.data<T>(),
index_sort.data<int>(), im_shape.data<T>(), proposals.data<T>()});
}
// 3. filter
Tensor keep_index, keep_num_t;
keep_index.mutable_data<int>({pre_nms_num}, ctx.GetPlace());
keep_num_t.mutable_data<int>({1}, ctx.GetPlace());
min_size = std::max(min_size, 1.0f);
auto stream = ctx.stream();
FilterBBoxes<T, 512><<<1, 512, 0, stream>>>(
proposals.data<T>(), im_shape.data<T>(), min_size, pre_nms_num,
keep_num_t.data<int>(), keep_index.data<int>(), false);
int keep_num;
const auto gpu_place = BOOST_GET_CONST(platform::CUDAPlace, ctx.GetPlace());
memory::Copy(platform::CPUPlace(), &keep_num, gpu_place,
keep_num_t.data<int>(), sizeof(int), ctx.stream());
ctx.Wait();
keep_index.Resize({keep_num});
Tensor scores_filter, proposals_filter;
// Handle the case when there is no keep index left
if (keep_num == 0) {
math::SetConstant<platform::CUDADeviceContext, T> set_zero;
proposals_filter.mutable_data<T>({1, 4}, ctx.GetPlace());
scores_filter.mutable_data<T>({1, 1}, ctx.GetPlace());
set_zero(ctx, &proposals_filter, static_cast<T>(0));
set_zero(ctx, &scores_filter, static_cast<T>(0));
return std::make_pair(proposals_filter, scores_filter);
}
proposals_filter.mutable_data<T>({keep_num, 4}, ctx.GetPlace());
scores_filter.mutable_data<T>({keep_num, 1}, ctx.GetPlace());
GPUGather<T>(ctx, proposals, keep_index, &proposals_filter);
GPUGather<T>(ctx, scores_sort, keep_index, &scores_filter);
if (nms_thresh <= 0) {
return std::make_pair(proposals_filter, scores_filter);
}
// 4. nms
Tensor keep_nms;
NMS<T>(ctx, proposals_filter, keep_index, nms_thresh, &keep_nms);
if (post_nms_top_n > 0 && post_nms_top_n < keep_nms.numel()) {
keep_nms.Resize({post_nms_top_n});
}
Tensor scores_nms, proposals_nms;
proposals_nms.mutable_data<T>({keep_nms.numel(), 4}, ctx.GetPlace());
scores_nms.mutable_data<T>({keep_nms.numel(), 1}, ctx.GetPlace());
GPUGather<T>(ctx, proposals_filter, keep_nms, &proposals_nms);
GPUGather<T>(ctx, scores_filter, keep_nms, &scores_nms);
return std::make_pair(proposals_nms, scores_nms);
}
} // namespace
template <typename DeviceContext, typename T>
class CUDAGenerateProposalsV2Kernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext &context) const override {
auto *scores = context.Input<Tensor>("Scores");
auto *bbox_deltas = context.Input<Tensor>("BboxDeltas");
auto *im_shape = context.Input<Tensor>("ImShape");
auto anchors = GET_DATA_SAFELY(context.Input<Tensor>("Anchors"), "Input",
"Anchors", "GenerateProposals");
auto variances = GET_DATA_SAFELY(context.Input<Tensor>("Variances"),
"Input", "Variances", "GenerateProposals");
auto *rpn_rois = context.Output<LoDTensor>("RpnRois");
auto *rpn_roi_probs = context.Output<LoDTensor>("RpnRoiProbs");
int pre_nms_top_n = context.Attr<int>("pre_nms_topN");
int post_nms_top_n = context.Attr<int>("post_nms_topN");
float nms_thresh = context.Attr<float>("nms_thresh");
float min_size = context.Attr<float>("min_size");
float eta = context.Attr<float>("eta");
PADDLE_ENFORCE_GE(eta, 1.,
platform::errors::InvalidArgument(
"Not support adaptive NMS. The attribute 'eta' "
"should not less than 1. But received eta=[%d]",
eta));
auto &dev_ctx = context.template device_context<DeviceContext>();
auto scores_dim = scores->dims();
int64_t num = scores_dim[0];
int64_t c_score = scores_dim[1];
int64_t h_score = scores_dim[2];
int64_t w_score = scores_dim[3];
auto bbox_dim = bbox_deltas->dims();
int64_t c_bbox = bbox_dim[1];
int64_t h_bbox = bbox_dim[2];
int64_t w_bbox = bbox_dim[3];
Tensor bbox_deltas_swap, scores_swap;
bbox_deltas_swap.mutable_data<T>({num, h_bbox, w_bbox, c_bbox},
dev_ctx.GetPlace());
scores_swap.mutable_data<T>({num, h_score, w_score, c_score},
dev_ctx.GetPlace());
math::Transpose<DeviceContext, T, 4> trans;
std::vector<int> axis = {0, 2, 3, 1};
trans(dev_ctx, *bbox_deltas, &bbox_deltas_swap, axis);
trans(dev_ctx, *scores, &scores_swap, axis);
anchors.Resize({anchors.numel() / 4, 4});
variances.Resize({variances.numel() / 4, 4});
rpn_rois->mutable_data<T>({bbox_deltas->numel() / 4, 4},
context.GetPlace());
rpn_roi_probs->mutable_data<T>({scores->numel(), 1}, context.GetPlace());
T *rpn_rois_data = rpn_rois->data<T>();
T *rpn_roi_probs_data = rpn_roi_probs->data<T>();
auto place = BOOST_GET_CONST(platform::CUDAPlace, dev_ctx.GetPlace());
auto cpu_place = platform::CPUPlace();
int64_t num_proposals = 0;
std::vector<size_t> offset(1, 0);
std::vector<int> tmp_num;
for (int64_t i = 0; i < num; ++i) {
Tensor im_shape_slice = im_shape->Slice(i, i + 1);
Tensor bbox_deltas_slice = bbox_deltas_swap.Slice(i, i + 1);
Tensor scores_slice = scores_swap.Slice(i, i + 1);
bbox_deltas_slice.Resize({h_bbox * w_bbox * c_bbox / 4, 4});
scores_slice.Resize({h_score * w_score * c_score, 1});
std::pair<Tensor, Tensor> box_score_pair =
ProposalForOneImage<T>(dev_ctx, im_shape_slice, anchors, variances,
bbox_deltas_slice, scores_slice, pre_nms_top_n,
post_nms_top_n, nms_thresh, min_size, eta);
Tensor &proposals = box_score_pair.first;
Tensor &scores = box_score_pair.second;
memory::Copy(place, rpn_rois_data + num_proposals * 4, place,
proposals.data<T>(), sizeof(T) * proposals.numel(),
dev_ctx.stream());
memory::Copy(place, rpn_roi_probs_data + num_proposals, place,
scores.data<T>(), sizeof(T) * scores.numel(),
dev_ctx.stream());
dev_ctx.Wait();
num_proposals += proposals.dims()[0];
offset.emplace_back(num_proposals);
tmp_num.push_back(proposals.dims()[0]);
}
if (context.HasOutput("RpnRoisNum")) {
auto *rpn_rois_num = context.Output<Tensor>("RpnRoisNum");
rpn_rois_num->mutable_data<int>({num}, context.GetPlace());
int *num_data = rpn_rois_num->data<int>();
memory::Copy(place, num_data, cpu_place, &tmp_num[0], sizeof(int) * num,
dev_ctx.stream());
rpn_rois_num->Resize({num});
}
framework::LoD lod;
lod.emplace_back(offset);
rpn_rois->set_lod(lod);
rpn_roi_probs->set_lod(lod);
rpn_rois->Resize({num_proposals, 4});
rpn_roi_probs->Resize({num_proposals, 1});
}
};
} // namespace operators
} // namespace paddle
namespace ops = paddle::operators;
REGISTER_OP_CUDA_KERNEL(generate_proposals_v2,
ops::CUDAGenerateProposalsV2Kernel<
paddle::platform::CUDADeviceContext, float>);
......@@ -99,5 +99,74 @@ T PolyIoU(const T* box1, const T* box2, const size_t box_size,
}
}
template <class T>
static inline std::vector<std::pair<T, int>> GetSortedScoreIndex(
const std::vector<T>& scores) {
std::vector<std::pair<T, int>> sorted_indices;
sorted_indices.reserve(scores.size());
for (size_t i = 0; i < scores.size(); ++i) {
sorted_indices.emplace_back(scores[i], i);
}
// Sort the score pair according to the scores in descending order
std::stable_sort(sorted_indices.begin(), sorted_indices.end(),
[](const std::pair<T, int>& a, const std::pair<T, int>& b) {
return a.first < b.first;
});
return sorted_indices;
}
template <typename T>
static inline framework::Tensor VectorToTensor(
const std::vector<T>& selected_indices, int selected_num) {
framework::Tensor keep_nms;
keep_nms.Resize({selected_num});
auto* keep_data = keep_nms.mutable_data<T>(platform::CPUPlace());
for (int i = 0; i < selected_num; ++i) {
keep_data[i] = selected_indices[i];
}
return keep_nms;
}
template <class T>
framework::Tensor NMS(const platform::DeviceContext& ctx,
framework::Tensor* bbox, framework::Tensor* scores,
T nms_threshold, float eta) {
int64_t num_boxes = bbox->dims()[0];
// 4: [xmin ymin xmax ymax]
int64_t box_size = bbox->dims()[1];
std::vector<T> scores_data(num_boxes);
std::copy_n(scores->data<T>(), num_boxes, scores_data.begin());
std::vector<std::pair<T, int>> sorted_indices =
GetSortedScoreIndex<T>(scores_data);
std::vector<int> selected_indices;
int selected_num = 0;
T adaptive_threshold = nms_threshold;
const T* bbox_data = bbox->data<T>();
while (sorted_indices.size() != 0) {
int idx = sorted_indices.back().second;
bool flag = true;
for (int kept_idx : selected_indices) {
if (flag) {
T overlap = JaccardOverlap<T>(bbox_data + idx * box_size,
bbox_data + kept_idx * box_size, false);
flag = (overlap <= adaptive_threshold);
} else {
break;
}
}
if (flag) {
selected_indices.push_back(idx);
++selected_num;
}
sorted_indices.erase(sorted_indices.end() - 1);
if (flag && eta < 1 && adaptive_threshold > 0.5) {
adaptive_threshold *= eta;
}
}
return VectorToTensor(selected_indices, selected_num);
}
} // namespace operators
} // namespace paddle
......@@ -81,6 +81,7 @@ std::map<std::string, std::set<std::string>> op_outs_map = {
{"MultiFpnRois", "RestoreIndex", "MultiLevelRoIsNum"}},
{"moving_average_abs_max_scale", {"OutScale", "OutAccum", "OutState"}},
{"multiclass_nms3", {"Out", "NmsRoisNum"}},
{"generate_proposals_v2", {"RpnRois", "RpnRoiProbs", "RpnRoisNum"}},
};
// NOTE(zhiqiu): Commonly, the outputs in auto-generated OP function are
......
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import print_function
import unittest
import numpy as np
import sys
import math
import paddle
import paddle.fluid as fluid
from op_test import OpTest
from test_multiclass_nms_op import nms
from test_anchor_generator_op import anchor_generator_in_python
import copy
from test_generate_proposals_op import clip_tiled_boxes, box_coder, nms
def generate_proposals_v2_in_python(scores, bbox_deltas, im_shape, anchors,
variances, pre_nms_topN, post_nms_topN,
nms_thresh, min_size, eta):
all_anchors = anchors.reshape(-1, 4)
rois = np.empty((0, 5), dtype=np.float32)
roi_probs = np.empty((0, 1), dtype=np.float32)
rpn_rois = []
rpn_roi_probs = []
rois_num = []
num_images = scores.shape[0]
for img_idx in range(num_images):
img_i_boxes, img_i_probs = proposal_for_one_image(
im_shape[img_idx, :], all_anchors, variances,
bbox_deltas[img_idx, :, :, :], scores[img_idx, :, :, :],
pre_nms_topN, post_nms_topN, nms_thresh, min_size, eta)
rois_num.append(img_i_probs.shape[0])
rpn_rois.append(img_i_boxes)
rpn_roi_probs.append(img_i_probs)
return rpn_rois, rpn_roi_probs, rois_num
def proposal_for_one_image(im_shape, all_anchors, variances, bbox_deltas,
scores, pre_nms_topN, post_nms_topN, nms_thresh,
min_size, eta):
# Transpose and reshape predicted bbox transformations to get them
# into the same order as the anchors:
# - bbox deltas will be (4 * A, H, W) format from conv output
# - transpose to (H, W, 4 * A)
# - reshape to (H * W * A, 4) where rows are ordered by (H, W, A)
# in slowest to fastest order to match the enumerated anchors
bbox_deltas = bbox_deltas.transpose((1, 2, 0)).reshape(-1, 4)
all_anchors = all_anchors.reshape(-1, 4)
variances = variances.reshape(-1, 4)
# Same story for the scores:
# - scores are (A, H, W) format from conv output
# - transpose to (H, W, A)
# - reshape to (H * W * A, 1) where rows are ordered by (H, W, A)
# to match the order of anchors and bbox_deltas
scores = scores.transpose((1, 2, 0)).reshape(-1, 1)
# sort all (proposal, score) pairs by score from highest to lowest
# take top pre_nms_topN (e.g. 6000)
if pre_nms_topN <= 0 or pre_nms_topN >= len(scores):
order = np.argsort(-scores.squeeze())
else:
# Avoid sorting possibly large arrays;
# First partition to get top K unsorted
# and then sort just those
inds = np.argpartition(-scores.squeeze(), pre_nms_topN)[:pre_nms_topN]
order = np.argsort(-scores[inds].squeeze())
order = inds[order]
scores = scores[order, :]
bbox_deltas = bbox_deltas[order, :]
all_anchors = all_anchors[order, :]
proposals = box_coder(all_anchors, bbox_deltas, variances)
# clip proposals to image (may result in proposals with zero area
# that will be removed in the next step)
proposals = clip_tiled_boxes(proposals, im_shape)
# remove predicted boxes with height or width < min_size
keep = filter_boxes(proposals, min_size, im_shape)
if len(keep) == 0:
proposals = np.zeros((1, 4)).astype('float32')
scores = np.zeros((1, 1)).astype('float32')
return proposals, scores
proposals = proposals[keep, :]
scores = scores[keep, :]
# apply loose nms (e.g. threshold = 0.7)
# take post_nms_topN (e.g. 1000)
# return the top proposals
if nms_thresh > 0:
keep = nms(boxes=proposals,
scores=scores,
nms_threshold=nms_thresh,
eta=eta)
if post_nms_topN > 0 and post_nms_topN < len(keep):
keep = keep[:post_nms_topN]
proposals = proposals[keep, :]
scores = scores[keep, :]
return proposals, scores
def filter_boxes(boxes, min_size, im_shape):
"""Only keep boxes with both sides >= min_size and center within the image.
"""
# Scale min_size to match image scale
min_size = max(min_size, 1.0)
ws = boxes[:, 2] - boxes[:, 0] + 1
hs = boxes[:, 3] - boxes[:, 1] + 1
x_ctr = boxes[:, 0] + ws / 2.
y_ctr = boxes[:, 1] + hs / 2.
keep = np.where((ws >= min_size) & (hs >= min_size) & (x_ctr < im_shape[1])
& (y_ctr < im_shape[0]))[0]
return keep
class TestGenerateProposalsV2Op(OpTest):
def set_data(self):
self.init_test_params()
self.init_test_input()
self.init_test_output()
self.inputs = {
'Scores': self.scores,
'BboxDeltas': self.bbox_deltas,
'ImShape': self.im_shape.astype(np.float32),
'Anchors': self.anchors,
'Variances': self.variances
}
self.attrs = {
'pre_nms_topN': self.pre_nms_topN,
'post_nms_topN': self.post_nms_topN,
'nms_thresh': self.nms_thresh,
'min_size': self.min_size,
'eta': self.eta
}
self.outputs = {
'RpnRois': (self.rpn_rois[0], [self.rois_num]),
'RpnRoiProbs': (self.rpn_roi_probs[0], [self.rois_num]),
}
def test_check_output(self):
self.check_output()
def setUp(self):
self.op_type = "generate_proposals_v2"
self.set_data()
def init_test_params(self):
self.pre_nms_topN = 12000 # train 12000, test 2000
self.post_nms_topN = 5000 # train 6000, test 1000
self.nms_thresh = 0.7
self.min_size = 3.0
self.eta = 1.
def init_test_input(self):
batch_size = 1
input_channels = 20
layer_h = 16
layer_w = 16
input_feat = np.random.random(
(batch_size, input_channels, layer_h, layer_w)).astype('float32')
self.anchors, self.variances = anchor_generator_in_python(
input_feat=input_feat,
anchor_sizes=[16., 32.],
aspect_ratios=[0.5, 1.0],
variances=[1.0, 1.0, 1.0, 1.0],
stride=[16.0, 16.0],
offset=0.5)
self.im_shape = np.array([[64, 64]]).astype('float32')
num_anchors = self.anchors.shape[2]
self.scores = np.random.random(
(batch_size, num_anchors, layer_h, layer_w)).astype('float32')
self.bbox_deltas = np.random.random(
(batch_size, num_anchors * 4, layer_h, layer_w)).astype('float32')
def init_test_output(self):
self.rpn_rois, self.rpn_roi_probs, self.rois_num = generate_proposals_v2_in_python(
self.scores, self.bbox_deltas, self.im_shape, self.anchors,
self.variances, self.pre_nms_topN, self.post_nms_topN,
self.nms_thresh, self.min_size, self.eta)
class TestGenerateProposalsV2OutLodOp(TestGenerateProposalsV2Op):
def set_data(self):
self.init_test_params()
self.init_test_input()
self.init_test_output()
self.inputs = {
'Scores': self.scores,
'BboxDeltas': self.bbox_deltas,
'ImShape': self.im_shape.astype(np.float32),
'Anchors': self.anchors,
'Variances': self.variances
}
self.attrs = {
'pre_nms_topN': self.pre_nms_topN,
'post_nms_topN': self.post_nms_topN,
'nms_thresh': self.nms_thresh,
'min_size': self.min_size,
'eta': self.eta,
'return_rois_num': True
}
self.outputs = {
'RpnRois': (self.rpn_rois[0], [self.rois_num]),
'RpnRoiProbs': (self.rpn_roi_probs[0], [self.rois_num]),
'RpnRoisNum': (np.asarray(
self.rois_num, dtype=np.int32))
}
class TestGenerateProposalsV2OpNoBoxLeft(TestGenerateProposalsV2Op):
def init_test_params(self):
self.pre_nms_topN = 12000 # train 12000, test 2000
self.post_nms_topN = 5000 # train 6000, test 1000
self.nms_thresh = 0.7
self.min_size = 1000.0
self.eta = 1.
if __name__ == '__main__':
paddle.enable_static()
unittest.main()
......@@ -673,4 +673,5 @@ STATIC_MODE_TESTING_LIST = [
'test_sgd_op_xpu',
'test_shape_op_xpu',
'test_slice_op_xpu',
'test_generate_proposals_v2_op',
]
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册