preprocess_op.cpp 5.0 KB
Newer Older
littletomatodonkey's avatar
littletomatodonkey 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62
// Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
//     http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#include "opencv2/core.hpp"
#include "opencv2/imgcodecs.hpp"
#include "opencv2/imgproc.hpp"
#include "paddle_api.h"
#include "paddle_inference_api.h"
#include <chrono>
#include <iomanip>
#include <iostream>
#include <ostream>
#include <vector>

#include <cstring>
#include <fstream>
#include <numeric>

#include <include/preprocess_op.h>

namespace PaddleOCR {

void Permute::Run(const cv::Mat *im, float *data) {
  int rh = im->rows;
  int rw = im->cols;
  int rc = im->channels();
  for (int i = 0; i < rc; ++i) {
    cv::extractChannel(*im, cv::Mat(rh, rw, CV_32FC1, data + i * rh * rw), i);
  }
}

void Normalize::Run(cv::Mat *im, const std::vector<float> &mean,
                    const std::vector<float> &scale, const bool is_scale) {
  double e = 1.0;
  if (is_scale) {
    e /= 255.0;
  }
  (*im).convertTo(*im, CV_32FC3, e);
  for (int h = 0; h < im->rows; h++) {
    for (int w = 0; w < im->cols; w++) {
      im->at<cv::Vec3f>(h, w)[0] =
          (im->at<cv::Vec3f>(h, w)[0] - mean[0]) * scale[0];
      im->at<cv::Vec3f>(h, w)[1] =
          (im->at<cv::Vec3f>(h, w)[1] - mean[1]) * scale[1];
      im->at<cv::Vec3f>(h, w)[2] =
          (im->at<cv::Vec3f>(h, w)[2] - mean[2]) * scale[2];
    }
  }
}

void ResizeImgType0::Run(const cv::Mat &img, cv::Mat &resize_img,
R
root 已提交
63 64
                         int max_size_len, float &ratio_h, float &ratio_w,
                         bool use_tensorrt) {
littletomatodonkey's avatar
littletomatodonkey 已提交
65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84
  int w = img.cols;
  int h = img.rows;

  float ratio = 1.f;
  int max_wh = w >= h ? w : h;
  if (max_wh > max_size_len) {
    if (h > w) {
      ratio = float(max_size_len) / float(h);
    } else {
      ratio = float(max_size_len) / float(w);
    }
  }

  int resize_h = int(float(h) * ratio);
  int resize_w = int(float(w) * ratio);
  if (resize_h % 32 == 0)
    resize_h = resize_h;
  else if (resize_h / 32 < 1 + 1e-5)
    resize_h = 32;
  else
L
LDOUBLEV 已提交
85
    resize_h = (resize_h / 32) * 32;
littletomatodonkey's avatar
littletomatodonkey 已提交
86 87 88

  if (resize_w % 32 == 0)
    resize_w = resize_w;
Z
zhoujun 已提交
89
  else if (resize_w / 32 < 1 + 1e-5)
littletomatodonkey's avatar
littletomatodonkey 已提交
90 91
    resize_w = 32;
  else
L
LDOUBLEV 已提交
92
    resize_w = (resize_w / 32) * 32;
R
root 已提交
93 94 95 96 97 98 99 100 101
  if (!use_tensorrt) {
    cv::resize(img, resize_img, cv::Size(resize_w, resize_h));
    ratio_h = float(resize_h) / float(h);
    ratio_w = float(resize_w) / float(w);
  } else {
    cv::resize(img, resize_img, cv::Size(640, 640));
    ratio_h = float(640) / float(h);
    ratio_w = float(640) / float(w);
  }
littletomatodonkey's avatar
littletomatodonkey 已提交
102 103 104
}

void CrnnResizeImg::Run(const cv::Mat &img, cv::Mat &resize_img, float wh_ratio,
R
root 已提交
105
                        bool use_tensorrt,
littletomatodonkey's avatar
littletomatodonkey 已提交
106
                        const std::vector<int> &rec_image_shape) {
littletomatodonkey's avatar
littletomatodonkey 已提交
107 108 109 110 111 112 113 114 115 116 117 118 119
  int imgC, imgH, imgW;
  imgC = rec_image_shape[0];
  imgH = rec_image_shape[1];
  imgW = rec_image_shape[2];

  imgW = int(32 * wh_ratio);

  float ratio = float(img.cols) / float(img.rows);
  int resize_w, resize_h;
  if (ceilf(imgH * ratio) > imgW)
    resize_w = imgW;
  else
    resize_w = int(ceilf(imgH * ratio));
R
root 已提交
120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136
  if (!use_tensorrt) {
    cv::resize(img, resize_img, cv::Size(resize_w, imgH), 0.f, 0.f,
               cv::INTER_LINEAR);
    cv::copyMakeBorder(resize_img, resize_img, 0, 0, 0,
                       int(imgW - resize_img.cols), cv::BORDER_CONSTANT,
                       {127, 127, 127});
  } else {
    int k = int(img.cols * 32 / img.rows);
    if (k >= 100) {
      cv::resize(img, resize_img, cv::Size(100, 32), 0.f, 0.f,
                 cv::INTER_LINEAR);
    } else {
      cv::resize(img, resize_img, cv::Size(k, 32), 0.f, 0.f, cv::INTER_LINEAR);
      cv::copyMakeBorder(resize_img, resize_img, 0, 0, 0, int(100 - k),
                         cv::BORDER_CONSTANT, {127, 127, 127});
    }
  }
littletomatodonkey's avatar
littletomatodonkey 已提交
137 138
}

Z
zhoujun 已提交
139
void ClsResizeImg::Run(const cv::Mat &img, cv::Mat &resize_img,
R
root 已提交
140
                       bool use_tensorrt,
Z
zhoujun 已提交
141 142 143 144 145 146 147 148 149 150 151 152 153
                       const std::vector<int> &rec_image_shape) {
  int imgC, imgH, imgW;
  imgC = rec_image_shape[0];
  imgH = rec_image_shape[1];
  imgW = rec_image_shape[2];

  float ratio = float(img.cols) / float(img.rows);
  int resize_w, resize_h;
  if (ceilf(imgH * ratio) > imgW)
    resize_w = imgW;
  else
    resize_w = int(ceilf(imgH * ratio));

R
root 已提交
154 155 156 157 158 159 160 161 162
  if (!use_tensorrt) {
    cv::resize(img, resize_img, cv::Size(resize_w, imgH), 0.f, 0.f,
               cv::INTER_LINEAR);
    if (resize_w < imgW) {
      cv::copyMakeBorder(resize_img, resize_img, 0, 0, 0, imgW - resize_w,
                         cv::BORDER_CONSTANT, cv::Scalar(0, 0, 0));
    }
  } else {
    cv::resize(img, resize_img, cv::Size(100, 32), 0.f, 0.f, cv::INTER_LINEAR);
Z
zhoujun 已提交
163 164 165 166
  }
}

} // namespace PaddleOCR