main.cpp 11.9 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29
// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
//     http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#include "opencv2/core.hpp"
#include "opencv2/imgcodecs.hpp"
#include "opencv2/imgproc.hpp"
#include <chrono>
#include <iomanip>
#include <iostream>
#include <opencv2/core/utils/filesystem.hpp>
#include <ostream>
#include <vector>

#include <cstring>
#include <fstream>
#include <numeric>

#include <auto_log/autolog.h>
D
dongshuilong 已提交
30
#include <gflags/gflags.h>
D
dongshuilong 已提交
31
#include <include/feature_extracter.h>
D
dongshuilong 已提交
32
#include <include/nms.h>
33
#include <include/object_detector.h>
D
dongshuilong 已提交
34
#include <include/vector_search.h>
D
dongshuilong 已提交
35
#include <include/yaml_config.h>
36 37 38

using namespace std;
using namespace cv;
39

D
dongshuilong 已提交
40 41 42 43
DEFINE_string(config,
"", "Path of yaml file");
DEFINE_string(c,
"", "Path of yaml file");
D
dongshuilong 已提交
44

D
dongshuilong 已提交
45 46
void DetPredictImage(const std::vector <cv::Mat> &batch_imgs,
                     const std::vector <std::string> &all_img_paths,
D
dongshuilong 已提交
47
                     const int batch_size, Detection::ObjectDetector *det,
D
dongshuilong 已提交
48
                     std::vector <Detection::ObjectResult> &im_result,
49 50 51 52
                     std::vector<int> &im_bbox_num, std::vector<double> &det_t,
                     const bool visual_det = false,
                     const bool run_benchmark = false,
                     const std::string &output_dir = "output") {
D
dongshuilong 已提交
53 54 55 56 57 58 59 60 61 62 63 64 65
    int steps = ceil(float(all_img_paths.size()) / batch_size);
    //   printf("total images = %d, batch_size = %d, total steps = %d\n",
    //                 all_img_paths.size(), batch_size, steps);
    for (int idx = 0; idx < steps; idx++) {
        int left_image_cnt = all_img_paths.size() - idx * batch_size;
        if (left_image_cnt > batch_size) {
            left_image_cnt = batch_size;
        }
        // for (int bs = 0; bs < left_image_cnt; bs++) {
        // std::string image_file_path = all_img_paths.at(idx * batch_size+bs);
        // cv::Mat im = cv::imread(image_file_path, 1);
        // batch_imgs.insert(batch_imgs.end(), im);
        // }
66

D
dongshuilong 已提交
67 68 69 70 71 72 73 74 75 76 77 78
        // Store all detected result
        std::vector <Detection::ObjectResult> result;
        std::vector<int> bbox_num;
        std::vector<double> det_times;
        bool is_rbox = false;
        if (run_benchmark) {
            det->Predict(batch_imgs, 10, 10, &result, &bbox_num, &det_times);
        } else {
            det->Predict(batch_imgs, 0, 1, &result, &bbox_num, &det_times);
            // get labels and colormap
            auto labels = det->GetLabelList();
            auto colormap = Detection::GenerateColorMap(labels.size());
79

D
dongshuilong 已提交
80 81 82 83
            int item_start_idx = 0;
            for (int i = 0; i < left_image_cnt; i++) {
                cv::Mat im = batch_imgs[i];
                int detect_num = 0;
84

D
dongshuilong 已提交
85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108
                for (int j = 0; j < bbox_num[i]; j++) {
                    Detection::ObjectResult item = result[item_start_idx + j];
                    if (item.confidence < det->GetThreshold() || item.class_id == -1) {
                        continue;
                    }
                    detect_num += 1;
                    im_result.push_back(item);
                    if (visual_det) {
                        if (item.rect.size() > 6) {
                            is_rbox = true;
                            printf(
                                    "class=%d confidence=%.4f rect=[%d %d %d %d %d %d %d %d]\n",
                                    item.class_id, item.confidence, item.rect[0], item.rect[1],
                                    item.rect[2], item.rect[3], item.rect[4], item.rect[5],
                                    item.rect[6], item.rect[7]);
                        } else {
                            printf("class=%d confidence=%.4f rect=[%d %d %d %d]\n",
                                   item.class_id, item.confidence, item.rect[0], item.rect[1],
                                   item.rect[2], item.rect[3]);
                        }
                    }
                }
                im_bbox_num.push_back(detect_num);
                item_start_idx = item_start_idx + bbox_num[i];
109

D
dongshuilong 已提交
110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130
                // Visualization result
                if (visual_det) {
                    std::cout << all_img_paths.at(idx * batch_size + i)
                              << " The number of detected box: " << detect_num
                              << std::endl;
                    cv::Mat vis_img = Detection::VisualizeResult(im, im_result, labels,
                                                                 colormap, is_rbox);
                    std::vector<int> compression_params;
                    compression_params.push_back(CV_IMWRITE_JPEG_QUALITY);
                    compression_params.push_back(95);
                    std::string output_path(output_dir);
                    if (output_dir.rfind(OS_PATH_SEP) != output_dir.size() - 1) {
                        output_path += OS_PATH_SEP;
                    }
                    std::string image_file_path = all_img_paths.at(idx * batch_size + i);
                    output_path +=
                            image_file_path.substr(image_file_path.find_last_of('/') + 1);
                    cv::imwrite(output_path, vis_img, compression_params);
                    printf("Visualized output saved as %s\n", output_path.c_str());
                }
            }
131
        }
D
dongshuilong 已提交
132 133 134
        det_t[0] += det_times[0];
        det_t[1] += det_times[1];
        det_t[2] += det_times[2];
135 136
    }
}
137

