utils.h 5.6 KB
Newer Older
J
joey12300 已提交
1 2 3 4 5 6
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
7
// http://www.apache.org/licenses/LICENSE-2.0
J
joey12300 已提交
8 9 10 11 12 13 14
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

W
wuzewu 已提交
15 16 17 18 19
#pragma once

#include <iostream>
#include <vector>
#include <string>
S
sjtubinlong 已提交
20

21 22 23 24
#include <opencv2/core/core.hpp>
#include <opencv2/imgproc/imgproc.hpp>
#include <opencv2/highgui/highgui.hpp>

S
sjtubinlong 已提交
25
#ifdef _WIN32
26 27
#define GLOG_NO_ABBREVIATED_SEVERITIES
#include <windows.h>
S
sjtubinlong 已提交
28 29 30 31
#else
#include <dirent.h>
#include <sys/types.h>
#endif
W
wuzewu 已提交
32 33

namespace PaddleSolution {
34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52
namespace utils {
    inline std::string path_join(const std::string& dir,
                                 const std::string& path) {
        std::string seperator = "/";
        #ifdef _WIN32
        seperator = "\\";
        #endif
        return dir + seperator + path;
    }
    #ifndef _WIN32
    // scan a directory and get all files with input extensions
    inline std::vector<std::string> get_directory_images(
                        const std::string& path, const std::string& exts) {
        std::vector<std::string> imgs;
        struct dirent *entry;
        DIR *dir = opendir(path.c_str());
        if (dir == NULL) {
            closedir(dir);
            return imgs;
W
wuzewu 已提交
53 54
        }

55 56 57 58 59 60 61 62
        while ((entry = readdir(dir)) != NULL) {
            std::string item = entry->d_name;
            auto ext = strrchr(entry->d_name, '.');
            if (!ext || std::string(ext) == "." || std::string(ext) == "..") {
                continue;
            }
            if (exts.find(ext) != std::string::npos) {
                imgs.push_back(path_join(path, entry->d_name));
S
sjtubinlong 已提交
63 64
            }
        }
65 66 67 68 69 70
        return imgs;
    }
    #else
    // scan a directory and get all files with input extensions
    inline std::vector<std::string> get_directory_images(
                    const std::string& path, const std::string& exts) {
71 72
        std::string pattern(path);
        pattern.append("\\*");
73
        std::vector<std::string> imgs;
74 75 76 77 78 79 80 81 82 83 84 85
        WIN32_FIND_DATA data;
        HANDLE hFind;
        if ((hFind = FindFirstFile(pattern.c_str(), &data)) != INVALID_HANDLE_VALUE) {
            do {
                auto fname = std::string(data.cFileName);
                auto pos = fname.rfind(".");
                auto ext = fname.substr(pos + 1);
                if (ext.size() > 1 && exts.find(ext) != std::string::npos) {
                    imgs.push_back(path + "\\" + data.cFileName);
                }
            } while (FindNextFile(hFind, &data) != 0);
            FindClose(hFind);
W
wuzewu 已提交
86
        }
87 88 89
        return imgs;
    }
    #endif
90

91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107
    // normalize and HWC_BGR -> CHW_RGB
    inline void normalize(cv::Mat& im, float* data, std::vector<float>& fmean,
                          std::vector<float>& fstd) {
        int rh = im.rows;
        int rw = im.cols;
        int rc = im.channels();
        double normf = static_cast<double>(1.0) / 255.0;
        #pragma omp parallel for
        for (int h = 0; h < rh; ++h) {
            const uchar* ptr = im.ptr<uchar>(h);
            int im_index = 0;
            for (int w = 0; w < rw; ++w) {
                for (int c = 0; c < rc; ++c) {
                    int top_index = (c * rh + h) * rw + w;
                    float pixel = static_cast<float>(ptr[im_index++]);
                    pixel = (pixel * normf - fmean[c]) / fstd[c];
                    data[top_index] = pixel;
108 109 110
                }
            }
        }
111
    }
112

113 114 115 116 117
    // flatten a cv::mat
    inline void flatten_mat(cv::Mat& im, float* data) {
        int rh = im.rows;
        int rw = im.cols;
        int rc = im.channels();
118
        #pragma omp parallel for
119 120 121
        for (int h = 0; h < rh; ++h) {
            const uchar* ptr = im.ptr<uchar>(h);
            int im_index = 0;
122
            int top_index = h * rw * rc;
123 124 125 126 127 128 129 130 131
            for (int w = 0; w < rw; ++w) {
                for (int c = 0; c < rc; ++c) {
                    float pixel = static_cast<float>(ptr[im_index++]);
                    data[top_index++] = pixel;
                }
            }
        }
    }

132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156
    // argmax
    inline void argmax(float* out, std::vector<int>& shape,
                       std::vector<uchar>& mask, std::vector<uchar>& scoremap) {
        int out_img_len = shape[1] * shape[2];
        int blob_out_len = out_img_len * shape[0];
        /*
        Eigen::TensorMap<Eigen::Tensor<float, 3>> out_3d(out, shape[0], shape[1], shape[2]);
        Eigen::Tensor<Eigen::DenseIndex, 2> argmax = out_3d.argmax(0);
        */
        float max_value = -1;
        int label = 0;
        #pragma omp parallel private(label)
        for (int i = 0; i < out_img_len; ++i) {
            max_value = -1;
            label = 0;
            #pragma omp for reduction(max : max_value)
            for (int j = 0; j < shape[0]; ++j) {
                int index = i + j * out_img_len;
                if (index >= blob_out_len) {
                    continue;
                }
                float value = out[index];
                if (value > max_value) {
                    max_value = value;
                    label = j;
157 158
                }
            }
159 160 161
            if (label == 0) max_value = 0;
            mask[i] = uchar(label);
            scoremap[i] = uchar(max_value * 255);
162
        }
W
wuzewu 已提交
163
    }
164 165
}  // namespace utils
}  // namespace PaddleSolution