utils.h 1.8 KB
Newer Older
D
dongshuilong 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16
//   Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
//     http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#pragma once

17
#include <algorithm>
D
dongshuilong 已提交
18
#include <ctime>
L
lubin 已提交
19
#include <include/feature_extractor.h>
20
#include <memory>
D
dongshuilong 已提交
21
#include <numeric>
22 23 24
#include <string>
#include <utility>
#include <vector>
D
dongshuilong 已提交
25 26 27 28 29 30 31 32 33 34 35

namespace PPShiTu {

// Object Detection Result
struct ObjectResult {
  // Rectangle coordinates of detected object: left, right, top, down
  std::vector<int> rect;
  // Class id of detected object
  int class_id;
  // Confidence of detected object
  float confidence;
36 37 38

  // RecModel result
  std::vector<RESULT> rec_result;
D
dongshuilong 已提交
39 40
};

D
dongshuilong 已提交
41 42
void nms(std::vector<ObjectResult> &input_boxes, float nms_threshold,
         bool rec_nms = false);
D
dongshuilong 已提交
43

D
dongshuilong 已提交
44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60
template <typename T>
static inline bool SortScorePairDescend(const std::pair<float, T> &pair1,
		const std::pair<float, T> &pair2){
  return pair1.first > pair2.first;
}

float RectOverlap(const ObjectResult &a,
                  const ObjectResult &b);

inline void
GetMaxScoreIndex(const std::vector<ObjectResult> &det_result,
                 const float threshold,
                 std::vector<std::pair<float, int>> &score_index_vec);

void NMSBoxes(const std::vector<ObjectResult> det_result,
              const float score_threshold, const float nms_threshold,
              std::vector<int> &indices);
61
} // namespace PPShiTu