提交 f3415ec5 编写于 作者: D dangqingqing

Follow comments.

上级 53788640
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. /* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserve.
Licensed under the Apache License, Version 2.0 (the "License"); Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License. you may not use this file except in compliance with the License.
...@@ -28,12 +28,18 @@ class BipartiteMatchOp : public framework::OperatorWithKernel { ...@@ -28,12 +28,18 @@ class BipartiteMatchOp : public framework::OperatorWithKernel {
void InferShape(framework::InferShapeContext* ctx) const override { void InferShape(framework::InferShapeContext* ctx) const override {
PADDLE_ENFORCE(ctx->HasInput("DistMat"), PADDLE_ENFORCE(ctx->HasInput("DistMat"),
"Input(DistMat) of BipartiteMatch should not be null."); "Input(DistMat) of BipartiteMatch should not be null.");
PADDLE_ENFORCE(
ctx->HasOutput("ColToRowMatchIndices"),
"Output(ColToRowMatchIndices) of BipartiteMatch should not be null.");
PADDLE_ENFORCE(
ctx->HasOutput("ColToRowMatchDist"),
"Output(ColToRowMatchDist) of BipartiteMatch should not be null.");
auto dims = ctx->GetInputDim("DistMat"); auto dims = ctx->GetInputDim("DistMat");
PADDLE_ENFORCE_EQ(dims.size(), 2, "The rank of Input(DistMat) must be 2."); PADDLE_ENFORCE_EQ(dims.size(), 2, "The rank of Input(DistMat) must be 2.");
ctx->SetOutputDim("ColToRowMatchIndices", dims); ctx->SetOutputDim("ColToRowMatchIndices", dims);
ctx->SetOutputDim("ColToRowMatchDis", dims); ctx->SetOutputDim("ColToRowMatchDist", dims);
} }
}; };
...@@ -91,7 +97,7 @@ class BipartiteMatchKernel : public framework::OpKernel<T> { ...@@ -91,7 +97,7 @@ class BipartiteMatchKernel : public framework::OpKernel<T> {
void Compute(const framework::ExecutionContext& context) const override { void Compute(const framework::ExecutionContext& context) const override {
auto* dist_mat = context.Input<LoDTensor>("DistMat"); auto* dist_mat = context.Input<LoDTensor>("DistMat");
auto* match_indices = context.Output<Tensor>("ColToRowMatchIndices"); auto* match_indices = context.Output<Tensor>("ColToRowMatchIndices");
auto* match_dist = context.Output<Tensor>("ColToRowMatchDis"); auto* match_dist = context.Output<Tensor>("ColToRowMatchDist");
auto& dev_ctx = context.device_context<platform::CPUDeviceContext>(); auto& dev_ctx = context.device_context<platform::CPUDeviceContext>();
...@@ -148,13 +154,13 @@ class BipartiteMatchOpMaker : public framework::OpProtoAndCheckerMaker { ...@@ -148,13 +154,13 @@ class BipartiteMatchOpMaker : public framework::OpProtoAndCheckerMaker {
"Otherwise, it means B[j] is matched to row " "Otherwise, it means B[j] is matched to row "
"ColToRowMatchIndices[i][j] in i-th instance. The row number of " "ColToRowMatchIndices[i][j] in i-th instance. The row number of "
"i-th instance is saved in ColToRowMatchIndices[i][j]."); "i-th instance is saved in ColToRowMatchIndices[i][j].");
AddOutput("ColToRowMatchDis", AddOutput("ColToRowMatchDist",
"(Tensor) A 2-D Tensor with shape [N, M] in float type. " "(Tensor) A 2-D Tensor with shape [N, M] in float type. "
"N is batch size. If ColToRowMatchIndices[i][j] is -1, " "N is batch size. If ColToRowMatchIndices[i][j] is -1, "
"ColToRowMatchDis[i][j] is also -1.0. Otherwise, assumed " "ColToRowMatchDist[i][j] is also -1.0. Otherwise, assumed "
"ColToRowMatchIndices[i][j] = d, and the row offsets of each " "ColToRowMatchIndices[i][j] = d, and the row offsets of each "
"instance are called LoD. Then " "instance are called LoD. Then "
"ColToRowMatchDis[i][j] = DistMat[d+LoD[i]][j]"); "ColToRowMatchDist[i][j] = DistMat[d+LoD[i]][j]");
AddComment(R"DOC( AddComment(R"DOC(
This operator is a greedy bipartite matching algorithm, which is used to This operator is a greedy bipartite matching algorithm, which is used to
obtain the matching with the maximum distance based on the input obtain the matching with the maximum distance based on the input
......
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. /* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserve.
Licensed under the Apache License, Version 2.0 (the "License"); Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License. you may not use this file except in compliance with the License.
...@@ -24,25 +24,33 @@ using LoDTensor = framework::LoDTensor; ...@@ -24,25 +24,33 @@ using LoDTensor = framework::LoDTensor;
constexpr int64_t kOutputDim = 6; constexpr int64_t kOutputDim = 6;
constexpr int64_t kBBoxSize = 4; constexpr int64_t kBBoxSize = 4;
class MulticlassNMSOp : public framework::OperatorWithKernel { class MultiClassNMSOp : public framework::OperatorWithKernel {
public: public:
using framework::OperatorWithKernel::OperatorWithKernel; using framework::OperatorWithKernel::OperatorWithKernel;
void InferShape(framework::InferShapeContext* ctx) const override { void InferShape(framework::InferShapeContext* ctx) const override {
PADDLE_ENFORCE(ctx->HasInput("Bboxes"), PADDLE_ENFORCE(ctx->HasInput("BBoxes"),
"Input(Bboxes) of MulticlassNMS should not be null."); "Input(BBoxes) of MultiClassNMS should not be null.");
PADDLE_ENFORCE(ctx->HasInput("Scores"), PADDLE_ENFORCE(ctx->HasInput("Scores"),
"Input(Scores) of MulticlassNMS should not be null."); "Input(Scores) of MultiClassNMS should not be null.");
PADDLE_ENFORCE(ctx->HasOutput("Out"),
"Output(Out) of MultiClassNMS should not be null.");
auto box_dims = ctx->GetInputDim("Bboxes"); auto box_dims = ctx->GetInputDim("BBoxes");
auto score_dims = ctx->GetInputDim("Scores"); auto score_dims = ctx->GetInputDim("Scores");
PADDLE_ENFORCE_EQ(box_dims.size(), 2, PADDLE_ENFORCE_EQ(box_dims.size(), 2,
"The rank of Input(Bboxes) must be 3."); "The rank of Input(BBoxes) must be 2.");
PADDLE_ENFORCE_EQ(score_dims.size(), 3, PADDLE_ENFORCE_EQ(score_dims.size(), 3,
"The rank of Input(Scores) must be 3."); "The rank of Input(Scores) must be 3.");
PADDLE_ENFORCE_EQ(box_dims[1], 4); PADDLE_ENFORCE_EQ(box_dims[1], 4,
PADDLE_ENFORCE_EQ(box_dims[0], score_dims[2]); "The 2nd dimension of Input(BBoxes) must be 4, "
"represents the layout of coordinate "
"[xmin, ymin, xmax, ymax]");
PADDLE_ENFORCE_EQ(box_dims[0], score_dims[2],
"The 1st dimensiong of Input(BBoxes) must be equal to "
"3rd dimension of Input(Scores), which represents the "
"predicted bboxes.");
// Here the box_dims[0] is not the real dimension of output. // Here the box_dims[0] is not the real dimension of output.
// It will be rewritten in the computing kernel. // It will be rewritten in the computing kernel.
...@@ -86,15 +94,16 @@ static inline void GetMaxScoreIndex( ...@@ -86,15 +94,16 @@ static inline void GetMaxScoreIndex(
template <class T> template <class T>
T BBoxArea(const T* box, const bool normalized) { T BBoxArea(const T* box, const bool normalized) {
if (box[2] < box[0] || box[3] < box[1]) { if (box[2] < box[0] || box[3] < box[1]) {
// If bbox is invalid (e.g. xmax < xmin or ymax < ymin), return 0. // If coordinate values are is invalid
return T(0.); // (e.g. xmax < xmin or ymax < ymin), return 0.
return static_cast<T>(0.);
} else { } else {
const T w = box[2] - box[0]; const T w = box[2] - box[0];
const T h = box[3] - box[1]; const T h = box[3] - box[1];
if (normalized) { if (normalized) {
return w * h; return w * h;
} else { } else {
// If bbox is not within range [0, 1]. // If coordinate values are not within range [0, 1].
return (w + 1) * (h + 1); return (w + 1) * (h + 1);
} }
} }
...@@ -121,7 +130,7 @@ static inline T JaccardOverlap(const T* box1, const T* box2, ...@@ -121,7 +130,7 @@ static inline T JaccardOverlap(const T* box1, const T* box2,
} }
template <typename T> template <typename T>
class MulticlassNMSKernel : public framework::OpKernel<T> { class MultiClassNMSKernel : public framework::OpKernel<T> {
public: public:
void NMSFast(const Tensor& bbox, const Tensor& scores, void NMSFast(const Tensor& bbox, const Tensor& scores,
const T score_threshold, const T nms_threshold, const T eta, const T score_threshold, const T nms_threshold, const T eta,
...@@ -163,10 +172,10 @@ class MulticlassNMSKernel : public framework::OpKernel<T> { ...@@ -163,10 +172,10 @@ class MulticlassNMSKernel : public framework::OpKernel<T> {
} }
} }
void MulticlassNMS(const framework::ExecutionContext& ctx, void MultiClassNMS(const framework::ExecutionContext& ctx,
const Tensor& scores, const Tensor& bboxes, const Tensor& scores, const Tensor& bboxes,
std::map<int, std::vector<int>>* indices, std::map<int, std::vector<int>>& indices,
int* num_nmsed_out) const { int& num_nmsed_out) const {
int64_t background_label = ctx.Attr<int>("background_label"); int64_t background_label = ctx.Attr<int>("background_label");
int64_t nms_top_k = ctx.Attr<int>("nms_top_k"); int64_t nms_top_k = ctx.Attr<int>("nms_top_k");
int64_t keep_top_k = ctx.Attr<int>("keep_top_k"); int64_t keep_top_k = ctx.Attr<int>("keep_top_k");
...@@ -181,15 +190,15 @@ class MulticlassNMSKernel : public framework::OpKernel<T> { ...@@ -181,15 +190,15 @@ class MulticlassNMSKernel : public framework::OpKernel<T> {
if (c == background_label) continue; if (c == background_label) continue;
Tensor score = scores.Slice(c, c + 1); Tensor score = scores.Slice(c, c + 1);
NMSFast(bboxes, score, score_threshold, nms_threshold, nms_eta, nms_top_k, NMSFast(bboxes, score, score_threshold, nms_threshold, nms_eta, nms_top_k,
&((*indices)[c])); &(indices[c]));
num_det += (*indices)[c].size(); num_det += indices[c].size();
} }
*num_nmsed_out = num_det; num_nmsed_out = num_det;
const T* scores_data = scores.data<T>(); const T* scores_data = scores.data<T>();
if (keep_top_k > -1 && num_det > keep_top_k) { if (keep_top_k > -1 && num_det > keep_top_k) {
std::vector<std::pair<float, std::pair<int, int>>> score_index_pairs; std::vector<std::pair<float, std::pair<int, int>>> score_index_pairs;
for (const auto& it : *indices) { for (const auto& it : indices) {
int label = it.first; int label = it.first;
const T* sdata = scores_data + label * predict_dim; const T* sdata = scores_data + label * predict_dim;
const std::vector<int>& label_indices = it.second; const std::vector<int>& label_indices = it.second;
...@@ -212,12 +221,12 @@ class MulticlassNMSKernel : public framework::OpKernel<T> { ...@@ -212,12 +221,12 @@ class MulticlassNMSKernel : public framework::OpKernel<T> {
int idx = score_index_pairs[j].second.second; int idx = score_index_pairs[j].second.second;
new_indices[label].push_back(idx); new_indices[label].push_back(idx);
} }
new_indices.swap(*indices); new_indices.swap(indices);
*num_nmsed_out = keep_top_k; num_nmsed_out = keep_top_k;
} }
} }
void MulticlassOutput(const Tensor& scores, const Tensor& bboxes, void MultiClassOutput(const Tensor& scores, const Tensor& bboxes,
std::map<int, std::vector<int>>& selected_indices, std::map<int, std::vector<int>>& selected_indices,
Tensor* outs) const { Tensor* outs) const {
int predict_dim = scores.dims()[1]; int predict_dim = scores.dims()[1];
...@@ -229,23 +238,21 @@ class MulticlassNMSKernel : public framework::OpKernel<T> { ...@@ -229,23 +238,21 @@ class MulticlassNMSKernel : public framework::OpKernel<T> {
for (const auto& it : selected_indices) { for (const auto& it : selected_indices) {
int label = it.first; int label = it.first;
const T* sdata = scores_data + label * predict_dim; const T* sdata = scores_data + label * predict_dim;
std::vector<int> indices = it.second; const std::vector<int>& indices = it.second;
for (int j = 0; j < indices.size(); ++j) { for (int j = 0; j < indices.size(); ++j) {
int idx = indices[j]; int idx = indices[j];
const T* bdata = bboxes_data + idx * kBBoxSize; const T* bdata = bboxes_data + idx * kBBoxSize;
odata[count * kOutputDim] = label; // label odata[count * kOutputDim] = label; // label
odata[count * kOutputDim + 1] = sdata[idx]; // score odata[count * kOutputDim + 1] = sdata[idx]; // score
odata[count * kOutputDim + 2] = bdata[0]; // xmin // xmin, ymin, xmax, ymax
odata[count * kOutputDim + 3] = bdata[1]; // ymin std::memcpy(odata + count * kOutputDim + 2, bdata, 4 * sizeof(T));
odata[count * kOutputDim + 4] = bdata[2]; // xmax
odata[count * kOutputDim + 5] = bdata[3]; // ymax
count++; count++;
} }
} }
} }
void Compute(const framework::ExecutionContext& ctx) const override { void Compute(const framework::ExecutionContext& ctx) const override {
auto* boxes = ctx.Input<Tensor>("Bboxes"); auto* boxes = ctx.Input<Tensor>("BBoxes");
auto* scores = ctx.Input<Tensor>("Scores"); auto* scores = ctx.Input<Tensor>("Scores");
auto* outs = ctx.Output<LoDTensor>("Out"); auto* outs = ctx.Output<LoDTensor>("Out");
...@@ -262,7 +269,7 @@ class MulticlassNMSKernel : public framework::OpKernel<T> { ...@@ -262,7 +269,7 @@ class MulticlassNMSKernel : public framework::OpKernel<T> {
ins_score.Resize({class_num, predict_dim}); ins_score.Resize({class_num, predict_dim});
std::map<int, std::vector<int>> indices; std::map<int, std::vector<int>> indices;
int num_nmsed_out = 0; int num_nmsed_out = 0;
MulticlassNMS(ctx, ins_score, *boxes, &indices, &num_nmsed_out); MultiClassNMS(ctx, ins_score, *boxes, indices, num_nmsed_out);
all_indices.push_back(indices); all_indices.push_back(indices);
batch_starts.push_back(batch_starts.back() + num_nmsed_out); batch_starts.push_back(batch_starts.back() + num_nmsed_out);
} }
...@@ -280,7 +287,7 @@ class MulticlassNMSKernel : public framework::OpKernel<T> { ...@@ -280,7 +287,7 @@ class MulticlassNMSKernel : public framework::OpKernel<T> {
int64_t e = batch_starts[i + 1]; int64_t e = batch_starts[i + 1];
if (e > s) { if (e > s) {
Tensor out = outs->Slice(s, e); Tensor out = outs->Slice(s, e);
MulticlassOutput(ins_score, *boxes, all_indices[i], &out); MultiClassOutput(ins_score, *boxes, all_indices[i], &out);
} }
} }
} }
...@@ -292,28 +299,31 @@ class MulticlassNMSKernel : public framework::OpKernel<T> { ...@@ -292,28 +299,31 @@ class MulticlassNMSKernel : public framework::OpKernel<T> {
} }
}; };
class MulticlassNMSOpMaker : public framework::OpProtoAndCheckerMaker { class MultiClassNMSOpMaker : public framework::OpProtoAndCheckerMaker {
public: public:
MulticlassNMSOpMaker(OpProto* proto, OpAttrChecker* op_checker) MultiClassNMSOpMaker(OpProto* proto, OpAttrChecker* op_checker)
: OpProtoAndCheckerMaker(proto, op_checker) { : OpProtoAndCheckerMaker(proto, op_checker) {
AddInput("Bboxes", AddInput("BBoxes",
"(Tensor) A 2-D Tensor with shape [M, 4] represents the location " "(Tensor) A 2-D Tensor with shape [M, 4] represents the "
"predictions with M bboxes. 4 is the number of " "predicted locations of M bounding bboxes. Each bounding box "
"each location coordinates."); "has four coordinate values and the layout is "
"[xmin, ymin, xmax, ymax].");
AddInput("Scores", AddInput("Scores",
"(Tensor) A 3-D Tensor with shape [N, C, M] represents the " "(Tensor) A 3-D Tensor with shape [N, C, M] represents the "
"confidence predictions. N is the batch size, C is the class " "predicted confidence predictions. N is the batch size, C is the "
"number, M is number of predictions for each class, which is " "class number, M is number of bounding boxes. For each category "
"the same with Bboxes."); "there are total M scores which corresponding M bounding boxes. "
" Please note, M is equal to the 1st dimension of BBoxes. ");
AddAttr<int>( AddAttr<int>(
"background_label", "background_label",
"(int64_t, defalut: 0) " "(int64_t, defalut: 0) "
"The index of background label, the background label will be ignored.") "The index of background label, the background label will be ignored. "
"If set to -1, then all categories will be considered.")
.SetDefault(0); .SetDefault(0);
AddAttr<float>("score_threshold", AddAttr<float>("score_threshold",
"(float) " "(float) "
"Only consider detections whose confidences are larger than " "Threshold to filter out bounding boxes with low "
"a threshold. If not provided, consider all boxes."); "confidence score. If not provided, consider all boxes.");
AddAttr<int>("nms_top_k", AddAttr<int>("nms_top_k",
"(int64_t) " "(int64_t) "
"Maximum number of detections to be kept according to the " "Maximum number of detections to be kept according to the "
...@@ -368,8 +378,8 @@ value which is -1. ...@@ -368,8 +378,8 @@ value which is -1.
} // namespace paddle } // namespace paddle
namespace ops = paddle::operators; namespace ops = paddle::operators;
REGISTER_OPERATOR(multiclass_nms, ops::MulticlassNMSOp, REGISTER_OPERATOR(multiclass_nms, ops::MultiClassNMSOp,
ops::MulticlassNMSOpMaker, ops::MultiClassNMSOpMaker,
paddle::framework::EmptyGradOpMaker); paddle::framework::EmptyGradOpMaker);
REGISTER_OP_CPU_KERNEL(multiclass_nms, ops::MulticlassNMSKernel<float>, REGISTER_OP_CPU_KERNEL(multiclass_nms, ops::MultiClassNMSKernel<float>,
ops::MulticlassNMSKernel<double>); ops::MultiClassNMSKernel<double>);
...@@ -72,7 +72,7 @@ class TestBipartiteMatchOpWithLoD(OpTest): ...@@ -72,7 +72,7 @@ class TestBipartiteMatchOpWithLoD(OpTest):
self.inputs = {'DistMat': (dist, lod)} self.inputs = {'DistMat': (dist, lod)}
self.outputs = { self.outputs = {
'ColToRowMatchIndices': (match_indices), 'ColToRowMatchIndices': (match_indices),
'ColToRowMatchDis': (match_dist), 'ColToRowMatchDist': (match_dist),
} }
def test_check_output(self): def test_check_output(self):
...@@ -89,7 +89,7 @@ class TestBipartiteMatchOpWithoutLoD(OpTest): ...@@ -89,7 +89,7 @@ class TestBipartiteMatchOpWithoutLoD(OpTest):
self.inputs = {'DistMat': dist} self.inputs = {'DistMat': dist}
self.outputs = { self.outputs = {
'ColToRowMatchIndices': match_indices, 'ColToRowMatchIndices': match_indices,
'ColToRowMatchDis': match_dist, 'ColToRowMatchDist': match_dist,
} }
def test_check_output(self): def test_check_output(self):
......
...@@ -190,7 +190,7 @@ class TestMulticlassNMSOp(OpTest): ...@@ -190,7 +190,7 @@ class TestMulticlassNMSOp(OpTest):
nmsed_outs = np.array(nmsed_outs).astype('float32') nmsed_outs = np.array(nmsed_outs).astype('float32')
self.op_type = 'multiclass_nms' self.op_type = 'multiclass_nms'
self.inputs = {'Bboxes': boxes, 'Scores': scores} self.inputs = {'BBoxes': boxes, 'Scores': scores}
self.outputs = {'Out': (nmsed_outs, [lod])} self.outputs = {'Out': (nmsed_outs, [lod])}
self.attrs = { self.attrs = {
'background_label': 0, 'background_label': 0,
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册