未验证 提交 02414aac 编写于 作者: Z zhiboniu 提交者: GitHub

Phi matrixnums (#44437)

* phi_matrix_nms

* remove old kernels and add optest check_eager

* reoder args

* reoder args in infermate

* update

* get back legacy dygraph
上级 c770053c
...@@ -11,9 +11,11 @@ distributed under the License is distributed on an "AS IS" BASIS, ...@@ -11,9 +11,11 @@ distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
limitations under the License. */ limitations under the License. */
#include "paddle/fluid/framework/infershape_utils.h"
#include "paddle/fluid/framework/op_registry.h" #include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/framework/op_version_registry.h" #include "paddle/fluid/framework/op_version_registry.h"
#include "paddle/fluid/operators/detection/nms_util.h" #include "paddle/fluid/operators/detection/nms_util.h"
#include "paddle/phi/infermeta/binary.h"
namespace paddle { namespace paddle {
namespace operators { namespace operators {
...@@ -25,55 +27,6 @@ class MatrixNMSOp : public framework::OperatorWithKernel { ...@@ -25,55 +27,6 @@ class MatrixNMSOp : public framework::OperatorWithKernel {
public: public:
using framework::OperatorWithKernel::OperatorWithKernel; using framework::OperatorWithKernel::OperatorWithKernel;
void InferShape(framework::InferShapeContext* ctx) const override {
OP_INOUT_CHECK(ctx->HasInput("BBoxes"), "Input", "BBoxes", "MatrixNMS");
OP_INOUT_CHECK(ctx->HasInput("Scores"), "Input", "Scores", "MatrixNMS");
OP_INOUT_CHECK(ctx->HasOutput("Out"), "Output", "Out", "MatrixNMS");
auto box_dims = ctx->GetInputDim("BBoxes");
auto score_dims = ctx->GetInputDim("Scores");
auto score_size = score_dims.size();
if (ctx->IsRuntime()) {
PADDLE_ENFORCE_EQ(score_size == 3,
true,
platform::errors::InvalidArgument(
"The rank of Input(Scores) must be 3. "
"But received rank = %d.",
score_size));
PADDLE_ENFORCE_EQ(box_dims.size(),
3,
platform::errors::InvalidArgument(
"The rank of Input(BBoxes) must be 3."
"But received rank = %d.",
box_dims.size()));
PADDLE_ENFORCE_EQ(box_dims[2] == 4,
true,
platform::errors::InvalidArgument(
"The last dimension of Input (BBoxes) must be 4, "
"represents the layout of coordinate "
"[xmin, ymin, xmax, ymax]."));
PADDLE_ENFORCE_EQ(
box_dims[1],
score_dims[2],
platform::errors::InvalidArgument(
"The 2nd dimension of Input(BBoxes) must be equal to "
"last dimension of Input(Scores), which represents the "
"predicted bboxes."
"But received box_dims[1](%s) != socre_dims[2](%s)",
box_dims[1],
score_dims[2]));
}
ctx->SetOutputDim("Out", {box_dims[1], box_dims[2] + 2});
ctx->SetOutputDim("Index", {box_dims[1], 1});
if (ctx->HasOutput("RoisNum")) {
ctx->SetOutputDim("RoisNum", {-1});
}
if (!ctx->IsRuntime()) {
ctx->SetLoDLevel("Out", std::max(ctx->GetLoDLevel("BBoxes"), 1));
ctx->SetLoDLevel("Index", std::max(ctx->GetLoDLevel("BBoxes"), 1));
}
}
protected: protected:
framework::OpKernelType GetExpectedKernelType( framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override { const framework::ExecutionContext& ctx) const override {
...@@ -83,266 +36,6 @@ class MatrixNMSOp : public framework::OperatorWithKernel { ...@@ -83,266 +36,6 @@ class MatrixNMSOp : public framework::OperatorWithKernel {
} }
}; };
template <typename T, bool gaussian>
struct decay_score;
template <typename T>
struct decay_score<T, true> {
T operator()(T iou, T max_iou, T sigma) {
return std::exp((max_iou * max_iou - iou * iou) * sigma);
}
};
template <typename T>
struct decay_score<T, false> {
T operator()(T iou, T max_iou, T sigma) {
return (1. - iou) / (1. - max_iou);
}
};
template <typename T, bool gaussian>
void NMSMatrix(const Tensor& bbox,
const Tensor& scores,
const T score_threshold,
const T post_threshold,
const float sigma,
const int64_t top_k,
const bool normalized,
std::vector<int>* selected_indices,
std::vector<T>* decayed_scores) {
int64_t num_boxes = bbox.dims()[0];
int64_t box_size = bbox.dims()[1];
auto score_ptr = scores.data<T>();
auto bbox_ptr = bbox.data<T>();
std::vector<int32_t> perm(num_boxes);
std::iota(perm.begin(), perm.end(), 0);
auto end = std::remove_if(
perm.begin(), perm.end(), [&score_ptr, score_threshold](int32_t idx) {
return score_ptr[idx] <= score_threshold;
});
auto sort_fn = [&score_ptr](int32_t lhs, int32_t rhs) {
return score_ptr[lhs] > score_ptr[rhs];
};
int64_t num_pre = std::distance(perm.begin(), end);
if (num_pre <= 0) {
return;
}
if (top_k > -1 && num_pre > top_k) {
num_pre = top_k;
}
std::partial_sort(perm.begin(), perm.begin() + num_pre, end, sort_fn);
std::vector<T> iou_matrix((num_pre * (num_pre - 1)) >> 1);
std::vector<T> iou_max(num_pre);
iou_max[0] = 0.;
for (int64_t i = 1; i < num_pre; i++) {
T max_iou = 0.;
auto idx_a = perm[i];
for (int64_t j = 0; j < i; j++) {
auto idx_b = perm[j];
auto iou = JaccardOverlap<T>(
bbox_ptr + idx_a * box_size, bbox_ptr + idx_b * box_size, normalized);
max_iou = std::max(max_iou, iou);
iou_matrix[i * (i - 1) / 2 + j] = iou;
}
iou_max[i] = max_iou;
}
if (score_ptr[perm[0]] > post_threshold) {
selected_indices->push_back(perm[0]);
decayed_scores->push_back(score_ptr[perm[0]]);
}
decay_score<T, gaussian> decay_fn;
for (int64_t i = 1; i < num_pre; i++) {
T min_decay = 1.;
for (int64_t j = 0; j < i; j++) {
auto max_iou = iou_max[j];
auto iou = iou_matrix[i * (i - 1) / 2 + j];
auto decay = decay_fn(iou, max_iou, sigma);
min_decay = std::min(min_decay, decay);
}
auto ds = min_decay * score_ptr[perm[i]];
if (ds <= post_threshold) continue;
selected_indices->push_back(perm[i]);
decayed_scores->push_back(ds);
}
}
template <typename T>
class MatrixNMSKernel : public framework::OpKernel<T> {
public:
size_t MultiClassMatrixNMS(const Tensor& scores,
const Tensor& bboxes,
std::vector<T>* out,
std::vector<int>* indices,
int start,
int64_t background_label,
int64_t nms_top_k,
int64_t keep_top_k,
bool normalized,
T score_threshold,
T post_threshold,
bool use_gaussian,
float gaussian_sigma) const {
std::vector<int> all_indices;
std::vector<T> all_scores;
std::vector<T> all_classes;
all_indices.reserve(scores.numel());
all_scores.reserve(scores.numel());
all_classes.reserve(scores.numel());
size_t num_det = 0;
auto class_num = scores.dims()[0];
Tensor score_slice;
for (int64_t c = 0; c < class_num; ++c) {
if (c == background_label) continue;
score_slice = scores.Slice(c, c + 1);
if (use_gaussian) {
NMSMatrix<T, true>(bboxes,
score_slice,
score_threshold,
post_threshold,
gaussian_sigma,
nms_top_k,
normalized,
&all_indices,
&all_scores);
} else {
NMSMatrix<T, false>(bboxes,
score_slice,
score_threshold,
post_threshold,
gaussian_sigma,
nms_top_k,
normalized,
&all_indices,
&all_scores);
}
for (size_t i = 0; i < all_indices.size() - num_det; i++) {
all_classes.push_back(static_cast<T>(c));
}
num_det = all_indices.size();
}
if (num_det <= 0) {
return num_det;
}
if (keep_top_k > -1) {
auto k = static_cast<size_t>(keep_top_k);
if (num_det > k) num_det = k;
}
std::vector<int32_t> perm(all_indices.size());
std::iota(perm.begin(), perm.end(), 0);
std::partial_sort(perm.begin(),
perm.begin() + num_det,
perm.end(),
[&all_scores](int lhs, int rhs) {
return all_scores[lhs] > all_scores[rhs];
});
for (size_t i = 0; i < num_det; i++) {
auto p = perm[i];
auto idx = all_indices[p];
auto cls = all_classes[p];
auto score = all_scores[p];
auto bbox = bboxes.data<T>() + idx * bboxes.dims()[1];
(*indices).push_back(start + idx);
(*out).push_back(cls);
(*out).push_back(score);
for (int j = 0; j < bboxes.dims()[1]; j++) {
(*out).push_back(bbox[j]);
}
}
return num_det;
}
void Compute(const framework::ExecutionContext& ctx) const override {
auto* boxes = ctx.Input<LoDTensor>("BBoxes");
auto* scores = ctx.Input<LoDTensor>("Scores");
auto* outs = ctx.Output<LoDTensor>("Out");
auto* index = ctx.Output<LoDTensor>("Index");
auto background_label = ctx.Attr<int>("background_label");
auto nms_top_k = ctx.Attr<int>("nms_top_k");
auto keep_top_k = ctx.Attr<int>("keep_top_k");
auto normalized = ctx.Attr<bool>("normalized");
auto score_threshold = ctx.Attr<float>("score_threshold");
auto post_threshold = ctx.Attr<float>("post_threshold");
auto use_gaussian = ctx.Attr<bool>("use_gaussian");
auto gaussian_sigma = ctx.Attr<float>("gaussian_sigma");
auto score_dims = scores->dims();
auto batch_size = score_dims[0];
auto num_boxes = score_dims[2];
auto box_dim = boxes->dims()[2];
auto out_dim = box_dim + 2;
Tensor boxes_slice, scores_slice;
size_t num_out = 0;
std::vector<size_t> offsets = {0};
std::vector<T> detections;
std::vector<int> indices;
std::vector<int> num_per_batch;
detections.reserve(out_dim * num_boxes * batch_size);
indices.reserve(num_boxes * batch_size);
num_per_batch.reserve(batch_size);
for (int i = 0; i < batch_size; ++i) {
scores_slice = scores->Slice(i, i + 1);
scores_slice.Resize({score_dims[1], score_dims[2]});
boxes_slice = boxes->Slice(i, i + 1);
boxes_slice.Resize({score_dims[2], box_dim});
int start = i * score_dims[2];
num_out = MultiClassMatrixNMS(scores_slice,
boxes_slice,
&detections,
&indices,
start,
background_label,
nms_top_k,
keep_top_k,
normalized,
score_threshold,
post_threshold,
use_gaussian,
gaussian_sigma);
offsets.push_back(offsets.back() + num_out);
num_per_batch.emplace_back(num_out);
}
int64_t num_kept = offsets.back();
if (num_kept == 0) {
outs->mutable_data<T>({0, out_dim}, ctx.GetPlace());
index->mutable_data<int>({0, 1}, ctx.GetPlace());
} else {
outs->mutable_data<T>({num_kept, out_dim}, ctx.GetPlace());
index->mutable_data<int>({num_kept, 1}, ctx.GetPlace());
std::copy(detections.begin(), detections.end(), outs->data<T>());
std::copy(indices.begin(), indices.end(), index->data<int>());
}
if (ctx.HasOutput("RoisNum")) {
auto* rois_num = ctx.Output<Tensor>("RoisNum");
rois_num->mutable_data<int>({batch_size}, ctx.GetPlace());
std::copy(
num_per_batch.begin(), num_per_batch.end(), rois_num->data<int>());
}
framework::LoD lod;
lod.emplace_back(offsets);
outs->set_lod(lod);
index->set_lod(lod);
}
};
class MatrixNMSOpMaker : public framework::OpProtoAndCheckerMaker { class MatrixNMSOpMaker : public framework::OpProtoAndCheckerMaker {
public: public:
void Make() override { void Make() override {
...@@ -435,16 +128,19 @@ https://arxiv.org/abs/2003.10152 ...@@ -435,16 +128,19 @@ https://arxiv.org/abs/2003.10152
} // namespace operators } // namespace operators
} // namespace paddle } // namespace paddle
DECLARE_INFER_SHAPE_FUNCTOR(matrix_nms,
MatrixNMSInferShapeFunctor,
PD_INFER_META(phi::MatrixNMSInferMeta));
namespace ops = paddle::operators; namespace ops = paddle::operators;
REGISTER_OPERATOR( REGISTER_OPERATOR(
matrix_nms, matrix_nms,
ops::MatrixNMSOp, ops::MatrixNMSOp,
ops::MatrixNMSOpMaker, ops::MatrixNMSOpMaker,
paddle::framework::EmptyGradOpMaker<paddle::framework::OpDesc>, paddle::framework::EmptyGradOpMaker<paddle::framework::OpDesc>,
paddle::framework::EmptyGradOpMaker<paddle::imperative::OpBase>); paddle::framework::EmptyGradOpMaker<paddle::imperative::OpBase>,
REGISTER_OP_CPU_KERNEL(matrix_nms, MatrixNMSInferShapeFunctor);
ops::MatrixNMSKernel<float>,
ops::MatrixNMSKernel<double>);
REGISTER_OP_VERSION(matrix_nms) REGISTER_OP_VERSION(matrix_nms)
.AddCheckpoint(R"ROC(Upgrade matrix_nms: add a new output [RoisNum].)ROC", .AddCheckpoint(R"ROC(Upgrade matrix_nms: add a new output [RoisNum].)ROC",
paddle::framework::compatible::OpVersionDesc().NewOutput( paddle::framework::compatible::OpVersionDesc().NewOutput(
......
...@@ -1501,6 +1501,14 @@ ...@@ -1501,6 +1501,14 @@
func : matmul func : matmul
backward : matmul_grad backward : matmul_grad
- api : matrix_nms
args : (Tensor bboxes, Tensor scores, float score_threshold, int nms_top_k, int keep_top_k, float post_threshold=0., bool use_gaussian = false, float gaussian_sigma = 2.0, int background_label = 0, bool normalized = true)
output : Tensor(out), Tensor(index), Tensor(roisnum)
infer_meta :
func : MatrixNMSInferMeta
kernel :
func : matrix_nms
# matrix_power # matrix_power
- api : matrix_power - api : matrix_power
args : (Tensor x, int n) args : (Tensor x, int n)
......
...@@ -1687,6 +1687,64 @@ void MatmulWithFlattenInferMeta(const MetaTensor& x, ...@@ -1687,6 +1687,64 @@ void MatmulWithFlattenInferMeta(const MetaTensor& x,
out->share_lod(x); out->share_lod(x);
} }
void MatrixNMSInferMeta(const MetaTensor& bboxes,
const MetaTensor& scores,
float score_threshold,
int nms_top_k,
int keep_top_k,
float post_threshold,
bool use_gaussian,
float gaussian_sigma,
int background_label,
bool normalized,
MetaTensor* out,
MetaTensor* index,
MetaTensor* roisnum,
MetaConfig config) {
auto box_dims = bboxes.dims();
auto score_dims = scores.dims();
auto score_size = score_dims.size();
if (config.is_runtime) {
PADDLE_ENFORCE_EQ(
score_size == 3,
true,
errors::InvalidArgument("The rank of Input(Scores) must be 3. "
"But received rank = %d.",
score_size));
PADDLE_ENFORCE_EQ(
box_dims.size(),
3,
errors::InvalidArgument("The rank of Input(BBoxes) must be 3."
"But received rank = %d.",
box_dims.size()));
PADDLE_ENFORCE_EQ(box_dims[2] == 4,
true,
errors::InvalidArgument(
"The last dimension of Input (BBoxes) must be 4, "
"represents the layout of coordinate "
"[xmin, ymin, xmax, ymax]."));
PADDLE_ENFORCE_EQ(
box_dims[1],
score_dims[2],
errors::InvalidArgument(
"The 2nd dimension of Input(BBoxes) must be equal to "
"last dimension of Input(Scores), which represents the "
"predicted bboxes."
"But received box_dims[1](%s) != socre_dims[2](%s)",
box_dims[1],
score_dims[2]));
}
out->set_dims({box_dims[1], box_dims[2] + 2});
out->set_dtype(bboxes.dtype());
index->set_dims({box_dims[1], 1});
index->set_dtype(phi::DataType::INT32);
if (roisnum != nullptr) {
roisnum->set_dims({-1});
roisnum->set_dtype(phi::DataType::INT32);
}
}
void MatrixRankTolInferMeta(const MetaTensor& x, void MatrixRankTolInferMeta(const MetaTensor& x,
const MetaTensor& atol_tensor, const MetaTensor& atol_tensor,
bool use_default_tol, bool use_default_tol,
......
...@@ -249,6 +249,21 @@ void MatmulWithFlattenInferMeta(const MetaTensor& x, ...@@ -249,6 +249,21 @@ void MatmulWithFlattenInferMeta(const MetaTensor& x,
int y_num_col_dims, int y_num_col_dims,
MetaTensor* out); MetaTensor* out);
void MatrixNMSInferMeta(const MetaTensor& bboxes,
const MetaTensor& scores,
float score_threshold,
int nms_top_k,
int keep_top_k,
float post_threshold,
bool use_gaussian,
float gaussian_sigma,
int background_label,
bool normalized,
MetaTensor* out,
MetaTensor* index,
MetaTensor* roisnum,
MetaConfig config = MetaConfig());
void MatrixRankTolInferMeta(const MetaTensor& x, void MatrixRankTolInferMeta(const MetaTensor& x,
const MetaTensor& atol_tensor, const MetaTensor& atol_tensor,
bool use_default_tol, bool use_default_tol,
......
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/phi/kernels/matrix_nms_kernel.h"
#include "paddle/phi/backends/cpu/cpu_context.h"
#include "paddle/phi/core/kernel_registry.h"
namespace phi {
template <class T>
static inline T BBoxArea(const T* box, const bool normalized) {
if (box[2] < box[0] || box[3] < box[1]) {
// If coordinate values are is invalid
// (e.g. xmax < xmin or ymax < ymin), return 0.
return static_cast<T>(0.);
} else {
const T w = box[2] - box[0];
const T h = box[3] - box[1];
if (normalized) {
return w * h;
} else {
// If coordinate values are not within range [0, 1].
return (w + 1) * (h + 1);
}
}
}
template <class T>
static inline T JaccardOverlap(const T* box1,
const T* box2,
const bool normalized) {
if (box2[0] > box1[2] || box2[2] < box1[0] || box2[1] > box1[3] ||
box2[3] < box1[1]) {
return static_cast<T>(0.);
} else {
const T inter_xmin = std::max(box1[0], box2[0]);
const T inter_ymin = std::max(box1[1], box2[1]);
const T inter_xmax = std::min(box1[2], box2[2]);
const T inter_ymax = std::min(box1[3], box2[3]);
T norm = normalized ? static_cast<T>(0.) : static_cast<T>(1.);
T inter_w = inter_xmax - inter_xmin + norm;
T inter_h = inter_ymax - inter_ymin + norm;
const T inter_area = inter_w * inter_h;
const T bbox1_area = BBoxArea<T>(box1, normalized);
const T bbox2_area = BBoxArea<T>(box2, normalized);
return inter_area / (bbox1_area + bbox2_area - inter_area);
}
}
template <typename T, bool gaussian>
struct decay_score;
template <typename T>
struct decay_score<T, true> {
T operator()(T iou, T max_iou, T sigma) {
return std::exp((max_iou * max_iou - iou * iou) * sigma);
}
};
template <typename T>
struct decay_score<T, false> {
T operator()(T iou, T max_iou, T sigma) {
return (1. - iou) / (1. - max_iou);
}
};
template <typename T, bool gaussian>
void NMSMatrix(const DenseTensor& bbox,
const DenseTensor& scores,
const T score_threshold,
const T post_threshold,
const float sigma,
const int64_t top_k,
const bool normalized,
std::vector<int>* selected_indices,
std::vector<T>* decayed_scores) {
int64_t num_boxes = bbox.dims()[0];
int64_t box_size = bbox.dims()[1];
auto score_ptr = scores.data<T>();
auto bbox_ptr = bbox.data<T>();
std::vector<int32_t> perm(num_boxes);
std::iota(perm.begin(), perm.end(), 0);
auto end = std::remove_if(
perm.begin(), perm.end(), [&score_ptr, score_threshold](int32_t idx) {
return score_ptr[idx] <= score_threshold;
});
auto sort_fn = [&score_ptr](int32_t lhs, int32_t rhs) {
return score_ptr[lhs] > score_ptr[rhs];
};
int64_t num_pre = std::distance(perm.begin(), end);
if (num_pre <= 0) {
return;
}
if (top_k > -1 && num_pre > top_k) {
num_pre = top_k;
}
std::partial_sort(perm.begin(), perm.begin() + num_pre, end, sort_fn);
std::vector<T> iou_matrix((num_pre * (num_pre - 1)) >> 1);
std::vector<T> iou_max(num_pre);
iou_max[0] = 0.;
for (int64_t i = 1; i < num_pre; i++) {
T max_iou = 0.;
auto idx_a = perm[i];
for (int64_t j = 0; j < i; j++) {
auto idx_b = perm[j];
auto iou = JaccardOverlap<T>(
bbox_ptr + idx_a * box_size, bbox_ptr + idx_b * box_size, normalized);
max_iou = std::max(max_iou, iou);
iou_matrix[i * (i - 1) / 2 + j] = iou;
}
iou_max[i] = max_iou;
}
if (score_ptr[perm[0]] > post_threshold) {
selected_indices->push_back(perm[0]);
decayed_scores->push_back(score_ptr[perm[0]]);
}
decay_score<T, gaussian> decay_fn;
for (int64_t i = 1; i < num_pre; i++) {
T min_decay = 1.;
for (int64_t j = 0; j < i; j++) {
auto max_iou = iou_max[j];
auto iou = iou_matrix[i * (i - 1) / 2 + j];
auto decay = decay_fn(iou, max_iou, sigma);
min_decay = std::min(min_decay, decay);
}
auto ds = min_decay * score_ptr[perm[i]];
if (ds <= post_threshold) continue;
selected_indices->push_back(perm[i]);
decayed_scores->push_back(ds);
}
}
template <typename T>
size_t MultiClassMatrixNMS(const DenseTensor& scores,
const DenseTensor& bboxes,
std::vector<T>* out,
std::vector<int>* indices,
int start,
int64_t background_label,
int64_t nms_top_k,
int64_t keep_top_k,
bool normalized,
T score_threshold,
T post_threshold,
bool use_gaussian,
float gaussian_sigma) {
std::vector<int> all_indices;
std::vector<T> all_scores;
std::vector<T> all_classes;
all_indices.reserve(scores.numel());
all_scores.reserve(scores.numel());
all_classes.reserve(scores.numel());
size_t num_det = 0;
auto class_num = scores.dims()[0];
DenseTensor score_slice;
for (int64_t c = 0; c < class_num; ++c) {
if (c == background_label) continue;
score_slice = scores.Slice(c, c + 1);
if (use_gaussian) {
NMSMatrix<T, true>(bboxes,
score_slice,
score_threshold,
post_threshold,
gaussian_sigma,
nms_top_k,
normalized,
&all_indices,
&all_scores);
} else {
NMSMatrix<T, false>(bboxes,
score_slice,
score_threshold,
post_threshold,
gaussian_sigma,
nms_top_k,
normalized,
&all_indices,
&all_scores);
}
for (size_t i = 0; i < all_indices.size() - num_det; i++) {
all_classes.push_back(static_cast<T>(c));
}
num_det = all_indices.size();
}
if (num_det <= 0) {
return num_det;
}
if (keep_top_k > -1) {
auto k = static_cast<size_t>(keep_top_k);
if (num_det > k) num_det = k;
}
std::vector<int32_t> perm(all_indices.size());
std::iota(perm.begin(), perm.end(), 0);
std::partial_sort(perm.begin(),
perm.begin() + num_det,
perm.end(),
[&all_scores](int lhs, int rhs) {
return all_scores[lhs] > all_scores[rhs];
});
for (size_t i = 0; i < num_det; i++) {
auto p = perm[i];
auto idx = all_indices[p];
auto cls = all_classes[p];
auto score = all_scores[p];
auto bbox = bboxes.data<T>() + idx * bboxes.dims()[1];
(*indices).push_back(start + idx);
(*out).push_back(cls);
(*out).push_back(score);
for (int j = 0; j < bboxes.dims()[1]; j++) {
(*out).push_back(bbox[j]);
}
}
return num_det;
}
template <typename T, typename Context>
void MatrixNMSKernel(const Context& ctx,
const DenseTensor& bboxes,
const DenseTensor& scores,
float score_threshold,
int nms_top_k,
int keep_top_k,
float post_threshold,
bool use_gaussian,
float gaussian_sigma,
int background_label,
bool normalized,
DenseTensor* out,
DenseTensor* index,
DenseTensor* roisnum) {
auto score_dims = scores.dims();
auto batch_size = score_dims[0];
auto num_boxes = score_dims[2];
auto box_dim = bboxes.dims()[2];
auto out_dim = box_dim + 2;
DenseTensor boxes_slice, scores_slice;
size_t num_out = 0;
std::vector<size_t> offsets = {0};
std::vector<T> detections;
std::vector<int> indices;
std::vector<int> num_per_batch;
detections.reserve(out_dim * num_boxes * batch_size);
indices.reserve(num_boxes * batch_size);
num_per_batch.reserve(batch_size);
for (int i = 0; i < batch_size; ++i) {
scores_slice = scores.Slice(i, i + 1);
scores_slice.Resize({score_dims[1], score_dims[2]});
boxes_slice = bboxes.Slice(i, i + 1);
boxes_slice.Resize({score_dims[2], box_dim});
int start = i * score_dims[2];
num_out = MultiClassMatrixNMS(scores_slice,
boxes_slice,
&detections,
&indices,
start,
background_label,
nms_top_k,
keep_top_k,
normalized,
static_cast<T>(score_threshold),
static_cast<T>(post_threshold),
use_gaussian,
gaussian_sigma);
offsets.push_back(offsets.back() + num_out);
num_per_batch.emplace_back(num_out);
}
int64_t num_kept = offsets.back();
if (num_kept == 0) {
out->Resize(phi::make_ddim({0, out_dim}));
ctx.template Alloc<T>(out);
index->Resize(phi::make_ddim({0, 1}));
ctx.template Alloc<int>(index);
} else {
out->Resize(phi::make_ddim({num_kept, out_dim}));
ctx.template Alloc<T>(out);
index->Resize(phi::make_ddim({num_kept, 1}));
ctx.template Alloc<int>(index);
std::copy(detections.begin(), detections.end(), out->data<T>());
std::copy(indices.begin(), indices.end(), index->data<int>());
}
if (roisnum != nullptr) {
roisnum->Resize(phi::make_ddim({batch_size}));
ctx.template Alloc<int>(roisnum);
std::copy(num_per_batch.begin(), num_per_batch.end(), roisnum->data<int>());
}
}
} // namespace phi
PD_REGISTER_KERNEL(
matrix_nms, CPU, ALL_LAYOUT, phi::MatrixNMSKernel, float, double) {}
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include "paddle/phi/core/dense_tensor.h"
namespace phi {
template <typename T, typename Context>
void MatrixNMSKernel(const Context& ctx,
const DenseTensor& bboxes,
const DenseTensor& scores,
float score_threshold,
int nms_top_k,
int keep_top_k,
float post_threshold,
bool use_gaussian,
float gaussian_sigma,
int background_label,
bool normalized,
DenseTensor* out,
DenseTensor* index,
DenseTensor* roisnum);
} // namespace phi
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/phi/core/compat/op_utils.h"
namespace phi {
KernelSignature MatrixNMSOpArgumentMapping(const ArgumentMappingContext& ctx) {
return KernelSignature("matrix_nms",
{"BBoxes", "Scores"},
{"score_threshold",
"nms_top_k",
"keep_top_k",
"post_threshold",
"use_gaussian",
"gaussian_sigma",
"background_label",
"normalized"},
{"Out", "Index", "RoisNum"});
}
} // namespace phi
PD_REGISTER_ARG_MAPPING_FN(matrix_nms, phi::MatrixNMSOpArgumentMapping);
...@@ -3642,6 +3642,16 @@ def matrix_nms(bboxes, ...@@ -3642,6 +3642,16 @@ def matrix_nms(bboxes,
keep_top_k=200, keep_top_k=200,
normalized=False) normalized=False)
""" """
if in_dygraph_mode():
attrs = (score_threshold, nms_top_k, keep_top_k, post_threshold,
use_gaussian, gaussian_sigma, background_label, normalized)
out, index = _C_ops.final_state_matrix_nms(bboxes, scores, *attrs)
if return_index:
return out, index
else:
return out
check_variable_and_dtype(bboxes, 'BBoxes', ['float32', 'float64'], check_variable_and_dtype(bboxes, 'BBoxes', ['float32', 'float64'],
'matrix_nms') 'matrix_nms')
check_variable_and_dtype(scores, 'Scores', ['float32', 'float64'], check_variable_and_dtype(scores, 'Scores', ['float32', 'float64'],
...@@ -3664,13 +3674,13 @@ def matrix_nms(bboxes, ...@@ -3664,13 +3674,13 @@ def matrix_nms(bboxes,
'Scores': scores 'Scores': scores
}, },
attrs={ attrs={
'background_label': background_label,
'score_threshold': score_threshold, 'score_threshold': score_threshold,
'post_threshold': post_threshold, 'post_threshold': post_threshold,
'nms_top_k': nms_top_k, 'nms_top_k': nms_top_k,
'gaussian_sigma': gaussian_sigma,
'use_gaussian': use_gaussian,
'keep_top_k': keep_top_k, 'keep_top_k': keep_top_k,
'use_gaussian': use_gaussian,
'gaussian_sigma': gaussian_sigma,
'background_label': background_label,
'normalized': normalized 'normalized': normalized
}, },
outputs={ outputs={
......
...@@ -22,6 +22,29 @@ from paddle.fluid import Program, program_guard ...@@ -22,6 +22,29 @@ from paddle.fluid import Program, program_guard
import paddle import paddle
def python_matrix_nms(bboxes,
scores,
score_threshold,
nms_top_k,
keep_top_k,
post_threshold,
use_gaussian=False,
gaussian_sigma=2.,
background_label=0,
normalized=True,
return_index=True,
return_rois_num=True):
out, rois_num, index = paddle.vision.ops.matrix_nms(
bboxes, scores, score_threshold, post_threshold, nms_top_k, keep_top_k,
use_gaussian, gaussian_sigma, background_label, normalized,
return_index, return_rois_num)
if not return_index:
index = None
if not return_rois_num:
rois_num = None
return out, index, rois_num
def softmax(x): def softmax(x):
# clip to shiftx, otherwise, when calc loss with # clip to shiftx, otherwise, when calc loss with
# log(exp(shiftx)), may get log(0)=INF # log(exp(shiftx)), may get log(0)=INF
...@@ -167,6 +190,7 @@ class TestMatrixNMSOp(OpTest): ...@@ -167,6 +190,7 @@ class TestMatrixNMSOp(OpTest):
def setUp(self): def setUp(self):
self.set_argument() self.set_argument()
self.python_api = python_matrix_nms
N = 7 N = 7
M = 1200 M = 1200
C = 21 C = 21
...@@ -203,23 +227,23 @@ class TestMatrixNMSOp(OpTest): ...@@ -203,23 +227,23 @@ class TestMatrixNMSOp(OpTest):
self.op_type = 'matrix_nms' self.op_type = 'matrix_nms'
self.inputs = {'BBoxes': boxes, 'Scores': scores} self.inputs = {'BBoxes': boxes, 'Scores': scores}
self.outputs = { self.outputs = {
'Out': (nmsed_outs, [lod]), 'Out': nmsed_outs,
'Index': (index_outs[:, None], [lod]), 'Index': index_outs[:, None],
'RoisNum': np.array(lod).astype('int32') 'RoisNum': np.array(lod).astype('int32')
} }
self.attrs = { self.attrs = {
'background_label': 0, 'score_threshold': score_threshold,
'nms_top_k': nms_top_k, 'nms_top_k': nms_top_k,
'keep_top_k': keep_top_k, 'keep_top_k': keep_top_k,
'score_threshold': score_threshold,
'post_threshold': post_threshold, 'post_threshold': post_threshold,
'use_gaussian': use_gaussian, 'use_gaussian': use_gaussian,
'gaussian_sigma': gaussian_sigma, 'gaussian_sigma': gaussian_sigma,
'background_label': 0,
'normalized': True, 'normalized': True,
} }
def test_check_output(self): def test_check_output(self):
self.check_output() self.check_output(check_eager=True)
class TestMatrixNMSOpNoOutput(TestMatrixNMSOp): class TestMatrixNMSOpNoOutput(TestMatrixNMSOp):
...@@ -265,50 +289,51 @@ class TestMatrixNMSError(unittest.TestCase): ...@@ -265,50 +289,51 @@ class TestMatrixNMSError(unittest.TestCase):
# the bboxes type must be Variable # the bboxes type must be Variable
fluid.layers.matrix_nms(bboxes=boxes_np, fluid.layers.matrix_nms(bboxes=boxes_np,
scores=scores_data, scores=scores_data,
nms_top_k=nms_top_k,
keep_top_k=keep_top_k,
score_threshold=score_threshold, score_threshold=score_threshold,
post_threshold=post_threshold) post_threshold=post_threshold,
nms_top_k=nms_top_k,
keep_top_k=keep_top_k)
paddle.vision.ops.matrix_nms(bboxes=boxes_np, paddle.vision.ops.matrix_nms(bboxes=boxes_np,
scores=scores_data, scores=scores_data,
nms_top_k=nms_top_k,
keep_top_k=keep_top_k,
score_threshold=score_threshold, score_threshold=score_threshold,
post_threshold=post_threshold) post_threshold=post_threshold,
nms_top_k=nms_top_k,
keep_top_k=keep_top_k)
def test_scores_Variable(): def test_scores_Variable():
# the scores type must be Variable # the scores type must be Variable
fluid.layers.matrix_nms(bboxes=boxes_data, fluid.layers.matrix_nms(bboxes=boxes_data,
scores=scores_np, scores=scores_np,
nms_top_k=nms_top_k,
keep_top_k=keep_top_k,
score_threshold=score_threshold, score_threshold=score_threshold,
post_threshold=post_threshold) post_threshold=post_threshold,
nms_top_k=nms_top_k,
keep_top_k=keep_top_k)
paddle.vision.ops.matrix_nms(bboxes=boxes_data, paddle.vision.ops.matrix_nms(bboxes=boxes_data,
scores=scores_np, scores=scores_np,
nms_top_k=nms_top_k,
keep_top_k=keep_top_k,
score_threshold=score_threshold, score_threshold=score_threshold,
post_threshold=post_threshold) post_threshold=post_threshold,
nms_top_k=nms_top_k,
keep_top_k=keep_top_k)
def test_empty(): def test_empty():
# when all score are lower than threshold # when all score are lower than threshold
try: try:
fluid.layers.matrix_nms(bboxes=boxes_data, fluid.layers.matrix_nms(bboxes=boxes_data,
scores=scores_data, scores=scores_data,
score_threshold=score_threshold,
post_threshold=post_threshold,
nms_top_k=nms_top_k, nms_top_k=nms_top_k,
keep_top_k=keep_top_k, keep_top_k=keep_top_k)
score_threshold=10.,
post_threshold=post_threshold)
except Exception as e: except Exception as e:
self.fail(e) self.fail(e)
try: try:
paddle.vision.ops.matrix_nms(bboxes=boxes_data, paddle.vision.ops.matrix_nms(
scores=scores_data, bboxes=boxes_data,
nms_top_k=nms_top_k, scores=scores_data,
keep_top_k=keep_top_k, score_threshold=score_threshold,
score_threshold=10., post_threshold=post_threshold,
post_threshold=post_threshold) nms_top_k=nms_top_k,
keep_top_k=keep_top_k)
except Exception as e: except Exception as e:
self.fail(e) self.fail(e)
...@@ -317,20 +342,20 @@ class TestMatrixNMSError(unittest.TestCase): ...@@ -317,20 +342,20 @@ class TestMatrixNMSError(unittest.TestCase):
try: try:
fluid.layers.matrix_nms(bboxes=boxes_data, fluid.layers.matrix_nms(bboxes=boxes_data,
scores=scores_data, scores=scores_data,
nms_top_k=nms_top_k,
keep_top_k=keep_top_k,
score_threshold=score_threshold, score_threshold=score_threshold,
post_threshold=post_threshold) post_threshold=post_threshold,
nms_top_k=nms_top_k,
keep_top_k=keep_top_k)
except Exception as e: except Exception as e:
self.fail(e) self.fail(e)
try: try:
paddle.vision.ops.matrix_nms( paddle.vision.ops.matrix_nms(
bboxes=boxes_data, bboxes=boxes_data,
scores=scores_data, scores=scores_data,
nms_top_k=nms_top_k,
keep_top_k=keep_top_k,
score_threshold=score_threshold, score_threshold=score_threshold,
post_threshold=post_threshold) post_threshold=post_threshold,
nms_top_k=nms_top_k,
keep_top_k=keep_top_k)
except Exception as e: except Exception as e:
self.fail(e) self.fail(e)
...@@ -340,4 +365,5 @@ class TestMatrixNMSError(unittest.TestCase): ...@@ -340,4 +365,5 @@ class TestMatrixNMSError(unittest.TestCase):
if __name__ == '__main__': if __name__ == '__main__':
paddle.enable_static()
unittest.main() unittest.main()
...@@ -1891,6 +1891,16 @@ def matrix_nms(bboxes, ...@@ -1891,6 +1891,16 @@ def matrix_nms(bboxes,
check_type(background_label, 'background_label', int, 'matrix_nms') check_type(background_label, 'background_label', int, 'matrix_nms')
if in_dygraph_mode(): if in_dygraph_mode():
out, index, rois_num = _C_ops.final_state_matrix_nms(
bboxes, scores, score_threshold, nms_top_k, keep_top_k,
post_threshold, use_gaussian, gaussian_sigma, background_label,
normalized)
if not return_index:
index = None
if not return_rois_num:
rois_num = None
return out, rois_num, index
elif _in_legacy_dygraph():
attrs = ('background_label', background_label, 'score_threshold', attrs = ('background_label', background_label, 'score_threshold',
score_threshold, 'post_threshold', post_threshold, 'nms_top_k', score_threshold, 'post_threshold', post_threshold, 'nms_top_k',
nms_top_k, 'gaussian_sigma', gaussian_sigma, 'use_gaussian', nms_top_k, 'gaussian_sigma', gaussian_sigma, 'use_gaussian',
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册