diff --git a/lite/kernels/host/multiclass_nms_compute.cc b/lite/kernels/host/multiclass_nms_compute.cc index 6f6079ef88fd9e61dbacb35c0ca8bdac536288a9..9cbc798d46ecb3cf98159e9b4762c8692ec8c1eb 100644 --- a/lite/kernels/host/multiclass_nms_compute.cc +++ b/lite/kernels/host/multiclass_nms_compute.cc @@ -271,7 +271,9 @@ void MultiClassOutput(const Tensor& scores, const Tensor& bboxes, const std::map>& selected_indices, const int scores_size, - Tensor* outs) { + Tensor* outs, + int* oindices = nullptr, + const int offset = 0) { int64_t class_num = scores.dims()[1]; int64_t predict_dim = scores.dims()[1]; int64_t box_size = bboxes.dims()[1]; @@ -301,9 +303,15 @@ void MultiClassOutput(const Tensor& scores, if (scores_size == 3) { bdata = bboxes_data + idx * box_size; odata[count * out_dim + 1] = sdata[idx]; // score + if (oindices != nullptr) { + oindices[count] = offset + idx; + } } else { bdata = bbox.data() + idx * box_size; odata[count * out_dim + 1] = *(scores_data + idx * class_num + label); + if (oindices != nullptr) { + oindices[count] = offset + idx * class_num + label; + } } // xmin, ymin, xmax, ymax or multi-points coordinates std::memcpy(odata + count * out_dim + 2, bdata, box_size * sizeof(T)); @@ -317,7 +325,8 @@ void MulticlassNmsCompute::Run() { auto* boxes = param.bboxes; auto* scores = param.scores; auto* outs = param.out; - + bool return_index = param.index ? true : false; + auto* index = param.index; auto score_dims = scores->dims(); auto score_size = score_dims.size(); @@ -349,36 +358,61 @@ void MulticlassNmsCompute::Run() { uint64_t num_kept = batch_starts.back(); if (num_kept == 0) { - outs->Resize({1, 1}); - float* od = outs->mutable_data(); - od[0] = -1; - batch_starts = {0, 1}; + if (return_index) { + outs->Resize({0, out_dim}); + index->Resize({0, 1}); + } else { + outs->Resize({1, 1}); + float* od = outs->mutable_data(); + od[0] = -1; + batch_starts = {0, 1}; + } } else { outs->Resize({static_cast(num_kept), out_dim}); + int offset = 0; + int* oindices = nullptr; for (int i = 0; i < n; ++i) { if (score_size == 3) { scores_slice = scores->Slice(i, i + 1); boxes_slice = boxes->Slice(i, i + 1); scores_slice.Resize({score_dims[1], score_dims[2]}); boxes_slice.Resize({score_dims[2], box_dim}); + if (return_index) { + offset = i * score_dims[2]; + } } else { auto boxes_lod = boxes->lod().back(); scores_slice = scores->Slice(boxes_lod[i], boxes_lod[i + 1]); boxes_slice = boxes->Slice(boxes_lod[i], boxes_lod[i + 1]); + if (return_index) { + offset = boxes_lod[i] * score_dims[1]; + } } int64_t s = static_cast(batch_starts[i]); int64_t e = static_cast(batch_starts[i + 1]); if (e > s) { Tensor out = outs->Slice(s, e); - MultiClassOutput( - scores_slice, boxes_slice, all_indices[i], score_dims.size(), &out); + if (return_index) { + index->Resize({static_cast(num_kept), 1}); + int* output_idx = index->mutable_data(); + oindices = output_idx + s; + } + MultiClassOutput(scores_slice, + boxes_slice, + all_indices[i], + score_dims.size(), + &out, + oindices, + offset); } } } LoD lod; lod.emplace_back(batch_starts); - + if (return_index) { + index->set_lod(lod); + } outs->set_lod(lod); } } // namespace host @@ -395,4 +429,6 @@ REGISTER_LITE_KERNEL(multiclass_nms, .BindInput("BBoxes", {LiteType::GetTensorTy(TARGET(kHost))}) .BindInput("Scores", {LiteType::GetTensorTy(TARGET(kHost))}) .BindOutput("Out", {LiteType::GetTensorTy(TARGET(kHost))}) + .BindOutput("Index", + {LiteType::GetTensorTy(TARGET(kHost), PRECISION(kInt32))}) .Finalize(); diff --git a/lite/operators/multiclass_nms_op.cc b/lite/operators/multiclass_nms_op.cc index b9b0db5ccac6ad4561f2bf71ddf5faed98c40a61..9dba5de4f81a1cba8f66132d89f6321ed76d368c 100644 --- a/lite/operators/multiclass_nms_op.cc +++ b/lite/operators/multiclass_nms_op.cc @@ -58,6 +58,12 @@ bool MulticlassNmsOpLite::AttachImpl(const cpp::OpDesc& opdesc, auto bboxes_name = opdesc.Input("BBoxes").front(); auto scores_name = opdesc.Input("Scores").front(); auto out_name = opdesc.Output("Out").front(); + std::vector output_arg_names = opdesc.OutputArgumentNames(); + if (std::find(output_arg_names.begin(), output_arg_names.end(), "Index") != + output_arg_names.end()) { + auto index_name = opdesc.Output("Index").front(); + param_.index = GetMutableVar(scope, index_name); + } param_.bboxes = GetVar(scope, bboxes_name); param_.scores = GetVar(scope, scores_name); param_.out = GetMutableVar(scope, out_name); diff --git a/lite/operators/op_params.h b/lite/operators/op_params.h index cfee6a0391d81992069d70e9ac37e0e6594bd305..769c8329f460280303089458e29668c1afa4c5a4 100644 --- a/lite/operators/op_params.h +++ b/lite/operators/op_params.h @@ -592,6 +592,7 @@ struct MulticlassNmsParam { const lite::Tensor* bboxes{}; const lite::Tensor* scores{}; lite::Tensor* out{}; + lite::Tensor* index{}; int background_label{0}; float score_threshold{}; int nms_top_k{};