未验证 提交 e1c4adfd 编写于 作者: Y yiicy 提交者: GitHub

[ARM] multiclass_nms op add index output, test=develop (#2654)

上级 d904c9dd
......@@ -271,7 +271,9 @@ void MultiClassOutput(const Tensor& scores,
const Tensor& bboxes,
const std::map<int, std::vector<int>>& selected_indices,
const int scores_size,
Tensor* outs) {
Tensor* outs,
int* oindices = nullptr,
const int offset = 0) {
int64_t class_num = scores.dims()[1];
int64_t predict_dim = scores.dims()[1];
int64_t box_size = bboxes.dims()[1];
......@@ -301,9 +303,15 @@ void MultiClassOutput(const Tensor& scores,
if (scores_size == 3) {
bdata = bboxes_data + idx * box_size;
odata[count * out_dim + 1] = sdata[idx]; // score
if (oindices != nullptr) {
oindices[count] = offset + idx;
}
} else {
bdata = bbox.data<T>() + idx * box_size;
odata[count * out_dim + 1] = *(scores_data + idx * class_num + label);
if (oindices != nullptr) {
oindices[count] = offset + idx * class_num + label;
}
}
// xmin, ymin, xmax, ymax or multi-points coordinates
std::memcpy(odata + count * out_dim + 2, bdata, box_size * sizeof(T));
......@@ -317,7 +325,8 @@ void MulticlassNmsCompute::Run() {
auto* boxes = param.bboxes;
auto* scores = param.scores;
auto* outs = param.out;
bool return_index = param.index ? true : false;
auto* index = param.index;
auto score_dims = scores->dims();
auto score_size = score_dims.size();
......@@ -349,36 +358,61 @@ void MulticlassNmsCompute::Run() {
uint64_t num_kept = batch_starts.back();
if (num_kept == 0) {
outs->Resize({1, 1});
float* od = outs->mutable_data<float>();
od[0] = -1;
batch_starts = {0, 1};
if (return_index) {
outs->Resize({0, out_dim});
index->Resize({0, 1});
} else {
outs->Resize({1, 1});
float* od = outs->mutable_data<float>();
od[0] = -1;
batch_starts = {0, 1};
}
} else {
outs->Resize({static_cast<int64_t>(num_kept), out_dim});
int offset = 0;
int* oindices = nullptr;
for (int i = 0; i < n; ++i) {
if (score_size == 3) {
scores_slice = scores->Slice<float>(i, i + 1);
boxes_slice = boxes->Slice<float>(i, i + 1);
scores_slice.Resize({score_dims[1], score_dims[2]});
boxes_slice.Resize({score_dims[2], box_dim});
if (return_index) {
offset = i * score_dims[2];
}
} else {
auto boxes_lod = boxes->lod().back();
scores_slice = scores->Slice<float>(boxes_lod[i], boxes_lod[i + 1]);
boxes_slice = boxes->Slice<float>(boxes_lod[i], boxes_lod[i + 1]);
if (return_index) {
offset = boxes_lod[i] * score_dims[1];
}
}
int64_t s = static_cast<int64_t>(batch_starts[i]);
int64_t e = static_cast<int64_t>(batch_starts[i + 1]);
if (e > s) {
Tensor out = outs->Slice<float>(s, e);
MultiClassOutput<float>(
scores_slice, boxes_slice, all_indices[i], score_dims.size(), &out);
if (return_index) {
index->Resize({static_cast<int64_t>(num_kept), 1});
int* output_idx = index->mutable_data<int>();
oindices = output_idx + s;
}
MultiClassOutput<float>(scores_slice,
boxes_slice,
all_indices[i],
score_dims.size(),
&out,
oindices,
offset);
}
}
}
LoD lod;
lod.emplace_back(batch_starts);
if (return_index) {
index->set_lod(lod);
}
outs->set_lod(lod);
}
} // namespace host
......@@ -395,4 +429,6 @@ REGISTER_LITE_KERNEL(multiclass_nms,
.BindInput("BBoxes", {LiteType::GetTensorTy(TARGET(kHost))})
.BindInput("Scores", {LiteType::GetTensorTy(TARGET(kHost))})
.BindOutput("Out", {LiteType::GetTensorTy(TARGET(kHost))})
.BindOutput("Index",
{LiteType::GetTensorTy(TARGET(kHost), PRECISION(kInt32))})
.Finalize();
......@@ -58,6 +58,12 @@ bool MulticlassNmsOpLite::AttachImpl(const cpp::OpDesc& opdesc,
auto bboxes_name = opdesc.Input("BBoxes").front();
auto scores_name = opdesc.Input("Scores").front();
auto out_name = opdesc.Output("Out").front();
std::vector<std::string> output_arg_names = opdesc.OutputArgumentNames();
if (std::find(output_arg_names.begin(), output_arg_names.end(), "Index") !=
output_arg_names.end()) {
auto index_name = opdesc.Output("Index").front();
param_.index = GetMutableVar<lite::Tensor>(scope, index_name);
}
param_.bboxes = GetVar<lite::Tensor>(scope, bboxes_name);
param_.scores = GetVar<lite::Tensor>(scope, scores_name);
param_.out = GetMutableVar<lite::Tensor>(scope, out_name);
......
......@@ -592,6 +592,7 @@ struct MulticlassNmsParam {
const lite::Tensor* bboxes{};
const lite::Tensor* scores{};
lite::Tensor* out{};
lite::Tensor* index{};
int background_label{0};
float score_threshold{};
int nms_top_k{};
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册