提交 4862d30c 编写于 作者: M MissPenguin

Merge branch 'dygraph' of https://github.com/MissPenguin/PaddleOCR into dygraph

......@@ -86,15 +86,8 @@ PPOCRLabel # [Normal mode] for [detection + recognition] labeling
PPOCRLabel --kie True # [KIE mode] for [detection + recognition + keyword extraction] labeling
```
#### 1.2.2 Build and Install the Whl Package Locally
```bash
cd PaddleOCR/PPOCRLabel
python3 setup.py bdist_wheel
pip3 install dist/PPOCRLabel-1.0.2-py2.py3-none-any.whl
```
#### 1.2.3 Run PPOCRLabel by Python Script
#### 1.2.2 Run PPOCRLabel by Python Script
If you modify the PPOCRLabel file (for example, specifying a new built-in model), it will be more convenient to see the results by running the Python script. If you still want to start with the whl package, you need to uninstall the whl package in the current environment and then recompile it according to the next section.
```bash
cd ./PPOCRLabel # Switch to the PPOCRLabel directory
......@@ -104,6 +97,13 @@ python PPOCRLabel.py # [Normal mode] for [detection + recognition] labeling
python PPOCRLabel.py --kie True # [KIE mode] for [detection + recognition + keyword extraction] labeling
```
#### 1.2.3 Build and Install the Whl Package Locally
Compile and install a new whl package, where 1.0.2 is the version number, you can specify the new version in 'setup.py'.
```bash
cd PaddleOCR/PPOCRLabel
python3 setup.py bdist_wheel
pip3 install dist/PPOCRLabel-1.0.2-py2.py3-none-any.whl
```
## 2. Usage
......
......@@ -86,28 +86,26 @@ PPOCRLabel --lang ch --kie True # 启动 【KIE 模式】,用于打【检测+
> 如果上述安装出现问题,可以参考3.6节 错误提示
#### 1.2.2 本地构建whl包并安装
#### 1.2.2 通过Python脚本运行PPOCRLabel
如果您对PPOCRLabel文件有所更改(例如指定新的内置模型),通过Python脚本运行会更加方便的看到更改的结果。如果仍然需要通过whl包启动,则需要先卸载当前环境中的whl包,然后参考下节重新编译whl包。
```bash
cd PaddleOCR/PPOCRLabel
python3 setup.py bdist_wheel
pip3 install dist/PPOCRLabel-1.0.2-py2.py3-none-any.whl -i https://mirror.baidu.com/pypi/simple
cd ./PPOCRLabel # 切换到PPOCRLabel目录
python PPOCRLabel.py --lang ch
```
#### 1.2.3 通过Python脚本运行PPOCRLabel
#### 1.2.3 本地构建whl包并安装
如果您对PPOCRLabel文件有所更改,通过Python脚本运行会更加方面的看到更改的结果
编译与安装新的whl包,其中1.0.2为版本号,可在 `setup.py` 中指定新版本。
```bash
cd ./PPOCRLabel # 切换到PPOCRLabel目录
# 选择标签模式来启动
python PPOCRLabel.py --lang ch # 启动【普通模式】,用于打【检测+识别】场景的标签
python PPOCRLabel.py --lang ch --kie True # 启动 【KIE 模式】,用于打【检测+识别+关键字提取】场景的标签
cd PaddleOCR/PPOCRLabel
python3 setup.py bdist_wheel
pip3 install dist/PPOCRLabel-1.0.2-py2.py3-none-any.whl -i https://mirror.baidu.com/pypi/simple
```
## 2. 使用
### 2.1 操作步骤
......
// Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <gflags/gflags.h>
// common args
DECLARE_bool(use_gpu);
DECLARE_bool(use_tensorrt);
DECLARE_int32(gpu_id);
DECLARE_int32(gpu_mem);
DECLARE_int32(cpu_threads);
DECLARE_bool(enable_mkldnn);
DECLARE_string(precision);
DECLARE_bool(benchmark);
DECLARE_string(output);
DECLARE_string(image_dir);
DECLARE_string(type);
// detection related
DECLARE_string(det_model_dir);
DECLARE_int32(max_side_len);
DECLARE_double(det_db_thresh);
DECLARE_double(det_db_box_thresh);
DECLARE_double(det_db_unclip_ratio);
DECLARE_bool(use_dilation);
DECLARE_string(det_db_score_mode);
DECLARE_bool(visualize);
// classification related
DECLARE_bool(use_angle_cls);
DECLARE_string(cls_model_dir);
DECLARE_double(cls_thresh);
DECLARE_int32(cls_batch_num);
// recognition related
DECLARE_string(rec_model_dir);
DECLARE_int32(rec_batch_num);
DECLARE_string(rec_char_dict_path);
// forward related
DECLARE_bool(det);
DECLARE_bool(rec);
DECLARE_bool(cls);
......@@ -42,7 +42,8 @@ public:
const int &gpu_id, const int &gpu_mem,
const int &cpu_math_library_num_threads,
const bool &use_mkldnn, const double &cls_thresh,
const bool &use_tensorrt, const std::string &precision) {
const bool &use_tensorrt, const std::string &precision,
const int &cls_batch_num) {
this->use_gpu_ = use_gpu;
this->gpu_id_ = gpu_id;
this->gpu_mem_ = gpu_mem;
......@@ -52,14 +53,17 @@ public:
this->cls_thresh = cls_thresh;
this->use_tensorrt_ = use_tensorrt;
this->precision_ = precision;
this->cls_batch_num_ = cls_batch_num;
LoadModel(model_dir);
}
double cls_thresh = 0.9;
// Load Paddle inference model
void LoadModel(const std::string &model_dir);
cv::Mat Run(cv::Mat &img);
void Run(std::vector<cv::Mat> img_list, std::vector<int> &cls_labels,
std::vector<float> &cls_scores, std::vector<double> &times);
private:
std::shared_ptr<Predictor> predictor_;
......@@ -69,17 +73,17 @@ private:
int gpu_mem_ = 4000;
int cpu_math_library_num_threads_ = 4;
bool use_mkldnn_ = false;
double cls_thresh = 0.5;
std::vector<float> mean_ = {0.5f, 0.5f, 0.5f};
std::vector<float> scale_ = {1 / 0.5f, 1 / 0.5f, 1 / 0.5f};
bool is_scale_ = true;
bool use_tensorrt_ = false;
std::string precision_ = "fp32";
int cls_batch_num_ = 1;
// pre-process
ClsResizeImg resize_op_;
Normalize normalize_op_;
Permute permute_op_;
PermuteBatch permute_op_;
}; // class Classifier
......
......@@ -45,8 +45,9 @@ public:
const double &det_db_thresh,
const double &det_db_box_thresh,
const double &det_db_unclip_ratio,
const bool &use_polygon_score, const bool &use_dilation,
const bool &use_tensorrt, const std::string &precision) {
const std::string &det_db_score_mode,
const bool &use_dilation, const bool &use_tensorrt,
const std::string &precision) {
this->use_gpu_ = use_gpu;
this->gpu_id_ = gpu_id;
this->gpu_mem_ = gpu_mem;
......@@ -58,7 +59,7 @@ public:
this->det_db_thresh_ = det_db_thresh;
this->det_db_box_thresh_ = det_db_box_thresh;
this->det_db_unclip_ratio_ = det_db_unclip_ratio;
this->use_polygon_score_ = use_polygon_score;
this->det_db_score_mode_ = det_db_score_mode;
this->use_dilation_ = use_dilation;
this->use_tensorrt_ = use_tensorrt;
......@@ -72,7 +73,7 @@ public:
// Run predictor
void Run(cv::Mat &img, std::vector<std::vector<std::vector<int>>> &boxes,
std::vector<double> *times);
std::vector<double> &times);
private:
std::shared_ptr<Predictor> predictor_;
......@@ -88,7 +89,7 @@ private:
double det_db_thresh_ = 0.3;
double det_db_box_thresh_ = 0.5;
double det_db_unclip_ratio_ = 2.0;
bool use_polygon_score_ = false;
std::string det_db_score_mode_ = "slow";
bool use_dilation_ = false;
bool visualize_ = true;
......
......@@ -30,7 +30,6 @@
#include <numeric>
#include <include/ocr_cls.h>
#include <include/postprocess_op.h>
#include <include/preprocess_op.h>
#include <include/utility.h>
......@@ -68,7 +67,7 @@ public:
void LoadModel(const std::string &model_dir);
void Run(std::vector<cv::Mat> img_list, std::vector<std::string> &rec_texts,
std::vector<float> &rec_text_scores, std::vector<double> *times);
std::vector<float> &rec_text_scores, std::vector<double> &times);
private:
std::shared_ptr<Predictor> predictor_;
......@@ -93,9 +92,6 @@ private:
Normalize normalize_op_;
PermuteBatch permute_op_;
// post-process
PostProcessor post_processor_;
}; // class CrnnRecognizer
} // namespace PaddleOCR
// Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include "opencv2/core.hpp"
#include "opencv2/imgcodecs.hpp"
#include "opencv2/imgproc.hpp"
#include "paddle_api.h"
#include "paddle_inference_api.h"
#include <chrono>
#include <iomanip>
#include <iostream>
#include <ostream>
#include <vector>
#include <cstring>
#include <fstream>
#include <numeric>
#include <include/ocr_cls.h>
#include <include/ocr_det.h>
#include <include/ocr_rec.h>
#include <include/preprocess_op.h>
#include <include/utility.h>
using namespace paddle_infer;
namespace PaddleOCR {
class PaddleOCR {
public:
explicit PaddleOCR();
~PaddleOCR();
std::vector<std::vector<OCRPredictResult>>
ocr(std::vector<cv::String> cv_all_img_names, bool det = true,
bool rec = true, bool cls = true);
private:
DBDetector *detector_ = nullptr;
Classifier *classifier_ = nullptr;
CRNNRecognizer *recognizer_ = nullptr;
void det(cv::Mat img, std::vector<OCRPredictResult> &ocr_results,
std::vector<double> &times);
void rec(std::vector<cv::Mat> img_list,
std::vector<OCRPredictResult> &ocr_results,
std::vector<double> &times);
void cls(std::vector<cv::Mat> img_list,
std::vector<OCRPredictResult> &ocr_results,
std::vector<double> &times);
void log(std::vector<double> &det_times, std::vector<double> &rec_times,
std::vector<double> &cls_times, int img_num);
};
} // namespace PaddleOCR
......@@ -56,7 +56,7 @@ public:
std::vector<std::vector<std::vector<int>>>
BoxesFromBitmap(const cv::Mat pred, const cv::Mat bitmap,
const float &box_thresh, const float &det_db_unclip_ratio,
const bool &use_polygon_score);
const std::string &det_db_score_mode);
std::vector<std::vector<std::vector<int>>>
FilterTagDetRes(std::vector<std::vector<std::vector<int>>> boxes,
......
......@@ -32,14 +32,21 @@
namespace PaddleOCR {
struct OCRPredictResult {
std::vector<std::vector<int>> box;
std::string text;
float score = -1.0;
float cls_score;
int cls_label = -1;
};
class Utility {
public:
static std::vector<std::string> ReadDict(const std::string &path);
static void
VisualizeBboxes(const cv::Mat &srcimg,
const std::vector<std::vector<std::vector<int>>> &boxes,
const std::string &save_path);
static void VisualizeBboxes(const cv::Mat &srcimg,
const std::vector<OCRPredictResult> &ocr_result,
const std::string &save_path);
template <class ForwardIterator>
inline static size_t argmax(ForwardIterator first, ForwardIterator last) {
......@@ -55,6 +62,10 @@ public:
static std::vector<int> argsort(const std::vector<float> &array);
static std::string basename(const std::string &filename);
static bool PathExists(const std::string &path);
static void print_result(const std::vector<OCRPredictResult> &ocr_result);
};
} // namespace PaddleOCR
\ No newline at end of file
......@@ -9,9 +9,12 @@
- [2.1 将模型导出为inference model](#21-将模型导出为inference-model)
- [2.2 编译PaddleOCR C++预测demo](#22-编译paddleocr-c预测demo)
- [2.3 运行demo](#23-运行demo)
- [1. 只调用检测:](#1-只调用检测)
- [2. 只调用识别:](#2-只调用识别)
- [3. 调用串联:](#3-调用串联)
- [1. 检测+分类+识别:](#1-检测分类识别)
- [2. 检测+识别:](#2-检测识别)
- [3. 检测:](#3-检测)
- [4. 分类+识别:](#4-分类识别)
- [5. 识别:](#5-识别)
- [6. 分类:](#6-分类)
- [3. FAQ](#3-faq)
# 服务器端C++预测
......@@ -181,6 +184,9 @@ inference/
|-- rec_rcnn
| |--inference.pdiparams
| |--inference.pdmodel
|-- cls
| |--inference.pdiparams
| |--inference.pdmodel
```
<a name="22"></a>
......@@ -213,36 +219,71 @@ CUDNN_LIB_DIR=/your_cudnn_lib_dir
运行方式:
```shell
./build/ppocr <mode> [--param1] [--param2] [...]
./build/ppocr [--param1] [--param2] [...]
```
具体命令如下:
##### 1. 检测+分类+识别:
```shell
./build/ppocr --det_model_dir=inference/det_db \
--rec_model_dir=inference/rec_rcnn \
--cls_model_dir=inference/cls \
--image_dir=../../doc/imgs/12.jpg \
--use_angle_cls=true \
--det=true \
--rec=true \
--cls=true \
```
##### 2. 检测+识别:
```shell
./build/ppocr --det_model_dir=inference/det_db \
--rec_model_dir=inference/rec_rcnn \
--image_dir=../../doc/imgs/12.jpg \
--use_angle_cls=false \
--det=true \
--rec=true \
--cls=false \
```
##### 3. 检测:
```shell
./build/ppocr --det_model_dir=inference/det_db \
--image_dir=../../doc/imgs/12.jpg \
--det=true \
--rec=false
```
其中,`mode`为必选参数,表示选择的功能,取值范围['det', 'rec', 'system'],分别表示调用检测、识别、检测识别串联(包括方向分类器)。具体命令如下:
##### 1. 只调用检测
##### 4. 分类+识别
```shell
./build/ppocr det \
--det_model_dir=inference/ch_ppocr_mobile_v2.0_det_infer \
--image_dir=../../doc/imgs/12.jpg
./build/ppocr --rec_model_dir=inference/rec_rcnn \
--cls_model_dir=inference/cls \
--image_dir=../../doc/imgs_words/ch/word_1.jpg \
--use_angle_cls=true \
--det=false \
--rec=true \
--cls=true \
```
##### 2. 只调用识别:
##### 5. 识别:
```shell
./build/ppocr rec \
--rec_model_dir=inference/ch_ppocr_mobile_v2.0_rec_infer \
--image_dir=../../doc/imgs_words/ch/
./build/ppocr --rec_model_dir=inference/rec_rcnn \
--image_dir=../../doc/imgs_words/ch/word_1.jpg \
--use_angle_cls=false \
--det=false \
--rec=true \
--cls=false \
```
##### 3. 调用串联:
##### 6. 分类:
```shell
# 不使用方向分类器
./build/ppocr system \
--det_model_dir=inference/ch_ppocr_mobile_v2.0_det_infer \
--rec_model_dir=inference/ch_ppocr_mobile_v2.0_rec_infer \
--image_dir=../../doc/imgs/12.jpg
# 使用方向分类器
./build/ppocr system \
--det_model_dir=inference/ch_ppocr_mobile_v2.0_det_infer \
./build/ppocr --cls_model_dir=inference/cls \
--cls_model_dir=inference/cls \
--image_dir=../../doc/imgs_words/ch/word_1.jpg \
--use_angle_cls=true \
--cls_model_dir=inference/ch_ppocr_mobile_v2.0_cls_infer \
--rec_model_dir=inference/ch_ppocr_mobile_v2.0_rec_infer \
--image_dir=../../doc/imgs/12.jpg
--det=false \
--rec=false \
--cls=true \
```
更多支持的可调节参数解释如下:
......@@ -258,6 +299,15 @@ CUDNN_LIB_DIR=/your_cudnn_lib_dir
|enable_mkldnn|bool|true|是否使用mkldnn库|
|output|str|./output|可视化结果保存的路径|
- 前向相关
|参数名称|类型|默认参数|意义|
| :---: | :---: | :---: | :---: |
|det|bool|true|前向是否执行文字检测|
|rec|bool|true|前向是否执行文字识别|
|cls|bool|false|前向是否执行文字方向分类|
- 检测模型相关
|参数名称|类型|默认参数|意义|
......@@ -267,7 +317,7 @@ CUDNN_LIB_DIR=/your_cudnn_lib_dir
|det_db_thresh|float|0.3|用于过滤DB预测的二值化图像,设置为0.-0.3对结果影响不明显|
|det_db_box_thresh|float|0.5|DB后处理过滤box的阈值,如果检测存在漏框情况,可酌情减小|
|det_db_unclip_ratio|float|1.6|表示文本框的紧致程度,越小则文本框更靠近文本|
|use_polygon_score|bool|false|是否使用多边形框计算bbox score,false表示使用矩形框计算。矩形框计算速度更快,多边形框对弯曲文本区域计算更准确。|
|det_db_score_mode|string|slow|slow:使用多边形框计算bbox score,fast:使用矩形框计算。矩形框计算速度更快,多边形框对弯曲文本区域计算更准确。|
|visualize|bool|true|是否对结果进行可视化,为1时,预测结果会保存在`output`字段指定的文件夹下和输入图像同名的图像上。|
- 方向分类器相关
......@@ -277,6 +327,7 @@ CUDNN_LIB_DIR=/your_cudnn_lib_dir
|use_angle_cls|bool|false|是否使用方向分类器|
|cls_model_dir|string|-|方向分类器inference model地址|
|cls_thresh|float|0.9|方向分类器的得分阈值|
|cls_batch_num|int|1|方向分类器batchsize|
- 识别模型相关
......@@ -284,15 +335,22 @@ CUDNN_LIB_DIR=/your_cudnn_lib_dir
| :---: | :---: | :---: | :---: |
|rec_model_dir|string|-|识别模型inference model地址|
|rec_char_dict_path|string|../../ppocr/utils/ppocr_keys_v1.txt|字典文件|
|rec_batch_num|int|6|识别模型batchsize|
* PaddleOCR也支持多语言的预测,更多支持的语言和模型可以参考[识别文档](../../doc/doc_ch/recognition.md)中的多语言字典与模型部分,如果希望进行多语言预测,只需将修改`rec_char_dict_path`(字典文件路径)以及`rec_model_dir`(inference模型路径)字段即可。
最终屏幕上会输出检测结果如下。
<div align="center">
<img src="./imgs/cpp_infer_pred_12.png" width="600">
</div>
```bash
predict img: ../../doc/imgs/12.jpg
../../doc/imgs/12.jpg
0 det boxes: [[79,553],[399,541],[400,573],[80,585]] rec text: 打浦路252935号 rec score: 0.933757
1 det boxes: [[31,509],[510,488],[511,529],[33,549]] rec text: 绿洲仕格维花园公寓 rec score: 0.951745
2 det boxes: [[181,456],[395,448],[396,480],[182,488]] rec text: 打浦路15号 rec score: 0.91956
3 det boxes: [[43,413],[480,391],[481,428],[45,450]] rec text: 上海斯格威铂尔多大酒店 rec score: 0.915914
The detection visualized image saved in ./output//12.jpg
```
## 3. FAQ
......
......@@ -9,9 +9,12 @@
- [2.1 Export the inference model](#21-export-the-inference-model)
- [2.2 Compile PaddleOCR C++ inference demo](#22-compile-paddleocr-c-inference-demo)
- [Run the demo](#run-the-demo)
- [1. run det demo:](#1-run-det-demo)
- [2. run rec demo:](#2-run-rec-demo)
- [3. run system demo:](#3-run-system-demo)
- [1. det+cls+rec:](#1-detclsrec)
- [2. det+rec:](#2-detrec)
- [3. det](#3-det)
- [4. cls+rec:](#4-clsrec)
- [5. rec](#5-rec)
- [6. cls](#6-cls)
- [3. FAQ](#3-faq)
# Server-side C++ Inference
......@@ -166,6 +169,9 @@ inference/
|-- rec_rcnn
| |--inference.pdiparams
| |--inference.pdmodel
|-- cls
| |--inference.pdiparams
| |--inference.pdmodel
```
......@@ -198,44 +204,72 @@ or the generated Paddle inference library path (`build/paddle_inference_install_
### Run the demo
Execute the built executable file:
```shell
./build/ppocr <mode> [--param1] [--param2] [...]
./build/ppocr [--param1] [--param2] [...]
```
`mode` is a required parameter,and the valid values are
mode value | Model used
-----|------
det | Detection only
rec | Recognition only
system | End-to-end system
Specifically,
##### 1. run det demo:
##### 1. det+cls+rec:
```shell
./build/ppocr --det_model_dir=inference/det_db \
--rec_model_dir=inference/rec_rcnn \
--cls_model_dir=inference/cls \
--image_dir=../../doc/imgs/12.jpg \
--use_angle_cls=true \
--det=true \
--rec=true \
--cls=true \
```
##### 2. det+rec:
```shell
./build/ppocr det \
--det_model_dir=inference/ch_ppocr_mobile_v2.0_det_infer \
--image_dir=../../doc/imgs/12.jpg
./build/ppocr --det_model_dir=inference/det_db \
--rec_model_dir=inference/rec_rcnn \
--image_dir=../../doc/imgs/12.jpg \
--use_angle_cls=false \
--det=true \
--rec=true \
--cls=false \
```
##### 2. run rec demo:
##### 3. det
```shell
./build/ppocr rec \
--rec_model_dir=inference/ch_ppocr_mobile_v2.0_rec_infer \
--image_dir=../../doc/imgs_words/ch/
./build/ppocr --det_model_dir=inference/det_db \
--image_dir=../../doc/imgs/12.jpg \
--det=true \
--rec=false
```
##### 3. run system demo:
##### 4. cls+rec:
```shell
# without text direction classifier
./build/ppocr system \
--det_model_dir=inference/ch_ppocr_mobile_v2.0_det_infer \
--rec_model_dir=inference/ch_ppocr_mobile_v2.0_rec_infer \
--image_dir=../../doc/imgs/12.jpg
# with text direction classifier
./build/ppocr system \
--det_model_dir=inference/ch_ppocr_mobile_v2.0_det_infer \
./build/ppocr --rec_model_dir=inference/rec_rcnn \
--cls_model_dir=inference/cls \
--image_dir=../../doc/imgs_words/ch/word_1.jpg \
--use_angle_cls=true \
--cls_model_dir=inference/ch_ppocr_mobile_v2.0_cls_infer \
--rec_model_dir=inference/ch_ppocr_mobile_v2.0_rec_infer \
--image_dir=../../doc/imgs/12.jpg
--det=false \
--rec=true \
--cls=true \
```
##### 5. rec
```shell
./build/ppocr --rec_model_dir=inference/rec_rcnn \
--image_dir=../../doc/imgs_words/ch/word_1.jpg \
--use_angle_cls=false \
--det=false \
--rec=true \
--cls=false \
```
##### 6. cls
```shell
./build/ppocr --cls_model_dir=inference/cls \
--cls_model_dir=inference/cls \
--image_dir=../../doc/imgs_words/ch/word_1.jpg \
--use_angle_cls=true \
--det=false \
--rec=false \
--cls=true \
```
More parameters are as follows,
......@@ -251,6 +285,16 @@ More parameters are as follows,
|enable_mkldnn|bool|true|Whether to use mkdlnn library|
|output|str|./output|Path where visualization results are saved|
- forward
|parameter|data type|default|meaning|
| :---: | :---: | :---: | :---: |
|det|bool|true|前向是否执行文字检测|
|rec|bool|true|前向是否执行文字识别|
|cls|bool|false|前向是否执行文字方向分类|
- Detection related parameters
|parameter|data type|default|meaning|
......@@ -260,7 +304,7 @@ More parameters are as follows,
|det_db_thresh|float|0.3|Used to filter the binarized image of DB prediction, setting 0.-0.3 has no obvious effect on the result|
|det_db_box_thresh|float|0.5|DB post-processing filter box threshold, if there is a missing box detected, it can be reduced as appropriate|
|det_db_unclip_ratio|float|1.6|Indicates the compactness of the text box, the smaller the value, the closer the text box to the text|
|use_polygon_score|bool|false|Whether to use polygon box to calculate bbox score, false means to use rectangle box to calculate. Use rectangular box to calculate faster, and polygonal box more accurate for curved text area.|
|det_db_score_mode|string|slow| slow: use polygon box to calculate bbox score, fast: use rectangle box to calculate. Use rectangular box to calculate faster, and polygonal box more accurate for curved text area.|
|visualize|bool|true|Whether to visualize the results,when it is set as true, the prediction results will be saved in the folder specified by the `output` field on an image with the same name as the input image.|
- Classifier related parameters
......@@ -270,6 +314,7 @@ More parameters are as follows,
|use_angle_cls|bool|false|Whether to use the direction classifier|
|cls_model_dir|string|-|Address of direction classifier inference model|
|cls_thresh|float|0.9|Score threshold of the direction classifier|
|cls_batch_num|int|1|batch size of classifier|
- Recognition related parameters
......@@ -277,15 +322,22 @@ More parameters are as follows,
| --- | --- | --- | --- |
|rec_model_dir|string|-|Address of recognition inference model|
|rec_char_dict_path|string|../../ppocr/utils/ppocr_keys_v1.txt|dictionary file|
|rec_batch_num|int|6|batch size of recognition|
* Multi-language inference is also supported in PaddleOCR, you can refer to [recognition tutorial](../../doc/doc_en/recognition_en.md) for more supported languages and models in PaddleOCR. Specifically, if you want to infer using multi-language models, you just need to modify values of `rec_char_dict_path` and `rec_model_dir`.
The detection results will be shown on the screen, which is as follows.
<div align="center">
<img src="./imgs/cpp_infer_pred_12.png" width="600">
</div>
```bash
predict img: ../../doc/imgs/12.jpg
../../doc/imgs/12.jpg
0 det boxes: [[79,553],[399,541],[400,573],[80,585]] rec text: 打浦路252935号 rec score: 0.933757
1 det boxes: [[31,509],[510,488],[511,529],[33,549]] rec text: 绿洲仕格维花园公寓 rec score: 0.951745
2 det boxes: [[181,456],[395,448],[396,480],[182,488]] rec text: 打浦路15号 rec score: 0.91956
3 det boxes: [[43,413],[480,391],[481,428],[45,450]] rec text: 上海斯格威铂尔多大酒店 rec score: 0.915914
The detection visualized image saved in ./output//12.jpg
```
## 3. FAQ
......
// Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include <gflags/gflags.h>
// common args
DEFINE_bool(use_gpu, false, "Infering with GPU or CPU.");
DEFINE_bool(use_tensorrt, false, "Whether use tensorrt.");
DEFINE_int32(gpu_id, 0, "Device id of GPU to execute.");
DEFINE_int32(gpu_mem, 4000, "GPU id when infering with GPU.");
DEFINE_int32(cpu_threads, 10, "Num of threads with CPU.");
DEFINE_bool(enable_mkldnn, false, "Whether use mkldnn with CPU.");
DEFINE_string(precision, "fp32", "Precision be one of fp32/fp16/int8");
DEFINE_bool(benchmark, false, "Whether use benchmark.");
DEFINE_string(output, "./output/", "Save benchmark log path.");
DEFINE_string(image_dir, "", "Dir of input image.");
DEFINE_string(
type, "ocr",
"Perform ocr or structure, the value is selected in ['ocr','structure'].");
// detection related
DEFINE_string(det_model_dir, "", "Path of det inference model.");
DEFINE_int32(max_side_len, 960, "max_side_len of input image.");
DEFINE_double(det_db_thresh, 0.3, "Threshold of det_db_thresh.");
DEFINE_double(det_db_box_thresh, 0.6, "Threshold of det_db_box_thresh.");
DEFINE_double(det_db_unclip_ratio, 1.5, "Threshold of det_db_unclip_ratio.");
DEFINE_bool(use_dilation, false, "Whether use the dilation on output map.");
DEFINE_string(det_db_score_mode, "slow", "Whether use polygon score.");
DEFINE_bool(visualize, true, "Whether show the detection results.");
// classification related
DEFINE_bool(use_angle_cls, false, "Whether use use_angle_cls.");
DEFINE_string(cls_model_dir, "", "Path of cls inference model.");
DEFINE_double(cls_thresh, 0.9, "Threshold of cls_thresh.");
DEFINE_int32(cls_batch_num, 1, "cls_batch_num.");
// recognition related
DEFINE_string(rec_model_dir, "", "Path of rec inference model.");
DEFINE_int32(rec_batch_num, 6, "rec_batch_num.");
DEFINE_string(rec_char_dict_path, "../../ppocr/utils/ppocr_keys_v1.txt",
"Path of dictionary.");
// ocr forward related
DEFINE_bool(det, true, "Whether use det in forward.");
DEFINE_bool(rec, true, "Whether use rec in forward.");
DEFINE_bool(cls, false, "Whether use cls in forward.");
\ No newline at end of file
......@@ -11,273 +11,19 @@
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "omp.h"
#include "opencv2/core.hpp"
#include "opencv2/imgcodecs.hpp"
#include "opencv2/imgproc.hpp"
#include <chrono>
#include <iomanip>
#include <iostream>
#include <ostream>
#include <sys/stat.h>
#include <vector>
#include <cstring>
#include <fstream>
#include <numeric>
#include <include/ocr_cls.h>
#include <include/ocr_det.h>
#include <include/ocr_rec.h>
#include <include/utility.h>
#include <sys/stat.h>
#include "auto_log/autolog.h"
#include <gflags/gflags.h>
#include <include/args.h>
#include <include/paddleocr.h>
DEFINE_bool(use_gpu, false, "Infering with GPU or CPU.");
DEFINE_int32(gpu_id, 0, "Device id of GPU to execute.");
DEFINE_int32(gpu_mem, 4000, "GPU id when infering with GPU.");
DEFINE_int32(cpu_threads, 10, "Num of threads with CPU.");
DEFINE_bool(enable_mkldnn, false, "Whether use mkldnn with CPU.");
DEFINE_bool(use_tensorrt, false, "Whether use tensorrt.");
DEFINE_string(precision, "fp32", "Precision be one of fp32/fp16/int8");
DEFINE_bool(benchmark, false, "Whether use benchmark.");
DEFINE_string(output, "./output/", "Save benchmark log path.");
// detection related
DEFINE_string(image_dir, "", "Dir of input image.");
DEFINE_string(det_model_dir, "", "Path of det inference model.");
DEFINE_int32(max_side_len, 960, "max_side_len of input image.");
DEFINE_double(det_db_thresh, 0.3, "Threshold of det_db_thresh.");
DEFINE_double(det_db_box_thresh, 0.6, "Threshold of det_db_box_thresh.");
DEFINE_double(det_db_unclip_ratio, 1.5, "Threshold of det_db_unclip_ratio.");
DEFINE_bool(use_polygon_score, false, "Whether use polygon score.");
DEFINE_bool(use_dilation, false, "Whether use the dilation on output map.");
DEFINE_bool(visualize, true, "Whether show the detection results.");
// classification related
DEFINE_bool(use_angle_cls, false, "Whether use use_angle_cls.");
DEFINE_string(cls_model_dir, "", "Path of cls inference model.");
DEFINE_double(cls_thresh, 0.9, "Threshold of cls_thresh.");
// recognition related
DEFINE_string(rec_model_dir, "", "Path of rec inference model.");
DEFINE_int32(rec_batch_num, 6, "rec_batch_num.");
DEFINE_string(rec_char_dict_path, "../../ppocr/utils/ppocr_keys_v1.txt",
"Path of dictionary.");
using namespace std;
using namespace cv;
using namespace PaddleOCR;
static bool PathExists(const std::string &path) {
#ifdef _WIN32
struct _stat buffer;
return (_stat(path.c_str(), &buffer) == 0);
#else
struct stat buffer;
return (stat(path.c_str(), &buffer) == 0);
#endif // !_WIN32
}
int main_det(std::vector<cv::String> cv_all_img_names) {
std::vector<double> time_info = {0, 0, 0};
DBDetector det(FLAGS_det_model_dir, FLAGS_use_gpu, FLAGS_gpu_id,
FLAGS_gpu_mem, FLAGS_cpu_threads, FLAGS_enable_mkldnn,
FLAGS_max_side_len, FLAGS_det_db_thresh,
FLAGS_det_db_box_thresh, FLAGS_det_db_unclip_ratio,
FLAGS_use_polygon_score, FLAGS_use_dilation,
FLAGS_use_tensorrt, FLAGS_precision);
if (!PathExists(FLAGS_output)) {
mkdir(FLAGS_output.c_str(), 0777);
}
for (int i = 0; i < cv_all_img_names.size(); ++i) {
if (!FLAGS_benchmark) {
cout << "The predict img: " << cv_all_img_names[i] << endl;
}
cv::Mat srcimg = cv::imread(cv_all_img_names[i], cv::IMREAD_COLOR);
if (!srcimg.data) {
std::cerr << "[ERROR] image read failed! image path: "
<< cv_all_img_names[i] << endl;
exit(1);
}
std::vector<std::vector<std::vector<int>>> boxes;
std::vector<double> det_times;
det.Run(srcimg, boxes, &det_times);
// visualization
if (FLAGS_visualize) {
std::string file_name = Utility::basename(cv_all_img_names[i]);
Utility::VisualizeBboxes(srcimg, boxes, FLAGS_output + "/" + file_name);
}
time_info[0] += det_times[0];
time_info[1] += det_times[1];
time_info[2] += det_times[2];
if (FLAGS_benchmark) {
cout << cv_all_img_names[i] << '\t';
for (int n = 0; n < boxes.size(); n++) {
for (int m = 0; m < boxes[n].size(); m++) {
cout << boxes[n][m][0] << ' ' << boxes[n][m][1] << ' ';
}
}
cout << endl;
}
}
if (FLAGS_benchmark) {
AutoLogger autolog("ocr_det", FLAGS_use_gpu, FLAGS_use_tensorrt,
FLAGS_enable_mkldnn, FLAGS_cpu_threads, 1, "dynamic",
FLAGS_precision, time_info, cv_all_img_names.size());
autolog.report();
}
return 0;
}
int main_rec(std::vector<cv::String> cv_all_img_names) {
std::vector<double> time_info = {0, 0, 0};
std::string rec_char_dict_path = FLAGS_rec_char_dict_path;
if (FLAGS_benchmark)
rec_char_dict_path = FLAGS_rec_char_dict_path.substr(6);
cout << "label file: " << rec_char_dict_path << endl;
CRNNRecognizer rec(FLAGS_rec_model_dir, FLAGS_use_gpu, FLAGS_gpu_id,
FLAGS_gpu_mem, FLAGS_cpu_threads, FLAGS_enable_mkldnn,
rec_char_dict_path, FLAGS_use_tensorrt, FLAGS_precision,
FLAGS_rec_batch_num);
std::vector<cv::Mat> img_list;
for (int i = 0; i < cv_all_img_names.size(); ++i) {
cv::Mat srcimg = cv::imread(cv_all_img_names[i], cv::IMREAD_COLOR);
if (!srcimg.data) {
std::cerr << "[ERROR] image read failed! image path: "
<< cv_all_img_names[i] << endl;
exit(1);
}
img_list.push_back(srcimg);
}
std::vector<std::string> rec_texts(img_list.size(), "");
std::vector<float> rec_text_scores(img_list.size(), 0);
std::vector<double> rec_times;
rec.Run(img_list, rec_texts, rec_text_scores, &rec_times);
// output rec results
for (int i = 0; i < rec_texts.size(); i++) {
cout << "The predict img: " << cv_all_img_names[i] << "\t" << rec_texts[i]
<< "\t" << rec_text_scores[i] << endl;
}
time_info[0] += rec_times[0];
time_info[1] += rec_times[1];
time_info[2] += rec_times[2];
if (FLAGS_benchmark) {
AutoLogger autolog("ocr_rec", FLAGS_use_gpu, FLAGS_use_tensorrt,
FLAGS_enable_mkldnn, FLAGS_cpu_threads,
FLAGS_rec_batch_num, "dynamic", FLAGS_precision,
time_info, cv_all_img_names.size());
autolog.report();
}
return 0;
}
int main_system(std::vector<cv::String> cv_all_img_names) {
std::vector<double> time_info_det = {0, 0, 0};
std::vector<double> time_info_rec = {0, 0, 0};
if (!PathExists(FLAGS_output)) {
mkdir(FLAGS_output.c_str(), 0777);
}
DBDetector det(FLAGS_det_model_dir, FLAGS_use_gpu, FLAGS_gpu_id,
FLAGS_gpu_mem, FLAGS_cpu_threads, FLAGS_enable_mkldnn,
FLAGS_max_side_len, FLAGS_det_db_thresh,
FLAGS_det_db_box_thresh, FLAGS_det_db_unclip_ratio,
FLAGS_use_polygon_score, FLAGS_use_dilation,
FLAGS_use_tensorrt, FLAGS_precision);
Classifier *cls = nullptr;
if (FLAGS_use_angle_cls) {
cls = new Classifier(FLAGS_cls_model_dir, FLAGS_use_gpu, FLAGS_gpu_id,
FLAGS_gpu_mem, FLAGS_cpu_threads, FLAGS_enable_mkldnn,
FLAGS_cls_thresh, FLAGS_use_tensorrt, FLAGS_precision);
}
std::string rec_char_dict_path = FLAGS_rec_char_dict_path;
if (FLAGS_benchmark)
rec_char_dict_path = FLAGS_rec_char_dict_path.substr(6);
cout << "label file: " << rec_char_dict_path << endl;
CRNNRecognizer rec(FLAGS_rec_model_dir, FLAGS_use_gpu, FLAGS_gpu_id,
FLAGS_gpu_mem, FLAGS_cpu_threads, FLAGS_enable_mkldnn,
rec_char_dict_path, FLAGS_use_tensorrt, FLAGS_precision,
FLAGS_rec_batch_num);
for (int i = 0; i < cv_all_img_names.size(); ++i) {
cout << "The predict img: " << cv_all_img_names[i] << endl;
cv::Mat srcimg = cv::imread(cv_all_img_names[i], cv::IMREAD_COLOR);
if (!srcimg.data) {
std::cerr << "[ERROR] image read failed! image path: "
<< cv_all_img_names[i] << endl;
exit(1);
}
// det
std::vector<std::vector<std::vector<int>>> boxes;
std::vector<double> det_times;
std::vector<double> rec_times;
det.Run(srcimg, boxes, &det_times);
if (FLAGS_visualize) {
std::string file_name = Utility::basename(cv_all_img_names[i]);
Utility::VisualizeBboxes(srcimg, boxes, FLAGS_output + "/" + file_name);
}
time_info_det[0] += det_times[0];
time_info_det[1] += det_times[1];
time_info_det[2] += det_times[2];
// rec
std::vector<cv::Mat> img_list;
for (int j = 0; j < boxes.size(); j++) {
cv::Mat crop_img;
crop_img = Utility::GetRotateCropImage(srcimg, boxes[j]);
if (cls != nullptr) {
crop_img = cls->Run(crop_img);
}
img_list.push_back(crop_img);
}
std::vector<std::string> rec_texts(img_list.size(), "");
std::vector<float> rec_text_scores(img_list.size(), 0);
rec.Run(img_list, rec_texts, rec_text_scores, &rec_times);
// output rec results
for (int i = 0; i < rec_texts.size(); i++) {
std::cout << i << "\t" << rec_texts[i] << "\t" << rec_text_scores[i]
<< std::endl;
}
time_info_rec[0] += rec_times[0];
time_info_rec[1] += rec_times[1];
time_info_rec[2] += rec_times[2];
}
if (FLAGS_benchmark) {
AutoLogger autolog_det("ocr_det", FLAGS_use_gpu, FLAGS_use_tensorrt,
FLAGS_enable_mkldnn, FLAGS_cpu_threads, 1, "dynamic",
FLAGS_precision, time_info_det,
cv_all_img_names.size());
AutoLogger autolog_rec("ocr_rec", FLAGS_use_gpu, FLAGS_use_tensorrt,
FLAGS_enable_mkldnn, FLAGS_cpu_threads,
FLAGS_rec_batch_num, "dynamic", FLAGS_precision,
time_info_rec, cv_all_img_names.size());
autolog_det.report();
std::cout << endl;
autolog_rec.report();
}
return 0;
}
void check_params(char *mode) {
if (strcmp(mode, "det") == 0) {
void check_params() {
if (FLAGS_det) {
if (FLAGS_det_model_dir.empty() || FLAGS_image_dir.empty()) {
std::cout << "Usage[det]: ./ppocr "
"--det_model_dir=/PATH/TO/DET_INFERENCE_MODEL/ "
......@@ -285,7 +31,7 @@ void check_params(char *mode) {
exit(1);
}
}
if (strcmp(mode, "rec") == 0) {
if (FLAGS_rec) {
if (FLAGS_rec_model_dir.empty() || FLAGS_image_dir.empty()) {
std::cout << "Usage[rec]: ./ppocr "
"--rec_model_dir=/PATH/TO/REC_INFERENCE_MODEL/ "
......@@ -293,19 +39,10 @@ void check_params(char *mode) {
exit(1);
}
}
if (strcmp(mode, "system") == 0) {
if ((FLAGS_det_model_dir.empty() || FLAGS_rec_model_dir.empty() ||
FLAGS_image_dir.empty()) ||
(FLAGS_use_angle_cls && FLAGS_cls_model_dir.empty())) {
std::cout << "Usage[system without angle cls]: ./ppocr "
"--det_model_dir=/PATH/TO/DET_INFERENCE_MODEL/ "
<< "--rec_model_dir=/PATH/TO/REC_INFERENCE_MODEL/ "
<< "--image_dir=/PATH/TO/INPUT/IMAGE/" << std::endl;
std::cout << "Usage[system with angle cls]: ./ppocr "
"--det_model_dir=/PATH/TO/DET_INFERENCE_MODEL/ "
<< "--use_angle_cls=true "
<< "--cls_model_dir=/PATH/TO/CLS_INFERENCE_MODEL/ "
<< "--rec_model_dir=/PATH/TO/REC_INFERENCE_MODEL/ "
if (FLAGS_cls && FLAGS_use_angle_cls) {
if (FLAGS_cls_model_dir.empty() || FLAGS_image_dir.empty()) {
std::cout << "Usage[cls]: ./ppocr "
<< "--cls_model_dir=/PATH/TO/REC_INFERENCE_MODEL/ "
<< "--image_dir=/PATH/TO/INPUT/IMAGE/" << std::endl;
exit(1);
}
......@@ -318,19 +55,11 @@ void check_params(char *mode) {
}
int main(int argc, char **argv) {
if (argc <= 1 ||
(strcmp(argv[1], "det") != 0 && strcmp(argv[1], "rec") != 0 &&
strcmp(argv[1], "system") != 0)) {
std::cout << "Please choose one mode of [det, rec, system] !" << std::endl;
return -1;
}
std::cout << "mode: " << argv[1] << endl;
// Parsing command-line
google::ParseCommandLineFlags(&argc, &argv, true);
check_params(argv[1]);
check_params();
if (!PathExists(FLAGS_image_dir)) {
if (!Utility::PathExists(FLAGS_image_dir)) {
std::cerr << "[ERROR] image path not exist! image_dir: " << FLAGS_image_dir
<< endl;
exit(1);
......@@ -340,13 +69,37 @@ int main(int argc, char **argv) {
cv::glob(FLAGS_image_dir, cv_all_img_names);
std::cout << "total images num: " << cv_all_img_names.size() << endl;
if (strcmp(argv[1], "det") == 0) {
return main_det(cv_all_img_names);
}
if (strcmp(argv[1], "rec") == 0) {
return main_rec(cv_all_img_names);
}
if (strcmp(argv[1], "system") == 0) {
return main_system(cv_all_img_names);
PaddleOCR::PaddleOCR ocr = PaddleOCR::PaddleOCR();
std::vector<std::vector<OCRPredictResult>> ocr_results =
ocr.ocr(cv_all_img_names, FLAGS_det, FLAGS_rec, FLAGS_cls);
for (int i = 0; i < cv_all_img_names.size(); ++i) {
if (FLAGS_benchmark) {
cout << cv_all_img_names[i] << '\t';
for (int n = 0; n < ocr_results[i].size(); n++) {
for (int m = 0; m < ocr_results[i][n].box.size(); m++) {
cout << ocr_results[i][n].box[m][0] << ' '
<< ocr_results[i][n].box[m][1] << ' ';
}
}
cout << endl;
} else {
cout << cv_all_img_names[i] << "\n";
Utility::print_result(ocr_results[i]);
if (FLAGS_visualize && FLAGS_det) {
cv::Mat srcimg = cv::imread(cv_all_img_names[i], cv::IMREAD_COLOR);
if (!srcimg.data) {
std::cerr << "[ERROR] image read failed! image path: "
<< cv_all_img_names[i] << endl;
exit(1);
}
std::string file_name = Utility::basename(cv_all_img_names[i]);
Utility::VisualizeBboxes(srcimg, ocr_results[i],
FLAGS_output + "/" + file_name);
}
cout << "***************************" << endl;
}
}
}
......@@ -16,57 +16,84 @@
namespace PaddleOCR {
cv::Mat Classifier::Run(cv::Mat &img) {
cv::Mat src_img;
img.copyTo(src_img);
cv::Mat resize_img;
void Classifier::Run(std::vector<cv::Mat> img_list,
std::vector<int> &cls_labels,
std::vector<float> &cls_scores,
std::vector<double> &times) {
std::chrono::duration<float> preprocess_diff =
std::chrono::steady_clock::now() - std::chrono::steady_clock::now();
std::chrono::duration<float> inference_diff =
std::chrono::steady_clock::now() - std::chrono::steady_clock::now();
std::chrono::duration<float> postprocess_diff =
std::chrono::steady_clock::now() - std::chrono::steady_clock::now();
int img_num = img_list.size();
std::vector<int> cls_image_shape = {3, 48, 192};
int index = 0;
float wh_ratio = float(img.cols) / float(img.rows);
this->resize_op_.Run(img, resize_img, this->use_tensorrt_, cls_image_shape);
this->normalize_op_.Run(&resize_img, this->mean_, this->scale_,
this->is_scale_);
std::vector<float> input(1 * 3 * resize_img.rows * resize_img.cols, 0.0f);
this->permute_op_.Run(&resize_img, input.data());
// Inference.
auto input_names = this->predictor_->GetInputNames();
auto input_t = this->predictor_->GetInputHandle(input_names[0]);
input_t->Reshape({1, 3, resize_img.rows, resize_img.cols});
input_t->CopyFromCpu(input.data());
this->predictor_->Run();
std::vector<float> softmax_out;
std::vector<int64_t> label_out;
auto output_names = this->predictor_->GetOutputNames();
auto softmax_out_t = this->predictor_->GetOutputHandle(output_names[0]);
auto softmax_shape_out = softmax_out_t->shape();
int softmax_out_num =
std::accumulate(softmax_shape_out.begin(), softmax_shape_out.end(), 1,
std::multiplies<int>());
softmax_out.resize(softmax_out_num);
softmax_out_t->CopyToCpu(softmax_out.data());
float score = 0;
int label = 0;
for (int i = 0; i < softmax_out_num; i++) {
if (softmax_out[i] > score) {
score = softmax_out[i];
label = i;
for (int beg_img_no = 0; beg_img_no < img_num;
beg_img_no += this->cls_batch_num_) {
auto preprocess_start = std::chrono::steady_clock::now();
int end_img_no = min(img_num, beg_img_no + this->cls_batch_num_);
int batch_num = end_img_no - beg_img_no;
// preprocess
std::vector<cv::Mat> norm_img_batch;
for (int ino = beg_img_no; ino < end_img_no; ino++) {
cv::Mat srcimg;
img_list[ino].copyTo(srcimg);
cv::Mat resize_img;
this->resize_op_.Run(srcimg, resize_img, this->use_tensorrt_,
cls_image_shape);
this->normalize_op_.Run(&resize_img, this->mean_, this->scale_,
this->is_scale_);
norm_img_batch.push_back(resize_img);
}
std::vector<float> input(batch_num * cls_image_shape[0] *
cls_image_shape[1] * cls_image_shape[2],
0.0f);
this->permute_op_.Run(norm_img_batch, input.data());
auto preprocess_end = std::chrono::steady_clock::now();
preprocess_diff += preprocess_end - preprocess_start;
// inference.
auto input_names = this->predictor_->GetInputNames();
auto input_t = this->predictor_->GetInputHandle(input_names[0]);
input_t->Reshape({batch_num, cls_image_shape[0], cls_image_shape[1],
cls_image_shape[2]});
auto inference_start = std::chrono::steady_clock::now();
input_t->CopyFromCpu(input.data());
this->predictor_->Run();
std::vector<float> predict_batch;
auto output_names = this->predictor_->GetOutputNames();
auto output_t = this->predictor_->GetOutputHandle(output_names[0]);
auto predict_shape = output_t->shape();
int out_num = std::accumulate(predict_shape.begin(), predict_shape.end(), 1,
std::multiplies<int>());
predict_batch.resize(out_num);
output_t->CopyToCpu(predict_batch.data());
auto inference_end = std::chrono::steady_clock::now();
inference_diff += inference_end - inference_start;
// postprocess
auto postprocess_start = std::chrono::steady_clock::now();
for (int batch_idx = 0; batch_idx < predict_shape[0]; batch_idx++) {
int label = int(
Utility::argmax(&predict_batch[batch_idx * predict_shape[1]],
&predict_batch[(batch_idx + 1) * predict_shape[1]]));
float score = float(*std::max_element(
&predict_batch[batch_idx * predict_shape[1]],
&predict_batch[(batch_idx + 1) * predict_shape[1]]));
cls_labels[beg_img_no + batch_idx] = label;
cls_scores[beg_img_no + batch_idx] = score;
}
auto postprocess_end = std::chrono::steady_clock::now();
postprocess_diff += postprocess_end - postprocess_start;
}
if (label % 2 == 1 && score > this->cls_thresh) {
cv::rotate(src_img, src_img, 1);
}
return src_img;
times.push_back(double(preprocess_diff.count() * 1000));
times.push_back(double(inference_diff.count() * 1000));
times.push_back(double(postprocess_diff.count() * 1000));
}
void Classifier::LoadModel(const std::string &model_dir) {
......@@ -81,13 +108,10 @@ void Classifier::LoadModel(const std::string &model_dir) {
if (this->precision_ == "fp16") {
precision = paddle_infer::Config::Precision::kHalf;
}
if (this->precision_ == "int8") {
if (this->precision_ == "int8") {
precision = paddle_infer::Config::Precision::kInt8;
}
config.EnableTensorRtEngine(
1 << 20, 10, 3,
precision,
false, false);
}
config.EnableTensorRtEngine(1 << 20, 10, 3, precision, false, false);
}
} else {
config.DisableGpu();
......
......@@ -94,7 +94,7 @@ void DBDetector::LoadModel(const std::string &model_dir) {
void DBDetector::Run(cv::Mat &img,
std::vector<std::vector<std::vector<int>>> &boxes,
std::vector<double> *times) {
std::vector<double> &times) {
float ratio_h{};
float ratio_w{};
......@@ -161,20 +161,19 @@ void DBDetector::Run(cv::Mat &img,
boxes = post_processor_.BoxesFromBitmap(
pred_map, bit_map, this->det_db_box_thresh_, this->det_db_unclip_ratio_,
this->use_polygon_score_);
this->det_db_score_mode_);
boxes = post_processor_.FilterTagDetRes(boxes, ratio_h, ratio_w, srcimg);
auto postprocess_end = std::chrono::steady_clock::now();
std::cout << "Detected boxes num: " << boxes.size() << endl;
std::chrono::duration<float> preprocess_diff =
preprocess_end - preprocess_start;
times->push_back(double(preprocess_diff.count() * 1000));
times.push_back(double(preprocess_diff.count() * 1000));
std::chrono::duration<float> inference_diff = inference_end - inference_start;
times->push_back(double(inference_diff.count() * 1000));
times.push_back(double(inference_diff.count() * 1000));
std::chrono::duration<float> postprocess_diff =
postprocess_end - postprocess_start;
times->push_back(double(postprocess_diff.count() * 1000));
times.push_back(double(postprocess_diff.count() * 1000));
}
} // namespace PaddleOCR
......@@ -19,7 +19,7 @@ namespace PaddleOCR {
void CRNNRecognizer::Run(std::vector<cv::Mat> img_list,
std::vector<std::string> &rec_texts,
std::vector<float> &rec_text_scores,
std::vector<double> *times) {
std::vector<double> &times) {
std::chrono::duration<float> preprocess_diff =
std::chrono::steady_clock::now() - std::chrono::steady_clock::now();
std::chrono::duration<float> inference_diff =
......@@ -38,6 +38,7 @@ void CRNNRecognizer::Run(std::vector<cv::Mat> img_list,
beg_img_no += this->rec_batch_num_) {
auto preprocess_start = std::chrono::steady_clock::now();
int end_img_no = min(img_num, beg_img_no + this->rec_batch_num_);
int batch_num = end_img_no - beg_img_no;
float max_wh_ratio = 0;
for (int ino = beg_img_no; ino < end_img_no; ino++) {
int h = img_list[indices[ino]].rows;
......@@ -45,6 +46,7 @@ void CRNNRecognizer::Run(std::vector<cv::Mat> img_list,
float wh_ratio = w * 1.0 / h;
max_wh_ratio = max(max_wh_ratio, wh_ratio);
}
int batch_width = 0;
std::vector<cv::Mat> norm_img_batch;
for (int ino = beg_img_no; ino < end_img_no; ino++) {
......@@ -59,15 +61,14 @@ void CRNNRecognizer::Run(std::vector<cv::Mat> img_list,
batch_width = max(resize_img.cols, batch_width);
}
std::vector<float> input(this->rec_batch_num_ * 3 * 32 * batch_width, 0.0f);
std::vector<float> input(batch_num * 3 * 32 * batch_width, 0.0f);
this->permute_op_.Run(norm_img_batch, input.data());
auto preprocess_end = std::chrono::steady_clock::now();
preprocess_diff += preprocess_end - preprocess_start;
// Inference.
auto input_names = this->predictor_->GetInputNames();
auto input_t = this->predictor_->GetInputHandle(input_names[0]);
input_t->Reshape({this->rec_batch_num_, 3, 32, batch_width});
input_t->Reshape({batch_num, 3, 32, batch_width});
auto inference_start = std::chrono::steady_clock::now();
input_t->CopyFromCpu(input.data());
this->predictor_->Run();
......@@ -84,7 +85,6 @@ void CRNNRecognizer::Run(std::vector<cv::Mat> img_list,
output_t->CopyToCpu(predict_batch.data());
auto inference_end = std::chrono::steady_clock::now();
inference_diff += inference_end - inference_start;
// ctc decode
auto postprocess_start = std::chrono::steady_clock::now();
for (int m = 0; m < predict_shape[0]; m++) {
......@@ -120,9 +120,9 @@ void CRNNRecognizer::Run(std::vector<cv::Mat> img_list,
auto postprocess_end = std::chrono::steady_clock::now();
postprocess_diff += postprocess_end - postprocess_start;
}
times->push_back(double(preprocess_diff.count() * 1000));
times->push_back(double(inference_diff.count() * 1000));
times->push_back(double(postprocess_diff.count() * 1000));
times.push_back(double(preprocess_diff.count() * 1000));
times.push_back(double(inference_diff.count() * 1000));
times.push_back(double(postprocess_diff.count() * 1000));
}
void CRNNRecognizer::LoadModel(const std::string &model_dir) {
......
// Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include <include/args.h>
#include <include/paddleocr.h>
#include "auto_log/autolog.h"
#include <numeric>
#include <sys/stat.h>
namespace PaddleOCR {
PaddleOCR::PaddleOCR() {
if (FLAGS_det) {
this->detector_ = new DBDetector(
FLAGS_det_model_dir, FLAGS_use_gpu, FLAGS_gpu_id, FLAGS_gpu_mem,
FLAGS_cpu_threads, FLAGS_enable_mkldnn, FLAGS_max_side_len,
FLAGS_det_db_thresh, FLAGS_det_db_box_thresh, FLAGS_det_db_unclip_ratio,
FLAGS_det_db_score_mode, FLAGS_use_dilation, FLAGS_use_tensorrt,
FLAGS_precision);
}
if (FLAGS_cls && FLAGS_use_angle_cls) {
this->classifier_ = new Classifier(
FLAGS_cls_model_dir, FLAGS_use_gpu, FLAGS_gpu_id, FLAGS_gpu_mem,
FLAGS_cpu_threads, FLAGS_enable_mkldnn, FLAGS_cls_thresh,
FLAGS_use_tensorrt, FLAGS_precision, FLAGS_cls_batch_num);
}
if (FLAGS_rec) {
this->recognizer_ = new CRNNRecognizer(
FLAGS_rec_model_dir, FLAGS_use_gpu, FLAGS_gpu_id, FLAGS_gpu_mem,
FLAGS_cpu_threads, FLAGS_enable_mkldnn, FLAGS_rec_char_dict_path,
FLAGS_use_tensorrt, FLAGS_precision, FLAGS_rec_batch_num);
}
};
void PaddleOCR::det(cv::Mat img, std::vector<OCRPredictResult> &ocr_results,
std::vector<double> &times) {
std::vector<std::vector<std::vector<int>>> boxes;
std::vector<double> det_times;
this->detector_->Run(img, boxes, det_times);
for (int i = 0; i < boxes.size(); i++) {
OCRPredictResult res;
res.box = boxes[i];
ocr_results.push_back(res);
}
times[0] += det_times[0];
times[1] += det_times[1];
times[2] += det_times[2];
}
void PaddleOCR::rec(std::vector<cv::Mat> img_list,
std::vector<OCRPredictResult> &ocr_results,
std::vector<double> &times) {
std::vector<std::string> rec_texts(img_list.size(), "");
std::vector<float> rec_text_scores(img_list.size(), 0);
std::vector<double> rec_times;
this->recognizer_->Run(img_list, rec_texts, rec_text_scores, rec_times);
// output rec results
for (int i = 0; i < rec_texts.size(); i++) {
ocr_results[i].text = rec_texts[i];
ocr_results[i].score = rec_text_scores[i];
}
times[0] += rec_times[0];
times[1] += rec_times[1];
times[2] += rec_times[2];
}
void PaddleOCR::cls(std::vector<cv::Mat> img_list,
std::vector<OCRPredictResult> &ocr_results,
std::vector<double> &times) {
std::vector<int> cls_labels(img_list.size(), 0);
std::vector<float> cls_scores(img_list.size(), 0);
std::vector<double> cls_times;
this->classifier_->Run(img_list, cls_labels, cls_scores, cls_times);
// output cls results
for (int i = 0; i < cls_labels.size(); i++) {
ocr_results[i].cls_label = cls_labels[i];
ocr_results[i].cls_score = cls_scores[i];
}
times[0] += cls_times[0];
times[1] += cls_times[1];
times[2] += cls_times[2];
}
std::vector<std::vector<OCRPredictResult>>
PaddleOCR::ocr(std::vector<cv::String> cv_all_img_names, bool det, bool rec,
bool cls) {
std::vector<double> time_info_det = {0, 0, 0};
std::vector<double> time_info_rec = {0, 0, 0};
std::vector<double> time_info_cls = {0, 0, 0};
std::vector<std::vector<OCRPredictResult>> ocr_results;
if (!det) {
std::vector<OCRPredictResult> ocr_result;
// read image
std::vector<cv::Mat> img_list;
for (int i = 0; i < cv_all_img_names.size(); ++i) {
cv::Mat srcimg = cv::imread(cv_all_img_names[i], cv::IMREAD_COLOR);
if (!srcimg.data) {
std::cerr << "[ERROR] image read failed! image path: "
<< cv_all_img_names[i] << endl;
exit(1);
}
img_list.push_back(srcimg);
OCRPredictResult res;
ocr_result.push_back(res);
}
if (cls && this->classifier_ != nullptr) {
this->cls(img_list, ocr_result, time_info_cls);
for (int i = 0; i < img_list.size(); i++) {
if (ocr_result[i].cls_label % 2 == 1 &&
ocr_result[i].cls_score > this->classifier_->cls_thresh) {
cv::rotate(img_list[i], img_list[i], 1);
}
}
}
if (rec) {
this->rec(img_list, ocr_result, time_info_rec);
}
for (int i = 0; i < cv_all_img_names.size(); ++i) {
std::vector<OCRPredictResult> ocr_result_tmp;
ocr_result_tmp.push_back(ocr_result[i]);
ocr_results.push_back(ocr_result_tmp);
}
} else {
if (!Utility::PathExists(FLAGS_output) && FLAGS_det) {
mkdir(FLAGS_output.c_str(), 0777);
}
for (int i = 0; i < cv_all_img_names.size(); ++i) {
std::vector<OCRPredictResult> ocr_result;
if (!FLAGS_benchmark) {
cout << "predict img: " << cv_all_img_names[i] << endl;
}
cv::Mat srcimg = cv::imread(cv_all_img_names[i], cv::IMREAD_COLOR);
if (!srcimg.data) {
std::cerr << "[ERROR] image read failed! image path: "
<< cv_all_img_names[i] << endl;
exit(1);
}
// det
this->det(srcimg, ocr_result, time_info_det);
// crop image
std::vector<cv::Mat> img_list;
for (int j = 0; j < ocr_result.size(); j++) {
cv::Mat crop_img;
crop_img = Utility::GetRotateCropImage(srcimg, ocr_result[j].box);
img_list.push_back(crop_img);
}
// cls
if (cls && this->classifier_ != nullptr) {
this->cls(img_list, ocr_result, time_info_cls);
for (int i = 0; i < img_list.size(); i++) {
if (ocr_result[i].cls_label % 2 == 1 &&
ocr_result[i].cls_score > this->classifier_->cls_thresh) {
cv::rotate(img_list[i], img_list[i], 1);
}
}
}
// rec
if (rec) {
this->rec(img_list, ocr_result, time_info_rec);
}
ocr_results.push_back(ocr_result);
}
}
if (FLAGS_benchmark) {
this->log(time_info_det, time_info_rec, time_info_cls,
cv_all_img_names.size());
}
return ocr_results;
} // namespace PaddleOCR
void PaddleOCR::log(std::vector<double> &det_times,
std::vector<double> &rec_times,
std::vector<double> &cls_times, int img_num) {
if (det_times[0] + det_times[1] + det_times[2] > 0) {
AutoLogger autolog_det("ocr_det", FLAGS_use_gpu, FLAGS_use_tensorrt,
FLAGS_enable_mkldnn, FLAGS_cpu_threads, 1, "dynamic",
FLAGS_precision, det_times, img_num);
autolog_det.report();
}
if (rec_times[0] + rec_times[1] + rec_times[2] > 0) {
AutoLogger autolog_rec("ocr_rec", FLAGS_use_gpu, FLAGS_use_tensorrt,
FLAGS_enable_mkldnn, FLAGS_cpu_threads,
FLAGS_rec_batch_num, "dynamic", FLAGS_precision,
rec_times, img_num);
autolog_rec.report();
}
if (cls_times[0] + cls_times[1] + cls_times[2] > 0) {
AutoLogger autolog_cls("ocr_cls", FLAGS_use_gpu, FLAGS_use_tensorrt,
FLAGS_enable_mkldnn, FLAGS_cpu_threads,
FLAGS_cls_batch_num, "dynamic", FLAGS_precision,
cls_times, img_num);
autolog_cls.report();
}
}
PaddleOCR::~PaddleOCR() {
if (this->detector_ != nullptr) {
delete this->detector_;
}
if (this->classifier_ != nullptr) {
delete this->classifier_;
}
if (this->recognizer_ != nullptr) {
delete this->recognizer_;
}
};
} // namespace PaddleOCR
......@@ -12,8 +12,8 @@
// See the License for the specific language governing permissions and
// limitations under the License.
#include <include/clipper.h>
#include <include/postprocess_op.h>
#include <include/clipper.cpp>
namespace PaddleOCR {
......@@ -187,23 +187,22 @@ float PostProcessor::PolygonScoreAcc(std::vector<cv::Point> contour,
cv::Mat mask;
mask = cv::Mat::zeros(ymax - ymin + 1, xmax - xmin + 1, CV_8UC1);
cv::Point *rook_point = new cv::Point[contour.size()];
cv::Point* rook_point = new cv::Point[contour.size()];
for (int i = 0; i < contour.size(); ++i) {
rook_point[i] = cv::Point(int(box_x[i]) - xmin, int(box_y[i]) - ymin);
}
const cv::Point *ppt[1] = {rook_point};
int npt[] = {int(contour.size())};
cv::fillPoly(mask, ppt, npt, 1, cv::Scalar(1));
cv::Mat croppedImg;
pred(cv::Rect(xmin, ymin, xmax - xmin + 1, ymax - ymin + 1)).copyTo(croppedImg);
pred(cv::Rect(xmin, ymin, xmax - xmin + 1, ymax - ymin + 1))
.copyTo(croppedImg);
float score = cv::mean(croppedImg, mask)[0];
delete []rook_point;
delete[] rook_point;
return score;
}
......@@ -247,7 +246,7 @@ float PostProcessor::BoxScoreFast(std::vector<std::vector<float>> box_array,
std::vector<std::vector<std::vector<int>>> PostProcessor::BoxesFromBitmap(
const cv::Mat pred, const cv::Mat bitmap, const float &box_thresh,
const float &det_db_unclip_ratio, const bool &use_polygon_score) {
const float &det_db_unclip_ratio, const std::string &det_db_score_mode) {
const int min_size = 3;
const int max_candidates = 1000;
......@@ -281,7 +280,7 @@ std::vector<std::vector<std::vector<int>>> PostProcessor::BoxesFromBitmap(
}
float score;
if (use_polygon_score)
if (det_db_score_mode == "slow")
/* compute using polygon*/
score = PolygonScoreAcc(contours[_i], pred);
else
......
......@@ -38,16 +38,16 @@ std::vector<std::string> Utility::ReadDict(const std::string &path) {
return m_vec;
}
void Utility::VisualizeBboxes(
const cv::Mat &srcimg,
const std::vector<std::vector<std::vector<int>>> &boxes,
const std::string &save_path) {
void Utility::VisualizeBboxes(const cv::Mat &srcimg,
const std::vector<OCRPredictResult> &ocr_result,
const std::string &save_path) {
cv::Mat img_vis;
srcimg.copyTo(img_vis);
for (int n = 0; n < boxes.size(); n++) {
for (int n = 0; n < ocr_result.size(); n++) {
cv::Point rook_points[4];
for (int m = 0; m < boxes[n].size(); m++) {
rook_points[m] = cv::Point(int(boxes[n][m][0]), int(boxes[n][m][1]));
for (int m = 0; m < ocr_result[n].box.size(); m++) {
rook_points[m] =
cv::Point(int(ocr_result[n].box[m][0]), int(ocr_result[n].box[m][1]));
}
const cv::Point *ppt[1] = {rook_points};
......@@ -196,4 +196,43 @@ std::string Utility::basename(const std::string &filename) {
return filename.substr(index + 1, len - index);
}
bool Utility::PathExists(const std::string &path) {
#ifdef _WIN32
struct _stat buffer;
return (_stat(path.c_str(), &buffer) == 0);
#else
struct stat buffer;
return (stat(path.c_str(), &buffer) == 0);
#endif // !_WIN32
}
void Utility::print_result(const std::vector<OCRPredictResult> &ocr_result) {
for (int i = 0; i < ocr_result.size(); i++) {
std::cout << i << "\t";
// det
std::vector<std::vector<int>> boxes = ocr_result[i].box;
if (boxes.size() > 0) {
std::cout << "det boxes: [";
for (int n = 0; n < boxes.size(); n++) {
std::cout << '[' << boxes[n][0] << ',' << boxes[n][1] << "]";
if (n != boxes.size() - 1) {
std::cout << ',';
}
}
std::cout << "] ";
}
// rec
if (ocr_result[i].score != -1.0) {
std::cout << "rec text: " << ocr_result[i].text
<< " rec score: " << ocr_result[i].score << " ";
}
// cls
if (ocr_result[i].cls_label != -1) {
std::cout << "cls label: " << ocr_result[i].cls_label
<< " cls score: " << ocr_result[i].cls_score;
}
std::cout << std::endl;
}
}
} // namespace PaddleOCR
\ No newline at end of file
......@@ -29,8 +29,7 @@ def read_params():
cfg.rec_model_dir = "./inference/ch_PP-OCRv2_rec_infer/"
cfg.rec_image_shape = "3, 32, 320"
cfg.rec_char_type = 'ch'
cfg.rec_batch_num = 30
cfg.rec_batch_num = 6
cfg.max_text_length = 25
cfg.rec_char_dict_path = "./ppocr/utils/ppocr_keys_v1.txt"
......
......@@ -47,8 +47,7 @@ def read_params():
cfg.rec_model_dir = "./inference/ch_PP-OCRv2_rec_infer/"
cfg.rec_image_shape = "3, 32, 320"
cfg.rec_char_type = 'ch'
cfg.rec_batch_num = 30
cfg.rec_batch_num = 6
cfg.max_text_length = 25
cfg.rec_char_dict_path = "./ppocr/utils/ppocr_keys_v1.txt"
......
......@@ -188,7 +188,7 @@ hub serving start -c deploy/hubserving/ocr_system/config.json
- **output**:可视化结果保存路径,默认为`./hubserving_result`
访问示例:
```python tools/test_hubserving.py --server_url=http://127.0.0.1:8868/predict/ocr_system --image_dir./doc/imgs/ --visualize=false```
```python tools/test_hubserving.py --server_url=http://127.0.0.1:8868/predict/ocr_system --image_dir=./doc/imgs/ --visualize=false```
## 4. 返回结果格式说明
返回结果为列表(list),列表中的每一项为词典(dict),词典一共可能包含3种字段,信息如下:
......
......@@ -196,7 +196,7 @@ For example, if using the configuration file to start the text angle classificat
**Eg.**
```shell
python tools/test_hubserving.py --server_url=http://127.0.0.1:8868/predict/ocr_system --image_dir./doc/imgs/ --visualize=false`
python tools/test_hubserving.py --server_url=http://127.0.0.1:8868/predict/ocr_system --image_dir=./doc/imgs/ --visualize=false`
```
## 4. Returned result format
......
......@@ -25,7 +25,6 @@ def read_params():
# params for table structure model
cfg.table_max_len = 488
cfg.table_model_dir = './inference/en_ppocr_mobile_v2.0_table_structure_infer/'
cfg.table_char_type = 'en'
cfg.table_char_dict_path = './ppocr/utils/dict/table_structure_dict.txt'
cfg.show_log = False
return cfg
......@@ -42,7 +42,7 @@ python deploy/slim/quantization/quant.py -c configs/det/ch_ppocr_v2.0/ch_det_mv3
# 比如下载提供的训练模型
wget https://paddleocr.bj.bcebos.com/dygraph_v2.0/ch/ch_ppocr_mobile_v2.0_det_train.tar
tar -xf ch_ppocr_mobile_v2.0_det_train.tar
python deploy/slim/quantization/quant.py -c configs/det/ch_ppocr_v2.0/ch_det_mv3_db_v2.0.yml -o Global.pretrained_model=./ch_ppocr_mobile_v2.0_det_train/best_accuracy Global.save_model_dir=./output/quant_inference_model
python deploy/slim/quantization/quant.py -c configs/det/ch_ppocr_v2.0/ch_det_mv3_db_v2.0.yml -o Global.pretrained_model=./ch_ppocr_mobile_v2.0_det_train/best_accuracy Global.save_model_dir=./output/quant_model
```
如果要训练识别模型的量化,修改配置文件和加载的模型参数即可。
......
......@@ -127,11 +127,13 @@ def main():
arch_config = config["Architecture"]
if arch_config["algorithm"] in ["Distillation", ]: # distillation model
for idx, name in enumerate(model.model_name_list):
model.model_list[idx].eval()
sub_model_save_path = os.path.join(save_path, name, "inference")
export_single_model(quanter, model.model_list[idx], infer_shape,
sub_model_save_path, logger)
else:
save_path = os.path.join(save_path, "inference")
model.eval()
export_single_model(quanter, model, infer_shape, save_path, logger)
......
......@@ -3,12 +3,13 @@
本文介绍针对PP-OCR模型库的Python推理引擎使用方法,内容依次为文本检测、文本识别、方向分类器以及三者串联在CPU、GPU上的预测方法。
- [1. 文本检测模型推理](#文本检测模型推理)
- [2. 文本识别模型推理](#文本识别模型推理)
- [2.1 超轻量中文识别模型推理](#超轻量中文识别模型推理)
- [2.2 多语言模型的推理](#多语言模型的推理)
- [3. 方向分类模型推理](#方向分类模型推理)
- [4. 文本检测、方向分类和文字识别串联推理](#文本检测、方向分类和文字识别串联推理)
- [基于Python引擎的PP-OCR模型库推理](#基于python引擎的pp-ocr模型库推理)
- [1. 文本检测模型推理](#1-文本检测模型推理)
- [2. 文本识别模型推理](#2-文本识别模型推理)
- [2.1 超轻量中文识别模型推理](#21-超轻量中文识别模型推理)
- [2.2 多语言模型的推理](#22-多语言模型的推理)
- [3. 方向分类模型推理](#3-方向分类模型推理)
- [4. 文本检测、方向分类和文字识别串联推理](#4-文本检测方向分类和文字识别串联推理)
<a name="文本检测模型推理"></a>
......@@ -82,7 +83,7 @@ Predicts of ./doc/imgs_words/ch/word_4.jpg:('实力活力', 0.98458153)
如果您需要预测的是其他语言模型,可以在[此链接](./models_list.md#%E5%A4%9A%E8%AF%AD%E8%A8%80%E8%AF%86%E5%88%AB%E6%A8%A1%E5%9E%8B)中找到对应语言的inference模型,在使用inference模型预测时,需要通过`--rec_char_dict_path`指定使用的字典路径, 同时为了得到正确的可视化结果,需要通过 `--vis_font_path` 指定可视化的字体路径,`doc/fonts/` 路径下有默认提供的小语种字体,例如韩文识别:
```
wget https://paddleocr.bj.bcebos.com/dygraph_v2.0/multilingual/korean_mobile_v2.0_rec_infer.tar
python3 tools/infer/predict_rec.py --image_dir="./doc/imgs_words/korean/1.jpg" --rec_model_dir="./your inference model" --rec_char_type="korean" --rec_char_dict_path="ppocr/utils/dict/korean_dict.txt" --vis_font_path="doc/fonts/korean.ttf"
python3 tools/infer/predict_rec.py --image_dir="./doc/imgs_words/korean/1.jpg" --rec_model_dir="./your inference model" --rec_char_dict_path="ppocr/utils/dict/korean_dict.txt" --vis_font_path="doc/fonts/korean.ttf"
```
![](../imgs_words/korean/1.jpg)
......
......@@ -13,6 +13,7 @@
- [2.2.1 中英文与多语言使用](#221)
- [3.小结](#3)
<a name="1"></a>
## 1. 安装
......
......@@ -2,19 +2,20 @@
本文提供了PaddleOCR文本识别任务的全流程指南,包括数据准备、模型训练、调优、评估、预测,各个阶段的详细说明:
- [1 数据准备](#数据准备)
- [1.1 自定义数据集](#自定义数据集)
- [1.2 数据下载](#数据下载)
- [1.3 字典](#字典)
- [1.4 支持空格](#支持空格)
- [2 启动训练](#启动训练)
- [2.1 数据增强](#数据增强)
- [2.2 通用模型训练](#通用模型训练)
- [2.3 多语言模型训练](#多语言模型训练)
- [2.4 知识蒸馏训练](#知识蒸馏训练)
- [3 评估](#评估)
- [4 预测](#预测)
- [5 转Inference模型测试](#Inference)
- [文字识别](#文字识别)
- [1. 数据准备](#1-数据准备)
- [1.1 自定义数据集](#11-自定义数据集)
- [1.2 数据下载](#12-数据下载)
- [1.3 字典](#13-字典)
- [1.4 添加空格类别](#14-添加空格类别)
- [2. 启动训练](#2-启动训练)
- [2.1 数据增强](#21-数据增强)
- [2.2 通用模型训练](#22-通用模型训练)
- [2.3 多语言模型训练](#23-多语言模型训练)
- [2.4 知识蒸馏训练](#24-知识蒸馏训练)
- [3 评估](#3-评估)
- [4 预测](#4-预测)
- [5. 转Inference模型测试](#5-转inference模型测试)
<a name="数据准备"></a>
......@@ -477,8 +478,8 @@ python3 tools/export_model.py -c configs/rec/ch_ppocr_v2.0/rec_chinese_lite_trai
- 自定义模型推理
如果训练时修改了文本的字典,在使用inference模型预测时,需要通过`--rec_char_dict_path`指定使用的字典路径,并且设置 `rec_char_type=ch`
如果训练时修改了文本的字典,在使用inference模型预测时,需要通过`--rec_char_dict_path`指定使用的字典路径
```
python3 tools/infer/predict_rec.py --image_dir="./doc/imgs_words_en/word_336.png" --rec_model_dir="./your inference model" --rec_image_shape="3, 32, 100" --rec_char_type="ch" --rec_char_dict_path="your text dict path"
python3 tools/infer/predict_rec.py --image_dir="./doc/imgs_words_en/word_336.png" --rec_model_dir="./your inference model" --rec_image_shape="3, 32, 100" --rec_char_dict_path="your text dict path"
```
......@@ -98,7 +98,6 @@ def read_params():
cfg.rec_model_dir = "./ocr_rec_server/" # 识别算法模型路径
cfg.rec_image_shape = "3, 32, 320"
cfg.rec_char_type = 'ch'
cfg.rec_batch_num = 30
cfg.max_text_length = 25
......
......@@ -401,7 +401,6 @@ im_show.save('result.jpg')
| rec_algorithm | 使用的识别算法类型 | CRNN |
| rec_model_dir | 识别模型所在文件夹。传参方式有两种,1. None: 自动下载内置模型到 `~/.paddleocr/rec`;2.自己转换好的inference模型路径,模型路径下必须包含model和params文件 | None |
| rec_image_shape | 识别算法的输入图片尺寸 | "3,32,320" |
| rec_char_type | 识别算法的字符类型,中英文(ch)、英文(en)、法语(french)、德语(german)、韩语(korean)、日语(japan) | ch |
| rec_batch_num | 进行识别时,同时前向的图片数 | 30 |
| max_text_length | 识别算法能识别的最大文字长度 | 25 |
| rec_char_dict_path | 识别模型字典路径,当rec_model_dir使用方式2传参时需要修改为自己的字典路径 | ./ppocr/utils/ppocr_keys_v1.txt |
......
......@@ -296,7 +296,7 @@ Predicts of ./doc/imgs_words_en/word_336.png:('super', 0.9999073)
- The image resolution used in training is different: the image resolution used in training the above model is [3,32,100], while during our Chinese model training, in order to ensure the recognition effect of long text, the image resolution used in training is [3, 32, 320]. The default shape parameter of the inference stage is the image resolution used in training phase, that is [3, 32, 320]. Therefore, when running inference of the above English model here, you need to set the shape of the recognition image through the parameter `rec_image_shape`.
- Character list: the experiment in the DTRB paper is only for 26 lowercase English characters and 10 numbers, a total of 36 characters. All upper and lower case characters are converted to lower case characters, and characters not in the above list are ignored and considered as spaces. Therefore, no characters dictionary file is used here, but a dictionary is generated by the below command. Therefore, the parameter `rec_char_type` needs to be set during inference, which is specified as "en" in English.
- Character list: the experiment in the DTRB paper is only for 26 lowercase English characters and 10 numbers, a total of 36 characters. All upper and lower case characters are converted to lower case characters, and characters not in the above list are ignored and considered as spaces. Therefore, no characters dictionary file is used here, but a dictionary is generated by the below command.
```
self.character_str = "0123456789abcdefghijklmnopqrstuvwxyz"
......@@ -320,7 +320,7 @@ python3 tools/infer/predict_rec.py --image_dir="./doc/imgs_words_en/word_336.png
<a name="USING_CUSTOM_CHARACTERS"></a>
### 3.4 Text Recognition Model Inference Using Custom Characters Dictionary
If the text dictionary is modified during training, when using the inference model to predict, you need to specify the dictionary path used by `--rec_char_dict_path`, and set `rec_char_type=ch`
If the text dictionary is modified during training, when using the inference model to predict, you need to specify the dictionary path used by `--rec_char_dict_path`
```
python3 tools/infer/predict_rec.py --image_dir="./doc/imgs_words_en/word_336.png" --rec_model_dir="./your inference model" --rec_image_shape="3, 32, 100" --rec_char_dict_path="your text dict path"
......
......@@ -4,12 +4,13 @@
This article introduces the use of the Python inference engine for the PP-OCR model library. The content is in order of text detection, text recognition, direction classifier and the prediction method of the three in series on the CPU and GPU.
- [Text Detection Model Inference](#DETECTION_MODEL_INFERENCE)
- [Text Recognition Model Inference](#RECOGNITION_MODEL_INFERENCE)
- [1. Lightweight Chinese Recognition Model Inference](#LIGHTWEIGHT_RECOGNITION)
- [2. Multilingual Model Inference](#MULTILINGUAL_MODEL_INFERENCE)
- [Angle Classification Model Inference](#ANGLE_CLASS_MODEL_INFERENCE)
- [Text Detection Angle Classification and Recognition Inference Concatenation](#CONCATENATION)
- [Python Inference for PP-OCR Model Zoo](#python-inference-for-pp-ocr-model-zoo)
- [Text Detection Model Inference](#text-detection-model-inference)
- [Text Recognition Model Inference](#text-recognition-model-inference)
- [1. Lightweight Chinese Recognition Model Inference](#1-lightweight-chinese-recognition-model-inference)
- [2. Multilingual Model Inference](#2-multilingual-model-inference)
- [Angle Classification Model Inference](#angle-classification-model-inference)
- [Text Detection Angle Classification and Recognition Inference Concatenation](#text-detection-angle-classification-and-recognition-inference-concatenation)
<a name="DETECTION_MODEL_INFERENCE"></a>
......@@ -82,7 +83,7 @@ You need to specify the visual font path through `--vis_font_path`. There are sm
```
wget wget https://paddleocr.bj.bcebos.com/dygraph_v2.0/multilingual/korean_mobile_v2.0_rec_infer.tar
python3 tools/infer/predict_rec.py --image_dir="./doc/imgs_words/korean/1.jpg" --rec_model_dir="./your inference model" --rec_char_type="korean" --rec_char_dict_path="ppocr/utils/dict/korean_dict.txt" --vis_font_path="doc/fonts/korean.ttf"
python3 tools/infer/predict_rec.py --image_dir="./doc/imgs_words/korean/1.jpg" --rec_model_dir="./your inference model" --rec_char_dict_path="ppocr/utils/dict/korean_dict.txt" --vis_font_path="doc/fonts/korean.ttf"
```
![](../imgs_words/korean/1.jpg)
......
- [PaddleOCR Quick Start](#paddleocr-quick-start)
- [1. Installation](#1-installation)
- [1.1 Install PaddlePaddle](#11-install-paddlepaddle)
- [1.2 Install PaddleOCR Whl Package](#12-install-paddleocr-whl-package)
- [2. Easy-to-Use](#2-easy-to-use)
- [2.1 Use by Command Line](#21-use-by-command-line)
- [2.1.1 Chinese and English Model](#211-chinese-and-english-model)
- [2.1.2 Multi-language Model](#212-multi-language-model)
- [2.1.3 Layout Analysis](#213-layout-analysis)
- [2.2 Use by Code](#22-use-by-code)
- [2.2.1 Chinese & English Model and Multilingual Model](#221-chinese--english-model-and-multilingual-model)
- [2.2.2 Layout Analysis](#222-layout-analysis)
- [3. Summary](#3-summary)
# PaddleOCR Quick Start
+ [1. Installation](#1installation)
+ [1.1 Install PaddlePaddle](#11-install-paddlepaddle)
+ [1.2 Install PaddleOCR Whl Package](#12-install-paddleocr-whl-package)
* [2. Easy-to-Use](#2-easy-to-use)
+ [2.1 Use by Command Line](#21-use-by-command-line)
- [2.1.1 English and Chinese Model](#211-english-and-chinese-model)
- [2.1.2 Multi-language Model](#212-multi-language-model)
- [2.1.3 Layout Analysis](#213-layoutAnalysis)
+ [2.2 Use by Code](#22-use-by-code)
- [2.2.1 Chinese & English Model and Multilingual Model](#221-chinese---english-model-and-multilingual-model)
- [2.2.2 Layout Analysis](#222-layoutAnalysis)
* [3. Summary](#3)
<a name="1nstallation"></a>
......@@ -196,7 +197,7 @@ paddleocr --image_dir=../doc/table/1.png --type=structure
| output | The path where excel and recognition results are saved | ./output/table |
| table_max_len | The long side of the image is resized in table structure model | 488 |
| table_model_dir | inference model path of table structure model | None |
| table_char_type | dict path of table structure model | ../ppocr/utils/dict/table_structure_dict.txt |
| table_char_dict_path | dict path of table structure model | ../ppocr/utils/dict/table_structure_dict.txt |
<a name="22-use-by-code"></a>
......
......@@ -470,8 +470,8 @@ inference/det_db/
- Text recognition model Inference using custom characters dictionary
If the text dictionary is modified during training, when using the inference model to predict, you need to specify the dictionary path used by `--rec_char_dict_path`, and set `rec_char_type=ch`
If the text dictionary is modified during training, when using the inference model to predict, you need to specify the dictionary path used by `--rec_char_dict_path`
```
python3 tools/infer/predict_rec.py --image_dir="./doc/imgs_words_en/word_336.png" --rec_model_dir="./your inference model" --rec_image_shape="3, 32, 100" --rec_char_type="ch" --rec_char_dict_path="your text dict path"
python3 tools/infer/predict_rec.py --image_dir="./doc/imgs_words_en/word_336.png" --rec_model_dir="./your inference model" --rec_image_shape="3, 32, 100" --rec_char_dict_path="your text dict path"
```
......@@ -348,7 +348,6 @@ im_show.save('result.jpg')
| rec_algorithm | Type of recognition algorithm selected | CRNN |
| rec_model_dir | the text recognition inference model folder. There are two ways to transfer parameters, 1. None: Automatically download the built-in model to `~/.paddleocr/rec`; 2. The path of the inference model converted by yourself, the model and params files must be included in the model path | None |
| rec_image_shape | image shape of recognition algorithm | "3,32,320" |
| rec_char_type | Character type of recognition algorithm, Chinese (ch) or English (en) | ch |
| rec_batch_num | When performing recognition, the batchsize of forward images | 30 |
| max_text_length | The maximum text length that the recognition algorithm can recognize | 25 |
| rec_char_dict_path | the alphabet path which needs to be modified to your own path when `rec_model_Name` use mode 2 | ./ppocr/utils/ppocr_keys_v1.txt |
......
doc/joinus.PNG

201.0 KB | W: | H:

doc/joinus.PNG

201.3 KB | W: | H:

doc/joinus.PNG
doc/joinus.PNG
doc/joinus.PNG
doc/joinus.PNG
  • 2-up
  • Swipe
  • Onion skin
......@@ -63,7 +63,7 @@ class FCEHead(nn.Layer):
weight_attr=ParamAttr(
name='cls_weights',
initializer=Normal(
mean=paddle.to_tensor(0.), std=paddle.to_tensor(0.01))),
mean=0., std=0.01)),
bias_attr=True)
self.out_conv_reg = nn.Conv2D(
in_channels=self.in_channels,
......@@ -75,7 +75,7 @@ class FCEHead(nn.Layer):
weight_attr=ParamAttr(
name='reg_weights',
initializer=Normal(
mean=paddle.to_tensor(0.), std=paddle.to_tensor(0.01))),
mean=0., std=0.01)),
bias_attr=True)
def forward(self, feats, targets=None):
......
......@@ -80,7 +80,6 @@ class CTCHead(nn.Layer):
result = (x, predicts)
else:
result = predicts
if not self.training:
predicts = F.softmax(predicts, axis=2)
result = predicts
......
......@@ -89,7 +89,7 @@ class CTCLabelDecode(BaseRecLabelDecode):
use_space_char)
def __call__(self, preds, label=None, *args, **kwargs):
if isinstance(preds, tuple):
if isinstance(preds, tuple) or isinstance(preds, list):
preds = preds[-1]
if isinstance(preds, paddle.Tensor):
preds = preds.numpy()
......
......@@ -117,7 +117,7 @@ teds: 93.32
```python
cd PaddleOCR/ppstructure
python3 table/predict_table.py --det_model_dir=path/to/det_model_dir --rec_model_dir=path/to/rec_model_dir --table_model_dir=path/to/table_model_dir --image_dir=../doc/table/1.png --rec_char_dict_path=../ppocr/utils/dict/table_dict.txt --table_char_dict_path=../ppocr/utils/dict/table_structure_dict.txt --rec_char_type=EN --det_limit_side_len=736 --det_limit_type=min --output ../output/table
python3 table/predict_table.py --det_model_dir=path/to/det_model_dir --rec_model_dir=path/to/rec_model_dir --table_model_dir=path/to/table_model_dir --image_dir=../doc/table/1.png --rec_char_dict_path=../ppocr/utils/dict/table_dict.txt --table_char_dict_path=../ppocr/utils/dict/table_structure_dict.txt --det_limit_side_len=736 --det_limit_type=min --output ../output/table
```
After running, the excel sheet of each picture will be saved in the directory specified by the output field
......
......@@ -128,7 +128,7 @@ teds: 93.32
```python
cd PaddleOCR/ppstructure
python3 table/predict_table.py --det_model_dir=path/to/det_model_dir --rec_model_dir=path/to/rec_model_dir --table_model_dir=path/to/table_model_dir --image_dir=../doc/table/1.png --rec_char_dict_path=../ppocr/utils/dict/table_dict.txt --table_char_dict_path=../ppocr/utils/dict/table_structure_dict.txt --rec_char_type=EN --det_limit_side_len=736 --det_limit_type=min --output ../output/table
python3 table/predict_table.py --det_model_dir=path/to/det_model_dir --rec_model_dir=path/to/rec_model_dir --table_model_dir=path/to/table_model_dir --image_dir=../doc/table/1.png --rec_char_dict_path=../ppocr/utils/dict/table_dict.txt --table_char_dict_path=../ppocr/utils/dict/table_structure_dict.txt --det_limit_side_len=736 --det_limit_type=min --output ../output/table
```
# Reference
......
......@@ -58,7 +58,6 @@ class TableStructurer(object):
}]
postprocess_params = {
'name': 'TableLabelDecode',
"character_type": args.table_char_type,
"character_dict_path": args.table_char_dict_path,
}
......@@ -104,7 +103,9 @@ class TableStructurer(object):
res_loc_final.append([left, top, right, bottom])
structure_str_list = structure_str_list[0][:-1]
structure_str_list = ['<html>', '<body>', '<table>'] + structure_str_list + ['</table>', '</body>', '</html>']
structure_str_list = [
'<html>', '<body>', '<table>'
] + structure_str_list + ['</table>', '</body>', '</html>']
elapse = time.time() - starttime
return (structure_str_list, res_loc_final), elapse
......
......@@ -26,7 +26,6 @@ def init_args():
# params for table structure
parser.add_argument("--table_max_len", type=int, default=488)
parser.add_argument("--table_model_dir", type=str)
parser.add_argument("--table_char_type", type=str, default='en')
parser.add_argument(
"--table_char_dict_path",
type=str,
......
# 文档视觉问答(DocVQA)
English | [简体中文](README_ch.md)
- [1. 简介](#1)
- [2. 性能](#2)
- [3. 效果演示](#3)
- [3.1 SER](#31)
- [3.2 RE](#32)
- [4. 安装](#4)
- [4.1 安装依赖](#41)
- [4.2 安装PaddleOCR](#42)
- [5. 使用](#5)
- [5.1 数据和预训练模型准备](#51)
- [5.2 SER](#52)
- [5.3 RE](#53)
- [6. 参考链接](#6)
- [Document Visual Question Answering (Doc-VQA)](#Document-Visual-Question-Answering)
- [1. Introduction](#1-Introduction)
- [2. Performance](#2-performance)
- [3. Effect demo](#3-Effect-demo)
- [3.1 SER](#31-ser)
- [3.2 RE](#32-re)
- [4. Install](#4-Install)
- [4.1 Installation dependencies](#41-Install-dependencies)
- [4.2 Install PaddleOCR](#42-Install-PaddleOCR)
- [5. Usage](#5-Usage)
- [5.1 Data and Model Preparation](#51-Data-and-Model-Preparation)
- [5.2 SER](#52-ser)
- [5.3 RE](#53-re)
- [6. Reference](#6-Reference-Links)
# Document Visual Question Answering
## 1 Introduction
<a name="1"></a>
## 1. 简介
VQA refers to visual question answering, which mainly asks and answers image content. DOC-VQA is one of the VQA tasks. DOC-VQA mainly asks questions about the text content of text images.
VQA指视觉问答,主要针对图像内容进行提问和回答,DocVQA是VQA任务中的一种,DocVQA主要针对文本图像的文字内容提出问题。
The DOC-VQA algorithm in PP-Structure is developed based on the PaddleNLP natural language processing algorithm library.
PP-Structure 里的DocVQA算法基于PaddleNLP自然语言处理算法库进行开发。
The main features are as follows:
主要特性如下:
- Integrate [LayoutXLM](https://arxiv.org/pdf/2104.08836.pdf) model and PP-OCR prediction engine.
- Supports Semantic Entity Recognition (SER) and Relation Extraction (RE) tasks based on multimodal methods. Based on the SER task, the text recognition and classification in the image can be completed; based on the RE task, the relationship extraction of the text content in the image can be completed, such as judging the problem pair (pair).
- Supports custom training for SER tasks and RE tasks.
- Supports end-to-end system prediction and evaluation of OCR+SER.
- Supports end-to-end system prediction of OCR+SER+RE.
- 集成[LayoutXLM](https://arxiv.org/pdf/2104.08836.pdf)模型以及PP-OCR预测引擎。
- 支持基于多模态方法的语义实体识别 (Semantic Entity Recognition, SER) 以及关系抽取 (Relation Extraction, RE) 任务。基于 SER 任务,可以完成对图像中的文本识别与分类;基于 RE 任务,可以完成对图象中的文本内容的关系提取,如判断问题对(pair)。
- 支持SER任务和RE任务的自定义训练。
- 支持OCR+SER的端到端系统预测与评估。
- 支持OCR+SER+RE的端到端系统预测。
This project is an open source implementation of [LayoutXLM: Multimodal Pre-training for Multilingual Visually-rich Document Understanding](https://arxiv.org/pdf/2104.08836.pdf) on Paddle 2.2,
Included fine-tuning code on [XFUND dataset](https://github.com/doc-analysis/XFUND).
本项目是 [LayoutXLM: Multimodal Pre-training for Multilingual Visually-rich Document Understanding](https://arxiv.org/pdf/2104.08836.pdf) 在 Paddle 2.2上的开源实现,
包含了在 [XFUND数据集](https://github.com/doc-analysis/XFUND) 上的微调代码。
## 2. Performance
<a name="2"></a>
## 2. 性能
We evaluate the algorithm on the Chinese dataset of [XFUND](https://github.com/doc-analysis/XFUND), and the performance is as follows
我们在 [XFUN](https://github.com/doc-analysis/XFUND) 的中文数据集上对算法进行了评估,性能如下
| 模型 | 任务 | hmean | 模型下载地址 |
| Model | Task | hmean | Model download address |
|:---:|:---:|:---:| :---:|
| LayoutXLM | SER | 0.9038 | [链接](https://paddleocr.bj.bcebos.com/pplayout/ser_LayoutXLM_xfun_zh.tar) |
| LayoutXLM | RE | 0.7483 | [链接](https://paddleocr.bj.bcebos.com/pplayout/re_LayoutXLM_xfun_zh.tar) |
| LayoutLMv2 | SER | 0.8544 | [链接](https://paddleocr.bj.bcebos.com/pplayout/ser_LayoutLMv2_xfun_zh.tar)
| LayoutLMv2 | RE | 0.6777 | [链接](https://paddleocr.bj.bcebos.com/pplayout/re_LayoutLMv2_xfun_zh.tar) |
| LayoutLM | SER | 0.7731 | [链接](https://paddleocr.bj.bcebos.com/pplayout/ser_LayoutLM_xfun_zh.tar) |
| LayoutXLM | SER | 0.9038 | [link](https://paddleocr.bj.bcebos.com/pplayout/ser_LayoutXLM_xfun_zh.tar) |
| LayoutXLM | RE | 0.7483 | [link](https://paddleocr.bj.bcebos.com/pplayout/re_LayoutXLM_xfun_zh.tar) |
| LayoutLMv2 | SER | 0.8544 | [link](https://paddleocr.bj.bcebos.com/pplayout/ser_LayoutLMv2_xfun_zh.tar)
| LayoutLMv2 | RE | 0.6777 | [link](https://paddleocr.bj.bcebos.com/pplayout/re_LayoutLMv2_xfun_zh.tar) |
| LayoutLM | SER | 0.7731 | [link](https://paddleocr.bj.bcebos.com/pplayout/ser_LayoutLM_xfun_zh.tar) |
<a name="3"></a>
## 3. 效果演示
## 3. Effect demo
**注意:** 测试图片来源于XFUN数据集。
**Note:** The test images are from the XFUND dataset.
<a name="31"></a>
### 3.1 SER
......@@ -59,13 +57,13 @@ PP-Structure 里的DocVQA算法基于PaddleNLP自然语言处理算法库进行
![](../docs/vqa/result_ser/zh_val_0_ser.jpg) | ![](../docs/vqa/result_ser/zh_val_42_ser.jpg)
---|---
图中不同颜色的框表示不同的类别,对于XFUN数据集,有`QUESTION`, `ANSWER`, `HEADER` 3种类别
Boxes with different colors in the figure represent different categories. For the XFUND dataset, there are 3 categories: `QUESTION`, `ANSWER`, `HEADER`
* 深紫色:HEADER
* 浅紫色:QUESTION
* 军绿色:ANSWER
* Dark purple: HEADER
* Light purple: QUESTION
* Army Green: ANSWER
在OCR检测框的左上方也标出了对应的类别和OCR识别结果。
The corresponding categories and OCR recognition results are also marked on the upper left of the OCR detection frame.
<a name="32"></a>
### 3.2 RE
......@@ -74,183 +72,187 @@ PP-Structure 里的DocVQA算法基于PaddleNLP自然语言处理算法库进行
---|---
图中红色框表示问题,蓝色框表示答案,问题和答案之间使用绿色线连接。在OCR检测框的左上方也标出了对应的类别和OCR识别结果。
The red box in the figure represents the question, the blue box represents the answer, and the question and the answer are connected by a green line. The corresponding categories and OCR recognition results are also marked on the upper left of the OCR detection frame.
<a name="4"></a>
## 4. 安装
## 4. Install
<a name="41"></a>
### 4.1 安装依赖
### 4.1 Install dependencies
- **(1) 安装PaddlePaddle**
- **(1) Install PaddlePaddle**
```bash
python3 -m pip install --upgrade pip
# GPU安装
# GPU installation
python3 -m pip install "paddlepaddle-gpu>=2.2" -i https://mirror.baidu.com/pypi/simple
# CPU安装
# CPU installation
python3 -m pip install "paddlepaddle>=2.2" -i https://mirror.baidu.com/pypi/simple
```
更多需求,请参照[安装文档](https://www.paddlepaddle.org.cn/install/quick)中的说明进行操作。
````
For more requirements, please refer to the instructions in [Installation Documentation](https://www.paddlepaddle.org.cn/install/quick).
<a name="42"></a>
### 4.2 安装PaddleOCR
### 4.2 Install PaddleOCR
- **(1)pip快速安装PaddleOCR whl包(仅预测)**
- **(1) pip install PaddleOCR whl package quickly (prediction only)**
```bash
python3 -m pip install paddleocr
```
````
- **(2)下载VQA源码(预测+训练)**
- **(2) Download VQA source code (prediction + training)**
```bash
【推荐】git clone https://github.com/PaddlePaddle/PaddleOCR
[Recommended] git clone https://github.com/PaddlePaddle/PaddleOCR
# 如果因为网络问题无法pull成功,也可选择使用码云上的托管:
# If the pull cannot be successful due to network problems, you can also choose to use the hosting on the code cloud:
git clone https://gitee.com/paddlepaddle/PaddleOCR
# 注:码云托管代码可能无法实时同步本github项目更新,存在3~5天延时,请优先使用推荐方式。
```
# Note: Code cloud hosting code may not be able to synchronize the update of this github project in real time, there is a delay of 3 to 5 days, please use the recommended method first.
````
- **(3)安装VQA的`requirements`**
- **(3) Install VQA's `requirements`**
```bash
python3 -m pip install -r ppstructure/vqa/requirements.txt
```
````
<a name="5"></a>
## 5. 使用
## 5. Usage
<a name="51"></a>
### 5.1 数据和预训练模型准备
### 5.1 Data and Model Preparation
如果希望直接体验预测过程,可以下载我们提供的预训练模型,跳过训练过程,直接预测即可。
If you want to experience the prediction process directly, you can download the pre-training model provided by us, skip the training process, and just predict directly.
* 下载处理好的数据集
* Download the processed dataset
处理好的XFUN中文数据集下载地址:[https://paddleocr.bj.bcebos.com/dataset/XFUND.tar](https://paddleocr.bj.bcebos.com/dataset/XFUND.tar)
The download address of the processed XFUND Chinese dataset: [https://paddleocr.bj.bcebos.com/dataset/XFUND.tar](https://paddleocr.bj.bcebos.com/dataset/XFUND.tar).
下载并解压该数据集,解压后将数据集放置在当前目录下。
Download and unzip the dataset, and place the dataset in the current directory after unzipping.
```shell
wget https://paddleocr.bj.bcebos.com/dataset/XFUND.tar
```
````
* Convert the dataset
* 转换数据集
If you need to train other XFUND datasets, you can use the following commands to convert the datasets
若需进行其他XFUN数据集的训练,可使用下面的命令进行数据集的转换
```bash
python3 ppstructure/vqa/tools/trans_xfun_data.py --ori_gt_path=path/to/json_path --output_path=path/to/save_path
````
* Download the pretrained models
```bash
python3 ppstructure/vqa/helper/trans_xfun_data.py --ori_gt_path=path/to/json_path --output_path=path/to/save_path
```
mkdir pretrain && cd pretrain
#download the SER model
wget https://paddleocr.bj.bcebos.com/pplayout/ser_LayoutXLM_xfun_zh.tar && tar -xvf ser_LayoutXLM_xfun_zh.tar
#download the RE model
wget https://paddleocr.bj.bcebos.com/pplayout/re_LayoutXLM_xfun_zh.tar && tar -xvf re_LayoutXLM_xfun_zh.tar
cd ../
````
<a name="52"></a>
### 5.2 SER
启动训练之前,需要修改下面的四个字段
Before starting training, you need to modify the following four fields
1. `Train.dataset.data_dir`:指向训练集图片存放目录
2. `Train.dataset.label_file_list`:指向训练集标注文件
3. `Eval.dataset.data_dir`:指指向验证集图片存放目录
4. `Eval.dataset.label_file_list`:指向验证集标注文件
1. `Train.dataset.data_dir`: point to the directory where the training set images are stored
2. `Train.dataset.label_file_list`: point to the training set label file
3. `Eval.dataset.data_dir`: refers to the directory where the validation set images are stored
4. `Eval.dataset.label_file_list`: point to the validation set label file
* 启动训练
* start training
```shell
CUDA_VISIBLE_DEVICES=0 python3 tools/train.py -c configs/vqa/ser/layoutxlm.yml
```
````
最终会打印出`precision`, `recall`, `hmean`等指标。
`./output/ser_layoutxlm/`文件夹中会保存训练日志,最优的模型和最新epoch的模型。
Finally, `precision`, `recall`, `hmean` and other indicators will be printed.
In the `./output/ser_layoutxlm/` folder will save the training log, the optimal model and the model for the latest epoch.
* 恢复训练
* resume training
恢复训练需要将之前训练好的模型所在文件夹路径赋值给 `Architecture.Backbone.checkpoints` 字段。
To resume training, assign the folder path of the previously trained model to the `Architecture.Backbone.checkpoints` field.
```shell
CUDA_VISIBLE_DEVICES=0 python3 tools/train.py -c configs/vqa/ser/layoutxlm.yml -o Architecture.Backbone.checkpoints=path/to/model_dir
```
````
* 评估
* evaluate
评估需要将待评估的模型所在文件夹路径赋值给 `Architecture.Backbone.checkpoints` 字段。
Evaluation requires assigning the folder path of the model to be evaluated to the `Architecture.Backbone.checkpoints` field.
```shell
CUDA_VISIBLE_DEVICES=0 python3 tools/eval.py -c configs/vqa/ser/layoutxlm.yml -o Architecture.Backbone.checkpoints=path/to/model_dir
```
最终会打印出`precision`, `recall`, `hmean`等指标
````
Finally, `precision`, `recall`, `hmean` and other indicators will be printed
* 使用`OCR引擎 + SER`串联预测
* Use `OCR engine + SER` tandem prediction
使用如下命令即可完成`OCR引擎 + SER`的串联预测
Use the following command to complete the series prediction of `OCR engine + SER`, taking the pretrained SER model as an example:
```shell
CUDA_VISIBLE_DEVICES=0 python3 tools/infer_vqa_token_ser.py -c configs/vqa/ser/layoutxlm.yml -o Architecture.Backbone.checkpoints=ser_LayoutXLM_xfun_zh/ Global.infer_img=doc/vqa/input/zh_val_42.jpg
```
CUDA_VISIBLE_DEVICES=0 python3 tools/infer_vqa_token_ser.py -c configs/vqa/ser/layoutxlm.yml -o Architecture.Backbone.checkpoints=pretrain/ser_LayoutXLM_xfun_zh/Global.infer_img=doc/vqa/input/zh_val_42.jpg
````
最终会在`config.Global.save_res_path`字段所配置的目录下保存预测结果可视化图像以及预测结果文本文件,预测结果文本文件名为`infer_results.txt`
Finally, the prediction result visualization image and the prediction result text file will be saved in the directory configured by the `config.Global.save_res_path` field. The prediction result text file is named `infer_results.txt`.
* `OCR引擎 + SER`预测系统进行端到端评估
* End-to-end evaluation of `OCR engine + SER` prediction system
首先使用 `tools/infer_vqa_token_ser.py` 脚本完成数据集的预测,然后使用下面的命令进行评估。
First use the `tools/infer_vqa_token_ser.py` script to complete the prediction of the dataset, then use the following command to evaluate.
```shell
export CUDA_VISIBLE_DEVICES=0
python3 helper/eval_with_label_end2end.py --gt_json_path XFUND/zh_val/xfun_normalize_val.json --pred_json_path output_res/infer_results.txt
```
python3 tools/eval_with_label_end2end.py --gt_json_path XFUND/zh_val/xfun_normalize_val.json --pred_json_path output_res/infer_results.txt
````
<a name="53"></a>
### 5.3 RE
* 启动训练
* start training
启动训练之前,需要修改下面的四个字段
Before starting training, you need to modify the following four fields
1. `Train.dataset.data_dir`:指向训练集图片存放目录
2. `Train.dataset.label_file_list`:指向训练集标注文件
3. `Eval.dataset.data_dir`:指指向验证集图片存放目录
4. `Eval.dataset.label_file_list`:指向验证集标注文件
1. `Train.dataset.data_dir`: point to the directory where the training set images are stored
2. `Train.dataset.label_file_list`: point to the training set label file
3. `Eval.dataset.data_dir`: refers to the directory where the validation set images are stored
4. `Eval.dataset.label_file_list`: point to the validation set label file
```shell
CUDA_VISIBLE_DEVICES=0 python3 tools/train.py -c configs/vqa/re/layoutxlm.yml
```
````
最终会打印出`precision`, `recall`, `hmean`等指标。
`./output/re_layoutxlm/`文件夹中会保存训练日志,最优的模型和最新epoch的模型。
Finally, `precision`, `recall`, `hmean` and other indicators will be printed.
In the `./output/re_layoutxlm/` folder will save the training log, the optimal model and the model for the latest epoch.
* 恢复训练
* resume training
恢复训练需要将之前训练好的模型所在文件夹路径赋值给 `Architecture.Backbone.checkpoints` 字段。
To resume training, assign the folder path of the previously trained model to the `Architecture.Backbone.checkpoints` field.
```shell
CUDA_VISIBLE_DEVICES=0 python3 tools/train.py -c configs/vqa/re/layoutxlm.yml -o Architecture.Backbone.checkpoints=path/to/model_dir
```
````
* 评估
* evaluate
评估需要将待评估的模型所在文件夹路径赋值给 `Architecture.Backbone.checkpoints` 字段。
Evaluation requires assigning the folder path of the model to be evaluated to the `Architecture.Backbone.checkpoints` field.
```shell
CUDA_VISIBLE_DEVICES=0 python3 tools/eval.py -c configs/vqa/re/layoutxlm.yml -o Architecture.Backbone.checkpoints=path/to/model_dir
```
最终会打印出`precision`, `recall`, `hmean`等指标
````
Finally, `precision`, `recall`, `hmean` and other indicators will be printed
* 使用`OCR引擎 + SER + RE`串联预测
* Use `OCR engine + SER + RE` tandem prediction
使用如下命令即可完成`OCR引擎 + SER + RE`的串联预测
Use the following command to complete the series prediction of `OCR engine + SER + RE`, taking the pretrained SER and RE models as an example:
```shell
export CUDA_VISIBLE_DEVICES=0
python3 tools/infer_vqa_token_ser_re.py -c configs/vqa/re/layoutxlm.yml -o Architecture.Backbone.checkpoints=re_LayoutXLM_xfun_zh/ Global.infer_img=doc/vqa/input/zh_val_21.jpg -c_ser configs/vqa/ser/layoutxlm.yml -o_ser Architecture.Backbone.checkpoints=ser_LayoutXLM_xfun_zh/
```
python3 tools/infer_vqa_token_ser_re.py -c configs/vqa/re/layoutxlm.yml -o Architecture.Backbone.checkpoints=pretrain/re_LayoutXLM_xfun_zh/Global.infer_img=doc/vqa/input/zh_val_21.jpg -c_ser configs/vqa/ser/layoutxlm. yml -o_ser Architecture.Backbone.checkpoints=pretrain/ser_LayoutXLM_xfun_zh/
````
最终会在`config.Global.save_res_path`字段所配置的目录下保存预测结果可视化图像以及预测结果文本文件,预测结果文本文件名为`infer_results.txt`
Finally, the prediction result visualization image and the prediction result text file will be saved in the directory configured by the `config.Global.save_res_path` field. The prediction result text file is named `infer_results.txt`.
<a name="6"></a>
## 6. 参考链接
## 6. Reference Links
- LayoutXLM: Multimodal Pre-training for Multilingual Visually-rich Document Understanding, https://arxiv.org/pdf/2104.08836.pdf
- microsoft/unilm/layoutxlm, https://github.com/microsoft/unilm/tree/master/layoutxlm
......
[English](README.md) | 简体中文
- [文档视觉问答(DOC-VQA)](#文档视觉问答doc-vqa)
- [1. 简介](#1-简介)
- [2. 性能](#2-性能)
- [3. 效果演示](#3-效果演示)
- [3.1 SER](#31-ser)
- [3.2 RE](#32-re)
- [4. 安装](#4-安装)
- [4.1 安装依赖](#41-安装依赖)
- [4.2 安装PaddleOCR(包含 PP-OCR 和 VQA)](#42-安装paddleocr包含-pp-ocr-和-vqa)
- [5. 使用](#5-使用)
- [5.1 数据和预训练模型准备](#51-数据和预训练模型准备)
- [5.2 SER](#52-ser)
- [5.3 RE](#53-re)
- [6. 参考链接](#6-参考链接)
# 文档视觉问答(DOC-VQA)
## 1. 简介
VQA指视觉问答,主要针对图像内容进行提问和回答,DOC-VQA是VQA任务中的一种,DOC-VQA主要针对文本图像的文字内容提出问题。
PP-Structure 里的 DOC-VQA算法基于PaddleNLP自然语言处理算法库进行开发。
主要特性如下:
- 集成[LayoutXLM](https://arxiv.org/pdf/2104.08836.pdf)模型以及PP-OCR预测引擎。
- 支持基于多模态方法的语义实体识别 (Semantic Entity Recognition, SER) 以及关系抽取 (Relation Extraction, RE) 任务。基于 SER 任务,可以完成对图像中的文本识别与分类;基于 RE 任务,可以完成对图象中的文本内容的关系提取,如判断问题对(pair)。
- 支持SER任务和RE任务的自定义训练。
- 支持OCR+SER的端到端系统预测与评估。
- 支持OCR+SER+RE的端到端系统预测。
本项目是 [LayoutXLM: Multimodal Pre-training for Multilingual Visually-rich Document Understanding](https://arxiv.org/pdf/2104.08836.pdf) 在 Paddle 2.2上的开源实现,
包含了在 [XFUND数据集](https://github.com/doc-analysis/XFUND) 上的微调代码。
## 2. 性能
我们在 [XFUND](https://github.com/doc-analysis/XFUND) 的中文数据集上对算法进行了评估,性能如下
| 模型 | 任务 | hmean | 模型下载地址 |
|:---:|:---:|:---:| :---:|
| LayoutXLM | SER | 0.9038 | [链接](https://paddleocr.bj.bcebos.com/pplayout/ser_LayoutXLM_xfun_zh.tar) |
| LayoutXLM | RE | 0.7483 | [链接](https://paddleocr.bj.bcebos.com/pplayout/re_LayoutXLM_xfun_zh.tar) |
| LayoutLMv2 | SER | 0.8544 | [链接](https://paddleocr.bj.bcebos.com/pplayout/ser_LayoutLMv2_xfun_zh.tar)
| LayoutLMv2 | RE | 0.6777 | [链接](https://paddleocr.bj.bcebos.com/pplayout/re_LayoutLMv2_xfun_zh.tar) |
| LayoutLM | SER | 0.7731 | [链接](https://paddleocr.bj.bcebos.com/pplayout/ser_LayoutLM_xfun_zh.tar) |
## 3. 效果演示
**注意:** 测试图片来源于XFUND数据集。
### 3.1 SER
![](../../doc/vqa/result_ser/zh_val_0_ser.jpg) | ![](../../doc/vqa/result_ser/zh_val_42_ser.jpg)
---|---
图中不同颜色的框表示不同的类别,对于XFUND数据集,有`QUESTION`, `ANSWER`, `HEADER` 3种类别
* 深紫色:HEADER
* 浅紫色:QUESTION
* 军绿色:ANSWER
在OCR检测框的左上方也标出了对应的类别和OCR识别结果。
### 3.2 RE
![](../../doc/vqa/result_re/zh_val_21_re.jpg) | ![](../../doc/vqa/result_re/zh_val_40_re.jpg)
---|---
图中红色框表示问题,蓝色框表示答案,问题和答案之间使用绿色线连接。在OCR检测框的左上方也标出了对应的类别和OCR识别结果。
## 4. 安装
### 4.1 安装依赖
- **(1) 安装PaddlePaddle**
```bash
python3 -m pip install --upgrade pip
# GPU安装
python3 -m pip install "paddlepaddle-gpu>=2.2" -i https://mirror.baidu.com/pypi/simple
# CPU安装
python3 -m pip install "paddlepaddle>=2.2" -i https://mirror.baidu.com/pypi/simple
```
更多需求,请参照[安装文档](https://www.paddlepaddle.org.cn/install/quick)中的说明进行操作。
### 4.2 安装PaddleOCR(包含 PP-OCR 和 VQA)
- **(1)pip快速安装PaddleOCR whl包(仅预测)**
```bash
python3 -m pip install paddleocr
```
- **(2)下载VQA源码(预测+训练)**
```bash
【推荐】git clone https://github.com/PaddlePaddle/PaddleOCR
# 如果因为网络问题无法pull成功,也可选择使用码云上的托管:
git clone https://gitee.com/paddlepaddle/PaddleOCR
# 注:码云托管代码可能无法实时同步本github项目更新,存在3~5天延时,请优先使用推荐方式。
```
- **(3)安装VQA的`requirements`**
```bash
python3 -m pip install -r ppstructure/vqa/requirements.txt
```
## 5. 使用
### 5.1 数据和预训练模型准备
如果希望直接体验预测过程,可以下载我们提供的预训练模型,跳过训练过程,直接预测即可。
* 下载处理好的数据集
处理好的XFUND中文数据集下载地址:[https://paddleocr.bj.bcebos.com/dataset/XFUND.tar](https://paddleocr.bj.bcebos.com/dataset/XFUND.tar)
下载并解压该数据集,解压后将数据集放置在当前目录下。
```shell
wget https://paddleocr.bj.bcebos.com/dataset/XFUND.tar
```
* 转换数据集
若需进行其他XFUND数据集的训练,可使用下面的命令进行数据集的转换
```bash
python3 ppstructure/vqa/tools/trans_xfun_data.py --ori_gt_path=path/to/json_path --output_path=path/to/save_path
```
* 下载预训练模型
```bash
mkdir pretrain && cd pretrain
#下载SER模型
wget https://paddleocr.bj.bcebos.com/pplayout/ser_LayoutXLM_xfun_zh.tar && tar -xvf ser_LayoutXLM_xfun_zh.tar
#下载RE模型
wget https://paddleocr.bj.bcebos.com/pplayout/re_LayoutXLM_xfun_zh.tar && tar -xvf re_LayoutXLM_xfun_zh.tar
cd ../
```
### 5.2 SER
启动训练之前,需要修改下面的四个字段
1. `Train.dataset.data_dir`:指向训练集图片存放目录
2. `Train.dataset.label_file_list`:指向训练集标注文件
3. `Eval.dataset.data_dir`:指指向验证集图片存放目录
4. `Eval.dataset.label_file_list`:指向验证集标注文件
* 启动训练
```shell
CUDA_VISIBLE_DEVICES=0 python3 tools/train.py -c configs/vqa/ser/layoutxlm.yml
```
最终会打印出`precision`, `recall`, `hmean`等指标。
`./output/ser_layoutxlm/`文件夹中会保存训练日志,最优的模型和最新epoch的模型。
* 恢复训练
恢复训练需要将之前训练好的模型所在文件夹路径赋值给 `Architecture.Backbone.checkpoints` 字段。
```shell
CUDA_VISIBLE_DEVICES=0 python3 tools/train.py -c configs/vqa/ser/layoutxlm.yml -o Architecture.Backbone.checkpoints=path/to/model_dir
```
* 评估
评估需要将待评估的模型所在文件夹路径赋值给 `Architecture.Backbone.checkpoints` 字段。
```shell
CUDA_VISIBLE_DEVICES=0 python3 tools/eval.py -c configs/vqa/ser/layoutxlm.yml -o Architecture.Backbone.checkpoints=path/to/model_dir
```
最终会打印出`precision`, `recall`, `hmean`等指标
* 使用`OCR引擎 + SER`串联预测
使用如下命令即可完成`OCR引擎 + SER`的串联预测, 以SER预训练模型为例:
```shell
CUDA_VISIBLE_DEVICES=0 python3 tools/infer_vqa_token_ser.py -c configs/vqa/ser/layoutxlm.yml -o Architecture.Backbone.checkpoints=pretrain/ser_LayoutXLM_xfun_zh/ Global.infer_img=doc/vqa/input/zh_val_42.jpg
```
最终会在`config.Global.save_res_path`字段所配置的目录下保存预测结果可视化图像以及预测结果文本文件,预测结果文本文件名为`infer_results.txt`
*`OCR引擎 + SER`预测系统进行端到端评估
首先使用 `tools/infer_vqa_token_ser.py` 脚本完成数据集的预测,然后使用下面的命令进行评估。
```shell
export CUDA_VISIBLE_DEVICES=0
python3 tools/eval_with_label_end2end.py --gt_json_path XFUND/zh_val/xfun_normalize_val.json --pred_json_path output_res/infer_results.txt
```
### 5.3 RE
* 启动训练
启动训练之前,需要修改下面的四个字段
1. `Train.dataset.data_dir`:指向训练集图片存放目录
2. `Train.dataset.label_file_list`:指向训练集标注文件
3. `Eval.dataset.data_dir`:指指向验证集图片存放目录
4. `Eval.dataset.label_file_list`:指向验证集标注文件
```shell
CUDA_VISIBLE_DEVICES=0 python3 tools/train.py -c configs/vqa/re/layoutxlm.yml
```
最终会打印出`precision`, `recall`, `hmean`等指标。
`./output/re_layoutxlm/`文件夹中会保存训练日志,最优的模型和最新epoch的模型。
* 恢复训练
恢复训练需要将之前训练好的模型所在文件夹路径赋值给 `Architecture.Backbone.checkpoints` 字段。
```shell
CUDA_VISIBLE_DEVICES=0 python3 tools/train.py -c configs/vqa/re/layoutxlm.yml -o Architecture.Backbone.checkpoints=path/to/model_dir
```
* 评估
评估需要将待评估的模型所在文件夹路径赋值给 `Architecture.Backbone.checkpoints` 字段。
```shell
CUDA_VISIBLE_DEVICES=0 python3 tools/eval.py -c configs/vqa/re/layoutxlm.yml -o Architecture.Backbone.checkpoints=path/to/model_dir
```
最终会打印出`precision`, `recall`, `hmean`等指标
* 使用`OCR引擎 + SER + RE`串联预测
使用如下命令即可完成`OCR引擎 + SER + RE`的串联预测, 以预训练SER和RE模型为例:
```shell
export CUDA_VISIBLE_DEVICES=0
python3 tools/infer_vqa_token_ser_re.py -c configs/vqa/re/layoutxlm.yml -o Architecture.Backbone.checkpoints=pretrain/re_LayoutXLM_xfun_zh/ Global.infer_img=doc/vqa/input/zh_val_21.jpg -c_ser configs/vqa/ser/layoutxlm.yml -o_ser Architecture.Backbone.checkpoints=pretrain/ser_LayoutXLM_xfun_zh/
```
最终会在`config.Global.save_res_path`字段所配置的目录下保存预测结果可视化图像以及预测结果文本文件,预测结果文本文件名为`infer_results.txt`
## 6. 参考链接
- LayoutXLM: Multimodal Pre-training for Multilingual Visually-rich Document Understanding, https://arxiv.org/pdf/2104.08836.pdf
- microsoft/unilm/layoutxlm, https://github.com/microsoft/unilm/tree/master/layoutxlm
- XFUND dataset, https://github.com/doc-analysis/XFUND
## License
The content of this project itself is licensed under the [Attribution-NonCommercial-ShareAlike 4.0 International (CC BY-NC-SA 4.0)](https://creativecommons.org/licenses/by-nc-sa/4.0/)
......@@ -6,7 +6,7 @@ Global.use_gpu:True|True
Global.auto_cast:fp32
Global.epoch_num:lite_train_lite_infer=1|whole_train_whole_infer=500
Global.save_model_dir:./output/
Train.loader.batch_size_per_card:lite_train_lite_infer=2|whole_train_whole_infer=4
Train.loader.batch_size_per_card:lite_train_lite_infer=1|whole_train_whole_infer=4
Global.pretrained_model:null
train_model_name:latest
train_infer_img_dir:./train_data/icdar2015/text_localization/ch4_test_images/
......
===========================train_params===========================
model_name:PPOCRv2_ocr_rec_pact
model_name:ch_PPOCRv2_rec_PACT
python:python3.7
gpu_list:0|0,1
Global.use_gpu:True|True
Global.auto_cast:fp32
Global.epoch_num:lite_train_lite_infer=3|whole_train_whole_infer=300
Global.epoch_num:lite_train_lite_infer=6|whole_train_whole_infer=300
Global.save_model_dir:./output/
Train.loader.batch_size_per_card:lite_train_lite_infer=16|whole_train_whole_infer=128
Global.pretrained_model:null
Global.pretrained_model:pretrain_models/ch_PP-OCRv2_rec_train/best_accuracy
train_model_name:latest
train_infer_img_dir:./inference/rec_inference
null:null
......
......@@ -3,7 +3,7 @@ model_name:ocr_det
use_opencv:True
infer_model:./inference/ch_ppocr_mobile_v2.0_det_infer/
infer_quant:False
inference:./deploy/cpp_infer/build/ppocr det
inference:./deploy/cpp_infer/build/ppocr
--use_gpu:True|False
--enable_mkldnn:True|False
--cpu_threads:1|6
......@@ -13,4 +13,8 @@ inference:./deploy/cpp_infer/build/ppocr det
--det_model_dir:
--image_dir:./inference/ch_det_data_50/all-sum-510/
null:null
--benchmark:True
\ No newline at end of file
--benchmark:True
--det:True
--rec:False
--cls:False
--use_angle_cls:False
\ No newline at end of file
===========================train_params===========================
model_name:det_mv3_east_v2.0
python:python3.7
gpu_list:0
gpu_list:0|0,1
Global.use_gpu:True|True
Global.auto_cast:fp32
Global.epoch_num:lite_train_lite_infer=1|whole_train_whole_infer=500
Global.save_model_dir:./output/
Train.loader.batch_size_per_card:lite_train_lite_infer=2|whole_train_whole_infer=4
Global.pretrained_model:null
Global.pretrained_model:./pretrain_models/det_mv3_east_v2.0_train/best_accuracy
train_model_name:latest
train_infer_img_dir:./train_data/icdar2015/text_localization/ch4_test_images/
null:null
......
......@@ -37,7 +37,7 @@ export2:null
train_model:./inference/rec_mv3_tps_bilstm_att_v2.0_train/best_accuracy
infer_export:tools/export_model.py -c test_tipc/configs/rec_mv3_tps_bilstm_att_v2.0/rec_mv3_tps_bilstm_att.yml -o
infer_quant:False
inference:tools/infer/predict_rec.py --rec_char_dict_path=./ppocr/utils/ic15_dict.txt --rec_image_shape="3,32,100" --rec_algorithm="RARE"
inference:tools/infer/predict_rec.py --rec_char_dict_path=./ppocr/utils/ic15_dict.txt --rec_image_shape="3,32,100" --rec_algorithm="RARE" --min_subgraph_size=5
--use_gpu:True|False
--enable_mkldnn:True|False
--cpu_threads:1|6
......
......@@ -50,4 +50,4 @@ inference:tools/infer/predict_rec.py --rec_char_dict_path=./ppocr/utils/dict90.t
--benchmark:True
null:null
===========================infer_benchmark_params==========================
random_infer_input:[{float32,[3,48,48,160]}]
random_infer_input:[{float32,[3,48,160]}]
......@@ -37,7 +37,7 @@ export2:null
train_model:./inference/rec_r34_vd_tps_bilstm_att_v2.0_train/best_accuracy
infer_export:tools/export_model.py -c test_tipc/configs/rec_r34_vd_tps_bilstm_att_v2.0/rec_r34_vd_tps_bilstm_att.yml -o
infer_quant:False
inference:tools/infer/predict_rec.py --rec_char_dict_path=./ppocr/utils/ic15_dict.txt --rec_image_shape="3,32,100" --rec_algorithm="RARE"
inference:tools/infer/predict_rec.py --rec_char_dict_path=./ppocr/utils/ic15_dict.txt --rec_image_shape="3,32,100" --rec_algorithm="RARE" --min_subgraph_size=5
--use_gpu:True|False
--enable_mkldnn:True|False
--cpu_threads:1|6
......
......@@ -37,7 +37,7 @@ export2:null
train_model:./inference/rec_r50_vd_srn_train/best_accuracy
infer_export:tools/export_model.py -c test_tipc/configs/rec_r50_fpn_vd_none_srn/rec_r50_fpn_srn.yml -o
infer_quant:False
inference:tools/infer/predict_rec.py --rec_char_dict_path=./ppocr/utils/ic15_dict.txt --rec_image_shape="1,64,256" --rec_algorithm="SRN" --use_space_char=False
inference:tools/infer/predict_rec.py --rec_char_dict_path=./ppocr/utils/ic15_dict.txt --rec_image_shape="1,64,256" --rec_algorithm="SRN" --use_space_char=False --min_subgraph_size=3
--use_gpu:True|False
--enable_mkldnn:True|False
--cpu_threads:1|6
......
......@@ -64,6 +64,10 @@ if [ ${MODE} = "lite_train_lite_infer" ];then
wget -nc -P ./pretrain_models/ https://paddleocr.bj.bcebos.com/dygraph_v2.0/ch/ch_ppocr_server_v2.0_det_train.tar --no-check-certificate
cd ./pretrain_models/ && tar xf ch_ppocr_server_v2.0_det_train.tar && cd ../
fi
if [ ${model_name} == "ch_PPOCRv2_rec" ] || [ ${model_name} == "ch_PPOCRv2_rec_PACT" ]; then
wget -nc -P ./pretrain_models/ https://paddleocr.bj.bcebos.com/PP-OCRv2/chinese/ch_PP-OCRv2_rec_train.tar --no-check-certificate
cd ./pretrain_models/ && tar xf ch_PP-OCRv2_rec_train.tar && cd ../
fi
if [ ${model_name} == "det_r18_db_v2_0" ]; then
wget -nc -P ./pretrain_models/ https://paddleocr.bj.bcebos.com/pretrained/ResNet18_vd_pretrained.pdparams --no-check-certificate
fi
......@@ -91,6 +95,10 @@ if [ ${MODE} = "lite_train_lite_infer" ];then
wget -nc -P ./pretrain_models/ https://paddleocr.bj.bcebos.com/dygraph_v2.0/ch/ch_ppocr_mobile_v2.0_rec_train.tar --no-check-certificate
cd ./pretrain_models/ && tar xf ch_ppocr_mobile_v2.0_rec_train.tar && cd ../
fi
if [ ${model_name} == "det_mv3_east_v2.0" ]; then
wget -nc -P ./pretrain_models/ https://paddleocr.bj.bcebos.com/dygraph_v2.0/en/det_mv3_east_v2.0_train.tar --no-check-certificate
cd ./pretrain_models/ && tar xf det_mv3_east_v2.0_train.tar && cd ../
fi
elif [ ${MODE} = "whole_train_whole_infer" ];then
wget -nc -P ./pretrain_models/ https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/MobileNetV3_large_x0_5_pretrained.pdparams --no-check-certificate
......
......@@ -2,7 +2,7 @@
source test_tipc/common_func.sh
FILENAME=$1
dataline=$(awk 'NR==1, NR==16{print}' $FILENAME)
dataline=$(awk 'NR==1, NR==20{print}' $FILENAME)
# parser params
IFS=$'\n'
......@@ -34,6 +34,14 @@ cpp_infer_key1=$(func_parser_key "${lines[14]}")
cpp_infer_value1=$(func_parser_value "${lines[14]}")
cpp_benchmark_key=$(func_parser_key "${lines[15]}")
cpp_benchmark_value=$(func_parser_value "${lines[15]}")
cpp_det_key=$(func_parser_key "${lines[16]}")
cpp_det_value=$(func_parser_value "${lines[16]}")
cpp_rec_key=$(func_parser_key "${lines[17]}")
cpp_rec_value=$(func_parser_value "${lines[17]}")
cpp_cls_key=$(func_parser_key "${lines[18]}")
cpp_cls_value=$(func_parser_value "${lines[18]}")
cpp_use_angle_cls_key=$(func_parser_key "${lines[19]}")
cpp_use_angle_cls_value=$(func_parser_value "${lines[19]}")
LOG_PATH="./test_tipc/output"
mkdir -p ${LOG_PATH}
......@@ -68,7 +76,11 @@ function func_cpp_inference(){
set_cpu_threads=$(func_set_params "${cpp_cpu_threads_key}" "${threads}")
set_model_dir=$(func_set_params "${cpp_infer_model_key}" "${_model_dir}")
set_infer_params1=$(func_set_params "${cpp_infer_key1}" "${cpp_infer_value1}")
command="${_script} ${cpp_use_gpu_key}=${use_gpu} ${set_mkldnn} ${set_cpu_threads} ${set_model_dir} ${set_batchsize} ${set_infer_data} ${set_benchmark} ${set_infer_params1} > ${_save_log_path} 2>&1 "
set_det=$(func_set_params "${cpp_det_key}" "${cpp_det_value}")
set_rec=$(func_set_params "${cpp_rec_key}" "${cpp_rec_value}")
set_cls=$(func_set_params "${cpp_cls_key}" "${cpp_cls_value}")
set_use_angle_cls=$(func_set_params "${cpp_use_angle_cls_key}" "${cpp_use_angle_cls_value}")
command="${_script} ${cpp_use_gpu_key}=${use_gpu} ${set_mkldnn} ${set_cpu_threads} ${set_model_dir} ${set_batchsize} ${set_infer_data} ${set_benchmark} ${set_det} ${set_rec} ${set_cls} ${set_use_angle_cls} ${set_infer_params1} > ${_save_log_path} 2>&1 "
eval $command
last_status=${PIPESTATUS[0]}
eval "cat ${_save_log_path}"
......@@ -97,7 +109,11 @@ function func_cpp_inference(){
set_precision=$(func_set_params "${cpp_precision_key}" "${precision}")
set_model_dir=$(func_set_params "${cpp_infer_model_key}" "${_model_dir}")
set_infer_params1=$(func_set_params "${cpp_infer_key1}" "${cpp_infer_value1}")
command="${_script} ${cpp_use_gpu_key}=${use_gpu} ${set_tensorrt} ${set_precision} ${set_model_dir} ${set_batchsize} ${set_infer_data} ${set_benchmark} ${set_infer_params1} > ${_save_log_path} 2>&1 "
set_det=$(func_set_params "${cpp_det_key}" "${cpp_det_value}")
set_rec=$(func_set_params "${cpp_rec_key}" "${cpp_rec_value}")
set_cls=$(func_set_params "${cpp_cls_key}" "${cpp_cls_value}")
set_use_angle_cls=$(func_set_params "${cpp_use_angle_cls_key}" "${cpp_use_angle_cls_value}")
command="${_script} ${cpp_use_gpu_key}=${use_gpu} ${set_tensorrt} ${set_precision} ${set_model_dir} ${set_batchsize} ${set_infer_data} ${set_benchmark} ${set_det} ${set_rec} ${set_cls} ${set_use_angle_cls} ${set_infer_params1} > ${_save_log_path} 2>&1 "
eval $command
last_status=${PIPESTATUS[0]}
eval "cat ${_save_log_path}"
......
......@@ -119,6 +119,10 @@ class TextRecognizer(object):
resized_w = imgW
else:
resized_w = int(math.ceil(imgH * ratio))
if self.rec_algorithm == 'RARE':
if resized_w > self.rec_image_shape[2]:
resized_w = self.rec_image_shape[2]
imgW = self.rec_image_shape[2]
resized_image = cv2.resize(img, (resized_w, imgH))
resized_image = resized_image.astype('float32')
resized_image = resized_image.transpose((2, 0, 1)) / 255
......
......@@ -312,12 +312,26 @@ def create_predictor(args, mode, logger):
input_names = predictor.get_input_names()
for name in input_names:
input_tensor = predictor.get_input_handle(name)
output_names = predictor.get_output_names()
output_tensors = []
output_tensors = get_output_tensors(args, mode, predictor)
return predictor, input_tensor, output_tensors, config
def get_output_tensors(args, mode, predictor):
output_names = predictor.get_output_names()
output_tensors = []
if mode == "rec" and args.rec_algorithm == "CRNN":
output_name = 'softmax_0.tmp_0'
if output_name in output_names:
return [predictor.get_output_handle(output_name)]
else:
for output_name in output_names:
output_tensor = predictor.get_output_handle(output_name)
output_tensors.append(output_tensor)
else:
for output_name in output_names:
output_tensor = predictor.get_output_handle(output_name)
output_tensors.append(output_tensor)
return predictor, input_tensor, output_tensors, config
return output_tensors
def get_infer_gpuid():
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册