未验证 提交 1c0e4965 编写于 作者: Z zhoujun 提交者: GitHub

cp of 7550 (#7566)

* fix table cpp infer bug

* update vis

* update doc

* change default tale dict to ch
上级 274c216c
...@@ -54,6 +54,7 @@ DECLARE_string(table_model_dir); ...@@ -54,6 +54,7 @@ DECLARE_string(table_model_dir);
DECLARE_int32(table_max_len); DECLARE_int32(table_max_len);
DECLARE_int32(table_batch_num); DECLARE_int32(table_batch_num);
DECLARE_string(table_char_dict_path); DECLARE_string(table_char_dict_path);
DECLARE_bool(merge_no_span_structure);
// forward related // forward related
DECLARE_bool(det); DECLARE_bool(det);
DECLARE_bool(rec); DECLARE_bool(rec);
......
...@@ -54,15 +54,12 @@ private: ...@@ -54,15 +54,12 @@ private:
std::vector<double> &time_info_det, std::vector<double> &time_info_det,
std::vector<double> &time_info_rec, std::vector<double> &time_info_rec,
std::vector<double> &time_info_cls); std::vector<double> &time_info_cls);
std::string std::string rebuild_table(std::vector<std::string> rec_html_tags,
rebuild_table(std::vector<std::string> rec_html_tags, std::vector<std::vector<int>> rec_boxes,
std::vector<std::vector<std::vector<int>>> rec_boxes, std::vector<OCRPredictResult> &ocr_result);
std::vector<OCRPredictResult> &ocr_result);
float iou(std::vector<std::vector<int>> &box1, float iou(std::vector<int> &box1, std::vector<int> &box2);
std::vector<std::vector<int>> &box2); float dis(std::vector<int> &box1, std::vector<int> &box2);
float dis(std::vector<std::vector<int>> &box1,
std::vector<std::vector<int>> &box2);
static bool comparison_dis(const std::vector<float> &dis1, static bool comparison_dis(const std::vector<float> &dis1,
const std::vector<float> &dis2) { const std::vector<float> &dis2) {
......
...@@ -92,14 +92,13 @@ private: ...@@ -92,14 +92,13 @@ private:
class TablePostProcessor { class TablePostProcessor {
public: public:
void init(std::string label_path); void init(std::string label_path, bool merge_no_span_structure = true);
void void Run(std::vector<float> &loc_preds, std::vector<float> &structure_probs,
Run(std::vector<float> &loc_preds, std::vector<float> &structure_probs, std::vector<float> &rec_scores, std::vector<int> &loc_preds_shape,
std::vector<float> &rec_scores, std::vector<int> &loc_preds_shape, std::vector<int> &structure_probs_shape,
std::vector<int> &structure_probs_shape, std::vector<std::vector<std::string>> &rec_html_tag_batch,
std::vector<std::vector<std::string>> &rec_html_tag_batch, std::vector<std::vector<std::vector<int>>> &rec_boxes_batch,
std::vector<std::vector<std::vector<std::vector<int>>>> &rec_boxes_batch, std::vector<int> &width_list, std::vector<int> &height_list);
std::vector<int> &width_list, std::vector<int> &height_list);
private: private:
std::vector<std::string> label_list_; std::vector<std::string> label_list_;
......
...@@ -44,7 +44,8 @@ public: ...@@ -44,7 +44,8 @@ public:
const int &gpu_mem, const int &cpu_math_library_num_threads, const int &gpu_mem, const int &cpu_math_library_num_threads,
const bool &use_mkldnn, const string &label_path, const bool &use_mkldnn, const string &label_path,
const bool &use_tensorrt, const std::string &precision, const bool &use_tensorrt, const std::string &precision,
const int &table_batch_num, const int &table_max_len) { const int &table_batch_num, const int &table_max_len,
const bool &merge_no_span_structure) {
this->use_gpu_ = use_gpu; this->use_gpu_ = use_gpu;
this->gpu_id_ = gpu_id; this->gpu_id_ = gpu_id;
this->gpu_mem_ = gpu_mem; this->gpu_mem_ = gpu_mem;
...@@ -55,7 +56,7 @@ public: ...@@ -55,7 +56,7 @@ public:
this->table_batch_num_ = table_batch_num; this->table_batch_num_ = table_batch_num;
this->table_max_len_ = table_max_len; this->table_max_len_ = table_max_len;
this->post_processor_.init(label_path); this->post_processor_.init(label_path, merge_no_span_structure);
LoadModel(model_dir); LoadModel(model_dir);
} }
...@@ -65,7 +66,7 @@ public: ...@@ -65,7 +66,7 @@ public:
void Run(std::vector<cv::Mat> img_list, void Run(std::vector<cv::Mat> img_list,
std::vector<std::vector<std::string>> &rec_html_tags, std::vector<std::vector<std::string>> &rec_html_tags,
std::vector<float> &rec_scores, std::vector<float> &rec_scores,
std::vector<std::vector<std::vector<std::vector<int>>>> &rec_boxes, std::vector<std::vector<std::vector<int>>> &rec_boxes,
std::vector<double> &times); std::vector<double> &times);
private: private:
......
...@@ -42,6 +42,7 @@ struct OCRPredictResult { ...@@ -42,6 +42,7 @@ struct OCRPredictResult {
struct StructurePredictResult { struct StructurePredictResult {
std::vector<int> box; std::vector<int> box;
std::vector<std::vector<int>> cell_box;
std::string type; std::string type;
std::vector<OCRPredictResult> text_res; std::vector<OCRPredictResult> text_res;
std::string html; std::string html;
...@@ -56,6 +57,10 @@ public: ...@@ -56,6 +57,10 @@ public:
const std::vector<OCRPredictResult> &ocr_result, const std::vector<OCRPredictResult> &ocr_result,
const std::string &save_path); const std::string &save_path);
static void VisualizeBboxes(const cv::Mat &srcimg,
const StructurePredictResult &structure_result,
const std::string &save_path);
template <class ForwardIterator> template <class ForwardIterator>
inline static size_t argmax(ForwardIterator first, ForwardIterator last) { inline static size_t argmax(ForwardIterator first, ForwardIterator last) {
return std::distance(first, std::max_element(first, last)); return std::distance(first, std::max_element(first, last));
...@@ -81,6 +86,9 @@ public: ...@@ -81,6 +86,9 @@ public:
static void sorted_boxes(std::vector<OCRPredictResult> &ocr_result); static void sorted_boxes(std::vector<OCRPredictResult> &ocr_result);
static std::vector<int> xyxyxyxy2xyxy(std::vector<std::vector<int>> &box);
static std::vector<int> xyxyxyxy2xyxy(std::vector<int> &box);
private: private:
static bool comparison_box(const OCRPredictResult &result1, static bool comparison_box(const OCRPredictResult &result1,
const OCRPredictResult &result2) { const OCRPredictResult &result2) {
......
...@@ -350,6 +350,7 @@ More parameters are as follows, ...@@ -350,6 +350,7 @@ More parameters are as follows,
|table_model_dir|string|-|Address of table recognition inference model| |table_model_dir|string|-|Address of table recognition inference model|
|table_char_dict_path|string|../../ppocr/utils/dict/table_structure_dict.txt|dictionary file| |table_char_dict_path|string|../../ppocr/utils/dict/table_structure_dict.txt|dictionary file|
|table_max_len|int|488|The size of the long side of the input image of the table recognition model, the final input image size of the network is(table_max_len,table_max_len)| |table_max_len|int|488|The size of the long side of the input image of the table recognition model, the final input image size of the network is(table_max_len,table_max_len)|
|merge_no_span_structure|bool|true|Whether to merge <td> and </td> to <td></td|
* Multi-language inference is also supported in PaddleOCR, you can refer to [recognition tutorial](../../doc/doc_en/recognition_en.md) for more supported languages and models in PaddleOCR. Specifically, if you want to infer using multi-language models, you just need to modify values of `rec_char_dict_path` and `rec_model_dir`. * Multi-language inference is also supported in PaddleOCR, you can refer to [recognition tutorial](../../doc/doc_en/recognition_en.md) for more supported languages and models in PaddleOCR. Specifically, if you want to infer using multi-language models, you just need to modify values of `rec_char_dict_path` and `rec_model_dir`.
......
...@@ -359,6 +359,7 @@ CUDNN_LIB_DIR=/your_cudnn_lib_dir ...@@ -359,6 +359,7 @@ CUDNN_LIB_DIR=/your_cudnn_lib_dir
|table_model_dir|string|-|表格识别模型inference model地址| |table_model_dir|string|-|表格识别模型inference model地址|
|table_char_dict_path|string|../../ppocr/utils/dict/table_structure_dict.txt|字典文件| |table_char_dict_path|string|../../ppocr/utils/dict/table_structure_dict.txt|字典文件|
|table_max_len|int|488|表格识别模型输入图像长边大小,最终网络输入图像大小为(table_max_len,table_max_len)| |table_max_len|int|488|表格识别模型输入图像长边大小,最终网络输入图像大小为(table_max_len,table_max_len)|
|merge_no_span_structure|bool|true|是否合并<td></td><td></td>|
* PaddleOCR也支持多语言的预测,更多支持的语言和模型可以参考[识别文档](../../doc/doc_ch/recognition.md)中的多语言字典与模型部分,如果希望进行多语言预测,只需将修改`rec_char_dict_path`(字典文件路径)以及`rec_model_dir`(inference模型路径)字段即可。 * PaddleOCR也支持多语言的预测,更多支持的语言和模型可以参考[识别文档](../../doc/doc_ch/recognition.md)中的多语言字典与模型部分,如果希望进行多语言预测,只需将修改`rec_char_dict_path`(字典文件路径)以及`rec_model_dir`(inference模型路径)字段即可。
......
...@@ -55,8 +55,10 @@ DEFINE_int32(rec_img_w, 320, "rec image width"); ...@@ -55,8 +55,10 @@ DEFINE_int32(rec_img_w, 320, "rec image width");
DEFINE_string(table_model_dir, "", "Path of table struture inference model."); DEFINE_string(table_model_dir, "", "Path of table struture inference model.");
DEFINE_int32(table_max_len, 488, "max len size of input image."); DEFINE_int32(table_max_len, 488, "max len size of input image.");
DEFINE_int32(table_batch_num, 1, "table_batch_num."); DEFINE_int32(table_batch_num, 1, "table_batch_num.");
DEFINE_bool(merge_no_span_structure, true,
"Whether merge <td> and </td> to <td></td>");
DEFINE_string(table_char_dict_path, DEFINE_string(table_char_dict_path,
"../../ppocr/utils/dict/table_structure_dict.txt", "../../ppocr/utils/dict/table_structure_dict_ch.txt",
"Path of dictionary."); "Path of dictionary.");
// ocr forward related // ocr forward related
......
...@@ -120,6 +120,7 @@ void structure(std::vector<cv::String> &cv_all_img_names) { ...@@ -120,6 +120,7 @@ void structure(std::vector<cv::String> &cv_all_img_names) {
engine.structure(cv_all_img_names, false, FLAGS_table); engine.structure(cv_all_img_names, false, FLAGS_table);
for (int i = 0; i < cv_all_img_names.size(); i++) { for (int i = 0; i < cv_all_img_names.size(); i++) {
cout << "predict img: " << cv_all_img_names[i] << endl; cout << "predict img: " << cv_all_img_names[i] << endl;
cv::Mat srcimg = cv::imread(cv_all_img_names[i], cv::IMREAD_COLOR);
for (int j = 0; j < structure_results[i].size(); j++) { for (int j = 0; j < structure_results[i].size(); j++) {
std::cout << j << "\ttype: " << structure_results[i][j].type std::cout << j << "\ttype: " << structure_results[i][j].type
<< ", region: ["; << ", region: [";
...@@ -129,6 +130,11 @@ void structure(std::vector<cv::String> &cv_all_img_names) { ...@@ -129,6 +130,11 @@ void structure(std::vector<cv::String> &cv_all_img_names) {
<< structure_results[i][j].box[3] << "], res: "; << structure_results[i][j].box[3] << "], res: ";
if (structure_results[i][j].type == "table") { if (structure_results[i][j].type == "table") {
std::cout << structure_results[i][j].html << std::endl; std::cout << structure_results[i][j].html << std::endl;
std::string file_name = Utility::basename(cv_all_img_names[i]);
Utility::VisualizeBboxes(srcimg, structure_results[i][j],
FLAGS_output + "/" + std::to_string(j) + "_" +
file_name);
} else { } else {
Utility::print_result(structure_results[i][j].text_res); Utility::print_result(structure_results[i][j].text_res);
} }
......
...@@ -27,7 +27,7 @@ PaddleStructure::PaddleStructure() { ...@@ -27,7 +27,7 @@ PaddleStructure::PaddleStructure() {
FLAGS_table_model_dir, FLAGS_use_gpu, FLAGS_gpu_id, FLAGS_gpu_mem, FLAGS_table_model_dir, FLAGS_use_gpu, FLAGS_gpu_id, FLAGS_gpu_mem,
FLAGS_cpu_threads, FLAGS_enable_mkldnn, FLAGS_table_char_dict_path, FLAGS_cpu_threads, FLAGS_enable_mkldnn, FLAGS_table_char_dict_path,
FLAGS_use_tensorrt, FLAGS_precision, FLAGS_table_batch_num, FLAGS_use_tensorrt, FLAGS_precision, FLAGS_table_batch_num,
FLAGS_table_max_len); FLAGS_table_max_len, FLAGS_merge_no_span_structure);
} }
}; };
...@@ -42,7 +42,7 @@ PaddleStructure::structure(std::vector<cv::String> cv_all_img_names, ...@@ -42,7 +42,7 @@ PaddleStructure::structure(std::vector<cv::String> cv_all_img_names,
std::vector<std::vector<StructurePredictResult>> structure_results; std::vector<std::vector<StructurePredictResult>> structure_results;
if (!Utility::PathExists(FLAGS_output) && FLAGS_det) { if (!Utility::PathExists(FLAGS_output) && FLAGS_det) {
mkdir(FLAGS_output.c_str(), 0777); Utility::CreateDir(FLAGS_output);
} }
for (int i = 0; i < cv_all_img_names.size(); ++i) { for (int i = 0; i < cv_all_img_names.size(); ++i) {
std::vector<StructurePredictResult> structure_result; std::vector<StructurePredictResult> structure_result;
...@@ -84,7 +84,7 @@ void PaddleStructure::table(cv::Mat img, ...@@ -84,7 +84,7 @@ void PaddleStructure::table(cv::Mat img,
// predict structure // predict structure
std::vector<std::vector<std::string>> structure_html_tags; std::vector<std::vector<std::string>> structure_html_tags;
std::vector<float> structure_scores(1, 0); std::vector<float> structure_scores(1, 0);
std::vector<std::vector<std::vector<std::vector<int>>>> structure_boxes; std::vector<std::vector<std::vector<int>>> structure_boxes;
std::vector<double> structure_imes; std::vector<double> structure_imes;
std::vector<cv::Mat> img_list; std::vector<cv::Mat> img_list;
img_list.push_back(img); img_list.push_back(img);
...@@ -103,20 +103,15 @@ void PaddleStructure::table(cv::Mat img, ...@@ -103,20 +103,15 @@ void PaddleStructure::table(cv::Mat img,
this->det(img_list[i], ocr_result, time_info_det); this->det(img_list[i], ocr_result, time_info_det);
// crop image // crop image
std::vector<cv::Mat> rec_img_list; std::vector<cv::Mat> rec_img_list;
std::vector<int> ocr_box;
for (int j = 0; j < ocr_result.size(); j++) { for (int j = 0; j < ocr_result.size(); j++) {
int x_collect[4] = {ocr_result[j].box[0][0], ocr_result[j].box[1][0], ocr_box = Utility::xyxyxyxy2xyxy(ocr_result[j].box);
ocr_result[j].box[2][0], ocr_result[j].box[3][0]}; ocr_box[0] = max(0, ocr_box[0] - expand_pixel);
int y_collect[4] = {ocr_result[j].box[0][1], ocr_result[j].box[1][1], ocr_box[1] = max(0, ocr_box[1] - expand_pixel),
ocr_result[j].box[2][1], ocr_result[j].box[3][1]}; ocr_box[2] = min(img_list[i].cols, ocr_box[2] + expand_pixel);
int left = int(*std::min_element(x_collect, x_collect + 4)); ocr_box[3] = min(img_list[i].rows, ocr_box[3] + expand_pixel);
int right = int(*std::max_element(x_collect, x_collect + 4));
int top = int(*std::min_element(y_collect, y_collect + 4)); cv::Mat crop_img = Utility::crop_image(img_list[i], ocr_box);
int bottom = int(*std::max_element(y_collect, y_collect + 4));
std::vector<int> box{max(0, left - expand_pixel),
max(0, top - expand_pixel),
min(img_list[i].cols, right + expand_pixel),
min(img_list[i].rows, bottom + expand_pixel)};
cv::Mat crop_img = Utility::crop_image(img_list[i], box);
rec_img_list.push_back(crop_img); rec_img_list.push_back(crop_img);
} }
// rec // rec
...@@ -125,38 +120,37 @@ void PaddleStructure::table(cv::Mat img, ...@@ -125,38 +120,37 @@ void PaddleStructure::table(cv::Mat img,
html = this->rebuild_table(structure_html_tags[i], structure_boxes[i], html = this->rebuild_table(structure_html_tags[i], structure_boxes[i],
ocr_result); ocr_result);
structure_result.html = html; structure_result.html = html;
structure_result.cell_box = structure_boxes[i];
structure_result.html_score = structure_scores[i]; structure_result.html_score = structure_scores[i];
} }
}; };
std::string PaddleStructure::rebuild_table( std::string
std::vector<std::string> structure_html_tags, PaddleStructure::rebuild_table(std::vector<std::string> structure_html_tags,
std::vector<std::vector<std::vector<int>>> structure_boxes, std::vector<std::vector<int>> structure_boxes,
std::vector<OCRPredictResult> &ocr_result) { std::vector<OCRPredictResult> &ocr_result) {
// match text in same cell // match text in same cell
std::vector<std::vector<string>> matched(structure_boxes.size(), std::vector<std::vector<string>> matched(structure_boxes.size(),
std::vector<std::string>()); std::vector<std::string>());
std::vector<int> ocr_box;
std::vector<int> structure_box;
for (int i = 0; i < ocr_result.size(); i++) { for (int i = 0; i < ocr_result.size(); i++) {
ocr_box = Utility::xyxyxyxy2xyxy(ocr_result[i].box);
ocr_box[0] -= 1;
ocr_box[1] -= 1;
ocr_box[2] += 1;
ocr_box[3] += 1;
std::vector<std::vector<float>> dis_list(structure_boxes.size(), std::vector<std::vector<float>> dis_list(structure_boxes.size(),
std::vector<float>(3, 100000.0)); std::vector<float>(3, 100000.0));
for (int j = 0; j < structure_boxes.size(); j++) { for (int j = 0; j < structure_boxes.size(); j++) {
int x_collect[4] = {ocr_result[i].box[0][0], ocr_result[i].box[1][0], if (structure_boxes[i].size() == 8) {
ocr_result[i].box[2][0], ocr_result[i].box[3][0]}; structure_box = Utility::xyxyxyxy2xyxy(structure_boxes[j]);
int y_collect[4] = {ocr_result[i].box[0][1], ocr_result[i].box[1][1], } else {
ocr_result[i].box[2][1], ocr_result[i].box[3][1]}; structure_box = structure_boxes[j];
int left = int(*std::min_element(x_collect, x_collect + 4)); }
int right = int(*std::max_element(x_collect, x_collect + 4)); dis_list[j][0] = this->dis(ocr_box, structure_box);
int top = int(*std::min_element(y_collect, y_collect + 4)); dis_list[j][1] = 1 - this->iou(ocr_box, structure_box);
int bottom = int(*std::max_element(y_collect, y_collect + 4));
std::vector<std::vector<int>> box(2, std::vector<int>(2, 0));
box[0][0] = left - 1;
box[0][1] = top - 1;
box[1][0] = right + 1;
box[1][1] = bottom + 1;
dis_list[j][0] = this->dis(box, structure_boxes[j]);
dis_list[j][1] = 1 - this->iou(box, structure_boxes[j]);
dis_list[j][2] = j; dis_list[j][2] = j;
} }
// find min dis idx // find min dis idx
...@@ -164,6 +158,7 @@ std::string PaddleStructure::rebuild_table( ...@@ -164,6 +158,7 @@ std::string PaddleStructure::rebuild_table(
PaddleStructure::comparison_dis); PaddleStructure::comparison_dis);
matched[dis_list[0][2]].push_back(ocr_result[i].text); matched[dis_list[0][2]].push_back(ocr_result[i].text);
} }
// get pred html // get pred html
std::string html_str = ""; std::string html_str = "";
int td_tag_idx = 0; int td_tag_idx = 0;
...@@ -221,19 +216,18 @@ std::string PaddleStructure::rebuild_table( ...@@ -221,19 +216,18 @@ std::string PaddleStructure::rebuild_table(
return html_str; return html_str;
} }
float PaddleStructure::iou(std::vector<std::vector<int>> &box1, float PaddleStructure::iou(std::vector<int> &box1, std::vector<int> &box2) {
std::vector<std::vector<int>> &box2) { int area1 = max(0, box1[2] - box1[0]) * max(0, box1[3] - box1[1]);
int area1 = max(0, box1[1][0] - box1[0][0]) * max(0, box1[1][1] - box1[0][1]); int area2 = max(0, box2[2] - box2[0]) * max(0, box2[3] - box2[1]);
int area2 = max(0, box2[1][0] - box2[0][0]) * max(0, box2[1][1] - box2[0][1]);
// computing the sum_area // computing the sum_area
int sum_area = area1 + area2; int sum_area = area1 + area2;
// find the each point of intersect rectangle // find the each point of intersect rectangle
int x1 = max(box1[0][0], box2[0][0]); int x1 = max(box1[0], box2[0]);
int y1 = max(box1[0][1], box2[0][1]); int y1 = max(box1[1], box2[1]);
int x2 = min(box1[1][0], box2[1][0]); int x2 = min(box1[2], box2[2]);
int y2 = min(box1[1][1], box2[1][1]); int y2 = min(box1[3], box2[3]);
// judge if there is an intersect // judge if there is an intersect
if (y1 >= y2 || x1 >= x2) { if (y1 >= y2 || x1 >= x2) {
...@@ -244,17 +238,16 @@ float PaddleStructure::iou(std::vector<std::vector<int>> &box1, ...@@ -244,17 +238,16 @@ float PaddleStructure::iou(std::vector<std::vector<int>> &box1,
} }
} }
float PaddleStructure::dis(std::vector<std::vector<int>> &box1, float PaddleStructure::dis(std::vector<int> &box1, std::vector<int> &box2) {
std::vector<std::vector<int>> &box2) { int x1_1 = box1[0];
int x1_1 = box1[0][0]; int y1_1 = box1[1];
int y1_1 = box1[0][1]; int x2_1 = box1[2];
int x2_1 = box1[1][0]; int y2_1 = box1[3];
int y2_1 = box1[1][1];
int x1_2 = box2[0][0]; int x1_2 = box2[0];
int y1_2 = box2[0][1]; int y1_2 = box2[1];
int x2_2 = box2[1][0]; int x2_2 = box2[2];
int y2_2 = box2[1][1]; int y2_2 = box2[3];
float dis = float dis =
abs(x1_2 - x1_1) + abs(y1_2 - y1_1) + abs(x2_2 - x2_1) + abs(y2_2 - y2_1); abs(x1_2 - x1_1) + abs(y1_2 - y1_1) + abs(x2_2 - x2_1) + abs(y2_2 - y2_1);
......
...@@ -352,8 +352,21 @@ std::vector<std::vector<std::vector<int>>> DBPostProcessor::FilterTagDetRes( ...@@ -352,8 +352,21 @@ std::vector<std::vector<std::vector<int>>> DBPostProcessor::FilterTagDetRes(
return root_points; return root_points;
} }
void TablePostProcessor::init(std::string label_path) { void TablePostProcessor::init(std::string label_path,
bool merge_no_span_structure) {
this->label_list_ = Utility::ReadDict(label_path); this->label_list_ = Utility::ReadDict(label_path);
if (merge_no_span_structure) {
this->label_list_.push_back("<td></td>");
std::vector<std::string>::iterator it;
for (it = this->label_list_.begin(); it != this->label_list_.end();) {
if (*it == "<td>") {
it = this->label_list_.erase(it);
} else {
++it;
}
}
}
// add_special_char
this->label_list_.insert(this->label_list_.begin(), this->beg); this->label_list_.insert(this->label_list_.begin(), this->beg);
this->label_list_.push_back(this->end); this->label_list_.push_back(this->end);
} }
...@@ -363,12 +376,12 @@ void TablePostProcessor::Run( ...@@ -363,12 +376,12 @@ void TablePostProcessor::Run(
std::vector<float> &rec_scores, std::vector<int> &loc_preds_shape, std::vector<float> &rec_scores, std::vector<int> &loc_preds_shape,
std::vector<int> &structure_probs_shape, std::vector<int> &structure_probs_shape,
std::vector<std::vector<std::string>> &rec_html_tag_batch, std::vector<std::vector<std::string>> &rec_html_tag_batch,
std::vector<std::vector<std::vector<std::vector<int>>>> &rec_boxes_batch, std::vector<std::vector<std::vector<int>>> &rec_boxes_batch,
std::vector<int> &width_list, std::vector<int> &height_list) { std::vector<int> &width_list, std::vector<int> &height_list) {
for (int batch_idx = 0; batch_idx < structure_probs_shape[0]; batch_idx++) { for (int batch_idx = 0; batch_idx < structure_probs_shape[0]; batch_idx++) {
// image tags and boxs // image tags and boxs
std::vector<std::string> rec_html_tags; std::vector<std::string> rec_html_tags;
std::vector<std::vector<std::vector<int>>> rec_boxes; std::vector<std::vector<int>> rec_boxes;
float score = 0.f; float score = 0.f;
int count = 0; int count = 0;
...@@ -378,7 +391,7 @@ void TablePostProcessor::Run( ...@@ -378,7 +391,7 @@ void TablePostProcessor::Run(
// step // step
for (int step_idx = 0; step_idx < structure_probs_shape[1]; step_idx++) { for (int step_idx = 0; step_idx < structure_probs_shape[1]; step_idx++) {
std::string html_tag; std::string html_tag;
std::vector<std::vector<int>> rec_box; std::vector<int> rec_box;
// html tag // html tag
int step_start_idx = (batch_idx * structure_probs_shape[1] + step_idx) * int step_start_idx = (batch_idx * structure_probs_shape[1] + step_idx) *
structure_probs_shape[2]; structure_probs_shape[2];
...@@ -399,17 +412,19 @@ void TablePostProcessor::Run( ...@@ -399,17 +412,19 @@ void TablePostProcessor::Run(
count += 1; count += 1;
score += char_score; score += char_score;
rec_html_tags.push_back(html_tag); rec_html_tags.push_back(html_tag);
// box // box
if (html_tag == "<td>" || html_tag == "<td" || html_tag == "<td></td>") { if (html_tag == "<td>" || html_tag == "<td" || html_tag == "<td></td>") {
for (int point_idx = 0; point_idx < loc_preds_shape[2]; for (int point_idx = 0; point_idx < loc_preds_shape[2]; point_idx++) {
point_idx += 2) {
std::vector<int> point(2, 0);
step_start_idx = (batch_idx * structure_probs_shape[1] + step_idx) * step_start_idx = (batch_idx * structure_probs_shape[1] + step_idx) *
loc_preds_shape[2] + loc_preds_shape[2] +
point_idx; point_idx;
point[0] = int(loc_preds[step_start_idx] * width_list[batch_idx]); float point = loc_preds[step_start_idx];
point[1] = if (point_idx % 2 == 0) {
int(loc_preds[step_start_idx + 1] * height_list[batch_idx]); point = int(point * width_list[batch_idx]);
} else {
point = int(point * height_list[batch_idx]);
}
rec_box.push_back(point); rec_box.push_back(point);
} }
rec_boxes.push_back(rec_box); rec_boxes.push_back(rec_box);
......
...@@ -20,7 +20,7 @@ void StructureTableRecognizer::Run( ...@@ -20,7 +20,7 @@ void StructureTableRecognizer::Run(
std::vector<cv::Mat> img_list, std::vector<cv::Mat> img_list,
std::vector<std::vector<std::string>> &structure_html_tags, std::vector<std::vector<std::string>> &structure_html_tags,
std::vector<float> &structure_scores, std::vector<float> &structure_scores,
std::vector<std::vector<std::vector<std::vector<int>>>> &structure_boxes, std::vector<std::vector<std::vector<int>>> &structure_boxes,
std::vector<double> &times) { std::vector<double> &times) {
std::chrono::duration<float> preprocess_diff = std::chrono::duration<float> preprocess_diff =
std::chrono::steady_clock::now() - std::chrono::steady_clock::now(); std::chrono::steady_clock::now() - std::chrono::steady_clock::now();
...@@ -89,8 +89,7 @@ void StructureTableRecognizer::Run( ...@@ -89,8 +89,7 @@ void StructureTableRecognizer::Run(
auto postprocess_start = std::chrono::steady_clock::now(); auto postprocess_start = std::chrono::steady_clock::now();
std::vector<std::vector<std::string>> structure_html_tag_batch; std::vector<std::vector<std::string>> structure_html_tag_batch;
std::vector<float> structure_score_batch; std::vector<float> structure_score_batch;
std::vector<std::vector<std::vector<std::vector<int>>>> std::vector<std::vector<std::vector<int>>> structure_boxes_batch;
structure_boxes_batch;
this->post_processor_.Run(loc_preds, structure_probs, structure_score_batch, this->post_processor_.Run(loc_preds, structure_probs, structure_score_batch,
predict_shape0, predict_shape1, predict_shape0, predict_shape1,
structure_html_tag_batch, structure_boxes_batch, structure_html_tag_batch, structure_boxes_batch,
......
...@@ -65,6 +65,37 @@ void Utility::VisualizeBboxes(const cv::Mat &srcimg, ...@@ -65,6 +65,37 @@ void Utility::VisualizeBboxes(const cv::Mat &srcimg,
<< std::endl; << std::endl;
} }
void Utility::VisualizeBboxes(const cv::Mat &srcimg,
const StructurePredictResult &structure_result,
const std::string &save_path) {
cv::Mat img_vis;
srcimg.copyTo(img_vis);
for (int n = 0; n < structure_result.cell_box.size(); n++) {
if (structure_result.cell_box[n].size() == 8) {
cv::Point rook_points[4];
for (int m = 0; m < structure_result.cell_box[n].size(); m += 2) {
rook_points[m / 2] =
cv::Point(int(structure_result.cell_box[n][m]),
int(structure_result.cell_box[n][m + 1]));
}
const cv::Point *ppt[1] = {rook_points};
int npt[] = {4};
cv::polylines(img_vis, ppt, npt, 1, 1, CV_RGB(0, 255, 0), 2, 8, 0);
} else if (structure_result.cell_box[n].size() == 4) {
cv::Point rook_points[2];
rook_points[0] = cv::Point(int(structure_result.cell_box[n][0]),
int(structure_result.cell_box[n][1]));
rook_points[1] = cv::Point(int(structure_result.cell_box[n][2]),
int(structure_result.cell_box[n][3]));
cv::rectangle(img_vis, rook_points[0], rook_points[1], CV_RGB(0, 255, 0),
2, 8, 0);
}
}
cv::imwrite(save_path, img_vis);
std::cout << "The table visualized image saved in " + save_path << std::endl;
}
// list all files under a directory // list all files under a directory
void Utility::GetAllFiles(const char *dir_name, void Utility::GetAllFiles(const char *dir_name,
std::vector<std::string> &all_inputs) { std::vector<std::string> &all_inputs) {
...@@ -268,13 +299,46 @@ cv::Mat Utility::crop_image(cv::Mat &img, std::vector<int> &area) { ...@@ -268,13 +299,46 @@ cv::Mat Utility::crop_image(cv::Mat &img, std::vector<int> &area) {
void Utility::sorted_boxes(std::vector<OCRPredictResult> &ocr_result) { void Utility::sorted_boxes(std::vector<OCRPredictResult> &ocr_result) {
std::sort(ocr_result.begin(), ocr_result.end(), Utility::comparison_box); std::sort(ocr_result.begin(), ocr_result.end(), Utility::comparison_box);
if (ocr_result.size() > 0) {
for (int i = 0; i < ocr_result.size() - 1; i++) { for (int i = 0; i < ocr_result.size() - 1; i++) {
if (abs(ocr_result[i + 1].box[0][1] - ocr_result[i].box[0][1]) < 10 && for (int j = i; j > 0; j--) {
(ocr_result[i + 1].box[0][0] < ocr_result[i].box[0][0])) { if (abs(ocr_result[j + 1].box[0][1] - ocr_result[j].box[0][1]) < 10 &&
std::swap(ocr_result[i], ocr_result[i + 1]); (ocr_result[j + 1].box[0][0] < ocr_result[j].box[0][0])) {
std::swap(ocr_result[i], ocr_result[i + 1]);
}
}
} }
} }
} }
std::vector<int> Utility::xyxyxyxy2xyxy(std::vector<std::vector<int>> &box) {
int x_collect[4] = {box[0][0], box[1][0], box[2][0], box[3][0]};
int y_collect[4] = {box[0][1], box[1][1], box[2][1], box[3][1]};
int left = int(*std::min_element(x_collect, x_collect + 4));
int right = int(*std::max_element(x_collect, x_collect + 4));
int top = int(*std::min_element(y_collect, y_collect + 4));
int bottom = int(*std::max_element(y_collect, y_collect + 4));
std::vector<int> box1(4, 0);
box1[0] = left;
box1[1] = top;
box1[2] = right;
box1[3] = bottom;
return box1;
}
std::vector<int> Utility::xyxyxyxy2xyxy(std::vector<int> &box) {
int x_collect[4] = {box[0], box[2], box[4], box[6]};
int y_collect[4] = {box[1], box[3], box[5], box[7]};
int left = int(*std::min_element(x_collect, x_collect + 4));
int right = int(*std::max_element(x_collect, x_collect + 4));
int top = int(*std::min_element(y_collect, y_collect + 4));
int bottom = int(*std::max_element(y_collect, y_collect + 4));
std::vector<int> box1(4, 0);
box1[0] = left;
box1[1] = top;
box1[2] = right;
box1[3] = bottom;
return box1;
}
} // namespace PaddleOCR } // namespace PaddleOCR
\ No newline at end of file
...@@ -32,7 +32,7 @@ def init_args(): ...@@ -32,7 +32,7 @@ def init_args():
parser.add_argument( parser.add_argument(
"--table_char_dict_path", "--table_char_dict_path",
type=str, type=str,
default="../ppocr/utils/dict/table_structure_dict.txt") default="../ppocr/utils/dict/table_structure_dict_ch.txt")
# params for layout # params for layout
parser.add_argument("--layout_model_dir", type=str) parser.add_argument("--layout_model_dir", type=str)
parser.add_argument( parser.add_argument(
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册