提交 300452c7 编写于 作者: L LDOUBLEV

Merge branch 'dygraph' of https://github.com/PaddlePaddle/PaddleOCR into dygraph

......@@ -70,7 +70,7 @@ cmake安装完后后系统里会有一个cmake-gui程序,打开cmake-gui,在
* cpu版本,仅需考虑OPENCV_DIR、OpenCV_DIR、PADDLE_LIB三个参数
- OPENCV_DIR:填写opencv lib文件夹所在位置
- OpenCV_DIR:同填写opencv lib文件夹所在位
- OpenCV_DIR:同填写opencv lib文件夹所在位
- PADDLE_LIB:paddle_inference文件夹所在位置
* GPU版本,在cpu版本的基础上,还需填写以下变量
......@@ -78,7 +78,7 @@ CUDA_LIB、CUDNN_LIB、TENSORRT_DIR、WITH_GPU、WITH_TENSORRT
- CUDA_LIB: CUDA地址,如 `C:\Program Files\NVIDIA GPU Computing Toolkit\CUDA\v11.2\lib\x64`
- CUDNN_LIB: 和CUDA_LIB一致
- TENSORRT_DIR:TRT下载后解压缩的位置
- TENSORRT_DIR:TRT下载后解压缩的位置,如 `D:\TensorRT-8.0.1.6`
- WITH_GPU: 打钩
- WITH_TENSORRT:打勾
......@@ -110,10 +110,11 @@ CUDA_LIB、CUDNN_LIB、TENSORRT_DIR、WITH_GPU、WITH_TENSORRT
运行之前,将下面文件拷贝到`build/Release/`文件夹下
1. `paddle_inference/paddle/lib/paddle_inference.dll`
2. `opencv/build/x64/vc15/bin/opencv_world455.dll`
3. 如果使用openblas版本的预测库还需要拷贝 `paddle_inference/third_party/install/openblas/lib/openblas.dll`
### Step4: 预测
上述`Visual Studio 2019`编译产出的可执行文件在`out\build\x64-Release\Release`目录下,打开`cmd`,并切换到`D:\projects\cpp\PaddleOCR\deploy\cpp_infer\`:
上述`Visual Studio 2019`编译产出的可执行文件在`build/Release/`目录下,打开`cmd`,并切换到`D:\projects\cpp\PaddleOCR\deploy\cpp_infer\`:
```
cd /d D:\projects\cpp\PaddleOCR\deploy\cpp_infer
......@@ -128,7 +129,7 @@ CHCP 65001
```
识别结果如下
![result](imgs/result.png)
![result](imgs/result.jpg)
## FAQ
......
......@@ -46,6 +46,8 @@ DECLARE_int32(cls_batch_num);
DECLARE_string(rec_model_dir);
DECLARE_int32(rec_batch_num);
DECLARE_string(rec_char_dict_path);
DECLARE_int32(rec_img_h);
DECLARE_int32(rec_img_w);
// forward related
DECLARE_bool(det);
DECLARE_bool(rec);
......
......@@ -45,7 +45,8 @@ public:
const bool &use_mkldnn, const string &label_path,
const bool &use_tensorrt,
const std::string &precision,
const int &rec_batch_num) {
const int &rec_batch_num, const int &rec_img_h,
const int &rec_img_w) {
this->use_gpu_ = use_gpu;
this->gpu_id_ = gpu_id;
this->gpu_mem_ = gpu_mem;
......@@ -54,6 +55,10 @@ public:
this->use_tensorrt_ = use_tensorrt;
this->precision_ = precision;
this->rec_batch_num_ = rec_batch_num;
this->rec_img_h_ = rec_img_h;
this->rec_img_w_ = rec_img_w;
std::vector<int> rec_image_shape = {3, rec_img_h, rec_img_w};
this->rec_image_shape_ = rec_image_shape;
this->label_list_ = Utility::ReadDict(label_path);
this->label_list_.insert(this->label_list_.begin(),
......@@ -86,7 +91,9 @@ private:
bool use_tensorrt_ = false;
std::string precision_ = "fp32";
int rec_batch_num_ = 6;
int rec_img_h_ = 32;
int rec_img_w_ = 320;
std::vector<int> rec_image_shape_ = {3, rec_img_h_, rec_img_w_};
// pre-process
CrnnResizeImg resize_op_;
Normalize normalize_op_;
......
......@@ -39,10 +39,10 @@ using namespace paddle_infer;
namespace PaddleOCR {
class PaddleOCR {
class PPOCR {
public:
explicit PaddleOCR();
~PaddleOCR();
explicit PPOCR();
~PPOCR();
std::vector<std::vector<OCRPredictResult>>
ocr(std::vector<cv::String> cv_all_img_names, bool det = true,
bool rec = true, bool cls = true);
......
......@@ -65,6 +65,8 @@ public:
static bool PathExists(const std::string &path);
static void CreateDir(const std::string &path);
static void print_result(const std::vector<OCRPredictResult> &ocr_result);
};
......
......@@ -323,6 +323,8 @@ More parameters are as follows,
|rec_model_dir|string|-|Address of recognition inference model|
|rec_char_dict_path|string|../../ppocr/utils/ppocr_keys_v1.txt|dictionary file|
|rec_batch_num|int|6|batch size of recognition|
|rec_img_h|int|32|image height of recognition|
|rec_img_w|int|320|image width of recognition|
* Multi-language inference is also supported in PaddleOCR, you can refer to [recognition tutorial](../../doc/doc_en/recognition_en.md) for more supported languages and models in PaddleOCR. Specifically, if you want to infer using multi-language models, you just need to modify values of `rec_char_dict_path` and `rec_model_dir`.
......
......@@ -336,6 +336,8 @@ CUDNN_LIB_DIR=/your_cudnn_lib_dir
|rec_model_dir|string|-|识别模型inference model地址|
|rec_char_dict_path|string|../../ppocr/utils/ppocr_keys_v1.txt|字典文件|
|rec_batch_num|int|6|识别模型batchsize|
|rec_img_h|int|32|识别模型输入图像高度|
|rec_img_w|int|320|识别模型输入图像宽度|
* PaddleOCR也支持多语言的预测,更多支持的语言和模型可以参考[识别文档](../../doc/doc_ch/recognition.md)中的多语言字典与模型部分,如果希望进行多语言预测,只需将修改`rec_char_dict_path`(字典文件路径)以及`rec_model_dir`(inference模型路径)字段即可。
......
......@@ -47,6 +47,8 @@ DEFINE_string(rec_model_dir, "", "Path of rec inference model.");
DEFINE_int32(rec_batch_num, 6, "rec_batch_num.");
DEFINE_string(rec_char_dict_path, "../../ppocr/utils/ppocr_keys_v1.txt",
"Path of dictionary.");
DEFINE_int32(rec_img_h, 32, "rec image height");
DEFINE_int32(rec_img_w, 320, "rec image width");
// ocr forward related
DEFINE_bool(det, true, "Whether use det in forward.");
......
......@@ -69,7 +69,7 @@ int main(int argc, char **argv) {
cv::glob(FLAGS_image_dir, cv_all_img_names);
std::cout << "total images num: " << cv_all_img_names.size() << endl;
PaddleOCR::PaddleOCR ocr = PaddleOCR::PaddleOCR();
PPOCR ocr = PPOCR();
std::vector<std::vector<OCRPredictResult>> ocr_results =
ocr.ocr(cv_all_img_names, FLAGS_det, FLAGS_rec, FLAGS_cls);
......
......@@ -39,7 +39,9 @@ void CRNNRecognizer::Run(std::vector<cv::Mat> img_list,
auto preprocess_start = std::chrono::steady_clock::now();
int end_img_no = min(img_num, beg_img_no + this->rec_batch_num_);
int batch_num = end_img_no - beg_img_no;
float max_wh_ratio = 0;
int imgH = this->rec_image_shape_[1];
int imgW = this->rec_image_shape_[2];
float max_wh_ratio = imgW * 1.0 / imgH;
for (int ino = beg_img_no; ino < end_img_no; ino++) {
int h = img_list[indices[ino]].rows;
int w = img_list[indices[ino]].cols;
......@@ -47,28 +49,28 @@ void CRNNRecognizer::Run(std::vector<cv::Mat> img_list,
max_wh_ratio = max(max_wh_ratio, wh_ratio);
}
int batch_width = 0;
int batch_width = imgW;
std::vector<cv::Mat> norm_img_batch;
for (int ino = beg_img_no; ino < end_img_no; ino++) {
cv::Mat srcimg;
img_list[indices[ino]].copyTo(srcimg);
cv::Mat resize_img;
this->resize_op_.Run(srcimg, resize_img, max_wh_ratio,
this->use_tensorrt_);
this->use_tensorrt_, this->rec_image_shape_);
this->normalize_op_.Run(&resize_img, this->mean_, this->scale_,
this->is_scale_);
norm_img_batch.push_back(resize_img);
batch_width = max(resize_img.cols, batch_width);
}
std::vector<float> input(batch_num * 3 * 32 * batch_width, 0.0f);
std::vector<float> input(batch_num * 3 * imgH * batch_width, 0.0f);
this->permute_op_.Run(norm_img_batch, input.data());
auto preprocess_end = std::chrono::steady_clock::now();
preprocess_diff += preprocess_end - preprocess_start;
// Inference.
auto input_names = this->predictor_->GetInputNames();
auto input_t = this->predictor_->GetInputHandle(input_names[0]);
input_t->Reshape({batch_num, 3, 32, batch_width});
input_t->Reshape({batch_num, 3, imgH, batch_width});
auto inference_start = std::chrono::steady_clock::now();
input_t->CopyFromCpu(input.data());
this->predictor_->Run();
......@@ -142,13 +144,14 @@ void CRNNRecognizer::LoadModel(const std::string &model_dir) {
precision = paddle_infer::Config::Precision::kInt8;
}
config.EnableTensorRtEngine(1 << 20, 10, 3, precision, false, false);
int imgH = this->rec_image_shape_[1];
int imgW = this->rec_image_shape_[2];
std::map<std::string, std::vector<int>> min_input_shape = {
{"x", {1, 3, 32, 10}}, {"lstm_0.tmp_0", {10, 1, 96}}};
{"x", {1, 3, imgH, 10}}, {"lstm_0.tmp_0", {10, 1, 96}}};
std::map<std::string, std::vector<int>> max_input_shape = {
{"x", {1, 3, 32, 2000}}, {"lstm_0.tmp_0", {1000, 1, 96}}};
{"x", {1, 3, imgH, 2000}}, {"lstm_0.tmp_0", {1000, 1, 96}}};
std::map<std::string, std::vector<int>> opt_input_shape = {
{"x", {1, 3, 32, 320}}, {"lstm_0.tmp_0", {25, 1, 96}}};
{"x", {1, 3, imgH, imgW}}, {"lstm_0.tmp_0", {25, 1, 96}}};
config.SetTRTDynamicShapeInfo(min_input_shape, max_input_shape,
opt_input_shape);
......
......@@ -17,11 +17,9 @@
#include "auto_log/autolog.h"
#include <numeric>
#include <sys/stat.h>
namespace PaddleOCR {
PaddleOCR::PaddleOCR() {
PPOCR::PPOCR() {
if (FLAGS_det) {
this->detector_ = new DBDetector(
FLAGS_det_model_dir, FLAGS_use_gpu, FLAGS_gpu_id, FLAGS_gpu_mem,
......@@ -41,12 +39,13 @@ PaddleOCR::PaddleOCR() {
this->recognizer_ = new CRNNRecognizer(
FLAGS_rec_model_dir, FLAGS_use_gpu, FLAGS_gpu_id, FLAGS_gpu_mem,
FLAGS_cpu_threads, FLAGS_enable_mkldnn, FLAGS_rec_char_dict_path,
FLAGS_use_tensorrt, FLAGS_precision, FLAGS_rec_batch_num);
FLAGS_use_tensorrt, FLAGS_precision, FLAGS_rec_batch_num,
FLAGS_rec_img_h, FLAGS_rec_img_w);
}
};
void PaddleOCR::det(cv::Mat img, std::vector<OCRPredictResult> &ocr_results,
std::vector<double> &times) {
void PPOCR::det(cv::Mat img, std::vector<OCRPredictResult> &ocr_results,
std::vector<double> &times) {
std::vector<std::vector<std::vector<int>>> boxes;
std::vector<double> det_times;
......@@ -63,9 +62,9 @@ void PaddleOCR::det(cv::Mat img, std::vector<OCRPredictResult> &ocr_results,
times[2] += det_times[2];
}
void PaddleOCR::rec(std::vector<cv::Mat> img_list,
std::vector<OCRPredictResult> &ocr_results,
std::vector<double> &times) {
void PPOCR::rec(std::vector<cv::Mat> img_list,
std::vector<OCRPredictResult> &ocr_results,
std::vector<double> &times) {
std::vector<std::string> rec_texts(img_list.size(), "");
std::vector<float> rec_text_scores(img_list.size(), 0);
std::vector<double> rec_times;
......@@ -80,9 +79,9 @@ void PaddleOCR::rec(std::vector<cv::Mat> img_list,
times[2] += rec_times[2];
}
void PaddleOCR::cls(std::vector<cv::Mat> img_list,
std::vector<OCRPredictResult> &ocr_results,
std::vector<double> &times) {
void PPOCR::cls(std::vector<cv::Mat> img_list,
std::vector<OCRPredictResult> &ocr_results,
std::vector<double> &times) {
std::vector<int> cls_labels(img_list.size(), 0);
std::vector<float> cls_scores(img_list.size(), 0);
std::vector<double> cls_times;
......@@ -98,8 +97,8 @@ void PaddleOCR::cls(std::vector<cv::Mat> img_list,
}
std::vector<std::vector<OCRPredictResult>>
PaddleOCR::ocr(std::vector<cv::String> cv_all_img_names, bool det, bool rec,
bool cls) {
PPOCR::ocr(std::vector<cv::String> cv_all_img_names, bool det, bool rec,
bool cls) {
std::vector<double> time_info_det = {0, 0, 0};
std::vector<double> time_info_rec = {0, 0, 0};
std::vector<double> time_info_cls = {0, 0, 0};
......@@ -139,7 +138,7 @@ PaddleOCR::ocr(std::vector<cv::String> cv_all_img_names, bool det, bool rec,
}
} else {
if (!Utility::PathExists(FLAGS_output) && FLAGS_det) {
mkdir(FLAGS_output.c_str(), 0777);
Utility::CreateDir(FLAGS_output);
}
for (int i = 0; i < cv_all_img_names.size(); ++i) {
......@@ -188,9 +187,8 @@ PaddleOCR::ocr(std::vector<cv::String> cv_all_img_names, bool det, bool rec,
return ocr_results;
} // namespace PaddleOCR
void PaddleOCR::log(std::vector<double> &det_times,
std::vector<double> &rec_times,
std::vector<double> &cls_times, int img_num) {
void PPOCR::log(std::vector<double> &det_times, std::vector<double> &rec_times,
std::vector<double> &cls_times, int img_num) {
if (det_times[0] + det_times[1] + det_times[2] > 0) {
AutoLogger autolog_det("ocr_det", FLAGS_use_gpu, FLAGS_use_tensorrt,
FLAGS_enable_mkldnn, FLAGS_cpu_threads, 1, "dynamic",
......@@ -212,7 +210,7 @@ void PaddleOCR::log(std::vector<double> &det_times,
autolog_cls.report();
}
}
PaddleOCR::~PaddleOCR() {
PPOCR::~PPOCR() {
if (this->detector_ != nullptr) {
delete this->detector_;
}
......
......@@ -41,16 +41,17 @@ void Permute::Run(const cv::Mat *im, float *data) {
}
void PermuteBatch::Run(const std::vector<cv::Mat> imgs, float *data) {
for (int j = 0; j < imgs.size(); j ++){
int rh = imgs[j].rows;
int rw = imgs[j].cols;
int rc = imgs[j].channels();
for (int i = 0; i < rc; ++i) {
cv::extractChannel(imgs[j], cv::Mat(rh, rw, CV_32FC1, data + (j * rc + i) * rh * rw), i);
}
for (int j = 0; j < imgs.size(); j++) {
int rh = imgs[j].rows;
int rw = imgs[j].cols;
int rc = imgs[j].channels();
for (int i = 0; i < rc; ++i) {
cv::extractChannel(
imgs[j], cv::Mat(rh, rw, CV_32FC1, data + (j * rc + i) * rh * rw), i);
}
}
}
void Normalize::Run(cv::Mat *im, const std::vector<float> &mean,
const std::vector<float> &scale, const bool is_scale) {
double e = 1.0;
......@@ -101,8 +102,8 @@ void CrnnResizeImg::Run(const cv::Mat &img, cv::Mat &resize_img, float wh_ratio,
imgC = rec_image_shape[0];
imgH = rec_image_shape[1];
imgW = rec_image_shape[2];
imgW = int(32 * wh_ratio);
imgW = int(imgH * wh_ratio);
float ratio = float(img.cols) / float(img.rows);
int resize_w, resize_h;
......@@ -111,7 +112,7 @@ void CrnnResizeImg::Run(const cv::Mat &img, cv::Mat &resize_img, float wh_ratio,
resize_w = imgW;
else
resize_w = int(ceilf(imgH * ratio));
cv::resize(img, resize_img, cv::Size(resize_w, imgH), 0.f, 0.f,
cv::INTER_LINEAR);
cv::copyMakeBorder(resize_img, resize_img, 0, 0, 0,
......
......@@ -16,10 +16,15 @@
#include <include/utility.h>
#include <iostream>
#include <ostream>
#include <sys/stat.h>
#include <sys/types.h>
#include <vector>
#ifdef _WIN32
#include <direct.h>
#else
#include <sys/stat.h>
#endif
namespace PaddleOCR {
std::vector<std::string> Utility::ReadDict(const std::string &path) {
......@@ -206,6 +211,14 @@ bool Utility::PathExists(const std::string &path) {
#endif // !_WIN32
}
void Utility::CreateDir(const std::string &path) {
#ifdef _WIN32
_mkdir(path.c_str());
#else
mkdir(path.c_str(), 0777);
#endif // !_WIN32
}
void Utility::print_result(const std::vector<OCRPredictResult> &ocr_result) {
for (int i = 0; i < ocr_result.size(); i++) {
std::cout << i << "\t";
......
......@@ -36,7 +36,6 @@ PaddleOCR operating environment and Paddle Serving operating environment are nee
1. Please prepare PaddleOCR operating environment reference [link](../../doc/doc_ch/installation.md).
Download the corresponding paddlepaddle whl package according to the environment, it is recommended to install version 2.2.2.
2. The steps of PaddleServing operating environment prepare are as follows:
......@@ -194,6 +193,52 @@ The recognition model is the same.
2021-05-13 03:42:36,979 chl2(In: ['rec'], Out: ['@DAGExecutor']) size[0/0]
```
## C++ Serving
Service deployment based on python obviously has the advantage of convenient secondary development. However, the real application often needs to pursue better performance. PaddleServing also provides a more performant C++ deployment version.
The C++ service deployment is the same as python in the environment setup and data preparation stages, the difference is when the service is started and the client sends requests.
| Language | Speed ​​| Secondary development | Do you need to compile |
|-----|-----|---------|------------|
| C++ | fast | Slightly difficult | Single model prediction does not need to be compiled, multi-model concatenation needs to be compiled |
| python | general | easy | single-model/multi-model no compilation required |
1. Compile Serving
To improve predictive performance, C++ services also provide multiple model concatenation services. Unlike Python Pipeline services, multiple model concatenation requires the pre - and post-model processing code to be written on the server side, so local recompilation is required to generate serving. Specific may refer to the official document: [how to compile Serving](https://github.com/PaddlePaddle/Serving/blob/v0.8.3/doc/Compile_EN.md)
2. Run the following command to start the service.
```
# Start the service and save the running log in log.txt
python3 -m paddle_serving_server.serve --model ppocrv2_det_serving ppocrv2_rec_serving --op GeneralDetectionOp GeneralInferOp --port 9293 &>log.txt &
```
After the service is successfully started, a log similar to the following will be printed in log.txt
![](./imgs/start_server.png)
3. Send service request
Due to the need for pre and post-processing in the C++Server part, in order to speed up the input to the C++Server is only the base64 encoded string of the picture, it needs to be manually modified
Change the feed_type field and shape field in ppocrv2_det_client/serving_client_conf.prototxt to the following:
```
feed_var {
name: "x"
alias_name: "x"
is_lod_tensor: false
feed_type: 20
shape: 1
}
```
start the client:
```
python3 ocr_cpp_client.py ppocrv2_det_client ppocrv2_rec_client
```
After successfully running, the predicted result of the model will be printed in the cmd window. An example of the result is:
![](./imgs/results.png)
## WINDOWS Users
Windows does not support Pipeline Serving, if we want to lauch paddle serving on Windows, we should use Web Service, for more infomation please refer to [Paddle Serving for Windows Users](https://github.com/PaddlePaddle/Serving/blob/develop/doc/Windows_Tutorial_EN.md)
......
......@@ -6,6 +6,7 @@ PaddleOCR提供2种服务部署方式:
- 基于PaddleHub Serving的部署:代码路径为"`./deploy/hubserving`",使用方法参考[文档](../../deploy/hubserving/readme.md)
- 基于PaddleServing的部署:代码路径为"`./deploy/pdserving`",按照本教程使用。
# 基于PaddleServing的服务部署
本文档将介绍如何使用[PaddleServing](https://github.com/PaddlePaddle/Serving/blob/develop/README_CN.md)工具部署PP-OCR动态图模型的pipeline在线服务。
......@@ -17,6 +18,8 @@ PaddleOCR提供2种服务部署方式:
更多有关PaddleServing服务化部署框架介绍和使用教程参考[文档](https://github.com/PaddlePaddle/Serving/blob/develop/README_CN.md)
AIStudio演示案例可参考 [基于PaddleServing的OCR服务化部署实战](https://aistudio.baidu.com/aistudio/projectdetail/3630726)
## 目录
- [环境准备](#环境准备)
- [模型转换](#模型转换)
......@@ -30,7 +33,6 @@ PaddleOCR提供2种服务部署方式:
需要准备PaddleOCR的运行环境和Paddle Serving的运行环境。
- 准备PaddleOCR的运行环境[链接](../../doc/doc_ch/installation.md)
根据环境下载对应的paddlepaddle whl包,推荐安装2.2.2版本
- 准备PaddleServing的运行环境,步骤如下
......@@ -106,7 +108,7 @@ python3 -m paddle_serving_client.convert --dirname ./ch_PP-OCRv2_rec_infer/ \
1. 下载PaddleOCR代码,若已下载可跳过此步骤
```
git clone https://github.com/PaddlePaddle/PaddleOCR
# 进入到工作目录
cd PaddleOCR/deploy/pdserving/
```
......@@ -132,7 +134,7 @@ python3 -m paddle_serving_client.convert --dirname ./ch_PP-OCRv2_rec_infer/ \
python3 pipeline_http_client.py
```
成功运行后,模型预测的结果会打印在cmd窗口中,结果示例为:
![](./imgs/results.png)
![](./imgs/pipeline_result.png)
调整 config.yml 中的并发个数获得最大的QPS, 一般检测和识别的并发数为2:1
```
......@@ -187,6 +189,73 @@ python3 -m paddle_serving_client.convert --dirname ./ch_PP-OCRv2_rec_infer/ \
2021-05-13 03:42:36,979 chl2(In: ['rec'], Out: ['@DAGExecutor']) size[0/0]
```
<a name="C++"></a>
## Paddle Serving C++ 部署
基于python的服务部署,显然具有二次开发便捷的优势,然而真正落地应用,往往需要追求更优的性能。PaddleServing 也提供了性能更优的C++部署版本。
C++ 服务部署在环境搭建和数据准备阶段与 python 相同,区别在于启动服务和客户端发送请求时不同。
| 语言 | 速度 | 二次开发 | 是否需要编译 |
|-----|-----|---------|------------|
| C++ | 很快 | 略有难度 | 单模型预测无需编译,多模型串联需要编译 |
| python | 一般 | 容易 | 单模型/多模型 均无需编译|
1. 准备 Serving 环境
为了提高预测性能,C++ 服务同样提供了多模型串联服务。与python pipeline服务不同,多模型串联的过程中需要将模型前后处理代码写在服务端,因此需要在本地重新编译生成serving。
首先需要下载Serving代码库, 把OCR文本检测预处理相关代码替换到Serving库中
```
git clone https://github.com/PaddlePaddle/Serving
cp -rf general_detection_op.cpp Serving/core/general-server/op
```
具体可参考官方文档:[如何编译Serving](https://github.com/PaddlePaddle/Serving/blob/v0.8.3/doc/Compile_CN.md),注意需要开启 WITH_OPENCV 选项。
完成编译后,注意要安装编译出的三个whl包,并设置SERVING_BIN环境变量。
2. 启动服务可运行如下命令:
一个服务启动两个模型串联,只需要在--model后依次按顺序传入模型文件夹的相对路径,且需要在--op后依次传入自定义C++OP类名称:
```
# 启动服务,运行日志保存在log.txt
python3 -m paddle_serving_server.serve --model ppocrv2_det_serving ppocrv2_rec_serving --op GeneralDetectionOp GeneralInferOp --port 9293 &>log.txt &
```
成功启动服务后,log.txt中会打印类似如下日志
![](./imgs/start_server.png)
3. 发送服务请求:
由于需要在C++Server部分进行前后处理,为了加速传入C++Server的仅仅是图片的base64编码的字符串,故需要手动修改
ppocrv2_det_client/serving_client_conf.prototxt 中 feed_type 字段 和 shape 字段,修改成如下内容:
```
feed_var {
name: "x"
alias_name: "x"
is_lod_tensor: false
feed_type: 20
shape: 1
}
```
启动客户端
```
python3 ocr_cpp_client.py ppocrv2_det_client ppocrv2_rec_client
```
成功运行后,模型预测的结果会打印在cmd窗口中,结果示例为:
![](./imgs/results.png)
在浏览器中输入服务器 ip:端口号,可以看到当前服务的实时QPS。(端口号范围需要是8000-9000)
在200张真实图片上测试,把检测长边限制为960。T4 GPU 上 QPS 峰值可达到51左右,约为pipeline的 2.12 倍。
![](./imgs/c++_qps.png)
<a name="Windows用户"></a>
## Windows用户
......
// Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "core/general-server/op/general_detection_op.h"
#include "core/predictor/framework/infer.h"
#include "core/predictor/framework/memory.h"
#include "core/predictor/framework/resource.h"
#include "core/util/include/timer.h"
#include <algorithm>
#include <iostream>
#include <memory>
#include <sstream>
/*
#include "opencv2/imgcodecs/legacy/constants_c.h"
#include "opencv2/imgproc/types_c.h"
*/
namespace baidu {
namespace paddle_serving {
namespace serving {
using baidu::paddle_serving::Timer;
using baidu::paddle_serving::predictor::MempoolWrapper;
using baidu::paddle_serving::predictor::general_model::Tensor;
using baidu::paddle_serving::predictor::general_model::Response;
using baidu::paddle_serving::predictor::general_model::Request;
using baidu::paddle_serving::predictor::InferManager;
using baidu::paddle_serving::predictor::PaddleGeneralModelConfig;
int GeneralDetectionOp::inference() {
VLOG(2) << "Going to run inference";
const std::vector<std::string> pre_node_names = pre_names();
if (pre_node_names.size() != 1) {
LOG(ERROR) << "This op(" << op_name()
<< ") can only have one predecessor op, but received "
<< pre_node_names.size();
return -1;
}
const std::string pre_name = pre_node_names[0];
const GeneralBlob *input_blob = get_depend_argument<GeneralBlob>(pre_name);
if (!input_blob) {
LOG(ERROR) << "input_blob is nullptr,error";
return -1;
}
uint64_t log_id = input_blob->GetLogId();
VLOG(2) << "(logid=" << log_id << ") Get precedent op name: " << pre_name;
GeneralBlob *output_blob = mutable_data<GeneralBlob>();
if (!output_blob) {
LOG(ERROR) << "output_blob is nullptr,error";
return -1;
}
output_blob->SetLogId(log_id);
if (!input_blob) {
LOG(ERROR) << "(logid=" << log_id
<< ") Failed mutable depended argument, op:" << pre_name;
return -1;
}
const TensorVector *in = &input_blob->tensor_vector;
TensorVector *out = &output_blob->tensor_vector;
int batch_size = input_blob->_batch_size;
VLOG(2) << "(logid=" << log_id << ") input batch size: " << batch_size;
output_blob->_batch_size = batch_size;
std::vector<int> input_shape;
int in_num = 0;
void *databuf_data = NULL;
char *databuf_char = NULL;
size_t databuf_size = 0;
// now only support single string
char *total_input_ptr = static_cast<char *>(in->at(0).data.data());
std::string base64str = total_input_ptr;
float ratio_h{};
float ratio_w{};
cv::Mat img = Base2Mat(base64str);
cv::Mat srcimg;
cv::Mat resize_img;
cv::Mat resize_img_rec;
cv::Mat crop_img;
img.copyTo(srcimg);
this->resize_op_.Run(img, resize_img, this->max_side_len_, ratio_h, ratio_w,
this->use_tensorrt_);
this->normalize_op_.Run(&resize_img, this->mean_det, this->scale_det,
this->is_scale_);
std::vector<float> input(1 * 3 * resize_img.rows * resize_img.cols, 0.0f);
this->permute_op_.Run(&resize_img, input.data());
TensorVector *real_in = new TensorVector();
if (!real_in) {
LOG(ERROR) << "real_in is nullptr,error";
return -1;
}
for (int i = 0; i < in->size(); ++i) {
input_shape = {1, 3, resize_img.rows, resize_img.cols};
in_num = std::accumulate(input_shape.begin(), input_shape.end(), 1,
std::multiplies<int>());
databuf_size = in_num * sizeof(float);
databuf_data = MempoolWrapper::instance().malloc(databuf_size);
if (!databuf_data) {
LOG(ERROR) << "Malloc failed, size: " << databuf_size;
return -1;
}
memcpy(databuf_data, input.data(), databuf_size);
databuf_char = reinterpret_cast<char *>(databuf_data);
paddle::PaddleBuf paddleBuf(databuf_char, databuf_size);
paddle::PaddleTensor tensor_in;
tensor_in.name = in->at(i).name;
tensor_in.dtype = paddle::PaddleDType::FLOAT32;
tensor_in.shape = {1, 3, resize_img.rows, resize_img.cols};
tensor_in.lod = in->at(i).lod;
tensor_in.data = paddleBuf;
real_in->push_back(tensor_in);
}
Timer timeline;
int64_t start = timeline.TimeStampUS();
timeline.Start();
if (InferManager::instance().infer(engine_name().c_str(), real_in, out,
batch_size)) {
LOG(ERROR) << "(logid=" << log_id
<< ") Failed do infer in fluid model: " << engine_name().c_str();
return -1;
}
delete real_in;
std::vector<int> output_shape;
int out_num = 0;
void *databuf_data_out = NULL;
char *databuf_char_out = NULL;
size_t databuf_size_out = 0;
// this is special add for PaddleOCR postprecess
int infer_outnum = out->size();
for (int k = 0; k < infer_outnum; ++k) {
int n2 = out->at(k).shape[2];
int n3 = out->at(k).shape[3];
int n = n2 * n3;
float *out_data = static_cast<float *>(out->at(k).data.data());
std::vector<float> pred(n, 0.0);
std::vector<unsigned char> cbuf(n, ' ');
for (int i = 0; i < n; i++) {
pred[i] = float(out_data[i]);
cbuf[i] = (unsigned char)((out_data[i]) * 255);
}
cv::Mat cbuf_map(n2, n3, CV_8UC1, (unsigned char *)cbuf.data());
cv::Mat pred_map(n2, n3, CV_32F, (float *)pred.data());
const double threshold = this->det_db_thresh_ * 255;
const double maxvalue = 255;
cv::Mat bit_map;
cv::threshold(cbuf_map, bit_map, threshold, maxvalue, cv::THRESH_BINARY);
cv::Mat dilation_map;
cv::Mat dila_ele =
cv::getStructuringElement(cv::MORPH_RECT, cv::Size(2, 2));
cv::dilate(bit_map, dilation_map, dila_ele);
boxes = post_processor_.BoxesFromBitmap(pred_map, dilation_map,
this->det_db_box_thresh_,
this->det_db_unclip_ratio_);
boxes = post_processor_.FilterTagDetRes(boxes, ratio_h, ratio_w, srcimg);
float max_wh_ratio = 0.0f;
std::vector<cv::Mat> crop_imgs;
std::vector<cv::Mat> resize_imgs;
int max_resize_w = 0;
int max_resize_h = 0;
int box_num = boxes.size();
std::vector<std::vector<float>> output_rec;
for (int i = 0; i < box_num; ++i) {
cv::Mat line_img = GetRotateCropImage(img, boxes[i]);
float wh_ratio = float(line_img.cols) / float(line_img.rows);
max_wh_ratio = max_wh_ratio > wh_ratio ? max_wh_ratio : wh_ratio;
crop_imgs.push_back(line_img);
}
for (int i = 0; i < box_num; ++i) {
cv::Mat resize_img;
crop_img = crop_imgs[i];
this->resize_op_rec.Run(crop_img, resize_img, max_wh_ratio,
this->use_tensorrt_);
this->normalize_op_.Run(&resize_img, this->mean_rec, this->scale_rec,
this->is_scale_);
max_resize_w = std::max(max_resize_w, resize_img.cols);
max_resize_h = std::max(max_resize_h, resize_img.rows);
resize_imgs.push_back(resize_img);
}
int buf_size = 3 * max_resize_h * max_resize_w;
output_rec = std::vector<std::vector<float>>(
box_num, std::vector<float>(buf_size, 0.0f));
for (int i = 0; i < box_num; ++i) {
resize_img_rec = resize_imgs[i];
this->permute_op_.Run(&resize_img_rec, output_rec[i].data());
}
// Inference.
output_shape = {box_num, 3, max_resize_h, max_resize_w};
out_num = std::accumulate(output_shape.begin(), output_shape.end(), 1,
std::multiplies<int>());
databuf_size_out = out_num * sizeof(float);
databuf_data_out = MempoolWrapper::instance().malloc(databuf_size_out);
if (!databuf_data_out) {
LOG(ERROR) << "Malloc failed, size: " << databuf_size_out;
return -1;
}
int offset = buf_size * sizeof(float);
for (int i = 0; i < box_num; ++i) {
memcpy(databuf_data_out + i * offset, output_rec[i].data(), offset);
}
databuf_char_out = reinterpret_cast<char *>(databuf_data_out);
paddle::PaddleBuf paddleBuf(databuf_char_out, databuf_size_out);
paddle::PaddleTensor tensor_out;
tensor_out.name = "x";
tensor_out.dtype = paddle::PaddleDType::FLOAT32;
tensor_out.shape = output_shape;
tensor_out.data = paddleBuf;
out->push_back(tensor_out);
}
out->erase(out->begin(), out->begin() + infer_outnum);
int64_t end = timeline.TimeStampUS();
CopyBlobInfo(input_blob, output_blob);
AddBlobInfo(output_blob, start);
AddBlobInfo(output_blob, end);
return 0;
}
cv::Mat GeneralDetectionOp::Base2Mat(std::string &base64_data) {
cv::Mat img;
std::string s_mat;
s_mat = base64Decode(base64_data.data(), base64_data.size());
std::vector<char> base64_img(s_mat.begin(), s_mat.end());
img = cv::imdecode(base64_img, cv::IMREAD_COLOR); // CV_LOAD_IMAGE_COLOR
return img;
}
std::string GeneralDetectionOp::base64Decode(const char *Data, int DataByte) {
const char DecodeTable[] = {
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0,
62, // '+'
0, 0, 0,
63, // '/'
52, 53, 54, 55, 56, 57, 58, 59, 60, 61, // '0'-'9'
0, 0, 0, 0, 0, 0, 0, 0, 1, 2, 3, 4, 5, 6, 7, 8, 9,
10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, // 'A'-'Z'
0, 0, 0, 0, 0, 0, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36,
37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51, // 'a'-'z'
};
std::string strDecode;
int nValue;
int i = 0;
while (i < DataByte) {
if (*Data != '\r' && *Data != '\n') {
nValue = DecodeTable[*Data++] << 18;
nValue += DecodeTable[*Data++] << 12;
strDecode += (nValue & 0x00FF0000) >> 16;
if (*Data != '=') {
nValue += DecodeTable[*Data++] << 6;
strDecode += (nValue & 0x0000FF00) >> 8;
if (*Data != '=') {
nValue += DecodeTable[*Data++];
strDecode += nValue & 0x000000FF;
}
}
i += 4;
} else // 回车换行,跳过
{
Data++;
i++;
}
}
return strDecode;
}
cv::Mat
GeneralDetectionOp::GetRotateCropImage(const cv::Mat &srcimage,
std::vector<std::vector<int>> box) {
cv::Mat image;
srcimage.copyTo(image);
std::vector<std::vector<int>> points = box;
int x_collect[4] = {box[0][0], box[1][0], box[2][0], box[3][0]};
int y_collect[4] = {box[0][1], box[1][1], box[2][1], box[3][1]};
int left = int(*std::min_element(x_collect, x_collect + 4));
int right = int(*std::max_element(x_collect, x_collect + 4));
int top = int(*std::min_element(y_collect, y_collect + 4));
int bottom = int(*std::max_element(y_collect, y_collect + 4));
cv::Mat img_crop;
image(cv::Rect(left, top, right - left, bottom - top)).copyTo(img_crop);
for (int i = 0; i < points.size(); i++) {
points[i][0] -= left;
points[i][1] -= top;
}
int img_crop_width = int(sqrt(pow(points[0][0] - points[1][0], 2) +
pow(points[0][1] - points[1][1], 2)));
int img_crop_height = int(sqrt(pow(points[0][0] - points[3][0], 2) +
pow(points[0][1] - points[3][1], 2)));
cv::Point2f pts_std[4];
pts_std[0] = cv::Point2f(0., 0.);
pts_std[1] = cv::Point2f(img_crop_width, 0.);
pts_std[2] = cv::Point2f(img_crop_width, img_crop_height);
pts_std[3] = cv::Point2f(0.f, img_crop_height);
cv::Point2f pointsf[4];
pointsf[0] = cv::Point2f(points[0][0], points[0][1]);
pointsf[1] = cv::Point2f(points[1][0], points[1][1]);
pointsf[2] = cv::Point2f(points[2][0], points[2][1]);
pointsf[3] = cv::Point2f(points[3][0], points[3][1]);
cv::Mat M = cv::getPerspectiveTransform(pointsf, pts_std);
cv::Mat dst_img;
cv::warpPerspective(img_crop, dst_img, M,
cv::Size(img_crop_width, img_crop_height),
cv::BORDER_REPLICATE);
if (float(dst_img.rows) >= float(dst_img.cols) * 1.5) {
cv::Mat srcCopy = cv::Mat(dst_img.rows, dst_img.cols, dst_img.depth());
cv::transpose(dst_img, srcCopy);
cv::flip(srcCopy, srcCopy, 0);
return srcCopy;
} else {
return dst_img;
}
}
DEFINE_OP(GeneralDetectionOp);
} // namespace serving
} // namespace paddle_serving
} // namespace baidu
......@@ -45,10 +45,8 @@ for img_file in os.listdir(test_img_dir):
image_data = file.read()
image = cv2_to_base64(image_data)
res_list = []
#print(image)
fetch_map = client.predict(
feed={"x": image}, fetch=["save_infer_model/scale_0.tmp_1"], batch=True)
print("fetrch map:", fetch_map)
one_batch_res = ocr_reader.postprocess(fetch_map, with_score=True)
for res in one_batch_res:
res_list.append(res[0])
......
......@@ -34,12 +34,28 @@ test_img_dir = args.image_dir
for idx, img_file in enumerate(os.listdir(test_img_dir)):
with open(os.path.join(test_img_dir, img_file), 'rb') as file:
image_data1 = file.read()
# print file name
print('{}{}{}'.format('*' * 10, img_file, '*' * 10))
image = cv2_to_base64(image_data1)
for i in range(1):
data = {"key": ["image"], "value": [image]}
r = requests.post(url=url, data=json.dumps(data))
print(r.json())
data = {"key": ["image"], "value": [image]}
r = requests.post(url=url, data=json.dumps(data))
result = r.json()
print("erro_no:{}, err_msg:{}".format(result["err_no"], result["err_msg"]))
# check success
if result["err_no"] == 0:
ocr_result = result["value"][0]
try:
for item in eval(ocr_result):
# return transcription and points
print("{}, {}".format(item[0], item[1]))
except Exception as e:
print("No results")
continue
else:
print(
"For details about error message, see PipelineServingLogs/pipeline.log"
)
print("==> total number of test imgs: ", len(os.listdir(test_img_dir)))
......@@ -15,6 +15,7 @@ from paddle_serving_server.web_service import WebService, Op
import logging
import numpy as np
import copy
import cv2
import base64
# from paddle_serving_app.reader import OCRReader
......@@ -36,7 +37,7 @@ class DetOp(Op):
self.filter_func = FilterBoxes(10, 10)
self.post_func = DBPostProcess({
"thresh": 0.3,
"box_thresh": 0.5,
"box_thresh": 0.6,
"max_candidates": 1000,
"unclip_ratio": 1.5,
"min_size": 3
......@@ -79,8 +80,10 @@ class RecOp(Op):
raw_im = input_dict["image"]
data = np.frombuffer(raw_im, np.uint8)
im = cv2.imdecode(data, cv2.IMREAD_COLOR)
dt_boxes = input_dict["dt_boxes"]
dt_boxes = self.sorted_boxes(dt_boxes)
self.dt_list = input_dict["dt_boxes"]
self.dt_list = self.sorted_boxes(self.dt_list)
# deepcopy to save origin dt_boxes
dt_boxes = copy.deepcopy(self.dt_list)
feed_list = []
img_list = []
max_wh_ratio = 0
......@@ -126,25 +129,29 @@ class RecOp(Op):
imgs[id] = norm_img
feed = {"x": imgs.copy()}
feed_list.append(feed)
return feed_list, False, None, ""
def postprocess(self, input_dicts, fetch_data, data_id, log_id):
res_list = []
rec_list = []
dt_num = len(self.dt_list)
if isinstance(fetch_data, dict):
if len(fetch_data) > 0:
rec_batch_res = self.ocr_reader.postprocess(
fetch_data, with_score=True)
for res in rec_batch_res:
res_list.append(res[0])
rec_list.append(res)
elif isinstance(fetch_data, list):
for one_batch in fetch_data:
one_batch_res = self.ocr_reader.postprocess(
one_batch, with_score=True)
for res in one_batch_res:
res_list.append(res[0])
res = {"res": str(res_list)}
rec_list.append(res)
result_list = []
for i in range(dt_num):
text = rec_list[i]
dt_box = self.dt_list[i]
result_list.append([text, dt_box.tolist()])
res = {"result": str(result_list)}
return res, None, ""
......
## 版面分析数据集
这里整理了常用版面分析数据集,持续更新中,欢迎各位小伙伴贡献数据集~
- [publaynet数据集](#publaynet)
- [CDLA数据集](#CDLA)
- [TableBank数据集](#TableBank)
版面分析数据集多为目标检测数据集,除了开源数据,用户还可使用合成工具自行合成,如[labelme](https://github.com/wkentaro/labelme)等。
<a name="publaynet"></a>
#### 1、publaynet数据集
- **数据来源**:https://github.com/ibm-aur-nlp/PubLayNet
- **数据简介**:publaynet数据集的训练集合中包含35万张图像,验证集合中包含1.1万张图像。总共包含5个类别,分别是: `text, title, list, table, figure`。部分图像以及标注框可视化如下所示。
<div align="center">
<img src="../datasets/publaynet_demo/gt_PMC3724501_00006.jpg" width="500">
<img src="../datasets/publaynet_demo/gt_PMC5086060_00002.jpg" width="500">
</div>
- **下载地址**:https://developer.ibm.com/exchanges/data/all/publaynet/
- **说明**:使用该数据集时,需要遵守[CDLA-Permissive](https://cdla.io/permissive-1-0/)协议。
<a name="CDLA"></a>
#### 2、CDLA数据集
- **数据来源**:https://github.com/buptlihang/CDLA
- **数据简介**:publaynet数据集的训练集合中包含5000张图像,验证集合中包含1000张图像。总共包含10个类别,分别是: `Text, Title, Figure, Figure caption, Table, Table caption, Header, Footer, Reference, Equation`。部分图像以及标注框可视化如下所示。
<div align="center">
<img src="../datasets/CDLA_demo/val_0633.jpg" width="500">
<img src="../datasets/CDLA_demo/val_0941.jpg" width="500">
</div>
- **下载地址**:https://github.com/buptlihang/CDLA
- **说明**:基于[PaddleDetection](https://github.com/PaddlePaddle/PaddleDetection/tree/develop)套件,在该数据集上训练目标检测模型时,在转换label时,需要将`label.txt`中的`__ignore__``_background_`去除。
<a name="TableBank"></a>
#### 3、TableBank数据集
- **数据来源**:https://doc-analysis.github.io/tablebank-page/index.html
- **数据简介**:TableBank数据集包含Latex(训练集187199张,验证集7265张,测试集5719张)与Word(训练集73383张,验证集2735张,测试集2281张)两种类别的文档。仅包含`Table` 1个类别。部分图像以及标注框可视化如下所示。
<div align="center">
<img src="../datasets/tablebank_demo/004.png" height="700">
<img src="../datasets/tablebank_demo/005.png" height="700">
</div>
- **下载地址**:https://doc-analysis.github.io/tablebank-page/index.html
- **说明**:使用该数据集时,需要遵守[Apache-2.0](https://github.com/doc-analysis/TableBank/blob/master/LICENSE)协议。
......@@ -47,7 +47,7 @@ __all__ = [
]
SUPPORT_DET_MODEL = ['DB']
VERSION = '2.4.0.4'
VERSION = '2.5'
SUPPORT_REC_MODEL = ['CRNN']
BASE_DIR = os.path.expanduser("~/.paddleocr/")
......@@ -442,7 +442,7 @@ class PPStructure(StructureSystem):
logger.debug(params)
super().__init__(params)
def __call__(self, img):
def __call__(self, img, return_ocr_result_in_table=False):
if isinstance(img, str):
# download net image
if img.startswith('http'):
......@@ -460,7 +460,7 @@ class PPStructure(StructureSystem):
if isinstance(img, np.ndarray) and len(img.shape) == 2:
img = cv2.cvtColor(img, cv2.COLOR_GRAY2BGR)
res = super().__call__(img)
res = super().__call__(img, return_ocr_result_in_table)
return res
......
......@@ -73,7 +73,7 @@ class BaseRecLabelDecode(object):
conf_list = [0]
text = ''.join(char_list)
result_list.append((text, np.mean(conf_list)))
result_list.append((text, np.mean(conf_list).tolist()))
return result_list
def get_ignored_tokens(self):
......@@ -196,7 +196,7 @@ class NRTRLabelDecode(BaseRecLabelDecode):
else:
conf_list.append(1)
text = ''.join(char_list)
result_list.append((text.lower(), np.mean(conf_list)))
result_list.append((text.lower(), np.mean(conf_list).tolist()))
return result_list
......@@ -241,7 +241,7 @@ class AttnLabelDecode(BaseRecLabelDecode):
else:
conf_list.append(1)
text = ''.join(char_list)
result_list.append((text, np.mean(conf_list)))
result_list.append((text, np.mean(conf_list).tolist()))
return result_list
def __call__(self, preds, label=None, *args, **kwargs):
......@@ -333,7 +333,7 @@ class SEEDLabelDecode(BaseRecLabelDecode):
else:
conf_list.append(1)
text = ''.join(char_list)
result_list.append((text, np.mean(conf_list)))
result_list.append((text, np.mean(conf_list).tolist()))
return result_list
def __call__(self, preds, label=None, *args, **kwargs):
......@@ -417,7 +417,7 @@ class SRNLabelDecode(BaseRecLabelDecode):
conf_list.append(1)
text = ''.join(char_list)
result_list.append((text, np.mean(conf_list)))
result_list.append((text, np.mean(conf_list).tolist()))
return result_list
def add_special_char(self, dict_character):
......@@ -636,7 +636,7 @@ class SARLabelDecode(BaseRecLabelDecode):
comp = re.compile('[^A-Z^a-z^0-9^\u4e00-\u9fa5]')
text = text.lower()
text = comp.sub('', text)
result_list.append((text, np.mean(conf_list)))
result_list.append((text, np.mean(conf_list).tolist()))
return result_list
def __call__(self, preds, label=None, *args, **kwargs):
......@@ -699,7 +699,7 @@ class PRENLabelDecode(BaseRecLabelDecode):
text = ''.join(char_list)
if len(text) > 0:
result_list.append((text, np.mean(conf_list)))
result_list.append((text, np.mean(conf_list).tolist()))
else:
# here confidence of empty recog result is 1
result_list.append(('', 1))
......
# 基于Python预测引擎推理
- [版面分析+表格识别](#1)
- [DocVQA](#2)
- [1. Structure](#1)
- [1.1 版面分析+表格识别](#1.1)
- [1.2 版面分析](#1.2)
- [1.3 表格识别](#1.3)
- [2. DocVQA](#2)
<a name="1"></a>
## 1. 版面分析+表格识别
## 1. Structure
进入`ppstructure`目录
```bash
cd ppstructure
# 下载模型
````
下载模型
```bash
mkdir inference && cd inference
# 下载PP-OCRv2文本检测模型并解压
wget https://paddleocr.bj.bcebos.com/PP-OCRv2/chinese/ch_PP-OCRv2_det_slim_quant_infer.tar && tar xf ch_PP-OCRv2_det_slim_quant_infer.tar
......@@ -18,17 +23,42 @@ wget https://paddleocr.bj.bcebos.com/PP-OCRv2/chinese/ch_PP-OCRv2_rec_slim_quant
# 下载超轻量级英文表格预测模型并解压
wget https://paddleocr.bj.bcebos.com/dygraph_v2.0/table/en_ppocr_mobile_v2.0_table_structure_infer.tar && tar xf en_ppocr_mobile_v2.0_table_structure_infer.tar
cd ..
```
<a name="1.1"></a>
### 1.1 版面分析+表格识别
```bash
python3 predict_system.py --det_model_dir=inference/ch_PP-OCRv2_det_slim_quant_infer \
--rec_model_dir=inference/ch_PP-OCRv2_rec_slim_quant_infer \
--table_model_dir=inference/en_ppocr_mobile_v2.0_table_structure_infer \
--image_dir=../doc/table/1.png \
--image_dir=./docs/table/1.png \
--rec_char_dict_path=../ppocr/utils/ppocr_keys_v1.txt \
--table_char_dict_path=../ppocr/utils/dict/table_structure_dict.txt \
--output=../output/table \
--output=../output \
--vis_font_path=../doc/fonts/simfang.ttf
```
运行完成后,每张图片会在`output`字段指定的目录下的`talbe`目录下有一个同名目录,图片里的每个表格会存储为一个excel,图片区域会被裁剪之后保存下来,excel文件和图片名名为表格在图片里的坐标。
运行完成后,每张图片会在`output`字段指定的目录下的`structure`目录下有一个同名目录,图片里的每个表格会存储为一个excel,图片区域会被裁剪之后保存下来,excel文件和图片名为表格在图片里的坐标。详细的结果会存储在`res.txt`文件中。
<a name="1.2"></a>
### 1.2 版面分析
```bash
python3 predict_system.py --image_dir=./docs/table/1.png --table=false --ocr=false --output=../output/
```
运行完成后,每张图片会在`output`字段指定的目录下的`structure`目录下有一个同名目录,图片区域会被裁剪之后保存下来,图片名为表格在图片里的坐标。版面分析结果会存储在`res.txt`文件中。
<a name="1.3"></a>
### 1.3 表格识别
```bash
python3 predict_system.py --det_model_dir=inference/ch_PP-OCRv2_det_slim_quant_infer \
--rec_model_dir=inference/ch_PP-OCRv2_rec_slim_quant_infer \
--table_model_dir=inference/en_ppocr_mobile_v2.0_table_structure_infer \
--image_dir=./docs/table/table.jpg \
--rec_char_dict_path=../ppocr/utils/ppocr_keys_v1.txt \
--table_char_dict_path=../ppocr/utils/dict/table_structure_dict.txt \
--output=../output \
--vis_font_path=../doc/fonts/simfang.ttf \
--layout=false
```
运行完成后,每张图片会在`output`字段指定的目录下的`structure`目录下有一个同名目录,表格会存储为一个excel,excel文件名为`[0,0,img_h,img_w]`。
<a name="2"></a>
## 2. DocVQA
......@@ -47,4 +77,4 @@ python3 predict_system.py --model_name_or_path=vqa/PP-Layout_v1.0_ser_pretrained
--image_dir=vqa/images/input/zh_val_0.jpg \
--vis_font_path=../doc/fonts/simfang.ttf
```
运行完成后,每张图片会在`output`字段指定的目录下的`vqa`目录下存放可视化之后的图片,图片名和输入图片名一致。
\ No newline at end of file
运行完成后,每张图片会在`output`字段指定的目录下的`vqa`目录下存放可视化之后的图片,图片名和输入图片名一致。
# 基于Python预测引擎推理
# Python Inference
- [版面分析+表格识别](#1)
- [DocVQA](#2)
- [1. Structure](#1)
- [1.1 layout analysis + table recognition](#1.1)
- [1.2 layout analysis](#1.2)
- [1.3 table recognition](#1.3)
- [2. DocVQA](#2)
<a name="1"></a>
## 1. 版面分析+表格识别
## 1. Structure
Go to the `ppstructure` directory
```bash
cd ppstructure
````
# 下载模型
download model
```bash
mkdir inference && cd inference
# 下载PP-OCRv2文本检测模型并解压
# Download the PP-OCRv2 text detection model and unzip it
wget https://paddleocr.bj.bcebos.com/PP-OCRv2/chinese/ch_PP-OCRv2_det_slim_quant_infer.tar && tar xf ch_PP-OCRv2_det_slim_quant_infer.tar
# 下载PP-OCRv2文本识别模型并解压
# Download the PP-OCRv2 text recognition model and unzip it
wget https://paddleocr.bj.bcebos.com/PP-OCRv2/chinese/ch_PP-OCRv2_rec_slim_quant_infer.tar && tar xf ch_PP-OCRv2_rec_slim_quant_infer.tar
# 下载超轻量级英文表格预测模型并解压
# Download the ultra-lightweight English table structure model and unzip it
wget https://paddleocr.bj.bcebos.com/dygraph_v2.0/table/en_ppocr_mobile_v2.0_table_structure_infer.tar && tar xf en_ppocr_mobile_v2.0_table_structure_infer.tar
cd ..
```
<a name="1.1"></a>
### 1.1 layout analysis + table recognition
```bash
python3 predict_system.py --det_model_dir=inference/ch_PP-OCRv2_det_slim_quant_infer \
--rec_model_dir=inference/ch_PP-OCRv2_rec_slim_quant_infer \
--table_model_dir=inference/en_ppocr_mobile_v2.0_table_structure_infer \
--image_dir=../doc/table/1.png \
--image_dir=./docs/table/1.png \
--rec_char_dict_path=../ppocr/utils/ppocr_keys_v1.txt \
--table_char_dict_path=../ppocr/utils/dict/table_structure_dict.txt \
--output=../output/table \
--output=../output \
--vis_font_path=../doc/fonts/simfang.ttf
```
运行完成后,每张图片会在`output`字段指定的目录下的`talbe`目录下有一个同名目录,图片里的每个表格会存储为一个excel,图片区域会被裁剪之后保存下来,excel文件和图片名名为表格在图片里的坐标。
After the operation is completed, each image will have a directory with the same name in the `structure` directory under the directory specified by the `output` field. Each table in the image will be stored as an excel, and the picture area will be cropped and saved. The filename of excel and picture is their coordinates in the image. Detailed results are stored in the `res.txt` file.
<a name="1.2"></a>
### 1.2 layout analysis
```bash
python3 predict_system.py --image_dir=./docs/table/1.png --table=false --ocr=false --output=../output/
```
After the operation is completed, each image will have a directory with the same name in the `structure` directory under the directory specified by the `output` field. Each picture in image will be cropped and saved. The filename of picture area is their coordinates in the image. Layout analysis results will be stored in the `res.txt` file
<a name="1.3"></a>
### 1.3 table recognition
```bash
python3 predict_system.py --det_model_dir=inference/ch_PP-OCRv2_det_slim_quant_infer \
--rec_model_dir=inference/ch_PP-OCRv2_rec_slim_quant_infer \
--table_model_dir=inference/en_ppocr_mobile_v2.0_table_structure_infer \
--image_dir=./docs/table/table.jpg \
--rec_char_dict_path=../ppocr/utils/ppocr_keys_v1.txt \
--table_char_dict_path=../ppocr/utils/dict/table_structure_dict.txt \
--output=../output \
--vis_font_path=../doc/fonts/simfang.ttf \
--layout=false
```
After the operation is completed, each image will have a directory with the same name in the `structure` directory under the directory specified by the `output` field. Each table in the image will be stored as an excel. The filename of excel is their coordinates in the image.
<a name="2"></a>
## 2. DocVQA
......@@ -36,9 +68,8 @@ python3 predict_system.py --det_model_dir=inference/ch_PP-OCRv2_det_slim_quant_i
```bash
cd ppstructure
# 下载模型
# download model
mkdir inference && cd inference
# 下载SER xfun 模型并解压
wget https://paddleocr.bj.bcebos.com/pplayout/PP-Layout_v1.0_ser_pretrained.tar && tar xf PP-Layout_v1.0_ser_pretrained.tar
cd ..
......@@ -47,4 +78,4 @@ python3 predict_system.py --model_name_or_path=vqa/PP-Layout_v1.0_ser_pretrained
--image_dir=vqa/images/input/zh_val_0.jpg \
--vis_font_path=../doc/fonts/simfang.ttf
```
运行完成后,每张图片会在`output`字段指定的目录下的`vqa`目录下存放可视化之后的图片,图片名和输入图片名一致。
\ No newline at end of file
After the operation is completed, each image will store the visualized image in the `vqa` directory under the directory specified by the `output` field, and the image name is the same as the input image name.
# PP-Structure 系列模型列表
# PP-Structure Model list
- [1. 版面分析模型](#1)
- [2. OCR和表格识别模型](#2)
- [1. Layout Analysis](#1)
- [2. OCR and Table Recognition](#2)
- [2.1 OCR](#21)
- [2.2 表格识别模型](#22)
- [3. VQA模型](#3)
- [4. KIE模型](#4)
- [2.2 Table Recognition](#22)
- [3. VQA](#3)
- [4. KIE](#4)
<a name="1"></a>
## 1. 版面分析模型
## 1. Layout Analysis
|模型名称|模型简介|下载地址|label_map|
| --- | --- | --- | --- |
| ppyolov2_r50vd_dcn_365e_publaynet | PubLayNet 数据集训练的版面分析模型,可以划分**文字、标题、表格、图片以及列表**5类区域 | [推理模型](https://paddle-model-ecology.bj.bcebos.com/model/layout-parser/ppyolov2_r50vd_dcn_365e_publaynet.tar) / [训练模型](https://paddle-model-ecology.bj.bcebos.com/model/layout-parser/ppyolov2_r50vd_dcn_365e_publaynet_pretrained.pdparams) |{0: "Text", 1: "Title", 2: "List", 3:"Table", 4:"Figure"}|
| ppyolov2_r50vd_dcn_365e_tableBank_word | TableBank Word 数据集训练的版面分析模型,只能检测表格 | [推理模型](https://paddle-model-ecology.bj.bcebos.com/model/layout-parser/ppyolov2_r50vd_dcn_365e_tableBank_word.tar) | {0:"Table"}|
| ppyolov2_r50vd_dcn_365e_tableBank_latex | TableBank Latex 数据集训练的版面分析模型,只能检测表格 | [推理模型](https://paddle-model-ecology.bj.bcebos.com/model/layout-parser/ppyolov2_r50vd_dcn_365e_tableBank_latex.tar) | {0:"Table"}|
|model name| description |download|label_map|
| --- |---------------------------------------------------------------------------------------------------------------------------------------------------------| --- | --- |
| ppyolov2_r50vd_dcn_365e_publaynet | The layout analysis model trained on the PubLayNet dataset, the model can recognition 5 types of areas such as **text, title, table, picture and list** | [inference model](https://paddle-model-ecology.bj.bcebos.com/model/layout-parser/ppyolov2_r50vd_dcn_365e_publaynet.tar) / [trained model](https://paddle-model-ecology.bj.bcebos.com/model/layout-parser/ppyolov2_r50vd_dcn_365e_publaynet_pretrained.pdparams) |{0: "Text", 1: "Title", 2: "List", 3:"Table", 4:"Figure"}|
| ppyolov2_r50vd_dcn_365e_tableBank_word | The layout analysis model trained on the TableBank Word dataset, the model can only detect tables | [inference model](https://paddle-model-ecology.bj.bcebos.com/model/layout-parser/ppyolov2_r50vd_dcn_365e_tableBank_word.tar) | {0:"Table"}|
| ppyolov2_r50vd_dcn_365e_tableBank_latex | The layout analysis model trained on the TableBank Latex dataset, the model can only detect tables | [inference model](https://paddle-model-ecology.bj.bcebos.com/model/layout-parser/ppyolov2_r50vd_dcn_365e_tableBank_latex.tar) | {0:"Table"}|
<a name="2"></a>
## 2. OCR和表格识别模型
## 2. OCR and Table Recognition
<a name="21"></a>
### 2.1 OCR
|模型名称|模型简介|推理模型大小|下载地址|
| --- | --- | --- | --- |
|en_ppocr_mobile_v2.0_table_det|PubLayNet数据集训练的英文表格场景的文字检测|4.7M|[推理模型](https://paddleocr.bj.bcebos.com/dygraph_v2.0/table/en_ppocr_mobile_v2.0_table_det_infer.tar) / [训练模型](https://paddleocr.bj.bcebos.com/dygraph_v2.1/table/en_ppocr_mobile_v2.0_table_det_train.tar) |
|en_ppocr_mobile_v2.0_table_rec|PubLayNet数据集训练的英文表格场景的文字识别|6.9M|[推理模型](https://paddleocr.bj.bcebos.com/dygraph_v2.0/table/en_ppocr_mobile_v2.0_table_rec_infer.tar) / [训练模型](https://paddleocr.bj.bcebos.com/dygraph_v2.1/table/en_ppocr_mobile_v2.0_table_rec_train.tar) |
|model name| description | inference model size |download|
| --- |---|---| --- |
|en_ppocr_mobile_v2.0_table_det| Text detection model of English table scenes trained on PubTabNet dataset | 4.7M |[inference model](https://paddleocr.bj.bcebos.com/dygraph_v2.0/table/en_ppocr_mobile_v2.0_table_det_infer.tar) / [trained model](https://paddleocr.bj.bcebos.com/dygraph_v2.1/table/en_ppocr_mobile_v2.0_table_det_train.tar) |
|en_ppocr_mobile_v2.0_table_rec| Text recognition model of English table scenes trained on PubTabNet dataset | 6.9M |[inference model](https://paddleocr.bj.bcebos.com/dygraph_v2.0/table/en_ppocr_mobile_v2.0_table_rec_infer.tar) / [trained model](https://paddleocr.bj.bcebos.com/dygraph_v2.1/table/en_ppocr_mobile_v2.0_table_rec_train.tar) |
如需要使用其他OCR模型,可以在 [PP-OCR model_list](../../doc/doc_ch/models_list.md) 下载模型或者使用自己训练好的模型配置到 `det_model_dir`, `rec_model_dir`两个字段即可。
If you need to use other OCR models, you can download the model in [PP-OCR model_list](../../doc/doc_ch/models_list.md) or use the model you trained yourself to configure to `det_model_dir`, `rec_model_dir` field.
<a name="22"></a>
### 2.2 表格识别模型
### 2.2 Table Recognition
|模型名称|模型简介|推理模型大小|下载地址|
| --- | --- | --- | --- |
|en_ppocr_mobile_v2.0_table_structure|PubLayNet数据集训练的英文表格场景的表格结构预测|18.6M|[推理模型](https://paddleocr.bj.bcebos.com/dygraph_v2.0/table/en_ppocr_mobile_v2.0_table_structure_infer.tar) / [训练模型](https://paddleocr.bj.bcebos.com/dygraph_v2.1/table/en_ppocr_mobile_v2.0_table_structure_train.tar) |
|model| description |inference model size|download|
| --- |-----------------------------------------------------------------------------| --- | --- |
|en_ppocr_mobile_v2.0_table_structure| Table structure model for English table scenes trained on PubTabNet dataset |18.6M|[inference model](https://paddleocr.bj.bcebos.com/dygraph_v2.0/table/en_ppocr_mobile_v2.0_table_structure_infer.tar) / [trained model](https://paddleocr.bj.bcebos.com/dygraph_v2.1/table/en_ppocr_mobile_v2.0_table_structure_train.tar) |
<a name="3"></a>
## 3. VQA模型
## 3. VQA
|模型名称|模型简介|推理模型大小|下载地址|
| --- | --- | --- | --- |
|ser_LayoutXLM_xfun_zh|基于LayoutXLM在xfun中文数据集上训练的SER模型|1.4G|[推理模型 coming soon]() / [训练模型](https://paddleocr.bj.bcebos.com/pplayout/re_LayoutXLM_xfun_zh.tar) |
|re_LayoutXLM_xfun_zh|基于LayoutXLM在xfun中文数据集上训练的RE模型|1.4G|[推理模型 coming soon]() / [训练模型](https://paddleocr.bj.bcebos.com/pplayout/ser_LayoutXLM_xfun_zh.tar) |
|ser_LayoutLMv2_xfun_zh|基于LayoutLMv2在xfun中文数据集上训练的SER模型|778M|[推理模型 coming soon]() / [训练模型](https://paddleocr.bj.bcebos.com/pplayout/ser_LayoutLMv2_xfun_zh.tar) |
|re_LayoutLMv2_xfun_zh|基于LayoutLMv2在xfun中文数据集上训练的RE模型|765M|[推理模型 coming soon]() / [训练模型](https://paddleocr.bj.bcebos.com/pplayout/re_LayoutLMv2_xfun_zh.tar) |
|ser_LayoutLM_xfun_zh|基于LayoutLM在xfun中文数据集上训练的SER模型|430M|[推理模型 coming soon]() / [训练模型](https://paddleocr.bj.bcebos.com/pplayout/ser_LayoutLM_xfun_zh.tar) |
|model| description |inference model size|download|
| --- |----------------------------------------------------------------| --- | --- |
|ser_LayoutXLM_xfun_zh| SER model trained on xfun Chinese dataset based on LayoutXLM |1.4G|[inference model coming soon]() / [trained model](https://paddleocr.bj.bcebos.com/pplayout/re_LayoutXLM_xfun_zh.tar) |
|re_LayoutXLM_xfun_zh| Re model trained on xfun Chinese dataset based on LayoutXLM |1.4G|[inference model coming soon]() / [trained model](https://paddleocr.bj.bcebos.com/pplayout/ser_LayoutXLM_xfun_zh.tar) |
|ser_LayoutLMv2_xfun_zh| SER model trained on xfun Chinese dataset based on LayoutXLMv2 |778M|[inference model coming soon]() / [trained model](https://paddleocr.bj.bcebos.com/pplayout/ser_LayoutLMv2_xfun_zh.tar) |
|re_LayoutLMv2_xfun_zh| Re model trained on xfun Chinese dataset based on LayoutXLMv2 |765M|[inference model coming soon]() / [trained model](https://paddleocr.bj.bcebos.com/pplayout/re_LayoutLMv2_xfun_zh.tar) |
|ser_LayoutLM_xfun_zh| SER model trained on xfun Chinese dataset based on LayoutLM |430M|[inference model coming soon]() / [trained model](https://paddleocr.bj.bcebos.com/pplayout/ser_LayoutLM_xfun_zh.tar) |
<a name="4"></a>
## 4. KIE模型
## 4. KIE
|模型名称|模型简介|模型大小|下载地址|
|model|description|model size|download|
| --- | --- | --- | --- |
|SDMGR|关键信息提取模型|78M|[推理模型 coming soon]() / [训练模型](https://paddleocr.bj.bcebos.com/dygraph_v2.1/kie/kie_vgg16.tar)|
|SDMGR|Key Information Extraction Model|78M|[inference model coming soon]() / [trained model](https://paddleocr.bj.bcebos.com/dygraph_v2.1/kie/kie_vgg16.tar)|
......@@ -4,10 +4,14 @@
- [2. 便捷使用](#2)
- [2.1 命令行使用](#21)
- [2.1.1 版面分析+表格识别](#211)
- [2.1.2 DocVQA](#212)
- [2.2 Python脚本使用](#22)
- [2.1.2 版面分析](#212)
- [2.1.3 表格识别](#213)
- [2.1.4 DocVQA](#214)
- [2.2 代码使用](#22)
- [2.2.1 版面分析+表格识别](#221)
- [2.2.2 DocVQA](#222)
- [2.2.2 版面分析](#222)
- [2.2.3 表格识别](#223)
- [2.2.4 DocVQA](#224)
- [2.3 返回结果说明](#23)
- [2.3.1 版面分析+表格识别](#231)
- [2.3.2 DocVQA](#232)
......@@ -18,10 +22,10 @@
## 1. 安装依赖包
```bash
# 安装 paddleocr,推荐使用2.3.0.2+版本
pip3 install "paddleocr>=2.3.0.2"
# 安装 paddleocr,推荐使用2.5+版本
pip3 install "paddleocr>=2.5"
# 安装 版面分析依赖包layoutparser(如不需要版面分析功能,可跳过)
pip3 install -U https://paddleocr.bj.bcebos.com/whl/layoutparser-0.0.0-py3-none-any.whl
pip3 install -U https://paddleocr.bj.bcebos.com/whl/layoutparser-0.0.0-py3-none-any.whl
# 安装 DocVQA依赖包paddlenlp(如不需要DocVQA功能,可跳过)
pip install paddlenlp
......@@ -32,20 +36,32 @@ pip install paddlenlp
<a name="21"></a>
### 2.1 命令行使用
<a name="211"></a>
#### 2.1.1 版面分析+表格识别
```bash
paddleocr --image_dir=../doc/table/1.png --type=structure
paddleocr --image_dir=PaddleOCR/ppstructure/docs/table/1.png --type=structure
```
<a name="212"></a>
#### 2.1.2 DocVQA
#### 2.1.2 版面分析
```bash
paddleocr --image_dir=PaddleOCR/ppstructure/docs/table/1.png --type=structure --table=false --ocr=false
```
<a name="213"></a>
#### 2.1.3 表格识别
```bash
paddleocr --image_dir=PaddleOCR/ppstructure/docs/table/table.jpg --type=structure --layout=false
```
<a name="214"></a>
#### 2.1.4 DocVQA
请参考:[文档视觉问答](../vqa/README.md)
<a name="22"></a>
### 2.2 Python脚本使用
### 2.2 代码使用
<a name="221"></a>
#### 2.2.1 版面分析+表格识别
......@@ -57,8 +73,8 @@ from paddleocr import PPStructure,draw_structure_result,save_structure_res
table_engine = PPStructure(show_log=True)
save_folder = './output/table'
img_path = '../doc/table/1.png'
save_folder = './output'
img_path = 'PaddleOCR/ppstructure/docs/table/1.png'
img = cv2.imread(img_path)
result = table_engine(img)
save_structure_res(result, save_folder,os.path.basename(img_path).split('.')[0])
......@@ -69,7 +85,7 @@ for line in result:
from PIL import Image
font_path = '../doc/fonts/simfang.ttf' # PaddleOCR下提供字体包
font_path = 'PaddleOCR/doc/fonts/simfang.ttf' # PaddleOCR下提供字体包
image = Image.open(img_path).convert('RGB')
im_show = draw_structure_result(image, result,font_path=font_path)
im_show = Image.fromarray(im_show)
......@@ -77,7 +93,49 @@ im_show.save('result.jpg')
```
<a name="222"></a>
#### 2.2.2 DocVQA
#### 2.2.2 版面分析
```python
import os
import cv2
from paddleocr import PPStructure,save_structure_res
table_engine = PPStructure(table=False, ocr=False, show_log=True)
save_folder = './output'
img_path = 'PaddleOCR/ppstructure/docs/table/1.png'
img = cv2.imread(img_path)
result = table_engine(img)
save_structure_res(result, save_folder, os.path.basename(img_path).split('.')[0])
for line in result:
line.pop('img')
print(line)
```
<a name="223"></a>
#### 2.2.3 表格识别
```python
import os
import cv2
from paddleocr import PPStructure,save_structure_res
table_engine = PPStructure(layout=False, show_log=True)
save_folder = './output'
img_path = 'PaddleOCR/ppstructure/docs/table/table.jpg'
img = cv2.imread(img_path)
result = table_engine(img)
save_structure_res(result, save_folder, os.path.basename(img_path).split('.')[0])
for line in result:
line.pop('img')
print(line)
```
<a name="224"></a>
#### 2.2.4 DocVQA
请参考:[文档视觉问答](../vqa/README.md)
......@@ -98,11 +156,11 @@ PP-Structure的返回结果为一个dict组成的list,示例如下
```
dict 里各个字段说明如下
| 字段 | 说明 |
| --------------- | -------------|
|type|图片区域的类型|
|bbox|图片区域的在原图的坐标,分别[左上角x,左上角y,右下角x,右下角y]|
|res|图片区域的OCR或表格识别结果。<br> 表格: 表格的HTML字符串; <br> OCR: 一个包含各个单行文字的检测坐标和识别结果的元组|
| 字段 | 说明 |
| --------------- |-----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------|
|type| 图片区域的类型 |
|bbox| 图片区域的在原图的坐标,分别[左上角x,左上角y,右下角x,右下角y] |
|res| 图片区域的OCR或表格识别结果。<br> 表格: 一个dict,字段说明如下<br>&emsp;&emsp;&emsp;&emsp;&emsp;&emsp;&emsp; `html`: 表格的HTML字符串<br>&emsp;&emsp;&emsp;&emsp;&emsp;&emsp;&emsp; 在代码使用模式下,前向传入return_ocr_result_in_table=True可以拿到表格中每个文本的检测识别结果,对应为如下字段: <br>&emsp;&emsp;&emsp;&emsp;&emsp;&emsp;&emsp; `boxes`: 文本检测坐标<br>&emsp;&emsp;&emsp;&emsp;&emsp;&emsp;&emsp; `rec_res`: 文本识别结果。<br> OCR: 一个包含各个单行文字的检测坐标和识别结果的元组 |
运行完成后,每张图片会在`output`字段指定的目录下有一个同名目录,图片里的每个表格会存储为一个excel,图片区域会被裁剪之后保存下来,excel文件和图片名为表格在图片里的坐标。
......@@ -110,8 +168,8 @@ dict 里各个字段说明如下
/output/table/1/
└─ res.txt
└─ [454, 360, 824, 658].xlsx 表格识别结果
└─ [16, 2, 828, 305].jpg 被裁剪出的图片区域
└─ [17, 361, 404, 711].xlsx 表格识别结果
└─ [16, 2, 828, 305].jpg 被裁剪出的图片区域
└─ [17, 361, 404, 711].xlsx 表格识别结果
```
<a name="232"></a>
......@@ -122,17 +180,19 @@ dict 里各个字段说明如下
<a name="24"></a>
### 2.4 参数说明
| 字段 | 说明 | 默认值 |
| --------------- | ---------------------------------------- | ------------------------------------------- |
| output | excel和识别结果保存的地址 | ./output/table |
| table_max_len | 表格结构模型预测时,图像的长边resize尺度 | 488 |
| table_model_dir | 表格结构模型 inference 模型地址 | None |
| table_char_dict_path | 表格结构模型所用字典地址 | ../ppocr/utils/dict/table_structure_dict.txt |
| layout_path_model | 版面分析模型模型地址,可以为在线地址或者本地地址,当为本地地址时,需要指定 layout_label_map, 命令行模式下可通过--layout_label_map='{0: "Text", 1: "Title", 2: "List", 3:"Table", 4:"Figure"}' 指定 | lp://PubLayNet/ppyolov2_r50vd_dcn_365e_publaynet/config |
| layout_label_map | 版面分析模型模型label映射字典 | None |
| model_name_or_path | VQA SER模型地址 | None |
| max_seq_length | VQA SER模型最大支持token长度 | 512 |
| label_map_path | VQA SER 标签文件地址 | ./vqa/labels/labels_ser.txt |
| mode | pipeline预测模式,structure: 版面分析+表格识别; VQA: SER文档信息抽取 | structure |
| 字段 | 说明 | 默认值 |
|----------------------|----------------------------------------------------------------------------------------------------------------------------------------------------|---------------------------------------------------------|
| output | excel和识别结果保存的地址 | ./output/table |
| table_max_len | 表格结构模型预测时,图像的长边resize尺度 | 488 |
| table_model_dir | 表格结构模型 inference 模型地址 | None |
| table_char_dict_path | 表格结构模型所用字典地址 | ../ppocr/utils/dict/table_structure_dict.txt |
| layout_path_model | 版面分析模型模型地址,可以为在线地址或者本地地址,当为本地地址时,需要指定 layout_label_map, 命令行模式下可通过--layout_label_map='{0: "Text", 1: "Title", 2: "List", 3:"Table", 4:"Figure"}' 指定 | lp://PubLayNet/ppyolov2_r50vd_dcn_365e_publaynet/config |
| layout_label_map | 版面分析模型模型label映射字典 | None |
| model_name_or_path | VQA SER模型地址 | None |
| max_seq_length | VQA SER模型最大支持token长度 | 512 |
| label_map_path | VQA SER 标签文件地址 | ./vqa/labels/labels_ser.txt |
| layout | 前向中是否执行版面分析 | True |
| table | 前向中是否执行表格识别 | True |
| ocr | 对于版面分析中的非表格区域,是否执行ocr。当layout为False时会被自动设置为False | True |
大部分参数和PaddleOCR whl包保持一致,见 [whl包文档](../../doc/doc_ch/whl.md)
# PP-Structure 快速开始
- [1. 安装依赖包](#1)
- [2. 便捷使用](#2)
- [2.1 命令行使用](#21)
- [2.1.1 版面分析+表格识别](#211)
- [2.1.2 DocVQA](#212)
- [2.2 Python脚本使用](#22)
- [2.2.1 版面分析+表格识别](#221)
- [2.2.2 DocVQA](#222)
- [2.3 返回结果说明](#23)
- [2.3.1 版面分析+表格识别](#231)
# PP-Structure Quick Start
- [1. Install package](#1)
- [2. Use](#2)
- [2.1 Use by command line](#21)
- [2.1.1 layout analysis + table recognition](#211)
- [2.1.2 layout analysis](#212)
- [2.1.3 table recognition](#213)
- [2.1.4 DocVQA](#214)
- [2.2 Use by code](#22)
- [2.2.1 layout analysis + table recognition](#221)
- [2.2.2 layout analysis](#222)
- [2.2.3 table recognition](#223)
- [2.2.4 DocVQA](#224)
- [2.3 Result description](#23)
- [2.3.1 layout analysis + table recognition](#231)
- [2.3.2 DocVQA](#232)
- [2.4 参数说明](#24)
- [2.4 Parameter Description](#24)
<a name="1"></a>
## 1. 安装依赖包
## 1. Install package
```bash
# 安装 paddleocr,推荐使用2.3.0.2+版本
pip3 install "paddleocr>=2.3.0.2"
# 安装 版面分析依赖包layoutparser(如不需要版面分析功能,可跳过)
pip3 install -U https://paddleocr.bj.bcebos.com/whl/layoutparser-0.0.0-py3-none-any.whl
# 安装 DocVQA依赖包paddlenlp(如不需要DocVQA功能,可跳过)
# Install paddleocr, version 2.5+ is recommended
pip3 install "paddleocr>=2.5"
# Install layoutparser (if you do not use the layout analysis, you can skip it)
pip3 install -U https://paddleocr.bj.bcebos.com/whl/layoutparser-0.0.0-py3-none-any.whl
# Install the DocVQA dependency package paddlenlp (if you do not use the DocVQA, you can skip it)
pip install paddlenlp
```
<a name="2"></a>
## 2. 便捷使用
## 2. Use
<a name="21"></a>
### 2.1 命令行使用
### 2.1 Use by command line
<a name="211"></a>
#### 2.1.1 版面分析+表格识别
#### 2.1.1 layout analysis + table recognition
```bash
paddleocr --image_dir=../doc/table/1.png --type=structure
paddleocr --image_dir=PaddleOCR/ppstructure/docs/table/1.png --type=structure
```
<a name="212"></a>
#### 2.1.2 DocVQA
#### 2.1.2 layout analysis
```bash
paddleocr --image_dir=PaddleOCR/ppstructure/docs/table/1.png --type=structure --table=false --ocr=false
```
<a name="213"></a>
#### 2.1.3 table recognition
```bash
paddleocr --image_dir=PaddleOCR/ppstructure/docs/table/table.jpg --type=structure --layout=false
```
<a name="214"></a>
#### 2.1.4 DocVQA
请参考:[文档视觉问答](../vqa/README.md)
Please refer to: [Documentation Visual Q&A](../vqa/README.md) .
<a name="22"></a>
### 2.2 Python脚本使用
### 2.2 Use by code
<a name="221"></a>
#### 2.2.1 版面分析+表格识别
#### 2.2.1 layout analysis + table recognition
```python
import os
......@@ -57,8 +73,8 @@ from paddleocr import PPStructure,draw_structure_result,save_structure_res
table_engine = PPStructure(show_log=True)
save_folder = './output/table'
img_path = '../doc/table/1.png'
save_folder = './output'
img_path = 'PaddleOCR/ppstructure/docs/table/1.png'
img = cv2.imread(img_path)
result = table_engine(img)
save_structure_res(result, save_folder,os.path.basename(img_path).split('.')[0])
......@@ -69,7 +85,7 @@ for line in result:
from PIL import Image
font_path = '../doc/fonts/simfang.ttf' # PaddleOCR下提供字体包
font_path = 'PaddleOCR/doc/fonts/simfang.ttf' # PaddleOCR下提供字体包
image = Image.open(img_path).convert('RGB')
im_show = draw_structure_result(image, result,font_path=font_path)
im_show = Image.fromarray(im_show)
......@@ -77,16 +93,59 @@ im_show.save('result.jpg')
```
<a name="222"></a>
#### 2.2.2 DocVQA
#### 2.2.2 layout analysis
请参考:[文档视觉问答](../vqa/README.md)
```python
import os
import cv2
from paddleocr import PPStructure,save_structure_res
table_engine = PPStructure(table=False, ocr=False, show_log=True)
save_folder = './output'
img_path = 'PaddleOCR/ppstructure/docs/table/1.png'
img = cv2.imread(img_path)
result = table_engine(img)
save_structure_res(result, save_folder, os.path.basename(img_path).split('.')[0])
for line in result:
line.pop('img')
print(line)
```
<a name="223"></a>
#### 2.2.3 table recognition
```python
import os
import cv2
from paddleocr import PPStructure,save_structure_res
table_engine = PPStructure(layout=False, show_log=True)
save_folder = './output'
img_path = 'PaddleOCR/ppstructure/docs/table/table.jpg'
img = cv2.imread(img_path)
result = table_engine(img)
save_structure_res(result, save_folder, os.path.basename(img_path).split('.')[0])
for line in result:
line.pop('img')
print(line)
```
<a name="224"></a>
#### 2.2.4 DocVQA
Please refer to: [Documentation Visual Q&A](../vqa/README.md) .
<a name="23"></a>
### 2.3 返回结果说明
PP-Structure的返回结果为一个dict组成的list,示例如下
### 2.3 Result description
The return of PP-Structure is a list of dicts, the example is as follows:
<a name="231"></a>
#### 2.3.1 版面分析+表格识别
#### 2.3.1 layout analysis + table recognition
```shell
[
{ 'type': 'Text',
......@@ -96,43 +155,44 @@ PP-Structure的返回结果为一个dict组成的list,示例如下
}
]
```
dict 里各个字段说明如下
| 字段 | 说明 |
| --------------- | -------------|
|type|图片区域的类型|
|bbox|图片区域的在原图的坐标,分别[左上角x,左上角y,右下角x,右下角y]|
|res|图片区域的OCR或表格识别结果。<br> 表格: 表格的HTML字符串; <br> OCR: 一个包含各个单行文字的检测坐标和识别结果的元组|
Each field in dict is described as follows:
运行完成后,每张图片会在`output`字段指定的目录下有一个同名目录,图片里的每个表格会存储为一个excel,图片区域会被裁剪之后保存下来,excel文件和图片名为表格在图片里的坐标。
| field | description |
| --------------- |--------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------|
|type| Type of image area. |
|bbox| The coordinates of the image area in the original image, respectively [upper left corner x, upper left corner y, lower right corner x, lower right corner y]. |
|res| OCR or table recognition result of the image area. <br> table: a dict with field descriptions as follows: <br>&emsp;&emsp;&emsp;&emsp;&emsp;&emsp;&emsp; `html`: html str of table.<br>&emsp;&emsp;&emsp;&emsp;&emsp;&emsp;&emsp; In the code usage mode, set return_ocr_result_in_table=True whrn call can get the detection and recognition results of each text in the table area, corresponding to the following fields: <br>&emsp;&emsp;&emsp;&emsp;&emsp;&emsp;&emsp; `boxes`: text detection boxes.<br>&emsp;&emsp;&emsp;&emsp;&emsp;&emsp;&emsp; `rec_res`: text recognition results.<br> OCR: A tuple containing the detection boxes and recognition results of each single text. |
After the recognition is completed, each image will have a directory with the same name under the directory specified by the `output` field. Each table in the image will be stored as an excel, and the picture area will be cropped and saved. The filename of excel and picture is their coordinates in the image.
```
/output/table/1/
└─ res.txt
└─ [454, 360, 824, 658].xlsx 表格识别结果
└─ [16, 2, 828, 305].jpg 被裁剪出的图片区域
└─ [17, 361, 404, 711].xlsx 表格识别结果
└─ [454, 360, 824, 658].xlsx table recognition result
└─ [16, 2, 828, 305].jpg picture in Image
└─ [17, 361, 404, 711].xlsx table recognition result
```
<a name="232"></a>
#### 2.3.2 DocVQA
请参考:[文档视觉问答](../vqa/README.md)
Please refer to: [Documentation Visual Q&A](../vqa/README.md) .
<a name="24"></a>
### 2.4 参数说明
| 字段 | 说明 | 默认值 |
| --------------- | ---------------------------------------- | ------------------------------------------- |
| output | excel和识别结果保存的地址 | ./output/table |
| table_max_len | 表格结构模型预测时,图像的长边resize尺度 | 488 |
| table_model_dir | 表格结构模型 inference 模型地址 | None |
| table_char_dict_path | 表格结构模型所用字典地址 | ../ppocr/utils/dict/table_structure_dict.txt |
| layout_path_model | 版面分析模型模型地址,可以为在线地址或者本地地址,当为本地地址时,需要指定 layout_label_map, 命令行模式下可通过--layout_label_map='{0: "Text", 1: "Title", 2: "List", 3:"Table", 4:"Figure"}' 指定 | lp://PubLayNet/ppyolov2_r50vd_dcn_365e_publaynet/config |
| layout_label_map | 版面分析模型模型label映射字典 | None |
| model_name_or_path | VQA SER模型地址 | None |
| max_seq_length | VQA SER模型最大支持token长度 | 512 |
| label_map_path | VQA SER 标签文件地址 | ./vqa/labels/labels_ser.txt |
| mode | pipeline预测模式,structure: 版面分析+表格识别; VQA: SER文档信息抽取 | structure |
大部分参数和PaddleOCR whl包保持一致,见 [whl包文档](../../doc/doc_ch/whl.md)
### 2.4 Parameter Description
| field | description | default |
|----------------------|------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------|---------------------------------------------------------|
| output | The save path of result | ./output/table |
| table_max_len | When the table structure model predicts, the long side of the image | 488 |
| table_model_dir | the path of table structure model | None |
| table_char_dict_path | the dict path of table structure model | ../ppocr/utils/dict/table_structure_dict.txt |
| layout_path_model | The model path of the layout analysis model, which can be an online address or a local path. When it is a local path, layout_label_map needs to be set. In command line mode, use --layout_label_map='{0: "Text", 1: "Title", 2: "List", 3:"Table", 4:"Figure"}' | lp://PubLayNet/ppyolov2_r50vd_dcn_365e_publaynet/config |
| layout_label_map | Layout analysis model model label mapping dictionary path | None |
| model_name_or_path | the model path of VQA SER model | None |
| max_seq_length | the max token length of VQA SER model | 512 |
| label_map_path | the label path of VQA SER model | ./vqa/labels/labels_ser.txt |
| layout | Whether to perform layout analysis in forward | True |
| table | Whether to perform table recognition in forward | True |
| ocr | Whether to perform ocr for non-table areas in layout analysis. When layout is False, it will be automatically set to False | True |
Most of the parameters are consistent with the PaddleOCR whl package, see [whl package documentation](../../doc/doc_en/whl.md)
......@@ -23,9 +23,10 @@ sys.path.append(os.path.abspath(os.path.join(__dir__, '..')))
os.environ["FLAGS_allocator_strategy"] = 'auto_growth'
import cv2
import json
import numpy as np
import time
import logging
from copy import deepcopy
from attrdict import AttrDict
from ppocr.utils.utility import get_image_file_list, check_and_read_gif
from ppocr.utils.logging import get_logger
......@@ -40,97 +41,122 @@ class StructureSystem(object):
def __init__(self, args):
self.mode = args.mode
if self.mode == 'structure':
import layoutparser as lp
# args.det_limit_type = 'resize_long'
args.drop_score = 0
if not args.show_log:
logger.setLevel(logging.INFO)
self.text_system = TextSystem(args)
self.table_system = TableSystem(args,
self.text_system.text_detector,
self.text_system.text_recognizer)
config_path = None
model_path = None
if os.path.isdir(args.layout_path_model):
model_path = args.layout_path_model
if args.layout == False and args.ocr == True:
args.ocr = False
logger.warning(
"When args.layout is false, args.ocr is automatically set to false"
)
args.drop_score = 0
# init layout and ocr model
self.text_system = None
if args.layout:
import layoutparser as lp
config_path = None
model_path = None
if os.path.isdir(args.layout_path_model):
model_path = args.layout_path_model
else:
config_path = args.layout_path_model
self.table_layout = lp.PaddleDetectionLayoutModel(
config_path=config_path,
model_path=model_path,
label_map=args.layout_label_map,
threshold=0.5,
enable_mkldnn=args.enable_mkldnn,
enforce_cpu=not args.use_gpu,
thread_num=args.cpu_threads)
if args.ocr:
self.text_system = TextSystem(args)
else:
self.table_layout = None
if args.table:
if self.text_system is not None:
self.table_system = TableSystem(
args, self.text_system.text_detector,
self.text_system.text_recognizer)
else:
self.table_system = TableSystem(args)
else:
config_path = args.layout_path_model
self.table_layout = lp.PaddleDetectionLayoutModel(
config_path=config_path,
model_path=model_path,
label_map=args.layout_label_map,
threshold=0.5,
enable_mkldnn=args.enable_mkldnn,
enforce_cpu=not args.use_gpu,
thread_num=args.cpu_threads)
self.use_angle_cls = args.use_angle_cls
self.drop_score = args.drop_score
self.table_system = None
elif self.mode == 'vqa':
raise NotImplementedError
def __call__(self, img):
def __call__(self, img, return_ocr_result_in_table=False):
if self.mode == 'structure':
ori_im = img.copy()
layout_res = self.table_layout.detect(img[..., ::-1])
if self.table_layout is not None:
layout_res = self.table_layout.detect(img[..., ::-1])
else:
h, w = ori_im.shape[:2]
layout_res = [AttrDict(coordinates=[0, 0, w, h], type='Table')]
res_list = []
for region in layout_res:
res = ''
x1, y1, x2, y2 = region.coordinates
x1, y1, x2, y2 = int(x1), int(y1), int(x2), int(y2)
roi_img = ori_im[y1:y2, x1:x2, :]
if region.type == 'Table':
res = self.table_system(roi_img)
if self.table_system is not None:
res = self.table_system(roi_img,
return_ocr_result_in_table)
else:
filter_boxes, filter_rec_res = self.text_system(roi_img)
# remove style char
style_token = [
'<strike>', '<strike>', '<sup>', '</sub>', '<b>',
'</b>', '<sub>', '</sup>', '<overline>', '</overline>',
'<underline>', '</underline>', '<i>', '</i>'
]
res = []
for box, rec_res in zip(filter_boxes, filter_rec_res):
rec_str, rec_conf = rec_res
for token in style_token:
if token in rec_str:
rec_str = rec_str.replace(token, '')
box += [x1, y1]
res.append({
'text': rec_str,
'confidence': float(rec_conf),
'text_region': box.tolist()
})
if self.text_system is not None:
filter_boxes, filter_rec_res = self.text_system(roi_img)
# remove style char
style_token = [
'<strike>', '<strike>', '<sup>', '</sub>', '<b>',
'</b>', '<sub>', '</sup>', '<overline>',
'</overline>', '<underline>', '</underline>', '<i>',
'</i>'
]
res = []
for box, rec_res in zip(filter_boxes, filter_rec_res):
rec_str, rec_conf = rec_res
for token in style_token:
if token in rec_str:
rec_str = rec_str.replace(token, '')
box += [x1, y1]
res.append({
'text': rec_str,
'confidence': float(rec_conf),
'text_region': box.tolist()
})
res_list.append({
'type': region.type,
'bbox': [x1, y1, x2, y2],
'img': roi_img,
'res': res
})
return res_list
elif self.mode == 'vqa':
raise NotImplementedError
return res_list
return None
def save_structure_res(res, save_folder, img_name):
excel_save_folder = os.path.join(save_folder, img_name)
os.makedirs(excel_save_folder, exist_ok=True)
res_cp = deepcopy(res)
# save res
with open(
os.path.join(excel_save_folder, 'res.txt'), 'w',
encoding='utf8') as f:
for region in res:
if region['type'] == 'Table':
for region in res_cp:
roi_img = region.pop('img')
f.write('{}\n'.format(json.dumps(region)))
if region['type'] == 'Table' and len(region[
'res']) > 0 and 'html' in region['res']:
excel_path = os.path.join(excel_save_folder,
'{}.xlsx'.format(region['bbox']))
to_excel(region['res'], excel_path)
to_excel(region['res']['html'], excel_path)
elif region['type'] == 'Figure':
roi_img = region['img']
img_path = os.path.join(excel_save_folder,
'{}.jpg'.format(region['bbox']))
cv2.imwrite(img_path, roi_img)
else:
for text_result in region['res']:
f.write('{}\n'.format(json.dumps(text_result)))
def main(args):
......
......@@ -51,7 +51,7 @@ wget https://paddleocr.bj.bcebos.com/dygraph_v2.0/table/en_ppocr_mobile_v2.0_tab
wget https://paddleocr.bj.bcebos.com/dygraph_v2.0/table/en_ppocr_mobile_v2.0_table_structure_infer.tar && tar xf en_ppocr_mobile_v2.0_table_structure_infer.tar
cd ..
# run
python3 table/predict_table.py --det_model_dir=inference/en_ppocr_mobile_v2.0_table_det_infer --rec_model_dir=inference/en_ppocr_mobile_v2.0_table_rec_infer --table_model_dir=inference/en_ppocr_mobile_v2.0_table_structure_infer --image_dir=../doc/table/table.jpg --rec_char_dict_path=../ppocr/utils/dict/table_dict.txt --table_char_dict_path=../ppocr/utils/dict/table_structure_dict.txt --det_limit_side_len=736 --det_limit_type=min --output ../output/table
python3 table/predict_table.py --det_model_dir=inference/en_ppocr_mobile_v2.0_table_det_infer --rec_model_dir=inference/en_ppocr_mobile_v2.0_table_rec_infer --table_model_dir=inference/en_ppocr_mobile_v2.0_table_structure_infer --image_dir=./docs/table/table.jpg --rec_char_dict_path=../ppocr/utils/dict/table_dict.txt --table_char_dict_path=../ppocr/utils/dict/table_structure_dict.txt --det_limit_side_len=736 --det_limit_type=min --output ./output/table
```
Note: The above model is trained on the PubLayNet dataset and only supports English scanning scenarios. If you need to identify other scenarios, you need to train the model yourself and replace the three fields `det_model_dir`, `rec_model_dir`, `table_model_dir`.
......
......@@ -61,7 +61,7 @@ wget https://paddleocr.bj.bcebos.com/dygraph_v2.0/table/en_ppocr_mobile_v2.0_tab
wget https://paddleocr.bj.bcebos.com/dygraph_v2.0/table/en_ppocr_mobile_v2.0_table_structure_infer.tar && tar xf en_ppocr_mobile_v2.0_table_structure_infer.tar
cd ..
# 执行预测
python3 table/predict_table.py --det_model_dir=inference/en_ppocr_mobile_v2.0_table_det_infer --rec_model_dir=inference/en_ppocr_mobile_v2.0_table_rec_infer --table_model_dir=inference/en_ppocr_mobile_v2.0_table_structure_infer --image_dir=../doc/table/table.jpg --rec_char_dict_path=../ppocr/utils/dict/table_dict.txt --table_char_dict_path=../ppocr/utils/dict/table_structure_dict.txt --det_limit_side_len=736 --det_limit_type=min --output ../output/table
python3 table/predict_table.py --det_model_dir=inference/en_ppocr_mobile_v2.0_table_det_infer --rec_model_dir=inference/en_ppocr_mobile_v2.0_table_rec_infer --table_model_dir=inference/en_ppocr_mobile_v2.0_table_structure_infer --image_dir=./docs/table/table.jpg --rec_char_dict_path=../ppocr/utils/dict/table_dict.txt --table_char_dict_path=../ppocr/utils/dict/table_structure_dict.txt --det_limit_side_len=736 --det_limit_type=min --output ./output/table
```
运行完成后,每张图片的excel表格会保存到output字段指定的目录下
......
......@@ -54,16 +54,20 @@ def expand(pix, det_box, shape):
class TableSystem(object):
def __init__(self, args, text_detector=None, text_recognizer=None):
self.text_detector = predict_det.TextDetector(args) if text_detector is None else text_detector
self.text_recognizer = predict_rec.TextRecognizer(args) if text_recognizer is None else text_recognizer
self.text_detector = predict_det.TextDetector(
args) if text_detector is None else text_detector
self.text_recognizer = predict_rec.TextRecognizer(
args) if text_recognizer is None else text_recognizer
self.table_structurer = predict_strture.TableStructurer(args)
def __call__(self, img):
def __call__(self, img, return_ocr_result_in_table=False):
result = dict()
ori_im = img.copy()
structure_res, elapse = self.table_structurer(copy.deepcopy(img))
dt_boxes, elapse = self.text_detector(copy.deepcopy(img))
dt_boxes = sorted_boxes(dt_boxes)
if return_ocr_result_in_table:
result['boxes'] = [x.tolist() for x in dt_boxes]
r_boxes = []
for box in dt_boxes:
x_min = box[:, 0].min() - 1
......@@ -88,14 +92,17 @@ class TableSystem(object):
rec_res, elapse = self.text_recognizer(img_crop_list)
logger.debug("rec_res num : {}, elapse : {}".format(
len(rec_res), elapse))
if return_ocr_result_in_table:
result['rec_res'] = rec_res
pred_html, pred = self.rebuild_table(structure_res, dt_boxes, rec_res)
return pred_html
result['html'] = pred_html
return result
def rebuild_table(self, structure_res, dt_boxes, rec_res):
pred_structures, pred_bboxes = structure_res
matched_index = self.match_result(dt_boxes, pred_bboxes)
pred_html, pred = self.get_pred_html(pred_structures, matched_index, rec_res)
pred_html, pred = self.get_pred_html(pred_structures, matched_index,
rec_res)
return pred_html, pred
def match_result(self, dt_boxes, pred_bboxes):
......@@ -104,11 +111,13 @@ class TableSystem(object):
# gt_box = [np.min(gt_box[:, 0]), np.min(gt_box[:, 1]), np.max(gt_box[:, 0]), np.max(gt_box[:, 1])]
distances = []
for j, pred_box in enumerate(pred_bboxes):
distances.append(
(distance(gt_box, pred_box), 1. - compute_iou(gt_box, pred_box))) # 获取两两cell之间的L1距离和 1- IOU
distances.append((distance(gt_box, pred_box),
1. - compute_iou(gt_box, pred_box)
)) # 获取两两cell之间的L1距离和 1- IOU
sorted_distances = distances.copy()
# 根据距离和IOU挑选最"近"的cell
sorted_distances = sorted(sorted_distances, key=lambda item: (item[1], item[0]))
sorted_distances = sorted(
sorted_distances, key=lambda item: (item[1], item[0]))
if distances.index(sorted_distances[0]) not in matched.keys():
matched[distances.index(sorted_distances[0])] = [i]
else:
......@@ -122,7 +131,8 @@ class TableSystem(object):
if '</td>' in tag:
if td_index in matched_index.keys():
b_with = False
if '<b>' in ocr_contents[matched_index[td_index][0]] and len(matched_index[td_index]) > 1:
if '<b>' in ocr_contents[matched_index[td_index][
0]] and len(matched_index[td_index]) > 1:
b_with = True
end_html.extend('<b>')
for i, td_index_index in enumerate(matched_index[td_index]):
......@@ -138,7 +148,8 @@ class TableSystem(object):
content = content[:-4]
if len(content) == 0:
continue
if i != len(matched_index[td_index]) - 1 and ' ' != content[-1]:
if i != len(matched_index[
td_index]) - 1 and ' ' != content[-1]:
content += ' '
end_html.extend(content)
if b_with:
......@@ -187,18 +198,19 @@ def main(args):
for i, image_file in enumerate(image_file_list):
logger.info("[{}/{}] {}".format(i, img_num, image_file))
img, flag = check_and_read_gif(image_file)
excel_path = os.path.join(args.output, os.path.basename(image_file).split('.')[0] + '.xlsx')
excel_path = os.path.join(
args.output, os.path.basename(image_file).split('.')[0] + '.xlsx')
if not flag:
img = cv2.imread(image_file)
if img is None:
logger.error("error in loading image:{}".format(image_file))
continue
starttime = time.time()
pred_html = text_sys(img)
pred_res = text_sys(img)
pred_html = pred_res['html']
logger.info(pred_html)
to_excel(pred_html, excel_path)
logger.info('excel saved to {}'.format(excel_path))
logger.info(pred_html)
elapse = time.time() - starttime
logger.info("Predict time : {:.3f}s".format(elapse))
......
......@@ -15,7 +15,7 @@
import ast
from PIL import Image
import numpy as np
from tools.infer.utility import draw_ocr_box_txt, init_args as infer_args
from tools.infer.utility import draw_ocr_box_txt, str2bool, init_args as infer_args
def init_args():
......@@ -30,6 +30,7 @@ def init_args():
"--table_char_dict_path",
type=str,
default="../ppocr/utils/dict/table_structure_dict.txt")
# params for layout
parser.add_argument(
"--layout_path_model",
type=str,
......@@ -39,11 +40,27 @@ def init_args():
type=ast.literal_eval,
default=None,
help='label map according to ppstructure/layout/README_ch.md')
# params for inference
parser.add_argument(
"--mode",
type=str,
default='structure',
help='structure and vqa is supported')
parser.add_argument(
"--layout",
type=str2bool,
default=True,
help='Whether to enable layout analysis')
parser.add_argument(
"--table",
type=str2bool,
default=True,
help='In the forward, whether the table area uses table recognition')
parser.add_argument(
"--ocr",
type=str2bool,
default=True,
help='In the forward, whether the non-table area is recognition by ocr')
return parser
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册