699d6cd0 lite/arm/math/multiclass_nms.cc (Yan Chunwei 2019-08-16 22:39:39 +0800 1) // Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. 699d6cd0 lite/arm/math/multiclass_nms.cc (Yan Chunwei 2019-08-16 22:39:39 +0800 2) // 699d6cd0 lite/arm/math/multiclass_nms.cc (Yan Chunwei 2019-08-16 22:39:39 +0800 3) // Licensed under the Apache License, Version 2.0 (the "License"); 699d6cd0 lite/arm/math/multiclass_nms.cc (Yan Chunwei 2019-08-16 22:39:39 +0800 4) // you may not use this file except in compliance with the License. 699d6cd0 lite/arm/math/multiclass_nms.cc (Yan Chunwei 2019-08-16 22:39:39 +0800 5) // You may obtain a copy of the License at 699d6cd0 lite/arm/math/multiclass_nms.cc (Yan Chunwei 2019-08-16 22:39:39 +0800 6) // 699d6cd0 lite/arm/math/multiclass_nms.cc (Yan Chunwei 2019-08-16 22:39:39 +0800 7) // http://www.apache.org/licenses/LICENSE-2.0 699d6cd0 lite/arm/math/multiclass_nms.cc (Yan Chunwei 2019-08-16 22:39:39 +0800 8) // 699d6cd0 lite/arm/math/multiclass_nms.cc (Yan Chunwei 2019-08-16 22:39:39 +0800 9) // Unless required by applicable law or agreed to in writing, software 699d6cd0 lite/arm/math/multiclass_nms.cc (Yan Chunwei 2019-08-16 22:39:39 +0800 10) // distributed under the License is distributed on an "AS IS" BASIS, 699d6cd0 lite/arm/math/multiclass_nms.cc (Yan Chunwei 2019-08-16 22:39:39 +0800 11) // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 699d6cd0 lite/arm/math/multiclass_nms.cc (Yan Chunwei 2019-08-16 22:39:39 +0800 12) // See the License for the specific language governing permissions and 699d6cd0 lite/arm/math/multiclass_nms.cc (Yan Chunwei 2019-08-16 22:39:39 +0800 13) // limitations under the License. 699d6cd0 lite/arm/math/multiclass_nms.cc (Yan Chunwei 2019-08-16 22:39:39 +0800 14) de43e479 lite/kernels/host/multiclass_nms_compute.cc (Wilber 2019-08-29 21:24:46 +0800 15) #include "lite/kernels/host/multiclass_nms_compute.h" de43e479 lite/kernels/host/multiclass_nms_compute.cc (Wilber 2019-08-29 21:24:46 +0800 16) #include de43e479 lite/kernels/host/multiclass_nms_compute.cc (Wilber 2019-08-29 21:24:46 +0800 17) #include de43e479 lite/kernels/host/multiclass_nms_compute.cc (Wilber 2019-08-29 21:24:46 +0800 18) #include 699d6cd0 lite/arm/math/multiclass_nms.cc (Yan Chunwei 2019-08-16 22:39:39 +0800 19) 699d6cd0 lite/arm/math/multiclass_nms.cc (Yan Chunwei 2019-08-16 22:39:39 +0800 20) namespace paddle { 699d6cd0 lite/arm/math/multiclass_nms.cc (Yan Chunwei 2019-08-16 22:39:39 +0800 21) namespace lite { de43e479 lite/kernels/host/multiclass_nms_compute.cc (Wilber 2019-08-29 21:24:46 +0800 22) namespace kernels { de43e479 lite/kernels/host/multiclass_nms_compute.cc (Wilber 2019-08-29 21:24:46 +0800 23) namespace host { 699d6cd0 lite/arm/math/multiclass_nms.cc (Yan Chunwei 2019-08-16 22:39:39 +0800 24) deaddf9d lite/kernels/host/multiclass_nms_compute.cc (juncaipeng 2019-09-03 15:39:16 +0800 25) template deaddf9d lite/kernels/host/multiclass_nms_compute.cc (juncaipeng 2019-09-03 15:39:16 +0800 26) bool SortScorePairDescend(const std::pair& pair1, deaddf9d lite/kernels/host/multiclass_nms_compute.cc (juncaipeng 2019-09-03 15:39:16 +0800 27) const std::pair& pair2) { 699d6cd0 lite/arm/math/multiclass_nms.cc (Yan Chunwei 2019-08-16 22:39:39 +0800 28) return pair1.first > pair2.first; 699d6cd0 lite/arm/math/multiclass_nms.cc (Yan Chunwei 2019-08-16 22:39:39 +0800 29) } 699d6cd0 lite/arm/math/multiclass_nms.cc (Yan Chunwei 2019-08-16 22:39:39 +0800 30) deaddf9d lite/kernels/host/multiclass_nms_compute.cc (juncaipeng 2019-09-03 15:39:16 +0800 31) template deaddf9d lite/kernels/host/multiclass_nms_compute.cc (juncaipeng 2019-09-03 15:39:16 +0800 32) static void GetMaxScoreIndex(const std::vector& scores, deaddf9d lite/kernels/host/multiclass_nms_compute.cc (juncaipeng 2019-09-03 15:39:16 +0800 33) const T threshold, deaddf9d lite/kernels/host/multiclass_nms_compute.cc (juncaipeng 2019-09-03 15:39:16 +0800 34) int top_k, deaddf9d lite/kernels/host/multiclass_nms_compute.cc (juncaipeng 2019-09-03 15:39:16 +0800 35) std::vector>* sorted_indices) { deaddf9d lite/kernels/host/multiclass_nms_compute.cc (juncaipeng 2019-09-03 15:39:16 +0800 36) for (size_t i = 0; i < scores.size(); ++i) { 699d6cd0 lite/arm/math/multiclass_nms.cc (Yan Chunwei 2019-08-16 22:39:39 +0800 37) if (scores[i] > threshold) { deaddf9d lite/kernels/host/multiclass_nms_compute.cc (juncaipeng 2019-09-03 15:39:16 +0800 38) sorted_indices->push_back(std::make_pair(scores[i], i)); 699d6cd0 lite/arm/math/multiclass_nms.cc (Yan Chunwei 2019-08-16 22:39:39 +0800 39) } 699d6cd0 lite/arm/math/multiclass_nms.cc (Yan Chunwei 2019-08-16 22:39:39 +0800 40) } deaddf9d lite/kernels/host/multiclass_nms_compute.cc (juncaipeng 2019-09-03 15:39:16 +0800 41) // Sort the score pair according to the scores in descending order deaddf9d lite/kernels/host/multiclass_nms_compute.cc (juncaipeng 2019-09-03 15:39:16 +0800 42) std::stable_sort(sorted_indices->begin(), deaddf9d lite/kernels/host/multiclass_nms_compute.cc (juncaipeng 2019-09-03 15:39:16 +0800 43) sorted_indices->end(), deaddf9d lite/kernels/host/multiclass_nms_compute.cc (juncaipeng 2019-09-03 15:39:16 +0800 44) SortScorePairDescend); deaddf9d lite/kernels/host/multiclass_nms_compute.cc (juncaipeng 2019-09-03 15:39:16 +0800 45) // Keep top_k scores if needed. deaddf9d lite/kernels/host/multiclass_nms_compute.cc (juncaipeng 2019-09-03 15:39:16 +0800 46) if (top_k > -1 && top_k < static_cast(sorted_indices->size())) { deaddf9d lite/kernels/host/multiclass_nms_compute.cc (juncaipeng 2019-09-03 15:39:16 +0800 47) sorted_indices->resize(top_k); 699d6cd0 lite/arm/math/multiclass_nms.cc (Yan Chunwei 2019-08-16 22:39:39 +0800 48) } 699d6cd0 lite/arm/math/multiclass_nms.cc (Yan Chunwei 2019-08-16 22:39:39 +0800 49) } 699d6cd0 lite/arm/math/multiclass_nms.cc (Yan Chunwei 2019-08-16 22:39:39 +0800 50) deaddf9d lite/kernels/host/multiclass_nms_compute.cc (juncaipeng 2019-09-03 15:39:16 +0800 51) template deaddf9d lite/kernels/host/multiclass_nms_compute.cc (juncaipeng 2019-09-03 15:39:16 +0800 52) static T BBoxArea(const T* box, const bool normalized) { deaddf9d lite/kernels/host/multiclass_nms_compute.cc (juncaipeng 2019-09-03 15:39:16 +0800 53) if (box[2] < box[0] || box[3] < box[1]) { deaddf9d lite/kernels/host/multiclass_nms_compute.cc (juncaipeng 2019-09-03 15:39:16 +0800 54) // If coordinate values are is invalid deaddf9d lite/kernels/host/multiclass_nms_compute.cc (juncaipeng 2019-09-03 15:39:16 +0800 55) // (e.g. xmax < xmin or ymax < ymin), return 0. deaddf9d lite/kernels/host/multiclass_nms_compute.cc (juncaipeng 2019-09-03 15:39:16 +0800 56) return static_cast(0.); 699d6cd0 lite/arm/math/multiclass_nms.cc (Yan Chunwei 2019-08-16 22:39:39 +0800 57) } else { deaddf9d lite/kernels/host/multiclass_nms_compute.cc (juncaipeng 2019-09-03 15:39:16 +0800 58) const T w = box[2] - box[0]; deaddf9d lite/kernels/host/multiclass_nms_compute.cc (juncaipeng 2019-09-03 15:39:16 +0800 59) const T h = box[3] - box[1]; 699d6cd0 lite/arm/math/multiclass_nms.cc (Yan Chunwei 2019-08-16 22:39:39 +0800 60) if (normalized) { deaddf9d lite/kernels/host/multiclass_nms_compute.cc (juncaipeng 2019-09-03 15:39:16 +0800 61) return w * h; 699d6cd0 lite/arm/math/multiclass_nms.cc (Yan Chunwei 2019-08-16 22:39:39 +0800 62) } else { deaddf9d lite/kernels/host/multiclass_nms_compute.cc (juncaipeng 2019-09-03 15:39:16 +0800 63) // If coordinate values are not within range [0, 1]. deaddf9d lite/kernels/host/multiclass_nms_compute.cc (juncaipeng 2019-09-03 15:39:16 +0800 64) return (w + 1) * (h + 1); 699d6cd0 lite/arm/math/multiclass_nms.cc (Yan Chunwei 2019-08-16 22:39:39 +0800 65) } 699d6cd0 lite/arm/math/multiclass_nms.cc (Yan Chunwei 2019-08-16 22:39:39 +0800 66) } 699d6cd0 lite/arm/math/multiclass_nms.cc (Yan Chunwei 2019-08-16 22:39:39 +0800 67) } 699d6cd0 lite/arm/math/multiclass_nms.cc (Yan Chunwei 2019-08-16 22:39:39 +0800 68) deaddf9d lite/kernels/host/multiclass_nms_compute.cc (juncaipeng 2019-09-03 15:39:16 +0800 69) template deaddf9d lite/kernels/host/multiclass_nms_compute.cc (juncaipeng 2019-09-03 15:39:16 +0800 70) static T JaccardOverlap(const T* box1, const T* box2, const bool normalized) { deaddf9d lite/kernels/host/multiclass_nms_compute.cc (juncaipeng 2019-09-03 15:39:16 +0800 71) if (box2[0] > box1[2] || box2[2] < box1[0] || box2[1] > box1[3] || deaddf9d lite/kernels/host/multiclass_nms_compute.cc (juncaipeng 2019-09-03 15:39:16 +0800 72) box2[3] < box1[1]) { deaddf9d lite/kernels/host/multiclass_nms_compute.cc (juncaipeng 2019-09-03 15:39:16 +0800 73) return static_cast(0.); 699d6cd0 lite/arm/math/multiclass_nms.cc (Yan Chunwei 2019-08-16 22:39:39 +0800 74) } else { deaddf9d lite/kernels/host/multiclass_nms_compute.cc (juncaipeng 2019-09-03 15:39:16 +0800 75) const T inter_xmin = std::max(box1[0], box2[0]); deaddf9d lite/kernels/host/multiclass_nms_compute.cc (juncaipeng 2019-09-03 15:39:16 +0800 76) const T inter_ymin = std::max(box1[1], box2[1]); deaddf9d lite/kernels/host/multiclass_nms_compute.cc (juncaipeng 2019-09-03 15:39:16 +0800 77) const T inter_xmax = std::min(box1[2], box2[2]); deaddf9d lite/kernels/host/multiclass_nms_compute.cc (juncaipeng 2019-09-03 15:39:16 +0800 78) const T inter_ymax = std::min(box1[3], box2[3]); deaddf9d lite/kernels/host/multiclass_nms_compute.cc (juncaipeng 2019-09-03 15:39:16 +0800 79) T norm = normalized ? static_cast(0.) : static_cast(1.); deaddf9d lite/kernels/host/multiclass_nms_compute.cc (juncaipeng 2019-09-03 15:39:16 +0800 80) T inter_w = inter_xmax - inter_xmin + norm; deaddf9d lite/kernels/host/multiclass_nms_compute.cc (juncaipeng 2019-09-03 15:39:16 +0800 81) T inter_h = inter_ymax - inter_ymin + norm; deaddf9d lite/kernels/host/multiclass_nms_compute.cc (juncaipeng 2019-09-03 15:39:16 +0800 82) const T inter_area = inter_w * inter_h; deaddf9d lite/kernels/host/multiclass_nms_compute.cc (juncaipeng 2019-09-03 15:39:16 +0800 83) const T bbox1_area = BBoxArea(box1, normalized); deaddf9d lite/kernels/host/multiclass_nms_compute.cc (juncaipeng 2019-09-03 15:39:16 +0800 84) const T bbox2_area = BBoxArea(box2, normalized); deaddf9d lite/kernels/host/multiclass_nms_compute.cc (juncaipeng 2019-09-03 15:39:16 +0800 85) return inter_area / (bbox1_area + bbox2_area - inter_area); 699d6cd0 lite/arm/math/multiclass_nms.cc (Yan Chunwei 2019-08-16 22:39:39 +0800 86) } 699d6cd0 lite/arm/math/multiclass_nms.cc (Yan Chunwei 2019-08-16 22:39:39 +0800 87) } 699d6cd0 lite/arm/math/multiclass_nms.cc (Yan Chunwei 2019-08-16 22:39:39 +0800 88) deaddf9d lite/kernels/host/multiclass_nms_compute.cc (juncaipeng 2019-09-03 15:39:16 +0800 89) template deaddf9d lite/kernels/host/multiclass_nms_compute.cc (juncaipeng 2019-09-03 15:39:16 +0800 90) T PolyIoU(const T* box1, deaddf9d lite/kernels/host/multiclass_nms_compute.cc (juncaipeng 2019-09-03 15:39:16 +0800 91) const T* box2, deaddf9d lite/kernels/host/multiclass_nms_compute.cc (juncaipeng 2019-09-03 15:39:16 +0800 92) const size_t box_size, deaddf9d lite/kernels/host/multiclass_nms_compute.cc (juncaipeng 2019-09-03 15:39:16 +0800 93) const bool normalized) { deaddf9d lite/kernels/host/multiclass_nms_compute.cc (juncaipeng 2019-09-03 15:39:16 +0800 94) LOG(FATAL) << "PolyIoU not implement."; b80194db lite/kernels/host/multiclass_nms_compute.cc (huzhiqiang 2020-04-03 20:18:11 +0800 95) return *box1; deaddf9d lite/kernels/host/multiclass_nms_compute.cc (juncaipeng 2019-09-03 15:39:16 +0800 96) } 699d6cd0 lite/arm/math/multiclass_nms.cc (Yan Chunwei 2019-08-16 22:39:39 +0800 97) deaddf9d lite/kernels/host/multiclass_nms_compute.cc (juncaipeng 2019-09-03 15:39:16 +0800 98) template deaddf9d lite/kernels/host/multiclass_nms_compute.cc (juncaipeng 2019-09-03 15:39:16 +0800 99) void SliceOneClass(const Tensor& items, deaddf9d lite/kernels/host/multiclass_nms_compute.cc (juncaipeng 2019-09-03 15:39:16 +0800 100) const int class_id, deaddf9d lite/kernels/host/multiclass_nms_compute.cc (juncaipeng 2019-09-03 15:39:16 +0800 101) Tensor* one_class_item) { deaddf9d lite/kernels/host/multiclass_nms_compute.cc (juncaipeng 2019-09-03 15:39:16 +0800 102) T* item_data = one_class_item->mutable_data(); deaddf9d lite/kernels/host/multiclass_nms_compute.cc (juncaipeng 2019-09-03 15:39:16 +0800 103) const T* items_data = items.data(); deaddf9d lite/kernels/host/multiclass_nms_compute.cc (juncaipeng 2019-09-03 15:39:16 +0800 104) const int64_t num_item = items.dims()[0]; deaddf9d lite/kernels/host/multiclass_nms_compute.cc (juncaipeng 2019-09-03 15:39:16 +0800 105) const int64_t class_num = items.dims()[1]; deaddf9d lite/kernels/host/multiclass_nms_compute.cc (juncaipeng 2019-09-03 15:39:16 +0800 106) if (items.dims().size() == 3) { deaddf9d lite/kernels/host/multiclass_nms_compute.cc (juncaipeng 2019-09-03 15:39:16 +0800 107) int64_t item_size = items.dims()[2]; deaddf9d lite/kernels/host/multiclass_nms_compute.cc (juncaipeng 2019-09-03 15:39:16 +0800 108) for (int i = 0; i < num_item; ++i) { deaddf9d lite/kernels/host/multiclass_nms_compute.cc (juncaipeng 2019-09-03 15:39:16 +0800 109) std::memcpy(item_data + i * item_size, deaddf9d lite/kernels/host/multiclass_nms_compute.cc (juncaipeng 2019-09-03 15:39:16 +0800 110) items_data + i * class_num * item_size + class_id * item_size, deaddf9d lite/kernels/host/multiclass_nms_compute.cc (juncaipeng 2019-09-03 15:39:16 +0800 111) sizeof(T) * item_size); deaddf9d lite/kernels/host/multiclass_nms_compute.cc (juncaipeng 2019-09-03 15:39:16 +0800 112) } deaddf9d lite/kernels/host/multiclass_nms_compute.cc (juncaipeng 2019-09-03 15:39:16 +0800 113) } else { deaddf9d lite/kernels/host/multiclass_nms_compute.cc (juncaipeng 2019-09-03 15:39:16 +0800 114) for (int i = 0; i < num_item; ++i) { deaddf9d lite/kernels/host/multiclass_nms_compute.cc (juncaipeng 2019-09-03 15:39:16 +0800 115) item_data[i] = items_data[i * class_num + class_id]; deaddf9d lite/kernels/host/multiclass_nms_compute.cc (juncaipeng 2019-09-03 15:39:16 +0800 116) } deaddf9d lite/kernels/host/multiclass_nms_compute.cc (juncaipeng 2019-09-03 15:39:16 +0800 117) } deaddf9d lite/kernels/host/multiclass_nms_compute.cc (juncaipeng 2019-09-03 15:39:16 +0800 118) } 699d6cd0 lite/arm/math/multiclass_nms.cc (Yan Chunwei 2019-08-16 22:39:39 +0800 119) deaddf9d lite/kernels/host/multiclass_nms_compute.cc (juncaipeng 2019-09-03 15:39:16 +0800 120) template deaddf9d lite/kernels/host/multiclass_nms_compute.cc (juncaipeng 2019-09-03 15:39:16 +0800 121) void NMSFast(const Tensor& bbox, deaddf9d lite/kernels/host/multiclass_nms_compute.cc (juncaipeng 2019-09-03 15:39:16 +0800 122) const Tensor& scores, deaddf9d lite/kernels/host/multiclass_nms_compute.cc (juncaipeng 2019-09-03 15:39:16 +0800 123) const T score_threshold, deaddf9d lite/kernels/host/multiclass_nms_compute.cc (juncaipeng 2019-09-03 15:39:16 +0800 124) const T nms_threshold, deaddf9d lite/kernels/host/multiclass_nms_compute.cc (juncaipeng 2019-09-03 15:39:16 +0800 125) const T eta, deaddf9d lite/kernels/host/multiclass_nms_compute.cc (juncaipeng 2019-09-03 15:39:16 +0800 126) const int64_t top_k, deaddf9d lite/kernels/host/multiclass_nms_compute.cc (juncaipeng 2019-09-03 15:39:16 +0800 127) std::vector* selected_indices, deaddf9d lite/kernels/host/multiclass_nms_compute.cc (juncaipeng 2019-09-03 15:39:16 +0800 128) const bool normalized) { deaddf9d lite/kernels/host/multiclass_nms_compute.cc (juncaipeng 2019-09-03 15:39:16 +0800 129) // The total boxes for each instance. deaddf9d lite/kernels/host/multiclass_nms_compute.cc (juncaipeng 2019-09-03 15:39:16 +0800 130) int64_t num_boxes = bbox.dims()[0]; deaddf9d lite/kernels/host/multiclass_nms_compute.cc (juncaipeng 2019-09-03 15:39:16 +0800 131) // 4: [xmin ymin xmax ymax] deaddf9d lite/kernels/host/multiclass_nms_compute.cc (juncaipeng 2019-09-03 15:39:16 +0800 132) // 8: [x1 y1 x2 y2 x3 y3 x4 y4] deaddf9d lite/kernels/host/multiclass_nms_compute.cc (juncaipeng 2019-09-03 15:39:16 +0800 133) // 16, 24, or 32: [x1 y1 x2 y2 ... xn yn], n = 8, 12 or 16 deaddf9d lite/kernels/host/multiclass_nms_compute.cc (juncaipeng 2019-09-03 15:39:16 +0800 134) int64_t box_size = bbox.dims()[1]; deaddf9d lite/kernels/host/multiclass_nms_compute.cc (juncaipeng 2019-09-03 15:39:16 +0800 135) deaddf9d lite/kernels/host/multiclass_nms_compute.cc (juncaipeng 2019-09-03 15:39:16 +0800 136) std::vector scores_data(num_boxes); deaddf9d lite/kernels/host/multiclass_nms_compute.cc (juncaipeng 2019-09-03 15:39:16 +0800 137) std::copy_n(scores.data(), num_boxes, scores_data.begin()); deaddf9d lite/kernels/host/multiclass_nms_compute.cc (juncaipeng 2019-09-03 15:39:16 +0800 138) std::vector> sorted_indices; deaddf9d lite/kernels/host/multiclass_nms_compute.cc (juncaipeng 2019-09-03 15:39:16 +0800 139) GetMaxScoreIndex(scores_data, score_threshold, top_k, &sorted_indices); deaddf9d lite/kernels/host/multiclass_nms_compute.cc (juncaipeng 2019-09-03 15:39:16 +0800 140) deaddf9d lite/kernels/host/multiclass_nms_compute.cc (juncaipeng 2019-09-03 15:39:16 +0800 141) selected_indices->clear(); deaddf9d lite/kernels/host/multiclass_nms_compute.cc (juncaipeng 2019-09-03 15:39:16 +0800 142) T adaptive_threshold = nms_threshold; deaddf9d lite/kernels/host/multiclass_nms_compute.cc (juncaipeng 2019-09-03 15:39:16 +0800 143) const T* bbox_data = bbox.data(); deaddf9d lite/kernels/host/multiclass_nms_compute.cc (juncaipeng 2019-09-03 15:39:16 +0800 144) deaddf9d lite/kernels/host/multiclass_nms_compute.cc (juncaipeng 2019-09-03 15:39:16 +0800 145) while (sorted_indices.size() != 0) { deaddf9d lite/kernels/host/multiclass_nms_compute.cc (juncaipeng 2019-09-03 15:39:16 +0800 146) const int idx = sorted_indices.front().second; 699d6cd0 lite/arm/math/multiclass_nms.cc (Yan Chunwei 2019-08-16 22:39:39 +0800 147) bool keep = true; deaddf9d lite/kernels/host/multiclass_nms_compute.cc (juncaipeng 2019-09-03 15:39:16 +0800 148) for (size_t k = 0; k < selected_indices->size(); ++k) { 699d6cd0 lite/arm/math/multiclass_nms.cc (Yan Chunwei 2019-08-16 22:39:39 +0800 149) if (keep) { deaddf9d lite/kernels/host/multiclass_nms_compute.cc (juncaipeng 2019-09-03 15:39:16 +0800 150) const int kept_idx = (*selected_indices)[k]; deaddf9d lite/kernels/host/multiclass_nms_compute.cc (juncaipeng 2019-09-03 15:39:16 +0800 151) T overlap = T(0.); deaddf9d lite/kernels/host/multiclass_nms_compute.cc (juncaipeng 2019-09-03 15:39:16 +0800 152) // 4: [xmin ymin xmax ymax] deaddf9d lite/kernels/host/multiclass_nms_compute.cc (juncaipeng 2019-09-03 15:39:16 +0800 153) if (box_size == 4) { deaddf9d lite/kernels/host/multiclass_nms_compute.cc (juncaipeng 2019-09-03 15:39:16 +0800 154) overlap = JaccardOverlap(bbox_data + idx * box_size, deaddf9d lite/kernels/host/multiclass_nms_compute.cc (juncaipeng 2019-09-03 15:39:16 +0800 155) bbox_data + kept_idx * box_size, deaddf9d lite/kernels/host/multiclass_nms_compute.cc (juncaipeng 2019-09-03 15:39:16 +0800 156) normalized); deaddf9d lite/kernels/host/multiclass_nms_compute.cc (juncaipeng 2019-09-03 15:39:16 +0800 157) } deaddf9d lite/kernels/host/multiclass_nms_compute.cc (juncaipeng 2019-09-03 15:39:16 +0800 158) // 8: [x1 y1 x2 y2 x3 y3 x4 y4] or 16, 24, 32 deaddf9d lite/kernels/host/multiclass_nms_compute.cc (juncaipeng 2019-09-03 15:39:16 +0800 159) if (box_size == 8 || box_size == 16 || box_size == 24 || deaddf9d lite/kernels/host/multiclass_nms_compute.cc (juncaipeng 2019-09-03 15:39:16 +0800 160) box_size == 32) { deaddf9d lite/kernels/host/multiclass_nms_compute.cc (juncaipeng 2019-09-03 15:39:16 +0800 161) overlap = PolyIoU(bbox_data + idx * box_size, deaddf9d lite/kernels/host/multiclass_nms_compute.cc (juncaipeng 2019-09-03 15:39:16 +0800 162) bbox_data + kept_idx * box_size, deaddf9d lite/kernels/host/multiclass_nms_compute.cc (juncaipeng 2019-09-03 15:39:16 +0800 163) box_size, deaddf9d lite/kernels/host/multiclass_nms_compute.cc (juncaipeng 2019-09-03 15:39:16 +0800 164) normalized); deaddf9d lite/kernels/host/multiclass_nms_compute.cc (juncaipeng 2019-09-03 15:39:16 +0800 165) } 699d6cd0 lite/arm/math/multiclass_nms.cc (Yan Chunwei 2019-08-16 22:39:39 +0800 166) keep = overlap <= adaptive_threshold; 699d6cd0 lite/arm/math/multiclass_nms.cc (Yan Chunwei 2019-08-16 22:39:39 +0800 167) } else { 699d6cd0 lite/arm/math/multiclass_nms.cc (Yan Chunwei 2019-08-16 22:39:39 +0800 168) break; 699d6cd0 lite/arm/math/multiclass_nms.cc (Yan Chunwei 2019-08-16 22:39:39 +0800 169) } 699d6cd0 lite/arm/math/multiclass_nms.cc (Yan Chunwei 2019-08-16 22:39:39 +0800 170) } 699d6cd0 lite/arm/math/multiclass_nms.cc (Yan Chunwei 2019-08-16 22:39:39 +0800 171) if (keep) { deaddf9d lite/kernels/host/multiclass_nms_compute.cc (juncaipeng 2019-09-03 15:39:16 +0800 172) selected_indices->push_back(idx); 699d6cd0 lite/arm/math/multiclass_nms.cc (Yan Chunwei 2019-08-16 22:39:39 +0800 173) } deaddf9d lite/kernels/host/multiclass_nms_compute.cc (juncaipeng 2019-09-03 15:39:16 +0800 174) sorted_indices.erase(sorted_indices.begin()); 699d6cd0 lite/arm/math/multiclass_nms.cc (Yan Chunwei 2019-08-16 22:39:39 +0800 175) if (keep && eta < 1 && adaptive_threshold > 0.5) { 699d6cd0 lite/arm/math/multiclass_nms.cc (Yan Chunwei 2019-08-16 22:39:39 +0800 176) adaptive_threshold *= eta; 699d6cd0 lite/arm/math/multiclass_nms.cc (Yan Chunwei 2019-08-16 22:39:39 +0800 177) } 699d6cd0 lite/arm/math/multiclass_nms.cc (Yan Chunwei 2019-08-16 22:39:39 +0800 178) } 699d6cd0 lite/arm/math/multiclass_nms.cc (Yan Chunwei 2019-08-16 22:39:39 +0800 179) } 699d6cd0 lite/arm/math/multiclass_nms.cc (Yan Chunwei 2019-08-16 22:39:39 +0800 180) deaddf9d lite/kernels/host/multiclass_nms_compute.cc (juncaipeng 2019-09-03 15:39:16 +0800 181) template deaddf9d lite/kernels/host/multiclass_nms_compute.cc (juncaipeng 2019-09-03 15:39:16 +0800 182) void MultiClassNMS(const operators::MulticlassNmsParam& param, deaddf9d lite/kernels/host/multiclass_nms_compute.cc (juncaipeng 2019-09-03 15:39:16 +0800 183) const Tensor& scores, deaddf9d lite/kernels/host/multiclass_nms_compute.cc (juncaipeng 2019-09-03 15:39:16 +0800 184) const Tensor& bboxes, deaddf9d lite/kernels/host/multiclass_nms_compute.cc (juncaipeng 2019-09-03 15:39:16 +0800 185) const int scores_size, deaddf9d lite/kernels/host/multiclass_nms_compute.cc (juncaipeng 2019-09-03 15:39:16 +0800 186) std::map>* indices, deaddf9d lite/kernels/host/multiclass_nms_compute.cc (juncaipeng 2019-09-03 15:39:16 +0800 187) int* num_nmsed_out) { deaddf9d lite/kernels/host/multiclass_nms_compute.cc (juncaipeng 2019-09-03 15:39:16 +0800 188) int64_t background_label = param.background_label; deaddf9d lite/kernels/host/multiclass_nms_compute.cc (juncaipeng 2019-09-03 15:39:16 +0800 189) int64_t nms_top_k = param.nms_top_k; deaddf9d lite/kernels/host/multiclass_nms_compute.cc (juncaipeng 2019-09-03 15:39:16 +0800 190) int64_t keep_top_k = param.keep_top_k; deaddf9d lite/kernels/host/multiclass_nms_compute.cc (juncaipeng 2019-09-03 15:39:16 +0800 191) bool normalized = param.normalized; deaddf9d lite/kernels/host/multiclass_nms_compute.cc (juncaipeng 2019-09-03 15:39:16 +0800 192) T nms_threshold = static_cast(param.nms_threshold); deaddf9d lite/kernels/host/multiclass_nms_compute.cc (juncaipeng 2019-09-03 15:39:16 +0800 193) T nms_eta = static_cast(param.nms_eta); deaddf9d lite/kernels/host/multiclass_nms_compute.cc (juncaipeng 2019-09-03 15:39:16 +0800 194) T score_threshold = static_cast(param.score_threshold); deaddf9d lite/kernels/host/multiclass_nms_compute.cc (juncaipeng 2019-09-03 15:39:16 +0800 195) deaddf9d lite/kernels/host/multiclass_nms_compute.cc (juncaipeng 2019-09-03 15:39:16 +0800 196) int num_det = 0; deaddf9d lite/kernels/host/multiclass_nms_compute.cc (juncaipeng 2019-09-03 15:39:16 +0800 197) deaddf9d lite/kernels/host/multiclass_nms_compute.cc (juncaipeng 2019-09-03 15:39:16 +0800 198) int64_t class_num = scores_size == 3 ? scores.dims()[0] : scores.dims()[1]; deaddf9d lite/kernels/host/multiclass_nms_compute.cc (juncaipeng 2019-09-03 15:39:16 +0800 199) Tensor bbox_slice, score_slice; deaddf9d lite/kernels/host/multiclass_nms_compute.cc (juncaipeng 2019-09-03 15:39:16 +0800 200) for (int64_t c = 0; c < class_num; ++c) { deaddf9d lite/kernels/host/multiclass_nms_compute.cc (juncaipeng 2019-09-03 15:39:16 +0800 201) if (c == background_label) continue; deaddf9d lite/kernels/host/multiclass_nms_compute.cc (juncaipeng 2019-09-03 15:39:16 +0800 202) if (scores_size == 3) { deaddf9d lite/kernels/host/multiclass_nms_compute.cc (juncaipeng 2019-09-03 15:39:16 +0800 203) score_slice = scores.Slice(c, c + 1); deaddf9d lite/kernels/host/multiclass_nms_compute.cc (juncaipeng 2019-09-03 15:39:16 +0800 204) bbox_slice = bboxes; deaddf9d lite/kernels/host/multiclass_nms_compute.cc (juncaipeng 2019-09-03 15:39:16 +0800 205) } else { deaddf9d lite/kernels/host/multiclass_nms_compute.cc (juncaipeng 2019-09-03 15:39:16 +0800 206) score_slice.Resize({scores.dims()[0], 1}); deaddf9d lite/kernels/host/multiclass_nms_compute.cc (juncaipeng 2019-09-03 15:39:16 +0800 207) bbox_slice.Resize({scores.dims()[0], 4}); deaddf9d lite/kernels/host/multiclass_nms_compute.cc (juncaipeng 2019-09-03 15:39:16 +0800 208) SliceOneClass(scores, c, &score_slice); deaddf9d lite/kernels/host/multiclass_nms_compute.cc (juncaipeng 2019-09-03 15:39:16 +0800 209) SliceOneClass(bboxes, c, &bbox_slice); deaddf9d lite/kernels/host/multiclass_nms_compute.cc (juncaipeng 2019-09-03 15:39:16 +0800 210) } deaddf9d lite/kernels/host/multiclass_nms_compute.cc (juncaipeng 2019-09-03 15:39:16 +0800 211) NMSFast(bbox_slice, deaddf9d lite/kernels/host/multiclass_nms_compute.cc (juncaipeng 2019-09-03 15:39:16 +0800 212) score_slice, deaddf9d lite/kernels/host/multiclass_nms_compute.cc (juncaipeng 2019-09-03 15:39:16 +0800 213) score_threshold, deaddf9d lite/kernels/host/multiclass_nms_compute.cc (juncaipeng 2019-09-03 15:39:16 +0800 214) nms_threshold, deaddf9d lite/kernels/host/multiclass_nms_compute.cc (juncaipeng 2019-09-03 15:39:16 +0800 215) nms_eta, deaddf9d lite/kernels/host/multiclass_nms_compute.cc (juncaipeng 2019-09-03 15:39:16 +0800 216) nms_top_k, deaddf9d lite/kernels/host/multiclass_nms_compute.cc (juncaipeng 2019-09-03 15:39:16 +0800 217) &((*indices)[c]), deaddf9d lite/kernels/host/multiclass_nms_compute.cc (juncaipeng 2019-09-03 15:39:16 +0800 218) normalized); deaddf9d lite/kernels/host/multiclass_nms_compute.cc (juncaipeng 2019-09-03 15:39:16 +0800 219) if (scores_size == 2) { deaddf9d lite/kernels/host/multiclass_nms_compute.cc (juncaipeng 2019-09-03 15:39:16 +0800 220) std::stable_sort((*indices)[c].begin(), (*indices)[c].end()); deaddf9d lite/kernels/host/multiclass_nms_compute.cc (juncaipeng 2019-09-03 15:39:16 +0800 221) } deaddf9d lite/kernels/host/multiclass_nms_compute.cc (juncaipeng 2019-09-03 15:39:16 +0800 222) num_det += (*indices)[c].size(); deaddf9d lite/kernels/host/multiclass_nms_compute.cc (juncaipeng 2019-09-03 15:39:16 +0800 223) } 699d6cd0 lite/arm/math/multiclass_nms.cc (Yan Chunwei 2019-08-16 22:39:39 +0800 224) deaddf9d lite/kernels/host/multiclass_nms_compute.cc (juncaipeng 2019-09-03 15:39:16 +0800 225) *num_nmsed_out = num_det; deaddf9d lite/kernels/host/multiclass_nms_compute.cc (juncaipeng 2019-09-03 15:39:16 +0800 226) const T* scores_data = scores.data(); deaddf9d lite/kernels/host/multiclass_nms_compute.cc (juncaipeng 2019-09-03 15:39:16 +0800 227) if (keep_top_k > -1 && num_det > keep_top_k) { deaddf9d lite/kernels/host/multiclass_nms_compute.cc (juncaipeng 2019-09-03 15:39:16 +0800 228) const T* sdata; deaddf9d lite/kernels/host/multiclass_nms_compute.cc (juncaipeng 2019-09-03 15:39:16 +0800 229) std::vector>> score_index_pairs; deaddf9d lite/kernels/host/multiclass_nms_compute.cc (juncaipeng 2019-09-03 15:39:16 +0800 230) for (const auto& it : *indices) { deaddf9d lite/kernels/host/multiclass_nms_compute.cc (juncaipeng 2019-09-03 15:39:16 +0800 231) int label = it.first; deaddf9d lite/kernels/host/multiclass_nms_compute.cc (juncaipeng 2019-09-03 15:39:16 +0800 232) if (scores_size == 3) { deaddf9d lite/kernels/host/multiclass_nms_compute.cc (juncaipeng 2019-09-03 15:39:16 +0800 233) sdata = scores_data + label * scores.dims()[1]; deaddf9d lite/kernels/host/multiclass_nms_compute.cc (juncaipeng 2019-09-03 15:39:16 +0800 234) } else { deaddf9d lite/kernels/host/multiclass_nms_compute.cc (juncaipeng 2019-09-03 15:39:16 +0800 235) score_slice.Resize({scores.dims()[0], 1}); deaddf9d lite/kernels/host/multiclass_nms_compute.cc (juncaipeng 2019-09-03 15:39:16 +0800 236) SliceOneClass(scores, label, &score_slice); deaddf9d lite/kernels/host/multiclass_nms_compute.cc (juncaipeng 2019-09-03 15:39:16 +0800 237) sdata = score_slice.data(); 699d6cd0 lite/arm/math/multiclass_nms.cc (Yan Chunwei 2019-08-16 22:39:39 +0800 238) } deaddf9d lite/kernels/host/multiclass_nms_compute.cc (juncaipeng 2019-09-03 15:39:16 +0800 239) const std::vector& label_indices = it.second; deaddf9d lite/kernels/host/multiclass_nms_compute.cc (juncaipeng 2019-09-03 15:39:16 +0800 240) for (size_t j = 0; j < label_indices.size(); ++j) { deaddf9d lite/kernels/host/multiclass_nms_compute.cc (juncaipeng 2019-09-03 15:39:16 +0800 241) int idx = label_indices[j]; deaddf9d lite/kernels/host/multiclass_nms_compute.cc (juncaipeng 2019-09-03 15:39:16 +0800 242) score_index_pairs.push_back( deaddf9d lite/kernels/host/multiclass_nms_compute.cc (juncaipeng 2019-09-03 15:39:16 +0800 243) std::make_pair(sdata[idx], std::make_pair(label, idx))); 699d6cd0 lite/arm/math/multiclass_nms.cc (Yan Chunwei 2019-08-16 22:39:39 +0800 244) } 699d6cd0 lite/arm/math/multiclass_nms.cc (Yan Chunwei 2019-08-16 22:39:39 +0800 245) } deaddf9d lite/kernels/host/multiclass_nms_compute.cc (juncaipeng 2019-09-03 15:39:16 +0800 246) // Keep top k results per image. deaddf9d lite/kernels/host/multiclass_nms_compute.cc (juncaipeng 2019-09-03 15:39:16 +0800 247) std::stable_sort(score_index_pairs.begin(), deaddf9d lite/kernels/host/multiclass_nms_compute.cc (juncaipeng 2019-09-03 15:39:16 +0800 248) score_index_pairs.end(), deaddf9d lite/kernels/host/multiclass_nms_compute.cc (juncaipeng 2019-09-03 15:39:16 +0800 249) SortScorePairDescend>); deaddf9d lite/kernels/host/multiclass_nms_compute.cc (juncaipeng 2019-09-03 15:39:16 +0800 250) score_index_pairs.resize(keep_top_k); deaddf9d lite/kernels/host/multiclass_nms_compute.cc (juncaipeng 2019-09-03 15:39:16 +0800 251) deaddf9d lite/kernels/host/multiclass_nms_compute.cc (juncaipeng 2019-09-03 15:39:16 +0800 252) // Store the new indices. deaddf9d lite/kernels/host/multiclass_nms_compute.cc (juncaipeng 2019-09-03 15:39:16 +0800 253) std::map> new_indices; deaddf9d lite/kernels/host/multiclass_nms_compute.cc (juncaipeng 2019-09-03 15:39:16 +0800 254) for (size_t j = 0; j < score_index_pairs.size(); ++j) { deaddf9d lite/kernels/host/multiclass_nms_compute.cc (juncaipeng 2019-09-03 15:39:16 +0800 255) int label = score_index_pairs[j].second.first; deaddf9d lite/kernels/host/multiclass_nms_compute.cc (juncaipeng 2019-09-03 15:39:16 +0800 256) int idx = score_index_pairs[j].second.second; deaddf9d lite/kernels/host/multiclass_nms_compute.cc (juncaipeng 2019-09-03 15:39:16 +0800 257) new_indices[label].push_back(idx); deaddf9d lite/kernels/host/multiclass_nms_compute.cc (juncaipeng 2019-09-03 15:39:16 +0800 258) } deaddf9d lite/kernels/host/multiclass_nms_compute.cc (juncaipeng 2019-09-03 15:39:16 +0800 259) if (scores_size == 2) { deaddf9d lite/kernels/host/multiclass_nms_compute.cc (juncaipeng 2019-09-03 15:39:16 +0800 260) for (const auto& it : new_indices) { deaddf9d lite/kernels/host/multiclass_nms_compute.cc (juncaipeng 2019-09-03 15:39:16 +0800 261) int label = it.first; deaddf9d lite/kernels/host/multiclass_nms_compute.cc (juncaipeng 2019-09-03 15:39:16 +0800 262) std::stable_sort(new_indices[label].begin(), new_indices[label].end()); 699d6cd0 lite/arm/math/multiclass_nms.cc (Yan Chunwei 2019-08-16 22:39:39 +0800 263) } 699d6cd0 lite/arm/math/multiclass_nms.cc (Yan Chunwei 2019-08-16 22:39:39 +0800 264) } deaddf9d lite/kernels/host/multiclass_nms_compute.cc (juncaipeng 2019-09-03 15:39:16 +0800 265) new_indices.swap(*indices); deaddf9d lite/kernels/host/multiclass_nms_compute.cc (juncaipeng 2019-09-03 15:39:16 +0800 266) *num_nmsed_out = keep_top_k; 699d6cd0 lite/arm/math/multiclass_nms.cc (Yan Chunwei 2019-08-16 22:39:39 +0800 267) } deaddf9d lite/kernels/host/multiclass_nms_compute.cc (juncaipeng 2019-09-03 15:39:16 +0800 268) } 699d6cd0 lite/arm/math/multiclass_nms.cc (Yan Chunwei 2019-08-16 22:39:39 +0800 269) deaddf9d lite/kernels/host/multiclass_nms_compute.cc (juncaipeng 2019-09-03 15:39:16 +0800 270) template deaddf9d lite/kernels/host/multiclass_nms_compute.cc (juncaipeng 2019-09-03 15:39:16 +0800 271) void MultiClassOutput(const Tensor& scores, deaddf9d lite/kernels/host/multiclass_nms_compute.cc (juncaipeng 2019-09-03 15:39:16 +0800 272) const Tensor& bboxes, deaddf9d lite/kernels/host/multiclass_nms_compute.cc (juncaipeng 2019-09-03 15:39:16 +0800 273) const std::map>& selected_indices, deaddf9d lite/kernels/host/multiclass_nms_compute.cc (juncaipeng 2019-09-03 15:39:16 +0800 274) const int scores_size, e1c4adfd lite/kernels/host/multiclass_nms_compute.cc (yiicy 2019-12-24 16:10:19 +0800 275) Tensor* outs, e1c4adfd lite/kernels/host/multiclass_nms_compute.cc (yiicy 2019-12-24 16:10:19 +0800 276) int* oindices = nullptr, e1c4adfd lite/kernels/host/multiclass_nms_compute.cc (yiicy 2019-12-24 16:10:19 +0800 277) const int offset = 0) { deaddf9d lite/kernels/host/multiclass_nms_compute.cc (juncaipeng 2019-09-03 15:39:16 +0800 278) int64_t class_num = scores.dims()[1]; deaddf9d lite/kernels/host/multiclass_nms_compute.cc (juncaipeng 2019-09-03 15:39:16 +0800 279) int64_t predict_dim = scores.dims()[1]; deaddf9d lite/kernels/host/multiclass_nms_compute.cc (juncaipeng 2019-09-03 15:39:16 +0800 280) int64_t box_size = bboxes.dims()[1]; deaddf9d lite/kernels/host/multiclass_nms_compute.cc (juncaipeng 2019-09-03 15:39:16 +0800 281) if (scores_size == 2) { deaddf9d lite/kernels/host/multiclass_nms_compute.cc (juncaipeng 2019-09-03 15:39:16 +0800 282) box_size = bboxes.dims()[2]; 699d6cd0 lite/arm/math/multiclass_nms.cc (Yan Chunwei 2019-08-16 22:39:39 +0800 283) } deaddf9d lite/kernels/host/multiclass_nms_compute.cc (juncaipeng 2019-09-03 15:39:16 +0800 284) int64_t out_dim = box_size + 2; deaddf9d lite/kernels/host/multiclass_nms_compute.cc (juncaipeng 2019-09-03 15:39:16 +0800 285) auto* scores_data = scores.data(); deaddf9d lite/kernels/host/multiclass_nms_compute.cc (juncaipeng 2019-09-03 15:39:16 +0800 286) auto* bboxes_data = bboxes.data(); deaddf9d lite/kernels/host/multiclass_nms_compute.cc (juncaipeng 2019-09-03 15:39:16 +0800 287) auto* odata = outs->mutable_data(); deaddf9d lite/kernels/host/multiclass_nms_compute.cc (juncaipeng 2019-09-03 15:39:16 +0800 288) const T* sdata; deaddf9d lite/kernels/host/multiclass_nms_compute.cc (juncaipeng 2019-09-03 15:39:16 +0800 289) Tensor bbox; deaddf9d lite/kernels/host/multiclass_nms_compute.cc (juncaipeng 2019-09-03 15:39:16 +0800 290) bbox.Resize({scores.dims()[0], box_size}); 699d6cd0 lite/arm/math/multiclass_nms.cc (Yan Chunwei 2019-08-16 22:39:39 +0800 291) int count = 0; deaddf9d lite/kernels/host/multiclass_nms_compute.cc (juncaipeng 2019-09-03 15:39:16 +0800 292) for (const auto& it : selected_indices) { deaddf9d lite/kernels/host/multiclass_nms_compute.cc (juncaipeng 2019-09-03 15:39:16 +0800 293) int label = it.first; deaddf9d lite/kernels/host/multiclass_nms_compute.cc (juncaipeng 2019-09-03 15:39:16 +0800 294) const std::vector& indices = it.second; deaddf9d lite/kernels/host/multiclass_nms_compute.cc (juncaipeng 2019-09-03 15:39:16 +0800 295) if (scores_size == 2) { deaddf9d lite/kernels/host/multiclass_nms_compute.cc (juncaipeng 2019-09-03 15:39:16 +0800 296) SliceOneClass(bboxes, label, &bbox); deaddf9d lite/kernels/host/multiclass_nms_compute.cc (juncaipeng 2019-09-03 15:39:16 +0800 297) } else { deaddf9d lite/kernels/host/multiclass_nms_compute.cc (juncaipeng 2019-09-03 15:39:16 +0800 298) sdata = scores_data + label * predict_dim; deaddf9d lite/kernels/host/multiclass_nms_compute.cc (juncaipeng 2019-09-03 15:39:16 +0800 299) } deaddf9d lite/kernels/host/multiclass_nms_compute.cc (juncaipeng 2019-09-03 15:39:16 +0800 300) for (size_t j = 0; j < indices.size(); ++j) { deaddf9d lite/kernels/host/multiclass_nms_compute.cc (juncaipeng 2019-09-03 15:39:16 +0800 301) int idx = indices[j]; deaddf9d lite/kernels/host/multiclass_nms_compute.cc (juncaipeng 2019-09-03 15:39:16 +0800 302) odata[count * out_dim] = label; // label deaddf9d lite/kernels/host/multiclass_nms_compute.cc (juncaipeng 2019-09-03 15:39:16 +0800 303) const T* bdata; deaddf9d lite/kernels/host/multiclass_nms_compute.cc (juncaipeng 2019-09-03 15:39:16 +0800 304) if (scores_size == 3) { deaddf9d lite/kernels/host/multiclass_nms_compute.cc (juncaipeng 2019-09-03 15:39:16 +0800 305) bdata = bboxes_data + idx * box_size; deaddf9d lite/kernels/host/multiclass_nms_compute.cc (juncaipeng 2019-09-03 15:39:16 +0800 306) odata[count * out_dim + 1] = sdata[idx]; // score e1c4adfd lite/kernels/host/multiclass_nms_compute.cc (yiicy 2019-12-24 16:10:19 +0800 307) if (oindices != nullptr) { e1c4adfd lite/kernels/host/multiclass_nms_compute.cc (yiicy 2019-12-24 16:10:19 +0800 308) oindices[count] = offset + idx; e1c4adfd lite/kernels/host/multiclass_nms_compute.cc (yiicy 2019-12-24 16:10:19 +0800 309) } deaddf9d lite/kernels/host/multiclass_nms_compute.cc (juncaipeng 2019-09-03 15:39:16 +0800 310) } else { deaddf9d lite/kernels/host/multiclass_nms_compute.cc (juncaipeng 2019-09-03 15:39:16 +0800 311) bdata = bbox.data() + idx * box_size; deaddf9d lite/kernels/host/multiclass_nms_compute.cc (juncaipeng 2019-09-03 15:39:16 +0800 312) odata[count * out_dim + 1] = *(scores_data + idx * class_num + label); e1c4adfd lite/kernels/host/multiclass_nms_compute.cc (yiicy 2019-12-24 16:10:19 +0800 313) if (oindices != nullptr) { e1c4adfd lite/kernels/host/multiclass_nms_compute.cc (yiicy 2019-12-24 16:10:19 +0800 314) oindices[count] = offset + idx * class_num + label; e1c4adfd lite/kernels/host/multiclass_nms_compute.cc (yiicy 2019-12-24 16:10:19 +0800 315) } 699d6cd0 lite/arm/math/multiclass_nms.cc (Yan Chunwei 2019-08-16 22:39:39 +0800 316) } deaddf9d lite/kernels/host/multiclass_nms_compute.cc (juncaipeng 2019-09-03 15:39:16 +0800 317) // xmin, ymin, xmax, ymax or multi-points coordinates deaddf9d lite/kernels/host/multiclass_nms_compute.cc (juncaipeng 2019-09-03 15:39:16 +0800 318) std::memcpy(odata + count * out_dim + 2, bdata, box_size * sizeof(T)); deaddf9d lite/kernels/host/multiclass_nms_compute.cc (juncaipeng 2019-09-03 15:39:16 +0800 319) count++; 699d6cd0 lite/arm/math/multiclass_nms.cc (Yan Chunwei 2019-08-16 22:39:39 +0800 320) } 699d6cd0 lite/arm/math/multiclass_nms.cc (Yan Chunwei 2019-08-16 22:39:39 +0800 321) } 699d6cd0 lite/arm/math/multiclass_nms.cc (Yan Chunwei 2019-08-16 22:39:39 +0800 322) } 699d6cd0 lite/arm/math/multiclass_nms.cc (Yan Chunwei 2019-08-16 22:39:39 +0800 323) de43e479 lite/kernels/host/multiclass_nms_compute.cc (Wilber 2019-08-29 21:24:46 +0800 324) void MulticlassNmsCompute::Run() { de43e479 lite/kernels/host/multiclass_nms_compute.cc (Wilber 2019-08-29 21:24:46 +0800 325) auto& param = Param(); deaddf9d lite/kernels/host/multiclass_nms_compute.cc (juncaipeng 2019-09-03 15:39:16 +0800 326) auto* boxes = param.bboxes; deaddf9d lite/kernels/host/multiclass_nms_compute.cc (juncaipeng 2019-09-03 15:39:16 +0800 327) auto* scores = param.scores; deaddf9d lite/kernels/host/multiclass_nms_compute.cc (juncaipeng 2019-09-03 15:39:16 +0800 328) auto* outs = param.out; e1c4adfd lite/kernels/host/multiclass_nms_compute.cc (yiicy 2019-12-24 16:10:19 +0800 329) bool return_index = param.index ? true : false; e1c4adfd lite/kernels/host/multiclass_nms_compute.cc (yiicy 2019-12-24 16:10:19 +0800 330) auto* index = param.index; deaddf9d lite/kernels/host/multiclass_nms_compute.cc (juncaipeng 2019-09-03 15:39:16 +0800 331) auto score_dims = scores->dims(); deaddf9d lite/kernels/host/multiclass_nms_compute.cc (juncaipeng 2019-09-03 15:39:16 +0800 332) auto score_size = score_dims.size(); de43e479 lite/kernels/host/multiclass_nms_compute.cc (Wilber 2019-08-29 21:24:46 +0800 333) deaddf9d lite/kernels/host/multiclass_nms_compute.cc (juncaipeng 2019-09-03 15:39:16 +0800 334) std::vector>> all_indices; deaddf9d lite/kernels/host/multiclass_nms_compute.cc (juncaipeng 2019-09-03 15:39:16 +0800 335) std::vector batch_starts = {0}; deaddf9d lite/kernels/host/multiclass_nms_compute.cc (juncaipeng 2019-09-03 15:39:16 +0800 336) int64_t batch_size = score_dims[0]; deaddf9d lite/kernels/host/multiclass_nms_compute.cc (juncaipeng 2019-09-03 15:39:16 +0800 337) int64_t box_dim = boxes->dims()[2]; deaddf9d lite/kernels/host/multiclass_nms_compute.cc (juncaipeng 2019-09-03 15:39:16 +0800 338) int64_t out_dim = box_dim + 2; deaddf9d lite/kernels/host/multiclass_nms_compute.cc (juncaipeng 2019-09-03 15:39:16 +0800 339) int num_nmsed_out = 0; deaddf9d lite/kernels/host/multiclass_nms_compute.cc (juncaipeng 2019-09-03 15:39:16 +0800 340) Tensor boxes_slice, scores_slice; deaddf9d lite/kernels/host/multiclass_nms_compute.cc (juncaipeng 2019-09-03 15:39:16 +0800 341) int n = score_size == 3 ? batch_size : boxes->lod().back().size() - 1; deaddf9d lite/kernels/host/multiclass_nms_compute.cc (juncaipeng 2019-09-03 15:39:16 +0800 342) for (int i = 0; i < n; ++i) { deaddf9d lite/kernels/host/multiclass_nms_compute.cc (juncaipeng 2019-09-03 15:39:16 +0800 343) if (score_size == 3) { deaddf9d lite/kernels/host/multiclass_nms_compute.cc (juncaipeng 2019-09-03 15:39:16 +0800 344) scores_slice = scores->Slice(i, i + 1); deaddf9d lite/kernels/host/multiclass_nms_compute.cc (juncaipeng 2019-09-03 15:39:16 +0800 345) scores_slice.Resize({score_dims[1], score_dims[2]}); deaddf9d lite/kernels/host/multiclass_nms_compute.cc (juncaipeng 2019-09-03 15:39:16 +0800 346) boxes_slice = boxes->Slice(i, i + 1); deaddf9d lite/kernels/host/multiclass_nms_compute.cc (juncaipeng 2019-09-03 15:39:16 +0800 347) boxes_slice.Resize({score_dims[2], box_dim}); de43e479 lite/kernels/host/multiclass_nms_compute.cc (Wilber 2019-08-29 21:24:46 +0800 348) } else { deaddf9d lite/kernels/host/multiclass_nms_compute.cc (juncaipeng 2019-09-03 15:39:16 +0800 349) auto boxes_lod = boxes->lod().back(); deaddf9d lite/kernels/host/multiclass_nms_compute.cc (juncaipeng 2019-09-03 15:39:16 +0800 350) scores_slice = scores->Slice(boxes_lod[i], boxes_lod[i + 1]); deaddf9d lite/kernels/host/multiclass_nms_compute.cc (juncaipeng 2019-09-03 15:39:16 +0800 351) boxes_slice = boxes->Slice(boxes_lod[i], boxes_lod[i + 1]); de43e479 lite/kernels/host/multiclass_nms_compute.cc (Wilber 2019-08-29 21:24:46 +0800 352) } deaddf9d lite/kernels/host/multiclass_nms_compute.cc (juncaipeng 2019-09-03 15:39:16 +0800 353) std::map> indices; deaddf9d lite/kernels/host/multiclass_nms_compute.cc (juncaipeng 2019-09-03 15:39:16 +0800 354) MultiClassNMS( deaddf9d lite/kernels/host/multiclass_nms_compute.cc (juncaipeng 2019-09-03 15:39:16 +0800 355) param, scores_slice, boxes_slice, score_size, &indices, &num_nmsed_out); deaddf9d lite/kernels/host/multiclass_nms_compute.cc (juncaipeng 2019-09-03 15:39:16 +0800 356) all_indices.push_back(indices); deaddf9d lite/kernels/host/multiclass_nms_compute.cc (juncaipeng 2019-09-03 15:39:16 +0800 357) batch_starts.push_back(batch_starts.back() + num_nmsed_out); de43e479 lite/kernels/host/multiclass_nms_compute.cc (Wilber 2019-08-29 21:24:46 +0800 358) } deaddf9d lite/kernels/host/multiclass_nms_compute.cc (juncaipeng 2019-09-03 15:39:16 +0800 359) deaddf9d lite/kernels/host/multiclass_nms_compute.cc (juncaipeng 2019-09-03 15:39:16 +0800 360) uint64_t num_kept = batch_starts.back(); deaddf9d lite/kernels/host/multiclass_nms_compute.cc (juncaipeng 2019-09-03 15:39:16 +0800 361) if (num_kept == 0) { e1c4adfd lite/kernels/host/multiclass_nms_compute.cc (yiicy 2019-12-24 16:10:19 +0800 362) if (return_index) { e1c4adfd lite/kernels/host/multiclass_nms_compute.cc (yiicy 2019-12-24 16:10:19 +0800 363) outs->Resize({0, out_dim}); e1c4adfd lite/kernels/host/multiclass_nms_compute.cc (yiicy 2019-12-24 16:10:19 +0800 364) index->Resize({0, 1}); e1c4adfd lite/kernels/host/multiclass_nms_compute.cc (yiicy 2019-12-24 16:10:19 +0800 365) } else { e1c4adfd lite/kernels/host/multiclass_nms_compute.cc (yiicy 2019-12-24 16:10:19 +0800 366) outs->Resize({1, 1}); e1c4adfd lite/kernels/host/multiclass_nms_compute.cc (yiicy 2019-12-24 16:10:19 +0800 367) float* od = outs->mutable_data(); e1c4adfd lite/kernels/host/multiclass_nms_compute.cc (yiicy 2019-12-24 16:10:19 +0800 368) od[0] = -1; e1c4adfd lite/kernels/host/multiclass_nms_compute.cc (yiicy 2019-12-24 16:10:19 +0800 369) batch_starts = {0, 1}; e1c4adfd lite/kernels/host/multiclass_nms_compute.cc (yiicy 2019-12-24 16:10:19 +0800 370) } de43e479 lite/kernels/host/multiclass_nms_compute.cc (Wilber 2019-08-29 21:24:46 +0800 371) } else { deaddf9d lite/kernels/host/multiclass_nms_compute.cc (juncaipeng 2019-09-03 15:39:16 +0800 372) outs->Resize({static_cast(num_kept), out_dim}); 00000000 lite/kernels/host/multiclass_nms_compute.cc (Not Committed Yet 2020-04-15 02:36:38 +0000 373) <<<<<<< HEAD 99f9b310 lite/kernels/host/multiclass_nms_compute.cc (dingminghui 2020-04-07 16:59:12 +0800 374) (void)outs->mutable_data(); 00000000 lite/kernels/host/multiclass_nms_compute.cc (Not Committed Yet 2020-04-15 02:36:38 +0000 375) ======= d571eb4e lite/kernels/host/multiclass_nms_compute.cc (zhupengyang 2020-04-02 10:36:08 +0800 376) outs->mutable_data(); 00000000 lite/kernels/host/multiclass_nms_compute.cc (Not Committed Yet 2020-04-15 02:36:38 +0000 377) >>>>>>> upstream/develop e1c4adfd lite/kernels/host/multiclass_nms_compute.cc (yiicy 2019-12-24 16:10:19 +0800 378) int offset = 0; e1c4adfd lite/kernels/host/multiclass_nms_compute.cc (yiicy 2019-12-24 16:10:19 +0800 379) int* oindices = nullptr; deaddf9d lite/kernels/host/multiclass_nms_compute.cc (juncaipeng 2019-09-03 15:39:16 +0800 380) for (int i = 0; i < n; ++i) { deaddf9d lite/kernels/host/multiclass_nms_compute.cc (juncaipeng 2019-09-03 15:39:16 +0800 381) if (score_size == 3) { deaddf9d lite/kernels/host/multiclass_nms_compute.cc (juncaipeng 2019-09-03 15:39:16 +0800 382) scores_slice = scores->Slice(i, i + 1); deaddf9d lite/kernels/host/multiclass_nms_compute.cc (juncaipeng 2019-09-03 15:39:16 +0800 383) boxes_slice = boxes->Slice(i, i + 1); deaddf9d lite/kernels/host/multiclass_nms_compute.cc (juncaipeng 2019-09-03 15:39:16 +0800 384) scores_slice.Resize({score_dims[1], score_dims[2]}); deaddf9d lite/kernels/host/multiclass_nms_compute.cc (juncaipeng 2019-09-03 15:39:16 +0800 385) boxes_slice.Resize({score_dims[2], box_dim}); e1c4adfd lite/kernels/host/multiclass_nms_compute.cc (yiicy 2019-12-24 16:10:19 +0800 386) if (return_index) { e1c4adfd lite/kernels/host/multiclass_nms_compute.cc (yiicy 2019-12-24 16:10:19 +0800 387) offset = i * score_dims[2]; e1c4adfd lite/kernels/host/multiclass_nms_compute.cc (yiicy 2019-12-24 16:10:19 +0800 388) } deaddf9d lite/kernels/host/multiclass_nms_compute.cc (juncaipeng 2019-09-03 15:39:16 +0800 389) } else { deaddf9d lite/kernels/host/multiclass_nms_compute.cc (juncaipeng 2019-09-03 15:39:16 +0800 390) auto boxes_lod = boxes->lod().back(); deaddf9d lite/kernels/host/multiclass_nms_compute.cc (juncaipeng 2019-09-03 15:39:16 +0800 391) scores_slice = scores->Slice(boxes_lod[i], boxes_lod[i + 1]); deaddf9d lite/kernels/host/multiclass_nms_compute.cc (juncaipeng 2019-09-03 15:39:16 +0800 392) boxes_slice = boxes->Slice(boxes_lod[i], boxes_lod[i + 1]); e1c4adfd lite/kernels/host/multiclass_nms_compute.cc (yiicy 2019-12-24 16:10:19 +0800 393) if (return_index) { e1c4adfd lite/kernels/host/multiclass_nms_compute.cc (yiicy 2019-12-24 16:10:19 +0800 394) offset = boxes_lod[i] * score_dims[1]; e1c4adfd lite/kernels/host/multiclass_nms_compute.cc (yiicy 2019-12-24 16:10:19 +0800 395) } deaddf9d lite/kernels/host/multiclass_nms_compute.cc (juncaipeng 2019-09-03 15:39:16 +0800 396) } deaddf9d lite/kernels/host/multiclass_nms_compute.cc (juncaipeng 2019-09-03 15:39:16 +0800 397) int64_t s = static_cast(batch_starts[i]); deaddf9d lite/kernels/host/multiclass_nms_compute.cc (juncaipeng 2019-09-03 15:39:16 +0800 398) int64_t e = static_cast(batch_starts[i + 1]); deaddf9d lite/kernels/host/multiclass_nms_compute.cc (juncaipeng 2019-09-03 15:39:16 +0800 399) if (e > s) { deaddf9d lite/kernels/host/multiclass_nms_compute.cc (juncaipeng 2019-09-03 15:39:16 +0800 400) Tensor out = outs->Slice(s, e); e1c4adfd lite/kernels/host/multiclass_nms_compute.cc (yiicy 2019-12-24 16:10:19 +0800 401) if (return_index) { e1c4adfd lite/kernels/host/multiclass_nms_compute.cc (yiicy 2019-12-24 16:10:19 +0800 402) index->Resize({static_cast(num_kept), 1}); e1c4adfd lite/kernels/host/multiclass_nms_compute.cc (yiicy 2019-12-24 16:10:19 +0800 403) int* output_idx = index->mutable_data(); e1c4adfd lite/kernels/host/multiclass_nms_compute.cc (yiicy 2019-12-24 16:10:19 +0800 404) oindices = output_idx + s; e1c4adfd lite/kernels/host/multiclass_nms_compute.cc (yiicy 2019-12-24 16:10:19 +0800 405) } e1c4adfd lite/kernels/host/multiclass_nms_compute.cc (yiicy 2019-12-24 16:10:19 +0800 406) MultiClassOutput(scores_slice, e1c4adfd lite/kernels/host/multiclass_nms_compute.cc (yiicy 2019-12-24 16:10:19 +0800 407) boxes_slice, e1c4adfd lite/kernels/host/multiclass_nms_compute.cc (yiicy 2019-12-24 16:10:19 +0800 408) all_indices[i], e1c4adfd lite/kernels/host/multiclass_nms_compute.cc (yiicy 2019-12-24 16:10:19 +0800 409) score_dims.size(), e1c4adfd lite/kernels/host/multiclass_nms_compute.cc (yiicy 2019-12-24 16:10:19 +0800 410) &out, e1c4adfd lite/kernels/host/multiclass_nms_compute.cc (yiicy 2019-12-24 16:10:19 +0800 411) oindices, e1c4adfd lite/kernels/host/multiclass_nms_compute.cc (yiicy 2019-12-24 16:10:19 +0800 412) offset); deaddf9d lite/kernels/host/multiclass_nms_compute.cc (juncaipeng 2019-09-03 15:39:16 +0800 413) } deaddf9d lite/kernels/host/multiclass_nms_compute.cc (juncaipeng 2019-09-03 15:39:16 +0800 414) } de43e479 lite/kernels/host/multiclass_nms_compute.cc (Wilber 2019-08-29 21:24:46 +0800 415) } de43e479 lite/kernels/host/multiclass_nms_compute.cc (Wilber 2019-08-29 21:24:46 +0800 416) deaddf9d lite/kernels/host/multiclass_nms_compute.cc (juncaipeng 2019-09-03 15:39:16 +0800 417) LoD lod; deaddf9d lite/kernels/host/multiclass_nms_compute.cc (juncaipeng 2019-09-03 15:39:16 +0800 418) lod.emplace_back(batch_starts); e1c4adfd lite/kernels/host/multiclass_nms_compute.cc (yiicy 2019-12-24 16:10:19 +0800 419) if (return_index) { e1c4adfd lite/kernels/host/multiclass_nms_compute.cc (yiicy 2019-12-24 16:10:19 +0800 420) index->set_lod(lod); e1c4adfd lite/kernels/host/multiclass_nms_compute.cc (yiicy 2019-12-24 16:10:19 +0800 421) } deaddf9d lite/kernels/host/multiclass_nms_compute.cc (juncaipeng 2019-09-03 15:39:16 +0800 422) outs->set_lod(lod); deaddf9d lite/kernels/host/multiclass_nms_compute.cc (juncaipeng 2019-09-03 15:39:16 +0800 423) } de43e479 lite/kernels/host/multiclass_nms_compute.cc (Wilber 2019-08-29 21:24:46 +0800 424) } // namespace host de43e479 lite/kernels/host/multiclass_nms_compute.cc (Wilber 2019-08-29 21:24:46 +0800 425) } // namespace kernels 699d6cd0 lite/arm/math/multiclass_nms.cc (Yan Chunwei 2019-08-16 22:39:39 +0800 426) } // namespace lite 699d6cd0 lite/arm/math/multiclass_nms.cc (Yan Chunwei 2019-08-16 22:39:39 +0800 427) } // namespace paddle de43e479 lite/kernels/host/multiclass_nms_compute.cc (Wilber 2019-08-29 21:24:46 +0800 428) de43e479 lite/kernels/host/multiclass_nms_compute.cc (Wilber 2019-08-29 21:24:46 +0800 429) REGISTER_LITE_KERNEL(multiclass_nms, de43e479 lite/kernels/host/multiclass_nms_compute.cc (Wilber 2019-08-29 21:24:46 +0800 430) kHost, de43e479 lite/kernels/host/multiclass_nms_compute.cc (Wilber 2019-08-29 21:24:46 +0800 431) kFloat, de43e479 lite/kernels/host/multiclass_nms_compute.cc (Wilber 2019-08-29 21:24:46 +0800 432) kNCHW, de43e479 lite/kernels/host/multiclass_nms_compute.cc (Wilber 2019-08-29 21:24:46 +0800 433) paddle::lite::kernels::host::MulticlassNmsCompute, de43e479 lite/kernels/host/multiclass_nms_compute.cc (Wilber 2019-08-29 21:24:46 +0800 434) def) de43e479 lite/kernels/host/multiclass_nms_compute.cc (Wilber 2019-08-29 21:24:46 +0800 435) .BindInput("BBoxes", {LiteType::GetTensorTy(TARGET(kHost))}) de43e479 lite/kernels/host/multiclass_nms_compute.cc (Wilber 2019-08-29 21:24:46 +0800 436) .BindInput("Scores", {LiteType::GetTensorTy(TARGET(kHost))}) de43e479 lite/kernels/host/multiclass_nms_compute.cc (Wilber 2019-08-29 21:24:46 +0800 437) .BindOutput("Out", {LiteType::GetTensorTy(TARGET(kHost))}) 0679feed lite/kernels/host/multiclass_nms_compute.cc (yiicy 2020-02-20 11:39:43 +0800 438) .Finalize(); 0679feed lite/kernels/host/multiclass_nms_compute.cc (yiicy 2020-02-20 11:39:43 +0800 439) 0679feed lite/kernels/host/multiclass_nms_compute.cc (yiicy 2020-02-20 11:39:43 +0800 440) REGISTER_LITE_KERNEL(multiclass_nms2, 0679feed lite/kernels/host/multiclass_nms_compute.cc (yiicy 2020-02-20 11:39:43 +0800 441) kHost, 0679feed lite/kernels/host/multiclass_nms_compute.cc (yiicy 2020-02-20 11:39:43 +0800 442) kFloat, 0679feed lite/kernels/host/multiclass_nms_compute.cc (yiicy 2020-02-20 11:39:43 +0800 443) kNCHW, 0679feed lite/kernels/host/multiclass_nms_compute.cc (yiicy 2020-02-20 11:39:43 +0800 444) paddle::lite::kernels::host::MulticlassNmsCompute, 0679feed lite/kernels/host/multiclass_nms_compute.cc (yiicy 2020-02-20 11:39:43 +0800 445) def) 0679feed lite/kernels/host/multiclass_nms_compute.cc (yiicy 2020-02-20 11:39:43 +0800 446) .BindInput("BBoxes", {LiteType::GetTensorTy(TARGET(kHost))}) 0679feed lite/kernels/host/multiclass_nms_compute.cc (yiicy 2020-02-20 11:39:43 +0800 447) .BindInput("Scores", {LiteType::GetTensorTy(TARGET(kHost))}) 0679feed lite/kernels/host/multiclass_nms_compute.cc (yiicy 2020-02-20 11:39:43 +0800 448) .BindOutput("Out", {LiteType::GetTensorTy(TARGET(kHost))}) e1c4adfd lite/kernels/host/multiclass_nms_compute.cc (yiicy 2019-12-24 16:10:19 +0800 449) .BindOutput("Index", e1c4adfd lite/kernels/host/multiclass_nms_compute.cc (yiicy 2019-12-24 16:10:19 +0800 450) {LiteType::GetTensorTy(TARGET(kHost), PRECISION(kInt32))}) de43e479 lite/kernels/host/multiclass_nms_compute.cc (Wilber 2019-08-29 21:24:46 +0800 451) .Finalize();