postprocess_op.cpp 10.9 KB
Newer Older
littletomatodonkey's avatar
littletomatodonkey 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18
// Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
//     http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#include <include/postprocess_op.h>

namespace PaddleOCR {

littletomatodonkey's avatar
littletomatodonkey 已提交
19 20
void PostProcessor::GetContourArea(const std::vector<std::vector<float>> &box,
                                   float unclip_ratio, float &distance) {
littletomatodonkey's avatar
littletomatodonkey 已提交
21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36
  int pts_num = 4;
  float area = 0.0f;
  float dist = 0.0f;
  for (int i = 0; i < pts_num; i++) {
    area += box[i][0] * box[(i + 1) % pts_num][1] -
            box[i][1] * box[(i + 1) % pts_num][0];
    dist += sqrtf((box[i][0] - box[(i + 1) % pts_num][0]) *
                      (box[i][0] - box[(i + 1) % pts_num][0]) +
                  (box[i][1] - box[(i + 1) % pts_num][1]) *
                      (box[i][1] - box[(i + 1) % pts_num][1]));
  }
  area = fabs(float(area / 2.0));

  distance = area * unclip_ratio / dist;
}

littletomatodonkey's avatar
littletomatodonkey 已提交
37 38
cv::RotatedRect PostProcessor::UnClip(std::vector<std::vector<float>> box,
                                      const float &unclip_ratio) {
littletomatodonkey's avatar
littletomatodonkey 已提交
39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59
  float distance = 1.0;

  GetContourArea(box, unclip_ratio, distance);

  ClipperLib::ClipperOffset offset;
  ClipperLib::Path p;
  p << ClipperLib::IntPoint(int(box[0][0]), int(box[0][1]))
    << ClipperLib::IntPoint(int(box[1][0]), int(box[1][1]))
    << ClipperLib::IntPoint(int(box[2][0]), int(box[2][1]))
    << ClipperLib::IntPoint(int(box[3][0]), int(box[3][1]));
  offset.AddPath(p, ClipperLib::jtRound, ClipperLib::etClosedPolygon);

  ClipperLib::Paths soln;
  offset.Execute(soln, distance);
  std::vector<cv::Point2f> points;

  for (int j = 0; j < soln.size(); j++) {
    for (int i = 0; i < soln[soln.size() - 1].size(); i++) {
      points.emplace_back(soln[j][i].X, soln[j][i].Y);
    }
  }
60 61 62 63 64 65
  cv::RotatedRect res;
  if (points.size() <= 0) {
    res = cv::RotatedRect(cv::Point2f(0, 0), cv::Size2f(1, 1), 0);
  } else {
    res = cv::minAreaRect(points);
  }
littletomatodonkey's avatar
littletomatodonkey 已提交
66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82
  return res;
}

float **PostProcessor::Mat2Vec(cv::Mat mat) {
  auto **array = new float *[mat.rows];
  for (int i = 0; i < mat.rows; ++i)
    array[i] = new float[mat.cols];
  for (int i = 0; i < mat.rows; ++i) {
    for (int j = 0; j < mat.cols; ++j) {
      array[i][j] = mat.at<float>(i, j);
    }
  }

  return array;
}

std::vector<std::vector<int>>
littletomatodonkey's avatar
littletomatodonkey 已提交
83
PostProcessor::OrderPointsClockwise(std::vector<std::vector<int>> pts) {
littletomatodonkey's avatar
littletomatodonkey 已提交
84
  std::vector<std::vector<int>> box = pts;
littletomatodonkey's avatar
littletomatodonkey 已提交
85 86
  std::sort(box.begin(), box.end(), XsortInt);

littletomatodonkey's avatar
littletomatodonkey 已提交
87 88 89 90 91 92 93 94 95 96 97 98 99 100
  std::vector<std::vector<int>> leftmost = {box[0], box[1]};
  std::vector<std::vector<int>> rightmost = {box[2], box[3]};

  if (leftmost[0][1] > leftmost[1][1])
    std::swap(leftmost[0], leftmost[1]);

  if (rightmost[0][1] > rightmost[1][1])
    std::swap(rightmost[0], rightmost[1]);

  std::vector<std::vector<int>> rect = {leftmost[0], rightmost[0], rightmost[1],
                                        leftmost[1]};
  return rect;
}

littletomatodonkey's avatar
littletomatodonkey 已提交
101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129
std::vector<std::vector<float>> PostProcessor::Mat2Vector(cv::Mat mat) {
  std::vector<std::vector<float>> img_vec;
  std::vector<float> tmp;

  for (int i = 0; i < mat.rows; ++i) {
    tmp.clear();
    for (int j = 0; j < mat.cols; ++j) {
      tmp.push_back(mat.at<float>(i, j));
    }
    img_vec.push_back(tmp);
  }
  return img_vec;
}

bool PostProcessor::XsortFp32(std::vector<float> a, std::vector<float> b) {
  if (a[0] != b[0])
    return a[0] < b[0];
  return false;
}

bool PostProcessor::XsortInt(std::vector<int> a, std::vector<int> b) {
  if (a[0] != b[0])
    return a[0] < b[0];
  return false;
}

std::vector<std::vector<float>> PostProcessor::GetMiniBoxes(cv::RotatedRect box,
                                                            float &ssid) {
  ssid = std::max(box.size.width, box.size.height);
littletomatodonkey's avatar
littletomatodonkey 已提交
130 131 132 133

  cv::Mat points;
  cv::boxPoints(box, points);

littletomatodonkey's avatar
littletomatodonkey 已提交
134 135 136 137 138
  auto array = Mat2Vector(points);
  std::sort(array.begin(), array.end(), XsortFp32);

  std::vector<float> idx1 = array[0], idx2 = array[1], idx3 = array[2],
                     idx4 = array[3];
littletomatodonkey's avatar
littletomatodonkey 已提交
139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161
  if (array[3][1] <= array[2][1]) {
    idx2 = array[3];
    idx3 = array[2];
  } else {
    idx2 = array[2];
    idx3 = array[3];
  }
  if (array[1][1] <= array[0][1]) {
    idx1 = array[1];
    idx4 = array[0];
  } else {
    idx1 = array[0];
    idx4 = array[1];
  }

  array[0] = idx1;
  array[1] = idx2;
  array[2] = idx3;
  array[3] = idx4;

  return array;
}

162
float PostProcessor::PolygonScoreAcc(std::vector<cv::Point> contour,
163
                                     cv::Mat pred) {
164 165 166 167
  int width = pred.cols;
  int height = pred.rows;
  std::vector<float> box_x;
  std::vector<float> box_y;
168
  for (int i = 0; i < contour.size(); ++i) {
169 170 171 172
    box_x.push_back(contour[i].x);
    box_y.push_back(contour[i].y);
  }

173 174 175 176 177 178 179 180 181 182 183 184
  int xmin =
      clamp(int(std::floor(*(std::min_element(box_x.begin(), box_x.end())))), 0,
            width - 1);
  int xmax =
      clamp(int(std::ceil(*(std::max_element(box_x.begin(), box_x.end())))), 0,
            width - 1);
  int ymin =
      clamp(int(std::floor(*(std::min_element(box_y.begin(), box_y.end())))), 0,
            height - 1);
  int ymax =
      clamp(int(std::ceil(*(std::max_element(box_y.begin(), box_y.end())))), 0,
            height - 1);
185 186 187 188 189

  cv::Mat mask;
  mask = cv::Mat::zeros(ymax - ymin + 1, xmax - xmin + 1, CV_8UC1);

  cv::Point rook_point[contour.size()];
190
  for (int i = 0; i < contour.size(); ++i) {
191 192 193 194 195 196 197
    rook_point[i] = cv::Point(int(box_x[i]) - xmin, int(box_y[i]) - ymin);
  }
  const cv::Point *ppt[1] = {rook_point};
  int npt[] = {int(contour.size())};
  cv::fillPoly(mask, ppt, npt, 1, cv::Scalar(1));

  cv::Mat croppedImg;
198 199
  pred(cv::Rect(xmin, ymin, xmax - xmin + 1, ymax - ymin + 1))
      .copyTo(croppedImg);
200 201 202 203
  float score = cv::mean(croppedImg, mask)[0];
  return score;
}

littletomatodonkey's avatar
littletomatodonkey 已提交
204 205
float PostProcessor::BoxScoreFast(std::vector<std::vector<float>> box_array,
                                  cv::Mat pred) {
littletomatodonkey's avatar
littletomatodonkey 已提交
206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241
  auto array = box_array;
  int width = pred.cols;
  int height = pred.rows;

  float box_x[4] = {array[0][0], array[1][0], array[2][0], array[3][0]};
  float box_y[4] = {array[0][1], array[1][1], array[2][1], array[3][1]};

  int xmin = clamp(int(std::floor(*(std::min_element(box_x, box_x + 4)))), 0,
                   width - 1);
  int xmax = clamp(int(std::ceil(*(std::max_element(box_x, box_x + 4)))), 0,
                   width - 1);
  int ymin = clamp(int(std::floor(*(std::min_element(box_y, box_y + 4)))), 0,
                   height - 1);
  int ymax = clamp(int(std::ceil(*(std::max_element(box_y, box_y + 4)))), 0,
                   height - 1);

  cv::Mat mask;
  mask = cv::Mat::zeros(ymax - ymin + 1, xmax - xmin + 1, CV_8UC1);

  cv::Point root_point[4];
  root_point[0] = cv::Point(int(array[0][0]) - xmin, int(array[0][1]) - ymin);
  root_point[1] = cv::Point(int(array[1][0]) - xmin, int(array[1][1]) - ymin);
  root_point[2] = cv::Point(int(array[2][0]) - xmin, int(array[2][1]) - ymin);
  root_point[3] = cv::Point(int(array[3][0]) - xmin, int(array[3][1]) - ymin);
  const cv::Point *ppt[1] = {root_point};
  int npt[] = {4};
  cv::fillPoly(mask, ppt, npt, 1, cv::Scalar(1));

  cv::Mat croppedImg;
  pred(cv::Rect(xmin, ymin, xmax - xmin + 1, ymax - ymin + 1))
      .copyTo(croppedImg);

  auto score = cv::mean(croppedImg, mask)[0];
  return score;
}

242 243 244
std::vector<std::vector<std::vector<int>>> PostProcessor::BoxesFromBitmap(
    const cv::Mat pred, const cv::Mat bitmap, const float &box_thresh,
    const float &det_db_unclip_ratio, const bool &use_polygon_score) {
littletomatodonkey's avatar
littletomatodonkey 已提交
245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262
  const int min_size = 3;
  const int max_candidates = 1000;

  int width = bitmap.cols;
  int height = bitmap.rows;

  std::vector<std::vector<cv::Point>> contours;
  std::vector<cv::Vec4i> hierarchy;

  cv::findContours(bitmap, contours, hierarchy, cv::RETR_LIST,
                   cv::CHAIN_APPROX_SIMPLE);

  int num_contours =
      contours.size() >= max_candidates ? max_candidates : contours.size();

  std::vector<std::vector<std::vector<int>>> boxes;

  for (int _i = 0; _i < num_contours; _i++) {
littletomatodonkey's avatar
littletomatodonkey 已提交
263
    if (contours[_i].size() <= 2) {
264 265
      continue;
    }
littletomatodonkey's avatar
littletomatodonkey 已提交
266 267
    float ssid;
    cv::RotatedRect box = cv::minAreaRect(contours[_i]);
littletomatodonkey's avatar
littletomatodonkey 已提交
268
    auto array = GetMiniBoxes(box, ssid);
littletomatodonkey's avatar
littletomatodonkey 已提交
269 270 271 272 273 274 275 276 277

    auto box_for_unclip = array;
    // end get_mini_box

    if (ssid < min_size) {
      continue;
    }

    float score;
278 279 280 281 282 283
    if (use_polygon_score)
      /* compute using polygon*/
      score = PolygonScoreAcc(contours[_i], pred);
    else
      score = BoxScoreFast(array, pred);

littletomatodonkey's avatar
littletomatodonkey 已提交
284 285 286 287
    if (score < box_thresh)
      continue;

    // start for unclip
littletomatodonkey's avatar
littletomatodonkey 已提交
288
    cv::RotatedRect points = UnClip(box_for_unclip, det_db_unclip_ratio);
289 290 291
    if (points.size.height < 1.001 && points.size.width < 1.001) {
      continue;
    }
littletomatodonkey's avatar
littletomatodonkey 已提交
292 293 294
    // end for unclip

    cv::RotatedRect clipbox = points;
littletomatodonkey's avatar
littletomatodonkey 已提交
295
    auto cliparray = GetMiniBoxes(clipbox, ssid);
littletomatodonkey's avatar
littletomatodonkey 已提交
296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318

    if (ssid < min_size + 2)
      continue;

    int dest_width = pred.cols;
    int dest_height = pred.rows;
    std::vector<std::vector<int>> intcliparray;

    for (int num_pt = 0; num_pt < 4; num_pt++) {
      std::vector<int> a{int(clampf(roundf(cliparray[num_pt][0] / float(width) *
                                           float(dest_width)),
                                    0, float(dest_width))),
                         int(clampf(roundf(cliparray[num_pt][1] /
                                           float(height) * float(dest_height)),
                                    0, float(dest_height)))};
      intcliparray.push_back(a);
    }
    boxes.push_back(intcliparray);

  } // end for
  return boxes;
}

littletomatodonkey's avatar
littletomatodonkey 已提交
319 320 321
std::vector<std::vector<std::vector<int>>>
PostProcessor::FilterTagDetRes(std::vector<std::vector<std::vector<int>>> boxes,
                               float ratio_h, float ratio_w, cv::Mat srcimg) {
littletomatodonkey's avatar
littletomatodonkey 已提交
322 323 324 325 326
  int oriimg_h = srcimg.rows;
  int oriimg_w = srcimg.cols;

  std::vector<std::vector<std::vector<int>>> root_points;
  for (int n = 0; n < boxes.size(); n++) {
littletomatodonkey's avatar
littletomatodonkey 已提交
327
    boxes[n] = OrderPointsClockwise(boxes[n]);
littletomatodonkey's avatar
littletomatodonkey 已提交
328 329 330 331 332 333 334 335 336 337 338 339 340 341 342
    for (int m = 0; m < boxes[0].size(); m++) {
      boxes[n][m][0] /= ratio_w;
      boxes[n][m][1] /= ratio_h;

      boxes[n][m][0] = int(_min(_max(boxes[n][m][0], 0), oriimg_w - 1));
      boxes[n][m][1] = int(_min(_max(boxes[n][m][1], 0), oriimg_h - 1));
    }
  }

  for (int n = 0; n < boxes.size(); n++) {
    int rect_width, rect_height;
    rect_width = int(sqrt(pow(boxes[n][0][0] - boxes[n][1][0], 2) +
                          pow(boxes[n][0][1] - boxes[n][1][1], 2)));
    rect_height = int(sqrt(pow(boxes[n][0][0] - boxes[n][3][0], 2) +
                           pow(boxes[n][0][1] - boxes[n][3][1], 2)));
Z
zhoujun 已提交
343
    if (rect_width <= 4 || rect_height <= 4)
littletomatodonkey's avatar
littletomatodonkey 已提交
344 345 346 347 348 349 350
      continue;
    root_points.push_back(boxes[n]);
  }
  return root_points;
}

} // namespace PaddleOCR