未验证 提交 7de0daed 编写于 作者: Z zhoujun 提交者: GitHub

Merge pull request #5930 from WenmuZhou/cpp_infer

rebuild cpp infer code
// Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <gflags/gflags.h>
// common args
DECLARE_bool(use_gpu);
DECLARE_bool(use_tensorrt);
DECLARE_int32(gpu_id);
DECLARE_int32(gpu_mem);
DECLARE_int32(cpu_threads);
DECLARE_bool(enable_mkldnn);
DECLARE_string(precision);
DECLARE_bool(benchmark);
DECLARE_string(output);
DECLARE_string(image_dir);
DECLARE_string(type);
// detection related
DECLARE_string(det_model_dir);
DECLARE_int32(max_side_len);
DECLARE_double(det_db_thresh);
DECLARE_double(det_db_box_thresh);
DECLARE_double(det_db_unclip_ratio);
DECLARE_bool(use_dilation);
DECLARE_string(det_db_score_mode);
DECLARE_bool(visualize);
// classification related
DECLARE_bool(use_angle_cls);
DECLARE_string(cls_model_dir);
DECLARE_double(cls_thresh);
DECLARE_int32(cls_batch_num);
// recognition related
DECLARE_string(rec_model_dir);
DECLARE_int32(rec_batch_num);
DECLARE_string(rec_char_dict_path);
// forward related
DECLARE_bool(det);
DECLARE_bool(rec);
DECLARE_bool(cls);
......@@ -42,7 +42,8 @@ public:
const int &gpu_id, const int &gpu_mem,
const int &cpu_math_library_num_threads,
const bool &use_mkldnn, const double &cls_thresh,
const bool &use_tensorrt, const std::string &precision) {
const bool &use_tensorrt, const std::string &precision,
const int &cls_batch_num) {
this->use_gpu_ = use_gpu;
this->gpu_id_ = gpu_id;
this->gpu_mem_ = gpu_mem;
......@@ -52,14 +53,17 @@ public:
this->cls_thresh = cls_thresh;
this->use_tensorrt_ = use_tensorrt;
this->precision_ = precision;
this->cls_batch_num_ = cls_batch_num;
LoadModel(model_dir);
}
double cls_thresh = 0.9;
// Load Paddle inference model
void LoadModel(const std::string &model_dir);
cv::Mat Run(cv::Mat &img);
void Run(std::vector<cv::Mat> img_list, std::vector<int> &cls_labels,
std::vector<float> &cls_scores, std::vector<double> &times);
private:
std::shared_ptr<Predictor> predictor_;
......@@ -69,17 +73,17 @@ private:
int gpu_mem_ = 4000;
int cpu_math_library_num_threads_ = 4;
bool use_mkldnn_ = false;
double cls_thresh = 0.5;
std::vector<float> mean_ = {0.5f, 0.5f, 0.5f};
std::vector<float> scale_ = {1 / 0.5f, 1 / 0.5f, 1 / 0.5f};
bool is_scale_ = true;
bool use_tensorrt_ = false;
std::string precision_ = "fp32";
int cls_batch_num_ = 1;
// pre-process
ClsResizeImg resize_op_;
Normalize normalize_op_;
Permute permute_op_;
PermuteBatch permute_op_;
}; // class Classifier
......
......@@ -73,7 +73,7 @@ public:
// Run predictor
void Run(cv::Mat &img, std::vector<std::vector<std::vector<int>>> &boxes,
std::vector<double> *times);
std::vector<double> &times);
private:
std::shared_ptr<Predictor> predictor_;
......
......@@ -30,7 +30,6 @@
#include <numeric>
#include <include/ocr_cls.h>
#include <include/postprocess_op.h>
#include <include/preprocess_op.h>
#include <include/utility.h>
......@@ -68,7 +67,7 @@ public:
void LoadModel(const std::string &model_dir);
void Run(std::vector<cv::Mat> img_list, std::vector<std::string> &rec_texts,
std::vector<float> &rec_text_scores, std::vector<double> *times);
std::vector<float> &rec_text_scores, std::vector<double> &times);
private:
std::shared_ptr<Predictor> predictor_;
......@@ -93,9 +92,6 @@ private:
Normalize normalize_op_;
PermuteBatch permute_op_;
// post-process
PostProcessor post_processor_;
}; // class CrnnRecognizer
} // namespace PaddleOCR
// Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include "opencv2/core.hpp"
#include "opencv2/imgcodecs.hpp"
#include "opencv2/imgproc.hpp"
#include "paddle_api.h"
#include "paddle_inference_api.h"
#include <chrono>
#include <iomanip>
#include <iostream>
#include <ostream>
#include <vector>
#include <cstring>
#include <fstream>
#include <numeric>
#include <include/ocr_cls.h>
#include <include/ocr_det.h>
#include <include/ocr_rec.h>
#include <include/preprocess_op.h>
#include <include/utility.h>
using namespace paddle_infer;
namespace PaddleOCR {
class PaddleOCR {
public:
explicit PaddleOCR();
~PaddleOCR();
std::vector<std::vector<OCRPredictResult>>
ocr(std::vector<cv::String> cv_all_img_names, bool det = true,
bool rec = true, bool cls = true);
private:
DBDetector *detector_ = nullptr;
Classifier *classifier_ = nullptr;
CRNNRecognizer *recognizer_ = nullptr;
void det(cv::Mat img, std::vector<OCRPredictResult> &ocr_results,
std::vector<double> &times);
void rec(std::vector<cv::Mat> img_list,
std::vector<OCRPredictResult> &ocr_results,
std::vector<double> &times);
void cls(std::vector<cv::Mat> img_list,
std::vector<OCRPredictResult> &ocr_results,
std::vector<double> &times);
void log(std::vector<double> &det_times, std::vector<double> &rec_times,
std::vector<double> &cls_times, int img_num);
};
} // namespace PaddleOCR
......@@ -32,14 +32,21 @@
namespace PaddleOCR {
struct OCRPredictResult {
std::vector<std::vector<int>> box;
std::string text;
float score = -1.0;
float cls_score;
int cls_label = -1;
};
class Utility {
public:
static std::vector<std::string> ReadDict(const std::string &path);
static void
VisualizeBboxes(const cv::Mat &srcimg,
const std::vector<std::vector<std::vector<int>>> &boxes,
const std::string &save_path);
static void VisualizeBboxes(const cv::Mat &srcimg,
const std::vector<OCRPredictResult> &ocr_result,
const std::string &save_path);
template <class ForwardIterator>
inline static size_t argmax(ForwardIterator first, ForwardIterator last) {
......@@ -55,6 +62,10 @@ public:
static std::vector<int> argsort(const std::vector<float> &array);
static std::string basename(const std::string &filename);
static bool PathExists(const std::string &path);
static void print_result(const std::vector<OCRPredictResult> &ocr_result);
};
} // namespace PaddleOCR
\ No newline at end of file
......@@ -9,9 +9,12 @@
- [2.1 将模型导出为inference model](#21-将模型导出为inference-model)
- [2.2 编译PaddleOCR C++预测demo](#22-编译paddleocr-c预测demo)
- [2.3 运行demo](#23-运行demo)
- [1. 只调用检测:](#1-只调用检测)
- [2. 只调用识别:](#2-只调用识别)
- [3. 调用串联:](#3-调用串联)
- [1. 检测+分类+识别:](#1-检测分类识别)
- [2. 检测+识别:](#2-检测识别)
- [3. 检测:](#3-检测)
- [4. 分类+识别:](#4-分类识别)
- [5. 识别:](#5-识别)
- [6. 分类:](#6-分类)
- [3. FAQ](#3-faq)
# 服务器端C++预测
......@@ -181,6 +184,9 @@ inference/
|-- rec_rcnn
| |--inference.pdiparams
| |--inference.pdmodel
|-- cls
| |--inference.pdiparams
| |--inference.pdmodel
```
<a name="22"></a>
......@@ -213,36 +219,71 @@ CUDNN_LIB_DIR=/your_cudnn_lib_dir
运行方式:
```shell
./build/ppocr <mode> [--param1] [--param2] [...]
./build/ppocr [--param1] [--param2] [...]
```
具体命令如下:
##### 1. 检测+分类+识别:
```shell
./build/ppocr --det_model_dir=inference/det_db \
--rec_model_dir=inference/rec_rcnn \
--cls_model_dir=inference/cls \
--image_dir=../../doc/imgs/12.jpg \
--use_angle_cls=true \
--det=true \
--rec=true \
--cls=true \
```
##### 2. 检测+识别:
```shell
./build/ppocr --det_model_dir=inference/det_db \
--rec_model_dir=inference/rec_rcnn \
--image_dir=../../doc/imgs/12.jpg \
--use_angle_cls=false \
--det=true \
--rec=true \
--cls=false \
```
##### 3. 检测:
```shell
./build/ppocr --det_model_dir=inference/det_db \
--image_dir=../../doc/imgs/12.jpg \
--det=true \
--rec=false
```
其中,`mode`为必选参数,表示选择的功能,取值范围['det', 'rec', 'system'],分别表示调用检测、识别、检测识别串联(包括方向分类器)。具体命令如下:
##### 1. 只调用检测
##### 4. 分类+识别
```shell
./build/ppocr det \
--det_model_dir=inference/ch_ppocr_mobile_v2.0_det_infer \
--image_dir=../../doc/imgs/12.jpg
./build/ppocr --rec_model_dir=inference/rec_rcnn \
--cls_model_dir=inference/cls \
--image_dir=../../doc/imgs_words/ch/word_1.jpg \
--use_angle_cls=true \
--det=false \
--rec=true \
--cls=true \
```
##### 2. 只调用识别:
##### 5. 识别:
```shell
./build/ppocr rec \
--rec_model_dir=inference/ch_ppocr_mobile_v2.0_rec_infer \
--image_dir=../../doc/imgs_words/ch/
./build/ppocr --rec_model_dir=inference/rec_rcnn \
--image_dir=../../doc/imgs_words/ch/word_1.jpg \
--use_angle_cls=false \
--det=false \
--rec=true \
--cls=false \
```
##### 3. 调用串联:
##### 6. 分类:
```shell
# 不使用方向分类器
./build/ppocr system \
--det_model_dir=inference/ch_ppocr_mobile_v2.0_det_infer \
--rec_model_dir=inference/ch_ppocr_mobile_v2.0_rec_infer \
--image_dir=../../doc/imgs/12.jpg
# 使用方向分类器
./build/ppocr system \
--det_model_dir=inference/ch_ppocr_mobile_v2.0_det_infer \
./build/ppocr --cls_model_dir=inference/cls \
--cls_model_dir=inference/cls \
--image_dir=../../doc/imgs_words/ch/word_1.jpg \
--use_angle_cls=true \
--cls_model_dir=inference/ch_ppocr_mobile_v2.0_cls_infer \
--rec_model_dir=inference/ch_ppocr_mobile_v2.0_rec_infer \
--image_dir=../../doc/imgs/12.jpg
--det=false \
--rec=false \
--cls=true \
```
更多支持的可调节参数解释如下:
......@@ -258,6 +299,15 @@ CUDNN_LIB_DIR=/your_cudnn_lib_dir
|enable_mkldnn|bool|true|是否使用mkldnn库|
|output|str|./output|可视化结果保存的路径|
- 前向相关
|参数名称|类型|默认参数|意义|
| :---: | :---: | :---: | :---: |
|det|bool|true|前向是否执行文字检测|
|rec|bool|true|前向是否执行文字识别|
|cls|bool|false|前向是否执行文字方向分类|
- 检测模型相关
|参数名称|类型|默认参数|意义|
......@@ -277,6 +327,7 @@ CUDNN_LIB_DIR=/your_cudnn_lib_dir
|use_angle_cls|bool|false|是否使用方向分类器|
|cls_model_dir|string|-|方向分类器inference model地址|
|cls_thresh|float|0.9|方向分类器的得分阈值|
|cls_batch_num|int|1|方向分类器batchsize|
- 识别模型相关
......@@ -284,15 +335,22 @@ CUDNN_LIB_DIR=/your_cudnn_lib_dir
| :---: | :---: | :---: | :---: |
|rec_model_dir|string|-|识别模型inference model地址|
|rec_char_dict_path|string|../../ppocr/utils/ppocr_keys_v1.txt|字典文件|
|rec_batch_num|int|6|识别模型batchsize|
* PaddleOCR也支持多语言的预测,更多支持的语言和模型可以参考[识别文档](../../doc/doc_ch/recognition.md)中的多语言字典与模型部分,如果希望进行多语言预测,只需将修改`rec_char_dict_path`(字典文件路径)以及`rec_model_dir`(inference模型路径)字段即可。
最终屏幕上会输出检测结果如下。
<div align="center">
<img src="./imgs/cpp_infer_pred_12.png" width="600">
</div>
```bash
predict img: ../../doc/imgs/12.jpg
../../doc/imgs/12.jpg
0 det boxes: [[79,553],[399,541],[400,573],[80,585]] rec text: 打浦路252935号 rec score: 0.933757
1 det boxes: [[31,509],[510,488],[511,529],[33,549]] rec text: 绿洲仕格维花园公寓 rec score: 0.951745
2 det boxes: [[181,456],[395,448],[396,480],[182,488]] rec text: 打浦路15号 rec score: 0.91956
3 det boxes: [[43,413],[480,391],[481,428],[45,450]] rec text: 上海斯格威铂尔多大酒店 rec score: 0.915914
The detection visualized image saved in ./output//12.jpg
```
## 3. FAQ
......
......@@ -9,9 +9,12 @@
- [2.1 Export the inference model](#21-export-the-inference-model)
- [2.2 Compile PaddleOCR C++ inference demo](#22-compile-paddleocr-c-inference-demo)
- [Run the demo](#run-the-demo)
- [1. run det demo:](#1-run-det-demo)
- [2. run rec demo:](#2-run-rec-demo)
- [3. run system demo:](#3-run-system-demo)
- [1. det+cls+rec:](#1-detclsrec)
- [2. det+rec:](#2-detrec)
- [3. det](#3-det)
- [4. cls+rec:](#4-clsrec)
- [5. rec](#5-rec)
- [6. cls](#6-cls)
- [3. FAQ](#3-faq)
# Server-side C++ Inference
......@@ -166,6 +169,9 @@ inference/
|-- rec_rcnn
| |--inference.pdiparams
| |--inference.pdmodel
|-- cls
| |--inference.pdiparams
| |--inference.pdmodel
```
......@@ -198,44 +204,72 @@ or the generated Paddle inference library path (`build/paddle_inference_install_
### Run the demo
Execute the built executable file:
```shell
./build/ppocr <mode> [--param1] [--param2] [...]
./build/ppocr [--param1] [--param2] [...]
```
`mode` is a required parameter,and the valid values are
mode value | Model used
-----|------
det | Detection only
rec | Recognition only
system | End-to-end system
Specifically,
##### 1. run det demo:
##### 1. det+cls+rec:
```shell
./build/ppocr --det_model_dir=inference/det_db \
--rec_model_dir=inference/rec_rcnn \
--cls_model_dir=inference/cls \
--image_dir=../../doc/imgs/12.jpg \
--use_angle_cls=true \
--det=true \
--rec=true \
--cls=true \
```
##### 2. det+rec:
```shell
./build/ppocr det \
--det_model_dir=inference/ch_ppocr_mobile_v2.0_det_infer \
--image_dir=../../doc/imgs/12.jpg
./build/ppocr --det_model_dir=inference/det_db \
--rec_model_dir=inference/rec_rcnn \
--image_dir=../../doc/imgs/12.jpg \
--use_angle_cls=false \
--det=true \
--rec=true \
--cls=false \
```
##### 2. run rec demo:
##### 3. det
```shell
./build/ppocr rec \
--rec_model_dir=inference/ch_ppocr_mobile_v2.0_rec_infer \
--image_dir=../../doc/imgs_words/ch/
./build/ppocr --det_model_dir=inference/det_db \
--image_dir=../../doc/imgs/12.jpg \
--det=true \
--rec=false
```
##### 3. run system demo:
##### 4. cls+rec:
```shell
# without text direction classifier
./build/ppocr system \
--det_model_dir=inference/ch_ppocr_mobile_v2.0_det_infer \
--rec_model_dir=inference/ch_ppocr_mobile_v2.0_rec_infer \
--image_dir=../../doc/imgs/12.jpg
# with text direction classifier
./build/ppocr system \
--det_model_dir=inference/ch_ppocr_mobile_v2.0_det_infer \
./build/ppocr --rec_model_dir=inference/rec_rcnn \
--cls_model_dir=inference/cls \
--image_dir=../../doc/imgs_words/ch/word_1.jpg \
--use_angle_cls=true \
--cls_model_dir=inference/ch_ppocr_mobile_v2.0_cls_infer \
--rec_model_dir=inference/ch_ppocr_mobile_v2.0_rec_infer \
--image_dir=../../doc/imgs/12.jpg
--det=false \
--rec=true \
--cls=true \
```
##### 5. rec
```shell
./build/ppocr --rec_model_dir=inference/rec_rcnn \
--image_dir=../../doc/imgs_words/ch/word_1.jpg \
--use_angle_cls=false \
--det=false \
--rec=true \
--cls=false \
```
##### 6. cls
```shell
./build/ppocr --cls_model_dir=inference/cls \
--cls_model_dir=inference/cls \
--image_dir=../../doc/imgs_words/ch/word_1.jpg \
--use_angle_cls=true \
--det=false \
--rec=false \
--cls=true \
```
More parameters are as follows,
......@@ -251,6 +285,16 @@ More parameters are as follows,
|enable_mkldnn|bool|true|Whether to use mkdlnn library|
|output|str|./output|Path where visualization results are saved|
- forward
|parameter|data type|default|meaning|
| :---: | :---: | :---: | :---: |
|det|bool|true|前向是否执行文字检测|
|rec|bool|true|前向是否执行文字识别|
|cls|bool|false|前向是否执行文字方向分类|
- Detection related parameters
|parameter|data type|default|meaning|
......@@ -270,6 +314,7 @@ More parameters are as follows,
|use_angle_cls|bool|false|Whether to use the direction classifier|
|cls_model_dir|string|-|Address of direction classifier inference model|
|cls_thresh|float|0.9|Score threshold of the direction classifier|
|cls_batch_num|int|1|batch size of classifier|
- Recognition related parameters
......@@ -277,15 +322,22 @@ More parameters are as follows,
| --- | --- | --- | --- |
|rec_model_dir|string|-|Address of recognition inference model|
|rec_char_dict_path|string|../../ppocr/utils/ppocr_keys_v1.txt|dictionary file|
|rec_batch_num|int|6|batch size of recognition|
* Multi-language inference is also supported in PaddleOCR, you can refer to [recognition tutorial](../../doc/doc_en/recognition_en.md) for more supported languages and models in PaddleOCR. Specifically, if you want to infer using multi-language models, you just need to modify values of `rec_char_dict_path` and `rec_model_dir`.
The detection results will be shown on the screen, which is as follows.
<div align="center">
<img src="./imgs/cpp_infer_pred_12.png" width="600">
</div>
```bash
predict img: ../../doc/imgs/12.jpg
../../doc/imgs/12.jpg
0 det boxes: [[79,553],[399,541],[400,573],[80,585]] rec text: 打浦路252935号 rec score: 0.933757
1 det boxes: [[31,509],[510,488],[511,529],[33,549]] rec text: 绿洲仕格维花园公寓 rec score: 0.951745
2 det boxes: [[181,456],[395,448],[396,480],[182,488]] rec text: 打浦路15号 rec score: 0.91956
3 det boxes: [[43,413],[480,391],[481,428],[45,450]] rec text: 上海斯格威铂尔多大酒店 rec score: 0.915914
The detection visualized image saved in ./output//12.jpg
```
## 3. FAQ
......
// Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include <gflags/gflags.h>
// common args
DEFINE_bool(use_gpu, false, "Infering with GPU or CPU.");
DEFINE_bool(use_tensorrt, false, "Whether use tensorrt.");
DEFINE_int32(gpu_id, 0, "Device id of GPU to execute.");
DEFINE_int32(gpu_mem, 4000, "GPU id when infering with GPU.");
DEFINE_int32(cpu_threads, 10, "Num of threads with CPU.");
DEFINE_bool(enable_mkldnn, false, "Whether use mkldnn with CPU.");
DEFINE_string(precision, "fp32", "Precision be one of fp32/fp16/int8");
DEFINE_bool(benchmark, false, "Whether use benchmark.");
DEFINE_string(output, "./output/", "Save benchmark log path.");
DEFINE_string(image_dir, "", "Dir of input image.");
DEFINE_string(
type, "ocr",
"Perform ocr or structure, the value is selected in ['ocr','structure'].");
// detection related
DEFINE_string(det_model_dir, "", "Path of det inference model.");
DEFINE_int32(max_side_len, 960, "max_side_len of input image.");
DEFINE_double(det_db_thresh, 0.3, "Threshold of det_db_thresh.");
DEFINE_double(det_db_box_thresh, 0.6, "Threshold of det_db_box_thresh.");
DEFINE_double(det_db_unclip_ratio, 1.5, "Threshold of det_db_unclip_ratio.");
DEFINE_bool(use_dilation, false, "Whether use the dilation on output map.");
DEFINE_string(det_db_score_mode, "slow", "Whether use polygon score.");
DEFINE_bool(visualize, true, "Whether show the detection results.");
// classification related
DEFINE_bool(use_angle_cls, false, "Whether use use_angle_cls.");
DEFINE_string(cls_model_dir, "", "Path of cls inference model.");
DEFINE_double(cls_thresh, 0.9, "Threshold of cls_thresh.");
DEFINE_int32(cls_batch_num, 1, "cls_batch_num.");
// recognition related
DEFINE_string(rec_model_dir, "", "Path of rec inference model.");
DEFINE_int32(rec_batch_num, 6, "rec_batch_num.");
DEFINE_string(rec_char_dict_path, "../../ppocr/utils/ppocr_keys_v1.txt",
"Path of dictionary.");
// ocr forward related
DEFINE_bool(det, true, "Whether use det in forward.");
DEFINE_bool(rec, true, "Whether use rec in forward.");
DEFINE_bool(cls, false, "Whether use cls in forward.");
\ No newline at end of file
......@@ -11,278 +11,19 @@
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "omp.h"
#include "opencv2/core.hpp"
#include "opencv2/imgcodecs.hpp"
#include "opencv2/imgproc.hpp"
#include <chrono>
#include <iomanip>
#include <iostream>
#include <ostream>
#include <sys/stat.h>
#include <vector>
#include <cstring>
#include <fstream>
#include <numeric>
#include <include/ocr_cls.h>
#include <include/ocr_det.h>
#include <include/ocr_rec.h>
#include <include/utility.h>
#include <sys/stat.h>
#include "auto_log/autolog.h"
#include <gflags/gflags.h>
#include <include/args.h>
#include <include/paddleocr.h>
// common args
DEFINE_bool(use_gpu, false, "Infering with GPU or CPU.");
DEFINE_bool(use_tensorrt, false, "Whether use tensorrt.");
DEFINE_int32(gpu_id, 0, "Device id of GPU to execute.");
DEFINE_int32(gpu_mem, 4000, "GPU id when infering with GPU.");
DEFINE_int32(cpu_threads, 10, "Num of threads with CPU.");
DEFINE_bool(enable_mkldnn, false, "Whether use mkldnn with CPU.");
DEFINE_string(precision, "fp32", "Precision be one of fp32/fp16/int8");
DEFINE_bool(benchmark, false, "Whether use benchmark.");
DEFINE_string(output, "./output/", "Save benchmark log path.");
DEFINE_string(image_dir, "", "Dir of input image.");
DEFINE_bool(visualize, true, "Whether show the detection results.");
// detection related
DEFINE_string(det_model_dir, "", "Path of det inference model.");
DEFINE_int32(max_side_len, 960, "max_side_len of input image.");
DEFINE_double(det_db_thresh, 0.3, "Threshold of det_db_thresh.");
DEFINE_double(det_db_box_thresh, 0.6, "Threshold of det_db_box_thresh.");
DEFINE_double(det_db_unclip_ratio, 1.5, "Threshold of det_db_unclip_ratio.");
DEFINE_bool(use_dilation, false, "Whether use the dilation on output map.");
DEFINE_string(det_db_score_mode, "slow", "Whether use polygon score.");
// classification related
DEFINE_bool(use_angle_cls, false, "Whether use use_angle_cls.");
DEFINE_string(cls_model_dir, "", "Path of cls inference model.");
DEFINE_double(cls_thresh, 0.9, "Threshold of cls_thresh.");
// recognition related
DEFINE_string(rec_model_dir, "", "Path of rec inference model.");
DEFINE_int32(rec_batch_num, 6, "rec_batch_num.");
DEFINE_string(rec_char_dict_path, "../../ppocr/utils/ppocr_keys_v1.txt",
"Path of dictionary.");
using namespace std;
using namespace cv;
using namespace PaddleOCR;
static bool PathExists(const std::string &path) {
#ifdef _WIN32
struct _stat buffer;
return (_stat(path.c_str(), &buffer) == 0);
#else
struct stat buffer;
return (stat(path.c_str(), &buffer) == 0);
#endif // !_WIN32
}
int main_det(std::vector<cv::String> cv_all_img_names) {
std::vector<double> time_info = {0, 0, 0};
DBDetector det(FLAGS_det_model_dir, FLAGS_use_gpu, FLAGS_gpu_id,
FLAGS_gpu_mem, FLAGS_cpu_threads, FLAGS_enable_mkldnn,
FLAGS_max_side_len, FLAGS_det_db_thresh,
FLAGS_det_db_box_thresh, FLAGS_det_db_unclip_ratio,
FLAGS_det_db_score_mode, FLAGS_use_dilation,
FLAGS_use_tensorrt, FLAGS_precision);
if (!PathExists(FLAGS_output)) {
mkdir(FLAGS_output.c_str(), 0777);
}
for (int i = 0; i < cv_all_img_names.size(); ++i) {
if (!FLAGS_benchmark) {
cout << "The predict img: " << cv_all_img_names[i] << endl;
}
cv::Mat srcimg = cv::imread(cv_all_img_names[i], cv::IMREAD_COLOR);
if (!srcimg.data) {
std::cerr << "[ERROR] image read failed! image path: "
<< cv_all_img_names[i] << endl;
exit(1);
}
std::vector<std::vector<std::vector<int>>> boxes;
std::vector<double> det_times;
det.Run(srcimg, boxes, &det_times);
// visualization
if (FLAGS_visualize) {
std::string file_name = Utility::basename(cv_all_img_names[i]);
Utility::VisualizeBboxes(srcimg, boxes, FLAGS_output + "/" + file_name);
}
time_info[0] += det_times[0];
time_info[1] += det_times[1];
time_info[2] += det_times[2];
if (FLAGS_benchmark) {
cout << cv_all_img_names[i] << "\t[";
for (int n = 0; n < boxes.size(); n++) {
cout << '[';
for (int m = 0; m < boxes[n].size(); m++) {
cout << '[' << boxes[n][m][0] << ',' << boxes[n][m][1] << "]";
if (m != boxes[n].size() - 1) {
cout << ',';
}
}
cout << ']';
if (n != boxes.size() - 1) {
cout << ',';
}
}
cout << ']' << endl;
}
}
if (FLAGS_benchmark) {
AutoLogger autolog("ocr_det", FLAGS_use_gpu, FLAGS_use_tensorrt,
FLAGS_enable_mkldnn, FLAGS_cpu_threads, 1, "dynamic",
FLAGS_precision, time_info, cv_all_img_names.size());
autolog.report();
}
return 0;
}
int main_rec(std::vector<cv::String> cv_all_img_names) {
std::vector<double> time_info = {0, 0, 0};
std::string rec_char_dict_path = FLAGS_rec_char_dict_path;
cout << "label file: " << rec_char_dict_path << endl;
CRNNRecognizer rec(FLAGS_rec_model_dir, FLAGS_use_gpu, FLAGS_gpu_id,
FLAGS_gpu_mem, FLAGS_cpu_threads, FLAGS_enable_mkldnn,
rec_char_dict_path, FLAGS_use_tensorrt, FLAGS_precision,
FLAGS_rec_batch_num);
std::vector<cv::Mat> img_list;
for (int i = 0; i < cv_all_img_names.size(); ++i) {
cv::Mat srcimg = cv::imread(cv_all_img_names[i], cv::IMREAD_COLOR);
if (!srcimg.data) {
std::cerr << "[ERROR] image read failed! image path: "
<< cv_all_img_names[i] << endl;
exit(1);
}
img_list.push_back(srcimg);
}
std::vector<std::string> rec_texts(img_list.size(), "");
std::vector<float> rec_text_scores(img_list.size(), 0);
std::vector<double> rec_times;
rec.Run(img_list, rec_texts, rec_text_scores, &rec_times);
// output rec results
for (int i = 0; i < rec_texts.size(); i++) {
cout << "The predict img: " << cv_all_img_names[i] << "\t" << rec_texts[i]
<< "\t" << rec_text_scores[i] << endl;
}
time_info[0] += rec_times[0];
time_info[1] += rec_times[1];
time_info[2] += rec_times[2];
if (FLAGS_benchmark) {
AutoLogger autolog("ocr_rec", FLAGS_use_gpu, FLAGS_use_tensorrt,
FLAGS_enable_mkldnn, FLAGS_cpu_threads,
FLAGS_rec_batch_num, "dynamic", FLAGS_precision,
time_info, cv_all_img_names.size());
autolog.report();
}
return 0;
}
int main_system(std::vector<cv::String> cv_all_img_names) {
std::vector<double> time_info_det = {0, 0, 0};
std::vector<double> time_info_rec = {0, 0, 0};
if (!PathExists(FLAGS_output)) {
mkdir(FLAGS_output.c_str(), 0777);
}
DBDetector det(FLAGS_det_model_dir, FLAGS_use_gpu, FLAGS_gpu_id,
FLAGS_gpu_mem, FLAGS_cpu_threads, FLAGS_enable_mkldnn,
FLAGS_max_side_len, FLAGS_det_db_thresh,
FLAGS_det_db_box_thresh, FLAGS_det_db_unclip_ratio,
FLAGS_det_db_score_mode, FLAGS_use_dilation,
FLAGS_use_tensorrt, FLAGS_precision);
Classifier *cls = nullptr;
if (FLAGS_use_angle_cls) {
cls = new Classifier(FLAGS_cls_model_dir, FLAGS_use_gpu, FLAGS_gpu_id,
FLAGS_gpu_mem, FLAGS_cpu_threads, FLAGS_enable_mkldnn,
FLAGS_cls_thresh, FLAGS_use_tensorrt, FLAGS_precision);
}
std::string rec_char_dict_path = FLAGS_rec_char_dict_path;
cout << "label file: " << rec_char_dict_path << endl;
CRNNRecognizer rec(FLAGS_rec_model_dir, FLAGS_use_gpu, FLAGS_gpu_id,
FLAGS_gpu_mem, FLAGS_cpu_threads, FLAGS_enable_mkldnn,
rec_char_dict_path, FLAGS_use_tensorrt, FLAGS_precision,
FLAGS_rec_batch_num);
for (int i = 0; i < cv_all_img_names.size(); ++i) {
cout << "The predict img: " << cv_all_img_names[i] << endl;
cv::Mat srcimg = cv::imread(cv_all_img_names[i], cv::IMREAD_COLOR);
if (!srcimg.data) {
std::cerr << "[ERROR] image read failed! image path: "
<< cv_all_img_names[i] << endl;
exit(1);
}
// det
std::vector<std::vector<std::vector<int>>> boxes;
std::vector<double> det_times;
std::vector<double> rec_times;
det.Run(srcimg, boxes, &det_times);
if (FLAGS_visualize) {
std::string file_name = Utility::basename(cv_all_img_names[i]);
Utility::VisualizeBboxes(srcimg, boxes, FLAGS_output + "/" + file_name);
}
time_info_det[0] += det_times[0];
time_info_det[1] += det_times[1];
time_info_det[2] += det_times[2];
// rec
std::vector<cv::Mat> img_list;
for (int j = 0; j < boxes.size(); j++) {
cv::Mat crop_img;
crop_img = Utility::GetRotateCropImage(srcimg, boxes[j]);
if (cls != nullptr) {
crop_img = cls->Run(crop_img);
}
img_list.push_back(crop_img);
}
std::vector<std::string> rec_texts(img_list.size(), "");
std::vector<float> rec_text_scores(img_list.size(), 0);
rec.Run(img_list, rec_texts, rec_text_scores, &rec_times);
// output rec results
for (int i = 0; i < rec_texts.size(); i++) {
std::cout << i << "\t" << rec_texts[i] << "\t" << rec_text_scores[i]
<< std::endl;
}
time_info_rec[0] += rec_times[0];
time_info_rec[1] += rec_times[1];
time_info_rec[2] += rec_times[2];
}
if (FLAGS_benchmark) {
AutoLogger autolog_det("ocr_det", FLAGS_use_gpu, FLAGS_use_tensorrt,
FLAGS_enable_mkldnn, FLAGS_cpu_threads, 1, "dynamic",
FLAGS_precision, time_info_det,
cv_all_img_names.size());
AutoLogger autolog_rec("ocr_rec", FLAGS_use_gpu, FLAGS_use_tensorrt,
FLAGS_enable_mkldnn, FLAGS_cpu_threads,
FLAGS_rec_batch_num, "dynamic", FLAGS_precision,
time_info_rec, cv_all_img_names.size());
autolog_det.report();
std::cout << endl;
autolog_rec.report();
}
return 0;
}
void check_params(char *mode) {
if (strcmp(mode, "det") == 0) {
void check_params() {
if (FLAGS_det) {
if (FLAGS_det_model_dir.empty() || FLAGS_image_dir.empty()) {
std::cout << "Usage[det]: ./ppocr "
"--det_model_dir=/PATH/TO/DET_INFERENCE_MODEL/ "
......@@ -290,7 +31,7 @@ void check_params(char *mode) {
exit(1);
}
}
if (strcmp(mode, "rec") == 0) {
if (FLAGS_rec) {
if (FLAGS_rec_model_dir.empty() || FLAGS_image_dir.empty()) {
std::cout << "Usage[rec]: ./ppocr "
"--rec_model_dir=/PATH/TO/REC_INFERENCE_MODEL/ "
......@@ -298,19 +39,10 @@ void check_params(char *mode) {
exit(1);
}
}
if (strcmp(mode, "system") == 0) {
if ((FLAGS_det_model_dir.empty() || FLAGS_rec_model_dir.empty() ||
FLAGS_image_dir.empty()) ||
(FLAGS_use_angle_cls && FLAGS_cls_model_dir.empty())) {
std::cout << "Usage[system without angle cls]: ./ppocr "
"--det_model_dir=/PATH/TO/DET_INFERENCE_MODEL/ "
<< "--rec_model_dir=/PATH/TO/REC_INFERENCE_MODEL/ "
<< "--image_dir=/PATH/TO/INPUT/IMAGE/" << std::endl;
std::cout << "Usage[system with angle cls]: ./ppocr "
"--det_model_dir=/PATH/TO/DET_INFERENCE_MODEL/ "
<< "--use_angle_cls=true "
<< "--cls_model_dir=/PATH/TO/CLS_INFERENCE_MODEL/ "
<< "--rec_model_dir=/PATH/TO/REC_INFERENCE_MODEL/ "
if (FLAGS_cls && FLAGS_use_angle_cls) {
if (FLAGS_cls_model_dir.empty() || FLAGS_image_dir.empty()) {
std::cout << "Usage[cls]: ./ppocr "
<< "--cls_model_dir=/PATH/TO/REC_INFERENCE_MODEL/ "
<< "--image_dir=/PATH/TO/INPUT/IMAGE/" << std::endl;
exit(1);
}
......@@ -323,19 +55,11 @@ void check_params(char *mode) {
}
int main(int argc, char **argv) {
if (argc <= 1 ||
(strcmp(argv[1], "det") != 0 && strcmp(argv[1], "rec") != 0 &&
strcmp(argv[1], "system") != 0)) {
std::cout << "Please choose one mode of [det, rec, system] !" << std::endl;
return -1;
}
std::cout << "mode: " << argv[1] << endl;
// Parsing command-line
google::ParseCommandLineFlags(&argc, &argv, true);
check_params(argv[1]);
check_params();
if (!PathExists(FLAGS_image_dir)) {
if (!Utility::PathExists(FLAGS_image_dir)) {
std::cerr << "[ERROR] image path not exist! image_dir: " << FLAGS_image_dir
<< endl;
exit(1);
......@@ -345,13 +69,37 @@ int main(int argc, char **argv) {
cv::glob(FLAGS_image_dir, cv_all_img_names);
std::cout << "total images num: " << cv_all_img_names.size() << endl;
if (strcmp(argv[1], "det") == 0) {
return main_det(cv_all_img_names);
}
if (strcmp(argv[1], "rec") == 0) {
return main_rec(cv_all_img_names);
}
if (strcmp(argv[1], "system") == 0) {
return main_system(cv_all_img_names);
PaddleOCR::PaddleOCR ocr = PaddleOCR::PaddleOCR();
std::vector<std::vector<OCRPredictResult>> ocr_results =
ocr.ocr(cv_all_img_names, FLAGS_det, FLAGS_rec, FLAGS_cls);
for (int i = 0; i < cv_all_img_names.size(); ++i) {
if (FLAGS_benchmark) {
cout << cv_all_img_names[i] << '\t';
for (int n = 0; n < ocr_results[i].size(); n++) {
for (int m = 0; m < ocr_results[i][n].box.size(); m++) {
cout << ocr_results[i][n].box[m][0] << ' '
<< ocr_results[i][n].box[m][1] << ' ';
}
}
cout << endl;
} else {
cout << cv_all_img_names[i] << "\n";
Utility::print_result(ocr_results[i]);
if (FLAGS_visualize && FLAGS_det) {
cv::Mat srcimg = cv::imread(cv_all_img_names[i], cv::IMREAD_COLOR);
if (!srcimg.data) {
std::cerr << "[ERROR] image read failed! image path: "
<< cv_all_img_names[i] << endl;
exit(1);
}
std::string file_name = Utility::basename(cv_all_img_names[i]);
Utility::VisualizeBboxes(srcimg, ocr_results[i],
FLAGS_output + "/" + file_name);
}
cout << "***************************" << endl;
}
}
}
......@@ -16,57 +16,84 @@
namespace PaddleOCR {
cv::Mat Classifier::Run(cv::Mat &img) {
cv::Mat src_img;
img.copyTo(src_img);
cv::Mat resize_img;
void Classifier::Run(std::vector<cv::Mat> img_list,
std::vector<int> &cls_labels,
std::vector<float> &cls_scores,
std::vector<double> &times) {
std::chrono::duration<float> preprocess_diff =
std::chrono::steady_clock::now() - std::chrono::steady_clock::now();
std::chrono::duration<float> inference_diff =
std::chrono::steady_clock::now() - std::chrono::steady_clock::now();
std::chrono::duration<float> postprocess_diff =
std::chrono::steady_clock::now() - std::chrono::steady_clock::now();
int img_num = img_list.size();
std::vector<int> cls_image_shape = {3, 48, 192};
int index = 0;
float wh_ratio = float(img.cols) / float(img.rows);
this->resize_op_.Run(img, resize_img, this->use_tensorrt_, cls_image_shape);
this->normalize_op_.Run(&resize_img, this->mean_, this->scale_,
this->is_scale_);
std::vector<float> input(1 * 3 * resize_img.rows * resize_img.cols, 0.0f);
this->permute_op_.Run(&resize_img, input.data());
// Inference.
auto input_names = this->predictor_->GetInputNames();
auto input_t = this->predictor_->GetInputHandle(input_names[0]);
input_t->Reshape({1, 3, resize_img.rows, resize_img.cols});
input_t->CopyFromCpu(input.data());
this->predictor_->Run();
std::vector<float> softmax_out;
std::vector<int64_t> label_out;
auto output_names = this->predictor_->GetOutputNames();
auto softmax_out_t = this->predictor_->GetOutputHandle(output_names[0]);
auto softmax_shape_out = softmax_out_t->shape();
int softmax_out_num =
std::accumulate(softmax_shape_out.begin(), softmax_shape_out.end(), 1,
std::multiplies<int>());
softmax_out.resize(softmax_out_num);
softmax_out_t->CopyToCpu(softmax_out.data());
float score = 0;
int label = 0;
for (int i = 0; i < softmax_out_num; i++) {
if (softmax_out[i] > score) {
score = softmax_out[i];
label = i;
for (int beg_img_no = 0; beg_img_no < img_num;
beg_img_no += this->cls_batch_num_) {
auto preprocess_start = std::chrono::steady_clock::now();
int end_img_no = min(img_num, beg_img_no + this->cls_batch_num_);
int batch_num = end_img_no - beg_img_no;
// preprocess
std::vector<cv::Mat> norm_img_batch;
for (int ino = beg_img_no; ino < end_img_no; ino++) {
cv::Mat srcimg;
img_list[ino].copyTo(srcimg);
cv::Mat resize_img;
this->resize_op_.Run(srcimg, resize_img, this->use_tensorrt_,
cls_image_shape);
this->normalize_op_.Run(&resize_img, this->mean_, this->scale_,
this->is_scale_);
norm_img_batch.push_back(resize_img);
}
std::vector<float> input(batch_num * cls_image_shape[0] *
cls_image_shape[1] * cls_image_shape[2],
0.0f);
this->permute_op_.Run(norm_img_batch, input.data());
auto preprocess_end = std::chrono::steady_clock::now();
preprocess_diff += preprocess_end - preprocess_start;
// inference.
auto input_names = this->predictor_->GetInputNames();
auto input_t = this->predictor_->GetInputHandle(input_names[0]);
input_t->Reshape({batch_num, cls_image_shape[0], cls_image_shape[1],
cls_image_shape[2]});
auto inference_start = std::chrono::steady_clock::now();
input_t->CopyFromCpu(input.data());
this->predictor_->Run();
std::vector<float> predict_batch;
auto output_names = this->predictor_->GetOutputNames();
auto output_t = this->predictor_->GetOutputHandle(output_names[0]);
auto predict_shape = output_t->shape();
int out_num = std::accumulate(predict_shape.begin(), predict_shape.end(), 1,
std::multiplies<int>());
predict_batch.resize(out_num);
output_t->CopyToCpu(predict_batch.data());
auto inference_end = std::chrono::steady_clock::now();
inference_diff += inference_end - inference_start;
// postprocess
auto postprocess_start = std::chrono::steady_clock::now();
for (int batch_idx = 0; batch_idx < predict_shape[0]; batch_idx++) {
int label = int(
Utility::argmax(&predict_batch[batch_idx * predict_shape[1]],
&predict_batch[(batch_idx + 1) * predict_shape[1]]));
float score = float(*std::max_element(
&predict_batch[batch_idx * predict_shape[1]],
&predict_batch[(batch_idx + 1) * predict_shape[1]]));
cls_labels[beg_img_no + batch_idx] = label;
cls_scores[beg_img_no + batch_idx] = score;
}
auto postprocess_end = std::chrono::steady_clock::now();
postprocess_diff += postprocess_end - postprocess_start;
}
if (label % 2 == 1 && score > this->cls_thresh) {
cv::rotate(src_img, src_img, 1);
}
return src_img;
times.push_back(double(preprocess_diff.count() * 1000));
times.push_back(double(inference_diff.count() * 1000));
times.push_back(double(postprocess_diff.count() * 1000));
}
void Classifier::LoadModel(const std::string &model_dir) {
......@@ -81,13 +108,10 @@ void Classifier::LoadModel(const std::string &model_dir) {
if (this->precision_ == "fp16") {
precision = paddle_infer::Config::Precision::kHalf;
}
if (this->precision_ == "int8") {
if (this->precision_ == "int8") {
precision = paddle_infer::Config::Precision::kInt8;
}
config.EnableTensorRtEngine(
1 << 20, 10, 3,
precision,
false, false);
}
config.EnableTensorRtEngine(1 << 20, 10, 3, precision, false, false);
}
} else {
config.DisableGpu();
......
......@@ -94,7 +94,7 @@ void DBDetector::LoadModel(const std::string &model_dir) {
void DBDetector::Run(cv::Mat &img,
std::vector<std::vector<std::vector<int>>> &boxes,
std::vector<double> *times) {
std::vector<double> &times) {
float ratio_h{};
float ratio_w{};
......@@ -165,16 +165,15 @@ void DBDetector::Run(cv::Mat &img,
boxes = post_processor_.FilterTagDetRes(boxes, ratio_h, ratio_w, srcimg);
auto postprocess_end = std::chrono::steady_clock::now();
std::cout << "Detected boxes num: " << boxes.size() << endl;
std::chrono::duration<float> preprocess_diff =
preprocess_end - preprocess_start;
times->push_back(double(preprocess_diff.count() * 1000));
times.push_back(double(preprocess_diff.count() * 1000));
std::chrono::duration<float> inference_diff = inference_end - inference_start;
times->push_back(double(inference_diff.count() * 1000));
times.push_back(double(inference_diff.count() * 1000));
std::chrono::duration<float> postprocess_diff =
postprocess_end - postprocess_start;
times->push_back(double(postprocess_diff.count() * 1000));
times.push_back(double(postprocess_diff.count() * 1000));
}
} // namespace PaddleOCR
......@@ -19,7 +19,7 @@ namespace PaddleOCR {
void CRNNRecognizer::Run(std::vector<cv::Mat> img_list,
std::vector<std::string> &rec_texts,
std::vector<float> &rec_text_scores,
std::vector<double> *times) {
std::vector<double> &times) {
std::chrono::duration<float> preprocess_diff =
std::chrono::steady_clock::now() - std::chrono::steady_clock::now();
std::chrono::duration<float> inference_diff =
......@@ -38,6 +38,7 @@ void CRNNRecognizer::Run(std::vector<cv::Mat> img_list,
beg_img_no += this->rec_batch_num_) {
auto preprocess_start = std::chrono::steady_clock::now();
int end_img_no = min(img_num, beg_img_no + this->rec_batch_num_);
int batch_num = end_img_no - beg_img_no;
float max_wh_ratio = 0;
for (int ino = beg_img_no; ino < end_img_no; ino++) {
int h = img_list[indices[ino]].rows;
......@@ -45,6 +46,7 @@ void CRNNRecognizer::Run(std::vector<cv::Mat> img_list,
float wh_ratio = w * 1.0 / h;
max_wh_ratio = max(max_wh_ratio, wh_ratio);
}
int batch_width = 0;
std::vector<cv::Mat> norm_img_batch;
for (int ino = beg_img_no; ino < end_img_no; ino++) {
......@@ -59,15 +61,14 @@ void CRNNRecognizer::Run(std::vector<cv::Mat> img_list,
batch_width = max(resize_img.cols, batch_width);
}
std::vector<float> input(this->rec_batch_num_ * 3 * 32 * batch_width, 0.0f);
std::vector<float> input(batch_num * 3 * 32 * batch_width, 0.0f);
this->permute_op_.Run(norm_img_batch, input.data());
auto preprocess_end = std::chrono::steady_clock::now();
preprocess_diff += preprocess_end - preprocess_start;
// Inference.
auto input_names = this->predictor_->GetInputNames();
auto input_t = this->predictor_->GetInputHandle(input_names[0]);
input_t->Reshape({this->rec_batch_num_, 3, 32, batch_width});
input_t->Reshape({batch_num, 3, 32, batch_width});
auto inference_start = std::chrono::steady_clock::now();
input_t->CopyFromCpu(input.data());
this->predictor_->Run();
......@@ -84,7 +85,6 @@ void CRNNRecognizer::Run(std::vector<cv::Mat> img_list,
output_t->CopyToCpu(predict_batch.data());
auto inference_end = std::chrono::steady_clock::now();
inference_diff += inference_end - inference_start;
// ctc decode
auto postprocess_start = std::chrono::steady_clock::now();
for (int m = 0; m < predict_shape[0]; m++) {
......@@ -120,9 +120,9 @@ void CRNNRecognizer::Run(std::vector<cv::Mat> img_list,
auto postprocess_end = std::chrono::steady_clock::now();
postprocess_diff += postprocess_end - postprocess_start;
}
times->push_back(double(preprocess_diff.count() * 1000));
times->push_back(double(inference_diff.count() * 1000));
times->push_back(double(postprocess_diff.count() * 1000));
times.push_back(double(preprocess_diff.count() * 1000));
times.push_back(double(inference_diff.count() * 1000));
times.push_back(double(postprocess_diff.count() * 1000));
}
void CRNNRecognizer::LoadModel(const std::string &model_dir) {
......
// Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include <include/args.h>
#include <include/paddleocr.h>
#include "auto_log/autolog.h"
#include <numeric>
#include <sys/stat.h>
namespace PaddleOCR {
PaddleOCR::PaddleOCR() {
if (FLAGS_det) {
this->detector_ = new DBDetector(
FLAGS_det_model_dir, FLAGS_use_gpu, FLAGS_gpu_id, FLAGS_gpu_mem,
FLAGS_cpu_threads, FLAGS_enable_mkldnn, FLAGS_max_side_len,
FLAGS_det_db_thresh, FLAGS_det_db_box_thresh, FLAGS_det_db_unclip_ratio,
FLAGS_det_db_score_mode, FLAGS_use_dilation, FLAGS_use_tensorrt,
FLAGS_precision);
}
if (FLAGS_cls && FLAGS_use_angle_cls) {
this->classifier_ = new Classifier(
FLAGS_cls_model_dir, FLAGS_use_gpu, FLAGS_gpu_id, FLAGS_gpu_mem,
FLAGS_cpu_threads, FLAGS_enable_mkldnn, FLAGS_cls_thresh,
FLAGS_use_tensorrt, FLAGS_precision, FLAGS_cls_batch_num);
}
if (FLAGS_rec) {
this->recognizer_ = new CRNNRecognizer(
FLAGS_rec_model_dir, FLAGS_use_gpu, FLAGS_gpu_id, FLAGS_gpu_mem,
FLAGS_cpu_threads, FLAGS_enable_mkldnn, FLAGS_rec_char_dict_path,
FLAGS_use_tensorrt, FLAGS_precision, FLAGS_rec_batch_num);
}
};
void PaddleOCR::det(cv::Mat img, std::vector<OCRPredictResult> &ocr_results,
std::vector<double> &times) {
std::vector<std::vector<std::vector<int>>> boxes;
std::vector<double> det_times;
this->detector_->Run(img, boxes, det_times);
for (int i = 0; i < boxes.size(); i++) {
OCRPredictResult res;
res.box = boxes[i];
ocr_results.push_back(res);
}
times[0] += det_times[0];
times[1] += det_times[1];
times[2] += det_times[2];
}
void PaddleOCR::rec(std::vector<cv::Mat> img_list,
std::vector<OCRPredictResult> &ocr_results,
std::vector<double> &times) {
std::vector<std::string> rec_texts(img_list.size(), "");
std::vector<float> rec_text_scores(img_list.size(), 0);
std::vector<double> rec_times;
this->recognizer_->Run(img_list, rec_texts, rec_text_scores, rec_times);
// output rec results
for (int i = 0; i < rec_texts.size(); i++) {
ocr_results[i].text = rec_texts[i];
ocr_results[i].score = rec_text_scores[i];
}
times[0] += rec_times[0];
times[1] += rec_times[1];
times[2] += rec_times[2];
}
void PaddleOCR::cls(std::vector<cv::Mat> img_list,
std::vector<OCRPredictResult> &ocr_results,
std::vector<double> &times) {
std::vector<int> cls_labels(img_list.size(), 0);
std::vector<float> cls_scores(img_list.size(), 0);
std::vector<double> cls_times;
this->classifier_->Run(img_list, cls_labels, cls_scores, cls_times);
// output cls results
for (int i = 0; i < cls_labels.size(); i++) {
ocr_results[i].cls_label = cls_labels[i];
ocr_results[i].cls_score = cls_scores[i];
}
times[0] += cls_times[0];
times[1] += cls_times[1];
times[2] += cls_times[2];
}
std::vector<std::vector<OCRPredictResult>>
PaddleOCR::ocr(std::vector<cv::String> cv_all_img_names, bool det, bool rec,
bool cls) {
std::vector<double> time_info_det = {0, 0, 0};
std::vector<double> time_info_rec = {0, 0, 0};
std::vector<double> time_info_cls = {0, 0, 0};
std::vector<std::vector<OCRPredictResult>> ocr_results;
if (!det) {
std::vector<OCRPredictResult> ocr_result;
// read image
std::vector<cv::Mat> img_list;
for (int i = 0; i < cv_all_img_names.size(); ++i) {
cv::Mat srcimg = cv::imread(cv_all_img_names[i], cv::IMREAD_COLOR);
if (!srcimg.data) {
std::cerr << "[ERROR] image read failed! image path: "
<< cv_all_img_names[i] << endl;
exit(1);
}
img_list.push_back(srcimg);
OCRPredictResult res;
ocr_result.push_back(res);
}
if (cls && this->classifier_ != nullptr) {
this->cls(img_list, ocr_result, time_info_cls);
for (int i = 0; i < img_list.size(); i++) {
if (ocr_result[i].cls_label % 2 == 1 &&
ocr_result[i].cls_score > this->classifier_->cls_thresh) {
cv::rotate(img_list[i], img_list[i], 1);
}
}
}
if (rec) {
this->rec(img_list, ocr_result, time_info_rec);
}
for (int i = 0; i < cv_all_img_names.size(); ++i) {
std::vector<OCRPredictResult> ocr_result_tmp;
ocr_result_tmp.push_back(ocr_result[i]);
ocr_results.push_back(ocr_result_tmp);
}
} else {
if (!Utility::PathExists(FLAGS_output) && FLAGS_det) {
mkdir(FLAGS_output.c_str(), 0777);
}
for (int i = 0; i < cv_all_img_names.size(); ++i) {
std::vector<OCRPredictResult> ocr_result;
if (!FLAGS_benchmark) {
cout << "predict img: " << cv_all_img_names[i] << endl;
}
cv::Mat srcimg = cv::imread(cv_all_img_names[i], cv::IMREAD_COLOR);
if (!srcimg.data) {
std::cerr << "[ERROR] image read failed! image path: "
<< cv_all_img_names[i] << endl;
exit(1);
}
// det
this->det(srcimg, ocr_result, time_info_det);
// crop image
std::vector<cv::Mat> img_list;
for (int j = 0; j < ocr_result.size(); j++) {
cv::Mat crop_img;
crop_img = Utility::GetRotateCropImage(srcimg, ocr_result[j].box);
img_list.push_back(crop_img);
}
// cls
if (cls && this->classifier_ != nullptr) {
this->cls(img_list, ocr_result, time_info_cls);
for (int i = 0; i < img_list.size(); i++) {
if (ocr_result[i].cls_label % 2 == 1 &&
ocr_result[i].cls_score > this->classifier_->cls_thresh) {
cv::rotate(img_list[i], img_list[i], 1);
}
}
}
// rec
if (rec) {
this->rec(img_list, ocr_result, time_info_rec);
}
ocr_results.push_back(ocr_result);
}
}
if (FLAGS_benchmark) {
this->log(time_info_det, time_info_rec, time_info_cls,
cv_all_img_names.size());
}
return ocr_results;
} // namespace PaddleOCR
void PaddleOCR::log(std::vector<double> &det_times,
std::vector<double> &rec_times,
std::vector<double> &cls_times, int img_num) {
if (det_times[0] + det_times[1] + det_times[2] > 0) {
AutoLogger autolog_det("ocr_det", FLAGS_use_gpu, FLAGS_use_tensorrt,
FLAGS_enable_mkldnn, FLAGS_cpu_threads, 1, "dynamic",
FLAGS_precision, det_times, img_num);
autolog_det.report();
}
if (rec_times[0] + rec_times[1] + rec_times[2] > 0) {
AutoLogger autolog_rec("ocr_rec", FLAGS_use_gpu, FLAGS_use_tensorrt,
FLAGS_enable_mkldnn, FLAGS_cpu_threads,
FLAGS_rec_batch_num, "dynamic", FLAGS_precision,
rec_times, img_num);
autolog_rec.report();
}
if (cls_times[0] + cls_times[1] + cls_times[2] > 0) {
AutoLogger autolog_cls("ocr_cls", FLAGS_use_gpu, FLAGS_use_tensorrt,
FLAGS_enable_mkldnn, FLAGS_cpu_threads,
FLAGS_cls_batch_num, "dynamic", FLAGS_precision,
cls_times, img_num);
autolog_cls.report();
}
}
PaddleOCR::~PaddleOCR() {
if (this->detector_ != nullptr) {
delete this->detector_;
}
if (this->classifier_ != nullptr) {
delete this->classifier_;
}
if (this->recognizer_ != nullptr) {
delete this->recognizer_;
}
};
} // namespace PaddleOCR
......@@ -38,16 +38,16 @@ std::vector<std::string> Utility::ReadDict(const std::string &path) {
return m_vec;
}
void Utility::VisualizeBboxes(
const cv::Mat &srcimg,
const std::vector<std::vector<std::vector<int>>> &boxes,
const std::string &save_path) {
void Utility::VisualizeBboxes(const cv::Mat &srcimg,
const std::vector<OCRPredictResult> &ocr_result,
const std::string &save_path) {
cv::Mat img_vis;
srcimg.copyTo(img_vis);
for (int n = 0; n < boxes.size(); n++) {
for (int n = 0; n < ocr_result.size(); n++) {
cv::Point rook_points[4];
for (int m = 0; m < boxes[n].size(); m++) {
rook_points[m] = cv::Point(int(boxes[n][m][0]), int(boxes[n][m][1]));
for (int m = 0; m < ocr_result[n].box.size(); m++) {
rook_points[m] =
cv::Point(int(ocr_result[n].box[m][0]), int(ocr_result[n].box[m][1]));
}
const cv::Point *ppt[1] = {rook_points};
......@@ -196,4 +196,43 @@ std::string Utility::basename(const std::string &filename) {
return filename.substr(index + 1, len - index);
}
bool Utility::PathExists(const std::string &path) {
#ifdef _WIN32
struct _stat buffer;
return (_stat(path.c_str(), &buffer) == 0);
#else
struct stat buffer;
return (stat(path.c_str(), &buffer) == 0);
#endif // !_WIN32
}
void Utility::print_result(const std::vector<OCRPredictResult> &ocr_result) {
for (int i = 0; i < ocr_result.size(); i++) {
std::cout << i << "\t";
// det
std::vector<std::vector<int>> boxes = ocr_result[i].box;
if (boxes.size() > 0) {
std::cout << "det boxes: [";
for (int n = 0; n < boxes.size(); n++) {
std::cout << '[' << boxes[n][0] << ',' << boxes[n][1] << "]";
if (n != boxes.size() - 1) {
std::cout << ',';
}
}
std::cout << "] ";
}
// rec
if (ocr_result[i].score != -1.0) {
std::cout << "rec text: " << ocr_result[i].text
<< " rec score: " << ocr_result[i].score << " ";
}
// cls
if (ocr_result[i].cls_label != -1) {
std::cout << "cls label: " << ocr_result[i].cls_label
<< " cls score: " << ocr_result[i].cls_score;
}
std::cout << std::endl;
}
}
} // namespace PaddleOCR
\ No newline at end of file
......@@ -3,7 +3,7 @@ model_name:ocr_det
use_opencv:True
infer_model:./inference/ch_ppocr_mobile_v2.0_det_infer/
infer_quant:False
inference:./deploy/cpp_infer/build/ppocr det
inference:./deploy/cpp_infer/build/ppocr
--use_gpu:True|False
--enable_mkldnn:True|False
--cpu_threads:1|6
......@@ -13,4 +13,8 @@ inference:./deploy/cpp_infer/build/ppocr det
--det_model_dir:
--image_dir:./inference/ch_det_data_50/all-sum-510/
null:null
--benchmark:True
\ No newline at end of file
--benchmark:True
--det:True
--rec:False
--cls:False
--use_angle_cls:False
\ No newline at end of file
......@@ -2,7 +2,7 @@
source test_tipc/common_func.sh
FILENAME=$1
dataline=$(awk 'NR==1, NR==16{print}' $FILENAME)
dataline=$(awk 'NR==1, NR==20{print}' $FILENAME)
# parser params
IFS=$'\n'
......@@ -34,6 +34,14 @@ cpp_infer_key1=$(func_parser_key "${lines[14]}")
cpp_infer_value1=$(func_parser_value "${lines[14]}")
cpp_benchmark_key=$(func_parser_key "${lines[15]}")
cpp_benchmark_value=$(func_parser_value "${lines[15]}")
cpp_det_key=$(func_parser_key "${lines[16]}")
cpp_det_value=$(func_parser_value "${lines[16]}")
cpp_rec_key=$(func_parser_key "${lines[17]}")
cpp_rec_value=$(func_parser_value "${lines[17]}")
cpp_cls_key=$(func_parser_key "${lines[18]}")
cpp_cls_value=$(func_parser_value "${lines[18]}")
cpp_use_angle_cls_key=$(func_parser_key "${lines[19]}")
cpp_use_angle_cls_value=$(func_parser_value "${lines[19]}")
LOG_PATH="./test_tipc/output"
mkdir -p ${LOG_PATH}
......@@ -68,7 +76,11 @@ function func_cpp_inference(){
set_cpu_threads=$(func_set_params "${cpp_cpu_threads_key}" "${threads}")
set_model_dir=$(func_set_params "${cpp_infer_model_key}" "${_model_dir}")
set_infer_params1=$(func_set_params "${cpp_infer_key1}" "${cpp_infer_value1}")
command="${_script} ${cpp_use_gpu_key}=${use_gpu} ${set_mkldnn} ${set_cpu_threads} ${set_model_dir} ${set_batchsize} ${set_infer_data} ${set_benchmark} ${set_infer_params1} > ${_save_log_path} 2>&1 "
set_det=$(func_set_params "${cpp_det_key}" "${cpp_det_value}")
set_rec=$(func_set_params "${cpp_rec_key}" "${cpp_rec_value}")
set_cls=$(func_set_params "${cpp_cls_key}" "${cpp_cls_value}")
set_use_angle_cls=$(func_set_params "${cpp_use_angle_cls_key}" "${cpp_use_angle_cls_value}")
command="${_script} ${cpp_use_gpu_key}=${use_gpu} ${set_mkldnn} ${set_cpu_threads} ${set_model_dir} ${set_batchsize} ${set_infer_data} ${set_benchmark} ${set_det} ${set_rec} ${set_cls} ${set_use_angle_cls} ${set_infer_params1} > ${_save_log_path} 2>&1 "
eval $command
last_status=${PIPESTATUS[0]}
eval "cat ${_save_log_path}"
......@@ -97,7 +109,11 @@ function func_cpp_inference(){
set_precision=$(func_set_params "${cpp_precision_key}" "${precision}")
set_model_dir=$(func_set_params "${cpp_infer_model_key}" "${_model_dir}")
set_infer_params1=$(func_set_params "${cpp_infer_key1}" "${cpp_infer_value1}")
command="${_script} ${cpp_use_gpu_key}=${use_gpu} ${set_tensorrt} ${set_precision} ${set_model_dir} ${set_batchsize} ${set_infer_data} ${set_benchmark} ${set_infer_params1} > ${_save_log_path} 2>&1 "
set_det=$(func_set_params "${cpp_det_key}" "${cpp_det_value}")
set_rec=$(func_set_params "${cpp_rec_key}" "${cpp_rec_value}")
set_cls=$(func_set_params "${cpp_cls_key}" "${cpp_cls_value}")
set_use_angle_cls=$(func_set_params "${cpp_use_angle_cls_key}" "${cpp_use_angle_cls_value}")
command="${_script} ${cpp_use_gpu_key}=${use_gpu} ${set_tensorrt} ${set_precision} ${set_model_dir} ${set_batchsize} ${set_infer_data} ${set_benchmark} ${set_det} ${set_rec} ${set_cls} ${set_use_angle_cls} ${set_infer_params1} > ${_save_log_path} 2>&1 "
eval $command
last_status=${PIPESTATUS[0]}
eval "cat ${_save_log_path}"
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册