提交 603f234b 编写于 作者: 文幕地方's avatar 文幕地方

add layout

上级 bb97ad18
......@@ -49,6 +49,11 @@ DECLARE_int32(rec_batch_num);
DECLARE_string(rec_char_dict_path);
DECLARE_int32(rec_img_h);
DECLARE_int32(rec_img_w);
// layout model related
DECLARE_string(layout_model_dir);
DECLARE_string(layout_dict_path);
DECLARE_double(layout_score_threshold);
DECLARE_double(layout_nms_threshold);
// structure model related
DECLARE_string(table_model_dir);
DECLARE_int32(table_max_len);
......@@ -59,4 +64,5 @@ DECLARE_bool(merge_no_span_structure);
DECLARE_bool(det);
DECLARE_bool(rec);
DECLARE_bool(cls);
DECLARE_bool(table);
\ No newline at end of file
DECLARE_bool(table);
DECLARE_bool(layout);
\ No newline at end of file
......@@ -43,21 +43,26 @@ class PPOCR {
public:
explicit PPOCR();
~PPOCR();
std::vector<std::vector<OCRPredictResult>>
ocr(std::vector<cv::String> cv_all_img_names, bool det = true,
bool rec = true, bool cls = true);
std::vector<std::vector<OCRPredictResult>> ocr(std::vector<cv::Mat> img_list,
bool det = true,
bool rec = true,
bool cls = true);
std::vector<OCRPredictResult> ocr(cv::Mat img, bool det, bool rec, bool cls);
void reset_timer();
void benchmark_log(int img_num);
protected:
void det(cv::Mat img, std::vector<OCRPredictResult> &ocr_results,
std::vector<double> &times);
std::vector<double> time_info_det = {0, 0, 0};
std::vector<double> time_info_rec = {0, 0, 0};
std::vector<double> time_info_cls = {0, 0, 0};
void det(cv::Mat img, std::vector<OCRPredictResult> &ocr_results);
void rec(std::vector<cv::Mat> img_list,
std::vector<OCRPredictResult> &ocr_results,
std::vector<double> &times);
std::vector<OCRPredictResult> &ocr_results);
void cls(std::vector<cv::Mat> img_list,
std::vector<OCRPredictResult> &ocr_results,
std::vector<double> &times);
void log(std::vector<double> &det_times, std::vector<double> &rec_times,
std::vector<double> &cls_times, int img_num);
std::vector<OCRPredictResult> &ocr_results);
private:
DBDetector *detector_ = nullptr;
......
......@@ -31,6 +31,7 @@
#include <include/paddleocr.h>
#include <include/preprocess_op.h>
#include <include/structure_layout.h>
#include <include/structure_table.h>
#include <include/utility.h>
......@@ -42,23 +43,31 @@ class PaddleStructure : public PPOCR {
public:
explicit PaddleStructure();
~PaddleStructure();
std::vector<std::vector<StructurePredictResult>>
structure(std::vector<cv::String> cv_all_img_names, bool layout = false,
bool table = true);
std::vector<StructurePredictResult> structure(cv::Mat img,
bool layout = false,
bool table = true,
bool ocr = false);
void reset_timer();
void benchmark_log(int img_num);
private:
StructureTableRecognizer *recognizer_ = nullptr;
std::vector<double> time_info_table = {0, 0, 0};
std::vector<double> time_info_layout = {0, 0, 0};
StructureTableRecognizer *table_model_ = nullptr;
StructureLayoutRecognizer *layout_model_ = nullptr;
void layout(cv::Mat img,
std::vector<StructurePredictResult> &structure_result);
void table(cv::Mat img, StructurePredictResult &structure_result);
void table(cv::Mat img, StructurePredictResult &structure_result,
std::vector<double> &time_info_table,
std::vector<double> &time_info_det,
std::vector<double> &time_info_rec,
std::vector<double> &time_info_cls);
std::string rebuild_table(std::vector<std::string> rec_html_tags,
std::vector<std::vector<int>> rec_boxes,
std::vector<OCRPredictResult> &ocr_result);
float iou(std::vector<int> &box1, std::vector<int> &box2);
float dis(std::vector<int> &box1, std::vector<int> &box2);
static bool comparison_dis(const std::vector<float> &dis1,
......
......@@ -92,7 +92,23 @@ private:
class TablePostProcessor {
public:
void init(std::string label_path, bool merge_no_span_structure = true);
void init(std::string label_path, bool merge_no_span_structure = true) {
this->label_list_ = Utility::ReadDict(label_path);
if (merge_no_span_structure) {
this->label_list_.push_back("<td></td>");
std::vector<std::string>::iterator it;
for (it = this->label_list_.begin(); it != this->label_list_.end();) {
if (*it == "<td>") {
it = this->label_list_.erase(it);
} else {
++it;
}
}
}
// add_special_char
this->label_list_.insert(this->label_list_.begin(), this->beg);
this->label_list_.push_back(this->end);
}
void Run(std::vector<float> &loc_preds, std::vector<float> &structure_probs,
std::vector<float> &rec_scores, std::vector<int> &loc_preds_shape,
std::vector<int> &structure_probs_shape,
......@@ -106,4 +122,33 @@ private:
std::string beg = "sos";
};
class PicodetPostProcessor {
public:
void init(std::string label_path, const double score_threshold = 0.4,
const double nms_threshold = 0.5,
const std::vector<int> &fpn_stride = {8, 16, 32, 64}) {
this->label_list_ = Utility::ReadDict(label_path);
this->score_threshold_ = score_threshold;
this->nms_threshold_ = nms_threshold;
this->num_class_ = label_list_.size();
this->fpn_stride_ = fpn_stride;
}
void Run(std::vector<StructurePredictResult> &results,
std::vector<std::vector<float>> outs, std::vector<int> ori_shape,
std::vector<int> resize_shape, int eg_max);
std::vector<int> fpn_stride_ = {8, 16, 32, 64};
private:
StructurePredictResult disPred2Bbox(std::vector<float> bbox_pred, int label,
float score, int x, int y, int stride,
std::vector<int> im_shape, int reg_max);
void nms(std::vector<StructurePredictResult> &input_boxes,
float nms_threshold);
std::vector<std::string> label_list_;
double score_threshold_ = 0.4;
double nms_threshold_ = 0.5;
int num_class_ = 5;
};
} // namespace PaddleOCR
......@@ -82,4 +82,10 @@ public:
const int max_len = 488);
};
class Resize {
public:
virtual void Run(const cv::Mat &img, cv::Mat &resize_img, const int h,
const int w);
};
} // namespace PaddleOCR
\ No newline at end of file
// Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include "opencv2/core.hpp"
#include "opencv2/imgcodecs.hpp"
#include "opencv2/imgproc.hpp"
#include "paddle_api.h"
#include "paddle_inference_api.h"
#include <chrono>
#include <iomanip>
#include <iostream>
#include <ostream>
#include <vector>
#include <cstring>
#include <fstream>
#include <numeric>
#include <include/postprocess_op.h>
#include <include/preprocess_op.h>
#include <include/utility.h>
using namespace paddle_infer;
namespace PaddleOCR {
class StructureLayoutRecognizer {
public:
explicit StructureLayoutRecognizer(
const std::string &model_dir, const bool &use_gpu, const int &gpu_id,
const int &gpu_mem, const int &cpu_math_library_num_threads,
const bool &use_mkldnn, const string &label_path,
const bool &use_tensorrt, const std::string &precision,
const double &layout_score_threshold,
const double &layout_nms_threshold) {
this->use_gpu_ = use_gpu;
this->gpu_id_ = gpu_id;
this->gpu_mem_ = gpu_mem;
this->cpu_math_library_num_threads_ = cpu_math_library_num_threads;
this->use_mkldnn_ = use_mkldnn;
this->use_tensorrt_ = use_tensorrt;
this->precision_ = precision;
this->post_processor_.init(label_path, layout_score_threshold,
layout_nms_threshold);
LoadModel(model_dir);
}
// Load Paddle inference model
void LoadModel(const std::string &model_dir);
void Run(cv::Mat img, std::vector<StructurePredictResult> &result,
std::vector<double> &times);
private:
std::shared_ptr<Predictor> predictor_;
bool use_gpu_ = false;
int gpu_id_ = 0;
int gpu_mem_ = 4000;
int cpu_math_library_num_threads_ = 4;
bool use_mkldnn_ = false;
std::vector<float> mean_ = {0.485f, 0.456f, 0.406f};
std::vector<float> scale_ = {1 / 0.229f, 1 / 0.224f, 1 / 0.225f};
bool is_scale_ = true;
bool use_tensorrt_ = false;
std::string precision_ = "fp32";
// pre-process
Resize resize_op_;
Normalize normalize_op_;
Permute permute_op_;
// post-process
PicodetPostProcessor post_processor_;
}; // class StructureTableRecognizer
} // namespace PaddleOCR
\ No newline at end of file
......@@ -42,11 +42,13 @@ struct OCRPredictResult {
struct StructurePredictResult {
std::vector<int> box;
std::vector<float> box_float;
std::vector<std::vector<int>> cell_box;
std::string type;
std::vector<OCRPredictResult> text_res;
std::string html;
float html_score = -1;
float confidence;
};
class Utility {
......@@ -58,7 +60,7 @@ public:
const std::string &save_path);
static void VisualizeBboxes(const cv::Mat &srcimg,
const StructurePredictResult &structure_result,
StructurePredictResult &structure_result,
const std::string &save_path);
template <class ForwardIterator>
......@@ -89,6 +91,12 @@ public:
static std::vector<int> xyxyxyxy2xyxy(std::vector<std::vector<int>> &box);
static std::vector<int> xyxyxyxy2xyxy(std::vector<int> &box);
static float fast_exp(float x);
static std::vector<float>
activation_function_softmax(std::vector<float> &src);
static float iou(std::vector<int> &box1, std::vector<int> &box2);
static float iou(std::vector<float> &box1, std::vector<float> &box2);
private:
static bool comparison_box(const OCRPredictResult &result1,
const OCRPredictResult &result2) {
......
......@@ -51,6 +51,13 @@ DEFINE_string(rec_char_dict_path, "../../ppocr/utils/ppocr_keys_v1.txt",
DEFINE_int32(rec_img_h, 48, "rec image height");
DEFINE_int32(rec_img_w, 320, "rec image width");
// layout model related
DEFINE_string(layout_model_dir, "", "Path of table layout inference model.");
DEFINE_string(layout_dict_path,
"../../ppocr/utils/dict/layout_dict/layout_publaynet_dict.txt",
"Path of dictionary.");
DEFINE_double(layout_score_threshold, 0.5, "Threshold of score.");
DEFINE_double(layout_nms_threshold, 0.5, "Threshold of nms.");
// structure model related
DEFINE_string(table_model_dir, "", "Path of table struture inference model.");
DEFINE_int32(table_max_len, 488, "max len size of input image.");
......@@ -65,4 +72,5 @@ DEFINE_string(table_char_dict_path,
DEFINE_bool(det, true, "Whether use det in forward.");
DEFINE_bool(rec, true, "Whether use rec in forward.");
DEFINE_bool(cls, false, "Whether use cls in forward.");
DEFINE_bool(table, false, "Whether use table structure in forward.");
\ No newline at end of file
DEFINE_bool(table, false, "Whether use table structure in forward.");
DEFINE_bool(layout, false, "Whether use layout analysis in forward.");
\ No newline at end of file
......@@ -65,6 +65,14 @@ void check_params() {
exit(1);
}
}
if (FLAGS_layout) {
if (FLAGS_layout_model_dir.empty() || FLAGS_image_dir.empty()) {
std::cout << "Usage[layout]: ./ppocr "
<< "--layout_model_dir=/PATH/TO/LAYOUT_INFERENCE_MODEL/ "
<< "--image_dir=/PATH/TO/INPUT/IMAGE/" << std::endl;
exit(1);
}
}
if (FLAGS_precision != "fp32" && FLAGS_precision != "fp16" &&
FLAGS_precision != "int8") {
cout << "precison should be 'fp32'(default), 'fp16' or 'int8'. " << endl;
......@@ -75,71 +83,94 @@ void check_params() {
void ocr(std::vector<cv::String> &cv_all_img_names) {
PPOCR ocr = PPOCR();
std::vector<std::vector<OCRPredictResult>> ocr_results =
ocr.ocr(cv_all_img_names, FLAGS_det, FLAGS_rec, FLAGS_cls);
if (FLAGS_benchmark) {
ocr.reset_timer();
}
std::vector<cv::Mat> img_list;
std::vector<cv::String> img_names;
for (int i = 0; i < cv_all_img_names.size(); ++i) {
if (FLAGS_benchmark) {
cout << cv_all_img_names[i] << '\t';
if (FLAGS_rec && FLAGS_det) {
Utility::print_result(ocr_results[i]);
} else if (FLAGS_det) {
for (int n = 0; n < ocr_results[i].size(); n++) {
for (int m = 0; m < ocr_results[i][n].box.size(); m++) {
cout << ocr_results[i][n].box[m][0] << ' '
<< ocr_results[i][n].box[m][1] << ' ';
}
}
cout << endl;
} else {
Utility::print_result(ocr_results[i]);
}
} else {
cout << cv_all_img_names[i] << "\n";
Utility::print_result(ocr_results[i]);
if (FLAGS_visualize && FLAGS_det) {
cv::Mat srcimg = cv::imread(cv_all_img_names[i], cv::IMREAD_COLOR);
if (!srcimg.data) {
std::cerr << "[ERROR] image read failed! image path: "
<< cv_all_img_names[i] << endl;
exit(1);
}
std::string file_name = Utility::basename(cv_all_img_names[i]);
cv::Mat img = cv::imread(cv_all_img_names[i], cv::IMREAD_COLOR);
if (!img.data) {
std::cerr << "[ERROR] image read failed! image path: "
<< cv_all_img_names[i] << endl;
continue;
}
img_list.push_back(img);
img_names.push_back(cv_all_img_names[i]);
}
Utility::VisualizeBboxes(srcimg, ocr_results[i],
FLAGS_output + "/" + file_name);
}
cout << "***************************" << endl;
std::vector<std::vector<OCRPredictResult>> ocr_results =
ocr.ocr(img_list, FLAGS_det, FLAGS_rec, FLAGS_cls);
for (int i = 0; i < img_names.size(); ++i) {
cout << "predict img: " << cv_all_img_names[i] << endl;
Utility::print_result(ocr_results[i]);
if (FLAGS_visualize && FLAGS_det) {
std::string file_name = Utility::basename(img_names[i]);
cv::Mat srcimg = img_list[i];
Utility::VisualizeBboxes(srcimg, ocr_results[i],
FLAGS_output + "/" + file_name);
}
}
if (FLAGS_benchmark) {
ocr.benchmark_log(cv_all_img_names.size());
}
}
void structure(std::vector<cv::String> &cv_all_img_names) {
PaddleOCR::PaddleStructure engine = PaddleOCR::PaddleStructure();
std::vector<std::vector<StructurePredictResult>> structure_results =
engine.structure(cv_all_img_names, false, FLAGS_table);
if (FLAGS_benchmark) {
engine.reset_timer();
}
for (int i = 0; i < cv_all_img_names.size(); i++) {
cout << "predict img: " << cv_all_img_names[i] << endl;
cv::Mat srcimg = cv::imread(cv_all_img_names[i], cv::IMREAD_COLOR);
for (int j = 0; j < structure_results[i].size(); j++) {
std::cout << j << "\ttype: " << structure_results[i][j].type
cv::Mat img = cv::imread(cv_all_img_names[i], cv::IMREAD_COLOR);
if (!img.data) {
std::cerr << "[ERROR] image read failed! image path: "
<< cv_all_img_names[i] << endl;
continue;
}
std::vector<StructurePredictResult> structure_results = engine.structure(
img, FLAGS_layout, FLAGS_table, FLAGS_det && FLAGS_rec);
for (int j = 0; j < structure_results.size(); j++) {
std::cout << j << "\ttype: " << structure_results[j].type
<< ", region: [";
std::cout << structure_results[i][j].box[0] << ","
<< structure_results[i][j].box[1] << ","
<< structure_results[i][j].box[2] << ","
<< structure_results[i][j].box[3] << "], res: ";
if (structure_results[i][j].type == "table") {
std::cout << structure_results[i][j].html << std::endl;
std::string file_name = Utility::basename(cv_all_img_names[i]);
Utility::VisualizeBboxes(srcimg, structure_results[i][j],
FLAGS_output + "/" + std::to_string(j) + "_" +
file_name);
std::cout << structure_results[j].box[0] << ","
<< structure_results[j].box[1] << ","
<< structure_results[j].box[2] << ","
<< structure_results[j].box[3] << "], score: ";
std::cout << structure_results[j].confidence << ", res: ";
if (structure_results[j].type == "table") {
std::cout << structure_results[j].html << std::endl;
if (structure_results[j].cell_box.size() > 0 && FLAGS_visualize) {
std::string file_name = Utility::basename(cv_all_img_names[i]);
Utility::VisualizeBboxes(img, structure_results[j],
FLAGS_output + "/" + std::to_string(j) +
"_" + file_name);
}
} else {
Utility::print_result(structure_results[i][j].text_res);
cout << "count of ocr result is : "
<< structure_results[j].text_res.size() << endl;
if (structure_results[j].text_res.size() > 0) {
cout << "********** print ocr result "
<< "**********" << endl;
Utility::print_result(structure_results[j].text_res);
cout << "********** end print ocr result "
<< "**********" << endl;
}
}
}
}
if (FLAGS_benchmark) {
engine.benchmark_log(cv_all_img_names.size());
}
}
int main(int argc, char **argv) {
......@@ -157,6 +188,9 @@ int main(int argc, char **argv) {
cv::glob(FLAGS_image_dir, cv_all_img_names);
std::cout << "total images num: " << cv_all_img_names.size() << endl;
if (!Utility::PathExists(FLAGS_output)) {
Utility::CreateDir(FLAGS_output);
}
if (FLAGS_type == "ocr") {
ocr(cv_all_img_names);
} else if (FLAGS_type == "structure") {
......
......@@ -44,8 +44,71 @@ PPOCR::PPOCR() {
}
};
void PPOCR::det(cv::Mat img, std::vector<OCRPredictResult> &ocr_results,
std::vector<double> &times) {
std::vector<std::vector<OCRPredictResult>>
PPOCR::ocr(std::vector<cv::Mat> img_list, bool det, bool rec, bool cls) {
std::vector<std::vector<OCRPredictResult>> ocr_results;
if (!det) {
std::vector<OCRPredictResult> ocr_result;
ocr_result.resize(img_list.size());
if (cls && this->classifier_ != nullptr) {
this->cls(img_list, ocr_result);
for (int i = 0; i < img_list.size(); i++) {
if (ocr_result[i].cls_label % 2 == 1 &&
ocr_result[i].cls_score > this->classifier_->cls_thresh) {
cv::rotate(img_list[i], img_list[i], 1);
}
}
}
if (rec) {
this->rec(img_list, ocr_result);
}
for (int i = 0; i < ocr_result.size(); ++i) {
std::vector<OCRPredictResult> ocr_result_tmp;
ocr_result_tmp.push_back(ocr_result[i]);
ocr_results.push_back(ocr_result_tmp);
}
} else {
for (int i = 0; i < img_list.size(); ++i) {
std::vector<OCRPredictResult> ocr_result =
this->ocr(img_list[i], true, rec, cls);
ocr_results.push_back(ocr_result);
}
}
return ocr_results;
}
std::vector<OCRPredictResult> PPOCR::ocr(cv::Mat img, bool det, bool rec,
bool cls) {
std::vector<OCRPredictResult> ocr_result;
// det
this->det(img, ocr_result);
// crop image
std::vector<cv::Mat> img_list;
for (int j = 0; j < ocr_result.size(); j++) {
cv::Mat crop_img;
crop_img = Utility::GetRotateCropImage(img, ocr_result[j].box);
img_list.push_back(crop_img);
}
// cls
if (cls && this->classifier_ != nullptr) {
this->cls(img_list, ocr_result);
for (int i = 0; i < img_list.size(); i++) {
if (ocr_result[i].cls_label % 2 == 1 &&
ocr_result[i].cls_score > this->classifier_->cls_thresh) {
cv::rotate(img_list[i], img_list[i], 1);
}
}
}
// rec
if (rec) {
this->rec(img_list, ocr_result);
}
return ocr_result;
}
void PPOCR::det(cv::Mat img, std::vector<OCRPredictResult> &ocr_results) {
std::vector<std::vector<std::vector<int>>> boxes;
std::vector<double> det_times;
......@@ -58,14 +121,13 @@ void PPOCR::det(cv::Mat img, std::vector<OCRPredictResult> &ocr_results,
}
// sort boex from top to bottom, from left to right
Utility::sorted_boxes(ocr_results);
times[0] += det_times[0];
times[1] += det_times[1];
times[2] += det_times[2];
this->time_info_det[0] += det_times[0];
this->time_info_det[1] += det_times[1];
this->time_info_det[2] += det_times[2];
}
void PPOCR::rec(std::vector<cv::Mat> img_list,
std::vector<OCRPredictResult> &ocr_results,
std::vector<double> &times) {
std::vector<OCRPredictResult> &ocr_results) {
std::vector<std::string> rec_texts(img_list.size(), "");
std::vector<float> rec_text_scores(img_list.size(), 0);
std::vector<double> rec_times;
......@@ -75,14 +137,13 @@ void PPOCR::rec(std::vector<cv::Mat> img_list,
ocr_results[i].text = rec_texts[i];
ocr_results[i].score = rec_text_scores[i];
}
times[0] += rec_times[0];
times[1] += rec_times[1];
times[2] += rec_times[2];
this->time_info_rec[0] += rec_times[0];
this->time_info_rec[1] += rec_times[1];
this->time_info_rec[2] += rec_times[2];
}
void PPOCR::cls(std::vector<cv::Mat> img_list,
std::vector<OCRPredictResult> &ocr_results,
std::vector<double> &times) {
std::vector<OCRPredictResult> &ocr_results) {
std::vector<int> cls_labels(img_list.size(), 0);
std::vector<float> cls_scores(img_list.size(), 0);
std::vector<double> cls_times;
......@@ -92,125 +153,43 @@ void PPOCR::cls(std::vector<cv::Mat> img_list,
ocr_results[i].cls_label = cls_labels[i];
ocr_results[i].cls_score = cls_scores[i];
}
times[0] += cls_times[0];
times[1] += cls_times[1];
times[2] += cls_times[2];
this->time_info_cls[0] += cls_times[0];
this->time_info_cls[1] += cls_times[1];
this->time_info_cls[2] += cls_times[2];
}
std::vector<std::vector<OCRPredictResult>>
PPOCR::ocr(std::vector<cv::String> cv_all_img_names, bool det, bool rec,
bool cls) {
std::vector<double> time_info_det = {0, 0, 0};
std::vector<double> time_info_rec = {0, 0, 0};
std::vector<double> time_info_cls = {0, 0, 0};
std::vector<std::vector<OCRPredictResult>> ocr_results;
if (!det) {
std::vector<OCRPredictResult> ocr_result;
// read image
std::vector<cv::Mat> img_list;
for (int i = 0; i < cv_all_img_names.size(); ++i) {
cv::Mat srcimg = cv::imread(cv_all_img_names[i], cv::IMREAD_COLOR);
if (!srcimg.data) {
std::cerr << "[ERROR] image read failed! image path: "
<< cv_all_img_names[i] << endl;
exit(1);
}
img_list.push_back(srcimg);
OCRPredictResult res;
ocr_result.push_back(res);
}
if (cls && this->classifier_ != nullptr) {
this->cls(img_list, ocr_result, time_info_cls);
for (int i = 0; i < img_list.size(); i++) {
if (ocr_result[i].cls_label % 2 == 1 &&
ocr_result[i].cls_score > this->classifier_->cls_thresh) {
cv::rotate(img_list[i], img_list[i], 1);
}
}
}
if (rec) {
this->rec(img_list, ocr_result, time_info_rec);
}
for (int i = 0; i < cv_all_img_names.size(); ++i) {
std::vector<OCRPredictResult> ocr_result_tmp;
ocr_result_tmp.push_back(ocr_result[i]);
ocr_results.push_back(ocr_result_tmp);
}
} else {
if (!Utility::PathExists(FLAGS_output) && FLAGS_det) {
Utility::CreateDir(FLAGS_output);
}
for (int i = 0; i < cv_all_img_names.size(); ++i) {
std::vector<OCRPredictResult> ocr_result;
if (!FLAGS_benchmark) {
cout << "predict img: " << cv_all_img_names[i] << endl;
}
cv::Mat srcimg = cv::imread(cv_all_img_names[i], cv::IMREAD_COLOR);
if (!srcimg.data) {
std::cerr << "[ERROR] image read failed! image path: "
<< cv_all_img_names[i] << endl;
exit(1);
}
// det
this->det(srcimg, ocr_result, time_info_det);
// crop image
std::vector<cv::Mat> img_list;
for (int j = 0; j < ocr_result.size(); j++) {
cv::Mat crop_img;
crop_img = Utility::GetRotateCropImage(srcimg, ocr_result[j].box);
img_list.push_back(crop_img);
}
// cls
if (cls && this->classifier_ != nullptr) {
this->cls(img_list, ocr_result, time_info_cls);
for (int i = 0; i < img_list.size(); i++) {
if (ocr_result[i].cls_label % 2 == 1 &&
ocr_result[i].cls_score > this->classifier_->cls_thresh) {
cv::rotate(img_list[i], img_list[i], 1);
}
}
}
// rec
if (rec) {
this->rec(img_list, ocr_result, time_info_rec);
}
ocr_results.push_back(ocr_result);
}
}
if (FLAGS_benchmark) {
this->log(time_info_det, time_info_rec, time_info_cls,
cv_all_img_names.size());
}
return ocr_results;
} // namespace PaddleOCR
void PPOCR::reset_timer() {
this->time_info_det = {0, 0, 0};
this->time_info_rec = {0, 0, 0};
this->time_info_cls = {0, 0, 0};
}
void PPOCR::log(std::vector<double> &det_times, std::vector<double> &rec_times,
std::vector<double> &cls_times, int img_num) {
if (det_times[0] + det_times[1] + det_times[2] > 0) {
void PPOCR::benchmark_log(int img_num) {
if (this->time_info_det[0] + this->time_info_det[1] + this->time_info_det[2] >
0) {
AutoLogger autolog_det("ocr_det", FLAGS_use_gpu, FLAGS_use_tensorrt,
FLAGS_enable_mkldnn, FLAGS_cpu_threads, 1, "dynamic",
FLAGS_precision, det_times, img_num);
FLAGS_precision, this->time_info_det, img_num);
autolog_det.report();
}
if (rec_times[0] + rec_times[1] + rec_times[2] > 0) {
if (this->time_info_rec[0] + this->time_info_rec[1] + this->time_info_rec[2] >
0) {
AutoLogger autolog_rec("ocr_rec", FLAGS_use_gpu, FLAGS_use_tensorrt,
FLAGS_enable_mkldnn, FLAGS_cpu_threads,
FLAGS_rec_batch_num, "dynamic", FLAGS_precision,
rec_times, img_num);
this->time_info_rec, img_num);
autolog_rec.report();
}
if (cls_times[0] + cls_times[1] + cls_times[2] > 0) {
if (this->time_info_cls[0] + this->time_info_cls[1] + this->time_info_cls[2] >
0) {
AutoLogger autolog_cls("ocr_cls", FLAGS_use_gpu, FLAGS_use_tensorrt,
FLAGS_enable_mkldnn, FLAGS_cpu_threads,
FLAGS_cls_batch_num, "dynamic", FLAGS_precision,
cls_times, img_num);
this->time_info_cls, img_num);
autolog_cls.report();
}
}
PPOCR::~PPOCR() {
if (this->detector_ != nullptr) {
delete this->detector_;
......
......@@ -22,8 +22,15 @@
namespace PaddleOCR {
PaddleStructure::PaddleStructure() {
if (FLAGS_layout) {
this->layout_model_ = new StructureLayoutRecognizer(
FLAGS_layout_model_dir, FLAGS_use_gpu, FLAGS_gpu_id, FLAGS_gpu_mem,
FLAGS_cpu_threads, FLAGS_enable_mkldnn, FLAGS_layout_dict_path,
FLAGS_use_tensorrt, FLAGS_precision, FLAGS_layout_score_threshold,
FLAGS_layout_nms_threshold);
}
if (FLAGS_table) {
this->recognizer_ = new StructureTableRecognizer(
this->table_model_ = new StructureTableRecognizer(
FLAGS_table_model_dir, FLAGS_use_gpu, FLAGS_gpu_id, FLAGS_gpu_mem,
FLAGS_cpu_threads, FLAGS_enable_mkldnn, FLAGS_table_char_dict_path,
FLAGS_use_tensorrt, FLAGS_precision, FLAGS_table_batch_num,
......@@ -31,68 +38,63 @@ PaddleStructure::PaddleStructure() {
}
};
std::vector<std::vector<StructurePredictResult>>
PaddleStructure::structure(std::vector<cv::String> cv_all_img_names,
bool layout, bool table) {
std::vector<double> time_info_det = {0, 0, 0};
std::vector<double> time_info_rec = {0, 0, 0};
std::vector<double> time_info_cls = {0, 0, 0};
std::vector<double> time_info_table = {0, 0, 0};
std::vector<StructurePredictResult>
PaddleStructure::structure(cv::Mat srcimg, bool layout, bool table, bool ocr) {
cv::Mat img;
srcimg.copyTo(img);
std::vector<std::vector<StructurePredictResult>> structure_results;
std::vector<StructurePredictResult> structure_results;
if (!Utility::PathExists(FLAGS_output) && FLAGS_det) {
Utility::CreateDir(FLAGS_output);
if (layout) {
this->layout(img, structure_results);
} else {
StructurePredictResult res;
res.type = "table";
res.box = std::vector<int>(4, 0);
res.box[2] = img.cols;
res.box[3] = img.rows;
structure_results.push_back(res);
}
for (int i = 0; i < cv_all_img_names.size(); ++i) {
std::vector<StructurePredictResult> structure_result;
cv::Mat srcimg = cv::imread(cv_all_img_names[i], cv::IMREAD_COLOR);
if (!srcimg.data) {
std::cerr << "[ERROR] image read failed! image path: "
<< cv_all_img_names[i] << endl;
exit(1);
}
if (layout) {
} else {
StructurePredictResult res;
res.type = "table";
res.box = std::vector<int>(4, 0);
res.box[2] = srcimg.cols;
res.box[3] = srcimg.rows;
structure_result.push_back(res);
}
cv::Mat roi_img;
for (int i = 0; i < structure_result.size(); i++) {
// crop image
roi_img = Utility::crop_image(srcimg, structure_result[i].box);
if (structure_result[i].type == "table") {
this->table(roi_img, structure_result[i], time_info_table,
time_info_det, time_info_rec, time_info_cls);
}
cv::Mat roi_img;
for (int i = 0; i < structure_results.size(); i++) {
// crop image
roi_img = Utility::crop_image(img, structure_results[i].box);
if (structure_results[i].type == "table" && table) {
this->table(roi_img, structure_results[i]);
} else if (ocr) {
structure_results[i].text_res = this->ocr(roi_img, true, true, false);
}
structure_results.push_back(structure_result);
}
return structure_results;
};
void PaddleStructure::layout(
cv::Mat img, std::vector<StructurePredictResult> &structure_result) {
std::vector<double> layout_times;
this->layout_model_->Run(img, structure_result, layout_times);
this->time_info_layout[0] += layout_times[0];
this->time_info_layout[1] += layout_times[1];
this->time_info_layout[2] += layout_times[2];
}
void PaddleStructure::table(cv::Mat img,
StructurePredictResult &structure_result,
std::vector<double> &time_info_table,
std::vector<double> &time_info_det,
std::vector<double> &time_info_rec,
std::vector<double> &time_info_cls) {
StructurePredictResult &structure_result) {
// predict structure
std::vector<std::vector<std::string>> structure_html_tags;
std::vector<float> structure_scores(1, 0);
std::vector<std::vector<std::vector<int>>> structure_boxes;
std::vector<double> structure_imes;
std::vector<double> structure_times;
std::vector<cv::Mat> img_list;
img_list.push_back(img);
this->recognizer_->Run(img_list, structure_html_tags, structure_scores,
structure_boxes, structure_imes);
time_info_table[0] += structure_imes[0];
time_info_table[1] += structure_imes[1];
time_info_table[2] += structure_imes[2];
this->table_model_->Run(img_list, structure_html_tags, structure_scores,
structure_boxes, structure_times);
this->time_info_table[0] += structure_times[0];
this->time_info_table[1] += structure_times[1];
this->time_info_table[2] += structure_times[2];
std::vector<OCRPredictResult> ocr_result;
std::string html;
......@@ -100,7 +102,7 @@ void PaddleStructure::table(cv::Mat img,
for (int i = 0; i < img_list.size(); i++) {
// det
this->det(img_list[i], ocr_result, time_info_det);
this->det(img_list[i], ocr_result);
// crop image
std::vector<cv::Mat> rec_img_list;
std::vector<int> ocr_box;
......@@ -115,7 +117,7 @@ void PaddleStructure::table(cv::Mat img,
rec_img_list.push_back(crop_img);
}
// rec
this->rec(rec_img_list, ocr_result, time_info_rec);
this->rec(rec_img_list, ocr_result);
// rebuild table
html = this->rebuild_table(structure_html_tags[i], structure_boxes[i],
ocr_result);
......@@ -150,7 +152,7 @@ PaddleStructure::rebuild_table(std::vector<std::string> structure_html_tags,
structure_box = structure_boxes[j];
}
dis_list[j][0] = this->dis(ocr_box, structure_box);
dis_list[j][1] = 1 - this->iou(ocr_box, structure_box);
dis_list[j][1] = 1 - Utility::iou(ocr_box, structure_box);
dis_list[j][2] = j;
}
// find min dis idx
......@@ -216,28 +218,6 @@ PaddleStructure::rebuild_table(std::vector<std::string> structure_html_tags,
return html_str;
}
float PaddleStructure::iou(std::vector<int> &box1, std::vector<int> &box2) {
int area1 = max(0, box1[2] - box1[0]) * max(0, box1[3] - box1[1]);
int area2 = max(0, box2[2] - box2[0]) * max(0, box2[3] - box2[1]);
// computing the sum_area
int sum_area = area1 + area2;
// find the each point of intersect rectangle
int x1 = max(box1[0], box2[0]);
int y1 = max(box1[1], box2[1]);
int x2 = min(box1[2], box2[2]);
int y2 = min(box1[3], box2[3]);
// judge if there is an intersect
if (y1 >= y2 || x1 >= x2) {
return 0.0;
} else {
int intersect = (x2 - x1) * (y2 - y1);
return intersect / (sum_area - intersect + 0.00000001);
}
}
float PaddleStructure::dis(std::vector<int> &box1, std::vector<int> &box2) {
int x1_1 = box1[0];
int y1_1 = box1[1];
......@@ -256,9 +236,61 @@ float PaddleStructure::dis(std::vector<int> &box1, std::vector<int> &box2) {
return dis + min(dis_2, dis_3);
}
void PaddleStructure::reset_timer() {
this->time_info_det = {0, 0, 0};
this->time_info_rec = {0, 0, 0};
this->time_info_cls = {0, 0, 0};
this->time_info_table = {0, 0, 0};
this->time_info_layout = {0, 0, 0};
}
void PaddleStructure::benchmark_log(int img_num) {
if (this->time_info_det[0] + this->time_info_det[1] + this->time_info_det[2] >
0) {
AutoLogger autolog_det("ocr_det", FLAGS_use_gpu, FLAGS_use_tensorrt,
FLAGS_enable_mkldnn, FLAGS_cpu_threads, 1, "dynamic",
FLAGS_precision, this->time_info_det, img_num);
autolog_det.report();
}
if (this->time_info_rec[0] + this->time_info_rec[1] + this->time_info_rec[2] >
0) {
AutoLogger autolog_rec("ocr_rec", FLAGS_use_gpu, FLAGS_use_tensorrt,
FLAGS_enable_mkldnn, FLAGS_cpu_threads,
FLAGS_rec_batch_num, "dynamic", FLAGS_precision,
this->time_info_rec, img_num);
autolog_rec.report();
}
if (this->time_info_cls[0] + this->time_info_cls[1] + this->time_info_cls[2] >
0) {
AutoLogger autolog_cls("ocr_cls", FLAGS_use_gpu, FLAGS_use_tensorrt,
FLAGS_enable_mkldnn, FLAGS_cpu_threads,
FLAGS_cls_batch_num, "dynamic", FLAGS_precision,
this->time_info_cls, img_num);
autolog_cls.report();
}
if (this->time_info_table[0] + this->time_info_table[1] +
this->time_info_table[2] >
0) {
AutoLogger autolog_table("table", FLAGS_use_gpu, FLAGS_use_tensorrt,
FLAGS_enable_mkldnn, FLAGS_cpu_threads,
FLAGS_cls_batch_num, "dynamic", FLAGS_precision,
this->time_info_table, img_num);
autolog_table.report();
}
if (this->time_info_layout[0] + this->time_info_layout[1] +
this->time_info_layout[2] >
0) {
AutoLogger autolog_layout("layout", FLAGS_use_gpu, FLAGS_use_tensorrt,
FLAGS_enable_mkldnn, FLAGS_cpu_threads,
FLAGS_cls_batch_num, "dynamic", FLAGS_precision,
this->time_info_layout, img_num);
autolog_layout.report();
}
}
PaddleStructure::~PaddleStructure() {
if (this->recognizer_ != nullptr) {
delete this->recognizer_;
if (this->table_model_ != nullptr) {
delete this->table_model_;
}
};
......
......@@ -352,25 +352,6 @@ std::vector<std::vector<std::vector<int>>> DBPostProcessor::FilterTagDetRes(
return root_points;
}
void TablePostProcessor::init(std::string label_path,
bool merge_no_span_structure) {
this->label_list_ = Utility::ReadDict(label_path);
if (merge_no_span_structure) {
this->label_list_.push_back("<td></td>");
std::vector<std::string>::iterator it;
for (it = this->label_list_.begin(); it != this->label_list_.end();) {
if (*it == "<td>") {
it = this->label_list_.erase(it);
} else {
++it;
}
}
}
// add_special_char
this->label_list_.insert(this->label_list_.begin(), this->beg);
this->label_list_.push_back(this->end);
}
void TablePostProcessor::Run(
std::vector<float> &loc_preds, std::vector<float> &structure_probs,
std::vector<float> &rec_scores, std::vector<int> &loc_preds_shape,
......@@ -440,4 +421,129 @@ void TablePostProcessor::Run(
}
}
void PicodetPostProcessor::Run(std::vector<StructurePredictResult> &results,
std::vector<std::vector<float>> outs,
std::vector<int> ori_shape,
std::vector<int> resize_shape, int reg_max) {
int in_h = resize_shape[0];
int in_w = resize_shape[1];
float scale_factor_h = resize_shape[0] / float(ori_shape[0]);
float scale_factor_w = resize_shape[1] / float(ori_shape[1]);
std::vector<std::vector<StructurePredictResult>> bbox_results;
bbox_results.resize(this->num_class_);
for (int i = 0; i < this->fpn_stride_.size(); ++i) {
int feature_h = std::ceil((float)in_h / this->fpn_stride_[i]);
int feature_w = std::ceil((float)in_w / this->fpn_stride_[i]);
for (int idx = 0; idx < feature_h * feature_w; idx++) {
// score and label
float score = 0;
int cur_label = 0;
for (int label = 0; label < this->num_class_; label++) {
if (outs[i][idx * this->num_class_ + label] > score) {
score = outs[i][idx * this->num_class_ + label];
cur_label = label;
}
}
// bbox
if (score > this->score_threshold_) {
int row = idx / feature_w;
int col = idx % feature_w;
std::vector<float> bbox_pred(
outs[i + this->fpn_stride_.size()].begin() + idx * 4 * reg_max,
outs[i + this->fpn_stride_.size()].begin() +
(idx + 1) * 4 * reg_max);
bbox_results[cur_label].push_back(
this->disPred2Bbox(bbox_pred, cur_label, score, col, row,
this->fpn_stride_[i], resize_shape, reg_max));
}
}
}
for (int i = 0; i < bbox_results.size(); i++) {
bool flag = bbox_results[i].size() <= 0;
}
for (int i = 0; i < bbox_results.size(); i++) {
bool flag = bbox_results[i].size() <= 0;
if (bbox_results[i].size() <= 0) {
continue;
}
this->nms(bbox_results[i], this->nms_threshold_);
for (auto box : bbox_results[i]) {
box.box_float[0] = box.box_float[0] / scale_factor_w;
box.box_float[2] = box.box_float[2] / scale_factor_w;
box.box_float[1] = box.box_float[1] / scale_factor_h;
box.box_float[3] = box.box_float[3] / scale_factor_h;
box.box = {(int)box.box_float[0], (int)box.box_float[1],
(int)box.box_float[2], (int)box.box_float[3]};
results.push_back(box);
}
}
}
StructurePredictResult
PicodetPostProcessor::disPred2Bbox(std::vector<float> bbox_pred, int label,
float score, int x, int y, int stride,
std::vector<int> im_shape, int reg_max) {
float ct_x = (x + 0.5) * stride;
float ct_y = (y + 0.5) * stride;
std::vector<float> dis_pred;
dis_pred.resize(4);
for (int i = 0; i < 4; i++) {
float dis = 0;
std::vector<float> bbox_pred_i(bbox_pred.begin() + i * reg_max,
bbox_pred.begin() + (i + 1) * reg_max);
std::vector<float> dis_after_sm =
Utility::activation_function_softmax(bbox_pred_i);
for (int j = 0; j < reg_max; j++) {
dis += j * dis_after_sm[j];
}
dis *= stride;
dis_pred[i] = dis;
}
float xmin_float = (std::max)(ct_x - dis_pred[0], .0f);
float ymin_float = (std::max)(ct_y - dis_pred[1], .0f);
float xmax_float = (std::min)(ct_x + dis_pred[2], (float)im_shape[1]);
float ymax_float = (std::min)(ct_y + dis_pred[3], (float)im_shape[0]);
StructurePredictResult result_item;
result_item.box_float = {xmin_float, ymin_float, xmax_float, ymax_float};
result_item.type = this->label_list_[label];
result_item.confidence = score;
return result_item;
}
void PicodetPostProcessor::nms(std::vector<StructurePredictResult> &input_boxes,
float nms_threshold) {
std::sort(input_boxes.begin(), input_boxes.end(),
[](StructurePredictResult a, StructurePredictResult b) {
return a.confidence > b.confidence;
});
std::vector<int> picked(input_boxes.size(), 1);
for (int i = 0; i < input_boxes.size(); ++i) {
if (picked[i] == 0) {
continue;
}
for (int j = i + 1; j < input_boxes.size(); ++j) {
if (picked[j] == 0) {
continue;
}
float iou =
Utility::iou(input_boxes[i].box_float, input_boxes[j].box_float);
if (iou > nms_threshold) {
picked[j] = 0;
}
}
}
std::vector<StructurePredictResult> input_boxes_nms;
for (int i = 0; i < input_boxes.size(); ++i) {
if (picked[i] == 1) {
input_boxes_nms.push_back(input_boxes[i]);
}
}
input_boxes = input_boxes_nms;
}
} // namespace PaddleOCR
......@@ -175,4 +175,9 @@ void TablePadImg::Run(const cv::Mat &img, cv::Mat &resize_img,
cv::BORDER_CONSTANT, cv::Scalar(0, 0, 0));
}
void Resize::Run(const cv::Mat &img, cv::Mat &resize_img, const int h,
const int w) {
cv::resize(img, resize_img, cv::Size(w, h));
}
} // namespace PaddleOCR
// Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include <include/structure_layout.h>
namespace PaddleOCR {
void StructureLayoutRecognizer::Run(cv::Mat img,
std::vector<StructurePredictResult> &result,
std::vector<double> &times) {
std::chrono::duration<float> preprocess_diff =
std::chrono::steady_clock::now() - std::chrono::steady_clock::now();
std::chrono::duration<float> inference_diff =
std::chrono::steady_clock::now() - std::chrono::steady_clock::now();
std::chrono::duration<float> postprocess_diff =
std::chrono::steady_clock::now() - std::chrono::steady_clock::now();
// preprocess
auto preprocess_start = std::chrono::steady_clock::now();
cv::Mat srcimg;
img.copyTo(srcimg);
cv::Mat resize_img;
this->resize_op_.Run(srcimg, resize_img, 800, 608);
this->normalize_op_.Run(&resize_img, this->mean_, this->scale_,
this->is_scale_);
std::vector<float> input(1 * 3 * resize_img.rows * resize_img.cols, 0.0f);
this->permute_op_.Run(&resize_img, input.data());
auto preprocess_end = std::chrono::steady_clock::now();
preprocess_diff += preprocess_end - preprocess_start;
// inference.
auto input_names = this->predictor_->GetInputNames();
auto input_t = this->predictor_->GetInputHandle(input_names[0]);
input_t->Reshape({1, 3, resize_img.rows, resize_img.cols});
auto inference_start = std::chrono::steady_clock::now();
input_t->CopyFromCpu(input.data());
this->predictor_->Run();
// Get output tensor
std::vector<std::vector<float>> out_tensor_list;
std::vector<std::vector<int>> output_shape_list;
auto output_names = this->predictor_->GetOutputNames();
for (int j = 0; j < output_names.size(); j++) {
auto output_tensor = this->predictor_->GetOutputHandle(output_names[j]);
std::vector<int> output_shape = output_tensor->shape();
int out_num = std::accumulate(output_shape.begin(), output_shape.end(), 1,
std::multiplies<int>());
output_shape_list.push_back(output_shape);
std::vector<float> out_data;
out_data.resize(out_num);
output_tensor->CopyToCpu(out_data.data());
out_tensor_list.push_back(out_data);
}
auto inference_end = std::chrono::steady_clock::now();
inference_diff += inference_end - inference_start;
// postprocess
auto postprocess_start = std::chrono::steady_clock::now();
std::vector<int> bbox_num;
int reg_max = 0;
for (int i = 0; i < out_tensor_list.size(); i++) {
if (i == this->post_processor_.fpn_stride_.size()) {
reg_max = output_shape_list[i][2] / 4;
break;
}
}
std::vector<int> ori_shape = {srcimg.rows, srcimg.cols};
std::vector<int> resize_shape = {resize_img.rows, resize_img.cols};
this->post_processor_.Run(result, out_tensor_list, ori_shape, resize_shape,
reg_max);
bbox_num.push_back(result.size());
auto postprocess_end = std::chrono::steady_clock::now();
postprocess_diff += postprocess_end - postprocess_start;
times.push_back(double(preprocess_diff.count() * 1000));
times.push_back(double(inference_diff.count() * 1000));
times.push_back(double(postprocess_diff.count() * 1000));
}
void StructureLayoutRecognizer::LoadModel(const std::string &model_dir) {
AnalysisConfig config;
if (Utility::PathExists(model_dir + "/inference.pdmodel") &&
Utility::PathExists(model_dir + "/inference.pdiparams")) {
config.SetModel(model_dir + "/inference.pdmodel",
model_dir + "/inference.pdiparams");
} else if (Utility::PathExists(model_dir + "/model.pdmodel") &&
Utility::PathExists(model_dir + "/model.pdiparams")) {
config.SetModel(model_dir + "/model.pdmodel",
model_dir + "/model.pdiparams");
} else {
std::cerr << "[ERROR] not find model.pdiparams or inference.pdiparams in "
<< model_dir << endl;
exit(1);
}
if (this->use_gpu_) {
config.EnableUseGpu(this->gpu_mem_, this->gpu_id_);
if (this->use_tensorrt_) {
auto precision = paddle_infer::Config::Precision::kFloat32;
if (this->precision_ == "fp16") {
precision = paddle_infer::Config::Precision::kHalf;
}
if (this->precision_ == "int8") {
precision = paddle_infer::Config::Precision::kInt8;
}
config.EnableTensorRtEngine(1 << 20, 10, 3, precision, false, false);
}
} else {
config.DisableGpu();
if (this->use_mkldnn_) {
config.EnableMKLDNN();
}
config.SetCpuMathLibraryNumThreads(this->cpu_math_library_num_threads_);
}
// false for zero copy tensor
config.SwitchUseFeedFetchOps(false);
// true for multiple input
config.SwitchSpecifyInputNames(true);
config.SwitchIrOptim(true);
config.EnableMemoryOptim();
config.DisableGlogInfo();
this->predictor_ = CreatePredictor(config);
}
} // namespace PaddleOCR
......@@ -66,10 +66,11 @@ void Utility::VisualizeBboxes(const cv::Mat &srcimg,
}
void Utility::VisualizeBboxes(const cv::Mat &srcimg,
const StructurePredictResult &structure_result,
StructurePredictResult &structure_result,
const std::string &save_path) {
cv::Mat img_vis;
srcimg.copyTo(img_vis);
img_vis = crop_image(img_vis, structure_result.box);
for (int n = 0; n < structure_result.cell_box.size(); n++) {
if (structure_result.cell_box[n].size() == 8) {
cv::Point rook_points[4];
......@@ -280,17 +281,17 @@ void Utility::print_result(const std::vector<OCRPredictResult> &ocr_result) {
}
}
cv::Mat Utility::crop_image(cv::Mat &img, std::vector<int> &area) {
cv::Mat Utility::crop_image(cv::Mat &img, std::vector<int> &box) {
cv::Mat crop_im;
int crop_x1 = std::max(0, area[0]);
int crop_y1 = std::max(0, area[1]);
int crop_x2 = std::min(img.cols - 1, area[2] - 1);
int crop_y2 = std::min(img.rows - 1, area[3] - 1);
int crop_x1 = std::max(0, box[0]);
int crop_y1 = std::max(0, box[1]);
int crop_x2 = std::min(img.cols - 1, box[2] - 1);
int crop_y2 = std::min(img.rows - 1, box[3] - 1);
crop_im = cv::Mat::zeros(area[3] - area[1], area[2] - area[0], 16);
crop_im = cv::Mat::zeros(box[3] - box[1], box[2] - box[0], 16);
cv::Mat crop_im_window =
crop_im(cv::Range(crop_y1 - area[1], crop_y2 + 1 - area[1]),
cv::Range(crop_x1 - area[0], crop_x2 + 1 - area[0]));
crop_im(cv::Range(crop_y1 - box[1], crop_y2 + 1 - box[1]),
cv::Range(crop_x1 - box[0], crop_x2 + 1 - box[0]));
cv::Mat roi_img =
img(cv::Range(crop_y1, crop_y2 + 1), cv::Range(crop_x1, crop_x2 + 1));
crop_im_window += roi_img;
......@@ -341,4 +342,78 @@ std::vector<int> Utility::xyxyxyxy2xyxy(std::vector<int> &box) {
return box1;
}
float Utility::fast_exp(float x) {
union {
uint32_t i;
float f;
} v{};
v.i = (1 << 23) * (1.4426950409 * x + 126.93490512f);
return v.f;
}
std::vector<float>
Utility::activation_function_softmax(std::vector<float> &src) {
int length = src.size();
std::vector<float> dst;
dst.resize(length);
const float alpha = float(*std::max_element(&src[0], &src[0 + length]));
float denominator{0};
for (int i = 0; i < length; ++i) {
dst[i] = fast_exp(src[i] - alpha);
denominator += dst[i];
}
for (int i = 0; i < length; ++i) {
dst[i] /= denominator;
}
return dst;
}
float Utility::iou(std::vector<int> &box1, std::vector<int> &box2) {
int area1 = std::max(0, box1[2] - box1[0]) * std::max(0, box1[3] - box1[1]);
int area2 = std::max(0, box2[2] - box2[0]) * std::max(0, box2[3] - box2[1]);
// computing the sum_area
int sum_area = area1 + area2;
// find the each point of intersect rectangle
int x1 = std::max(box1[0], box2[0]);
int y1 = std::max(box1[1], box2[1]);
int x2 = std::min(box1[2], box2[2]);
int y2 = std::min(box1[3], box2[3]);
// judge if there is an intersect
if (y1 >= y2 || x1 >= x2) {
return 0.0;
} else {
int intersect = (x2 - x1) * (y2 - y1);
return intersect / (sum_area - intersect + 0.00000001);
}
}
float Utility::iou(std::vector<float> &box1, std::vector<float> &box2) {
float area1 = std::max((float)0.0, box1[2] - box1[0]) *
std::max((float)0.0, box1[3] - box1[1]);
float area2 = std::max((float)0.0, box2[2] - box2[0]) *
std::max((float)0.0, box2[3] - box2[1]);
// computing the sum_area
float sum_area = area1 + area2;
// find the each point of intersect rectangle
float x1 = std::max(box1[0], box2[0]);
float y1 = std::max(box1[1], box2[1]);
float x2 = std::min(box1[2], box2[2]);
float y2 = std::min(box1[3], box2[3]);
// judge if there is an intersect
if (y1 >= y2 || x1 >= x2) {
return 0.0;
} else {
float intersect = (x2 - x1) * (y2 - y1);
return intersect / (sum_area - intersect + 0.00000001);
}
}
} // namespace PaddleOCR
\ No newline at end of file
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册