未验证 提交 fcd4cf7b 编写于 作者: Y Yang Zhang 提交者: GitHub

Add `matrix_nms_op` (#25333)

test=release/1.8
上级 d171f373
......@@ -32,6 +32,7 @@ detection_library(rpn_target_assign_op SRCS rpn_target_assign_op.cc)
detection_library(generate_proposal_labels_op SRCS generate_proposal_labels_op.cc)
detection_library(multiclass_nms_op SRCS multiclass_nms_op.cc DEPS gpc)
detection_library(locality_aware_nms_op SRCS locality_aware_nms_op.cc DEPS gpc)
detection_library(matrix_nms_op SRCS matrix_nms_op.cc DEPS gpc)
detection_library(box_clip_op SRCS box_clip_op.cc box_clip_op.cu)
detection_library(yolov3_loss_op SRCS yolov3_loss_op.cc)
detection_library(yolo_box_op SRCS yolo_box_op.cc yolo_box_op.cu)
......
/* Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
limitations under the License. */
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/operators/detection/nms_util.h"
namespace paddle {
namespace operators {
using Tensor = framework::Tensor;
using LoDTensor = framework::LoDTensor;
class MatrixNMSOp : public framework::OperatorWithKernel {
public:
using framework::OperatorWithKernel::OperatorWithKernel;
void InferShape(framework::InferShapeContext* ctx) const override {
OP_INOUT_CHECK(ctx->HasInput("BBoxes"), "Input", "BBoxes", "MatrixNMS");
OP_INOUT_CHECK(ctx->HasInput("Scores"), "Input", "Scores", "MatrixNMS");
OP_INOUT_CHECK(ctx->HasOutput("Out"), "Output", "Out", "MatrixNMS");
auto box_dims = ctx->GetInputDim("BBoxes");
auto score_dims = ctx->GetInputDim("Scores");
auto score_size = score_dims.size();
if (ctx->IsRuntime()) {
PADDLE_ENFORCE_EQ(score_size == 3, true,
platform::errors::InvalidArgument(
"The rank of Input(Scores) must be 3. "
"But received rank = %d.",
score_size));
PADDLE_ENFORCE_EQ(box_dims.size(), 3,
platform::errors::InvalidArgument(
"The rank of Input(BBoxes) must be 3."
"But received rank = %d.",
box_dims.size()));
PADDLE_ENFORCE_EQ(box_dims[2] == 4, true,
platform::errors::InvalidArgument(
"The last dimension of Input (BBoxes) must be 4, "
"represents the layout of coordinate "
"[xmin, ymin, xmax, ymax]."));
PADDLE_ENFORCE_EQ(
box_dims[1], score_dims[2],
platform::errors::InvalidArgument(
"The 2nd dimension of Input(BBoxes) must be equal to "
"last dimension of Input(Scores), which represents the "
"predicted bboxes."
"But received box_dims[1](%s) != socre_dims[2](%s)",
box_dims[1], score_dims[2]));
}
ctx->SetOutputDim("Out", {box_dims[1], box_dims[2] + 2});
ctx->SetOutputDim("Index", {box_dims[1], 1});
if (!ctx->IsRuntime()) {
ctx->SetLoDLevel("Out", std::max(ctx->GetLoDLevel("BBoxes"), 1));
ctx->SetLoDLevel("Index", std::max(ctx->GetLoDLevel("BBoxes"), 1));
}
}
protected:
framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override {
return framework::OpKernelType(
OperatorWithKernel::IndicateVarDataType(ctx, "Scores"),
platform::CPUPlace());
}
};
template <typename T, bool gaussian>
struct decay_score;
template <typename T>
struct decay_score<T, true> {
T operator()(T iou, T max_iou, T sigma) {
return std::exp((max_iou * max_iou - iou * iou) * sigma);
}
};
template <typename T>
struct decay_score<T, false> {
T operator()(T iou, T max_iou, T sigma) {
return (1. - iou) / (1. - max_iou);
}
};
template <typename T, bool gaussian>
void NMSMatrix(const Tensor& bbox, const Tensor& scores,
const T score_threshold, const T post_threshold,
const float sigma, const int64_t top_k, const bool normalized,
std::vector<int>* selected_indices,
std::vector<T>* decayed_scores) {
int64_t num_boxes = bbox.dims()[0];
int64_t box_size = bbox.dims()[1];
auto score_ptr = scores.data<T>();
auto bbox_ptr = bbox.data<T>();
std::vector<int32_t> perm(num_boxes);
std::iota(perm.begin(), perm.end(), 0);
auto end = std::remove_if(perm.begin(), perm.end(),
[&score_ptr, score_threshold](int32_t idx) {
return score_ptr[idx] <= score_threshold;
});
auto sort_fn = [&score_ptr](int32_t lhs, int32_t rhs) {
return score_ptr[lhs] > score_ptr[rhs];
};
int64_t num_pre = std::distance(perm.begin(), end);
if (num_pre <= 0) {
return;
}
if (top_k > -1 && num_pre > top_k) {
num_pre = top_k;
}
std::partial_sort(perm.begin(), perm.begin() + num_pre, end, sort_fn);
std::vector<T> iou_matrix((num_pre * (num_pre - 1)) >> 1);
std::vector<T> iou_max(num_pre);
iou_max[0] = 0.;
for (int64_t i = 1; i < num_pre; i++) {
T max_iou = 0.;
auto idx_a = perm[i];
for (int64_t j = 0; j < i; j++) {
auto idx_b = perm[j];
auto iou = JaccardOverlap<T>(bbox_ptr + idx_a * box_size,
bbox_ptr + idx_b * box_size, normalized);
max_iou = std::max(max_iou, iou);
iou_matrix[i * (i - 1) / 2 + j] = iou;
}
iou_max[i] = max_iou;
}
if (score_ptr[perm[0]] > post_threshold) {
selected_indices->push_back(perm[0]);
decayed_scores->push_back(score_ptr[perm[0]]);
}
decay_score<T, gaussian> decay_fn;
for (int64_t i = 1; i < num_pre; i++) {
T min_decay = 1.;
for (int64_t j = 0; j < i; j++) {
auto max_iou = iou_max[j];
auto iou = iou_matrix[i * (i - 1) / 2 + j];
auto decay = decay_fn(iou, max_iou, sigma);
min_decay = std::min(min_decay, decay);
}
auto ds = min_decay * score_ptr[perm[i]];
if (ds <= post_threshold) continue;
selected_indices->push_back(perm[i]);
decayed_scores->push_back(ds);
}
}
template <typename T>
class MatrixNMSKernel : public framework::OpKernel<T> {
public:
size_t MultiClassMatrixNMS(const Tensor& scores, const Tensor& bboxes,
std::vector<T>* out, std::vector<int>* indices,
int start, int64_t background_label,
int64_t nms_top_k, int64_t keep_top_k,
bool normalized, T score_threshold,
T post_threshold, bool use_gaussian,
float gaussian_sigma) const {
std::vector<int> all_indices;
std::vector<T> all_scores;
std::vector<T> all_classes;
all_indices.reserve(scores.numel());
all_scores.reserve(scores.numel());
all_classes.reserve(scores.numel());
size_t num_det = 0;
auto class_num = scores.dims()[0];
Tensor score_slice;
for (int64_t c = 0; c < class_num; ++c) {
if (c == background_label) continue;
score_slice = scores.Slice(c, c + 1);
if (use_gaussian) {
NMSMatrix<T, true>(bboxes, score_slice, score_threshold, post_threshold,
gaussian_sigma, nms_top_k, normalized, &all_indices,
&all_scores);
} else {
NMSMatrix<T, false>(bboxes, score_slice, score_threshold,
post_threshold, gaussian_sigma, nms_top_k,
normalized, &all_indices, &all_scores);
}
for (size_t i = 0; i < all_indices.size() - num_det; i++) {
all_classes.push_back(static_cast<T>(c));
}
num_det = all_indices.size();
}
if (num_det <= 0) {
return num_det;
}
if (keep_top_k > -1) {
auto k = static_cast<size_t>(keep_top_k);
if (num_det > k) num_det = k;
}
std::vector<int32_t> perm(all_indices.size());
std::iota(perm.begin(), perm.end(), 0);
std::partial_sort(perm.begin(), perm.begin() + num_det, perm.end(),
[&all_scores](int lhs, int rhs) {
return all_scores[lhs] > all_scores[rhs];
});
for (size_t i = 0; i < num_det; i++) {
auto p = perm[i];
auto idx = all_indices[p];
auto cls = all_classes[p];
auto score = all_scores[p];
auto bbox = bboxes.data<T>() + idx * bboxes.dims()[1];
(*indices).push_back(start + idx);
(*out).push_back(cls);
(*out).push_back(score);
for (int j = 0; j < bboxes.dims()[1]; j++) {
(*out).push_back(bbox[j]);
}
}
return num_det;
}
void Compute(const framework::ExecutionContext& ctx) const override {
auto* boxes = ctx.Input<LoDTensor>("BBoxes");
auto* scores = ctx.Input<LoDTensor>("Scores");
auto* outs = ctx.Output<LoDTensor>("Out");
auto* index = ctx.Output<LoDTensor>("Index");
auto background_label = ctx.Attr<int>("background_label");
auto nms_top_k = ctx.Attr<int>("nms_top_k");
auto keep_top_k = ctx.Attr<int>("keep_top_k");
auto normalized = ctx.Attr<bool>("normalized");
auto score_threshold = ctx.Attr<float>("score_threshold");
auto post_threshold = ctx.Attr<float>("post_threshold");
auto use_gaussian = ctx.Attr<bool>("use_gaussian");
auto gaussian_sigma = ctx.Attr<float>("gaussian_sigma");
auto score_dims = scores->dims();
auto batch_size = score_dims[0];
auto num_boxes = score_dims[2];
auto box_dim = boxes->dims()[2];
auto out_dim = box_dim + 2;
Tensor boxes_slice, scores_slice;
size_t num_out = 0;
std::vector<size_t> offsets = {0};
std::vector<T> detections;
std::vector<int> indices;
detections.reserve(out_dim * num_boxes * batch_size);
indices.reserve(num_boxes * batch_size);
for (int i = 0; i < batch_size; ++i) {
scores_slice = scores->Slice(i, i + 1);
scores_slice.Resize({score_dims[1], score_dims[2]});
boxes_slice = boxes->Slice(i, i + 1);
boxes_slice.Resize({score_dims[2], box_dim});
int start = i * score_dims[2];
num_out = MultiClassMatrixNMS(
scores_slice, boxes_slice, &detections, &indices, start,
background_label, nms_top_k, keep_top_k, normalized, score_threshold,
post_threshold, use_gaussian, gaussian_sigma);
offsets.push_back(offsets.back() + num_out);
}
int64_t num_kept = offsets.back();
if (num_kept == 0) {
outs->mutable_data<T>({0, out_dim}, ctx.GetPlace());
index->mutable_data<int>({0, 1}, ctx.GetPlace());
} else {
outs->mutable_data<T>({num_kept, out_dim}, ctx.GetPlace());
index->mutable_data<int>({num_kept, 1}, ctx.GetPlace());
std::copy(detections.begin(), detections.end(), outs->data<T>());
std::copy(indices.begin(), indices.end(), index->data<int>());
}
framework::LoD lod;
lod.emplace_back(offsets);
outs->set_lod(lod);
index->set_lod(lod);
}
};
class MatrixNMSOpMaker : public framework::OpProtoAndCheckerMaker {
public:
void Make() override {
AddInput("BBoxes",
"(Tensor) A 3-D Tensor with shape "
"[N, M, 4] represents the predicted locations of M bounding boxes"
", N is the batch size. "
"Each bounding box has four coordinate values and the layout is "
"[xmin, ymin, xmax, ymax], when box size equals to 4.");
AddInput("Scores",
"(Tensor) A 3-D Tensor with shape [N, C, M] represents the "
"predicted confidence predictions. N is the batch size, C is the "
"class number, M is number of bounding boxes. For each category "
"there are total M scores which corresponding M bounding boxes. "
" Please note, M is equal to the 2nd dimension of BBoxes. ");
AddAttr<int>(
"background_label",
"(int, default: 0) "
"The index of background label, the background label will be ignored. "
"If set to -1, then all categories will be considered.")
.SetDefault(0);
AddAttr<float>("score_threshold",
"(float) "
"Threshold to filter out bounding boxes with low "
"confidence score.");
AddAttr<float>("post_threshold",
"(float, default 0.) "
"Threshold to filter out bounding boxes with low "
"confidence score AFTER decaying.")
.SetDefault(0.);
AddAttr<int>("nms_top_k",
"(int64_t) "
"Maximum number of detections to be kept according to the "
"confidences after the filtering detections based on "
"score_threshold");
AddAttr<int>("keep_top_k",
"(int64_t) "
"Number of total bboxes to be kept per image after NMS "
"step. -1 means keeping all bboxes after NMS step.");
AddAttr<bool>("normalized",
"(bool, default true) "
"Whether detections are normalized.")
.SetDefault(true);
AddAttr<bool>("use_gaussian",
"(bool, default false) "
"Whether to use Gaussian as decreasing function.")
.SetDefault(false);
AddAttr<float>("gaussian_sigma",
"(float) "
"Sigma for Gaussian decreasing function, only takes effect ",
"when 'use_gaussian' is enabled.")
.SetDefault(2.);
AddOutput("Out",
"(LoDTensor) A 2-D LoDTensor with shape [No, 6] represents the "
"detections. Each row has 6 values: "
"[label, confidence, xmin, ymin, xmax, ymax]. "
"the offsets in first dimension are called LoD, the number of "
"offset is N + 1, if LoD[i + 1] - LoD[i] == 0, means there is "
"no detected bbox.");
AddOutput("Index",
"(LoDTensor) A 2-D LoDTensor with shape [No, 1] represents the "
"index of selected bbox. The index is the absolute index cross "
"batches.");
AddComment(R"DOC(
This operator does multi-class matrix non maximum suppression (NMS) on batched
boxes and scores.
In the NMS step, this operator greedily selects a subset of detection bounding
boxes that have high scores larger than score_threshold, if providing this
threshold, then selects the largest nms_top_k confidences scores if nms_top_k
is larger than -1. Then this operator decays boxes score according to the
Matrix NMS scheme.
Aftern NMS step, at most keep_top_k number of total bboxes are to be kept
per image if keep_top_k is larger than -1.
This operator support multi-class and batched inputs. It applying NMS
independently for each class. The outputs is a 2-D LoDTenosr, for each
image, the offsets in first dimension of LoDTensor are called LoD, the number
of offset is N + 1, where N is the batch size. If LoD[i + 1] - LoD[i] == 0,
means there is no detected bbox for this image.
For more information on Matrix NMS, please refer to:
https://arxiv.org/abs/2003.10152
)DOC");
}
};
} // namespace operators
} // namespace paddle
namespace ops = paddle::operators;
REGISTER_OPERATOR(
matrix_nms, ops::MatrixNMSOp, ops::MatrixNMSOpMaker,
paddle::framework::EmptyGradOpMaker<paddle::framework::OpDesc>,
paddle::framework::EmptyGradOpMaker<paddle::imperative::OpBase>);
REGISTER_OP_CPU_KERNEL(matrix_nms, ops::MatrixNMSKernel<float>,
ops::MatrixNMSKernel<double>);
......@@ -57,6 +57,7 @@ __all__ = [
'box_clip',
'multiclass_nms',
'locality_aware_nms',
'matrix_nms',
'retinanet_detection_output',
'distribute_fpn_proposals',
'box_decoder_and_assign',
......@@ -3387,6 +3388,133 @@ def locality_aware_nms(bboxes,
return output
def matrix_nms(bboxes,
scores,
score_threshold,
post_threshold,
nms_top_k,
keep_top_k,
use_gaussian=False,
gaussian_sigma=2.,
background_label=0,
normalized=True,
return_index=False,
name=None):
"""
**Matrix NMS**
This operator does matrix non maximum suppression (NMS).
First selects a subset of candidate bounding boxes that have higher scores
than score_threshold (if provided), then the top k candidate is selected if
nms_top_k is larger than -1. Score of the remaining candidate are then
decayed according to the Matrix NMS scheme.
Aftern NMS step, at most keep_top_k number of total bboxes are to be kept
per image if keep_top_k is larger than -1.
Args:
bboxes (Variable): A 3-D Tensor with shape [N, M, 4] represents the
predicted locations of M bounding bboxes,
N is the batch size. Each bounding box has four
coordinate values and the layout is
[xmin, ymin, xmax, ymax], when box size equals to 4.
The data type is float32 or float64.
scores (Variable): A 3-D Tensor with shape [N, C, M]
represents the predicted confidence predictions.
N is the batch size, C is the class number, M is
number of bounding boxes. For each category there
are total M scores which corresponding M bounding
boxes. Please note, M is equal to the 2nd dimension
of BBoxes. The data type is float32 or float64.
score_threshold (float): Threshold to filter out bounding boxes with
low confidence score.
post_threshold (float): Threshold to filter out bounding boxes with
low confidence score AFTER decaying.
nms_top_k (int): Maximum number of detections to be kept according to
the confidences after the filtering detections based
on score_threshold.
keep_top_k (int): Number of total bboxes to be kept per image after NMS
step. -1 means keeping all bboxes after NMS step.
use_gaussian (bool): Use Gaussian as the decay function. Default: False
gaussian_sigma (float): Sigma for Gaussian decay function. Default: 2.0
background_label (int): The index of background label, the background
label will be ignored. If set to -1, then all
categories will be considered. Default: 0
normalized (bool): Whether detections are normalized. Default: True
return_index(bool): Whether return selected index. Default: False
name(str): Name of the matrix nms op. Default: None.
Returns:
A tuple with two Variables: (Out, Index) if return_index is True,
otherwise, one Variable(Out) is returned.
Out (Variable): A 2-D LoDTensor with shape [No, 6] containing the
detection results.
Each row has 6 values: [label, confidence, xmin, ymin, xmax, ymax]
(After version 1.3, when no boxes detected, the lod is changed
from {0} to {1})
Index (Variable): A 2-D LoDTensor with shape [No, 1] containing the
selected indices, which are absolute values cross batches.
Examples:
.. code-block:: python
import paddle.fluid as fluid
boxes = fluid.data(name='bboxes', shape=[None,81, 4],
dtype='float32', lod_level=1)
scores = fluid.data(name='scores', shape=[None,81],
dtype='float32', lod_level=1)
out = fluid.layers.matrix_nms(bboxes=boxes,
scores=scores,
background_label=0,
score_threshold=0.5,
post_threshold=0.1,
nms_top_k=400,
keep_top_k=200,
normalized=False)
"""
check_variable_and_dtype(bboxes, 'BBoxes', ['float32', 'float64'],
'matrix_nms')
check_variable_and_dtype(scores, 'Scores', ['float32', 'float64'],
'matrix_nms')
check_type(score_threshold, 'score_threshold', float, 'matrix_nms')
check_type(post_threshold, 'post_threshold', float, 'matrix_nms')
check_type(nms_top_k, 'nums_top_k', int, 'matrix_nms')
check_type(keep_top_k, 'keep_top_k', int, 'matrix_nms')
check_type(normalized, 'normalized', bool, 'matrix_nms')
check_type(use_gaussian, 'use_gaussian', bool, 'matrix_nms')
check_type(gaussian_sigma, 'gaussian_sigma', float, 'matrix_nms')
check_type(background_label, 'background_label', int, 'matrix_nms')
helper = LayerHelper('matrix_nms', **locals())
output = helper.create_variable_for_type_inference(dtype=bboxes.dtype)
index = helper.create_variable_for_type_inference(dtype='int')
helper.append_op(
type="matrix_nms",
inputs={'BBoxes': bboxes,
'Scores': scores},
attrs={
'background_label': background_label,
'score_threshold': score_threshold,
'post_threshold': post_threshold,
'nms_top_k': nms_top_k,
'gaussian_sigma': gaussian_sigma,
'use_gaussian': use_gaussian,
'keep_top_k': keep_top_k,
'normalized': normalized
},
outputs={'Out': output,
'Index': index})
output.stop_gradient = True
if return_index:
return output, index
else:
return output
def distribute_fpn_proposals(fpn_rois,
min_level,
max_level,
......
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import print_function
import unittest
import numpy as np
import copy
from op_test import OpTest
import paddle.fluid as fluid
from paddle.fluid import Program, program_guard
def softmax(x):
# clip to shiftx, otherwise, when calc loss with
# log(exp(shiftx)), may get log(0)=INF
shiftx = (x - np.max(x)).clip(-64.)
exps = np.exp(shiftx)
return exps / np.sum(exps)
def iou_matrix(a, b, norm=True):
tl_i = np.maximum(a[:, np.newaxis, :2], b[:, :2])
br_i = np.minimum(a[:, np.newaxis, 2:], b[:, 2:])
pad = not norm and 1 or 0
area_i = np.prod(br_i - tl_i + pad, axis=2) * (tl_i < br_i).all(axis=2)
area_a = np.prod(a[:, 2:] - a[:, :2] + pad, axis=1)
area_b = np.prod(b[:, 2:] - b[:, :2] + pad, axis=1)
area_o = (area_a[:, np.newaxis] + area_b - area_i)
return area_i / (area_o + 1e-10)
def matrix_nms(boxes,
scores,
score_threshold,
post_threshold=0.,
nms_top_k=400,
normalized=True,
use_gaussian=False,
gaussian_sigma=2.):
all_scores = copy.deepcopy(scores)
all_scores = all_scores.flatten()
selected_indices = np.where(all_scores > score_threshold)[0]
all_scores = all_scores[selected_indices]
sorted_indices = np.argsort(-all_scores, axis=0, kind='mergesort')
sorted_scores = all_scores[sorted_indices]
sorted_indices = selected_indices[sorted_indices]
if nms_top_k > -1 and nms_top_k < sorted_indices.shape[0]:
sorted_indices = sorted_indices[:nms_top_k]
sorted_scores = sorted_scores[:nms_top_k]
selected_boxes = boxes[sorted_indices, :]
ious = iou_matrix(selected_boxes, selected_boxes)
ious = np.triu(ious, k=1)
iou_cmax = ious.max(0)
N = iou_cmax.shape[0]
iou_cmax = np.repeat(iou_cmax[:, np.newaxis], N, axis=1)
if use_gaussian:
decay = np.exp((iou_cmax**2 - ious**2) * gaussian_sigma)
else:
decay = (1 - ious) / (1 - iou_cmax)
decay = decay.min(0)
decayed_scores = sorted_scores * decay
if post_threshold > 0.:
inds = np.where(decayed_scores > post_threshold)[0]
selected_boxes = selected_boxes[inds, :]
decayed_scores = decayed_scores[inds]
sorted_indices = sorted_indices[inds]
return decayed_scores, selected_boxes, sorted_indices
def multiclass_nms(boxes, scores, background, score_threshold, post_threshold,
nms_top_k, keep_top_k, normalized, use_gaussian,
gaussian_sigma):
all_boxes = []
all_cls = []
all_scores = []
all_indices = []
for c in range(scores.shape[0]):
if c == background:
continue
decayed_scores, selected_boxes, indices = matrix_nms(
boxes, scores[c], score_threshold, post_threshold, nms_top_k,
normalized, use_gaussian, gaussian_sigma)
all_cls.append(np.full(len(decayed_scores), c, decayed_scores.dtype))
all_boxes.append(selected_boxes)
all_scores.append(decayed_scores)
all_indices.append(indices)
all_cls = np.concatenate(all_cls)
all_boxes = np.concatenate(all_boxes)
all_scores = np.concatenate(all_scores)
all_indices = np.concatenate(all_indices)
all_pred = np.concatenate(
(all_cls[:, np.newaxis], all_scores[:, np.newaxis], all_boxes), axis=1)
num_det = len(all_pred)
if num_det == 0:
return all_pred, np.array([], dtype=np.float32)
inds = np.argsort(-all_scores, axis=0, kind='mergesort')
all_pred = all_pred[inds, :]
all_indices = all_indices[inds]
if keep_top_k > -1 and num_det > keep_top_k:
num_det = keep_top_k
all_pred = all_pred[:keep_top_k, :]
all_indices = all_indices[:keep_top_k]
return all_pred, all_indices
def batched_multiclass_nms(boxes,
scores,
background,
score_threshold,
post_threshold,
nms_top_k,
keep_top_k,
normalized=True,
use_gaussian=False,
gaussian_sigma=2.):
batch_size = scores.shape[0]
det_outs = []
index_outs = []
lod = []
for n in range(batch_size):
nmsed_outs, indices = multiclass_nms(
boxes[n], scores[n], background, score_threshold, post_threshold,
nms_top_k, keep_top_k, normalized, use_gaussian, gaussian_sigma)
nmsed_num = len(nmsed_outs)
lod.append(nmsed_num)
if nmsed_num == 0:
continue
indices += n * scores.shape[2]
det_outs.append(nmsed_outs)
index_outs.append(indices)
if det_outs:
det_outs = np.concatenate(det_outs)
index_outs = np.concatenate(index_outs)
return det_outs, index_outs, lod
class TestMatrixNMSOp(OpTest):
def set_argument(self):
self.post_threshold = 0.
self.use_gaussian = False
def setUp(self):
self.set_argument()
N = 7
M = 1200
C = 21
BOX_SIZE = 4
background = 0
nms_top_k = 400
keep_top_k = 200
score_threshold = 0.01
post_threshold = self.post_threshold
use_gaussian = False
if hasattr(self, 'use_gaussian'):
use_gaussian = self.use_gaussian
gaussian_sigma = 2.
scores = np.random.random((N * M, C)).astype('float32')
scores = np.apply_along_axis(softmax, 1, scores)
scores = np.reshape(scores, (N, M, C))
scores = np.transpose(scores, (0, 2, 1))
boxes = np.random.random((N, M, BOX_SIZE)).astype('float32')
boxes[:, :, 0:2] = boxes[:, :, 0:2] * 0.5
boxes[:, :, 2:4] = boxes[:, :, 2:4] * 0.5 + 0.5
det_outs, index_outs, lod = batched_multiclass_nms(
boxes, scores, background, score_threshold, post_threshold,
nms_top_k, keep_top_k, True, use_gaussian, gaussian_sigma)
empty = len(det_outs) == 0
det_outs = np.array([], dtype=np.float32) if empty else det_outs
index_outs = np.array([], dtype=np.float32) if empty else index_outs
nmsed_outs = det_outs.astype('float32')
self.op_type = 'matrix_nms'
self.inputs = {'BBoxes': boxes, 'Scores': scores}
self.outputs = {
'Out': (nmsed_outs, [lod]),
'Index': (index_outs[:, None], [lod])
}
self.attrs = {
'background_label': 0,
'nms_top_k': nms_top_k,
'keep_top_k': keep_top_k,
'score_threshold': score_threshold,
'post_threshold': post_threshold,
'use_gaussian': use_gaussian,
'gaussian_sigma': gaussian_sigma,
'normalized': True,
}
def test_check_output(self):
self.check_output()
class TestMatrixNMSOpNoOutput(TestMatrixNMSOp):
def set_argument(self):
self.post_threshold = 2.0
class TestMatrixNMSOpGaussian(TestMatrixNMSOp):
def set_argument(self):
self.post_threshold = 0.
self.use_gaussian = True
class TestMatrixNMSError(unittest.TestCase):
def test_errors(self):
with program_guard(Program(), Program()):
M = 1200
N = 7
C = 21
BOX_SIZE = 4
nms_top_k = 400
keep_top_k = 200
score_threshold = 0.01
post_threshold = 0.
boxes_np = np.random.random((M, C, BOX_SIZE)).astype('float32')
scores = np.random.random((N * M, C)).astype('float32')
scores = np.apply_along_axis(softmax, 1, scores)
scores = np.reshape(scores, (N, M, C))
scores_np = np.transpose(scores, (0, 2, 1))
boxes_data = fluid.data(
name='bboxes', shape=[M, C, BOX_SIZE], dtype='float32')
scores_data = fluid.data(
name='scores', shape=[N, C, M], dtype='float32')
def test_bboxes_Variable():
# the bboxes type must be Variable
fluid.layers.matrix_nms(
bboxes=boxes_np,
scores=scores_data,
nms_top_k=nms_top_k,
keep_top_k=keep_top_k,
score_threshold=score_threshold,
post_threshold=post_threshold)
def test_scores_Variable():
# the scores type must be Variable
fluid.layers.matrix_nms(
bboxes=boxes_data,
scores=scores_np,
nms_top_k=nms_top_k,
keep_top_k=keep_top_k,
score_threshold=score_threshold,
post_threshold=post_threshold)
def test_empty():
# when all score are lower than threshold
try:
fluid.layers.matrix_nms(
bboxes=boxes_data,
scores=scores_data,
nms_top_k=nms_top_k,
keep_top_k=keep_top_k,
score_threshold=10.,
post_threshold=post_threshold)
except Exception as e:
self.fail(e)
def test_coverage():
# cover correct workflow
try:
fluid.layers.matrix_nms(
bboxes=boxes_data,
scores=scores_data,
nms_top_k=nms_top_k,
keep_top_k=keep_top_k,
score_threshold=score_threshold,
post_threshold=post_threshold)
except Exception as e:
self.fail(e)
self.assertRaises(TypeError, test_bboxes_Variable)
self.assertRaises(TypeError, test_scores_Variable)
test_coverage()
if __name__ == '__main__':
unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册