提交 3867c8cc 编写于 作者: 文幕地方's avatar 文幕地方

add table cpp infer

上级 97f7f748
......@@ -30,7 +30,8 @@ DECLARE_string(image_dir);
DECLARE_string(type);
// detection related
DECLARE_string(det_model_dir);
DECLARE_int32(max_side_len);
DECLARE_string(limit_type);
DECLARE_int32(limit_side_len);
DECLARE_double(det_db_thresh);
DECLARE_double(det_db_box_thresh);
DECLARE_double(det_db_unclip_ratio);
......@@ -48,7 +49,13 @@ DECLARE_int32(rec_batch_num);
DECLARE_string(rec_char_dict_path);
DECLARE_int32(rec_img_h);
DECLARE_int32(rec_img_w);
// structure model related
DECLARE_string(table_model_dir);
DECLARE_int32(table_max_len);
DECLARE_int32(table_batch_num);
DECLARE_string(table_char_dict_path);
// forward related
DECLARE_bool(det);
DECLARE_bool(rec);
DECLARE_bool(cls);
DECLARE_bool(table);
\ No newline at end of file
......@@ -41,8 +41,8 @@ public:
explicit DBDetector(const std::string &model_dir, const bool &use_gpu,
const int &gpu_id, const int &gpu_mem,
const int &cpu_math_library_num_threads,
const bool &use_mkldnn, const int &max_side_len,
const double &det_db_thresh,
const bool &use_mkldnn, const string &limit_type,
const int &limit_side_len, const double &det_db_thresh,
const double &det_db_box_thresh,
const double &det_db_unclip_ratio,
const std::string &det_db_score_mode,
......@@ -54,7 +54,8 @@ public:
this->cpu_math_library_num_threads_ = cpu_math_library_num_threads;
this->use_mkldnn_ = use_mkldnn;
this->max_side_len_ = max_side_len;
this->limit_type_ = limit_type;
this->limit_side_len_ = limit_side_len;
this->det_db_thresh_ = det_db_thresh;
this->det_db_box_thresh_ = det_db_box_thresh;
......@@ -84,7 +85,8 @@ private:
int cpu_math_library_num_threads_ = 4;
bool use_mkldnn_ = false;
int max_side_len_ = 960;
string limit_type_ = "max";
int limit_side_len_ = 960;
double det_db_thresh_ = 0.3;
double det_db_box_thresh_ = 0.5;
......@@ -106,7 +108,7 @@ private:
Permute permute_op_;
// post-process
PostProcessor post_processor_;
DBPostProcessor post_processor_;
};
} // namespace PaddleOCR
\ No newline at end of file
......@@ -47,11 +47,7 @@ public:
ocr(std::vector<cv::String> cv_all_img_names, bool det = true,
bool rec = true, bool cls = true);
private:
DBDetector *detector_ = nullptr;
Classifier *classifier_ = nullptr;
CRNNRecognizer *recognizer_ = nullptr;
protected:
void det(cv::Mat img, std::vector<OCRPredictResult> &ocr_results,
std::vector<double> &times);
void rec(std::vector<cv::Mat> img_list,
......@@ -62,6 +58,11 @@ private:
std::vector<double> &times);
void log(std::vector<double> &det_times, std::vector<double> &rec_times,
std::vector<double> &cls_times, int img_num);
private:
DBDetector *detector_ = nullptr;
Classifier *classifier_ = nullptr;
CRNNRecognizer *recognizer_ = nullptr;
};
} // namespace PaddleOCR
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include "opencv2/core.hpp"
#include "opencv2/imgcodecs.hpp"
#include "opencv2/imgproc.hpp"
#include "paddle_api.h"
#include "paddle_inference_api.h"
#include <chrono>
#include <iomanip>
#include <iostream>
#include <ostream>
#include <vector>
#include <cstring>
#include <fstream>
#include <numeric>
#include <include/paddleocr.h>
#include <include/preprocess_op.h>
#include <include/structure_table.h>
#include <include/utility.h>
using namespace paddle_infer;
namespace PaddleOCR {
class PaddleStructure : public PPOCR {
public:
explicit PaddleStructure();
~PaddleStructure();
std::vector<std::vector<StructurePredictResult>>
structure(std::vector<cv::String> cv_all_img_names, bool layout = false,
bool table = true);
private:
StructureTableRecognizer *recognizer_ = nullptr;
void table(cv::Mat img, StructurePredictResult &structure_result,
std::vector<double> &time_info_table,
std::vector<double> &time_info_det,
std::vector<double> &time_info_rec,
std::vector<double> &time_info_cls);
std::string
rebuild_table(std::vector<std::string> rec_html_tags,
std::vector<std::vector<std::vector<int>>> rec_boxes,
std::vector<OCRPredictResult> &ocr_result);
float iou(std::vector<std::vector<int>> &box1,
std::vector<std::vector<int>> &box2);
float dis(std::vector<std::vector<int>> &box1,
std::vector<std::vector<int>> &box2);
static bool comparison_dis(const std::vector<float> &dis1,
const std::vector<float> &dis2) {
if (dis1[1] < dis2[1]) {
return true;
} else if (dis1[1] == dis2[1]) {
return dis1[0] < dis2[0];
} else {
return false;
}
}
};
} // namespace PaddleOCR
\ No newline at end of file
......@@ -34,7 +34,7 @@ using namespace std;
namespace PaddleOCR {
class PostProcessor {
class DBPostProcessor {
public:
void GetContourArea(const std::vector<std::vector<float>> &box,
float unclip_ratio, float &distance);
......@@ -90,4 +90,21 @@ private:
}
};
class TablePostProcessor {
public:
void init(std::string label_path);
void
Run(std::vector<float> &loc_preds, std::vector<float> &structure_probs,
std::vector<float> &rec_scores, std::vector<int> &loc_preds_shape,
std::vector<int> &structure_probs_shape,
std::vector<std::vector<std::string>> &rec_html_tag_batch,
std::vector<std::vector<std::vector<std::vector<int>>>> &rec_boxes_batch,
std::vector<int> &width_list, std::vector<int> &height_list);
private:
std::vector<std::string> label_list_;
std::string end = "eos";
std::string beg = "sos";
};
} // namespace PaddleOCR
......@@ -51,8 +51,9 @@ public:
class ResizeImgType0 {
public:
virtual void Run(const cv::Mat &img, cv::Mat &resize_img, int max_size_len,
float &ratio_h, float &ratio_w, bool use_tensorrt);
virtual void Run(const cv::Mat &img, cv::Mat &resize_img, string limit_type,
int limit_side_len, float &ratio_h, float &ratio_w,
bool use_tensorrt);
};
class CrnnResizeImg {
......@@ -69,4 +70,16 @@ public:
const std::vector<int> &rec_image_shape = {3, 48, 192});
};
class TableResizeImg {
public:
virtual void Run(const cv::Mat &img, cv::Mat &resize_img,
const int max_len = 488);
};
class TablePadImg {
public:
virtual void Run(const cv::Mat &img, cv::Mat &resize_img,
const int max_len = 488);
};
} // namespace PaddleOCR
\ No newline at end of file
// Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include "opencv2/core.hpp"
#include "opencv2/imgcodecs.hpp"
#include "opencv2/imgproc.hpp"
#include "paddle_api.h"
#include "paddle_inference_api.h"
#include <chrono>
#include <iomanip>
#include <iostream>
#include <ostream>
#include <vector>
#include <cstring>
#include <fstream>
#include <numeric>
#include <include/postprocess_op.h>
#include <include/preprocess_op.h>
#include <include/utility.h>
using namespace paddle_infer;
namespace PaddleOCR {
class StructureTableRecognizer {
public:
explicit StructureTableRecognizer(
const std::string &model_dir, const bool &use_gpu, const int &gpu_id,
const int &gpu_mem, const int &cpu_math_library_num_threads,
const bool &use_mkldnn, const string &label_path,
const bool &use_tensorrt, const std::string &precision,
const int &table_batch_num, const int &table_max_len) {
this->use_gpu_ = use_gpu;
this->gpu_id_ = gpu_id;
this->gpu_mem_ = gpu_mem;
this->cpu_math_library_num_threads_ = cpu_math_library_num_threads;
this->use_mkldnn_ = use_mkldnn;
this->use_tensorrt_ = use_tensorrt;
this->precision_ = precision;
this->table_batch_num_ = table_batch_num;
this->table_max_len_ = table_max_len;
this->post_processor_.init(label_path);
LoadModel(model_dir);
}
// Load Paddle inference model
void LoadModel(const std::string &model_dir);
void Run(std::vector<cv::Mat> img_list,
std::vector<std::vector<std::string>> &rec_html_tags,
std::vector<float> &rec_scores,
std::vector<std::vector<std::vector<std::vector<int>>>> &rec_boxes,
std::vector<double> &times);
private:
std::shared_ptr<Predictor> predictor_;
bool use_gpu_ = false;
int gpu_id_ = 0;
int gpu_mem_ = 4000;
int cpu_math_library_num_threads_ = 4;
bool use_mkldnn_ = false;
int table_max_len_ = 488;
std::vector<float> mean_ = {0.485f, 0.456f, 0.406f};
std::vector<float> scale_ = {1 / 0.229f, 1 / 0.224f, 1 / 0.225f};
bool is_scale_ = true;
bool use_tensorrt_ = false;
std::string precision_ = "fp32";
int table_batch_num_ = 1;
// pre-process
TableResizeImg resize_op_;
Normalize normalize_op_;
PermuteBatch permute_op_;
TablePadImg pad_op_;
// post-process
TablePostProcessor post_processor_;
}; // class StructureTableRecognizer
} // namespace PaddleOCR
\ No newline at end of file
......@@ -40,6 +40,14 @@ struct OCRPredictResult {
int cls_label = -1;
};
struct StructurePredictResult {
std::vector<int> box;
std::string type;
std::vector<OCRPredictResult> text_res;
std::string html;
float html_score = -1;
};
class Utility {
public:
static std::vector<std::string> ReadDict(const std::string &path);
......@@ -68,6 +76,22 @@ public:
static void CreateDir(const std::string &path);
static void print_result(const std::vector<OCRPredictResult> &ocr_result);
static cv::Mat crop_image(cv::Mat &img, std::vector<int> &area);
static void sorted_boxes(std::vector<OCRPredictResult> &ocr_result);
private:
static bool comparison_box(const OCRPredictResult &result1,
const OCRPredictResult &result2) {
if (result1.box[0][1] < result2.box[0][1]) {
return true;
} else if (result1.box[0][1] == result2.box[0][1]) {
return result1.box[0][0] < result2.box[0][0];
} else {
return false;
}
}
};
} // namespace PaddleOCR
\ No newline at end of file
......@@ -30,7 +30,8 @@ DEFINE_string(
"Perform ocr or structure, the value is selected in ['ocr','structure'].");
// detection related
DEFINE_string(det_model_dir, "", "Path of det inference model.");
DEFINE_int32(max_side_len, 960, "max_side_len of input image.");
DEFINE_string(limit_type, "max", "limit_type of input image.");
DEFINE_int32(limit_side_len, 960, "limit_side_len of input image.");
DEFINE_double(det_db_thresh, 0.3, "Threshold of det_db_thresh.");
DEFINE_double(det_db_box_thresh, 0.6, "Threshold of det_db_box_thresh.");
DEFINE_double(det_db_unclip_ratio, 1.5, "Threshold of det_db_unclip_ratio.");
......@@ -50,7 +51,16 @@ DEFINE_string(rec_char_dict_path, "../../ppocr/utils/ppocr_keys_v1.txt",
DEFINE_int32(rec_img_h, 48, "rec image height");
DEFINE_int32(rec_img_w, 320, "rec image width");
// structure model related
DEFINE_string(table_model_dir, "", "Path of table struture inference model.");
DEFINE_int32(table_max_len, 488, "max len size of input image.");
DEFINE_int32(table_batch_num, 1, "table_batch_num.");
DEFINE_string(table_char_dict_path,
"../../ppocr/utils/dict/table_structure_dict.txt",
"Path of dictionary.");
// ocr forward related
DEFINE_bool(det, true, "Whether use det in forward.");
DEFINE_bool(rec, true, "Whether use rec in forward.");
DEFINE_bool(cls, false, "Whether use cls in forward.");
DEFINE_bool(table, false, "Whether use table structure in forward.");
\ No newline at end of file
......@@ -19,6 +19,7 @@
#include <include/args.h>
#include <include/paddleocr.h>
#include <include/paddlestructure.h>
using namespace PaddleOCR;
......@@ -32,6 +33,12 @@ void check_params() {
}
}
if (FLAGS_rec) {
std::cout
<< "In PP-OCRv3, rec_image_shape parameter defaults to '3, 48, 320',"
"if you are using recognition model with PP-OCRv2 or an older "
"version, "
"please set --rec_image_shape='3,32,320"
<< std::endl;
if (FLAGS_rec_model_dir.empty() || FLAGS_image_dir.empty()) {
std::cout << "Usage[rec]: ./ppocr "
"--rec_model_dir=/PATH/TO/REC_INFERENCE_MODEL/ "
......@@ -47,6 +54,17 @@ void check_params() {
exit(1);
}
}
if (FLAGS_table) {
if (FLAGS_table_model_dir.empty() || FLAGS_det_model_dir.empty() ||
FLAGS_rec_model_dir.empty() || FLAGS_image_dir.empty()) {
std::cout << "Usage[table]: ./ppocr "
<< "--det_model_dir=/PATH/TO/DET_INFERENCE_MODEL/ "
<< "--rec_model_dir=/PATH/TO/REC_INFERENCE_MODEL/ "
<< "--table_model_dir=/PATH/TO/TABLE_INFERENCE_MODEL/ "
<< "--image_dir=/PATH/TO/INPUT/IMAGE/" << std::endl;
exit(1);
}
}
if (FLAGS_precision != "fp32" && FLAGS_precision != "fp16" &&
FLAGS_precision != "int8") {
cout << "precison should be 'fp32'(default), 'fp16' or 'int8'. " << endl;
......@@ -54,21 +72,7 @@ void check_params() {
}
}
int main(int argc, char **argv) {
// Parsing command-line
google::ParseCommandLineFlags(&argc, &argv, true);
check_params();
if (!Utility::PathExists(FLAGS_image_dir)) {
std::cerr << "[ERROR] image path not exist! image_dir: " << FLAGS_image_dir
<< endl;
exit(1);
}
std::vector<cv::String> cv_all_img_names;
cv::glob(FLAGS_image_dir, cv_all_img_names);
std::cout << "total images num: " << cv_all_img_names.size() << endl;
void ocr(std::vector<cv::String> &cv_all_img_names) {
PPOCR ocr = PPOCR();
std::vector<std::vector<OCRPredictResult>> ocr_results =
......@@ -109,3 +113,49 @@ int main(int argc, char **argv) {
}
}
}
void structure(std::vector<cv::String> &cv_all_img_names) {
PaddleOCR::PaddleStructure engine = PaddleOCR::PaddleStructure();
std::vector<std::vector<StructurePredictResult>> structure_results =
engine.structure(cv_all_img_names, false, FLAGS_table);
for (int i = 0; i < cv_all_img_names.size(); i++) {
cout << cv_all_img_names[i] << "\n";
for (int j = 0; j < structure_results[i].size(); j++) {
std::cout << j << "\ttype: " << structure_results[i][j].type
<< ", region: [";
std::cout << structure_results[i][j].box[0] << ","
<< structure_results[i][j].box[1] << ","
<< structure_results[i][j].box[2] << ","
<< structure_results[i][j].box[3] << "], res: ";
if (structure_results[i][j].type == "Table") {
std::cout << structure_results[i][j].html << std::endl;
} else {
Utility::print_result(structure_results[i][j].text_res);
}
}
}
}
int main(int argc, char **argv) {
// Parsing command-line
google::ParseCommandLineFlags(&argc, &argv, true);
check_params();
if (!Utility::PathExists(FLAGS_image_dir)) {
std::cerr << "[ERROR] image path not exist! image_dir: " << FLAGS_image_dir
<< endl;
exit(1);
}
std::vector<cv::String> cv_all_img_names;
cv::glob(FLAGS_image_dir, cv_all_img_names);
std::cout << "total images num: " << cv_all_img_names.size() << endl;
if (FLAGS_type == "ocr") {
ocr(cv_all_img_names);
} else if (FLAGS_type == "structure") {
structure(cv_all_img_names);
} else {
std::cout << "only value in ['ocr','structure'] is supported" << endl;
}
}
......@@ -47,7 +47,7 @@ void DBDetector::LoadModel(const std::string &model_dir) {
{"elementwise_add_7", {1, 56, 2, 2}},
{"nearest_interp_v2_0.tmp_0", {1, 256, 2, 2}}};
std::map<std::string, std::vector<int>> max_input_shape = {
{"x", {1, 3, this->max_side_len_, this->max_side_len_}},
{"x", {1, 3, 1536, 1536}},
{"conv2d_92.tmp_0", {1, 120, 400, 400}},
{"conv2d_91.tmp_0", {1, 24, 200, 200}},
{"conv2d_59.tmp_0", {1, 96, 400, 400}},
......@@ -109,7 +109,8 @@ void DBDetector::Run(cv::Mat &img,
img.copyTo(srcimg);
auto preprocess_start = std::chrono::steady_clock::now();
this->resize_op_.Run(img, resize_img, this->max_side_len_, ratio_h, ratio_w,
this->resize_op_.Run(img, resize_img, this->limit_type_,
this->limit_side_len_, ratio_h, ratio_w,
this->use_tensorrt_);
this->normalize_op_.Run(&resize_img, this->mean_, this->scale_,
......
......@@ -23,10 +23,10 @@ PPOCR::PPOCR() {
if (FLAGS_det) {
this->detector_ = new DBDetector(
FLAGS_det_model_dir, FLAGS_use_gpu, FLAGS_gpu_id, FLAGS_gpu_mem,
FLAGS_cpu_threads, FLAGS_enable_mkldnn, FLAGS_max_side_len,
FLAGS_det_db_thresh, FLAGS_det_db_box_thresh, FLAGS_det_db_unclip_ratio,
FLAGS_det_db_score_mode, FLAGS_use_dilation, FLAGS_use_tensorrt,
FLAGS_precision);
FLAGS_cpu_threads, FLAGS_enable_mkldnn, FLAGS_limit_type,
FLAGS_limit_side_len, FLAGS_det_db_thresh, FLAGS_det_db_box_thresh,
FLAGS_det_db_unclip_ratio, FLAGS_det_db_score_mode, FLAGS_use_dilation,
FLAGS_use_tensorrt, FLAGS_precision);
}
if (FLAGS_cls && FLAGS_use_angle_cls) {
......@@ -56,7 +56,8 @@ void PPOCR::det(cv::Mat img, std::vector<OCRPredictResult> &ocr_results,
res.box = boxes[i];
ocr_results.push_back(res);
}
// sort boex from top to bottom, from left to right
Utility::sorted_boxes(ocr_results);
times[0] += det_times[0];
times[1] += det_times[1];
times[2] += det_times[2];
......
// Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include <include/args.h>
#include <include/paddlestructure.h>
#include "auto_log/autolog.h"
#include <numeric>
#include <sys/stat.h>
namespace PaddleOCR {
PaddleStructure::PaddleStructure() {
if (FLAGS_table) {
this->recognizer_ = new StructureTableRecognizer(
FLAGS_table_model_dir, FLAGS_use_gpu, FLAGS_gpu_id, FLAGS_gpu_mem,
FLAGS_cpu_threads, FLAGS_enable_mkldnn, FLAGS_table_char_dict_path,
FLAGS_use_tensorrt, FLAGS_precision, FLAGS_table_batch_num,
FLAGS_table_max_len);
}
};
std::vector<std::vector<StructurePredictResult>>
PaddleStructure::structure(std::vector<cv::String> cv_all_img_names,
bool layout, bool table) {
std::vector<double> time_info_det = {0, 0, 0};
std::vector<double> time_info_rec = {0, 0, 0};
std::vector<double> time_info_cls = {0, 0, 0};
std::vector<double> time_info_table = {0, 0, 0};
std::vector<std::vector<StructurePredictResult>> structure_results;
if (!Utility::PathExists(FLAGS_output) && FLAGS_det) {
mkdir(FLAGS_output.c_str(), 0777);
}
for (int i = 0; i < cv_all_img_names.size(); ++i) {
std::vector<StructurePredictResult> structure_result;
cv::Mat srcimg = cv::imread(cv_all_img_names[i], cv::IMREAD_COLOR);
if (!srcimg.data) {
std::cerr << "[ERROR] image read failed! image path: "
<< cv_all_img_names[i] << endl;
exit(1);
}
if (layout) {
} else {
StructurePredictResult res;
res.type = "Table";
res.box = std::vector<int>(4, 0);
res.box[2] = srcimg.cols;
res.box[3] = srcimg.rows;
structure_result.push_back(res);
}
cv::Mat roi_img;
for (int i = 0; i < structure_result.size(); i++) {
// crop image
roi_img = Utility::crop_image(srcimg, structure_result[i].box);
if (structure_result[i].type == "Table") {
this->table(roi_img, structure_result[i], time_info_table,
time_info_det, time_info_rec, time_info_cls);
}
}
structure_results.push_back(structure_result);
}
return structure_results;
};
void PaddleStructure::table(cv::Mat img,
StructurePredictResult &structure_result,
std::vector<double> &time_info_table,
std::vector<double> &time_info_det,
std::vector<double> &time_info_rec,
std::vector<double> &time_info_cls) {
// predict structure
std::vector<std::vector<std::string>> structure_html_tags;
std::vector<float> structure_scores(1, 0);
std::vector<std::vector<std::vector<std::vector<int>>>> structure_boxes;
std::vector<double> structure_imes;
std::vector<cv::Mat> img_list;
img_list.push_back(img);
this->recognizer_->Run(img_list, structure_html_tags, structure_scores,
structure_boxes, structure_imes);
time_info_table[0] += structure_imes[0];
time_info_table[1] += structure_imes[1];
time_info_table[2] += structure_imes[2];
std::vector<OCRPredictResult> ocr_result;
std::string html;
int expand_pixel = 3;
for (int i = 0; i < img_list.size(); i++) {
// det
this->det(img_list[i], ocr_result, time_info_det);
// crop image
std::vector<cv::Mat> rec_img_list;
for (int j = 0; j < ocr_result.size(); j++) {
int x_collect[4] = {ocr_result[j].box[0][0], ocr_result[j].box[1][0],
ocr_result[j].box[2][0], ocr_result[j].box[3][0]};
int y_collect[4] = {ocr_result[j].box[0][1], ocr_result[j].box[1][1],
ocr_result[j].box[2][1], ocr_result[j].box[3][1]};
int left = int(*std::min_element(x_collect, x_collect + 4));
int right = int(*std::max_element(x_collect, x_collect + 4));
int top = int(*std::min_element(y_collect, y_collect + 4));
int bottom = int(*std::max_element(y_collect, y_collect + 4));
std::vector<int> box{max(0, left - expand_pixel),
max(0, top - expand_pixel),
min(img_list[i].cols, right + expand_pixel),
min(img_list[i].rows, bottom + expand_pixel)};
cv::Mat crop_img = Utility::crop_image(img_list[i], box);
rec_img_list.push_back(crop_img);
}
// rec
this->rec(rec_img_list, ocr_result, time_info_rec);
// rebuild table
html = this->rebuild_table(structure_html_tags[i], structure_boxes[i],
ocr_result);
structure_result.html = html;
structure_result.html_score = structure_scores[i];
}
};
std::string PaddleStructure::rebuild_table(
std::vector<std::string> structure_html_tags,
std::vector<std::vector<std::vector<int>>> structure_boxes,
std::vector<OCRPredictResult> &ocr_result) {
// match text in same cell
std::vector<std::vector<string>> matched(structure_boxes.size(),
std::vector<std::string>());
for (int i = 0; i < ocr_result.size(); i++) {
std::vector<std::vector<float>> dis_list(structure_boxes.size(),
std::vector<float>(3, 100000.0));
for (int j = 0; j < structure_boxes.size(); j++) {
int x_collect[4] = {ocr_result[i].box[0][0], ocr_result[i].box[1][0],
ocr_result[i].box[2][0], ocr_result[i].box[3][0]};
int y_collect[4] = {ocr_result[i].box[0][1], ocr_result[i].box[1][1],
ocr_result[i].box[2][1], ocr_result[i].box[3][1]};
int left = int(*std::min_element(x_collect, x_collect + 4));
int right = int(*std::max_element(x_collect, x_collect + 4));
int top = int(*std::min_element(y_collect, y_collect + 4));
int bottom = int(*std::max_element(y_collect, y_collect + 4));
std::vector<std::vector<int>> box(2, std::vector<int>(2, 0));
box[0][0] = left - 1;
box[0][1] = top - 1;
box[1][0] = right + 1;
box[1][1] = bottom + 1;
dis_list[j][0] = this->dis(box, structure_boxes[j]);
dis_list[j][1] = 1 - this->iou(box, structure_boxes[j]);
dis_list[j][2] = j;
}
// find min dis idx
std::sort(dis_list.begin(), dis_list.end(),
PaddleStructure::comparison_dis);
matched[dis_list[0][2]].push_back(ocr_result[i].text);
}
// get pred html
std::string html_str = "";
int td_tag_idx = 0;
for (int i = 0; i < structure_html_tags.size(); i++) {
if (structure_html_tags[i].find("</td>") != std::string::npos) {
if (structure_html_tags[i].find("<td></td>") != std::string::npos) {
html_str += "<td>";
}
if (matched[td_tag_idx].size() > 0) {
bool b_with = false;
if (matched[td_tag_idx][0].find("<b>") != std::string::npos &&
matched[td_tag_idx].size() > 1) {
b_with = true;
html_str += "<b>";
}
for (int j = 0; j < matched[td_tag_idx].size(); j++) {
std::string content = matched[td_tag_idx][j];
if (matched[td_tag_idx].size() > 1) {
// remove blank, <b> and </b>
if (content.length() > 0 && content.at(0) == ' ') {
content = content.substr(0);
}
if (content.length() > 2 && content.substr(0, 3) == "<b>") {
content = content.substr(3);
}
if (content.length() > 4 &&
content.substr(content.length() - 4) == "</b>") {
content = content.substr(0, content.length() - 4);
}
if (content.empty()) {
continue;
}
// add blank
if (j != matched[td_tag_idx].size() - 1 &&
content.at(content.length() - 1) != ' ') {
content += ' ';
}
}
html_str += content;
}
if (b_with) {
html_str += "</b>";
}
}
if (structure_html_tags[i].find("<td></td>") != std::string::npos) {
html_str += "</td>";
} else {
html_str += structure_html_tags[i];
}
td_tag_idx += 1;
} else {
html_str += structure_html_tags[i];
}
}
return html_str;
}
float PaddleStructure::iou(std::vector<std::vector<int>> &box1,
std::vector<std::vector<int>> &box2) {
int area1 = max(0, box1[1][0] - box1[0][0]) * max(0, box1[1][1] - box1[0][1]);
int area2 = max(0, box2[1][0] - box2[0][0]) * max(0, box2[1][1] - box2[0][1]);
// computing the sum_area
int sum_area = area1 + area2;
// find the each point of intersect rectangle
int x1 = max(box1[0][0], box2[0][0]);
int y1 = max(box1[0][1], box2[0][1]);
int x2 = min(box1[1][0], box2[1][0]);
int y2 = min(box1[1][1], box2[1][1]);
// judge if there is an intersect
if (y1 >= y2 || x1 >= x2) {
return 0.0;
} else {
int intersect = (x2 - x1) * (y2 - y1);
return intersect / (sum_area - intersect + 0.00000001);
}
}
float PaddleStructure::dis(std::vector<std::vector<int>> &box1,
std::vector<std::vector<int>> &box2) {
int x1_1 = box1[0][0];
int y1_1 = box1[0][1];
int x2_1 = box1[1][0];
int y2_1 = box1[1][1];
int x1_2 = box2[0][0];
int y1_2 = box2[0][1];
int x2_2 = box2[1][0];
int y2_2 = box2[1][1];
float dis =
abs(x1_2 - x1_1) + abs(y1_2 - y1_1) + abs(x2_2 - x2_1) + abs(y2_2 - y2_1);
float dis_2 = abs(x1_2 - x1_1) + abs(y1_2 - y1_1);
float dis_3 = abs(x2_2 - x2_1) + abs(y2_2 - y2_1);
return dis + min(dis_2, dis_3);
}
PaddleStructure::~PaddleStructure() {
if (this->recognizer_ != nullptr) {
delete this->recognizer_;
}
};
} // namespace PaddleOCR
\ No newline at end of file
......@@ -17,7 +17,7 @@
namespace PaddleOCR {
void PostProcessor::GetContourArea(const std::vector<std::vector<float>> &box,
void DBPostProcessor::GetContourArea(const std::vector<std::vector<float>> &box,
float unclip_ratio, float &distance) {
int pts_num = 4;
float area = 0.0f;
......@@ -35,7 +35,7 @@ void PostProcessor::GetContourArea(const std::vector<std::vector<float>> &box,
distance = area * unclip_ratio / dist;
}
cv::RotatedRect PostProcessor::UnClip(std::vector<std::vector<float>> box,
cv::RotatedRect DBPostProcessor::UnClip(std::vector<std::vector<float>> box,
const float &unclip_ratio) {
float distance = 1.0;
......@@ -67,7 +67,7 @@ cv::RotatedRect PostProcessor::UnClip(std::vector<std::vector<float>> box,
return res;
}
float **PostProcessor::Mat2Vec(cv::Mat mat) {
float **DBPostProcessor::Mat2Vec(cv::Mat mat) {
auto **array = new float *[mat.rows];
for (int i = 0; i < mat.rows; ++i)
array[i] = new float[mat.cols];
......@@ -81,7 +81,7 @@ float **PostProcessor::Mat2Vec(cv::Mat mat) {
}
std::vector<std::vector<int>>
PostProcessor::OrderPointsClockwise(std::vector<std::vector<int>> pts) {
DBPostProcessor::OrderPointsClockwise(std::vector<std::vector<int>> pts) {
std::vector<std::vector<int>> box = pts;
std::sort(box.begin(), box.end(), XsortInt);
......@@ -99,7 +99,7 @@ PostProcessor::OrderPointsClockwise(std::vector<std::vector<int>> pts) {
return rect;
}
std::vector<std::vector<float>> PostProcessor::Mat2Vector(cv::Mat mat) {
std::vector<std::vector<float>> DBPostProcessor::Mat2Vector(cv::Mat mat) {
std::vector<std::vector<float>> img_vec;
std::vector<float> tmp;
......@@ -113,20 +113,20 @@ std::vector<std::vector<float>> PostProcessor::Mat2Vector(cv::Mat mat) {
return img_vec;
}
bool PostProcessor::XsortFp32(std::vector<float> a, std::vector<float> b) {
bool DBPostProcessor::XsortFp32(std::vector<float> a, std::vector<float> b) {
if (a[0] != b[0])
return a[0] < b[0];
return false;
}
bool PostProcessor::XsortInt(std::vector<int> a, std::vector<int> b) {
bool DBPostProcessor::XsortInt(std::vector<int> a, std::vector<int> b) {
if (a[0] != b[0])
return a[0] < b[0];
return false;
}
std::vector<std::vector<float>> PostProcessor::GetMiniBoxes(cv::RotatedRect box,
float &ssid) {
std::vector<std::vector<float>>
DBPostProcessor::GetMiniBoxes(cv::RotatedRect box, float &ssid) {
ssid = std::max(box.size.width, box.size.height);
cv::Mat points;
......@@ -160,7 +160,7 @@ std::vector<std::vector<float>> PostProcessor::GetMiniBoxes(cv::RotatedRect box,
return array;
}
float PostProcessor::PolygonScoreAcc(std::vector<cv::Point> contour,
float DBPostProcessor::PolygonScoreAcc(std::vector<cv::Point> contour,
cv::Mat pred) {
int width = pred.cols;
int height = pred.rows;
......@@ -206,7 +206,7 @@ float PostProcessor::PolygonScoreAcc(std::vector<cv::Point> contour,
return score;
}
float PostProcessor::BoxScoreFast(std::vector<std::vector<float>> box_array,
float DBPostProcessor::BoxScoreFast(std::vector<std::vector<float>> box_array,
cv::Mat pred) {
auto array = box_array;
int width = pred.cols;
......@@ -244,7 +244,7 @@ float PostProcessor::BoxScoreFast(std::vector<std::vector<float>> box_array,
return score;
}
std::vector<std::vector<std::vector<int>>> PostProcessor::BoxesFromBitmap(
std::vector<std::vector<std::vector<int>>> DBPostProcessor::BoxesFromBitmap(
const cv::Mat pred, const cv::Mat bitmap, const float &box_thresh,
const float &det_db_unclip_ratio, const std::string &det_db_score_mode) {
const int min_size = 3;
......@@ -321,9 +321,9 @@ std::vector<std::vector<std::vector<int>>> PostProcessor::BoxesFromBitmap(
return boxes;
}
std::vector<std::vector<std::vector<int>>>
PostProcessor::FilterTagDetRes(std::vector<std::vector<std::vector<int>>> boxes,
float ratio_h, float ratio_w, cv::Mat srcimg) {
std::vector<std::vector<std::vector<int>>> DBPostProcessor::FilterTagDetRes(
std::vector<std::vector<std::vector<int>>> boxes, float ratio_h,
float ratio_w, cv::Mat srcimg) {
int oriimg_h = srcimg.rows;
int oriimg_w = srcimg.cols;
......@@ -352,4 +352,77 @@ PostProcessor::FilterTagDetRes(std::vector<std::vector<std::vector<int>>> boxes,
return root_points;
}
void TablePostProcessor::init(std::string label_path) {
this->label_list_ = Utility::ReadDict(label_path);
this->label_list_.insert(this->label_list_.begin(), this->beg);
this->label_list_.push_back(this->end);
}
void TablePostProcessor::Run(
std::vector<float> &loc_preds, std::vector<float> &structure_probs,
std::vector<float> &rec_scores, std::vector<int> &loc_preds_shape,
std::vector<int> &structure_probs_shape,
std::vector<std::vector<std::string>> &rec_html_tag_batch,
std::vector<std::vector<std::vector<std::vector<int>>>> &rec_boxes_batch,
std::vector<int> &width_list, std::vector<int> &height_list) {
for (int batch_idx = 0; batch_idx < structure_probs_shape[0]; batch_idx++) {
// image tags and boxs
std::vector<std::string> rec_html_tags;
std::vector<std::vector<std::vector<int>>> rec_boxes;
float score = 0.f;
int count = 0;
float char_score = 0.f;
int char_idx = 0;
// step
for (int step_idx = 0; step_idx < structure_probs_shape[1]; step_idx++) {
std::string html_tag;
std::vector<std::vector<int>> rec_box;
// html tag
int step_start_idx = (batch_idx * structure_probs_shape[1] + step_idx) *
structure_probs_shape[2];
char_idx = int(Utility::argmax(
&structure_probs[step_start_idx],
&structure_probs[step_start_idx + structure_probs_shape[2]]));
char_score = float(*std::max_element(
&structure_probs[step_start_idx],
&structure_probs[step_start_idx + structure_probs_shape[2]]));
html_tag = this->label_list_[char_idx];
if (step_idx > 0 && html_tag == this->end) {
break;
}
if (html_tag == this->beg) {
continue;
}
count += 1;
score += char_score;
rec_html_tags.push_back(html_tag);
// box
if (html_tag == "<td>" || html_tag == "<td") {
for (int point_idx = 0; point_idx < loc_preds_shape[2];
point_idx += 2) {
std::vector<int> point(2, 0);
step_start_idx = (batch_idx * structure_probs_shape[1] + step_idx) *
loc_preds_shape[2] +
point_idx;
point[0] = int(loc_preds[step_start_idx] * width_list[batch_idx]);
point[1] =
int(loc_preds[step_start_idx + 1] * height_list[batch_idx]);
rec_box.push_back(point);
}
rec_boxes.push_back(rec_box);
}
}
score /= count;
if (isnan(score) || rec_boxes.size() == 0 || rec_html_tags.size() == 0) {
score = -1;
}
rec_scores.push_back(score);
rec_boxes_batch.push_back(rec_boxes);
rec_html_tag_batch.push_back(rec_html_tags);
}
}
} // namespace PaddleOCR
......@@ -69,18 +69,28 @@ void Normalize::Run(cv::Mat *im, const std::vector<float> &mean,
}
void ResizeImgType0::Run(const cv::Mat &img, cv::Mat &resize_img,
int max_size_len, float &ratio_h, float &ratio_w,
bool use_tensorrt) {
string limit_type, int limit_side_len, float &ratio_h,
float &ratio_w, bool use_tensorrt) {
int w = img.cols;
int h = img.rows;
float ratio = 1.f;
int max_wh = w >= h ? w : h;
if (max_wh > max_size_len) {
if (limit_type == "min") {
int min_wh = min(h, w);
if (min_wh < limit_side_len) {
if (h < w) {
ratio = float(limit_side_len) / float(h);
} else {
ratio = float(limit_side_len) / float(w);
}
}
} else {
int max_wh = max(h, w);
if (max_wh > limit_side_len) {
if (h > w) {
ratio = float(max_size_len) / float(h);
ratio = float(limit_side_len) / float(h);
} else {
ratio = float(max_size_len) / float(w);
ratio = float(limit_side_len) / float(w);
}
}
}
......@@ -143,4 +153,26 @@ void ClsResizeImg::Run(const cv::Mat &img, cv::Mat &resize_img,
}
}
void TableResizeImg::Run(const cv::Mat &img, cv::Mat &resize_img,
const int max_len) {
int w = img.cols;
int h = img.rows;
int max_wh = w >= h ? w : h;
float ratio = w >= h ? float(max_len) / float(w) : float(max_len) / float(h);
int resize_h = int(float(h) * ratio);
int resize_w = int(float(w) * ratio);
cv::resize(img, resize_img, cv::Size(resize_w, resize_h));
}
void TablePadImg::Run(const cv::Mat &img, cv::Mat &resize_img,
const int max_len) {
int w = img.cols;
int h = img.rows;
cv::copyMakeBorder(img, resize_img, 0, max_len - h, 0, max_len - w,
cv::BORDER_CONSTANT, cv::Scalar(0, 0, 0));
}
} // namespace PaddleOCR
// Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include <include/structure_table.h>
namespace PaddleOCR {
void StructureTableRecognizer::Run(
std::vector<cv::Mat> img_list,
std::vector<std::vector<std::string>> &structure_html_tags,
std::vector<float> &structure_scores,
std::vector<std::vector<std::vector<std::vector<int>>>> &structure_boxes,
std::vector<double> &times) {
std::chrono::duration<float> preprocess_diff =
std::chrono::steady_clock::now() - std::chrono::steady_clock::now();
std::chrono::duration<float> inference_diff =
std::chrono::steady_clock::now() - std::chrono::steady_clock::now();
std::chrono::duration<float> postprocess_diff =
std::chrono::steady_clock::now() - std::chrono::steady_clock::now();
int img_num = img_list.size();
for (int beg_img_no = 0; beg_img_no < img_num;
beg_img_no += this->table_batch_num_) {
// preprocess
auto preprocess_start = std::chrono::steady_clock::now();
int end_img_no = min(img_num, beg_img_no + this->table_batch_num_);
int batch_num = end_img_no - beg_img_no;
std::vector<cv::Mat> norm_img_batch;
std::vector<int> width_list;
std::vector<int> height_list;
for (int ino = beg_img_no; ino < end_img_no; ino++) {
cv::Mat srcimg;
img_list[ino].copyTo(srcimg);
cv::Mat resize_img;
cv::Mat pad_img;
this->resize_op_.Run(srcimg, resize_img, this->table_max_len_);
this->normalize_op_.Run(&resize_img, this->mean_, this->scale_,
this->is_scale_);
this->pad_op_.Run(resize_img, pad_img, this->table_max_len_);
norm_img_batch.push_back(pad_img);
width_list.push_back(srcimg.cols);
height_list.push_back(srcimg.rows);
}
std::vector<float> input(
batch_num * 3 * this->table_max_len_ * this->table_max_len_, 0.0f);
this->permute_op_.Run(norm_img_batch, input.data());
auto preprocess_end = std::chrono::steady_clock::now();
preprocess_diff += preprocess_end - preprocess_start;
// inference.
auto input_names = this->predictor_->GetInputNames();
auto input_t = this->predictor_->GetInputHandle(input_names[0]);
input_t->Reshape(
{batch_num, 3, this->table_max_len_, this->table_max_len_});
auto inference_start = std::chrono::steady_clock::now();
input_t->CopyFromCpu(input.data());
this->predictor_->Run();
auto output_names = this->predictor_->GetOutputNames();
auto output_tensor0 = this->predictor_->GetOutputHandle(output_names[0]);
auto output_tensor1 = this->predictor_->GetOutputHandle(output_names[1]);
std::vector<int> predict_shape0 = output_tensor0->shape();
std::vector<int> predict_shape1 = output_tensor1->shape();
int out_num0 = std::accumulate(predict_shape0.begin(), predict_shape0.end(),
1, std::multiplies<int>());
int out_num1 = std::accumulate(predict_shape1.begin(), predict_shape1.end(),
1, std::multiplies<int>());
std::vector<float> loc_preds;
std::vector<float> structure_probs;
loc_preds.resize(out_num0);
structure_probs.resize(out_num1);
output_tensor0->CopyToCpu(loc_preds.data());
output_tensor1->CopyToCpu(structure_probs.data());
auto inference_end = std::chrono::steady_clock::now();
inference_diff += inference_end - inference_start;
// postprocess
auto postprocess_start = std::chrono::steady_clock::now();
std::vector<std::vector<std::string>> structure_html_tag_batch;
std::vector<float> structure_score_batch;
std::vector<std::vector<std::vector<std::vector<int>>>>
structure_boxes_batch;
this->post_processor_.Run(loc_preds, structure_probs, structure_score_batch,
predict_shape0, predict_shape1,
structure_html_tag_batch, structure_boxes_batch,
width_list, height_list);
for (int m = 0; m < predict_shape0[0]; m++) {
structure_html_tag_batch[m].insert(structure_html_tag_batch[m].begin(),
"<table>");
structure_html_tag_batch[m].insert(structure_html_tag_batch[m].begin(),
"<body>");
structure_html_tag_batch[m].insert(structure_html_tag_batch[m].begin(),
"<html>");
structure_html_tag_batch[m].push_back("</table>");
structure_html_tag_batch[m].push_back("</body>");
structure_html_tag_batch[m].push_back("</html>");
structure_html_tags.push_back(structure_html_tag_batch[m]);
structure_scores.push_back(structure_score_batch[m]);
structure_boxes.push_back(structure_boxes_batch[m]);
}
auto postprocess_end = std::chrono::steady_clock::now();
postprocess_diff += postprocess_end - postprocess_start;
times.push_back(double(preprocess_diff.count() * 1000));
times.push_back(double(inference_diff.count() * 1000));
times.push_back(double(postprocess_diff.count() * 1000));
}
}
void StructureTableRecognizer::LoadModel(const std::string &model_dir) {
AnalysisConfig config;
config.SetModel(model_dir + "/inference.pdmodel",
model_dir + "/inference.pdiparams");
if (this->use_gpu_) {
config.EnableUseGpu(this->gpu_mem_, this->gpu_id_);
if (this->use_tensorrt_) {
auto precision = paddle_infer::Config::Precision::kFloat32;
if (this->precision_ == "fp16") {
precision = paddle_infer::Config::Precision::kHalf;
}
if (this->precision_ == "int8") {
precision = paddle_infer::Config::Precision::kInt8;
}
config.EnableTensorRtEngine(1 << 20, 10, 3, precision, false, false);
}
} else {
config.DisableGpu();
if (this->use_mkldnn_) {
config.EnableMKLDNN();
}
config.SetCpuMathLibraryNumThreads(this->cpu_math_library_num_threads_);
}
// false for zero copy tensor
config.SwitchUseFeedFetchOps(false);
// true for multiple input
config.SwitchSpecifyInputNames(true);
config.SwitchIrOptim(true);
config.EnableMemoryOptim();
config.DisableGlogInfo();
this->predictor_ = CreatePredictor(config);
}
} // namespace PaddleOCR
......@@ -248,4 +248,33 @@ void Utility::print_result(const std::vector<OCRPredictResult> &ocr_result) {
std::cout << std::endl;
}
}
cv::Mat Utility::crop_image(cv::Mat &img, std::vector<int> &area) {
cv::Mat crop_im;
int crop_x1 = std::max(0, area[0]);
int crop_y1 = std::max(0, area[1]);
int crop_x2 = std::min(img.cols - 1, area[2] - 1);
int crop_y2 = std::min(img.rows - 1, area[3] - 1);
crop_im = cv::Mat::zeros(area[3] - area[1], area[2] - area[0], 16);
cv::Mat crop_im_window =
crop_im(cv::Range(crop_y1 - area[1], crop_y2 + 1 - area[1]),
cv::Range(crop_x1 - area[0], crop_x2 + 1 - area[0]));
cv::Mat roi_img =
img(cv::Range(crop_y1, crop_y2 + 1), cv::Range(crop_x1, crop_x2 + 1));
crop_im_window += roi_img;
return crop_im;
}
void Utility::sorted_boxes(std::vector<OCRPredictResult> &ocr_result) {
std::sort(ocr_result.begin(), ocr_result.end(), Utility::comparison_box);
for (int i = 0; i < ocr_result.size() - 1; i++) {
if (abs(ocr_result[i + 1].box[0][1] - ocr_result[i].box[0][1]) < 10 &&
(ocr_result[i + 1].box[0][0] < ocr_result[i].box[0][0])) {
std::swap(ocr_result[i], ocr_result[i + 1]);
}
}
}
} // namespace PaddleOCR
\ No newline at end of file
......@@ -62,7 +62,7 @@ class TableMatch:
def __call__(self, structure_res, dt_boxes, rec_res):
pred_structures, pred_bboxes = structure_res
if self.filter_ocr_result:
dt_boxes, rec_res = self.filter_ocr_result(pred_bboxes, dt_boxes,
dt_boxes, rec_res = self._filter_ocr_result(pred_bboxes, dt_boxes,
rec_res)
matched_index = self.match_result(dt_boxes, pred_bboxes)
if self.use_master:
......@@ -179,7 +179,7 @@ class TableMatch:
html = deal_bb(html)
return html, end_html
def filter_ocr_result(self, pred_bboxes, dt_boxes, rec_res):
def _filter_ocr_result(self, pred_bboxes, dt_boxes, rec_res):
y1 = pred_bboxes[:, 1::2].min()
new_dt_boxes = []
new_rec_res = []
......
......@@ -70,7 +70,7 @@ class TableSystem(object):
if args.table_algorithm in ['TableMaster']:
self.match = TableMasterMatcher()
else:
self.match = TableMatch()
self.match = TableMatch(filter_ocr_result=True)
self.benchmark = args.benchmark
self.predictor, self.input_tensor, self.output_tensors, self.config = utility.create_predictor(
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册