提交 eb3d3f56 编写于 作者: 文幕地方's avatar 文幕地方

rec return result

上级 97f0a2d5
......@@ -46,8 +46,7 @@ public:
const double &det_db_box_thresh,
const double &det_db_unclip_ratio,
const bool &use_polygon_score, const bool &use_dilation,
const bool &visualize, const bool &use_tensorrt,
const std::string &precision) {
const bool &use_tensorrt, const std::string &precision) {
this->use_gpu_ = use_gpu;
this->gpu_id_ = gpu_id;
this->gpu_mem_ = gpu_mem;
......@@ -62,7 +61,6 @@ public:
this->use_polygon_score_ = use_polygon_score;
this->use_dilation_ = use_dilation;
this->visualize_ = visualize;
this->use_tensorrt_ = use_tensorrt;
this->precision_ = precision;
......
......@@ -44,7 +44,8 @@ public:
const int &gpu_id, const int &gpu_mem,
const int &cpu_math_library_num_threads,
const bool &use_mkldnn, const string &label_path,
const bool &use_tensorrt, const std::string &precision,
const bool &use_tensorrt,
const std::string &precision,
const int &rec_batch_num) {
this->use_gpu_ = use_gpu;
this->gpu_id_ = gpu_id;
......@@ -66,7 +67,8 @@ public:
// Load Paddle inference model
void LoadModel(const std::string &model_dir);
void Run(std::vector<cv::Mat> img_list, std::vector<double> *times);
void Run(std::vector<cv::Mat> img_list, std::vector<std::string> &rec_texts,
std::vector<float> &rec_text_scores, std::vector<double> *times);
private:
std::shared_ptr<Predictor> predictor_;
......
......@@ -38,7 +38,8 @@ public:
static void
VisualizeBboxes(const cv::Mat &srcimg,
const std::vector<std::vector<std::vector<int>>> &boxes);
const std::vector<std::vector<std::vector<int>>> &boxes,
const std::string &save_path);
template <class ForwardIterator>
inline static size_t argmax(ForwardIterator first, ForwardIterator last) {
......@@ -51,8 +52,9 @@ public:
static cv::Mat GetRotateCropImage(const cv::Mat &srcimage,
std::vector<std::vector<int>> box);
static std::vector<int> argsort(const std::vector<float>& array);
static std::vector<int> argsort(const std::vector<float> &array);
static std::string basename(const std::string &filename);
};
} // namespace PaddleOCR
\ No newline at end of file
......@@ -30,7 +30,7 @@ PaddleOCR模型部署。
### 1.0 运行准备
- Linux环境,推荐使用docker。
- Windows环境,目前支持基于`Visual Studio 2019 Community`进行编译
- Windows环境。
* 该文档主要介绍基于Linux环境的PaddleOCR C++预测流程,如果需要在Windows下基于预测库进行C++预测,具体编译方法请参考[Windows下编译教程](./docs/windows_vs2019_build.md)
......@@ -256,6 +256,7 @@ CUDNN_LIB_DIR=/your_cudnn_lib_dir
|gpu_mem|int|4000|申请的GPU内存|
|cpu_math_library_num_threads|int|10|CPU预测时的线程数,在机器核数充足的情况下,该值越大,预测速度越快|
|enable_mkldnn|bool|true|是否使用mkldnn库|
|output|str|./output|可视化结果保存的路径|
- 检测模型相关
......@@ -267,7 +268,7 @@ CUDNN_LIB_DIR=/your_cudnn_lib_dir
|det_db_box_thresh|float|0.5|DB后处理过滤box的阈值,如果检测存在漏框情况,可酌情减小|
|det_db_unclip_ratio|float|1.6|表示文本框的紧致程度,越小则文本框更靠近文本|
|use_polygon_score|bool|false|是否使用多边形框计算bbox score,false表示使用矩形框计算。矩形框计算速度更快,多边形框对弯曲文本区域计算更准确。|
|visualize|bool|true|是否对结果进行可视化,为1时,会在当前文件夹下保存文件名为`ocr_vis.png`的预测结果。|
|visualize|bool|true|是否对结果进行可视化,为1时,预测结果会保存在`output`字段指定的文件夹下和输入图像同名的图像上。|
- 方向分类器相关
......
......@@ -26,6 +26,7 @@ This section will introduce how to configure the C++ environment and deploy Padd
### Environment
- Linux, docker is recommended.
- Windows.
### 1.1 Compile OpenCV
......@@ -248,6 +249,7 @@ More parameters are as follows,
|gpu_mem|int|4000|GPU memory requested|
|cpu_math_library_num_threads|int|10|Number of threads when using CPU inference. When machine cores is enough, the large the value, the faster the inference speed|
|enable_mkldnn|bool|true|Whether to use mkdlnn library|
|output|str|./output|Path where visualization results are saved|
- Detection related parameters
......@@ -259,7 +261,7 @@ More parameters are as follows,
|det_db_box_thresh|float|0.5|DB post-processing filter box threshold, if there is a missing box detected, it can be reduced as appropriate|
|det_db_unclip_ratio|float|1.6|Indicates the compactness of the text box, the smaller the value, the closer the text box to the text|
|use_polygon_score|bool|false|Whether to use polygon box to calculate bbox score, false means to use rectangle box to calculate. Use rectangular box to calculate faster, and polygonal box more accurate for curved text area.|
|visualize|bool|true|Whether to visualize the results,when it is set as true, The prediction result will be save in the image file `./ocr_vis.png`.|
|visualize|bool|true|Whether to visualize the results,when it is set as true, the prediction results will be saved in the folder specified by the `output` field on an image with the same name as the input image.|
- Classifier related parameters
......
......@@ -12,7 +12,6 @@
// See the License for the specific language governing permissions and
// limitations under the License.
#include "glog/logging.h"
#include "omp.h"
#include "opencv2/core.hpp"
#include "opencv2/imgcodecs.hpp"
......@@ -21,13 +20,13 @@
#include <iomanip>
#include <iostream>
#include <ostream>
#include <sys/stat.h>
#include <vector>
#include <cstring>
#include <fstream>
#include <numeric>
#include <glog/logging.h>
#include <include/ocr_cls.h>
#include <include/ocr_det.h>
#include <include/ocr_rec.h>
......@@ -45,7 +44,7 @@ DEFINE_bool(enable_mkldnn, false, "Whether use mkldnn with CPU.");
DEFINE_bool(use_tensorrt, false, "Whether use tensorrt.");
DEFINE_string(precision, "fp32", "Precision be one of fp32/fp16/int8");
DEFINE_bool(benchmark, false, "Whether use benchmark.");
DEFINE_string(save_log_path, "./log_output/", "Save benchmark log path.");
DEFINE_string(output, "./output/", "Save benchmark log path.");
// detection related
DEFINE_string(image_dir, "", "Dir of input image.");
DEFINE_string(det_model_dir, "", "Path of det inference model.");
......@@ -86,11 +85,17 @@ int main_det(std::vector<cv::String> cv_all_img_names) {
FLAGS_gpu_mem, FLAGS_cpu_threads, FLAGS_enable_mkldnn,
FLAGS_max_side_len, FLAGS_det_db_thresh,
FLAGS_det_db_box_thresh, FLAGS_det_db_unclip_ratio,
FLAGS_use_polygon_score, FLAGS_use_dilation, FLAGS_visualize,
FLAGS_use_polygon_score, FLAGS_use_dilation,
FLAGS_use_tensorrt, FLAGS_precision);
if (!PathExists(FLAGS_output)) {
mkdir(FLAGS_output.c_str(), 0777);
}
for (int i = 0; i < cv_all_img_names.size(); ++i) {
// LOG(INFO) << "The predict img: " << cv_all_img_names[i];
if (!FLAGS_benchmark) {
cout << "The predict img: " << cv_all_img_names[i] << endl;
}
cv::Mat srcimg = cv::imread(cv_all_img_names[i], cv::IMREAD_COLOR);
if (!srcimg.data) {
......@@ -102,7 +107,11 @@ int main_det(std::vector<cv::String> cv_all_img_names) {
std::vector<double> det_times;
det.Run(srcimg, boxes, &det_times);
//// visualization
if (FLAGS_visualize) {
std::string file_name = Utility::basename(cv_all_img_names[i]);
Utility::VisualizeBboxes(srcimg, boxes, FLAGS_output + "/" + file_name);
}
time_info[0] += det_times[0];
time_info[1] += det_times[1];
time_info[2] += det_times[2];
......@@ -142,8 +151,6 @@ int main_rec(std::vector<cv::String> cv_all_img_names) {
std::vector<cv::Mat> img_list;
for (int i = 0; i < cv_all_img_names.size(); ++i) {
LOG(INFO) << "The predict img: " << cv_all_img_names[i];
cv::Mat srcimg = cv::imread(cv_all_img_names[i], cv::IMREAD_COLOR);
if (!srcimg.data) {
std::cerr << "[ERROR] image read failed! image path: "
......@@ -152,8 +159,15 @@ int main_rec(std::vector<cv::String> cv_all_img_names) {
}
img_list.push_back(srcimg);
}
std::vector<std::string> rec_texts(img_list.size(), "");
std::vector<float> rec_text_scores(img_list.size(), 0);
std::vector<double> rec_times;
rec.Run(img_list, &rec_times);
rec.Run(img_list, rec_texts, rec_text_scores, &rec_times);
// output rec results
for (int i = 0; i < rec_texts.size(); i++) {
cout << "The predict img: " << cv_all_img_names[i] << "\t" << rec_texts[i]
<< "\t" << rec_text_scores[i] << endl;
}
time_info[0] += rec_times[0];
time_info[1] += rec_times[1];
time_info[2] += rec_times[2];
......@@ -172,11 +186,15 @@ int main_system(std::vector<cv::String> cv_all_img_names) {
std::vector<double> time_info_det = {0, 0, 0};
std::vector<double> time_info_rec = {0, 0, 0};
if (!PathExists(FLAGS_output)) {
mkdir(FLAGS_output.c_str(), 0777);
}
DBDetector det(FLAGS_det_model_dir, FLAGS_use_gpu, FLAGS_gpu_id,
FLAGS_gpu_mem, FLAGS_cpu_threads, FLAGS_enable_mkldnn,
FLAGS_max_side_len, FLAGS_det_db_thresh,
FLAGS_det_db_box_thresh, FLAGS_det_db_unclip_ratio,
FLAGS_use_polygon_score, FLAGS_use_dilation, FLAGS_visualize,
FLAGS_use_polygon_score, FLAGS_use_dilation,
FLAGS_use_tensorrt, FLAGS_precision);
Classifier *cls = nullptr;
......@@ -197,7 +215,7 @@ int main_system(std::vector<cv::String> cv_all_img_names) {
FLAGS_rec_batch_num);
for (int i = 0; i < cv_all_img_names.size(); ++i) {
LOG(INFO) << "The predict img: " << cv_all_img_names[i];
cout << "The predict img: " << cv_all_img_names[i] << endl;
cv::Mat srcimg = cv::imread(cv_all_img_names[i], cv::IMREAD_COLOR);
if (!srcimg.data) {
......@@ -205,15 +223,21 @@ int main_system(std::vector<cv::String> cv_all_img_names) {
<< cv_all_img_names[i] << endl;
exit(1);
}
#det
std::vector<std::vector<std::vector<int>>> boxes;
std::vector<double> det_times;
std::vector<double> rec_times;
det.Run(srcimg, boxes, &det_times);
if (FLAGS_visualize) {
std::string file_name = Utility::basename(cv_all_img_names[i]);
Utility::VisualizeBboxes(srcimg, boxes, FLAGS_output + "/" + file_name);
}
time_info_det[0] += det_times[0];
time_info_det[1] += det_times[1];
time_info_det[2] += det_times[2];
#rec
std::vector<cv::Mat> img_list;
for (int j = 0; j < boxes.size(); j++) {
cv::Mat crop_img;
......@@ -223,8 +247,14 @@ int main_system(std::vector<cv::String> cv_all_img_names) {
}
img_list.push_back(crop_img);
}
rec.Run(img_list, &rec_times);
std::vector<std::string> rec_texts(img_list.size(), "");
std::vector<float> rec_text_scores(img_list.size(), 0);
rec.Run(img_list, rec_texts, rec_text_scores, &rec_times);
// output rec results
for (int i = 0; i < rec_texts.size(); i++) {
std::cout << i << "\t" << rec_texts[i] << "\t" << rec_text_scores[i]
<< std::endl;
}
time_info_rec[0] += rec_times[0];
time_info_rec[1] += rec_times[1];
time_info_rec[2] += rec_times[2];
......
......@@ -175,11 +175,6 @@ void DBDetector::Run(cv::Mat &img,
std::chrono::duration<float> postprocess_diff =
postprocess_end - postprocess_start;
times->push_back(double(postprocess_diff.count() * 1000));
//// visualization
if (this->visualize_) {
Utility::VisualizeBboxes(srcimg, boxes);
}
}
} // namespace PaddleOCR
......@@ -17,6 +17,8 @@
namespace PaddleOCR {
void CRNNRecognizer::Run(std::vector<cv::Mat> img_list,
std::vector<std::string> &rec_texts,
std::vector<float> &rec_text_scores,
std::vector<double> *times) {
std::chrono::duration<float> preprocess_diff =
std::chrono::steady_clock::now() - std::chrono::steady_clock::now();
......@@ -86,7 +88,7 @@ void CRNNRecognizer::Run(std::vector<cv::Mat> img_list,
// ctc decode
auto postprocess_start = std::chrono::steady_clock::now();
for (int m = 0; m < predict_shape[0]; m++) {
std::vector<std::string> str_res;
std::string str_res;
int argmax_idx;
int last_index = 0;
float score = 0.f;
......@@ -104,17 +106,16 @@ void CRNNRecognizer::Run(std::vector<cv::Mat> img_list,
if (argmax_idx > 0 && (!(n > 0 && argmax_idx == last_index))) {
score += max_value;
count += 1;
str_res.push_back(label_list_[argmax_idx]);
str_res += label_list_[argmax_idx];
}
last_index = argmax_idx;
}
score /= count;
if (isnan(score))
if (isnan(score)) {
continue;
for (int i = 0; i < str_res.size(); i++) {
std::cout << str_res[i];
}
std::cout << "\tscore: " << score << std::endl;
rec_texts[indices[beg_img_no + m]] = str_res;
rec_text_scores[indices[beg_img_no + m]] = score;
}
auto postprocess_end = std::chrono::steady_clock::now();
postprocess_diff += postprocess_end - postprocess_start;
......
......@@ -40,7 +40,8 @@ std::vector<std::string> Utility::ReadDict(const std::string &path) {
void Utility::VisualizeBboxes(
const cv::Mat &srcimg,
const std::vector<std::vector<std::vector<int>>> &boxes) {
const std::vector<std::vector<std::vector<int>>> &boxes,
const std::string &save_path) {
cv::Mat img_vis;
srcimg.copyTo(img_vis);
for (int n = 0; n < boxes.size(); n++) {
......@@ -54,8 +55,8 @@ void Utility::VisualizeBboxes(
cv::polylines(img_vis, ppt, npt, 1, 1, CV_RGB(0, 255, 0), 2, 8, 0);
}
cv::imwrite("./ocr_vis.png", img_vis);
std::cout << "The detection visualized image saved in ./ocr_vis.png"
cv::imwrite(save_path, img_vis);
std::cout << "The detection visualized image saved in " + save_path
<< std::endl;
}
......@@ -147,17 +148,52 @@ cv::Mat Utility::GetRotateCropImage(const cv::Mat &srcimage,
}
}
std::vector<int> Utility::argsort(const std::vector<float>& array)
{
std::vector<int> Utility::argsort(const std::vector<float> &array) {
const int array_len(array.size());
std::vector<int> array_index(array_len, 0);
for (int i = 0; i < array_len; ++i)
array_index[i] = i;
std::sort(array_index.begin(), array_index.end(),
[&array](int pos1, int pos2) {return (array[pos1] < array[pos2]); });
std::sort(
array_index.begin(), array_index.end(),
[&array](int pos1, int pos2) { return (array[pos1] < array[pos2]); });
return array_index;
}
std::string Utility::basename(const std::string &filename) {
if (filename.empty()) {
return "";
}
auto len = filename.length();
auto index = filename.find_last_of("/\\");
if (index == std::string::npos) {
return filename;
}
if (index + 1 >= len) {
len--;
index = filename.substr(0, len).find_last_of("/\\");
if (len == 0) {
return filename;
}
if (index == 0) {
return filename.substr(1, len - 1);
}
if (index == std::string::npos) {
return filename.substr(0, len);
}
return filename.substr(index + 1, len - index - 1);
}
return filename.substr(index + 1, len - index);
}
} // namespace PaddleOCR
\ No newline at end of file
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册