D
dongshuilong 已提交
138
void PrintResult(std::string &img_path,
D
dongshuilong 已提交
139
                 std::vector <Detection::ObjectResult> &det_result,
D
dongshuilong 已提交
140 141
                 std::vector<int> &indeices, VectorSearch &vector_search,
                 SearchResult &search_result) {
D
dongshuilong 已提交
142 143 144 145 146 147 148 149 150
    printf("%s:\n", img_path.c_str());
    for (int i = 0; i < indeices.size(); ++i) {
        int t = indeices[i];
        printf("\tresult%d: bbox[%d, %d, %d, %d], score: %f, label: %s\n", i,
               det_result[t].rect[0], det_result[t].rect[1], det_result[t].rect[2],
               det_result[t].rect[3], det_result[t].confidence,
               vector_search.GetLabel(search_result.I[search_result.return_k * t])
                       .c_str());
    }
D
dongshuilong 已提交
151 152
}

153
int main(int argc, char **argv) {
D
dongshuilong 已提交
154 155 156 157 158 159 160 161 162 163 164 165 166
    google::ParseCommandLineFlags(&argc, &argv, true);
    std::string yaml_path = "";
    if (FLAGS_config == "" && FLAGS_c == "") {
        std::cerr << "[ERROR] usage: " << std::endl
                  << argv[0] << " -c $yaml_path" << std::endl
                  << "or:" << std::endl
                  << argv[0] << " -config $yaml_path" << std::endl;
        exit(1);
    } else if (FLAGS_config != "") {
        yaml_path = FLAGS_config;
    } else {
        yaml_path = FLAGS_c;
    }
167

D
dongshuilong 已提交
168 169
    YamlConfig config(yaml_path);
    config.PrintConfigInfo();
D
dongshuilong 已提交
170

D
dongshuilong 已提交
171 172 173 174
    // initialize detector, rec_Model, vector_search
    Feature::FeatureExtracter feature_extracter(config.config_file);
    Detection::ObjectDetector detector(config.config_file);
    VectorSearch searcher(config.config_file);
175

D
dongshuilong 已提交
176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193
    // config
    const int batch_size = config.config_file["Global"]["batch_size"].as<int>();
    bool visual_det = false;
    if (config.config_file["Global"]["visual_det"].IsDefined()) {
        visual_det = config.config_file["Global"]["visual_det"].as<bool>();
    }
    bool run_benchmark = false;
    if (config.config_file["Global"]["benchmark"].IsDefined()) {
        run_benchmark = config.config_file["Global"]["benchmark"].as<bool>();
    }
    int max_det_results = 5;
    if (config.config_file["Global"]["max_det_results"].IsDefined()) {
        max_det_results = config.config_file["Global"]["max_det_results"].as<int>();
    }
    float rec_nms_thresold = 0.05;
    if (config.config_file["Global"]["rec_nms_thresold"].IsDefined()) {
        rec_nms_thresold =
                config.config_file["Global"]["rec_nms_thresold"].as<float>();
194
    }
195

D
dongshuilong 已提交
196 197 198 199 200 201 202 203 204 205 206 207
    // load image_file_path
    std::string path =
            config.config_file["Global"]["infer_imgs"].as<std::string>();
    std::vector <std::string> img_files_list;
    if (cv::utils::fs::isDirectory(path)) {
        std::vector <cv::String> filenames;
        cv::glob(path, filenames);
        for (auto f : filenames) {
            img_files_list.push_back(f);
        }
    } else {
        img_files_list.push_back(path);
208
    }
D
dongshuilong 已提交
209 210 211 212 213 214 215 216 217 218 219 220 221 222 223
    std::cout << "img_file_list length: " << img_files_list.size() << std::endl;
    // for time log
    std::vector<double> cls_times = {0, 0, 0};
    std::vector<double> det_times = {0, 0, 0};
    // for read images
    std::vector <cv::Mat> batch_imgs;
    std::vector <std::string> img_paths;
    // for detection
    std::vector <Detection::ObjectResult> det_result;
    std::vector<int> det_bbox_num;
    // for vector search
    std::vector<float> features;
    std::vector<float> feature;
    // for nms
    std::vector<int> indeices;
224

D
dongshuilong 已提交
225 226 227 228 229 230 231 232 233 234
    int warmup_iter = img_files_list.size() > 5 ? 5 : 0;
    for (int idx = 0; idx < img_files_list.size(); ++idx) {
        std::string img_path = img_files_list[idx];
        cv::Mat srcimg = cv::imread(img_path, cv::IMREAD_COLOR);
        if (!srcimg.data) {
            std::cerr << "[ERROR] image read failed! image path: " << img_path
                      << "\n";
            exit(-1);
        }
        cv::cvtColor(srcimg, srcimg, cv::COLOR_BGR2RGB);
235

D
dongshuilong 已提交
236 237
        batch_imgs.push_back(srcimg);
        img_paths.push_back(img_path);
238

D
dongshuilong 已提交
239 240 241
        // step1: get all detection results
        DetPredictImage(batch_imgs, img_paths, batch_size, &detector, det_result,
                        det_bbox_num, det_times, visual_det, run_benchmark);
242

D
dongshuilong 已提交
243 244 245 246 247 248 249 250 251
        // select max_det_results bbox
        if (det_result.size() > max_det_results) {
            det_result.resize(max_det_results);
        }
        // step2: add the whole image for recognition to improve recall
        Detection::ObjectResult result_whole_img = {
                {0, 0, srcimg.cols - 1, srcimg.rows - 1}, 0, 1.0};
        det_result.push_back(result_whole_img);
        det_bbox_num[0] = det_result.size() + 1;
D
dongshuilong 已提交
252

D
dongshuilong 已提交
253 254 255 256 257 258 259 260 261 262
        // step3: extract feature for all boxes in an inmage
        SearchResult search_result;
        for (int j = 0; j < det_result.size(); ++j) {
            int w = det_result[j].rect[2] - det_result[j].rect[0];
            int h = det_result[j].rect[3] - det_result[j].rect[1];
            cv::Rect rect(det_result[j].rect[0], det_result[j].rect[1], w, h);
            cv::Mat crop_img = srcimg(rect);
            feature_extracter.Run(crop_img, feature, cls_times);
            features.insert(features.end(), feature.begin(), feature.end());
        }
D
dongshuilong 已提交
263

D
dongshuilong 已提交
264 265 266 267 268 269 270 271
        // step4: get search result
        search_result = searcher.Search(features.data(), det_result.size());

        // nms for search result
        for (int i = 0; i < det_result.size(); ++i) {
            det_result[i].confidence = search_result.D[search_result.return_k * i];
        }
        NMSBoxes(det_result, searcher.GetThreshold(), rec_nms_thresold, indeices);
D
dongshuilong 已提交
272

D
dongshuilong 已提交
273 274
        // print result
        PrintResult(img_path, det_result, indeices, searcher, search_result);
D
dongshuilong 已提交
275

D
dongshuilong 已提交
276 277 278 279 280 281 282 283 284
        // for postprocess
        batch_imgs.clear();
        img_paths.clear();
        det_bbox_num.clear();
        det_result.clear();
        feature.clear();
        features.clear();
        indeices.clear();
    }
285

D
dongshuilong 已提交
286
    std::string presion = "fp32";
287

D
dongshuilong 已提交
288 289 290 291 292 293 294 295 296 297
    // if (config.use_fp16)
    //   presion = "fp16";
    // if (config.benchmark) {
    //   AutoLogger autolog("Classification", config.use_gpu, config.use_tensorrt,
    //                      config.use_mkldnn, config.cpu_threads, 1,
    //                      "1, 3, 224, 224", presion, cls_times,
    //                      img_files_list.size());
    //   autolog.report();
    // }
    return 0;
298
}