提交 c3d3bd36 编写于 作者: E eclipsess

add multiclass nms op and test

上级 2004af04
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include "operators/kernel/multiclass_nms_kernel.h"
namespace paddle_mobile {
namespace operators {
constexpr int kOutputDim = 6;
constexpr int kBBoxSize = 4;
template <class T>
bool SortScorePairDescend(const std::pair<float, T>& pair1,
const std::pair<float, T>& pair2) {
return pair1.first > pair2.first;
}
template <class T>
static inline void GetMaxScoreIndex(
const std::vector<T>& scores, const T threshold, int top_k,
std::vector<std::pair<T, int>>* sorted_indices) {
for (size_t i = 0; i < scores.size(); ++i) {
if (scores[i] > threshold) {
sorted_indices->push_back(std::make_pair(scores[i], i));
}
}
// Sort the score pair according to the scores in descending order
std::stable_sort(sorted_indices->begin(), sorted_indices->end(),
SortScorePairDescend<int>);
// Keep top_k scores if needed.
if (top_k > -1 && top_k < static_cast<int>(sorted_indices->size())) {
sorted_indices->resize(top_k);
}
}
template <class T>
static inline T BBoxArea(const T* box, const bool normalized) {
if (box[2] < box[0] || box[3] < box[1]) {
// If coordinate values are is invalid
// (e.g. xmax < xmin or ymax < ymin), return 0.
return static_cast<T>(0.);
} else {
const T w = box[2] - box[0];
const T h = box[3] - box[1];
if (normalized) {
return w * h;
} else {
// If coordinate values are not within range [0, 1].
return (w + 1) * (h + 1);
}
}
}
template <class T>
static inline T JaccardOverlap(const T* box1, const T* box2,
const bool normalized) {
if (box2[0] > box1[2] || box2[2] < box1[0] || box2[1] > box1[3] ||
box2[3] < box1[1]) {
return static_cast<T>(0.);
} else {
const T inter_xmin = std::max(box1[0], box2[0]);
const T inter_ymin = std::max(box1[1], box2[1]);
const T inter_xmax = std::min(box1[2], box2[2]);
const T inter_ymax = std::min(box1[3], box2[3]);
const T inter_w = inter_xmax - inter_xmin;
const T inter_h = inter_ymax - inter_ymin;
const T inter_area = inter_w * inter_h;
const T bbox1_area = BBoxArea<T>(box1, normalized);
const T bbox2_area = BBoxArea<T>(box2, normalized);
return inter_area / (bbox1_area + bbox2_area - inter_area);
}
}
template <typename T>
static inline void NMSFast(const Tensor& bbox, const Tensor& scores,
const T score_threshold, const T nms_threshold,
const T eta, const int64_t top_k,
std::vector<int>* selected_indices) {
// The total boxes for each instance.
int64_t num_boxes = bbox.dims()[0];
// 4: [xmin ymin xmax ymax]
int64_t box_size = bbox.dims()[1];
std::vector<T> scores_data(num_boxes);
std::copy_n(scores.data<T>(), num_boxes, scores_data.begin());
std::vector<std::pair<T, int>> sorted_indices;
GetMaxScoreIndex(scores_data, score_threshold, top_k, &sorted_indices);
selected_indices->clear();
T adaptive_threshold = nms_threshold;
const T* bbox_data = bbox.data<T>();
while (sorted_indices.size() != 0) {
const int idx = sorted_indices.front().second;
bool keep = true;
for (size_t k = 0; k < selected_indices->size(); ++k) {
if (keep) {
const int kept_idx = (*selected_indices)[k];
T overlap = JaccardOverlap<T>(bbox_data + idx * box_size,
bbox_data + kept_idx * box_size, true);
keep = overlap <= adaptive_threshold;
} else {
break;
}
}
if (keep) {
selected_indices->push_back(idx);
}
sorted_indices.erase(sorted_indices.begin());
if (keep && eta < 1 && adaptive_threshold > 0.5) {
adaptive_threshold *= eta;
}
}
}
template <typename T>
void MultiClassNMS(const Tensor& scores, const Tensor& bboxes,
std::map<int, std::vector<int>>* indices, int* num_nmsed_out,
const int& background_label, const int& nms_top_k,
const int& keep_top_k, const T& nms_threshold,
const T& nms_eta, const T& score_threshold) {
int64_t class_num = scores.dims()[0];
int64_t predict_dim = scores.dims()[1];
int num_det = 0;
for (int64_t c = 0; c < class_num; ++c) {
if (c == background_label) continue;
Tensor score = scores.Slice(c, c + 1);
/// [c] is key
NMSFast<float>(bboxes, score, score_threshold, nms_threshold, nms_eta,
nms_top_k, &((*indices)[c]));
num_det += (*indices)[c].size();
}
*num_nmsed_out = num_det;
const T* scores_data = scores.data<T>();
if (keep_top_k > -1 && num_det > keep_top_k) {
std::vector<std::pair<float, std::pair<int, int>>> score_index_pairs;
for (const auto& it : *indices) {
int label = it.first;
const T* sdata = scores_data + label * predict_dim;
const std::vector<int>& label_indices = it.second;
for (size_t j = 0; j < label_indices.size(); ++j) {
int idx = label_indices[j];
// PADDLE_ENFORCE_LT(idx, predict_dim);
score_index_pairs.push_back(
std::make_pair(sdata[idx], std::make_pair(label, idx)));
}
}
// Keep top k results per image.
std::stable_sort(score_index_pairs.begin(), score_index_pairs.end(),
SortScorePairDescend<std::pair<int, int>>);
score_index_pairs.resize(keep_top_k);
// Store the new indices.
std::map<int, std::vector<int>> new_indices;
for (size_t j = 0; j < score_index_pairs.size(); ++j) {
int label = score_index_pairs[j].second.first;
int idx = score_index_pairs[j].second.second;
new_indices[label].push_back(idx);
}
new_indices.swap(*indices);
*num_nmsed_out = keep_top_k;
}
}
template <typename T>
void MultiClassOutput(const Tensor& scores, const Tensor& bboxes,
const std::map<int, std::vector<int>>& selected_indices,
Tensor* outs) {
int predict_dim = scores.dims()[1];
auto* scores_data = scores.data<T>();
auto* bboxes_data = bboxes.data<T>();
auto* odata = outs->data<T>();
int count = 0;
for (const auto& it : selected_indices) {
/// one batch
int label = it.first;
const T* sdata = scores_data + label * predict_dim;
const std::vector<int>& indices = it.second;
for (size_t j = 0; j < indices.size(); ++j) {
int idx = indices[j];
const T* bdata = bboxes_data + idx * kBBoxSize;
odata[count * kOutputDim] = label; // label
odata[count * kOutputDim + 1] = sdata[idx]; // score
// xmin, ymin, xmax, ymax
std::memcpy(odata + count * kOutputDim + 2, bdata, 4 * sizeof(T));
count++;
}
}
}
template <>
void MultiClassNMSKernel<CPU, float>::Compute(
const MultiClassNMSParam& param) const {
const auto* input_bboxes = param.InputBBoxes();
const auto& input_bboxes_dims = input_bboxes->dims();
const auto* input_scores = param.InputScores();
const auto& input_scores_dims = input_scores->dims();
auto* outs = param.Out();
auto background_label = param.BackGroundLabel();
auto nms_top_k = param.NMSTopK();
auto keep_top_k = param.KeepTopK();
auto nms_threshold = param.NMSThreshold();
auto nms_eta = param.NMSEta();
auto score_threshold = param.ScoreThreshold();
int64_t batch_size = input_scores_dims[0];
int64_t class_num = input_scores_dims[1];
int64_t predict_dim = input_scores_dims[2];
int64_t box_dim = input_bboxes_dims[2];
std::vector<std::map<int, std::vector<int>>> all_indices;
std::vector<size_t> batch_starts = {0};
for (int64_t i = 0; i < batch_size; ++i) {
Tensor ins_score = input_scores->Slice(i, i + 1);
ins_score.Resize({class_num, predict_dim});
Tensor ins_boxes = input_bboxes->Slice(i, i + 1);
ins_boxes.Resize({predict_dim, box_dim});
std::map<int, std::vector<int>> indices;
int num_nmsed_out = 0;
MultiClassNMS<float>(ins_score, ins_boxes, &indices, &num_nmsed_out,
background_label, nms_top_k, keep_top_k, nms_threshold,
nms_eta, score_threshold);
all_indices.push_back(indices);
batch_starts.push_back(batch_starts.back() + num_nmsed_out);
}
int num_kept = batch_starts.back();
if (num_kept == 0) {
float* od = outs->mutable_data<float>({1});
od[0] = -1;
} else {
outs->mutable_data<float>({num_kept, kOutputDim});
for (int64_t i = 0; i < batch_size; ++i) {
Tensor ins_score = input_scores->Slice(i, i + 1);
ins_score.Resize({class_num, predict_dim});
Tensor ins_boxes = input_bboxes->Slice(i, i + 1);
ins_boxes.Resize({predict_dim, box_dim});
int64_t s = batch_starts[i];
int64_t e = batch_starts[i + 1];
if (e > s) {
Tensor out = outs->Slice(s, e);
MultiClassOutput<float>(ins_score, ins_boxes, all_indices[i], &out);
}
}
}
// framework::LoD lod;
// lod.emplace_back(batch_starts);
//
// outs->set_lod(lod);
}
} // namespace operators
} // namespace paddle_mobile
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "framework/operator.h"
#include "operators/op_param.h"
#pragma once;
namespace paddle_mobile {
namespace operators {
template <typename DeviceType, typename T>
class MultiClassNMSKernel
: public framework::OpKernelBase<DeviceType, MultiClassNMSParam> {
public:
void Compute(const MultiClassNMSParam& param) const;
};
} // namespace operators
} // namespace paddle_mobile
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "operators/multiclass_nms_op.h"
namespace paddle_mobile {
namespace operators {
template <typename Dtype, typename T>
void MultiClassNMSOp<Dtype, T>::InferShape() const {
auto input_bboxes_dims = param_.InputBBoxes()->dims();
auto input_scores_dims = param_.InputScores()->dims();
if (input_scores_dims.size() != 3) {
LOG(kLOG_ERROR) << "Input Scores size must be 3";
}
if (input_bboxes_dims[2] != 4) {
LOG(kLOG_ERROR) << "Input BBoxes 2nd dimension must be 4";
}
if (input_bboxes_dims[1] != input_scores_dims[2]) {
LOG(kLOG_ERROR) << "Predict bboxes must be equal";
}
// pre size, will change in Compute.
param_.Out()->Resize(framework::make_ddim({input_bboxes_dims[1], 6}));
}
template class MultiClassNMSOp<CPU, float>;
} // namespace operators
} // namespace paddle_mobile
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include <string>
#include "framework/operator.h"
#include "operators/kernel/multiclass_nms_kernel.h"
#include "operators/op_param.h"
namespace paddle_mobile {
namespace operators {
using paddle_mobile::framework::Tensor;
template <typename DeviceType, typename T>
class MultiClassNMSOp : public framework::OperatorWithKernel<DeviceType> {
public:
MultiClassNMSOp(const std::string &type, const VariableNameMap &inputs,
const VariableNameMap &outputs,
const framework::AttributeMap attrs,
std::shared_ptr<framework::Scope> scope)
: framework::OperatorWithKernel<DeviceType>(type, inputs, outputs, attrs,
scope),
param_(inputs, outputs, attrs, *scope) {}
void Run() const {
operators::MultiClassNMSKernel<DeviceType, T> kernel;
kernel.Compute(param_);
}
using framework::OperatorWithKernel<DeviceType>::OperatorWithKernel;
void InferShape() const override;
protected:
MultiClassNMSParam param_;
};
} // namespace operators
} // namespace paddle_mobile
......@@ -89,6 +89,16 @@ class OpParam : PaddleMobileObject {
return GetVarValue<T>("TargetBox", inputs, scope);
}
template <typename T>
static T *InputBBoxesFrom(const VariableNameMap &inputs, const Scope &scope) {
return GetVarValue<T>("BBoxes", inputs, scope);
}
template <typename T>
static T *InputScoresFrom(const VariableNameMap &inputs, const Scope &scope) {
return GetVarValue<T>("Scores", inputs, scope);
}
template <typename T>
static vector<T *> InputMultiFrom(const VariableNameMap &inputs,
const Scope &scope) {
......@@ -527,6 +537,51 @@ class SoftmaxParam : public OpParam {
Tensor *input_x_;
Tensor *out_;
};
class MultiClassNMSParam : public OpParam {
public:
MultiClassNMSParam(const VariableNameMap &inputs,
const VariableNameMap &outputs, const AttributeMap &attrs,
const Scope &scope) {
input_bboxes_ = InputBBoxesFrom<Tensor>(inputs, scope);
input_scores_ = InputScoresFrom<Tensor>(inputs, scope);
out_ = OutFrom<Tensor>(outputs, scope);
background_label_ = GetAttr<int>("background_label", attrs);
nms_top_k_ = GetAttr<int>("nms_top_k", attrs);
keep_top_k_ = GetAttr<int>("keep_top_k", attrs);
nms_threshold_ = GetAttr<float>("nms_threshold", attrs);
nms_eta_ = GetAttr<float>("nms_eta", attrs);
score_threshold_ = GetAttr<float>("score_threshold", attrs);
}
const Tensor *InputBBoxes() const { return input_bboxes_; }
const Tensor *InputScores() const { return input_scores_; }
Tensor *Out() const { return out_; }
const int &BackGroundLabel() const { return background_label_; }
const int &NMSTopK() const { return nms_top_k_; }
const int &KeepTopK() const { return keep_top_k_; }
const float &NMSThreshold() const { return nms_threshold_; }
const float &NMSEta() const { return nms_eta_; }
const float &ScoreThreshold() const { return score_threshold_; }
private:
Tensor *input_bboxes_;
Tensor *input_scores_;
Tensor *out_;
int background_label_;
int nms_top_k_;
int keep_top_k_;
float nms_threshold_;
float nms_eta_;
float score_threshold_;
};
} // namespace operators
} // namespace paddle_mobile
......@@ -34,6 +34,10 @@ target_link_libraries(test-priorbox-op paddle-mobile)
ADD_EXECUTABLE(test-boxcoder-op operators/test_box_coder_op.cpp test_helper.h test_include.h)
target_link_libraries(test-boxcoder-op paddle-mobile)
# gen test
ADD_EXECUTABLE(test-multiclassnms-op operators/test_multiclass_nms_op.cpp test_helper.h test_include.h)
target_link_libraries(test-multiclassnms-op paddle-mobile)
# gen test log
ADD_EXECUTABLE(test-log common/test_log.cpp)
target_link_libraries(test-log paddle-mobile)
......
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include "../test_include.h"
#include "operators/multiclass_nms_op.h"
namespace paddle_mobile {
namespace framework {
template <typename Dtype>
class TestMultiClassNMSOp {
public:
explicit TestMultiClassNMSOp(const Program<Dtype> p) : program_(p) {
if (use_optimize_) {
to_predict_program_ = program_.optimizeProgram;
} else {
to_predict_program_ = program_.originProgram;
}
const std::vector<std::shared_ptr<BlockDesc>> blocks =
to_predict_program_->Blocks();
// DLOG << " **block size " << blocks.size();
for (auto block_desc : blocks) {
std::vector<std::shared_ptr<OpDesc>> ops = block_desc->Ops();
// DLOG << " ops " << ops.size();
for (auto op : ops) {
if (op->Type() == "multiclass_nms" &&
op->Input("BBoxes")[0] == "box_coder_0.tmp_0") {
DLOG << " mul attr size: " << op->GetAttrMap().size();
DLOG << " inputs size: " << op->GetInputs().size();
DLOG << " outputs size: " << op->GetOutputs().size();
DLOG << " BBoxes is : " << op->Input("BBoxes")[0];
DLOG << " Scores is : " << op->Input("Scores")[0];
DLOG << " Out is : " << op->Output("Out")[0];
DLOG << " keep_top_k : "
<< op->GetAttrMap().at("keep_top_k").Get<int>();
DLOG << " background_label : "
<< op->GetAttrMap().at("background_label").Get<int>();
DLOG << " nms_eta : " << op->GetAttrMap().at("nms_eta").Get<float>();
DLOG << " nms_threshold : "
<< op->GetAttrMap().at("nms_threshold").Get<float>();
DLOG << " nms_top_k : "
<< op->GetAttrMap().at("nms_top_k").Get<int>();
DLOG << " score_threshold : "
<< op->GetAttrMap().at("score_threshold").Get<float>();
// DLOG << " variances : " <<
// op->GetAttrMap().at("variances").Get<std::vector<float>>();
// DLOG << " aspect_ratios : " <<
// op->GetAttrMap().at("aspect_ratios").Get<std::vector<float>>();
// DLOG << " min_sizes : " <<
// op->GetAttrMap().at("min_sizes").Get<std::vector<float>>();
// DLOG << " max_sizes : " <<
// op->GetAttrMap().at("max_sizes").Get<std::vector<float>>();
std::shared_ptr<operators::MultiClassNMSOp<Dtype, float>> priorbox =
std::make_shared<operators::MultiClassNMSOp<Dtype, float>>(
op->Type(), op->GetInputs(), op->GetOutputs(),
op->GetAttrMap(), program_.scope);
ops_of_block_[*block_desc.get()].push_back(priorbox);
}
}
}
}
std::shared_ptr<Tensor> predict(const Tensor &t1, const Tensor &t2) {
// feed
auto scope = program_.scope;
Variable *x1_feed_value = scope->Var("box_coder_0.tmp_0");
auto tensor_x1 = x1_feed_value->GetMutable<Tensor>();
tensor_x1->ShareDataWith(t1);
Variable *x2_feed_value = scope->Var("transpose_12.tmp_0");
auto tensor_x2 = x2_feed_value->GetMutable<Tensor>();
tensor_x2->ShareDataWith(t2);
Variable *output = scope->Var("detection_output_0.tmp_0");
auto *output_tensor = output->GetMutable<Tensor>();
output_tensor->mutable_data<float>({1917, 6});
// DLOG << typeid(output_tensor).name();
// DLOG << "output_tensor dims: " << output_tensor->dims();
std::shared_ptr<Tensor> out_tensor = std::make_shared<LoDTensor>();
out_tensor.reset(output_tensor);
predict(t1, t2, 0);
return out_tensor;
// return outvars_tensor;
}
private:
const framework::Program<Dtype> program_;
std::shared_ptr<ProgramDesc> to_predict_program_;
std::map<framework::BlockDesc,
std::vector<std::shared_ptr<OperatorBase<Dtype>>>>
ops_of_block_;
bool use_optimize_ = false;
void predict(const Tensor &t1, const Tensor &t2, int block_id) {
std::shared_ptr<BlockDesc> to_predict_block =
to_predict_program_->Block(block_id);
for (int j = 0; j < ops_of_block_[*to_predict_block.get()].size(); ++j) {
auto op = ops_of_block_[*to_predict_block.get()][j];
DLOG << "op -> run()";
op->Run();
}
}
};
template class TestMultiClassNMSOp<CPU>;
} // namespace framework
} // namespace paddle_mobile
int main() {
DLOG << "----------**********----------";
DLOG << "begin to run MulticlassNMS Test";
paddle_mobile::Loader<paddle_mobile::CPU> loader;
auto program = loader.Load(std::string("../../test/models/mobilenet+ssd"));
/// input x (1,3,300,300)
paddle_mobile::framework::Tensor inputx1;
SetupTensor<float>(&inputx1, {10, 1917, 4}, static_cast<float>(0),
static_cast<float>(1));
auto *inputx1_ptr = inputx1.data<float>();
paddle_mobile::framework::Tensor inputx2;
SetupTensor<float>(&inputx2, {10, 21, 1917}, static_cast<float>(0),
static_cast<float>(1));
auto *inputx2_ptr = inputx2.data<float>();
paddle_mobile::framework::TestMultiClassNMSOp<paddle_mobile::CPU>
testMultiClassNMSOp(program);
auto output = testMultiClassNMSOp.predict(inputx1, inputx2);
auto *output_ptr = output->data<float>();
for (int i = 0; i < output->numel(); i++) {
DLOG << output_ptr[i];
}
return 0;
}
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册