提交 344c6d2d 编写于 作者: 文幕地方's avatar 文幕地方

Merge branch 'dygraph' of https://github.com/PaddlePaddle/PaddleOCR into re1

...@@ -88,6 +88,7 @@ Train: ...@@ -88,6 +88,7 @@ Train:
prob: 0.5 prob: 0.5
ext_data_num: 2 ext_data_num: 2
image_shape: [48, 320, 3] image_shape: [48, 320, 3]
max_text_length: *max_text_length
- RecAug: - RecAug:
- MultiLabelEncode: - MultiLabelEncode:
- RecResizeImg: - RecResizeImg:
......
...@@ -162,6 +162,7 @@ Train: ...@@ -162,6 +162,7 @@ Train:
prob: 0.5 prob: 0.5
ext_data_num: 2 ext_data_num: 2
image_shape: [48, 320, 3] image_shape: [48, 320, 3]
max_text_length: *max_text_length
- RecAug: - RecAug:
- MultiLabelEncode: - MultiLabelEncode:
- RecResizeImg: - RecResizeImg:
......
...@@ -88,6 +88,7 @@ Train: ...@@ -88,6 +88,7 @@ Train:
prob: 0.5 prob: 0.5
ext_data_num: 2 ext_data_num: 2
image_shape: [48, 320, 3] image_shape: [48, 320, 3]
max_text_length: *max_text_length
- RecAug: - RecAug:
- MultiLabelEncode: - MultiLabelEncode:
- RecResizeImg: - RecResizeImg:
......
...@@ -49,6 +49,11 @@ DECLARE_int32(rec_batch_num); ...@@ -49,6 +49,11 @@ DECLARE_int32(rec_batch_num);
DECLARE_string(rec_char_dict_path); DECLARE_string(rec_char_dict_path);
DECLARE_int32(rec_img_h); DECLARE_int32(rec_img_h);
DECLARE_int32(rec_img_w); DECLARE_int32(rec_img_w);
// layout model related
DECLARE_string(layout_model_dir);
DECLARE_string(layout_dict_path);
DECLARE_double(layout_score_threshold);
DECLARE_double(layout_nms_threshold);
// structure model related // structure model related
DECLARE_string(table_model_dir); DECLARE_string(table_model_dir);
DECLARE_int32(table_max_len); DECLARE_int32(table_max_len);
...@@ -59,4 +64,5 @@ DECLARE_bool(merge_no_span_structure); ...@@ -59,4 +64,5 @@ DECLARE_bool(merge_no_span_structure);
DECLARE_bool(det); DECLARE_bool(det);
DECLARE_bool(rec); DECLARE_bool(rec);
DECLARE_bool(cls); DECLARE_bool(cls);
DECLARE_bool(table); DECLARE_bool(table);
\ No newline at end of file DECLARE_bool(layout);
\ No newline at end of file
...@@ -14,26 +14,12 @@ ...@@ -14,26 +14,12 @@
#pragma once #pragma once
#include "opencv2/core.hpp"
#include "opencv2/imgcodecs.hpp"
#include "opencv2/imgproc.hpp"
#include "paddle_api.h" #include "paddle_api.h"
#include "paddle_inference_api.h" #include "paddle_inference_api.h"
#include <chrono>
#include <iomanip>
#include <iostream>
#include <ostream>
#include <vector>
#include <cstring>
#include <fstream>
#include <numeric>
#include <include/preprocess_op.h> #include <include/preprocess_op.h>
#include <include/utility.h> #include <include/utility.h>
using namespace paddle_infer;
namespace PaddleOCR { namespace PaddleOCR {
class Classifier { class Classifier {
...@@ -66,7 +52,7 @@ public: ...@@ -66,7 +52,7 @@ public:
std::vector<float> &cls_scores, std::vector<double> &times); std::vector<float> &cls_scores, std::vector<double> &times);
private: private:
std::shared_ptr<Predictor> predictor_; std::shared_ptr<paddle_infer::Predictor> predictor_;
bool use_gpu_ = false; bool use_gpu_ = false;
int gpu_id_ = 0; int gpu_id_ = 0;
......
...@@ -14,26 +14,12 @@ ...@@ -14,26 +14,12 @@
#pragma once #pragma once
#include "opencv2/core.hpp"
#include "opencv2/imgcodecs.hpp"
#include "opencv2/imgproc.hpp"
#include "paddle_api.h" #include "paddle_api.h"
#include "paddle_inference_api.h" #include "paddle_inference_api.h"
#include <chrono>
#include <iomanip>
#include <iostream>
#include <ostream>
#include <vector>
#include <cstring>
#include <fstream>
#include <numeric>
#include <include/postprocess_op.h> #include <include/postprocess_op.h>
#include <include/preprocess_op.h> #include <include/preprocess_op.h>
using namespace paddle_infer;
namespace PaddleOCR { namespace PaddleOCR {
class DBDetector { class DBDetector {
...@@ -41,7 +27,7 @@ public: ...@@ -41,7 +27,7 @@ public:
explicit DBDetector(const std::string &model_dir, const bool &use_gpu, explicit DBDetector(const std::string &model_dir, const bool &use_gpu,
const int &gpu_id, const int &gpu_mem, const int &gpu_id, const int &gpu_mem,
const int &cpu_math_library_num_threads, const int &cpu_math_library_num_threads,
const bool &use_mkldnn, const string &limit_type, const bool &use_mkldnn, const std::string &limit_type,
const int &limit_side_len, const double &det_db_thresh, const int &limit_side_len, const double &det_db_thresh,
const double &det_db_box_thresh, const double &det_db_box_thresh,
const double &det_db_unclip_ratio, const double &det_db_unclip_ratio,
...@@ -77,7 +63,7 @@ public: ...@@ -77,7 +63,7 @@ public:
std::vector<double> &times); std::vector<double> &times);
private: private:
std::shared_ptr<Predictor> predictor_; std::shared_ptr<paddle_infer::Predictor> predictor_;
bool use_gpu_ = false; bool use_gpu_ = false;
int gpu_id_ = 0; int gpu_id_ = 0;
...@@ -85,7 +71,7 @@ private: ...@@ -85,7 +71,7 @@ private:
int cpu_math_library_num_threads_ = 4; int cpu_math_library_num_threads_ = 4;
bool use_mkldnn_ = false; bool use_mkldnn_ = false;
string limit_type_ = "max"; std::string limit_type_ = "max";
int limit_side_len_ = 960; int limit_side_len_ = 960;
double det_db_thresh_ = 0.3; double det_db_thresh_ = 0.3;
......
...@@ -14,27 +14,12 @@ ...@@ -14,27 +14,12 @@
#pragma once #pragma once
#include "opencv2/core.hpp"
#include "opencv2/imgcodecs.hpp"
#include "opencv2/imgproc.hpp"
#include "paddle_api.h" #include "paddle_api.h"
#include "paddle_inference_api.h" #include "paddle_inference_api.h"
#include <chrono>
#include <iomanip>
#include <iostream>
#include <ostream>
#include <vector>
#include <cstring>
#include <fstream>
#include <numeric>
#include <include/ocr_cls.h> #include <include/ocr_cls.h>
#include <include/preprocess_op.h>
#include <include/utility.h> #include <include/utility.h>
using namespace paddle_infer;
namespace PaddleOCR { namespace PaddleOCR {
class CRNNRecognizer { class CRNNRecognizer {
...@@ -42,7 +27,7 @@ public: ...@@ -42,7 +27,7 @@ public:
explicit CRNNRecognizer(const std::string &model_dir, const bool &use_gpu, explicit CRNNRecognizer(const std::string &model_dir, const bool &use_gpu,
const int &gpu_id, const int &gpu_mem, const int &gpu_id, const int &gpu_mem,
const int &cpu_math_library_num_threads, const int &cpu_math_library_num_threads,
const bool &use_mkldnn, const string &label_path, const bool &use_mkldnn, const std::string &label_path,
const bool &use_tensorrt, const bool &use_tensorrt,
const std::string &precision, const std::string &precision,
const int &rec_batch_num, const int &rec_img_h, const int &rec_batch_num, const int &rec_img_h,
...@@ -75,7 +60,7 @@ public: ...@@ -75,7 +60,7 @@ public:
std::vector<float> &rec_text_scores, std::vector<double> &times); std::vector<float> &rec_text_scores, std::vector<double> &times);
private: private:
std::shared_ptr<Predictor> predictor_; std::shared_ptr<paddle_infer::Predictor> predictor_;
bool use_gpu_ = false; bool use_gpu_ = false;
int gpu_id_ = 0; int gpu_id_ = 0;
......
...@@ -14,28 +14,9 @@ ...@@ -14,28 +14,9 @@
#pragma once #pragma once
#include "opencv2/core.hpp"
#include "opencv2/imgcodecs.hpp"
#include "opencv2/imgproc.hpp"
#include "paddle_api.h"
#include "paddle_inference_api.h"
#include <chrono>
#include <iomanip>
#include <iostream>
#include <ostream>
#include <vector>
#include <cstring>
#include <fstream>
#include <numeric>
#include <include/ocr_cls.h> #include <include/ocr_cls.h>
#include <include/ocr_det.h> #include <include/ocr_det.h>
#include <include/ocr_rec.h> #include <include/ocr_rec.h>
#include <include/preprocess_op.h>
#include <include/utility.h>
using namespace paddle_infer;
namespace PaddleOCR { namespace PaddleOCR {
...@@ -43,21 +24,27 @@ class PPOCR { ...@@ -43,21 +24,27 @@ class PPOCR {
public: public:
explicit PPOCR(); explicit PPOCR();
~PPOCR(); ~PPOCR();
std::vector<std::vector<OCRPredictResult>>
ocr(std::vector<cv::String> cv_all_img_names, bool det = true, std::vector<std::vector<OCRPredictResult>> ocr(std::vector<cv::Mat> img_list,
bool rec = true, bool cls = true); bool det = true,
bool rec = true,
bool cls = true);
std::vector<OCRPredictResult> ocr(cv::Mat img, bool det = true,
bool rec = true, bool cls = true);
void reset_timer();
void benchmark_log(int img_num);
protected: protected:
void det(cv::Mat img, std::vector<OCRPredictResult> &ocr_results, std::vector<double> time_info_det = {0, 0, 0};
std::vector<double> &times); std::vector<double> time_info_rec = {0, 0, 0};
std::vector<double> time_info_cls = {0, 0, 0};
void det(cv::Mat img, std::vector<OCRPredictResult> &ocr_results);
void rec(std::vector<cv::Mat> img_list, void rec(std::vector<cv::Mat> img_list,
std::vector<OCRPredictResult> &ocr_results, std::vector<OCRPredictResult> &ocr_results);
std::vector<double> &times);
void cls(std::vector<cv::Mat> img_list, void cls(std::vector<cv::Mat> img_list,
std::vector<OCRPredictResult> &ocr_results, std::vector<OCRPredictResult> &ocr_results);
std::vector<double> &times);
void log(std::vector<double> &det_times, std::vector<double> &rec_times,
std::vector<double> &cls_times, int img_num);
private: private:
DBDetector *detector_ = nullptr; DBDetector *detector_ = nullptr;
......
...@@ -14,27 +14,9 @@ ...@@ -14,27 +14,9 @@
#pragma once #pragma once
#include "opencv2/core.hpp"
#include "opencv2/imgcodecs.hpp"
#include "opencv2/imgproc.hpp"
#include "paddle_api.h"
#include "paddle_inference_api.h"
#include <chrono>
#include <iomanip>
#include <iostream>
#include <ostream>
#include <vector>
#include <cstring>
#include <fstream>
#include <numeric>
#include <include/paddleocr.h> #include <include/paddleocr.h>
#include <include/preprocess_op.h> #include <include/structure_layout.h>
#include <include/structure_table.h> #include <include/structure_table.h>
#include <include/utility.h>
using namespace paddle_infer;
namespace PaddleOCR { namespace PaddleOCR {
...@@ -42,23 +24,31 @@ class PaddleStructure : public PPOCR { ...@@ -42,23 +24,31 @@ class PaddleStructure : public PPOCR {
public: public:
explicit PaddleStructure(); explicit PaddleStructure();
~PaddleStructure(); ~PaddleStructure();
std::vector<std::vector<StructurePredictResult>>
structure(std::vector<cv::String> cv_all_img_names, bool layout = false, std::vector<StructurePredictResult> structure(cv::Mat img,
bool table = true); bool layout = false,
bool table = true,
bool ocr = false);
void reset_timer();
void benchmark_log(int img_num);
private: private:
StructureTableRecognizer *recognizer_ = nullptr; std::vector<double> time_info_table = {0, 0, 0};
std::vector<double> time_info_layout = {0, 0, 0};
StructureTableRecognizer *table_model_ = nullptr;
StructureLayoutRecognizer *layout_model_ = nullptr;
void layout(cv::Mat img,
std::vector<StructurePredictResult> &structure_result);
void table(cv::Mat img, StructurePredictResult &structure_result);
void table(cv::Mat img, StructurePredictResult &structure_result,
std::vector<double> &time_info_table,
std::vector<double> &time_info_det,
std::vector<double> &time_info_rec,
std::vector<double> &time_info_cls);
std::string rebuild_table(std::vector<std::string> rec_html_tags, std::string rebuild_table(std::vector<std::string> rec_html_tags,
std::vector<std::vector<int>> rec_boxes, std::vector<std::vector<int>> rec_boxes,
std::vector<OCRPredictResult> &ocr_result); std::vector<OCRPredictResult> &ocr_result);
float iou(std::vector<int> &box1, std::vector<int> &box2);
float dis(std::vector<int> &box1, std::vector<int> &box2); float dis(std::vector<int> &box1, std::vector<int> &box2);
static bool comparison_dis(const std::vector<float> &dis1, static bool comparison_dis(const std::vector<float> &dis1,
......
...@@ -14,24 +14,9 @@ ...@@ -14,24 +14,9 @@
#pragma once #pragma once
#include "opencv2/core.hpp"
#include "opencv2/imgcodecs.hpp"
#include "opencv2/imgproc.hpp"
#include <chrono>
#include <iomanip>
#include <iostream>
#include <ostream>
#include <vector>
#include <cstring>
#include <fstream>
#include <numeric>
#include "include/clipper.h" #include "include/clipper.h"
#include "include/utility.h" #include "include/utility.h"
using namespace std;
namespace PaddleOCR { namespace PaddleOCR {
class DBPostProcessor { class DBPostProcessor {
...@@ -106,4 +91,27 @@ private: ...@@ -106,4 +91,27 @@ private:
std::string beg = "sos"; std::string beg = "sos";
}; };
class PicodetPostProcessor {
public:
void init(std::string label_path, const double score_threshold = 0.4,
const double nms_threshold = 0.5,
const std::vector<int> &fpn_stride = {8, 16, 32, 64});
void Run(std::vector<StructurePredictResult> &results,
std::vector<std::vector<float>> outs, std::vector<int> ori_shape,
std::vector<int> resize_shape, int eg_max);
std::vector<int> fpn_stride_ = {8, 16, 32, 64};
private:
StructurePredictResult disPred2Bbox(std::vector<float> bbox_pred, int label,
float score, int x, int y, int stride,
std::vector<int> im_shape, int reg_max);
void nms(std::vector<StructurePredictResult> &input_boxes,
float nms_threshold);
std::vector<std::string> label_list_;
double score_threshold_ = 0.4;
double nms_threshold_ = 0.5;
int num_class_ = 5;
};
} // namespace PaddleOCR } // namespace PaddleOCR
...@@ -14,21 +14,12 @@ ...@@ -14,21 +14,12 @@
#pragma once #pragma once
#include "opencv2/core.hpp"
#include "opencv2/imgcodecs.hpp"
#include "opencv2/imgproc.hpp"
#include <chrono>
#include <iomanip>
#include <iostream> #include <iostream>
#include <ostream>
#include <vector> #include <vector>
#include <cstring> #include "opencv2/core.hpp"
#include <fstream> #include "opencv2/imgcodecs.hpp"
#include <numeric> #include "opencv2/imgproc.hpp"
using namespace std;
using namespace paddle;
namespace PaddleOCR { namespace PaddleOCR {
...@@ -51,9 +42,9 @@ public: ...@@ -51,9 +42,9 @@ public:
class ResizeImgType0 { class ResizeImgType0 {
public: public:
virtual void Run(const cv::Mat &img, cv::Mat &resize_img, string limit_type, virtual void Run(const cv::Mat &img, cv::Mat &resize_img,
int limit_side_len, float &ratio_h, float &ratio_w, std::string limit_type, int limit_side_len, float &ratio_h,
bool use_tensorrt); float &ratio_w, bool use_tensorrt);
}; };
class CrnnResizeImg { class CrnnResizeImg {
...@@ -82,4 +73,10 @@ public: ...@@ -82,4 +73,10 @@ public:
const int max_len = 488); const int max_len = 488);
}; };
class Resize {
public:
virtual void Run(const cv::Mat &img, cv::Mat &resize_img, const int h,
const int w);
};
} // namespace PaddleOCR } // namespace PaddleOCR
\ No newline at end of file
// Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include "paddle_api.h"
#include "paddle_inference_api.h"
#include <include/postprocess_op.h>
#include <include/preprocess_op.h>
namespace PaddleOCR {
class StructureLayoutRecognizer {
public:
explicit StructureLayoutRecognizer(
const std::string &model_dir, const bool &use_gpu, const int &gpu_id,
const int &gpu_mem, const int &cpu_math_library_num_threads,
const bool &use_mkldnn, const std::string &label_path,
const bool &use_tensorrt, const std::string &precision,
const double &layout_score_threshold,
const double &layout_nms_threshold) {
this->use_gpu_ = use_gpu;
this->gpu_id_ = gpu_id;
this->gpu_mem_ = gpu_mem;
this->cpu_math_library_num_threads_ = cpu_math_library_num_threads;
this->use_mkldnn_ = use_mkldnn;
this->use_tensorrt_ = use_tensorrt;
this->precision_ = precision;
this->post_processor_.init(label_path, layout_score_threshold,
layout_nms_threshold);
LoadModel(model_dir);
}
// Load Paddle inference model
void LoadModel(const std::string &model_dir);
void Run(cv::Mat img, std::vector<StructurePredictResult> &result,
std::vector<double> &times);
private:
std::shared_ptr<paddle_infer::Predictor> predictor_;
bool use_gpu_ = false;
int gpu_id_ = 0;
int gpu_mem_ = 4000;
int cpu_math_library_num_threads_ = 4;
bool use_mkldnn_ = false;
std::vector<float> mean_ = {0.485f, 0.456f, 0.406f};
std::vector<float> scale_ = {1 / 0.229f, 1 / 0.224f, 1 / 0.225f};
bool is_scale_ = true;
bool use_tensorrt_ = false;
std::string precision_ = "fp32";
// pre-process
Resize resize_op_;
Normalize normalize_op_;
Permute permute_op_;
// post-process
PicodetPostProcessor post_processor_;
};
} // namespace PaddleOCR
\ No newline at end of file
...@@ -14,26 +14,11 @@ ...@@ -14,26 +14,11 @@
#pragma once #pragma once
#include "opencv2/core.hpp"
#include "opencv2/imgcodecs.hpp"
#include "opencv2/imgproc.hpp"
#include "paddle_api.h" #include "paddle_api.h"
#include "paddle_inference_api.h" #include "paddle_inference_api.h"
#include <chrono>
#include <iomanip>
#include <iostream>
#include <ostream>
#include <vector>
#include <cstring>
#include <fstream>
#include <numeric>
#include <include/postprocess_op.h> #include <include/postprocess_op.h>
#include <include/preprocess_op.h> #include <include/preprocess_op.h>
#include <include/utility.h>
using namespace paddle_infer;
namespace PaddleOCR { namespace PaddleOCR {
...@@ -42,7 +27,7 @@ public: ...@@ -42,7 +27,7 @@ public:
explicit StructureTableRecognizer( explicit StructureTableRecognizer(
const std::string &model_dir, const bool &use_gpu, const int &gpu_id, const std::string &model_dir, const bool &use_gpu, const int &gpu_id,
const int &gpu_mem, const int &cpu_math_library_num_threads, const int &gpu_mem, const int &cpu_math_library_num_threads,
const bool &use_mkldnn, const string &label_path, const bool &use_mkldnn, const std::string &label_path,
const bool &use_tensorrt, const std::string &precision, const bool &use_tensorrt, const std::string &precision,
const int &table_batch_num, const int &table_max_len, const int &table_batch_num, const int &table_max_len,
const bool &merge_no_span_structure) { const bool &merge_no_span_structure) {
...@@ -70,7 +55,7 @@ public: ...@@ -70,7 +55,7 @@ public:
std::vector<double> &times); std::vector<double> &times);
private: private:
std::shared_ptr<Predictor> predictor_; std::shared_ptr<paddle_infer::Predictor> predictor_;
bool use_gpu_ = false; bool use_gpu_ = false;
int gpu_id_ = 0; int gpu_id_ = 0;
......
...@@ -41,12 +41,13 @@ struct OCRPredictResult { ...@@ -41,12 +41,13 @@ struct OCRPredictResult {
}; };
struct StructurePredictResult { struct StructurePredictResult {
std::vector<int> box; std::vector<float> box;
std::vector<std::vector<int>> cell_box; std::vector<std::vector<int>> cell_box;
std::string type; std::string type;
std::vector<OCRPredictResult> text_res; std::vector<OCRPredictResult> text_res;
std::string html; std::string html;
float html_score = -1; float html_score = -1;
float confidence;
}; };
class Utility { class Utility {
...@@ -82,13 +83,20 @@ public: ...@@ -82,13 +83,20 @@ public:
static void print_result(const std::vector<OCRPredictResult> &ocr_result); static void print_result(const std::vector<OCRPredictResult> &ocr_result);
static cv::Mat crop_image(cv::Mat &img, std::vector<int> &area); static cv::Mat crop_image(cv::Mat &img, const std::vector<int> &area);
static cv::Mat crop_image(cv::Mat &img, const std::vector<float> &area);
static void sorted_boxes(std::vector<OCRPredictResult> &ocr_result); static void sorted_boxes(std::vector<OCRPredictResult> &ocr_result);
static std::vector<int> xyxyxyxy2xyxy(std::vector<std::vector<int>> &box); static std::vector<int> xyxyxyxy2xyxy(std::vector<std::vector<int>> &box);
static std::vector<int> xyxyxyxy2xyxy(std::vector<int> &box); static std::vector<int> xyxyxyxy2xyxy(std::vector<int> &box);
static float fast_exp(float x);
static std::vector<float>
activation_function_softmax(std::vector<float> &src);
static float iou(std::vector<int> &box1, std::vector<int> &box2);
static float iou(std::vector<float> &box1, std::vector<float> &box2);
private: private:
static bool comparison_box(const OCRPredictResult &result1, static bool comparison_box(const OCRPredictResult &result1,
const OCRPredictResult &result2) { const OCRPredictResult &result2) {
......
...@@ -174,6 +174,9 @@ inference/ ...@@ -174,6 +174,9 @@ inference/
|-- table |-- table
| |--inference.pdiparams | |--inference.pdiparams
| |--inference.pdmodel | |--inference.pdmodel
|-- layout
| |--inference.pdiparams
| |--inference.pdmodel
``` ```
...@@ -278,8 +281,30 @@ Specifically, ...@@ -278,8 +281,30 @@ Specifically,
--cls=true \ --cls=true \
``` ```
##### 7. layout+table
```shell
./build/ppocr --det_model_dir=inference/det_db \
--rec_model_dir=inference/rec_rcnn \
--table_model_dir=inference/table \
--image_dir=../../ppstructure/docs/table/table.jpg \
--layout_model_dir=inference/layout \
--type=structure \
--table=true \
--layout=true
```
##### 8. layout
```shell
./build/ppocr --layout_model_dir=inference/layout \
--image_dir=../../ppstructure/docs/table/1.png \
--type=structure \
--table=false \
--layout=true \
--det=false \
--rec=false
```
##### 7. table ##### 9. table
```shell ```shell
./build/ppocr --det_model_dir=inference/det_db \ ./build/ppocr --det_model_dir=inference/det_db \
--rec_model_dir=inference/rec_rcnn \ --rec_model_dir=inference/rec_rcnn \
...@@ -343,6 +368,16 @@ More parameters are as follows, ...@@ -343,6 +368,16 @@ More parameters are as follows,
|rec_img_h|int|48|image height of recognition| |rec_img_h|int|48|image height of recognition|
|rec_img_w|int|320|image width of recognition| |rec_img_w|int|320|image width of recognition|
- Layout related parameters
|parameter|data type|default|meaning|
| :---: | :---: | :---: | :---: |
|layout_model_dir|string|-| Address of layout inference model|
|layout_dict_path|string|../../ppocr/utils/dict/layout_dict/layout_publaynet_dict.txt|dictionary file|
|layout_score_threshold|float|0.5|Threshold of score.|
|layout_nms_threshold|float|0.5|Threshold of nms.|
- Table recognition related parameters - Table recognition related parameters
|parameter|data type|default|meaning| |parameter|data type|default|meaning|
...@@ -368,11 +403,51 @@ predict img: ../../doc/imgs/12.jpg ...@@ -368,11 +403,51 @@ predict img: ../../doc/imgs/12.jpg
The detection visualized image saved in ./output//12.jpg The detection visualized image saved in ./output//12.jpg
``` ```
- table - layout+table
```bash ```bash
predict img: ../../ppstructure/docs/table/table.jpg predict img: ../../ppstructure/docs/table/1.png
0 type: table, region: [0,0,371,293], res: <html><body><table><thead><tr><td>Methods</td><td>R</td><td>P</td><td>F</td><td>FPS</td></tr></thead><tbody><tr><td>SegLink [26]</td><td>70.0</td><td>86.0</td><td>77.0</td><td>8.9</td></tr><tr><td>PixelLink [4]</td><td>73.2</td><td>83.0</td><td>77.8</td><td>-</td></tr><tr><td>TextSnake [18]</td><td>73.9</td><td>83.2</td><td>78.3</td><td>1.1</td></tr><tr><td>TextField [37]</td><td>75.9</td><td>87.4</td><td>81.3</td><td>5.2 </td></tr><tr><td>MSR[38]</td><td>76.7</td><td>87.4</td><td>81.7</td><td>-</td></tr><tr><td>FTSN [3]</td><td>77.1</td><td>87.6</td><td>82.0</td><td>-</td></tr><tr><td>LSE[30]</td><td>81.7</td><td>84.2</td><td>82.9</td><td>-</td></tr><tr><td>CRAFT [2]</td><td>78.2</td><td>88.2</td><td>82.9</td><td>8.6</td></tr><tr><td>MCN [16]</td><td>79</td><td>88</td><td>83</td><td>-</td></tr><tr><td>ATRR[35]</td><td>82.1</td><td>85.2</td><td>83.6</td><td>-</td></tr><tr><td>PAN [34]</td><td>83.8</td><td>84.4</td><td>84.1</td><td>30.2</td></tr><tr><td>DB[12]</td><td>79.2</td><td>91.5</td><td>84.9</td><td>32.0</td></tr><tr><td>DRRG [41]</td><td>82.30</td><td>88.05</td><td>85.08</td><td>-</td></tr><tr><td>Ours (SynText)</td><td>80.68</td><td>85.40</td><td>82.97</td><td>12.68</td></tr><tr><td>Ours (MLT-17)</td><td>84.54</td><td>86.62</td><td>85.57</td><td>12.31</td></tr></tbody></table></body></html> 0 type: text, region: [12,729,410,848], score: 0.781044, res: count of ocr result is : 7
********** print ocr result **********
0 det boxes: [[4,1],[79,1],[79,12],[4,12]] rec text: CTW1500. rec score: 0.769472
...
6 det boxes: [[4,99],[391,99],[391,112],[4,112]] rec text: sate-of-the-artmethods[12.34.36l.ourapproachachieves rec score: 0.90414
********** end print ocr result **********
1 type: text, region: [69,342,342,359], score: 0.703666, res: count of ocr result is : 1
********** print ocr result **********
0 det boxes: [[8,2],[269,2],[269,13],[8,13]] rec text: Table6.Experimentalresults on CTW-1500 rec score: 0.890454
********** end print ocr result **********
2 type: text, region: [70,316,706,332], score: 0.659738, res: count of ocr result is : 2
********** print ocr result **********
0 det boxes: [[373,2],[630,2],[630,11],[373,11]] rec text: oroposals.andthegreencontoursarefinal rec score: 0.919729
1 det boxes: [[8,3],[357,3],[357,11],[8,11]] rec text: Visualexperimentalresultshebluecontoursareboundar rec score: 0.915963
********** end print ocr result **********
3 type: text, region: [489,342,789,359], score: 0.630538, res: count of ocr result is : 1
********** print ocr result **********
0 det boxes: [[8,2],[294,2],[294,14],[8,14]] rec text: Table7.Experimentalresults onMSRA-TD500 rec score: 0.942251
********** end print ocr result **********
4 type: text, region: [444,751,841,848], score: 0.607345, res: count of ocr result is : 5
********** print ocr result **********
0 det boxes: [[19,3],[389,3],[389,17],[19,17]] rec text: Inthispaper,weproposeanovel adaptivebound rec score: 0.941031
1 det boxes: [[4,22],[390,22],[390,36],[4,36]] rec text: aryproposalnetworkforarbitraryshapetextdetection rec score: 0.960172
2 det boxes: [[4,42],[392,42],[392,56],[4,56]] rec text: whichadoptanboundaryproposalmodeltogeneratecoarse rec score: 0.934647
3 det boxes: [[4,61],[389,61],[389,75],[4,75]] rec text: ooundaryproposals,andthenadoptanadaptiveboundary rec score: 0.946296
4 det boxes: [[5,80],[387,80],[387,93],[5,93]] rec text: leformationmodelcombinedwithGCNandRNNtoper rec score: 0.952401
********** end print ocr result **********
5 type: title, region: [444,705,564,724], score: 0.785429, res: count of ocr result is : 1
********** print ocr result **********
0 det boxes: [[6,2],[113,2],[113,14],[6,14]] rec text: 5.Conclusion rec score: 0.856903
********** end print ocr result **********
6 type: table, region: [14,360,402,711], score: 0.963643, res: <html><body><table><thead><tr><td>Methods</td><td>Ext</td><td>R</td><td>P</td><td>F</td><td>FPS</td></tr></thead><tbody><tr><td>TextSnake [18]</td><td>Syn</td><td>85.3</td><td>67.9</td><td>75.6</td><td></td></tr><tr><td>CSE [17]</td><td>MiLT</td><td>76.1</td><td>78.7</td><td>77.4</td><td>0.38</td></tr><tr><td>LOMO[40]</td><td>Syn</td><td>76.5</td><td>85.7</td><td>80.8</td><td>4.4</td></tr><tr><td>ATRR[35]</td><td>Sy-</td><td>80.2</td><td>80.1</td><td>80.1</td><td>-</td></tr><tr><td>SegLink++ [28]</td><td>Syn</td><td>79.8</td><td>82.8</td><td>81.3</td><td>-</td></tr><tr><td>TextField [37]</td><td>Syn</td><td>79.8</td><td>83.0</td><td>81.4</td><td>6.0</td></tr><tr><td>MSR[38]</td><td>Syn</td><td>79.0</td><td>84.1</td><td>81.5</td><td>4.3</td></tr><tr><td>PSENet-1s [33]</td><td>MLT</td><td>79.7</td><td>84.8</td><td>82.2</td><td>3.9</td></tr><tr><td>DB [12]</td><td>Syn</td><td>80.2</td><td>86.9</td><td>83.4</td><td>22.0</td></tr><tr><td>CRAFT [2]</td><td>Syn</td><td>81.1</td><td>86.0</td><td>83.5</td><td>-</td></tr><tr><td>TextDragon [5]</td><td>MLT+</td><td>82.8</td><td>84.5</td><td>83.6</td><td></td></tr><tr><td>PAN [34]</td><td>Syn</td><td>81.2</td><td>86.4</td><td>83.7</td><td>39.8</td></tr><tr><td>ContourNet [36]</td><td></td><td>84.1</td><td>83.7</td><td>83.9</td><td>4.5</td></tr><tr><td>DRRG [41]</td><td>MLT</td><td>83.02</td><td>85.93</td><td>84.45</td><td>-</td></tr><tr><td>TextPerception[23]</td><td>Syn</td><td>81.9</td><td>87.5</td><td>84.6</td><td></td></tr><tr><td>Ours</td><td> Syn</td><td>80.57</td><td>87.66</td><td>83.97</td><td>12.08</td></tr><tr><td>Ours</td><td></td><td>81.45</td><td>87.81</td><td>84.51</td><td>12.15</td></tr><tr><td>Ours</td><td>MLT</td><td>83.60</td><td>86.45</td><td>85.00</td><td>12.21</td></tr></tbody></table></body></html>
The table visualized image saved in ./output//6_1.png
7 type: table, region: [462,359,820,657], score: 0.953917, res: <html><body><table><thead><tr><td>Methods</td><td>R</td><td>P</td><td>F</td><td>FPS</td></tr></thead><tbody><tr><td>SegLink [26]</td><td>70.0</td><td>86.0</td><td>77.0</td><td>8.9</td></tr><tr><td>PixelLink [4]</td><td>73.2</td><td>83.0</td><td>77.8</td><td>-</td></tr><tr><td>TextSnake [18]</td><td>73.9</td><td>83.2</td><td>78.3</td><td>1.1</td></tr><tr><td>TextField [37]</td><td>75.9</td><td>87.4</td><td>81.3</td><td>5.2 </td></tr><tr><td>MSR[38]</td><td>76.7</td><td>87.4</td><td>81.7</td><td>-</td></tr><tr><td>FTSN[3]</td><td>77.1</td><td>87.6</td><td>82.0</td><td>:</td></tr><tr><td>LSE[30]</td><td>81.7</td><td>84.2</td><td>82.9</td><td></td></tr><tr><td>CRAFT [2]</td><td>78.2</td><td>88.2</td><td>82.9</td><td>8.6</td></tr><tr><td>MCN [16]</td><td>79</td><td>88</td><td>83</td><td>-</td></tr><tr><td>ATRR[35]</td><td>82.1</td><td>85.2</td><td>83.6</td><td>-</td></tr><tr><td>PAN [34]</td><td>83.8</td><td>84.4</td><td>84.1</td><td>30.2</td></tr><tr><td>DB[12]</td><td>79.2</td><td>91.5</td><td>84.9</td><td>32.0</td></tr><tr><td>DRRG [41]</td><td>82.30</td><td>88.05</td><td>85.08</td><td>-</td></tr><tr><td>Ours (SynText)</td><td>80.68</td><td>85.40</td><td>82.97</td><td>12.68</td></tr><tr><td>Ours (MLT-17)</td><td>84.54</td><td>86.62</td><td>85.57</td><td>12.31</td></tr></tbody></table></body></html>
The table visualized image saved in ./output//7_1.png
8 type: figure, region: [14,3,836,310], score: 0.969443, res: count of ocr result is : 26
********** print ocr result **********
0 det boxes: [[506,14],[539,15],[539,22],[506,21]] rec text: E rec score: 0.318073
...
25 det boxes: [[680,290],[759,288],[759,303],[680,305]] rec text: (d) CTW1500 rec score: 0.95911
********** end print ocr result **********
``` ```
<a name="3"></a> <a name="3"></a>
......
...@@ -184,6 +184,9 @@ inference/ ...@@ -184,6 +184,9 @@ inference/
|-- table |-- table
| |--inference.pdiparams | |--inference.pdiparams
| |--inference.pdmodel | |--inference.pdmodel
|-- layout
| |--inference.pdiparams
| |--inference.pdmodel
``` ```
<a name="22"></a> <a name="22"></a>
...@@ -288,7 +291,30 @@ CUDNN_LIB_DIR=/your_cudnn_lib_dir ...@@ -288,7 +291,30 @@ CUDNN_LIB_DIR=/your_cudnn_lib_dir
--cls=true \ --cls=true \
``` ```
##### 7. 表格识别 ##### 7. 版面分析+表格识别
```shell
./build/ppocr --det_model_dir=inference/det_db \
--rec_model_dir=inference/rec_rcnn \
--table_model_dir=inference/table \
--image_dir=../../ppstructure/docs/table/table.jpg \
--layout_model_dir=inference/layout \
--type=structure \
--table=true \
--layout=true
```
##### 8. 版面分析
```shell
./build/ppocr --layout_model_dir=inference/layout \
--image_dir=../../ppstructure/docs/table/1.png \
--type=structure \
--table=false \
--layout=true \
--det=false \
--rec=false
```
##### 9. 表格识别
```shell ```shell
./build/ppocr --det_model_dir=inference/det_db \ ./build/ppocr --det_model_dir=inference/det_db \
--rec_model_dir=inference/rec_rcnn \ --rec_model_dir=inference/rec_rcnn \
...@@ -352,12 +378,22 @@ CUDNN_LIB_DIR=/your_cudnn_lib_dir ...@@ -352,12 +378,22 @@ CUDNN_LIB_DIR=/your_cudnn_lib_dir
|rec_img_w|int|320|文字识别模型输入图像宽度| |rec_img_w|int|320|文字识别模型输入图像宽度|
- 版面分析模型相关
|参数名称|类型|默认参数|意义|
| :---: | :---: | :---: | :---: |
|layout_model_dir|string|-|版面分析模型inference model地址|
|layout_dict_path|string|../../ppocr/utils/dict/layout_dict/layout_publaynet_dict.txt|字典文件|
|layout_score_threshold|float|0.5|检测框的分数阈值|
|layout_nms_threshold|float|0.5|nms的阈值|
- 表格识别模型相关 - 表格识别模型相关
|参数名称|类型|默认参数|意义| |参数名称|类型|默认参数|意义|
| :---: | :---: | :---: | :---: | | :---: | :---: | :---: | :---: |
|table_model_dir|string|-|表格识别模型inference model地址| |table_model_dir|string|-|表格识别模型inference model地址|
|table_char_dict_path|string|../../ppocr/utils/dict/table_structure_dict.txt|字典文件| |table_char_dict_path|string|../../ppocr/utils/dict/table_structure_dict_ch.txt|字典文件|
|table_max_len|int|488|表格识别模型输入图像长边大小,最终网络输入图像大小为(table_max_len,table_max_len)| |table_max_len|int|488|表格识别模型输入图像长边大小,最终网络输入图像大小为(table_max_len,table_max_len)|
|merge_no_span_structure|bool|true|是否合并<td></td><td></td>| |merge_no_span_structure|bool|true|是否合并<td></td><td></td>|
...@@ -378,11 +414,51 @@ predict img: ../../doc/imgs/12.jpg ...@@ -378,11 +414,51 @@ predict img: ../../doc/imgs/12.jpg
The detection visualized image saved in ./output//12.jpg The detection visualized image saved in ./output//12.jpg
``` ```
- table - layout+table
```bash ```bash
predict img: ../../ppstructure/docs/table/table.jpg predict img: ../../ppstructure/docs/table/1.png
0 type: table, region: [0,0,371,293], res: <html><body><table><thead><tr><td>Methods</td><td>R</td><td>P</td><td>F</td><td>FPS</td></tr></thead><tbody><tr><td>SegLink [26]</td><td>70.0</td><td>86.0</td><td>77.0</td><td>8.9</td></tr><tr><td>PixelLink [4]</td><td>73.2</td><td>83.0</td><td>77.8</td><td>-</td></tr><tr><td>TextSnake [18]</td><td>73.9</td><td>83.2</td><td>78.3</td><td>1.1</td></tr><tr><td>TextField [37]</td><td>75.9</td><td>87.4</td><td>81.3</td><td>5.2 </td></tr><tr><td>MSR[38]</td><td>76.7</td><td>87.4</td><td>81.7</td><td>-</td></tr><tr><td>FTSN [3]</td><td>77.1</td><td>87.6</td><td>82.0</td><td>-</td></tr><tr><td>LSE[30]</td><td>81.7</td><td>84.2</td><td>82.9</td><td>-</td></tr><tr><td>CRAFT [2]</td><td>78.2</td><td>88.2</td><td>82.9</td><td>8.6</td></tr><tr><td>MCN [16]</td><td>79</td><td>88</td><td>83</td><td>-</td></tr><tr><td>ATRR[35]</td><td>82.1</td><td>85.2</td><td>83.6</td><td>-</td></tr><tr><td>PAN [34]</td><td>83.8</td><td>84.4</td><td>84.1</td><td>30.2</td></tr><tr><td>DB[12]</td><td>79.2</td><td>91.5</td><td>84.9</td><td>32.0</td></tr><tr><td>DRRG [41]</td><td>82.30</td><td>88.05</td><td>85.08</td><td>-</td></tr><tr><td>Ours (SynText)</td><td>80.68</td><td>85.40</td><td>82.97</td><td>12.68</td></tr><tr><td>Ours (MLT-17)</td><td>84.54</td><td>86.62</td><td>85.57</td><td>12.31</td></tr></tbody></table></body></html> 0 type: text, region: [12,729,410,848], score: 0.781044, res: count of ocr result is : 7
********** print ocr result **********
0 det boxes: [[4,1],[79,1],[79,12],[4,12]] rec text: CTW1500. rec score: 0.769472
...
6 det boxes: [[4,99],[391,99],[391,112],[4,112]] rec text: sate-of-the-artmethods[12.34.36l.ourapproachachieves rec score: 0.90414
********** end print ocr result **********
1 type: text, region: [69,342,342,359], score: 0.703666, res: count of ocr result is : 1
********** print ocr result **********
0 det boxes: [[8,2],[269,2],[269,13],[8,13]] rec text: Table6.Experimentalresults on CTW-1500 rec score: 0.890454
********** end print ocr result **********
2 type: text, region: [70,316,706,332], score: 0.659738, res: count of ocr result is : 2
********** print ocr result **********
0 det boxes: [[373,2],[630,2],[630,11],[373,11]] rec text: oroposals.andthegreencontoursarefinal rec score: 0.919729
1 det boxes: [[8,3],[357,3],[357,11],[8,11]] rec text: Visualexperimentalresultshebluecontoursareboundar rec score: 0.915963
********** end print ocr result **********
3 type: text, region: [489,342,789,359], score: 0.630538, res: count of ocr result is : 1
********** print ocr result **********
0 det boxes: [[8,2],[294,2],[294,14],[8,14]] rec text: Table7.Experimentalresults onMSRA-TD500 rec score: 0.942251
********** end print ocr result **********
4 type: text, region: [444,751,841,848], score: 0.607345, res: count of ocr result is : 5
********** print ocr result **********
0 det boxes: [[19,3],[389,3],[389,17],[19,17]] rec text: Inthispaper,weproposeanovel adaptivebound rec score: 0.941031
1 det boxes: [[4,22],[390,22],[390,36],[4,36]] rec text: aryproposalnetworkforarbitraryshapetextdetection rec score: 0.960172
2 det boxes: [[4,42],[392,42],[392,56],[4,56]] rec text: whichadoptanboundaryproposalmodeltogeneratecoarse rec score: 0.934647
3 det boxes: [[4,61],[389,61],[389,75],[4,75]] rec text: ooundaryproposals,andthenadoptanadaptiveboundary rec score: 0.946296
4 det boxes: [[5,80],[387,80],[387,93],[5,93]] rec text: leformationmodelcombinedwithGCNandRNNtoper rec score: 0.952401
********** end print ocr result **********
5 type: title, region: [444,705,564,724], score: 0.785429, res: count of ocr result is : 1
********** print ocr result **********
0 det boxes: [[6,2],[113,2],[113,14],[6,14]] rec text: 5.Conclusion rec score: 0.856903
********** end print ocr result **********
6 type: table, region: [14,360,402,711], score: 0.963643, res: <html><body><table><thead><tr><td>Methods</td><td>Ext</td><td>R</td><td>P</td><td>F</td><td>FPS</td></tr></thead><tbody><tr><td>TextSnake [18]</td><td>Syn</td><td>85.3</td><td>67.9</td><td>75.6</td><td></td></tr><tr><td>CSE [17]</td><td>MiLT</td><td>76.1</td><td>78.7</td><td>77.4</td><td>0.38</td></tr><tr><td>LOMO[40]</td><td>Syn</td><td>76.5</td><td>85.7</td><td>80.8</td><td>4.4</td></tr><tr><td>ATRR[35]</td><td>Sy-</td><td>80.2</td><td>80.1</td><td>80.1</td><td>-</td></tr><tr><td>SegLink++ [28]</td><td>Syn</td><td>79.8</td><td>82.8</td><td>81.3</td><td>-</td></tr><tr><td>TextField [37]</td><td>Syn</td><td>79.8</td><td>83.0</td><td>81.4</td><td>6.0</td></tr><tr><td>MSR[38]</td><td>Syn</td><td>79.0</td><td>84.1</td><td>81.5</td><td>4.3</td></tr><tr><td>PSENet-1s [33]</td><td>MLT</td><td>79.7</td><td>84.8</td><td>82.2</td><td>3.9</td></tr><tr><td>DB [12]</td><td>Syn</td><td>80.2</td><td>86.9</td><td>83.4</td><td>22.0</td></tr><tr><td>CRAFT [2]</td><td>Syn</td><td>81.1</td><td>86.0</td><td>83.5</td><td>-</td></tr><tr><td>TextDragon [5]</td><td>MLT+</td><td>82.8</td><td>84.5</td><td>83.6</td><td></td></tr><tr><td>PAN [34]</td><td>Syn</td><td>81.2</td><td>86.4</td><td>83.7</td><td>39.8</td></tr><tr><td>ContourNet [36]</td><td></td><td>84.1</td><td>83.7</td><td>83.9</td><td>4.5</td></tr><tr><td>DRRG [41]</td><td>MLT</td><td>83.02</td><td>85.93</td><td>84.45</td><td>-</td></tr><tr><td>TextPerception[23]</td><td>Syn</td><td>81.9</td><td>87.5</td><td>84.6</td><td></td></tr><tr><td>Ours</td><td> Syn</td><td>80.57</td><td>87.66</td><td>83.97</td><td>12.08</td></tr><tr><td>Ours</td><td></td><td>81.45</td><td>87.81</td><td>84.51</td><td>12.15</td></tr><tr><td>Ours</td><td>MLT</td><td>83.60</td><td>86.45</td><td>85.00</td><td>12.21</td></tr></tbody></table></body></html>
The table visualized image saved in ./output//6_1.png
7 type: table, region: [462,359,820,657], score: 0.953917, res: <html><body><table><thead><tr><td>Methods</td><td>R</td><td>P</td><td>F</td><td>FPS</td></tr></thead><tbody><tr><td>SegLink [26]</td><td>70.0</td><td>86.0</td><td>77.0</td><td>8.9</td></tr><tr><td>PixelLink [4]</td><td>73.2</td><td>83.0</td><td>77.8</td><td>-</td></tr><tr><td>TextSnake [18]</td><td>73.9</td><td>83.2</td><td>78.3</td><td>1.1</td></tr><tr><td>TextField [37]</td><td>75.9</td><td>87.4</td><td>81.3</td><td>5.2 </td></tr><tr><td>MSR[38]</td><td>76.7</td><td>87.4</td><td>81.7</td><td>-</td></tr><tr><td>FTSN[3]</td><td>77.1</td><td>87.6</td><td>82.0</td><td>:</td></tr><tr><td>LSE[30]</td><td>81.7</td><td>84.2</td><td>82.9</td><td></td></tr><tr><td>CRAFT [2]</td><td>78.2</td><td>88.2</td><td>82.9</td><td>8.6</td></tr><tr><td>MCN [16]</td><td>79</td><td>88</td><td>83</td><td>-</td></tr><tr><td>ATRR[35]</td><td>82.1</td><td>85.2</td><td>83.6</td><td>-</td></tr><tr><td>PAN [34]</td><td>83.8</td><td>84.4</td><td>84.1</td><td>30.2</td></tr><tr><td>DB[12]</td><td>79.2</td><td>91.5</td><td>84.9</td><td>32.0</td></tr><tr><td>DRRG [41]</td><td>82.30</td><td>88.05</td><td>85.08</td><td>-</td></tr><tr><td>Ours (SynText)</td><td>80.68</td><td>85.40</td><td>82.97</td><td>12.68</td></tr><tr><td>Ours (MLT-17)</td><td>84.54</td><td>86.62</td><td>85.57</td><td>12.31</td></tr></tbody></table></body></html>
The table visualized image saved in ./output//7_1.png
8 type: figure, region: [14,3,836,310], score: 0.969443, res: count of ocr result is : 26
********** print ocr result **********
0 det boxes: [[506,14],[539,15],[539,22],[506,21]] rec text: E rec score: 0.318073
...
25 det boxes: [[680,290],[759,288],[759,303],[680,305]] rec text: (d) CTW1500 rec score: 0.95911
********** end print ocr result **********
``` ```
<a name="3"></a> <a name="3"></a>
......
...@@ -51,6 +51,13 @@ DEFINE_string(rec_char_dict_path, "../../ppocr/utils/ppocr_keys_v1.txt", ...@@ -51,6 +51,13 @@ DEFINE_string(rec_char_dict_path, "../../ppocr/utils/ppocr_keys_v1.txt",
DEFINE_int32(rec_img_h, 48, "rec image height"); DEFINE_int32(rec_img_h, 48, "rec image height");
DEFINE_int32(rec_img_w, 320, "rec image width"); DEFINE_int32(rec_img_w, 320, "rec image width");
// layout model related
DEFINE_string(layout_model_dir, "", "Path of table layout inference model.");
DEFINE_string(layout_dict_path,
"../../ppocr/utils/dict/layout_dict/layout_publaynet_dict.txt",
"Path of dictionary.");
DEFINE_double(layout_score_threshold, 0.5, "Threshold of score.");
DEFINE_double(layout_nms_threshold, 0.5, "Threshold of nms.");
// structure model related // structure model related
DEFINE_string(table_model_dir, "", "Path of table struture inference model."); DEFINE_string(table_model_dir, "", "Path of table struture inference model.");
DEFINE_int32(table_max_len, 488, "max len size of input image."); DEFINE_int32(table_max_len, 488, "max len size of input image.");
...@@ -65,4 +72,5 @@ DEFINE_string(table_char_dict_path, ...@@ -65,4 +72,5 @@ DEFINE_string(table_char_dict_path,
DEFINE_bool(det, true, "Whether use det in forward."); DEFINE_bool(det, true, "Whether use det in forward.");
DEFINE_bool(rec, true, "Whether use rec in forward."); DEFINE_bool(rec, true, "Whether use rec in forward.");
DEFINE_bool(cls, false, "Whether use cls in forward."); DEFINE_bool(cls, false, "Whether use cls in forward.");
DEFINE_bool(table, false, "Whether use table structure in forward."); DEFINE_bool(table, false, "Whether use table structure in forward.");
\ No newline at end of file DEFINE_bool(layout, false, "Whether use layout analysis in forward.");
\ No newline at end of file
...@@ -65,9 +65,18 @@ void check_params() { ...@@ -65,9 +65,18 @@ void check_params() {
exit(1); exit(1);
} }
} }
if (FLAGS_layout) {
if (FLAGS_layout_model_dir.empty() || FLAGS_image_dir.empty()) {
std::cout << "Usage[layout]: ./ppocr "
<< "--layout_model_dir=/PATH/TO/LAYOUT_INFERENCE_MODEL/ "
<< "--image_dir=/PATH/TO/INPUT/IMAGE/" << std::endl;
exit(1);
}
}
if (FLAGS_precision != "fp32" && FLAGS_precision != "fp16" && if (FLAGS_precision != "fp32" && FLAGS_precision != "fp16" &&
FLAGS_precision != "int8") { FLAGS_precision != "int8") {
cout << "precison should be 'fp32'(default), 'fp16' or 'int8'. " << endl; std::cout << "precison should be 'fp32'(default), 'fp16' or 'int8'. "
<< std::endl;
exit(1); exit(1);
} }
} }
...@@ -75,71 +84,94 @@ void check_params() { ...@@ -75,71 +84,94 @@ void check_params() {
void ocr(std::vector<cv::String> &cv_all_img_names) { void ocr(std::vector<cv::String> &cv_all_img_names) {
PPOCR ocr = PPOCR(); PPOCR ocr = PPOCR();
std::vector<std::vector<OCRPredictResult>> ocr_results = if (FLAGS_benchmark) {
ocr.ocr(cv_all_img_names, FLAGS_det, FLAGS_rec, FLAGS_cls); ocr.reset_timer();
}
std::vector<cv::Mat> img_list;
std::vector<cv::String> img_names;
for (int i = 0; i < cv_all_img_names.size(); ++i) { for (int i = 0; i < cv_all_img_names.size(); ++i) {
if (FLAGS_benchmark) { cv::Mat img = cv::imread(cv_all_img_names[i], cv::IMREAD_COLOR);
cout << cv_all_img_names[i] << '\t'; if (!img.data) {
if (FLAGS_rec && FLAGS_det) { std::cerr << "[ERROR] image read failed! image path: "
Utility::print_result(ocr_results[i]); << cv_all_img_names[i] << std::endl;
} else if (FLAGS_det) { continue;
for (int n = 0; n < ocr_results[i].size(); n++) { }
for (int m = 0; m < ocr_results[i][n].box.size(); m++) { img_list.push_back(img);
cout << ocr_results[i][n].box[m][0] << ' ' img_names.push_back(cv_all_img_names[i]);
<< ocr_results[i][n].box[m][1] << ' '; }
}
}
cout << endl;
} else {
Utility::print_result(ocr_results[i]);
}
} else {
cout << cv_all_img_names[i] << "\n";
Utility::print_result(ocr_results[i]);
if (FLAGS_visualize && FLAGS_det) {
cv::Mat srcimg = cv::imread(cv_all_img_names[i], cv::IMREAD_COLOR);
if (!srcimg.data) {
std::cerr << "[ERROR] image read failed! image path: "
<< cv_all_img_names[i] << endl;
exit(1);
}
std::string file_name = Utility::basename(cv_all_img_names[i]);
Utility::VisualizeBboxes(srcimg, ocr_results[i], std::vector<std::vector<OCRPredictResult>> ocr_results =
FLAGS_output + "/" + file_name); ocr.ocr(img_list, FLAGS_det, FLAGS_rec, FLAGS_cls);
}
cout << "***************************" << endl; for (int i = 0; i < img_names.size(); ++i) {
std::cout << "predict img: " << cv_all_img_names[i] << std::endl;
Utility::print_result(ocr_results[i]);
if (FLAGS_visualize && FLAGS_det) {
std::string file_name = Utility::basename(img_names[i]);
cv::Mat srcimg = img_list[i];
Utility::VisualizeBboxes(srcimg, ocr_results[i],
FLAGS_output + "/" + file_name);
} }
} }
if (FLAGS_benchmark) {
ocr.benchmark_log(cv_all_img_names.size());
}
} }
void structure(std::vector<cv::String> &cv_all_img_names) { void structure(std::vector<cv::String> &cv_all_img_names) {
PaddleOCR::PaddleStructure engine = PaddleOCR::PaddleStructure(); PaddleOCR::PaddleStructure engine = PaddleOCR::PaddleStructure();
std::vector<std::vector<StructurePredictResult>> structure_results =
engine.structure(cv_all_img_names, false, FLAGS_table); if (FLAGS_benchmark) {
engine.reset_timer();
}
for (int i = 0; i < cv_all_img_names.size(); i++) { for (int i = 0; i < cv_all_img_names.size(); i++) {
cout << "predict img: " << cv_all_img_names[i] << endl; std::cout << "predict img: " << cv_all_img_names[i] << std::endl;
cv::Mat srcimg = cv::imread(cv_all_img_names[i], cv::IMREAD_COLOR); cv::Mat img = cv::imread(cv_all_img_names[i], cv::IMREAD_COLOR);
for (int j = 0; j < structure_results[i].size(); j++) { if (!img.data) {
std::cout << j << "\ttype: " << structure_results[i][j].type std::cerr << "[ERROR] image read failed! image path: "
<< cv_all_img_names[i] << std::endl;
continue;
}
std::vector<StructurePredictResult> structure_results = engine.structure(
img, FLAGS_layout, FLAGS_table, FLAGS_det && FLAGS_rec);
for (int j = 0; j < structure_results.size(); j++) {
std::cout << j << "\ttype: " << structure_results[j].type
<< ", region: ["; << ", region: [";
std::cout << structure_results[i][j].box[0] << "," std::cout << structure_results[j].box[0] << ","
<< structure_results[i][j].box[1] << "," << structure_results[j].box[1] << ","
<< structure_results[i][j].box[2] << "," << structure_results[j].box[2] << ","
<< structure_results[i][j].box[3] << "], res: "; << structure_results[j].box[3] << "], score: ";
if (structure_results[i][j].type == "table") { std::cout << structure_results[j].confidence << ", res: ";
std::cout << structure_results[i][j].html << std::endl;
std::string file_name = Utility::basename(cv_all_img_names[i]); if (structure_results[j].type == "table") {
std::cout << structure_results[j].html << std::endl;
Utility::VisualizeBboxes(srcimg, structure_results[i][j], if (structure_results[j].cell_box.size() > 0 && FLAGS_visualize) {
FLAGS_output + "/" + std::to_string(j) + "_" + std::string file_name = Utility::basename(cv_all_img_names[i]);
file_name);
Utility::VisualizeBboxes(img, structure_results[j],
FLAGS_output + "/" + std::to_string(j) +
"_" + file_name);
}
} else { } else {
Utility::print_result(structure_results[i][j].text_res); std::cout << "count of ocr result is : "
<< structure_results[j].text_res.size() << std::endl;
if (structure_results[j].text_res.size() > 0) {
std::cout << "********** print ocr result "
<< "**********" << std::endl;
Utility::print_result(structure_results[j].text_res);
std::cout << "********** end print ocr result "
<< "**********" << std::endl;
}
} }
} }
} }
if (FLAGS_benchmark) {
engine.benchmark_log(cv_all_img_names.size());
}
} }
int main(int argc, char **argv) { int main(int argc, char **argv) {
...@@ -149,19 +181,22 @@ int main(int argc, char **argv) { ...@@ -149,19 +181,22 @@ int main(int argc, char **argv) {
if (!Utility::PathExists(FLAGS_image_dir)) { if (!Utility::PathExists(FLAGS_image_dir)) {
std::cerr << "[ERROR] image path not exist! image_dir: " << FLAGS_image_dir std::cerr << "[ERROR] image path not exist! image_dir: " << FLAGS_image_dir
<< endl; << std::endl;
exit(1); exit(1);
} }
std::vector<cv::String> cv_all_img_names; std::vector<cv::String> cv_all_img_names;
cv::glob(FLAGS_image_dir, cv_all_img_names); cv::glob(FLAGS_image_dir, cv_all_img_names);
std::cout << "total images num: " << cv_all_img_names.size() << endl; std::cout << "total images num: " << cv_all_img_names.size() << std::endl;
if (!Utility::PathExists(FLAGS_output)) {
Utility::CreateDir(FLAGS_output);
}
if (FLAGS_type == "ocr") { if (FLAGS_type == "ocr") {
ocr(cv_all_img_names); ocr(cv_all_img_names);
} else if (FLAGS_type == "structure") { } else if (FLAGS_type == "structure") {
structure(cv_all_img_names); structure(cv_all_img_names);
} else { } else {
std::cout << "only value in ['ocr','structure'] is supported" << endl; std::cout << "only value in ['ocr','structure'] is supported" << std::endl;
} }
} }
...@@ -32,7 +32,7 @@ void Classifier::Run(std::vector<cv::Mat> img_list, ...@@ -32,7 +32,7 @@ void Classifier::Run(std::vector<cv::Mat> img_list,
for (int beg_img_no = 0; beg_img_no < img_num; for (int beg_img_no = 0; beg_img_no < img_num;
beg_img_no += this->cls_batch_num_) { beg_img_no += this->cls_batch_num_) {
auto preprocess_start = std::chrono::steady_clock::now(); auto preprocess_start = std::chrono::steady_clock::now();
int end_img_no = min(img_num, beg_img_no + this->cls_batch_num_); int end_img_no = std::min(img_num, beg_img_no + this->cls_batch_num_);
int batch_num = end_img_no - beg_img_no; int batch_num = end_img_no - beg_img_no;
// preprocess // preprocess
std::vector<cv::Mat> norm_img_batch; std::vector<cv::Mat> norm_img_batch;
...@@ -97,7 +97,7 @@ void Classifier::Run(std::vector<cv::Mat> img_list, ...@@ -97,7 +97,7 @@ void Classifier::Run(std::vector<cv::Mat> img_list,
} }
void Classifier::LoadModel(const std::string &model_dir) { void Classifier::LoadModel(const std::string &model_dir) {
AnalysisConfig config; paddle_infer::Config config;
config.SetModel(model_dir + "/inference.pdmodel", config.SetModel(model_dir + "/inference.pdmodel",
model_dir + "/inference.pdiparams"); model_dir + "/inference.pdiparams");
...@@ -112,9 +112,9 @@ void Classifier::LoadModel(const std::string &model_dir) { ...@@ -112,9 +112,9 @@ void Classifier::LoadModel(const std::string &model_dir) {
precision = paddle_infer::Config::Precision::kInt8; precision = paddle_infer::Config::Precision::kInt8;
} }
config.EnableTensorRtEngine(1 << 20, 10, 3, precision, false, false); config.EnableTensorRtEngine(1 << 20, 10, 3, precision, false, false);
if (!Utility::PathExists("./trt_cls_shape.txt")){ if (!Utility::PathExists("./trt_cls_shape.txt")) {
config.CollectShapeRangeInfo("./trt_cls_shape.txt"); config.CollectShapeRangeInfo("./trt_cls_shape.txt");
} else { } else {
config.EnableTunedTensorRtDynamicShape("./trt_cls_shape.txt", true); config.EnableTunedTensorRtDynamicShape("./trt_cls_shape.txt", true);
} }
} }
...@@ -136,6 +136,6 @@ void Classifier::LoadModel(const std::string &model_dir) { ...@@ -136,6 +136,6 @@ void Classifier::LoadModel(const std::string &model_dir) {
config.EnableMemoryOptim(); config.EnableMemoryOptim();
config.DisableGlogInfo(); config.DisableGlogInfo();
this->predictor_ = CreatePredictor(config); this->predictor_ = paddle_infer::CreatePredictor(config);
} }
} // namespace PaddleOCR } // namespace PaddleOCR
...@@ -33,12 +33,11 @@ void DBDetector::LoadModel(const std::string &model_dir) { ...@@ -33,12 +33,11 @@ void DBDetector::LoadModel(const std::string &model_dir) {
precision = paddle_infer::Config::Precision::kInt8; precision = paddle_infer::Config::Precision::kInt8;
} }
config.EnableTensorRtEngine(1 << 30, 1, 20, precision, false, false); config.EnableTensorRtEngine(1 << 30, 1, 20, precision, false, false);
if (!Utility::PathExists("./trt_det_shape.txt")){ if (!Utility::PathExists("./trt_det_shape.txt")) {
config.CollectShapeRangeInfo("./trt_det_shape.txt"); config.CollectShapeRangeInfo("./trt_det_shape.txt");
} else { } else {
config.EnableTunedTensorRtDynamicShape("./trt_det_shape.txt", true); config.EnableTunedTensorRtDynamicShape("./trt_det_shape.txt", true);
} }
} }
} else { } else {
config.DisableGpu(); config.DisableGpu();
...@@ -59,7 +58,7 @@ void DBDetector::LoadModel(const std::string &model_dir) { ...@@ -59,7 +58,7 @@ void DBDetector::LoadModel(const std::string &model_dir) {
config.EnableMemoryOptim(); config.EnableMemoryOptim();
// config.DisableGlogInfo(); // config.DisableGlogInfo();
this->predictor_ = CreatePredictor(config); this->predictor_ = paddle_infer::CreatePredictor(config);
} }
void DBDetector::Run(cv::Mat &img, void DBDetector::Run(cv::Mat &img,
......
...@@ -37,7 +37,7 @@ void CRNNRecognizer::Run(std::vector<cv::Mat> img_list, ...@@ -37,7 +37,7 @@ void CRNNRecognizer::Run(std::vector<cv::Mat> img_list,
for (int beg_img_no = 0; beg_img_no < img_num; for (int beg_img_no = 0; beg_img_no < img_num;
beg_img_no += this->rec_batch_num_) { beg_img_no += this->rec_batch_num_) {
auto preprocess_start = std::chrono::steady_clock::now(); auto preprocess_start = std::chrono::steady_clock::now();
int end_img_no = min(img_num, beg_img_no + this->rec_batch_num_); int end_img_no = std::min(img_num, beg_img_no + this->rec_batch_num_);
int batch_num = end_img_no - beg_img_no; int batch_num = end_img_no - beg_img_no;
int imgH = this->rec_image_shape_[1]; int imgH = this->rec_image_shape_[1];
int imgW = this->rec_image_shape_[2]; int imgW = this->rec_image_shape_[2];
...@@ -46,7 +46,7 @@ void CRNNRecognizer::Run(std::vector<cv::Mat> img_list, ...@@ -46,7 +46,7 @@ void CRNNRecognizer::Run(std::vector<cv::Mat> img_list,
int h = img_list[indices[ino]].rows; int h = img_list[indices[ino]].rows;
int w = img_list[indices[ino]].cols; int w = img_list[indices[ino]].cols;
float wh_ratio = w * 1.0 / h; float wh_ratio = w * 1.0 / h;
max_wh_ratio = max(max_wh_ratio, wh_ratio); max_wh_ratio = std::max(max_wh_ratio, wh_ratio);
} }
int batch_width = imgW; int batch_width = imgW;
...@@ -60,7 +60,7 @@ void CRNNRecognizer::Run(std::vector<cv::Mat> img_list, ...@@ -60,7 +60,7 @@ void CRNNRecognizer::Run(std::vector<cv::Mat> img_list,
this->normalize_op_.Run(&resize_img, this->mean_, this->scale_, this->normalize_op_.Run(&resize_img, this->mean_, this->scale_,
this->is_scale_); this->is_scale_);
norm_img_batch.push_back(resize_img); norm_img_batch.push_back(resize_img);
batch_width = max(resize_img.cols, batch_width); batch_width = std::max(resize_img.cols, batch_width);
} }
std::vector<float> input(batch_num * 3 * imgH * batch_width, 0.0f); std::vector<float> input(batch_num * 3 * imgH * batch_width, 0.0f);
...@@ -115,7 +115,7 @@ void CRNNRecognizer::Run(std::vector<cv::Mat> img_list, ...@@ -115,7 +115,7 @@ void CRNNRecognizer::Run(std::vector<cv::Mat> img_list,
last_index = argmax_idx; last_index = argmax_idx;
} }
score /= count; score /= count;
if (isnan(score)) { if (std::isnan(score)) {
continue; continue;
} }
rec_texts[indices[beg_img_no + m]] = str_res; rec_texts[indices[beg_img_no + m]] = str_res;
...@@ -130,7 +130,6 @@ void CRNNRecognizer::Run(std::vector<cv::Mat> img_list, ...@@ -130,7 +130,6 @@ void CRNNRecognizer::Run(std::vector<cv::Mat> img_list,
} }
void CRNNRecognizer::LoadModel(const std::string &model_dir) { void CRNNRecognizer::LoadModel(const std::string &model_dir) {
// AnalysisConfig config;
paddle_infer::Config config; paddle_infer::Config config;
config.SetModel(model_dir + "/inference.pdmodel", config.SetModel(model_dir + "/inference.pdmodel",
model_dir + "/inference.pdiparams"); model_dir + "/inference.pdiparams");
...@@ -147,12 +146,11 @@ void CRNNRecognizer::LoadModel(const std::string &model_dir) { ...@@ -147,12 +146,11 @@ void CRNNRecognizer::LoadModel(const std::string &model_dir) {
if (this->precision_ == "int8") { if (this->precision_ == "int8") {
precision = paddle_infer::Config::Precision::kInt8; precision = paddle_infer::Config::Precision::kInt8;
} }
if (!Utility::PathExists("./trt_rec_shape.txt")){ if (!Utility::PathExists("./trt_rec_shape.txt")) {
config.CollectShapeRangeInfo("./trt_rec_shape.txt"); config.CollectShapeRangeInfo("./trt_rec_shape.txt");
} else { } else {
config.EnableTunedTensorRtDynamicShape("./trt_rec_shape.txt", true); config.EnableTunedTensorRtDynamicShape("./trt_rec_shape.txt", true);
} }
} }
} else { } else {
config.DisableGpu(); config.DisableGpu();
...@@ -177,7 +175,7 @@ void CRNNRecognizer::LoadModel(const std::string &model_dir) { ...@@ -177,7 +175,7 @@ void CRNNRecognizer::LoadModel(const std::string &model_dir) {
config.EnableMemoryOptim(); config.EnableMemoryOptim();
// config.DisableGlogInfo(); // config.DisableGlogInfo();
this->predictor_ = CreatePredictor(config); this->predictor_ = paddle_infer::CreatePredictor(config);
} }
} // namespace PaddleOCR } // namespace PaddleOCR
...@@ -16,7 +16,7 @@ ...@@ -16,7 +16,7 @@
#include <include/paddleocr.h> #include <include/paddleocr.h>
#include "auto_log/autolog.h" #include "auto_log/autolog.h"
#include <numeric>
namespace PaddleOCR { namespace PaddleOCR {
PPOCR::PPOCR() { PPOCR::PPOCR() {
...@@ -44,8 +44,71 @@ PPOCR::PPOCR() { ...@@ -44,8 +44,71 @@ PPOCR::PPOCR() {
} }
}; };
void PPOCR::det(cv::Mat img, std::vector<OCRPredictResult> &ocr_results, std::vector<std::vector<OCRPredictResult>>
std::vector<double> &times) { PPOCR::ocr(std::vector<cv::Mat> img_list, bool det, bool rec, bool cls) {
std::vector<std::vector<OCRPredictResult>> ocr_results;
if (!det) {
std::vector<OCRPredictResult> ocr_result;
ocr_result.resize(img_list.size());
if (cls && this->classifier_ != nullptr) {
this->cls(img_list, ocr_result);
for (int i = 0; i < img_list.size(); i++) {
if (ocr_result[i].cls_label % 2 == 1 &&
ocr_result[i].cls_score > this->classifier_->cls_thresh) {
cv::rotate(img_list[i], img_list[i], 1);
}
}
}
if (rec) {
this->rec(img_list, ocr_result);
}
for (int i = 0; i < ocr_result.size(); ++i) {
std::vector<OCRPredictResult> ocr_result_tmp;
ocr_result_tmp.push_back(ocr_result[i]);
ocr_results.push_back(ocr_result_tmp);
}
} else {
for (int i = 0; i < img_list.size(); ++i) {
std::vector<OCRPredictResult> ocr_result =
this->ocr(img_list[i], true, rec, cls);
ocr_results.push_back(ocr_result);
}
}
return ocr_results;
}
std::vector<OCRPredictResult> PPOCR::ocr(cv::Mat img, bool det, bool rec,
bool cls) {
std::vector<OCRPredictResult> ocr_result;
// det
this->det(img, ocr_result);
// crop image
std::vector<cv::Mat> img_list;
for (int j = 0; j < ocr_result.size(); j++) {
cv::Mat crop_img;
crop_img = Utility::GetRotateCropImage(img, ocr_result[j].box);
img_list.push_back(crop_img);
}
// cls
if (cls && this->classifier_ != nullptr) {
this->cls(img_list, ocr_result);
for (int i = 0; i < img_list.size(); i++) {
if (ocr_result[i].cls_label % 2 == 1 &&
ocr_result[i].cls_score > this->classifier_->cls_thresh) {
cv::rotate(img_list[i], img_list[i], 1);
}
}
}
// rec
if (rec) {
this->rec(img_list, ocr_result);
}
return ocr_result;
}
void PPOCR::det(cv::Mat img, std::vector<OCRPredictResult> &ocr_results) {
std::vector<std::vector<std::vector<int>>> boxes; std::vector<std::vector<std::vector<int>>> boxes;
std::vector<double> det_times; std::vector<double> det_times;
...@@ -58,14 +121,13 @@ void PPOCR::det(cv::Mat img, std::vector<OCRPredictResult> &ocr_results, ...@@ -58,14 +121,13 @@ void PPOCR::det(cv::Mat img, std::vector<OCRPredictResult> &ocr_results,
} }
// sort boex from top to bottom, from left to right // sort boex from top to bottom, from left to right
Utility::sorted_boxes(ocr_results); Utility::sorted_boxes(ocr_results);
times[0] += det_times[0]; this->time_info_det[0] += det_times[0];
times[1] += det_times[1]; this->time_info_det[1] += det_times[1];
times[2] += det_times[2]; this->time_info_det[2] += det_times[2];
} }
void PPOCR::rec(std::vector<cv::Mat> img_list, void PPOCR::rec(std::vector<cv::Mat> img_list,
std::vector<OCRPredictResult> &ocr_results, std::vector<OCRPredictResult> &ocr_results) {
std::vector<double> &times) {
std::vector<std::string> rec_texts(img_list.size(), ""); std::vector<std::string> rec_texts(img_list.size(), "");
std::vector<float> rec_text_scores(img_list.size(), 0); std::vector<float> rec_text_scores(img_list.size(), 0);
std::vector<double> rec_times; std::vector<double> rec_times;
...@@ -75,14 +137,13 @@ void PPOCR::rec(std::vector<cv::Mat> img_list, ...@@ -75,14 +137,13 @@ void PPOCR::rec(std::vector<cv::Mat> img_list,
ocr_results[i].text = rec_texts[i]; ocr_results[i].text = rec_texts[i];
ocr_results[i].score = rec_text_scores[i]; ocr_results[i].score = rec_text_scores[i];
} }
times[0] += rec_times[0]; this->time_info_rec[0] += rec_times[0];
times[1] += rec_times[1]; this->time_info_rec[1] += rec_times[1];
times[2] += rec_times[2]; this->time_info_rec[2] += rec_times[2];
} }
void PPOCR::cls(std::vector<cv::Mat> img_list, void PPOCR::cls(std::vector<cv::Mat> img_list,
std::vector<OCRPredictResult> &ocr_results, std::vector<OCRPredictResult> &ocr_results) {
std::vector<double> &times) {
std::vector<int> cls_labels(img_list.size(), 0); std::vector<int> cls_labels(img_list.size(), 0);
std::vector<float> cls_scores(img_list.size(), 0); std::vector<float> cls_scores(img_list.size(), 0);
std::vector<double> cls_times; std::vector<double> cls_times;
...@@ -92,125 +153,43 @@ void PPOCR::cls(std::vector<cv::Mat> img_list, ...@@ -92,125 +153,43 @@ void PPOCR::cls(std::vector<cv::Mat> img_list,
ocr_results[i].cls_label = cls_labels[i]; ocr_results[i].cls_label = cls_labels[i];
ocr_results[i].cls_score = cls_scores[i]; ocr_results[i].cls_score = cls_scores[i];
} }
times[0] += cls_times[0]; this->time_info_cls[0] += cls_times[0];
times[1] += cls_times[1]; this->time_info_cls[1] += cls_times[1];
times[2] += cls_times[2]; this->time_info_cls[2] += cls_times[2];
} }
std::vector<std::vector<OCRPredictResult>> void PPOCR::reset_timer() {
PPOCR::ocr(std::vector<cv::String> cv_all_img_names, bool det, bool rec, this->time_info_det = {0, 0, 0};
bool cls) { this->time_info_rec = {0, 0, 0};
std::vector<double> time_info_det = {0, 0, 0}; this->time_info_cls = {0, 0, 0};
std::vector<double> time_info_rec = {0, 0, 0}; }
std::vector<double> time_info_cls = {0, 0, 0};
std::vector<std::vector<OCRPredictResult>> ocr_results;
if (!det) {
std::vector<OCRPredictResult> ocr_result;
// read image
std::vector<cv::Mat> img_list;
for (int i = 0; i < cv_all_img_names.size(); ++i) {
cv::Mat srcimg = cv::imread(cv_all_img_names[i], cv::IMREAD_COLOR);
if (!srcimg.data) {
std::cerr << "[ERROR] image read failed! image path: "
<< cv_all_img_names[i] << endl;
exit(1);
}
img_list.push_back(srcimg);
OCRPredictResult res;
ocr_result.push_back(res);
}
if (cls && this->classifier_ != nullptr) {
this->cls(img_list, ocr_result, time_info_cls);
for (int i = 0; i < img_list.size(); i++) {
if (ocr_result[i].cls_label % 2 == 1 &&
ocr_result[i].cls_score > this->classifier_->cls_thresh) {
cv::rotate(img_list[i], img_list[i], 1);
}
}
}
if (rec) {
this->rec(img_list, ocr_result, time_info_rec);
}
for (int i = 0; i < cv_all_img_names.size(); ++i) {
std::vector<OCRPredictResult> ocr_result_tmp;
ocr_result_tmp.push_back(ocr_result[i]);
ocr_results.push_back(ocr_result_tmp);
}
} else {
if (!Utility::PathExists(FLAGS_output) && FLAGS_det) {
Utility::CreateDir(FLAGS_output);
}
for (int i = 0; i < cv_all_img_names.size(); ++i) {
std::vector<OCRPredictResult> ocr_result;
if (!FLAGS_benchmark) {
cout << "predict img: " << cv_all_img_names[i] << endl;
}
cv::Mat srcimg = cv::imread(cv_all_img_names[i], cv::IMREAD_COLOR);
if (!srcimg.data) {
std::cerr << "[ERROR] image read failed! image path: "
<< cv_all_img_names[i] << endl;
exit(1);
}
// det
this->det(srcimg, ocr_result, time_info_det);
// crop image
std::vector<cv::Mat> img_list;
for (int j = 0; j < ocr_result.size(); j++) {
cv::Mat crop_img;
crop_img = Utility::GetRotateCropImage(srcimg, ocr_result[j].box);
img_list.push_back(crop_img);
}
// cls
if (cls && this->classifier_ != nullptr) {
this->cls(img_list, ocr_result, time_info_cls);
for (int i = 0; i < img_list.size(); i++) {
if (ocr_result[i].cls_label % 2 == 1 &&
ocr_result[i].cls_score > this->classifier_->cls_thresh) {
cv::rotate(img_list[i], img_list[i], 1);
}
}
}
// rec
if (rec) {
this->rec(img_list, ocr_result, time_info_rec);
}
ocr_results.push_back(ocr_result);
}
}
if (FLAGS_benchmark) {
this->log(time_info_det, time_info_rec, time_info_cls,
cv_all_img_names.size());
}
return ocr_results;
} // namespace PaddleOCR
void PPOCR::log(std::vector<double> &det_times, std::vector<double> &rec_times, void PPOCR::benchmark_log(int img_num) {
std::vector<double> &cls_times, int img_num) { if (this->time_info_det[0] + this->time_info_det[1] + this->time_info_det[2] >
if (det_times[0] + det_times[1] + det_times[2] > 0) { 0) {
AutoLogger autolog_det("ocr_det", FLAGS_use_gpu, FLAGS_use_tensorrt, AutoLogger autolog_det("ocr_det", FLAGS_use_gpu, FLAGS_use_tensorrt,
FLAGS_enable_mkldnn, FLAGS_cpu_threads, 1, "dynamic", FLAGS_enable_mkldnn, FLAGS_cpu_threads, 1, "dynamic",
FLAGS_precision, det_times, img_num); FLAGS_precision, this->time_info_det, img_num);
autolog_det.report(); autolog_det.report();
} }
if (rec_times[0] + rec_times[1] + rec_times[2] > 0) { if (this->time_info_rec[0] + this->time_info_rec[1] + this->time_info_rec[2] >
0) {
AutoLogger autolog_rec("ocr_rec", FLAGS_use_gpu, FLAGS_use_tensorrt, AutoLogger autolog_rec("ocr_rec", FLAGS_use_gpu, FLAGS_use_tensorrt,
FLAGS_enable_mkldnn, FLAGS_cpu_threads, FLAGS_enable_mkldnn, FLAGS_cpu_threads,
FLAGS_rec_batch_num, "dynamic", FLAGS_precision, FLAGS_rec_batch_num, "dynamic", FLAGS_precision,
rec_times, img_num); this->time_info_rec, img_num);
autolog_rec.report(); autolog_rec.report();
} }
if (cls_times[0] + cls_times[1] + cls_times[2] > 0) { if (this->time_info_cls[0] + this->time_info_cls[1] + this->time_info_cls[2] >
0) {
AutoLogger autolog_cls("ocr_cls", FLAGS_use_gpu, FLAGS_use_tensorrt, AutoLogger autolog_cls("ocr_cls", FLAGS_use_gpu, FLAGS_use_tensorrt,
FLAGS_enable_mkldnn, FLAGS_cpu_threads, FLAGS_enable_mkldnn, FLAGS_cpu_threads,
FLAGS_cls_batch_num, "dynamic", FLAGS_precision, FLAGS_cls_batch_num, "dynamic", FLAGS_precision,
cls_times, img_num); this->time_info_cls, img_num);
autolog_cls.report(); autolog_cls.report();
} }
} }
PPOCR::~PPOCR() { PPOCR::~PPOCR() {
if (this->detector_ != nullptr) { if (this->detector_ != nullptr) {
delete this->detector_; delete this->detector_;
......
...@@ -16,14 +16,19 @@ ...@@ -16,14 +16,19 @@
#include <include/paddlestructure.h> #include <include/paddlestructure.h>
#include "auto_log/autolog.h" #include "auto_log/autolog.h"
#include <numeric>
#include <sys/stat.h>
namespace PaddleOCR { namespace PaddleOCR {
PaddleStructure::PaddleStructure() { PaddleStructure::PaddleStructure() {
if (FLAGS_layout) {
this->layout_model_ = new StructureLayoutRecognizer(
FLAGS_layout_model_dir, FLAGS_use_gpu, FLAGS_gpu_id, FLAGS_gpu_mem,
FLAGS_cpu_threads, FLAGS_enable_mkldnn, FLAGS_layout_dict_path,
FLAGS_use_tensorrt, FLAGS_precision, FLAGS_layout_score_threshold,
FLAGS_layout_nms_threshold);
}
if (FLAGS_table) { if (FLAGS_table) {
this->recognizer_ = new StructureTableRecognizer( this->table_model_ = new StructureTableRecognizer(
FLAGS_table_model_dir, FLAGS_use_gpu, FLAGS_gpu_id, FLAGS_gpu_mem, FLAGS_table_model_dir, FLAGS_use_gpu, FLAGS_gpu_id, FLAGS_gpu_mem,
FLAGS_cpu_threads, FLAGS_enable_mkldnn, FLAGS_table_char_dict_path, FLAGS_cpu_threads, FLAGS_enable_mkldnn, FLAGS_table_char_dict_path,
FLAGS_use_tensorrt, FLAGS_precision, FLAGS_table_batch_num, FLAGS_use_tensorrt, FLAGS_precision, FLAGS_table_batch_num,
...@@ -31,68 +36,63 @@ PaddleStructure::PaddleStructure() { ...@@ -31,68 +36,63 @@ PaddleStructure::PaddleStructure() {
} }
}; };
std::vector<std::vector<StructurePredictResult>> std::vector<StructurePredictResult>
PaddleStructure::structure(std::vector<cv::String> cv_all_img_names, PaddleStructure::structure(cv::Mat srcimg, bool layout, bool table, bool ocr) {
bool layout, bool table) { cv::Mat img;
std::vector<double> time_info_det = {0, 0, 0}; srcimg.copyTo(img);
std::vector<double> time_info_rec = {0, 0, 0};
std::vector<double> time_info_cls = {0, 0, 0};
std::vector<double> time_info_table = {0, 0, 0};
std::vector<std::vector<StructurePredictResult>> structure_results; std::vector<StructurePredictResult> structure_results;
if (!Utility::PathExists(FLAGS_output) && FLAGS_det) { if (layout) {
Utility::CreateDir(FLAGS_output); this->layout(img, structure_results);
} else {
StructurePredictResult res;
res.type = "table";
res.box = std::vector<float>(4, 0.0);
res.box[2] = img.cols;
res.box[3] = img.rows;
structure_results.push_back(res);
} }
for (int i = 0; i < cv_all_img_names.size(); ++i) { cv::Mat roi_img;
std::vector<StructurePredictResult> structure_result; for (int i = 0; i < structure_results.size(); i++) {
cv::Mat srcimg = cv::imread(cv_all_img_names[i], cv::IMREAD_COLOR); // crop image
if (!srcimg.data) { roi_img = Utility::crop_image(img, structure_results[i].box);
std::cerr << "[ERROR] image read failed! image path: " if (structure_results[i].type == "table" && table) {
<< cv_all_img_names[i] << endl; this->table(roi_img, structure_results[i]);
exit(1); } else if (ocr) {
} structure_results[i].text_res = this->ocr(roi_img, true, true, false);
if (layout) {
} else {
StructurePredictResult res;
res.type = "table";
res.box = std::vector<int>(4, 0);
res.box[2] = srcimg.cols;
res.box[3] = srcimg.rows;
structure_result.push_back(res);
}
cv::Mat roi_img;
for (int i = 0; i < structure_result.size(); i++) {
// crop image
roi_img = Utility::crop_image(srcimg, structure_result[i].box);
if (structure_result[i].type == "table") {
this->table(roi_img, structure_result[i], time_info_table,
time_info_det, time_info_rec, time_info_cls);
}
} }
structure_results.push_back(structure_result);
} }
return structure_results; return structure_results;
}; };
void PaddleStructure::layout(
cv::Mat img, std::vector<StructurePredictResult> &structure_result) {
std::vector<double> layout_times;
this->layout_model_->Run(img, structure_result, layout_times);
this->time_info_layout[0] += layout_times[0];
this->time_info_layout[1] += layout_times[1];
this->time_info_layout[2] += layout_times[2];
}
void PaddleStructure::table(cv::Mat img, void PaddleStructure::table(cv::Mat img,
StructurePredictResult &structure_result, StructurePredictResult &structure_result) {
std::vector<double> &time_info_table,
std::vector<double> &time_info_det,
std::vector<double> &time_info_rec,
std::vector<double> &time_info_cls) {
// predict structure // predict structure
std::vector<std::vector<std::string>> structure_html_tags; std::vector<std::vector<std::string>> structure_html_tags;
std::vector<float> structure_scores(1, 0); std::vector<float> structure_scores(1, 0);
std::vector<std::vector<std::vector<int>>> structure_boxes; std::vector<std::vector<std::vector<int>>> structure_boxes;
std::vector<double> structure_imes; std::vector<double> structure_times;
std::vector<cv::Mat> img_list; std::vector<cv::Mat> img_list;
img_list.push_back(img); img_list.push_back(img);
this->recognizer_->Run(img_list, structure_html_tags, structure_scores,
structure_boxes, structure_imes); this->table_model_->Run(img_list, structure_html_tags, structure_scores,
time_info_table[0] += structure_imes[0]; structure_boxes, structure_times);
time_info_table[1] += structure_imes[1];
time_info_table[2] += structure_imes[2]; this->time_info_table[0] += structure_times[0];
this->time_info_table[1] += structure_times[1];
this->time_info_table[2] += structure_times[2];
std::vector<OCRPredictResult> ocr_result; std::vector<OCRPredictResult> ocr_result;
std::string html; std::string html;
...@@ -100,22 +100,22 @@ void PaddleStructure::table(cv::Mat img, ...@@ -100,22 +100,22 @@ void PaddleStructure::table(cv::Mat img,
for (int i = 0; i < img_list.size(); i++) { for (int i = 0; i < img_list.size(); i++) {
// det // det
this->det(img_list[i], ocr_result, time_info_det); this->det(img_list[i], ocr_result);
// crop image // crop image
std::vector<cv::Mat> rec_img_list; std::vector<cv::Mat> rec_img_list;
std::vector<int> ocr_box; std::vector<int> ocr_box;
for (int j = 0; j < ocr_result.size(); j++) { for (int j = 0; j < ocr_result.size(); j++) {
ocr_box = Utility::xyxyxyxy2xyxy(ocr_result[j].box); ocr_box = Utility::xyxyxyxy2xyxy(ocr_result[j].box);
ocr_box[0] = max(0, ocr_box[0] - expand_pixel); ocr_box[0] = std::max(0, ocr_box[0] - expand_pixel);
ocr_box[1] = max(0, ocr_box[1] - expand_pixel), ocr_box[1] = std::max(0, ocr_box[1] - expand_pixel),
ocr_box[2] = min(img_list[i].cols, ocr_box[2] + expand_pixel); ocr_box[2] = std::min(img_list[i].cols, ocr_box[2] + expand_pixel);
ocr_box[3] = min(img_list[i].rows, ocr_box[3] + expand_pixel); ocr_box[3] = std::min(img_list[i].rows, ocr_box[3] + expand_pixel);
cv::Mat crop_img = Utility::crop_image(img_list[i], ocr_box); cv::Mat crop_img = Utility::crop_image(img_list[i], ocr_box);
rec_img_list.push_back(crop_img); rec_img_list.push_back(crop_img);
} }
// rec // rec
this->rec(rec_img_list, ocr_result, time_info_rec); this->rec(rec_img_list, ocr_result);
// rebuild table // rebuild table
html = this->rebuild_table(structure_html_tags[i], structure_boxes[i], html = this->rebuild_table(structure_html_tags[i], structure_boxes[i],
ocr_result); ocr_result);
...@@ -130,8 +130,8 @@ PaddleStructure::rebuild_table(std::vector<std::string> structure_html_tags, ...@@ -130,8 +130,8 @@ PaddleStructure::rebuild_table(std::vector<std::string> structure_html_tags,
std::vector<std::vector<int>> structure_boxes, std::vector<std::vector<int>> structure_boxes,
std::vector<OCRPredictResult> &ocr_result) { std::vector<OCRPredictResult> &ocr_result) {
// match text in same cell // match text in same cell
std::vector<std::vector<string>> matched(structure_boxes.size(), std::vector<std::vector<std::string>> matched(structure_boxes.size(),
std::vector<std::string>()); std::vector<std::string>());
std::vector<int> ocr_box; std::vector<int> ocr_box;
std::vector<int> structure_box; std::vector<int> structure_box;
...@@ -150,7 +150,7 @@ PaddleStructure::rebuild_table(std::vector<std::string> structure_html_tags, ...@@ -150,7 +150,7 @@ PaddleStructure::rebuild_table(std::vector<std::string> structure_html_tags,
structure_box = structure_boxes[j]; structure_box = structure_boxes[j];
} }
dis_list[j][0] = this->dis(ocr_box, structure_box); dis_list[j][0] = this->dis(ocr_box, structure_box);
dis_list[j][1] = 1 - this->iou(ocr_box, structure_box); dis_list[j][1] = 1 - Utility::iou(ocr_box, structure_box);
dis_list[j][2] = j; dis_list[j][2] = j;
} }
// find min dis idx // find min dis idx
...@@ -216,28 +216,6 @@ PaddleStructure::rebuild_table(std::vector<std::string> structure_html_tags, ...@@ -216,28 +216,6 @@ PaddleStructure::rebuild_table(std::vector<std::string> structure_html_tags,
return html_str; return html_str;
} }
float PaddleStructure::iou(std::vector<int> &box1, std::vector<int> &box2) {
int area1 = max(0, box1[2] - box1[0]) * max(0, box1[3] - box1[1]);
int area2 = max(0, box2[2] - box2[0]) * max(0, box2[3] - box2[1]);
// computing the sum_area
int sum_area = area1 + area2;
// find the each point of intersect rectangle
int x1 = max(box1[0], box2[0]);
int y1 = max(box1[1], box2[1]);
int x2 = min(box1[2], box2[2]);
int y2 = min(box1[3], box2[3]);
// judge if there is an intersect
if (y1 >= y2 || x1 >= x2) {
return 0.0;
} else {
int intersect = (x2 - x1) * (y2 - y1);
return intersect / (sum_area - intersect + 0.00000001);
}
}
float PaddleStructure::dis(std::vector<int> &box1, std::vector<int> &box2) { float PaddleStructure::dis(std::vector<int> &box1, std::vector<int> &box2) {
int x1_1 = box1[0]; int x1_1 = box1[0];
int y1_1 = box1[1]; int y1_1 = box1[1];
...@@ -253,12 +231,64 @@ float PaddleStructure::dis(std::vector<int> &box1, std::vector<int> &box2) { ...@@ -253,12 +231,64 @@ float PaddleStructure::dis(std::vector<int> &box1, std::vector<int> &box2) {
abs(x1_2 - x1_1) + abs(y1_2 - y1_1) + abs(x2_2 - x2_1) + abs(y2_2 - y2_1); abs(x1_2 - x1_1) + abs(y1_2 - y1_1) + abs(x2_2 - x2_1) + abs(y2_2 - y2_1);
float dis_2 = abs(x1_2 - x1_1) + abs(y1_2 - y1_1); float dis_2 = abs(x1_2 - x1_1) + abs(y1_2 - y1_1);
float dis_3 = abs(x2_2 - x2_1) + abs(y2_2 - y2_1); float dis_3 = abs(x2_2 - x2_1) + abs(y2_2 - y2_1);
return dis + min(dis_2, dis_3); return dis + std::min(dis_2, dis_3);
}
void PaddleStructure::reset_timer() {
this->time_info_det = {0, 0, 0};
this->time_info_rec = {0, 0, 0};
this->time_info_cls = {0, 0, 0};
this->time_info_table = {0, 0, 0};
this->time_info_layout = {0, 0, 0};
}
void PaddleStructure::benchmark_log(int img_num) {
if (this->time_info_det[0] + this->time_info_det[1] + this->time_info_det[2] >
0) {
AutoLogger autolog_det("ocr_det", FLAGS_use_gpu, FLAGS_use_tensorrt,
FLAGS_enable_mkldnn, FLAGS_cpu_threads, 1, "dynamic",
FLAGS_precision, this->time_info_det, img_num);
autolog_det.report();
}
if (this->time_info_rec[0] + this->time_info_rec[1] + this->time_info_rec[2] >
0) {
AutoLogger autolog_rec("ocr_rec", FLAGS_use_gpu, FLAGS_use_tensorrt,
FLAGS_enable_mkldnn, FLAGS_cpu_threads,
FLAGS_rec_batch_num, "dynamic", FLAGS_precision,
this->time_info_rec, img_num);
autolog_rec.report();
}
if (this->time_info_cls[0] + this->time_info_cls[1] + this->time_info_cls[2] >
0) {
AutoLogger autolog_cls("ocr_cls", FLAGS_use_gpu, FLAGS_use_tensorrt,
FLAGS_enable_mkldnn, FLAGS_cpu_threads,
FLAGS_cls_batch_num, "dynamic", FLAGS_precision,
this->time_info_cls, img_num);
autolog_cls.report();
}
if (this->time_info_table[0] + this->time_info_table[1] +
this->time_info_table[2] >
0) {
AutoLogger autolog_table("table", FLAGS_use_gpu, FLAGS_use_tensorrt,
FLAGS_enable_mkldnn, FLAGS_cpu_threads,
FLAGS_cls_batch_num, "dynamic", FLAGS_precision,
this->time_info_table, img_num);
autolog_table.report();
}
if (this->time_info_layout[0] + this->time_info_layout[1] +
this->time_info_layout[2] >
0) {
AutoLogger autolog_layout("layout", FLAGS_use_gpu, FLAGS_use_tensorrt,
FLAGS_enable_mkldnn, FLAGS_cpu_threads,
FLAGS_cls_batch_num, "dynamic", FLAGS_precision,
this->time_info_layout, img_num);
autolog_layout.report();
}
} }
PaddleStructure::~PaddleStructure() { PaddleStructure::~PaddleStructure() {
if (this->recognizer_ != nullptr) { if (this->table_model_ != nullptr) {
delete this->recognizer_; delete this->table_model_;
} }
}; };
......
...@@ -12,7 +12,6 @@ ...@@ -12,7 +12,6 @@
// See the License for the specific language governing permissions and // See the License for the specific language governing permissions and
// limitations under the License. // limitations under the License.
#include <include/clipper.h>
#include <include/postprocess_op.h> #include <include/postprocess_op.h>
namespace PaddleOCR { namespace PaddleOCR {
...@@ -431,7 +430,7 @@ void TablePostProcessor::Run( ...@@ -431,7 +430,7 @@ void TablePostProcessor::Run(
} }
} }
score /= count; score /= count;
if (isnan(score) || rec_boxes.size() == 0) { if (std::isnan(score) || rec_boxes.size() == 0) {
score = -1; score = -1;
} }
rec_scores.push_back(score); rec_scores.push_back(score);
...@@ -440,4 +439,137 @@ void TablePostProcessor::Run( ...@@ -440,4 +439,137 @@ void TablePostProcessor::Run(
} }
} }
void PicodetPostProcessor::init(std::string label_path,
const double score_threshold,
const double nms_threshold,
const std::vector<int> &fpn_stride) {
this->label_list_ = Utility::ReadDict(label_path);
this->score_threshold_ = score_threshold;
this->nms_threshold_ = nms_threshold;
this->num_class_ = label_list_.size();
this->fpn_stride_ = fpn_stride;
}
void PicodetPostProcessor::Run(std::vector<StructurePredictResult> &results,
std::vector<std::vector<float>> outs,
std::vector<int> ori_shape,
std::vector<int> resize_shape, int reg_max) {
int in_h = resize_shape[0];
int in_w = resize_shape[1];
float scale_factor_h = resize_shape[0] / float(ori_shape[0]);
float scale_factor_w = resize_shape[1] / float(ori_shape[1]);
std::vector<std::vector<StructurePredictResult>> bbox_results;
bbox_results.resize(this->num_class_);
for (int i = 0; i < this->fpn_stride_.size(); ++i) {
int feature_h = std::ceil((float)in_h / this->fpn_stride_[i]);
int feature_w = std::ceil((float)in_w / this->fpn_stride_[i]);
for (int idx = 0; idx < feature_h * feature_w; idx++) {
// score and label
float score = 0;
int cur_label = 0;
for (int label = 0; label < this->num_class_; label++) {
if (outs[i][idx * this->num_class_ + label] > score) {
score = outs[i][idx * this->num_class_ + label];
cur_label = label;
}
}
// bbox
if (score > this->score_threshold_) {
int row = idx / feature_w;
int col = idx % feature_w;
std::vector<float> bbox_pred(
outs[i + this->fpn_stride_.size()].begin() + idx * 4 * reg_max,
outs[i + this->fpn_stride_.size()].begin() +
(idx + 1) * 4 * reg_max);
bbox_results[cur_label].push_back(
this->disPred2Bbox(bbox_pred, cur_label, score, col, row,
this->fpn_stride_[i], resize_shape, reg_max));
}
}
}
for (int i = 0; i < bbox_results.size(); i++) {
bool flag = bbox_results[i].size() <= 0;
}
for (int i = 0; i < bbox_results.size(); i++) {
bool flag = bbox_results[i].size() <= 0;
if (bbox_results[i].size() <= 0) {
continue;
}
this->nms(bbox_results[i], this->nms_threshold_);
for (auto box : bbox_results[i]) {
box.box[0] = box.box[0] / scale_factor_w;
box.box[2] = box.box[2] / scale_factor_w;
box.box[1] = box.box[1] / scale_factor_h;
box.box[3] = box.box[3] / scale_factor_h;
results.push_back(box);
}
}
}
StructurePredictResult
PicodetPostProcessor::disPred2Bbox(std::vector<float> bbox_pred, int label,
float score, int x, int y, int stride,
std::vector<int> im_shape, int reg_max) {
float ct_x = (x + 0.5) * stride;
float ct_y = (y + 0.5) * stride;
std::vector<float> dis_pred;
dis_pred.resize(4);
for (int i = 0; i < 4; i++) {
float dis = 0;
std::vector<float> bbox_pred_i(bbox_pred.begin() + i * reg_max,
bbox_pred.begin() + (i + 1) * reg_max);
std::vector<float> dis_after_sm =
Utility::activation_function_softmax(bbox_pred_i);
for (int j = 0; j < reg_max; j++) {
dis += j * dis_after_sm[j];
}
dis *= stride;
dis_pred[i] = dis;
}
float xmin = (std::max)(ct_x - dis_pred[0], .0f);
float ymin = (std::max)(ct_y - dis_pred[1], .0f);
float xmax = (std::min)(ct_x + dis_pred[2], (float)im_shape[1]);
float ymax = (std::min)(ct_y + dis_pred[3], (float)im_shape[0]);
StructurePredictResult result_item;
result_item.box = {xmin, ymin, xmax, ymax};
result_item.type = this->label_list_[label];
result_item.confidence = score;
return result_item;
}
void PicodetPostProcessor::nms(std::vector<StructurePredictResult> &input_boxes,
float nms_threshold) {
std::sort(input_boxes.begin(), input_boxes.end(),
[](StructurePredictResult a, StructurePredictResult b) {
return a.confidence > b.confidence;
});
std::vector<int> picked(input_boxes.size(), 1);
for (int i = 0; i < input_boxes.size(); ++i) {
if (picked[i] == 0) {
continue;
}
for (int j = i + 1; j < input_boxes.size(); ++j) {
if (picked[j] == 0) {
continue;
}
float iou = Utility::iou(input_boxes[i].box, input_boxes[j].box);
if (iou > nms_threshold) {
picked[j] = 0;
}
}
}
std::vector<StructurePredictResult> input_boxes_nms;
for (int i = 0; i < input_boxes.size(); ++i) {
if (picked[i] == 1) {
input_boxes_nms.push_back(input_boxes[i]);
}
}
input_boxes = input_boxes_nms;
}
} // namespace PaddleOCR } // namespace PaddleOCR
...@@ -12,21 +12,6 @@ ...@@ -12,21 +12,6 @@
// See the License for the specific language governing permissions and // See the License for the specific language governing permissions and
// limitations under the License. // limitations under the License.
#include "opencv2/core.hpp"
#include "opencv2/imgcodecs.hpp"
#include "opencv2/imgproc.hpp"
#include "paddle_api.h"
#include "paddle_inference_api.h"
#include <chrono>
#include <iomanip>
#include <iostream>
#include <ostream>
#include <vector>
#include <cstring>
#include <fstream>
#include <numeric>
#include <include/preprocess_op.h> #include <include/preprocess_op.h>
namespace PaddleOCR { namespace PaddleOCR {
...@@ -69,13 +54,13 @@ void Normalize::Run(cv::Mat *im, const std::vector<float> &mean, ...@@ -69,13 +54,13 @@ void Normalize::Run(cv::Mat *im, const std::vector<float> &mean,
} }
void ResizeImgType0::Run(const cv::Mat &img, cv::Mat &resize_img, void ResizeImgType0::Run(const cv::Mat &img, cv::Mat &resize_img,
string limit_type, int limit_side_len, float &ratio_h, std::string limit_type, int limit_side_len,
float &ratio_w, bool use_tensorrt) { float &ratio_h, float &ratio_w, bool use_tensorrt) {
int w = img.cols; int w = img.cols;
int h = img.rows; int h = img.rows;
float ratio = 1.f; float ratio = 1.f;
if (limit_type == "min") { if (limit_type == "min") {
int min_wh = min(h, w); int min_wh = std::min(h, w);
if (min_wh < limit_side_len) { if (min_wh < limit_side_len) {
if (h < w) { if (h < w) {
ratio = float(limit_side_len) / float(h); ratio = float(limit_side_len) / float(h);
...@@ -84,7 +69,7 @@ void ResizeImgType0::Run(const cv::Mat &img, cv::Mat &resize_img, ...@@ -84,7 +69,7 @@ void ResizeImgType0::Run(const cv::Mat &img, cv::Mat &resize_img,
} }
} }
} else { } else {
int max_wh = max(h, w); int max_wh = std::max(h, w);
if (max_wh > limit_side_len) { if (max_wh > limit_side_len) {
if (h > w) { if (h > w) {
ratio = float(limit_side_len) / float(h); ratio = float(limit_side_len) / float(h);
...@@ -97,8 +82,8 @@ void ResizeImgType0::Run(const cv::Mat &img, cv::Mat &resize_img, ...@@ -97,8 +82,8 @@ void ResizeImgType0::Run(const cv::Mat &img, cv::Mat &resize_img,
int resize_h = int(float(h) * ratio); int resize_h = int(float(h) * ratio);
int resize_w = int(float(w) * ratio); int resize_w = int(float(w) * ratio);
resize_h = max(int(round(float(resize_h) / 32) * 32), 32); resize_h = std::max(int(round(float(resize_h) / 32) * 32), 32);
resize_w = max(int(round(float(resize_w) / 32) * 32), 32); resize_w = std::max(int(round(float(resize_w) / 32) * 32), 32);
cv::resize(img, resize_img, cv::Size(resize_w, resize_h)); cv::resize(img, resize_img, cv::Size(resize_w, resize_h));
ratio_h = float(resize_h) / float(h); ratio_h = float(resize_h) / float(h);
...@@ -175,4 +160,9 @@ void TablePadImg::Run(const cv::Mat &img, cv::Mat &resize_img, ...@@ -175,4 +160,9 @@ void TablePadImg::Run(const cv::Mat &img, cv::Mat &resize_img,
cv::BORDER_CONSTANT, cv::Scalar(0, 0, 0)); cv::BORDER_CONSTANT, cv::Scalar(0, 0, 0));
} }
void Resize::Run(const cv::Mat &img, cv::Mat &resize_img, const int h,
const int w) {
cv::resize(img, resize_img, cv::Size(w, h));
}
} // namespace PaddleOCR } // namespace PaddleOCR
// Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include <include/structure_layout.h>
namespace PaddleOCR {
void StructureLayoutRecognizer::Run(cv::Mat img,
std::vector<StructurePredictResult> &result,
std::vector<double> &times) {
std::chrono::duration<float> preprocess_diff =
std::chrono::steady_clock::now() - std::chrono::steady_clock::now();
std::chrono::duration<float> inference_diff =
std::chrono::steady_clock::now() - std::chrono::steady_clock::now();
std::chrono::duration<float> postprocess_diff =
std::chrono::steady_clock::now() - std::chrono::steady_clock::now();
// preprocess
auto preprocess_start = std::chrono::steady_clock::now();
cv::Mat srcimg;
img.copyTo(srcimg);
cv::Mat resize_img;
this->resize_op_.Run(srcimg, resize_img, 800, 608);
this->normalize_op_.Run(&resize_img, this->mean_, this->scale_,
this->is_scale_);
std::vector<float> input(1 * 3 * resize_img.rows * resize_img.cols, 0.0f);
this->permute_op_.Run(&resize_img, input.data());
auto preprocess_end = std::chrono::steady_clock::now();
preprocess_diff += preprocess_end - preprocess_start;
// inference.
auto input_names = this->predictor_->GetInputNames();
auto input_t = this->predictor_->GetInputHandle(input_names[0]);
input_t->Reshape({1, 3, resize_img.rows, resize_img.cols});
auto inference_start = std::chrono::steady_clock::now();
input_t->CopyFromCpu(input.data());
this->predictor_->Run();
// Get output tensor
std::vector<std::vector<float>> out_tensor_list;
std::vector<std::vector<int>> output_shape_list;
auto output_names = this->predictor_->GetOutputNames();
for (int j = 0; j < output_names.size(); j++) {
auto output_tensor = this->predictor_->GetOutputHandle(output_names[j]);
std::vector<int> output_shape = output_tensor->shape();
int out_num = std::accumulate(output_shape.begin(), output_shape.end(), 1,
std::multiplies<int>());
output_shape_list.push_back(output_shape);
std::vector<float> out_data;
out_data.resize(out_num);
output_tensor->CopyToCpu(out_data.data());
out_tensor_list.push_back(out_data);
}
auto inference_end = std::chrono::steady_clock::now();
inference_diff += inference_end - inference_start;
// postprocess
auto postprocess_start = std::chrono::steady_clock::now();
std::vector<int> bbox_num;
int reg_max = 0;
for (int i = 0; i < out_tensor_list.size(); i++) {
if (i == this->post_processor_.fpn_stride_.size()) {
reg_max = output_shape_list[i][2] / 4;
break;
}
}
std::vector<int> ori_shape = {srcimg.rows, srcimg.cols};
std::vector<int> resize_shape = {resize_img.rows, resize_img.cols};
this->post_processor_.Run(result, out_tensor_list, ori_shape, resize_shape,
reg_max);
bbox_num.push_back(result.size());
auto postprocess_end = std::chrono::steady_clock::now();
postprocess_diff += postprocess_end - postprocess_start;
times.push_back(double(preprocess_diff.count() * 1000));
times.push_back(double(inference_diff.count() * 1000));
times.push_back(double(postprocess_diff.count() * 1000));
}
void StructureLayoutRecognizer::LoadModel(const std::string &model_dir) {
paddle_infer::Config config;
if (Utility::PathExists(model_dir + "/inference.pdmodel") &&
Utility::PathExists(model_dir + "/inference.pdiparams")) {
config.SetModel(model_dir + "/inference.pdmodel",
model_dir + "/inference.pdiparams");
} else if (Utility::PathExists(model_dir + "/model.pdmodel") &&
Utility::PathExists(model_dir + "/model.pdiparams")) {
config.SetModel(model_dir + "/model.pdmodel",
model_dir + "/model.pdiparams");
} else {
std::cerr << "[ERROR] not find model.pdiparams or inference.pdiparams in "
<< model_dir << std::endl;
exit(1);
}
if (this->use_gpu_) {
config.EnableUseGpu(this->gpu_mem_, this->gpu_id_);
if (this->use_tensorrt_) {
auto precision = paddle_infer::Config::Precision::kFloat32;
if (this->precision_ == "fp16") {
precision = paddle_infer::Config::Precision::kHalf;
}
if (this->precision_ == "int8") {
precision = paddle_infer::Config::Precision::kInt8;
}
config.EnableTensorRtEngine(1 << 20, 10, 3, precision, false, false);
if (!Utility::PathExists("./trt_layout_shape.txt")) {
config.CollectShapeRangeInfo("./trt_layout_shape.txt");
} else {
config.EnableTunedTensorRtDynamicShape("./trt_layout_shape.txt", true);
}
}
} else {
config.DisableGpu();
if (this->use_mkldnn_) {
config.EnableMKLDNN();
}
config.SetCpuMathLibraryNumThreads(this->cpu_math_library_num_threads_);
}
// false for zero copy tensor
config.SwitchUseFeedFetchOps(false);
// true for multiple input
config.SwitchSpecifyInputNames(true);
config.SwitchIrOptim(true);
config.EnableMemoryOptim();
config.DisableGlogInfo();
this->predictor_ = paddle_infer::CreatePredictor(config);
}
} // namespace PaddleOCR
...@@ -34,7 +34,7 @@ void StructureTableRecognizer::Run( ...@@ -34,7 +34,7 @@ void StructureTableRecognizer::Run(
beg_img_no += this->table_batch_num_) { beg_img_no += this->table_batch_num_) {
// preprocess // preprocess
auto preprocess_start = std::chrono::steady_clock::now(); auto preprocess_start = std::chrono::steady_clock::now();
int end_img_no = min(img_num, beg_img_no + this->table_batch_num_); int end_img_no = std::min(img_num, beg_img_no + this->table_batch_num_);
int batch_num = end_img_no - beg_img_no; int batch_num = end_img_no - beg_img_no;
std::vector<cv::Mat> norm_img_batch; std::vector<cv::Mat> norm_img_batch;
std::vector<int> width_list; std::vector<int> width_list;
...@@ -118,7 +118,7 @@ void StructureTableRecognizer::Run( ...@@ -118,7 +118,7 @@ void StructureTableRecognizer::Run(
} }
void StructureTableRecognizer::LoadModel(const std::string &model_dir) { void StructureTableRecognizer::LoadModel(const std::string &model_dir) {
AnalysisConfig config; paddle_infer::Config config;
config.SetModel(model_dir + "/inference.pdmodel", config.SetModel(model_dir + "/inference.pdmodel",
model_dir + "/inference.pdiparams"); model_dir + "/inference.pdiparams");
...@@ -133,6 +133,11 @@ void StructureTableRecognizer::LoadModel(const std::string &model_dir) { ...@@ -133,6 +133,11 @@ void StructureTableRecognizer::LoadModel(const std::string &model_dir) {
precision = paddle_infer::Config::Precision::kInt8; precision = paddle_infer::Config::Precision::kInt8;
} }
config.EnableTensorRtEngine(1 << 20, 10, 3, precision, false, false); config.EnableTensorRtEngine(1 << 20, 10, 3, precision, false, false);
if (!Utility::PathExists("./trt_table_shape.txt")) {
config.CollectShapeRangeInfo("./trt_table_shape.txt");
} else {
config.EnableTunedTensorRtDynamicShape("./trt_table_shape.txt", true);
}
} }
} else { } else {
config.DisableGpu(); config.DisableGpu();
...@@ -152,6 +157,6 @@ void StructureTableRecognizer::LoadModel(const std::string &model_dir) { ...@@ -152,6 +157,6 @@ void StructureTableRecognizer::LoadModel(const std::string &model_dir) {
config.EnableMemoryOptim(); config.EnableMemoryOptim();
config.DisableGlogInfo(); config.DisableGlogInfo();
this->predictor_ = CreatePredictor(config); this->predictor_ = paddle_infer::CreatePredictor(config);
} }
} // namespace PaddleOCR } // namespace PaddleOCR
...@@ -70,6 +70,7 @@ void Utility::VisualizeBboxes(const cv::Mat &srcimg, ...@@ -70,6 +70,7 @@ void Utility::VisualizeBboxes(const cv::Mat &srcimg,
const std::string &save_path) { const std::string &save_path) {
cv::Mat img_vis; cv::Mat img_vis;
srcimg.copyTo(img_vis); srcimg.copyTo(img_vis);
img_vis = crop_image(img_vis, structure_result.box);
for (int n = 0; n < structure_result.cell_box.size(); n++) { for (int n = 0; n < structure_result.cell_box.size(); n++) {
if (structure_result.cell_box[n].size() == 8) { if (structure_result.cell_box[n].size() == 8) {
cv::Point rook_points[4]; cv::Point rook_points[4];
...@@ -280,23 +281,29 @@ void Utility::print_result(const std::vector<OCRPredictResult> &ocr_result) { ...@@ -280,23 +281,29 @@ void Utility::print_result(const std::vector<OCRPredictResult> &ocr_result) {
} }
} }
cv::Mat Utility::crop_image(cv::Mat &img, std::vector<int> &area) { cv::Mat Utility::crop_image(cv::Mat &img, const std::vector<int> &box) {
cv::Mat crop_im; cv::Mat crop_im;
int crop_x1 = std::max(0, area[0]); int crop_x1 = std::max(0, box[0]);
int crop_y1 = std::max(0, area[1]); int crop_y1 = std::max(0, box[1]);
int crop_x2 = std::min(img.cols - 1, area[2] - 1); int crop_x2 = std::min(img.cols - 1, box[2] - 1);
int crop_y2 = std::min(img.rows - 1, area[3] - 1); int crop_y2 = std::min(img.rows - 1, box[3] - 1);
crop_im = cv::Mat::zeros(area[3] - area[1], area[2] - area[0], 16); crop_im = cv::Mat::zeros(box[3] - box[1], box[2] - box[0], 16);
cv::Mat crop_im_window = cv::Mat crop_im_window =
crop_im(cv::Range(crop_y1 - area[1], crop_y2 + 1 - area[1]), crop_im(cv::Range(crop_y1 - box[1], crop_y2 + 1 - box[1]),
cv::Range(crop_x1 - area[0], crop_x2 + 1 - area[0])); cv::Range(crop_x1 - box[0], crop_x2 + 1 - box[0]));
cv::Mat roi_img = cv::Mat roi_img =
img(cv::Range(crop_y1, crop_y2 + 1), cv::Range(crop_x1, crop_x2 + 1)); img(cv::Range(crop_y1, crop_y2 + 1), cv::Range(crop_x1, crop_x2 + 1));
crop_im_window += roi_img; crop_im_window += roi_img;
return crop_im; return crop_im;
} }
cv::Mat Utility::crop_image(cv::Mat &img, const std::vector<float> &box) {
std::vector<int> box_int = {(int)box[0], (int)box[1], (int)box[2],
(int)box[3]};
return crop_image(img, box_int);
}
void Utility::sorted_boxes(std::vector<OCRPredictResult> &ocr_result) { void Utility::sorted_boxes(std::vector<OCRPredictResult> &ocr_result) {
std::sort(ocr_result.begin(), ocr_result.end(), Utility::comparison_box); std::sort(ocr_result.begin(), ocr_result.end(), Utility::comparison_box);
if (ocr_result.size() > 0) { if (ocr_result.size() > 0) {
...@@ -341,4 +348,78 @@ std::vector<int> Utility::xyxyxyxy2xyxy(std::vector<int> &box) { ...@@ -341,4 +348,78 @@ std::vector<int> Utility::xyxyxyxy2xyxy(std::vector<int> &box) {
return box1; return box1;
} }
float Utility::fast_exp(float x) {
union {
uint32_t i;
float f;
} v{};
v.i = (1 << 23) * (1.4426950409 * x + 126.93490512f);
return v.f;
}
std::vector<float>
Utility::activation_function_softmax(std::vector<float> &src) {
int length = src.size();
std::vector<float> dst;
dst.resize(length);
const float alpha = float(*std::max_element(&src[0], &src[0 + length]));
float denominator{0};
for (int i = 0; i < length; ++i) {
dst[i] = fast_exp(src[i] - alpha);
denominator += dst[i];
}
for (int i = 0; i < length; ++i) {
dst[i] /= denominator;
}
return dst;
}
float Utility::iou(std::vector<int> &box1, std::vector<int> &box2) {
int area1 = std::max(0, box1[2] - box1[0]) * std::max(0, box1[3] - box1[1]);
int area2 = std::max(0, box2[2] - box2[0]) * std::max(0, box2[3] - box2[1]);
// computing the sum_area
int sum_area = area1 + area2;
// find the each point of intersect rectangle
int x1 = std::max(box1[0], box2[0]);
int y1 = std::max(box1[1], box2[1]);
int x2 = std::min(box1[2], box2[2]);
int y2 = std::min(box1[3], box2[3]);
// judge if there is an intersect
if (y1 >= y2 || x1 >= x2) {
return 0.0;
} else {
int intersect = (x2 - x1) * (y2 - y1);
return intersect / (sum_area - intersect + 0.00000001);
}
}
float Utility::iou(std::vector<float> &box1, std::vector<float> &box2) {
float area1 = std::max((float)0.0, box1[2] - box1[0]) *
std::max((float)0.0, box1[3] - box1[1]);
float area2 = std::max((float)0.0, box2[2] - box2[0]) *
std::max((float)0.0, box2[3] - box2[1]);
// computing the sum_area
float sum_area = area1 + area2;
// find the each point of intersect rectangle
float x1 = std::max(box1[0], box2[0]);
float y1 = std::max(box1[1], box2[1]);
float x2 = std::min(box1[2], box2[2]);
float y2 = std::min(box1[3], box2[3]);
// judge if there is an intersect
if (y1 >= y2 || x1 >= x2) {
return 0.0;
} else {
float intersect = (x2 - x1) * (y2 - y1);
return intersect / (sum_area - intersect + 0.00000001);
}
}
} // namespace PaddleOCR } // namespace PaddleOCR
\ No newline at end of file
ppstructure/docs/layout/layout.png

178.6 KB | W: | H:

ppstructure/docs/layout/layout.png

1.2 MB | W: | H:

ppstructure/docs/layout/layout.png
ppstructure/docs/layout/layout.png
ppstructure/docs/layout/layout.png
ppstructure/docs/layout/layout.png
  • 2-up
  • Swipe
  • Onion skin
...@@ -23,7 +23,7 @@ English | [简体中文](README_ch.md) ...@@ -23,7 +23,7 @@ English | [简体中文](README_ch.md)
## 1. Introduction ## 1. Introduction
Layout analysis refers to the regional division of documents in the form of pictures and the positioning of key areas, such as text, title, table, picture, etc. The layout analysis algorithm is based on the lightweight model PP-picodet of [PaddleDetection]( https://github.com/PaddlePaddle/PaddleDetection ) Layout analysis refers to the regional division of documents in the form of pictures and the positioning of key areas, such as text, title, table, picture, etc. The layout analysis algorithm is based on the lightweight model PP-picodet of [PaddleDetection]( https://github.com/PaddlePaddle/PaddleDetection ), including English layout analysis, Chinese layout analysis and table layout analysis models. English layout analysis models can detect document layout elements such as text, title, table, figure, list. Chinese layout analysis models can detect document layout elements such as text, figure, figure caption, table, table caption, header, footer, reference, and equation. Table layout analysis models can detect table regions.
<div align="center"> <div align="center">
<img src="../docs/layout/layout.png" width="800"> <img src="../docs/layout/layout.png" width="800">
...@@ -152,7 +152,7 @@ We provide CDLA(Chinese layout analysis), TableBank(Table layout analysis)etc. d ...@@ -152,7 +152,7 @@ We provide CDLA(Chinese layout analysis), TableBank(Table layout analysis)etc. d
| [cTDaR2019_cTDaR](https://cndplab-founder.github.io/cTDaR2019/) | For form detection (TRACKA) and form identification (TRACKB).Image types include historical data sets (beginning with cTDaR_t0, such as CTDAR_T00872.jpg) and modern data sets (beginning with cTDaR_t1, CTDAR_T10482.jpg). | | [cTDaR2019_cTDaR](https://cndplab-founder.github.io/cTDaR2019/) | For form detection (TRACKA) and form identification (TRACKB).Image types include historical data sets (beginning with cTDaR_t0, such as CTDAR_T00872.jpg) and modern data sets (beginning with cTDaR_t1, CTDAR_T10482.jpg). |
| [IIIT-AR-13K](http://cvit.iiit.ac.in/usodi/iiitar13k.php) | Data sets constructed by manually annotating figures or pages from publicly available annual reports, containing 5 categories:table, figure, natural image, logo, and signature. | | [IIIT-AR-13K](http://cvit.iiit.ac.in/usodi/iiitar13k.php) | Data sets constructed by manually annotating figures or pages from publicly available annual reports, containing 5 categories:table, figure, natural image, logo, and signature. |
| [TableBank](https://github.com/doc-analysis/TableBank) | For table detection and recognition of large datasets, including Word and Latex document formats | | [TableBank](https://github.com/doc-analysis/TableBank) | For table detection and recognition of large datasets, including Word and Latex document formats |
| [CDLA](https://github.com/buptlihang/CDLA) | Chinese document layout analysis data set, for Chinese literature (paper) scenarios, including 10 categories:Table, Figure, Figure caption, Table, Table caption, Header, Footer, Reference, Equation | | [CDLA](https://github.com/buptlihang/CDLA) | Chinese document layout analysis data set, for Chinese literature (paper) scenarios, including 10 categories:Text, Title, Figure, Figure caption, Table, Table caption, Header, Footer, Reference, Equation |
| [DocBank](https://github.com/doc-analysis/DocBank) | Large-scale dataset (500K document pages) constructed using weakly supervised methods for document layout analysis, containing 12 categories:Author, Caption, Date, Equation, Figure, Footer, List, Paragraph, Reference, Section, Table, Title | | [DocBank](https://github.com/doc-analysis/DocBank) | Large-scale dataset (500K document pages) constructed using weakly supervised methods for document layout analysis, containing 12 categories:Author, Caption, Date, Equation, Figure, Footer, List, Paragraph, Reference, Section, Table, Title |
...@@ -175,7 +175,7 @@ If the test image is Chinese, the pre-trained model of Chinese CDLA dataset can ...@@ -175,7 +175,7 @@ If the test image is Chinese, the pre-trained model of Chinese CDLA dataset can
### 5.1. Train ### 5.1. Train
Train: Start training with the PaddleDetection [layout analysis profile](https://github.com/PaddlePaddle/PaddleDetection/tree/release/2.5/configs/picodet/legacy_model/application/layout_analysis)
* Modify Profile * Modify Profile
......
...@@ -22,7 +22,7 @@ ...@@ -22,7 +22,7 @@
## 1. 简介 ## 1. 简介
版面分析指的是对图片形式的文档进行区域划分,定位其中的关键区域,如文字、标题、表格、图片等。版面分析算法基于[PaddleDetection](https://github.com/PaddlePaddle/PaddleDetection)的轻量模型PP-PicoDet进行开发 版面分析指的是对图片形式的文档进行区域划分,定位其中的关键区域,如文字、标题、表格、图片等。版面分析算法基于[PaddleDetection](https://github.com/PaddlePaddle/PaddleDetection)的轻量模型PP-PicoDet进行开发,包含英文、中文、表格版面分析3类模型。其中,英文模型支持Text、Title、Tale、Figure、List5类区域的检测,中文模型支持Text、Title、Figure、Figure caption、Table、Table caption、Header、Footer、Reference、Equation10类区域的检测,表格版面分析支持Table区域的检测,版面分析效果如下图所示:
<div align="center"> <div align="center">
<img src="../docs/layout/layout.png" width="800"> <img src="../docs/layout/layout.png" width="800">
...@@ -152,7 +152,7 @@ json文件包含所有图像的标注,数据以字典嵌套的方式存放, ...@@ -152,7 +152,7 @@ json文件包含所有图像的标注,数据以字典嵌套的方式存放,
| ------------------------------------------------------------ | ------------------------------------------------------------ | | ------------------------------------------------------------ | ------------------------------------------------------------ |
| [cTDaR2019_cTDaR](https://cndplab-founder.github.io/cTDaR2019/) | 用于表格检测(TRACKA)和表格识别(TRACKB)。图片类型包含历史数据集(以cTDaR_t0开头,如cTDaR_t00872.jpg)和现代数据集(以cTDaR_t1开头,cTDaR_t10482.jpg)。 | | [cTDaR2019_cTDaR](https://cndplab-founder.github.io/cTDaR2019/) | 用于表格检测(TRACKA)和表格识别(TRACKB)。图片类型包含历史数据集(以cTDaR_t0开头,如cTDaR_t00872.jpg)和现代数据集(以cTDaR_t1开头,cTDaR_t10482.jpg)。 |
| [IIIT-AR-13K](http://cvit.iiit.ac.in/usodi/iiitar13k.php) | 手动注释公开的年度报告中的图形或页面而构建的数据集,包含5类:table, figure, natural image, logo, and signature | | [IIIT-AR-13K](http://cvit.iiit.ac.in/usodi/iiitar13k.php) | 手动注释公开的年度报告中的图形或页面而构建的数据集,包含5类:table, figure, natural image, logo, and signature |
| [CDLA](https://github.com/buptlihang/CDLA) | 中文文档版面分析数据集,面向中文文献类(论文)场景,包含10类:Table、Figure、Figure caption、Table、Table caption、Header、Footer、Reference、Equation | | [CDLA](https://github.com/buptlihang/CDLA) | 中文文档版面分析数据集,面向中文文献类(论文)场景,包含10类:Text、Title、Figure、Figure caption、Table、Table caption、Header、Footer、Reference、Equation |
| [TableBank](https://github.com/doc-analysis/TableBank) | 用于表格检测和识别大型数据集,包含Word和Latex2种文档格式 | | [TableBank](https://github.com/doc-analysis/TableBank) | 用于表格检测和识别大型数据集,包含Word和Latex2种文档格式 |
| [DocBank](https://github.com/doc-analysis/DocBank) | 使用弱监督方法构建的大规模数据集(500K文档页面),用于文档布局分析,包含12类:Author、Caption、Date、Equation、Figure、Footer、List、Paragraph、Reference、Section、Table、Title | | [DocBank](https://github.com/doc-analysis/DocBank) | 使用弱监督方法构建的大规模数据集(500K文档页面),用于文档布局分析,包含12类:Author、Caption、Date、Equation、Figure、Footer、List、Paragraph、Reference、Section、Table、Title |
...@@ -161,7 +161,7 @@ json文件包含所有图像的标注,数据以字典嵌套的方式存放, ...@@ -161,7 +161,7 @@ json文件包含所有图像的标注,数据以字典嵌套的方式存放,
提供了训练脚本、评估脚本和预测脚本,本节将以PubLayNet预训练模型为例进行讲解。 提供了训练脚本、评估脚本和预测脚本,本节将以PubLayNet预训练模型为例进行讲解。
如果不希望训练,直接体验后面的模型评估、预测、动转静、推理的流程,可以下载提供的预训练模型(PubLayNet数据集),并跳过本部分 如果不希望训练,直接体验后面的模型评估、预测、动转静、推理的流程,可以下载提供的预训练模型(PubLayNet数据集),并跳过5.1和5.2
``` ```
mkdir pretrained_model mkdir pretrained_model
...@@ -176,7 +176,7 @@ wget https://paddleocr.bj.bcebos.com/ppstructure/models/layout/picodet_lcnet_x1_ ...@@ -176,7 +176,7 @@ wget https://paddleocr.bj.bcebos.com/ppstructure/models/layout/picodet_lcnet_x1_
### 5.1. 启动训练 ### 5.1. 启动训练
开始训练: 使用PaddleDetection[版面分析配置文件](https://github.com/PaddlePaddle/PaddleDetection/tree/release/2.5/configs/picodet/legacy_model/application/layout_analysis)启动训练
* 修改配置文件 * 修改配置文件
......
...@@ -255,8 +255,7 @@ def main(args): ...@@ -255,8 +255,7 @@ def main(args):
if args.recovery and all_res != []: if args.recovery and all_res != []:
try: try:
convert_info_docx(img, all_res, save_folder, img_name, convert_info_docx(img, all_res, save_folder, img_name)
args.save_pdf)
except Exception as ex: except Exception as ex:
logger.error("error in layout recovery image:{}, err msg: {}". logger.error("error in layout recovery image:{}, err msg: {}".
format(image_file, ex)) format(image_file, ex))
......
...@@ -82,8 +82,11 @@ Through layout analysis, we divided the image/PDF documents into regions, locate ...@@ -82,8 +82,11 @@ Through layout analysis, we divided the image/PDF documents into regions, locate
We can restore the test picture through the layout information, OCR detection and recognition structure, table information, and saved pictures. We can restore the test picture through the layout information, OCR detection and recognition structure, table information, and saved pictures.
The whl package is also provided for quick use, see [quickstart](../docs/quickstart_en.md) for details. The whl package is also provided for quick use, follow the above code, for more infomation please refer to [quickstart](../docs/quickstart_en.md) for details.
```bash
paddleocr --image_dir=ppstructure/docs/table/1.png --type=structure --recovery=true --lang='en'
```
<a name="3.1"></a> <a name="3.1"></a>
### 3.1 Download models ### 3.1 Download models
......
...@@ -83,7 +83,16 @@ python3 -m pip install -r ppstructure/recovery/requirements.txt ...@@ -83,7 +83,16 @@ python3 -m pip install -r ppstructure/recovery/requirements.txt
我们通过版面信息、OCR检测和识别结构、表格信息、保存的图片,对测试图片进行恢复即可。 我们通过版面信息、OCR检测和识别结构、表格信息、保存的图片,对测试图片进行恢复即可。
提供如下代码实现版面恢复,也提供了whl包的形式方便快速使用,详见 [quickstart](../docs/quickstart.md) 提供如下代码实现版面恢复,也提供了whl包的形式方便快速使用,代码如下,更多信息详见 [quickstart](../docs/quickstart.md)
```bash
# 中文测试图
paddleocr --image_dir=ppstructure/docs/table/1.png --type=structure --recovery=true
# 英文测试图
paddleocr --image_dir=ppstructure/docs/table/1.png --type=structure --recovery=true --lang='en'
# pdf测试文件
paddleocr --image_dir=ppstructure/recovery/UnrealText.pdf --type=structure --recovery=true --lang='en'
```
<a name="3.1"></a> <a name="3.1"></a>
......
...@@ -28,7 +28,7 @@ from ppocr.utils.logging import get_logger ...@@ -28,7 +28,7 @@ from ppocr.utils.logging import get_logger
logger = get_logger() logger = get_logger()
def convert_info_docx(img, res, save_folder, img_name, save_pdf=False): def convert_info_docx(img, res, save_folder, img_name):
doc = Document() doc = Document()
doc.styles['Normal'].font.name = 'Times New Roman' doc.styles['Normal'].font.name = 'Times New Roman'
doc.styles['Normal']._element.rPr.rFonts.set(qn('w:eastAsia'), u'宋体') doc.styles['Normal']._element.rPr.rFonts.set(qn('w:eastAsia'), u'宋体')
...@@ -60,14 +60,9 @@ def convert_info_docx(img, res, save_folder, img_name, save_pdf=False): ...@@ -60,14 +60,9 @@ def convert_info_docx(img, res, save_folder, img_name, save_pdf=False):
elif region['type'].lower() == 'title': elif region['type'].lower() == 'title':
doc.add_heading(region['res'][0]['text']) doc.add_heading(region['res'][0]['text'])
elif region['type'].lower() == 'table': elif region['type'].lower() == 'table':
paragraph = doc.add_paragraph() parser = HtmlToDocx()
new_parser = HtmlToDocx() parser.table_style = 'TableGrid'
new_parser.table_style = 'TableGrid' parser.handle_table(region['res']['html'], doc)
table = new_parser.handle_table(html=region['res']['html'])
new_table = deepcopy(table)
new_table.alignment = WD_TABLE_ALIGNMENT.CENTER
paragraph.add_run().element.addnext(new_table._tbl)
else: else:
paragraph = doc.add_paragraph() paragraph = doc.add_paragraph()
paragraph_format = paragraph.paragraph_format paragraph_format = paragraph.paragraph_format
...@@ -82,13 +77,6 @@ def convert_info_docx(img, res, save_folder, img_name, save_pdf=False): ...@@ -82,13 +77,6 @@ def convert_info_docx(img, res, save_folder, img_name, save_pdf=False):
doc.save(docx_path) doc.save(docx_path)
logger.info('docx save to {}'.format(docx_path)) logger.info('docx save to {}'.format(docx_path))
# save to pdf
if save_pdf:
pdf_path = os.path.join(save_folder, '{}.pdf'.format(img_name))
from docx2pdf import convert
convert(docx_path, pdf_path)
logger.info('pdf save to {}'.format(pdf_path))
def sorted_layout_boxes(res, w): def sorted_layout_boxes(res, w):
""" """
......
python-docx python-docx
docx2pdf
PyMuPDF PyMuPDF
beautifulsoup4 beautifulsoup4
\ No newline at end of file
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. # Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
...@@ -13,62 +12,59 @@ ...@@ -13,62 +12,59 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
""" """
This code is refer from:https://github.com/pqzx/html2docx/blob/8f6695a778c68befb302e48ac0ed5201ddbd4524/htmldocx/h2d.py This code is refer from: https://github.com/weizwx/html2docx/blob/master/htmldocx/h2d.py
""" """
import re, argparse
import io, os
import urllib.request
from urllib.parse import urlparse
from html.parser import HTMLParser
import docx, docx.table import re
import docx
from docx import Document from docx import Document
from docx.shared import RGBColor, Pt, Inches
from docx.enum.text import WD_COLOR, WD_ALIGN_PARAGRAPH
from docx.oxml import OxmlElement
from docx.oxml.ns import qn
from bs4 import BeautifulSoup from bs4 import BeautifulSoup
from html.parser import HTMLParser
# values in inches
INDENT = 0.25
LIST_INDENT = 0.5
MAX_INDENT = 5.5 # To stop indents going off the page
# Style to use with tables. By default no style is used. def get_table_rows(table_soup):
DEFAULT_TABLE_STYLE = None table_row_selectors = [
'table > tr', 'table > thead > tr', 'table > tbody > tr',
'table > tfoot > tr'
]
# If there's a header, body, footer or direct child tr tags, add row dimensions from there
return table_soup.select(', '.join(table_row_selectors), recursive=False)
# Style to use with paragraphs. By default no style is used.
DEFAULT_PARAGRAPH_STYLE = None
def get_table_columns(row):
# Get all columns for the specified row tag.
return row.find_all(['th', 'td'], recursive=False) if row else []
def get_filename_from_url(url):
return os.path.basename(urlparse(url).path)
def is_url(url): def get_table_dimensions(table_soup):
""" # Get rows for the table
Not to be used for actually validating a url, but in our use case we only rows = get_table_rows(table_soup)
care if it's a url or a file path, and they're pretty distinguishable # Table is either empty or has non-direct children between table and tr tags
""" # Thus the row dimensions and column dimensions are assumed to be 0
parts = urlparse(url)
return all([parts.scheme, parts.netloc, parts.path])
def fetch_image(url): cols = get_table_columns(rows[0]) if rows else []
""" # Add colspan calculation column number
Attempts to fetch an image from a url. col_count = 0
If successful returns a bytes object, else returns None for col in cols:
:return: colspan = col.attrs.get('colspan', 1)
""" col_count += int(colspan)
try:
with urllib.request.urlopen(url) as response: return rows, col_count
# security flaw?
return io.BytesIO(response.read())
except urllib.error.URLError: def get_cell_html(soup):
return None # Returns string of td element with opening and closing <td> tags removed
# Cannot use find_all as it only finds element tags and does not find text which
# is not inside an element
return ' '.join([str(i) for i in soup.contents])
def delete_paragraph(paragraph):
# https://github.com/python-openxml/python-docx/issues/33#issuecomment-77661907
p = paragraph._element
p.getparent().remove(p)
p._p = p._element = None
def remove_last_occurence(ls, x):
ls.pop(len(ls) - ls[::-1].index(x) - 1)
def remove_whitespace(string, leading=False, trailing=False): def remove_whitespace(string, leading=False, trailing=False):
"""Remove white space from a string. """Remove white space from a string.
...@@ -122,11 +118,6 @@ def remove_whitespace(string, leading=False, trailing=False): ...@@ -122,11 +118,6 @@ def remove_whitespace(string, leading=False, trailing=False):
# TODO need some way to get rid of extra spaces in e.g. text <span> </span> text # TODO need some way to get rid of extra spaces in e.g. text <span> </span> text
return re.sub(r'\s+', ' ', string) return re.sub(r'\s+', ' ', string)
def delete_paragraph(paragraph):
# https://github.com/python-openxml/python-docx/issues/33#issuecomment-77661907
p = paragraph._element
p.getparent().remove(p)
p._p = p._element = None
font_styles = { font_styles = {
'b': 'bold', 'b': 'bold',
...@@ -145,13 +136,8 @@ font_names = { ...@@ -145,13 +136,8 @@ font_names = {
'pre': 'Courier', 'pre': 'Courier',
} }
styles = {
'LIST_BULLET': 'List Bullet',
'LIST_NUMBER': 'List Number',
}
class HtmlToDocx(HTMLParser): class HtmlToDocx(HTMLParser):
def __init__(self): def __init__(self):
super().__init__() super().__init__()
self.options = { self.options = {
...@@ -161,13 +147,11 @@ class HtmlToDocx(HTMLParser): ...@@ -161,13 +147,11 @@ class HtmlToDocx(HTMLParser):
'styles': True, 'styles': True,
} }
self.table_row_selectors = [ self.table_row_selectors = [
'table > tr', 'table > tr', 'table > thead > tr', 'table > tbody > tr',
'table > thead > tr',
'table > tbody > tr',
'table > tfoot > tr' 'table > tfoot > tr'
] ]
self.table_style = DEFAULT_TABLE_STYLE self.table_style = None
self.paragraph_style = DEFAULT_PARAGRAPH_STYLE self.paragraph_style = None
def set_initial_attrs(self, document=None): def set_initial_attrs(self, document=None):
self.tags = { self.tags = {
...@@ -178,9 +162,10 @@ class HtmlToDocx(HTMLParser): ...@@ -178,9 +162,10 @@ class HtmlToDocx(HTMLParser):
self.doc = document self.doc = document
else: else:
self.doc = Document() self.doc = Document()
self.bs = self.options['fix-html'] # whether or not to clean with BeautifulSoup self.bs = self.options[
'fix-html'] # whether or not to clean with BeautifulSoup
self.document = self.doc self.document = self.doc
self.include_tables = True #TODO add this option back in? self.include_tables = True #TODO add this option back in?
self.include_images = self.options['images'] self.include_images = self.options['images']
self.include_styles = self.options['styles'] self.include_styles = self.options['styles']
self.paragraph = None self.paragraph = None
...@@ -193,55 +178,52 @@ class HtmlToDocx(HTMLParser): ...@@ -193,55 +178,52 @@ class HtmlToDocx(HTMLParser):
self.table_style = other.table_style self.table_style = other.table_style
self.paragraph_style = other.paragraph_style self.paragraph_style = other.paragraph_style
def get_cell_html(self, soup): def ignore_nested_tables(self, tables_soup):
# Returns string of td element with opening and closing <td> tags removed """
# Cannot use find_all as it only finds element tags and does not find text which Returns array containing only the highest level tables
# is not inside an element Operates on the assumption that bs4 returns child elements immediately after
return ' '.join([str(i) for i in soup.contents]) the parent element in `find_all`. If this changes in the future, this method will need to be updated
:return:
def add_styles_to_paragraph(self, style): """
if 'text-align' in style: new_tables = []
align = style['text-align'] nest = 0
if align == 'center': for table in tables_soup:
self.paragraph.paragraph_format.alignment = WD_ALIGN_PARAGRAPH.CENTER if nest:
elif align == 'right': nest -= 1
self.paragraph.paragraph_format.alignment = WD_ALIGN_PARAGRAPH.RIGHT continue
elif align == 'justify': new_tables.append(table)
self.paragraph.paragraph_format.alignment = WD_ALIGN_PARAGRAPH.JUSTIFY nest = len(table.find_all('table'))
if 'margin-left' in style: return new_tables
margin = style['margin-left']
units = re.sub(r'[0-9]+', '', margin) def get_tables(self):
margin = int(float(re.sub(r'[a-z]+', '', margin))) if not hasattr(self, 'soup'):
if units == 'px': self.include_tables = False
self.paragraph.paragraph_format.left_indent = Inches(min(margin // 10 * INDENT, MAX_INDENT)) return
# TODO handle non px units # find other way to do it, or require this dependency?
self.tables = self.ignore_nested_tables(self.soup.find_all('table'))
def add_styles_to_run(self, style): self.table_no = 0
if 'color' in style:
if 'rgb' in style['color']: def run_process(self, html):
color = re.sub(r'[a-z()]+', '', style['color']) if self.bs and BeautifulSoup:
colors = [int(x) for x in color.split(',')] self.soup = BeautifulSoup(html, 'html.parser')
elif '#' in style['color']: html = str(self.soup)
color = style['color'].lstrip('#') if self.include_tables:
colors = tuple(int(color[i:i+2], 16) for i in (0, 2, 4)) self.get_tables()
else: self.feed(html)
colors = [0, 0, 0]
# TODO map colors to named colors (and extended colors...) def add_html_to_cell(self, html, cell):
# For now set color to black to prevent crashing if not isinstance(cell, docx.table._Cell):
self.run.font.color.rgb = RGBColor(*colors) raise ValueError('Second argument needs to be a %s' %
docx.table._Cell)
if 'background-color' in style: unwanted_paragraph = cell.paragraphs[0]
if 'rgb' in style['background-color']: if unwanted_paragraph.text == "":
color = color = re.sub(r'[a-z()]+', '', style['background-color']) delete_paragraph(unwanted_paragraph)
colors = [int(x) for x in color.split(',')] self.set_initial_attrs(cell)
elif '#' in style['background-color']: self.run_process(html)
color = style['background-color'].lstrip('#') # cells must end with a paragraph or will get message about corrupt file
colors = tuple(int(color[i:i+2], 16) for i in (0, 2, 4)) # https://stackoverflow.com/a/29287121
else: if not self.doc.paragraphs:
colors = [0, 0, 0] self.doc.add_paragraph('')
# TODO map colors to named colors (and extended colors...)
# For now set color to black to prevent crashing
self.run.font.highlight_color = WD_COLOR.GRAY_25 #TODO: map colors
def apply_paragraph_style(self, style=None): def apply_paragraph_style(self, style=None):
try: try:
...@@ -250,69 +232,10 @@ class HtmlToDocx(HTMLParser): ...@@ -250,69 +232,10 @@ class HtmlToDocx(HTMLParser):
elif self.paragraph_style: elif self.paragraph_style:
self.paragraph.style = self.paragraph_style self.paragraph.style = self.paragraph_style
except KeyError as e: except KeyError as e:
raise ValueError(f"Unable to apply style {self.paragraph_style}.") from e raise ValueError(
f"Unable to apply style {self.paragraph_style}.") from e
def parse_dict_string(self, string, separator=';'):
new_string = string.replace(" ", '').split(separator)
string_dict = dict([x.split(':') for x in new_string if ':' in x])
return string_dict
def handle_li(self):
# check list stack to determine style and depth
list_depth = len(self.tags['list'])
if list_depth:
list_type = self.tags['list'][-1]
else:
list_type = 'ul' # assign unordered if no tag
if list_type == 'ol': def handle_table(self, html, doc):
list_style = styles['LIST_NUMBER']
else:
list_style = styles['LIST_BULLET']
self.paragraph = self.doc.add_paragraph(style=list_style)
self.paragraph.paragraph_format.left_indent = Inches(min(list_depth * LIST_INDENT, MAX_INDENT))
self.paragraph.paragraph_format.line_spacing = 1
def add_image_to_cell(self, cell, image):
# python-docx doesn't have method yet for adding images to table cells. For now we use this
paragraph = cell.add_paragraph()
run = paragraph.add_run()
run.add_picture(image)
def handle_img(self, current_attrs):
if not self.include_images:
self.skip = True
self.skip_tag = 'img'
return
src = current_attrs['src']
# fetch image
src_is_url = is_url(src)
if src_is_url:
try:
image = fetch_image(src)
except urllib.error.URLError:
image = None
else:
image = src
# add image to doc
if image:
try:
if isinstance(self.doc, docx.document.Document):
self.doc.add_picture(image)
else:
self.add_image_to_cell(self.doc, image)
except FileNotFoundError:
image = None
if not image:
if src_is_url:
self.doc.add_paragraph("<image: %s>" % src)
else:
# avoid exposing filepaths in document
self.doc.add_paragraph("<image: %s>" % get_filename_from_url(src))
def handle_table(self, html):
""" """
To handle nested tables, we will parse tables manually as follows: To handle nested tables, we will parse tables manually as follows:
Get table soup Get table soup
...@@ -320,194 +243,42 @@ class HtmlToDocx(HTMLParser): ...@@ -320,194 +243,42 @@ class HtmlToDocx(HTMLParser):
Iterate over soup and fill docx table with new instances of this parser Iterate over soup and fill docx table with new instances of this parser
Tell HTMLParser to ignore any tags until the corresponding closing table tag Tell HTMLParser to ignore any tags until the corresponding closing table tag
""" """
doc = Document()
table_soup = BeautifulSoup(html, 'html.parser') table_soup = BeautifulSoup(html, 'html.parser')
rows, cols_len = self.get_table_dimensions(table_soup) rows, cols_len = get_table_dimensions(table_soup)
table = doc.add_table(len(rows), cols_len) table = doc.add_table(len(rows), cols_len)
table.style = doc.styles['Table Grid'] table.style = doc.styles['Table Grid']
cell_row = 0 cell_row = 0
for index, row in enumerate(rows): for index, row in enumerate(rows):
cols = self.get_table_columns(row) cols = get_table_columns(row)
cell_col = 0 cell_col = 0
for col in cols: for col in cols:
colspan = int(col.attrs.get('colspan', 1)) colspan = int(col.attrs.get('colspan', 1))
rowspan = int(col.attrs.get('rowspan', 1)) rowspan = int(col.attrs.get('rowspan', 1))
cell_html = self.get_cell_html(col) cell_html = get_cell_html(col)
if col.name == 'th': if col.name == 'th':
cell_html = "<b>%s</b>" % cell_html cell_html = "<b>%s</b>" % cell_html
docx_cell = table.cell(cell_row, cell_col) docx_cell = table.cell(cell_row, cell_col)
while docx_cell.text != '': # Skip the merged cell while docx_cell.text != '': # Skip the merged cell
cell_col += 1 cell_col += 1
docx_cell = table.cell(cell_row, cell_col) docx_cell = table.cell(cell_row, cell_col)
cell_to_merge = table.cell(cell_row + rowspan - 1, cell_col + colspan - 1) cell_to_merge = table.cell(cell_row + rowspan - 1,
cell_col + colspan - 1)
if docx_cell != cell_to_merge: if docx_cell != cell_to_merge:
docx_cell.merge(cell_to_merge) docx_cell.merge(cell_to_merge)
child_parser = HtmlToDocx() child_parser = HtmlToDocx()
child_parser.copy_settings_from(self) child_parser.copy_settings_from(self)
child_parser.add_html_to_cell(cell_html or ' ', docx_cell)
child_parser.add_html_to_cell(cell_html or ' ', docx_cell) # occupy the position
cell_col += colspan cell_col += colspan
cell_row += 1 cell_row += 1
# skip all tags until corresponding closing tag
self.instances_to_skip = len(table_soup.find_all('table'))
self.skip_tag = 'table'
self.skip = True
self.table = None
return table
def handle_link(self, href, text):
# Link requires a relationship
is_external = href.startswith('http')
rel_id = self.paragraph.part.relate_to(
href,
docx.opc.constants.RELATIONSHIP_TYPE.HYPERLINK,
is_external=True # don't support anchor links for this library yet
)
# Create the w:hyperlink tag and add needed values
hyperlink = docx.oxml.shared.OxmlElement('w:hyperlink')
hyperlink.set(docx.oxml.shared.qn('r:id'), rel_id)
# Create sub-run
subrun = self.paragraph.add_run()
rPr = docx.oxml.shared.OxmlElement('w:rPr')
# add default color
c = docx.oxml.shared.OxmlElement('w:color')
c.set(docx.oxml.shared.qn('w:val'), "0000EE")
rPr.append(c)
# add underline
u = docx.oxml.shared.OxmlElement('w:u')
u.set(docx.oxml.shared.qn('w:val'), 'single')
rPr.append(u)
subrun._r.append(rPr)
subrun._r.text = text
# Add subrun to hyperlink
hyperlink.append(subrun._r)
# Add hyperlink to run
self.paragraph._p.append(hyperlink)
def handle_starttag(self, tag, attrs):
if self.skip:
return
if tag == 'head':
self.skip = True
self.skip_tag = tag
self.instances_to_skip = 0
return
elif tag == 'body':
return
current_attrs = dict(attrs)
if tag == 'span':
self.tags['span'].append(current_attrs)
return
elif tag == 'ol' or tag == 'ul':
self.tags['list'].append(tag)
return # don't apply styles for now
elif tag == 'br':
self.run.add_break()
return
self.tags[tag] = current_attrs
if tag in ['p', 'pre']:
self.paragraph = self.doc.add_paragraph()
self.apply_paragraph_style()
elif tag == 'li':
self.handle_li()
elif tag == "hr":
# This implementation was taken from:
# https://github.com/python-openxml/python-docx/issues/105#issuecomment-62806373
self.paragraph = self.doc.add_paragraph()
pPr = self.paragraph._p.get_or_add_pPr()
pBdr = OxmlElement('w:pBdr')
pPr.insert_element_before(pBdr,
'w:shd', 'w:tabs', 'w:suppressAutoHyphens', 'w:kinsoku', 'w:wordWrap',
'w:overflowPunct', 'w:topLinePunct', 'w:autoSpaceDE', 'w:autoSpaceDN',
'w:bidi', 'w:adjustRightInd', 'w:snapToGrid', 'w:spacing', 'w:ind',
'w:contextualSpacing', 'w:mirrorIndents', 'w:suppressOverlap', 'w:jc',
'w:textDirection', 'w:textAlignment', 'w:textboxTightWrap',
'w:outlineLvl', 'w:divId', 'w:cnfStyle', 'w:rPr', 'w:sectPr',
'w:pPrChange'
)
bottom = OxmlElement('w:bottom')
bottom.set(qn('w:val'), 'single')
bottom.set(qn('w:sz'), '6')
bottom.set(qn('w:space'), '1')
bottom.set(qn('w:color'), 'auto')
pBdr.append(bottom)
elif re.match('h[1-9]', tag):
if isinstance(self.doc, docx.document.Document):
h_size = int(tag[1])
self.paragraph = self.doc.add_heading(level=min(h_size, 9))
else:
self.paragraph = self.doc.add_paragraph()
elif tag == 'img':
self.handle_img(current_attrs)
return
elif tag == 'table':
self.handle_table()
return
# set new run reference point in case of leading line breaks doc.save('1.docx')
if tag in ['p', 'li', 'pre']:
self.run = self.paragraph.add_run()
# add style
if not self.include_styles:
return
if 'style' in current_attrs and self.paragraph:
style = self.parse_dict_string(current_attrs['style'])
self.add_styles_to_paragraph(style)
def handle_endtag(self, tag):
if self.skip:
if not tag == self.skip_tag:
return
if self.instances_to_skip > 0:
self.instances_to_skip -= 1
return
self.skip = False
self.skip_tag = None
self.paragraph = None
if tag == 'span':
if self.tags['span']:
self.tags['span'].pop()
return
elif tag == 'ol' or tag == 'ul':
remove_last_occurence(self.tags['list'], tag)
return
elif tag == 'table':
self.table_no += 1
self.table = None
self.doc = self.document
self.paragraph = None
if tag in self.tags:
self.tags.pop(tag)
# maybe set relevant reference to None?
def handle_data(self, data): def handle_data(self, data):
if self.skip: if self.skip:
...@@ -546,87 +317,3 @@ class HtmlToDocx(HTMLParser): ...@@ -546,87 +317,3 @@ class HtmlToDocx(HTMLParser):
if tag in font_names: if tag in font_names:
font_name = font_names[tag] font_name = font_names[tag]
self.run.font.name = font_name self.run.font.name = font_name
def ignore_nested_tables(self, tables_soup):
"""
Returns array containing only the highest level tables
Operates on the assumption that bs4 returns child elements immediately after
the parent element in `find_all`. If this changes in the future, this method will need to be updated
:return:
"""
new_tables = []
nest = 0
for table in tables_soup:
if nest:
nest -= 1
continue
new_tables.append(table)
nest = len(table.find_all('table'))
return new_tables
def get_table_rows(self, table_soup):
# If there's a header, body, footer or direct child tr tags, add row dimensions from there
return table_soup.select(', '.join(self.table_row_selectors), recursive=False)
def get_table_columns(self, row):
# Get all columns for the specified row tag.
return row.find_all(['th', 'td'], recursive=False) if row else []
def get_table_dimensions(self, table_soup):
# Get rows for the table
rows = self.get_table_rows(table_soup)
# Table is either empty or has non-direct children between table and tr tags
# Thus the row dimensions and column dimensions are assumed to be 0
cols = self.get_table_columns(rows[0]) if rows else []
# Add colspan calculation column number
col_count = 0
for col in cols:
colspan = col.attrs.get('colspan', 1)
col_count += int(colspan)
# return len(rows), col_count
return rows, col_count
def get_tables(self):
if not hasattr(self, 'soup'):
self.include_tables = False
return
# find other way to do it, or require this dependency?
self.tables = self.ignore_nested_tables(self.soup.find_all('table'))
self.table_no = 0
def run_process(self, html):
if self.bs and BeautifulSoup:
self.soup = BeautifulSoup(html, 'html.parser')
html = str(self.soup)
if self.include_tables:
self.get_tables()
self.feed(html)
def add_html_to_document(self, html, document):
if not isinstance(html, str):
raise ValueError('First argument needs to be a %s' % str)
elif not isinstance(document, docx.document.Document) and not isinstance(document, docx.table._Cell):
raise ValueError('Second argument needs to be a %s' % docx.document.Document)
self.set_initial_attrs(document)
self.run_process(html)
def add_html_to_cell(self, html, cell):
self.set_initial_attrs(cell)
self.run_process(html)
def parse_html_file(self, filename_html, filename_docx=None):
with open(filename_html, 'r') as infile:
html = infile.read()
self.set_initial_attrs()
self.run_process(html)
if not filename_docx:
path, filename = os.path.split(filename_html)
filename_docx = '%s/new_docx_file_%s' % (path, filename)
self.doc.save('%s.docx' % filename_docx)
def parse_html_string(self, html):
self.set_initial_attrs()
self.run_process(html)
return self.doc
\ No newline at end of file
...@@ -92,11 +92,6 @@ def init_args(): ...@@ -92,11 +92,6 @@ def init_args():
type=str2bool, type=str2bool,
default=False, default=False,
help='Whether to enable layout of recovery') help='Whether to enable layout of recovery')
parser.add_argument(
"--save_pdf",
type=str2bool,
default=False,
help='Whether to save pdf file')
return parser return parser
...@@ -110,7 +105,38 @@ def draw_structure_result(image, result, font_path): ...@@ -110,7 +105,38 @@ def draw_structure_result(image, result, font_path):
if isinstance(image, np.ndarray): if isinstance(image, np.ndarray):
image = Image.fromarray(image) image = Image.fromarray(image)
boxes, txts, scores = [], [], [] boxes, txts, scores = [], [], []
img_layout = image.copy()
draw_layout = ImageDraw.Draw(img_layout)
text_color = (255, 255, 255)
text_background_color = (80, 127, 255)
catid2color = {}
font_size = 15
font = ImageFont.truetype(font_path, font_size, encoding="utf-8")
for region in result: for region in result:
if region['type'] not in catid2color:
box_color = (random.randint(0, 255), random.randint(0, 255),
random.randint(0, 255))
catid2color[region['type']] = box_color
else:
box_color = catid2color[region['type']]
box_layout = region['bbox']
draw_layout.rectangle(
[(box_layout[0], box_layout[1]), (box_layout[2], box_layout[3])],
outline=box_color,
width=3)
text_w, text_h = font.getsize(region['type'])
draw_layout.rectangle(
[(box_layout[0], box_layout[1]),
(box_layout[0] + text_w, box_layout[1] + text_h)],
fill=text_background_color)
draw_layout.text(
(box_layout[0], box_layout[1]),
region['type'],
fill=text_color,
font=font)
if region['type'] == 'table': if region['type'] == 'table':
pass pass
else: else:
...@@ -118,6 +144,7 @@ def draw_structure_result(image, result, font_path): ...@@ -118,6 +144,7 @@ def draw_structure_result(image, result, font_path):
boxes.append(np.array(text_result['text_region'])) boxes.append(np.array(text_result['text_region']))
txts.append(text_result['text']) txts.append(text_result['text'])
scores.append(text_result['confidence']) scores.append(text_result['confidence'])
im_show = draw_ocr_box_txt( im_show = draw_ocr_box_txt(
image, boxes, txts, scores, font_path=font_path, drop_score=0) img_layout, boxes, txts, scores, font_path=font_path, drop_score=0)
return im_show return im_show
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册