preprocess_op.cpp 5.2 KB
Newer Older
littletomatodonkey's avatar
littletomatodonkey 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27
// Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
//     http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#include <include/preprocess_op.h>

namespace PaddleOCR {

void Permute::Run(const cv::Mat *im, float *data) {
  int rh = im->rows;
  int rw = im->cols;
  int rc = im->channels();
  for (int i = 0; i < rc; ++i) {
    cv::extractChannel(*im, cv::Mat(rh, rw, CV_32FC1, data + i * rh * rw), i);
  }
}

M
MissPenguin 已提交
28
void PermuteBatch::Run(const std::vector<cv::Mat> imgs, float *data) {
A
andyjpaddle 已提交
29 30 31 32 33 34 35
  for (int j = 0; j < imgs.size(); j++) {
    int rh = imgs[j].rows;
    int rw = imgs[j].cols;
    int rc = imgs[j].channels();
    for (int i = 0; i < rc; ++i) {
      cv::extractChannel(
          imgs[j], cv::Mat(rh, rw, CV_32FC1, data + (j * rc + i) * rh * rw), i);
M
MissPenguin 已提交
36
    }
A
andyjpaddle 已提交
37
  }
M
MissPenguin 已提交
38
}
A
andyjpaddle 已提交
39

littletomatodonkey's avatar
littletomatodonkey 已提交
40 41 42 43 44 45 46
void Normalize::Run(cv::Mat *im, const std::vector<float> &mean,
                    const std::vector<float> &scale, const bool is_scale) {
  double e = 1.0;
  if (is_scale) {
    e /= 255.0;
  }
  (*im).convertTo(*im, CV_32FC3, e);
L
littletomatodonkey 已提交
47 48 49 50 51
  std::vector<cv::Mat> bgr_channels(3);
  cv::split(*im, bgr_channels);
  for (auto i = 0; i < bgr_channels.size(); i++) {
    bgr_channels[i].convertTo(bgr_channels[i], CV_32FC1, 1.0 * scale[i],
                              (0.0 - mean[i]) * scale[i]);
littletomatodonkey's avatar
littletomatodonkey 已提交
52
  }
L
littletomatodonkey 已提交
53
  cv::merge(bgr_channels, *im);
littletomatodonkey's avatar
littletomatodonkey 已提交
54 55 56
}

void ResizeImgType0::Run(const cv::Mat &img, cv::Mat &resize_img,
文幕地方's avatar
fix bug  
文幕地方 已提交
57 58
                         std::string limit_type, int limit_side_len,
                         float &ratio_h, float &ratio_w, bool use_tensorrt) {
littletomatodonkey's avatar
littletomatodonkey 已提交
59 60 61
  int w = img.cols;
  int h = img.rows;
  float ratio = 1.f;
文幕地方's avatar
文幕地方 已提交
62
  if (limit_type == "min") {
文幕地方's avatar
fix bug  
文幕地方 已提交
63
    int min_wh = std::min(h, w);
文幕地方's avatar
文幕地方 已提交
64 65 66 67 68 69 70 71
    if (min_wh < limit_side_len) {
      if (h < w) {
        ratio = float(limit_side_len) / float(h);
      } else {
        ratio = float(limit_side_len) / float(w);
      }
    }
  } else {
文幕地方's avatar
fix bug  
文幕地方 已提交
72
    int max_wh = std::max(h, w);
文幕地方's avatar
文幕地方 已提交
73 74 75 76 77 78
    if (max_wh > limit_side_len) {
      if (h > w) {
        ratio = float(limit_side_len) / float(h);
      } else {
        ratio = float(limit_side_len) / float(w);
      }
littletomatodonkey's avatar
littletomatodonkey 已提交
79 80 81 82 83
    }
  }

  int resize_h = int(float(h) * ratio);
  int resize_w = int(float(w) * ratio);
L
LDOUBLEV 已提交
84

文幕地方's avatar
fix bug  
文幕地方 已提交
85 86
  resize_h = std::max(int(round(float(resize_h) / 32) * 32), 32);
  resize_w = std::max(int(round(float(resize_w) / 32) * 32), 32);
littletomatodonkey's avatar
littletomatodonkey 已提交
87

L
LDOUBLEV 已提交
88 89 90
  cv::resize(img, resize_img, cv::Size(resize_w, resize_h));
  ratio_h = float(resize_h) / float(h);
  ratio_w = float(resize_w) / float(w);
littletomatodonkey's avatar
littletomatodonkey 已提交
91 92 93
}

void CrnnResizeImg::Run(const cv::Mat &img, cv::Mat &resize_img, float wh_ratio,
R
root 已提交
94
                        bool use_tensorrt,
littletomatodonkey's avatar
littletomatodonkey 已提交
95
                        const std::vector<int> &rec_image_shape) {
littletomatodonkey's avatar
littletomatodonkey 已提交
96 97 98 99
  int imgC, imgH, imgW;
  imgC = rec_image_shape[0];
  imgH = rec_image_shape[1];
  imgW = rec_image_shape[2];
A
andyjpaddle 已提交
100 101

  imgW = int(imgH * wh_ratio);
littletomatodonkey's avatar
littletomatodonkey 已提交
102 103 104

  float ratio = float(img.cols) / float(img.rows);
  int resize_w, resize_h;
M
MissPenguin 已提交
105

littletomatodonkey's avatar
littletomatodonkey 已提交
106 107 108 109
  if (ceilf(imgH * ratio) > imgW)
    resize_w = imgW;
  else
    resize_w = int(ceilf(imgH * ratio));
A
andyjpaddle 已提交
110

L
LDOUBLEV 已提交
111 112 113 114 115
  cv::resize(img, resize_img, cv::Size(resize_w, imgH), 0.f, 0.f,
             cv::INTER_LINEAR);
  cv::copyMakeBorder(resize_img, resize_img, 0, 0, 0,
                     int(imgW - resize_img.cols), cv::BORDER_CONSTANT,
                     {127, 127, 127});
littletomatodonkey's avatar
littletomatodonkey 已提交
116 117
}

Z
zhoujun 已提交
118
void ClsResizeImg::Run(const cv::Mat &img, cv::Mat &resize_img,
R
root 已提交
119
                       bool use_tensorrt,
Z
zhoujun 已提交
120 121 122 123 124 125 126 127 128 129 130 131 132
                       const std::vector<int> &rec_image_shape) {
  int imgC, imgH, imgW;
  imgC = rec_image_shape[0];
  imgH = rec_image_shape[1];
  imgW = rec_image_shape[2];

  float ratio = float(img.cols) / float(img.rows);
  int resize_w, resize_h;
  if (ceilf(imgH * ratio) > imgW)
    resize_w = imgW;
  else
    resize_w = int(ceilf(imgH * ratio));

L
LDOUBLEV 已提交
133 134 135 136 137
  cv::resize(img, resize_img, cv::Size(resize_w, imgH), 0.f, 0.f,
             cv::INTER_LINEAR);
  if (resize_w < imgW) {
    cv::copyMakeBorder(resize_img, resize_img, 0, 0, 0, imgW - resize_w,
                       cv::BORDER_CONSTANT, cv::Scalar(0, 0, 0));
Z
zhoujun 已提交
138 139 140
  }
}

文幕地方's avatar
文幕地方 已提交
141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162
void TableResizeImg::Run(const cv::Mat &img, cv::Mat &resize_img,
                         const int max_len) {
  int w = img.cols;
  int h = img.rows;

  int max_wh = w >= h ? w : h;
  float ratio = w >= h ? float(max_len) / float(w) : float(max_len) / float(h);

  int resize_h = int(float(h) * ratio);
  int resize_w = int(float(w) * ratio);

  cv::resize(img, resize_img, cv::Size(resize_w, resize_h));
}

void TablePadImg::Run(const cv::Mat &img, cv::Mat &resize_img,
                      const int max_len) {
  int w = img.cols;
  int h = img.rows;
  cv::copyMakeBorder(img, resize_img, 0, max_len - h, 0, max_len - w,
                     cv::BORDER_CONSTANT, cv::Scalar(0, 0, 0));
}

文幕地方's avatar
文幕地方 已提交
163 164 165 166 167
void Resize::Run(const cv::Mat &img, cv::Mat &resize_img, const int h,
                 const int w) {
  cv::resize(img, resize_img, cv::Size(w, h));
}

Z
zhoujun 已提交
168
} // namespace PaddleOCR