“506092146d05aa94d2430c5b3dfb63d699dca858”上不存在“paddle/phi/kernels/nonzero_kernel.h”
未验证 提交 a9919903 编写于 作者: Z zhiboniu 提交者: GitHub

phi_multiclass_nms3 (#44613)

上级 e439d735
...@@ -13,8 +13,10 @@ limitations under the License. */ ...@@ -13,8 +13,10 @@ limitations under the License. */
#include <glog/logging.h> #include <glog/logging.h>
#include "paddle/fluid/framework/infershape_utils.h"
#include "paddle/fluid/framework/op_registry.h" #include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/operators/detection/nms_util.h" #include "paddle/fluid/operators/detection/nms_util.h"
#include "paddle/phi/infermeta/ternary.h"
namespace paddle { namespace paddle {
namespace operators { namespace operators {
...@@ -609,12 +611,6 @@ class MultiClassNMS3Op : public MultiClassNMS2Op { ...@@ -609,12 +611,6 @@ class MultiClassNMS3Op : public MultiClassNMS2Op {
const framework::VariableNameMap& outputs, const framework::VariableNameMap& outputs,
const framework::AttributeMap& attrs) const framework::AttributeMap& attrs)
: MultiClassNMS2Op(type, inputs, outputs, attrs) {} : MultiClassNMS2Op(type, inputs, outputs, attrs) {}
void InferShape(framework::InferShapeContext* ctx) const override {
MultiClassNMS2Op::InferShape(ctx);
ctx->SetOutputDim("NmsRoisNum", {-1});
}
}; };
class MultiClassNMS3OpMaker : public MultiClassNMS2OpMaker { class MultiClassNMS3OpMaker : public MultiClassNMS2OpMaker {
...@@ -633,6 +629,10 @@ class MultiClassNMS3OpMaker : public MultiClassNMS2OpMaker { ...@@ -633,6 +629,10 @@ class MultiClassNMS3OpMaker : public MultiClassNMS2OpMaker {
} // namespace operators } // namespace operators
} // namespace paddle } // namespace paddle
DECLARE_INFER_SHAPE_FUNCTOR(multiclass_nms3,
MultiClassNMSShapeFunctor,
PD_INFER_META(phi::MultiClassNMSInferMeta));
namespace ops = paddle::operators; namespace ops = paddle::operators;
REGISTER_OPERATOR( REGISTER_OPERATOR(
multiclass_nms, multiclass_nms,
...@@ -658,7 +658,5 @@ REGISTER_OPERATOR( ...@@ -658,7 +658,5 @@ REGISTER_OPERATOR(
ops::MultiClassNMS3Op, ops::MultiClassNMS3Op,
ops::MultiClassNMS3OpMaker, ops::MultiClassNMS3OpMaker,
paddle::framework::EmptyGradOpMaker<paddle::framework::OpDesc>, paddle::framework::EmptyGradOpMaker<paddle::framework::OpDesc>,
paddle::framework::EmptyGradOpMaker<paddle::imperative::OpBase>); paddle::framework::EmptyGradOpMaker<paddle::imperative::OpBase>,
REGISTER_OP_CPU_KERNEL(multiclass_nms3, MultiClassNMSShapeFunctor);
ops::MultiClassNMSKernel<float>,
ops::MultiClassNMSKernel<double>);
...@@ -1615,6 +1615,15 @@ ...@@ -1615,6 +1615,15 @@
func : multi_dot func : multi_dot
backward : multi_dot_grad backward : multi_dot_grad
- api : multiclass_nms3
args : (Tensor bboxes, Tensor scores, Tensor rois_num, float score_threshold, int nms_top_k, int keep_top_k, float nms_threshold=0.3, bool normalized=true, float nms_eta=1.0, int background_label=0)
output : Tensor(out), Tensor(index), Tensor(nms_rois_num)
infer_meta :
func : MultiClassNMSInferMeta
kernel :
func : multiclass_nms3
optional : rois_num
# multinomial # multinomial
- api : multinomial - api : multinomial
args : (Tensor x, int num_samples, bool replacement) args : (Tensor x, int num_samples, bool replacement)
......
...@@ -743,6 +743,99 @@ void LinspaceInferMeta(const MetaTensor& start, ...@@ -743,6 +743,99 @@ void LinspaceInferMeta(const MetaTensor& start,
LinspaceRawInferMeta(start, stop, number, out); LinspaceRawInferMeta(start, stop, number, out);
} }
void MultiClassNMSInferMeta(const MetaTensor& bboxes,
const MetaTensor& scores,
const MetaTensor& rois_num,
float score_threshold,
int nms_top_k,
int keep_top_k,
float nms_threshold,
bool normalized,
float nms_eta,
int background_label,
MetaTensor* out,
MetaTensor* index,
MetaTensor* nms_rois_num,
MetaConfig config) {
auto box_dims = bboxes.dims();
auto score_dims = scores.dims();
auto score_size = score_dims.size();
if (config.is_runtime) {
PADDLE_ENFORCE_EQ(
score_size == 2 || score_size == 3,
true,
errors::InvalidArgument("The rank of Input(Scores) must be 2 or 3"
". But received rank = %d",
score_size));
PADDLE_ENFORCE_EQ(
box_dims.size(),
3,
errors::InvalidArgument("The rank of Input(BBoxes) must be 3"
". But received rank = %d",
box_dims.size()));
if (score_size == 3) {
PADDLE_ENFORCE_EQ(box_dims[2] == 4 || box_dims[2] == 8 ||
box_dims[2] == 16 || box_dims[2] == 24 ||
box_dims[2] == 32,
true,
errors::InvalidArgument(
"The last dimension of Input"
"(BBoxes) must be 4 or 8, "
"represents the layout of coordinate "
"[xmin, ymin, xmax, ymax] or "
"4 points: [x1, y1, x2, y2, x3, y3, x4, y4] or "
"8 points: [xi, yi] i= 1,2,...,8 or "
"12 points: [xi, yi] i= 1,2,...,12 or "
"16 points: [xi, yi] i= 1,2,...,16"));
PADDLE_ENFORCE_EQ(
box_dims[1],
score_dims[2],
errors::InvalidArgument(
"The 2nd dimension of Input(BBoxes) must be equal to "
"last dimension of Input(Scores), which represents the "
"predicted bboxes."
"But received box_dims[1](%s) != socre_dims[2](%s)",
box_dims[1],
score_dims[2]));
} else {
PADDLE_ENFORCE_EQ(box_dims[2],
4,
errors::InvalidArgument(
"The last dimension of Input"
"(BBoxes) must be 4. But received dimension = %d",
box_dims[2]));
PADDLE_ENFORCE_EQ(
box_dims[1],
score_dims[1],
errors::InvalidArgument(
"The 2nd dimension of Input"
"(BBoxes) must be equal to the 2nd dimension of Input(Scores). "
"But received box dimension = %d, score dimension = %d",
box_dims[1],
score_dims[1]));
}
}
PADDLE_ENFORCE_NE(out,
nullptr,
errors::InvalidArgument(
"The out in MultiClassNMSInferMeta can't be nullptr."));
PADDLE_ENFORCE_NE(
index,
nullptr,
errors::InvalidArgument(
"The index in MultiClassNMSInferMeta can't be nullptr."));
// Here the box_dims[0] is not the real dimension of output.
// It will be rewritten in the computing kernel.
out->set_dims(phi::make_ddim({-1, box_dims[2] + 2}));
out->set_dtype(bboxes.dtype());
index->set_dims(phi::make_ddim({-1, box_dims[2] + 2}));
index->set_dtype(DataType::INT32);
nms_rois_num->set_dims(phi::make_ddim({-1}));
nms_rois_num->set_dtype(DataType::INT32);
}
void NllLossRawInferMeta(const MetaTensor& input, void NllLossRawInferMeta(const MetaTensor& input,
const MetaTensor& label, const MetaTensor& label,
const MetaTensor& weight, const MetaTensor& weight,
......
...@@ -123,6 +123,21 @@ void LinspaceInferMeta(const MetaTensor& start, ...@@ -123,6 +123,21 @@ void LinspaceInferMeta(const MetaTensor& start,
DataType dtype, DataType dtype,
MetaTensor* out); MetaTensor* out);
void MultiClassNMSInferMeta(const MetaTensor& bboxes,
const MetaTensor& scores,
const MetaTensor& rois_num,
float score_threshold,
int nms_top_k,
int keep_top_k,
float nms_threshold,
bool normalized,
float nms_eta,
int background_label,
MetaTensor* out,
MetaTensor* index,
MetaTensor* nms_rois_num,
MetaConfig config = MetaConfig());
void NllLossRawInferMeta(const MetaTensor& input, void NllLossRawInferMeta(const MetaTensor& input,
const MetaTensor& label, const MetaTensor& label,
const MetaTensor& weight, const MetaTensor& weight,
......
...@@ -80,7 +80,8 @@ set(COMMON_KERNEL_DEPS ...@@ -80,7 +80,8 @@ set(COMMON_KERNEL_DEPS
lod_utils lod_utils
custom_kernel custom_kernel
string_infermeta string_infermeta
utf8proc) utf8proc
gpc)
copy_if_different(${kernel_declare_file} ${kernel_declare_file_final}) copy_if_different(${kernel_declare_file} ${kernel_declare_file_final})
......
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/phi/kernels/multiclass_nms3_kernel.h"
#include "paddle/fluid/operators/detection/gpc.h"
#include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/core/tensor_utils.h"
namespace phi {
using gpc::gpc_free_polygon;
using gpc::gpc_polygon_clip;
template <class T>
class Point_ {
public:
// default constructor
Point_() {}
Point_(T _x, T _y) {}
Point_(const Point_& pt) {}
Point_& operator=(const Point_& pt);
// conversion to another data type
// template<typename _T> operator Point_<_T>() const;
// conversion to the old-style C structures
// operator Vec<T, 2>() const;
// checks whether the point is inside the specified rectangle
// bool inside(const Rect_<T>& r) const;
T x; //!< x coordinate of the point
T y; //!< y coordinate of the point
};
template <class T>
void Array2PointVec(const T* box,
const size_t box_size,
std::vector<Point_<T>>* vec) {
size_t pts_num = box_size / 2;
(*vec).resize(pts_num);
for (size_t i = 0; i < pts_num; i++) {
(*vec).at(i).x = box[2 * i];
(*vec).at(i).y = box[2 * i + 1];
}
}
template <class T>
void Array2Poly(const T* box, const size_t box_size, gpc::gpc_polygon* poly) {
size_t pts_num = box_size / 2;
(*poly).num_contours = 1;
(*poly).hole = reinterpret_cast<int*>(malloc(sizeof(int)));
(*poly).hole[0] = 0;
(*poly).contour = (gpc::gpc_vertex_list*)malloc(sizeof(gpc::gpc_vertex_list));
(*poly).contour->num_vertices = pts_num;
(*poly).contour->vertex =
(gpc::gpc_vertex*)malloc(sizeof(gpc::gpc_vertex) * pts_num);
for (size_t i = 0; i < pts_num; ++i) {
(*poly).contour->vertex[i].x = box[2 * i];
(*poly).contour->vertex[i].y = box[2 * i + 1];
}
}
template <class T>
void PointVec2Poly(const std::vector<Point_<T>>& vec, gpc::gpc_polygon* poly) {
int pts_num = vec.size();
(*poly).num_contours = 1;
(*poly).hole = reinterpret_cast<int*>(malloc(sizeof(int)));
(*poly).hole[0] = 0;
(*poly).contour = (gpc::gpc_vertex_list*)malloc(sizeof(gpc::gpc_vertex_list));
(*poly).contour->num_vertices = pts_num;
(*poly).contour->vertex =
(gpc::gpc_vertex*)malloc(sizeof(gpc::gpc_vertex) * pts_num);
for (size_t i = 0; i < pts_num; ++i) {
(*poly).contour->vertex[i].x = vec[i].x;
(*poly).contour->vertex[i].y = vec[i].y;
}
}
template <class T>
void Poly2PointVec(const gpc::gpc_vertex_list& contour,
std::vector<Point_<T>>* vec) {
int pts_num = contour.num_vertices;
(*vec).resize(pts_num);
for (int i = 0; i < pts_num; i++) {
(*vec).at(i).x = contour.vertex[i].x;
(*vec).at(i).y = contour.vertex[i].y;
}
}
template <class T>
T GetContourArea(const std::vector<Point_<T>>& vec) {
size_t pts_num = vec.size();
if (pts_num < 3) return T(0.);
T area = T(0.);
for (size_t i = 0; i < pts_num; ++i) {
area += vec[i].x * vec[(i + 1) % pts_num].y -
vec[i].y * vec[(i + 1) % pts_num].x;
}
return std::fabs(area / 2.0);
}
template <class T>
T PolyArea(const T* box, const size_t box_size, const bool normalized) {
// If coordinate values are is invalid
// if area size <= 0, return 0.
std::vector<Point_<T>> vec;
Array2PointVec<T>(box, box_size, &vec);
return GetContourArea<T>(vec);
}
template <class T>
T PolyOverlapArea(const T* box1,
const T* box2,
const size_t box_size,
const bool normalized) {
gpc::gpc_polygon poly1;
gpc::gpc_polygon poly2;
Array2Poly<T>(box1, box_size, &poly1);
Array2Poly<T>(box2, box_size, &poly2);
gpc::gpc_polygon respoly;
gpc::gpc_op op = gpc::GPC_INT;
gpc::gpc_polygon_clip(op, &poly2, &poly1, &respoly);
T inter_area = T(0.);
int contour_num = respoly.num_contours;
for (int i = 0; i < contour_num; ++i) {
std::vector<Point_<T>> resvec;
Poly2PointVec<T>(respoly.contour[i], &resvec);
// inter_area += std::fabs(cv::contourArea(resvec)) + 0.5f *
// (cv::arcLength(resvec, true));
inter_area += GetContourArea<T>(resvec);
}
gpc::gpc_free_polygon(&poly1);
gpc::gpc_free_polygon(&poly2);
gpc::gpc_free_polygon(&respoly);
return inter_area;
}
template <class T>
bool SortScorePairDescend(const std::pair<float, T>& pair1,
const std::pair<float, T>& pair2) {
return pair1.first > pair2.first;
}
template <class T>
static inline void GetMaxScoreIndex(
const std::vector<T>& scores,
const T threshold,
int top_k,
std::vector<std::pair<T, int>>* sorted_indices) {
for (size_t i = 0; i < scores.size(); ++i) {
if (scores[i] > threshold) {
sorted_indices->push_back(std::make_pair(scores[i], i));
}
}
// Sort the score pair according to the scores in descending order
std::stable_sort(sorted_indices->begin(),
sorted_indices->end(),
SortScorePairDescend<int>);
// Keep top_k scores if needed.
if (top_k > -1 && top_k < static_cast<int>(sorted_indices->size())) {
sorted_indices->resize(top_k);
}
}
template <class T>
static inline T BBoxArea(const T* box, const bool normalized) {
if (box[2] < box[0] || box[3] < box[1]) {
// If coordinate values are is invalid
// (e.g. xmax < xmin or ymax < ymin), return 0.
return static_cast<T>(0.);
} else {
const T w = box[2] - box[0];
const T h = box[3] - box[1];
if (normalized) {
return w * h;
} else {
// If coordinate values are not within range [0, 1].
return (w + 1) * (h + 1);
}
}
}
template <class T>
static inline T JaccardOverlap(const T* box1,
const T* box2,
const bool normalized) {
if (box2[0] > box1[2] || box2[2] < box1[0] || box2[1] > box1[3] ||
box2[3] < box1[1]) {
return static_cast<T>(0.);
} else {
const T inter_xmin = std::max(box1[0], box2[0]);
const T inter_ymin = std::max(box1[1], box2[1]);
const T inter_xmax = std::min(box1[2], box2[2]);
const T inter_ymax = std::min(box1[3], box2[3]);
T norm = normalized ? static_cast<T>(0.) : static_cast<T>(1.);
T inter_w = inter_xmax - inter_xmin + norm;
T inter_h = inter_ymax - inter_ymin + norm;
const T inter_area = inter_w * inter_h;
const T bbox1_area = BBoxArea<T>(box1, normalized);
const T bbox2_area = BBoxArea<T>(box2, normalized);
return inter_area / (bbox1_area + bbox2_area - inter_area);
}
}
template <class T>
T PolyIoU(const T* box1,
const T* box2,
const size_t box_size,
const bool normalized) {
T bbox1_area = PolyArea<T>(box1, box_size, normalized);
T bbox2_area = PolyArea<T>(box2, box_size, normalized);
T inter_area = PolyOverlapArea<T>(box1, box2, box_size, normalized);
if (bbox1_area == 0 || bbox2_area == 0 || inter_area == 0) {
// If coordinate values are invalid
// if area size <= 0, return 0.
return T(0.);
} else {
return inter_area / (bbox1_area + bbox2_area - inter_area);
}
}
inline std::vector<size_t> GetNmsLodFromRoisNum(const DenseTensor* rois_num) {
std::vector<size_t> rois_lod;
auto* rois_num_data = rois_num->data<int>();
rois_lod.push_back(static_cast<size_t>(0));
for (int i = 0; i < rois_num->numel(); ++i) {
rois_lod.push_back(rois_lod.back() + static_cast<size_t>(rois_num_data[i]));
}
return rois_lod;
}
template <typename T, typename Context>
void SliceOneClass(const Context& ctx,
const DenseTensor& items,
const int class_id,
DenseTensor* one_class_item) {
// T* item_data = one_class_item->mutable_data<T>(ctx.GetPlace());
T* item_data = ctx.template Alloc<T>(one_class_item);
const T* items_data = items.data<T>();
const int64_t num_item = items.dims()[0];
const int class_num = items.dims()[1];
if (items.dims().size() == 3) {
int item_size = items.dims()[2];
for (int i = 0; i < num_item; ++i) {
std::memcpy(item_data + i * item_size,
items_data + i * class_num * item_size + class_id * item_size,
sizeof(T) * item_size);
}
} else {
for (int i = 0; i < num_item; ++i) {
item_data[i] = items_data[i * class_num + class_id];
}
}
}
template <typename T>
void NMSFast(const DenseTensor& bbox,
const DenseTensor& scores,
const T score_threshold,
const T nms_threshold,
const T eta,
const int64_t top_k,
std::vector<int>* selected_indices,
const bool normalized) {
// The total boxes for each instance.
int64_t num_boxes = bbox.dims()[0];
// 4: [xmin ymin xmax ymax]
// 8: [x1 y1 x2 y2 x3 y3 x4 y4]
// 16, 24, or 32: [x1 y1 x2 y2 ... xn yn], n = 8, 12 or 16
int64_t box_size = bbox.dims()[1];
std::vector<T> scores_data(num_boxes);
std::copy_n(scores.data<T>(), num_boxes, scores_data.begin());
std::vector<std::pair<T, int>> sorted_indices;
GetMaxScoreIndex<T>(scores_data, score_threshold, top_k, &sorted_indices);
selected_indices->clear();
T adaptive_threshold = nms_threshold;
const T* bbox_data = bbox.data<T>();
while (sorted_indices.size() != 0) {
const int idx = sorted_indices.front().second;
bool keep = true;
for (size_t k = 0; k < selected_indices->size(); ++k) {
if (keep) {
const int kept_idx = (*selected_indices)[k];
T overlap = T(0.);
// 4: [xmin ymin xmax ymax]
if (box_size == 4) {
overlap = JaccardOverlap<T>(bbox_data + idx * box_size,
bbox_data + kept_idx * box_size,
normalized);
}
// 8: [x1 y1 x2 y2 x3 y3 x4 y4] or 16, 24, 32
if (box_size == 8 || box_size == 16 || box_size == 24 ||
box_size == 32) {
overlap = PolyIoU<T>(bbox_data + idx * box_size,
bbox_data + kept_idx * box_size,
box_size,
normalized);
}
keep = overlap <= adaptive_threshold;
} else {
break;
}
}
if (keep) {
selected_indices->push_back(idx);
}
sorted_indices.erase(sorted_indices.begin());
if (keep && eta < 1 && adaptive_threshold > 0.5) {
adaptive_threshold *= eta;
}
}
}
template <typename T, typename Context>
void MultiClassNMS(const Context& ctx,
const DenseTensor& scores,
const DenseTensor& bboxes,
const int scores_size,
float scorethreshold,
int nms_top_k,
int keep_top_k,
float nmsthreshold,
bool normalized,
float nmseta,
int background_label,
std::map<int, std::vector<int>>* indices,
int* num_nmsed_out) {
T nms_threshold = static_cast<T>(nmsthreshold);
T nms_eta = static_cast<T>(nmseta);
T score_threshold = static_cast<T>(scorethreshold);
int num_det = 0;
int64_t class_num = scores_size == 3 ? scores.dims()[0] : scores.dims()[1];
DenseTensor bbox_slice, score_slice;
for (int64_t c = 0; c < class_num; ++c) {
if (c == background_label) continue;
if (scores_size == 3) {
score_slice = scores.Slice(c, c + 1);
bbox_slice = bboxes;
} else {
score_slice.Resize({scores.dims()[0], 1});
bbox_slice.Resize({scores.dims()[0], 4});
SliceOneClass<T, Context>(ctx, scores, c, &score_slice);
SliceOneClass<T, Context>(ctx, bboxes, c, &bbox_slice);
}
NMSFast<T>(bbox_slice,
score_slice,
score_threshold,
nms_threshold,
nms_eta,
nms_top_k,
&((*indices)[c]),
normalized);
if (scores_size == 2) {
std::stable_sort((*indices)[c].begin(), (*indices)[c].end());
}
num_det += (*indices)[c].size();
}
*num_nmsed_out = num_det;
const T* scores_data = scores.data<T>();
if (keep_top_k > -1 && num_det > keep_top_k) {
const T* sdata;
std::vector<std::pair<float, std::pair<int, int>>> score_index_pairs;
for (const auto& it : *indices) {
int label = it.first;
if (scores_size == 3) {
sdata = scores_data + label * scores.dims()[1];
} else {
score_slice.Resize({scores.dims()[0], 1});
SliceOneClass<T, Context>(ctx, scores, label, &score_slice);
sdata = score_slice.data<T>();
}
const std::vector<int>& label_indices = it.second;
for (size_t j = 0; j < label_indices.size(); ++j) {
int idx = label_indices[j];
score_index_pairs.push_back(
std::make_pair(sdata[idx], std::make_pair(label, idx)));
}
}
// Keep top k results per image.
std::stable_sort(score_index_pairs.begin(),
score_index_pairs.end(),
SortScorePairDescend<std::pair<int, int>>);
score_index_pairs.resize(keep_top_k);
// Store the new indices.
std::map<int, std::vector<int>> new_indices;
for (size_t j = 0; j < score_index_pairs.size(); ++j) {
int label = score_index_pairs[j].second.first;
int idx = score_index_pairs[j].second.second;
new_indices[label].push_back(idx);
}
if (scores_size == 2) {
for (const auto& it : new_indices) {
int label = it.first;
std::stable_sort(new_indices[label].begin(), new_indices[label].end());
}
}
new_indices.swap(*indices);
*num_nmsed_out = keep_top_k;
}
}
template <typename T, typename Context>
void MultiClassOutput(const Context& ctx,
const DenseTensor& scores,
const DenseTensor& bboxes,
const std::map<int, std::vector<int>>& selected_indices,
const int scores_size,
DenseTensor* out,
int* oindices = nullptr,
const int offset = 0) {
int64_t class_num = scores.dims()[1];
int64_t predict_dim = scores.dims()[1];
int64_t box_size = bboxes.dims()[1];
if (scores_size == 2) {
box_size = bboxes.dims()[2];
}
int64_t out_dim = box_size + 2;
auto* scores_data = scores.data<T>();
auto* bboxes_data = bboxes.data<T>();
auto* odata = out->data<T>();
const T* sdata;
DenseTensor bbox;
bbox.Resize({scores.dims()[0], box_size});
int count = 0;
for (const auto& it : selected_indices) {
int label = it.first;
const std::vector<int>& indices = it.second;
if (scores_size == 2) {
SliceOneClass<T, Context>(ctx, bboxes, label, &bbox);
} else {
sdata = scores_data + label * predict_dim;
}
for (size_t j = 0; j < indices.size(); ++j) {
int idx = indices[j];
odata[count * out_dim] = label; // label
const T* bdata;
if (scores_size == 3) {
bdata = bboxes_data + idx * box_size;
odata[count * out_dim + 1] = sdata[idx]; // score
if (oindices != nullptr) {
oindices[count] = offset + idx;
}
} else {
bdata = bbox.data<T>() + idx * box_size;
odata[count * out_dim + 1] = *(scores_data + idx * class_num + label);
if (oindices != nullptr) {
oindices[count] = offset + idx * class_num + label;
}
}
// xmin, ymin, xmax, ymax or multi-points coordinates
std::memcpy(odata + count * out_dim + 2, bdata, box_size * sizeof(T));
count++;
}
}
}
template <typename T, typename Context>
void MultiClassNMSKernel(const Context& ctx,
const DenseTensor& bboxes,
const DenseTensor& scores,
const paddle::optional<DenseTensor>& rois_num,
float score_threshold,
int nms_top_k,
int keep_top_k,
float nms_threshold,
bool normalized,
float nms_eta,
int background_label,
DenseTensor* out,
DenseTensor* index,
DenseTensor* nms_rois_num) {
bool return_index = index != nullptr;
bool has_roisnum = rois_num.get_ptr() != nullptr;
auto score_dims = scores.dims();
auto score_size = score_dims.size();
std::vector<std::map<int, std::vector<int>>> all_indices;
std::vector<size_t> batch_starts = {0};
int64_t batch_size = score_dims[0];
int64_t box_dim = bboxes.dims()[2];
int64_t out_dim = box_dim + 2;
int num_nmsed_out = 0;
DenseTensor boxes_slice, scores_slice;
int n = 0;
if (has_roisnum) {
n = score_size == 3 ? batch_size : rois_num.get_ptr()->numel();
} else {
n = score_size == 3 ? batch_size : bboxes.lod().back().size() - 1;
}
for (int i = 0; i < n; ++i) {
std::map<int, std::vector<int>> indices;
if (score_size == 3) {
scores_slice = scores.Slice(i, i + 1);
scores_slice.Resize({score_dims[1], score_dims[2]});
boxes_slice = bboxes.Slice(i, i + 1);
boxes_slice.Resize({score_dims[2], box_dim});
} else {
std::vector<size_t> boxes_lod;
if (has_roisnum) {
boxes_lod = GetNmsLodFromRoisNum(rois_num.get_ptr());
} else {
boxes_lod = bboxes.lod().back();
}
if (boxes_lod[i] == boxes_lod[i + 1]) {
all_indices.push_back(indices);
batch_starts.push_back(batch_starts.back());
continue;
}
scores_slice = scores.Slice(boxes_lod[i], boxes_lod[i + 1]);
boxes_slice = bboxes.Slice(boxes_lod[i], boxes_lod[i + 1]);
}
MultiClassNMS<T, Context>(ctx,
scores_slice,
boxes_slice,
score_size,
score_threshold,
nms_top_k,
keep_top_k,
nms_threshold,
normalized,
nms_eta,
background_label,
&indices,
&num_nmsed_out);
all_indices.push_back(indices);
batch_starts.push_back(batch_starts.back() + num_nmsed_out);
}
int num_kept = batch_starts.back();
if (num_kept == 0) {
if (return_index) {
out->Resize({0, out_dim});
ctx.template Alloc<T>(out);
index->Resize({0, 1});
ctx.template Alloc<int>(index);
} else {
out->Resize({1, 1});
T* od = ctx.template Alloc<T>(out);
od[0] = -1;
batch_starts = {0, 1};
}
} else {
out->Resize({num_kept, out_dim});
ctx.template Alloc<T>(out);
int offset = 0;
int* oindices = nullptr;
for (int i = 0; i < n; ++i) {
if (score_size == 3) {
scores_slice = scores.Slice(i, i + 1);
boxes_slice = bboxes.Slice(i, i + 1);
scores_slice.Resize({score_dims[1], score_dims[2]});
boxes_slice.Resize({score_dims[2], box_dim});
if (return_index) {
offset = i * score_dims[2];
}
} else {
std::vector<size_t> boxes_lod;
if (has_roisnum) {
boxes_lod = GetNmsLodFromRoisNum(rois_num.get_ptr());
} else {
boxes_lod = bboxes.lod().back();
}
if (boxes_lod[i] == boxes_lod[i + 1]) continue;
scores_slice = scores.Slice(boxes_lod[i], boxes_lod[i + 1]);
boxes_slice = bboxes.Slice(boxes_lod[i], boxes_lod[i + 1]);
if (return_index) {
offset = boxes_lod[i] * score_dims[1];
}
}
int64_t s = batch_starts[i];
int64_t e = batch_starts[i + 1];
if (e > s) {
DenseTensor nout = out->Slice(s, e);
if (return_index) {
index->Resize({num_kept, 1});
int* output_idx = ctx.template Alloc<int>(index);
oindices = output_idx + s;
}
MultiClassOutput<T, Context>(ctx,
scores_slice,
boxes_slice,
all_indices[i],
score_dims.size(),
&nout,
oindices,
offset);
}
}
}
if (nms_rois_num != nullptr) {
nms_rois_num->Resize({n});
ctx.template Alloc<int>(nms_rois_num);
int* num_data = nms_rois_num->data<int>();
for (int i = 1; i <= n; i++) {
num_data[i - 1] = batch_starts[i] - batch_starts[i - 1];
}
nms_rois_num->Resize({n});
}
}
} // namespace phi
PD_REGISTER_KERNEL(
multiclass_nms3, CPU, ALL_LAYOUT, phi::MultiClassNMSKernel, float, double) {
}
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include "paddle/phi/core/dense_tensor.h"
namespace phi {
template <typename T, typename Context>
void MultiClassNMSKernel(const Context& ctx,
const DenseTensor& bboxes,
const DenseTensor& scores,
const paddle::optional<DenseTensor>& rois_num,
float score_threshold,
int nms_top_k,
int keep_top_k,
float nms_threshold,
bool normalized,
float nms_eta,
int background_label,
DenseTensor* out,
DenseTensor* index,
DenseTensor* nms_rois_num);
} // namespace phi
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/phi/core/compat/op_utils.h"
namespace phi {
KernelSignature MultiClassNMS3OpArgumentMapping(
const ArgumentMappingContext& ctx) {
return KernelSignature("multiclass_nms3",
{"BBoxes", "Scores", "RoisNum"},
{"score_threshold",
"nms_top_k",
"keep_top_k",
"nms_threshold",
"normalized",
"nms_eta",
"background_label"},
{"Out", "Index", "NmsRoisNum"});
}
} // namespace phi
PD_REGISTER_ARG_MAPPING_FN(multiclass_nms3,
phi::MultiClassNMS3OpArgumentMapping);
...@@ -1457,6 +1457,7 @@ class OpTest(unittest.TestCase): ...@@ -1457,6 +1457,7 @@ class OpTest(unittest.TestCase):
# see details: https://stackoverflow.com/questions/38331703/why-does-numpys-broadcasting-sometimes-allow-comparing-arrays-of-different-leng # see details: https://stackoverflow.com/questions/38331703/why-does-numpys-broadcasting-sometimes-allow-comparing-arrays-of-different-leng
if expect_np.size == 0: if expect_np.size == 0:
self.op_test.assertTrue(actual_np.size == 0) # }}} self.op_test.assertTrue(actual_np.size == 0) # }}}
# print("actual_np, expect_np", actual_np, expect_np)
self._compare_numpy(name, actual_np, expect_np) self._compare_numpy(name, actual_np, expect_np)
if isinstance(expect, tuple): if isinstance(expect, tuple):
self._compare_list(name, actual, expect) self._compare_list(name, actual, expect)
......
...@@ -19,7 +19,81 @@ import copy ...@@ -19,7 +19,81 @@ import copy
from op_test import OpTest from op_test import OpTest
import paddle import paddle
import paddle.fluid as fluid import paddle.fluid as fluid
from paddle.fluid import Program, program_guard from paddle.fluid import Program, program_guard, in_dygraph_mode, _non_static_mode
from paddle.fluid.layer_helper import LayerHelper
from paddle import _C_ops
def multiclass_nms3(bboxes,
scores,
rois_num=None,
score_threshold=0.3,
nms_top_k=1000,
keep_top_k=100,
nms_threshold=0.3,
normalized=True,
nms_eta=1.,
background_label=-1,
return_index=True,
return_rois_num=True,
name=None):
helper = LayerHelper('multiclass_nms3', **locals())
if in_dygraph_mode():
attrs = (score_threshold, nms_top_k, keep_top_k, nms_threshold,
normalized, nms_eta, background_label)
output, index, nms_rois_num = _C_ops.final_state_multiclass_nms3(
bboxes, scores, rois_num, *attrs)
if not return_index:
index = None
return output, index, nms_rois_num
elif _non_static_mode():
attrs = ('background_label', background_label, 'score_threshold',
score_threshold, 'nms_top_k', nms_top_k, 'nms_threshold',
nms_threshold, 'keep_top_k', keep_top_k, 'nms_eta', nms_eta,
'normalized', normalized)
output, index, nms_rois_num = _C_ops.multiclass_nms3(
bboxes, scores, rois_num, *attrs)
if not return_index:
index = None
return output, index, nms_rois_num
else:
output = helper.create_variable_for_type_inference(dtype=bboxes.dtype)
index = helper.create_variable_for_type_inference(dtype='int32')
inputs = {'BBoxes': bboxes, 'Scores': scores}
outputs = {'Out': output, 'Index': index}
if rois_num is not None:
inputs['RoisNum'] = rois_num
if return_rois_num:
nms_rois_num = helper.create_variable_for_type_inference(
dtype='int32')
outputs['NmsRoisNum'] = nms_rois_num
helper.append_op(type="multiclass_nms3",
inputs=inputs,
attrs={
'background_label': background_label,
'score_threshold': score_threshold,
'nms_top_k': nms_top_k,
'nms_threshold': nms_threshold,
'keep_top_k': keep_top_k,
'nms_eta': nms_eta,
'normalized': normalized
},
outputs=outputs)
output.stop_gradient = True
index.stop_gradient = True
if not return_index:
index = None
if not return_rois_num:
nms_rois_num = None
return output, nms_rois_num, index
def softmax(x): def softmax(x):
...@@ -541,7 +615,8 @@ class TestMulticlassNMS2LoDInput(TestMulticlassNMSLoDInput): ...@@ -541,7 +615,8 @@ class TestMulticlassNMS2LoDInput(TestMulticlassNMSLoDInput):
'normalized': normalized, 'normalized': normalized,
} }
def test_check_output(self):
def test_check_output(self):
self.check_output() self.check_output()
...@@ -590,6 +665,7 @@ class TestMulticlassNMSError(unittest.TestCase): ...@@ -590,6 +665,7 @@ class TestMulticlassNMSError(unittest.TestCase):
class TestMulticlassNMS3Op(TestMulticlassNMS2Op): class TestMulticlassNMS3Op(TestMulticlassNMS2Op):
def setUp(self): def setUp(self):
self.python_api = multiclass_nms3
self.set_argument() self.set_argument()
N = 7 N = 7
M = 1200 M = 1200
...@@ -623,8 +699,8 @@ class TestMulticlassNMS3Op(TestMulticlassNMS2Op): ...@@ -623,8 +699,8 @@ class TestMulticlassNMS3Op(TestMulticlassNMS2Op):
self.op_type = 'multiclass_nms3' self.op_type = 'multiclass_nms3'
self.inputs = {'BBoxes': boxes, 'Scores': scores} self.inputs = {'BBoxes': boxes, 'Scores': scores}
self.outputs = { self.outputs = {
'Out': (nmsed_outs, [lod]), 'Out': nmsed_outs,
'Index': (index_outs, [lod]), 'Index': index_outs,
'NmsRoisNum': np.array(lod).astype('int32') 'NmsRoisNum': np.array(lod).astype('int32')
} }
self.attrs = { self.attrs = {
...@@ -638,7 +714,7 @@ class TestMulticlassNMS3Op(TestMulticlassNMS2Op): ...@@ -638,7 +714,7 @@ class TestMulticlassNMS3Op(TestMulticlassNMS2Op):
} }
def test_check_output(self): def test_check_output(self):
self.check_output() self.check_output(check_eager=True)
class TestMulticlassNMS3OpNoOutput(TestMulticlassNMS3Op): class TestMulticlassNMS3OpNoOutput(TestMulticlassNMS3Op):
...@@ -649,71 +725,6 @@ class TestMulticlassNMS3OpNoOutput(TestMulticlassNMS3Op): ...@@ -649,71 +725,6 @@ class TestMulticlassNMS3OpNoOutput(TestMulticlassNMS3Op):
self.score_threshold = 2.0 self.score_threshold = 2.0
class TestMulticlassNMS3LoDInput(TestMulticlassNMS2LoDInput):
def setUp(self):
self.set_argument()
M = 1200
C = 21
BOX_SIZE = 4
box_lod = [[1200]]
background = 0
nms_threshold = 0.3
nms_top_k = 400
keep_top_k = 200
score_threshold = self.score_threshold
normalized = False
scores = np.random.random((M, C)).astype('float32')
scores = np.apply_along_axis(softmax, 1, scores)
boxes = np.random.random((M, C, BOX_SIZE)).astype('float32')
boxes[:, :, 0] = boxes[:, :, 0] * 10
boxes[:, :, 1] = boxes[:, :, 1] * 10
boxes[:, :, 2] = boxes[:, :, 2] * 10 + 10
boxes[:, :, 3] = boxes[:, :, 3] * 10 + 10
det_outs, lod = lod_multiclass_nms(boxes, scores, background,
score_threshold, nms_threshold,
nms_top_k, keep_top_k, box_lod,
normalized)
det_outs = np.array(det_outs)
nmsed_outs = det_outs[:, :-1].astype('float32') if len(
det_outs) else det_outs
self.op_type = 'multiclass_nms3'
self.inputs = {
'BBoxes': (boxes, box_lod),
'Scores': (scores, box_lod),
'RoisNum': np.array(box_lod).astype('int32')
}
self.outputs = {
'Out': (nmsed_outs, [lod]),
'NmsRoisNum': np.array(lod).astype('int32')
}
self.attrs = {
'background_label': 0,
'nms_threshold': nms_threshold,
'nms_top_k': nms_top_k,
'keep_top_k': keep_top_k,
'score_threshold': score_threshold,
'nms_eta': 1.0,
'normalized': normalized,
}
def test_check_output(self):
self.check_output()
class TestMulticlassNMS3LoDNoOutput(TestMulticlassNMS3LoDInput):
def set_argument(self):
# Here set 2.0 to test the case there is no outputs.
# In practical use, 0.0 < score_threshold < 1.0
self.score_threshold = 2.0
if __name__ == '__main__': if __name__ == '__main__':
paddle.enable_static() paddle.enable_static()
unittest.main() unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册