未验证 提交 a9c0cab7 编写于 作者: A andyjpaddle 提交者: GitHub

Merge pull request #6042 from andyjpaddle/dygraph

update cpp infer for rec
...@@ -46,6 +46,8 @@ DECLARE_int32(cls_batch_num); ...@@ -46,6 +46,8 @@ DECLARE_int32(cls_batch_num);
DECLARE_string(rec_model_dir); DECLARE_string(rec_model_dir);
DECLARE_int32(rec_batch_num); DECLARE_int32(rec_batch_num);
DECLARE_string(rec_char_dict_path); DECLARE_string(rec_char_dict_path);
DECLARE_int32(rec_img_h);
DECLARE_int32(rec_img_w);
// forward related // forward related
DECLARE_bool(det); DECLARE_bool(det);
DECLARE_bool(rec); DECLARE_bool(rec);
......
...@@ -45,7 +45,8 @@ public: ...@@ -45,7 +45,8 @@ public:
const bool &use_mkldnn, const string &label_path, const bool &use_mkldnn, const string &label_path,
const bool &use_tensorrt, const bool &use_tensorrt,
const std::string &precision, const std::string &precision,
const int &rec_batch_num) { const int &rec_batch_num, const int &rec_img_h,
const int &rec_img_w) {
this->use_gpu_ = use_gpu; this->use_gpu_ = use_gpu;
this->gpu_id_ = gpu_id; this->gpu_id_ = gpu_id;
this->gpu_mem_ = gpu_mem; this->gpu_mem_ = gpu_mem;
...@@ -54,6 +55,10 @@ public: ...@@ -54,6 +55,10 @@ public:
this->use_tensorrt_ = use_tensorrt; this->use_tensorrt_ = use_tensorrt;
this->precision_ = precision; this->precision_ = precision;
this->rec_batch_num_ = rec_batch_num; this->rec_batch_num_ = rec_batch_num;
this->rec_img_h_ = rec_img_h;
this->rec_img_w_ = rec_img_w;
std::vector<int> rec_image_shape = {3, rec_img_h, rec_img_w};
this->rec_image_shape_ = rec_image_shape;
this->label_list_ = Utility::ReadDict(label_path); this->label_list_ = Utility::ReadDict(label_path);
this->label_list_.insert(this->label_list_.begin(), this->label_list_.insert(this->label_list_.begin(),
...@@ -86,7 +91,9 @@ private: ...@@ -86,7 +91,9 @@ private:
bool use_tensorrt_ = false; bool use_tensorrt_ = false;
std::string precision_ = "fp32"; std::string precision_ = "fp32";
int rec_batch_num_ = 6; int rec_batch_num_ = 6;
int rec_img_h_ = 32;
int rec_img_w_ = 320;
std::vector<int> rec_image_shape_ = {3, rec_img_h_, rec_img_w_};
// pre-process // pre-process
CrnnResizeImg resize_op_; CrnnResizeImg resize_op_;
Normalize normalize_op_; Normalize normalize_op_;
......
...@@ -323,6 +323,8 @@ More parameters are as follows, ...@@ -323,6 +323,8 @@ More parameters are as follows,
|rec_model_dir|string|-|Address of recognition inference model| |rec_model_dir|string|-|Address of recognition inference model|
|rec_char_dict_path|string|../../ppocr/utils/ppocr_keys_v1.txt|dictionary file| |rec_char_dict_path|string|../../ppocr/utils/ppocr_keys_v1.txt|dictionary file|
|rec_batch_num|int|6|batch size of recognition| |rec_batch_num|int|6|batch size of recognition|
|rec_img_h|int|32|image height of recognition|
|rec_img_w|int|320|image width of recognition|
* Multi-language inference is also supported in PaddleOCR, you can refer to [recognition tutorial](../../doc/doc_en/recognition_en.md) for more supported languages and models in PaddleOCR. Specifically, if you want to infer using multi-language models, you just need to modify values of `rec_char_dict_path` and `rec_model_dir`. * Multi-language inference is also supported in PaddleOCR, you can refer to [recognition tutorial](../../doc/doc_en/recognition_en.md) for more supported languages and models in PaddleOCR. Specifically, if you want to infer using multi-language models, you just need to modify values of `rec_char_dict_path` and `rec_model_dir`.
......
...@@ -336,6 +336,8 @@ CUDNN_LIB_DIR=/your_cudnn_lib_dir ...@@ -336,6 +336,8 @@ CUDNN_LIB_DIR=/your_cudnn_lib_dir
|rec_model_dir|string|-|识别模型inference model地址| |rec_model_dir|string|-|识别模型inference model地址|
|rec_char_dict_path|string|../../ppocr/utils/ppocr_keys_v1.txt|字典文件| |rec_char_dict_path|string|../../ppocr/utils/ppocr_keys_v1.txt|字典文件|
|rec_batch_num|int|6|识别模型batchsize| |rec_batch_num|int|6|识别模型batchsize|
|rec_img_h|int|32|识别模型输入图像高度|
|rec_img_w|int|320|识别模型输入图像宽度|
* PaddleOCR也支持多语言的预测,更多支持的语言和模型可以参考[识别文档](../../doc/doc_ch/recognition.md)中的多语言字典与模型部分,如果希望进行多语言预测,只需将修改`rec_char_dict_path`(字典文件路径)以及`rec_model_dir`(inference模型路径)字段即可。 * PaddleOCR也支持多语言的预测,更多支持的语言和模型可以参考[识别文档](../../doc/doc_ch/recognition.md)中的多语言字典与模型部分,如果希望进行多语言预测,只需将修改`rec_char_dict_path`(字典文件路径)以及`rec_model_dir`(inference模型路径)字段即可。
......
...@@ -47,6 +47,8 @@ DEFINE_string(rec_model_dir, "", "Path of rec inference model."); ...@@ -47,6 +47,8 @@ DEFINE_string(rec_model_dir, "", "Path of rec inference model.");
DEFINE_int32(rec_batch_num, 6, "rec_batch_num."); DEFINE_int32(rec_batch_num, 6, "rec_batch_num.");
DEFINE_string(rec_char_dict_path, "../../ppocr/utils/ppocr_keys_v1.txt", DEFINE_string(rec_char_dict_path, "../../ppocr/utils/ppocr_keys_v1.txt",
"Path of dictionary."); "Path of dictionary.");
DEFINE_int32(rec_img_h, 32, "rec image height");
DEFINE_int32(rec_img_w, 320, "rec image width");
// ocr forward related // ocr forward related
DEFINE_bool(det, true, "Whether use det in forward."); DEFINE_bool(det, true, "Whether use det in forward.");
......
...@@ -39,7 +39,9 @@ void CRNNRecognizer::Run(std::vector<cv::Mat> img_list, ...@@ -39,7 +39,9 @@ void CRNNRecognizer::Run(std::vector<cv::Mat> img_list,
auto preprocess_start = std::chrono::steady_clock::now(); auto preprocess_start = std::chrono::steady_clock::now();
int end_img_no = min(img_num, beg_img_no + this->rec_batch_num_); int end_img_no = min(img_num, beg_img_no + this->rec_batch_num_);
int batch_num = end_img_no - beg_img_no; int batch_num = end_img_no - beg_img_no;
float max_wh_ratio = 0; int imgH = this->rec_image_shape_[1];
int imgW = this->rec_image_shape_[2];
float max_wh_ratio = imgW * 1.0 / imgH;
for (int ino = beg_img_no; ino < end_img_no; ino++) { for (int ino = beg_img_no; ino < end_img_no; ino++) {
int h = img_list[indices[ino]].rows; int h = img_list[indices[ino]].rows;
int w = img_list[indices[ino]].cols; int w = img_list[indices[ino]].cols;
...@@ -47,28 +49,28 @@ void CRNNRecognizer::Run(std::vector<cv::Mat> img_list, ...@@ -47,28 +49,28 @@ void CRNNRecognizer::Run(std::vector<cv::Mat> img_list,
max_wh_ratio = max(max_wh_ratio, wh_ratio); max_wh_ratio = max(max_wh_ratio, wh_ratio);
} }
int batch_width = 0; int batch_width = imgW;
std::vector<cv::Mat> norm_img_batch; std::vector<cv::Mat> norm_img_batch;
for (int ino = beg_img_no; ino < end_img_no; ino++) { for (int ino = beg_img_no; ino < end_img_no; ino++) {
cv::Mat srcimg; cv::Mat srcimg;
img_list[indices[ino]].copyTo(srcimg); img_list[indices[ino]].copyTo(srcimg);
cv::Mat resize_img; cv::Mat resize_img;
this->resize_op_.Run(srcimg, resize_img, max_wh_ratio, this->resize_op_.Run(srcimg, resize_img, max_wh_ratio,
this->use_tensorrt_); this->use_tensorrt_, this->rec_image_shape_);
this->normalize_op_.Run(&resize_img, this->mean_, this->scale_, this->normalize_op_.Run(&resize_img, this->mean_, this->scale_,
this->is_scale_); this->is_scale_);
norm_img_batch.push_back(resize_img); norm_img_batch.push_back(resize_img);
batch_width = max(resize_img.cols, batch_width); batch_width = max(resize_img.cols, batch_width);
} }
std::vector<float> input(batch_num * 3 * 32 * batch_width, 0.0f); std::vector<float> input(batch_num * 3 * imgH * batch_width, 0.0f);
this->permute_op_.Run(norm_img_batch, input.data()); this->permute_op_.Run(norm_img_batch, input.data());
auto preprocess_end = std::chrono::steady_clock::now(); auto preprocess_end = std::chrono::steady_clock::now();
preprocess_diff += preprocess_end - preprocess_start; preprocess_diff += preprocess_end - preprocess_start;
// Inference. // Inference.
auto input_names = this->predictor_->GetInputNames(); auto input_names = this->predictor_->GetInputNames();
auto input_t = this->predictor_->GetInputHandle(input_names[0]); auto input_t = this->predictor_->GetInputHandle(input_names[0]);
input_t->Reshape({batch_num, 3, 32, batch_width}); input_t->Reshape({batch_num, 3, imgH, batch_width});
auto inference_start = std::chrono::steady_clock::now(); auto inference_start = std::chrono::steady_clock::now();
input_t->CopyFromCpu(input.data()); input_t->CopyFromCpu(input.data());
this->predictor_->Run(); this->predictor_->Run();
...@@ -142,13 +144,14 @@ void CRNNRecognizer::LoadModel(const std::string &model_dir) { ...@@ -142,13 +144,14 @@ void CRNNRecognizer::LoadModel(const std::string &model_dir) {
precision = paddle_infer::Config::Precision::kInt8; precision = paddle_infer::Config::Precision::kInt8;
} }
config.EnableTensorRtEngine(1 << 20, 10, 3, precision, false, false); config.EnableTensorRtEngine(1 << 20, 10, 3, precision, false, false);
int imgH = this->rec_image_shape_[1];
int imgW = this->rec_image_shape_[2];
std::map<std::string, std::vector<int>> min_input_shape = { std::map<std::string, std::vector<int>> min_input_shape = {
{"x", {1, 3, 32, 10}}, {"lstm_0.tmp_0", {10, 1, 96}}}; {"x", {1, 3, imgH, 10}}, {"lstm_0.tmp_0", {10, 1, 96}}};
std::map<std::string, std::vector<int>> max_input_shape = { std::map<std::string, std::vector<int>> max_input_shape = {
{"x", {1, 3, 32, 2000}}, {"lstm_0.tmp_0", {1000, 1, 96}}}; {"x", {1, 3, imgH, 2000}}, {"lstm_0.tmp_0", {1000, 1, 96}}};
std::map<std::string, std::vector<int>> opt_input_shape = { std::map<std::string, std::vector<int>> opt_input_shape = {
{"x", {1, 3, 32, 320}}, {"lstm_0.tmp_0", {25, 1, 96}}}; {"x", {1, 3, imgH, imgW}}, {"lstm_0.tmp_0", {25, 1, 96}}};
config.SetTRTDynamicShapeInfo(min_input_shape, max_input_shape, config.SetTRTDynamicShapeInfo(min_input_shape, max_input_shape,
opt_input_shape); opt_input_shape);
......
...@@ -39,7 +39,8 @@ PPOCR::PPOCR() { ...@@ -39,7 +39,8 @@ PPOCR::PPOCR() {
this->recognizer_ = new CRNNRecognizer( this->recognizer_ = new CRNNRecognizer(
FLAGS_rec_model_dir, FLAGS_use_gpu, FLAGS_gpu_id, FLAGS_gpu_mem, FLAGS_rec_model_dir, FLAGS_use_gpu, FLAGS_gpu_id, FLAGS_gpu_mem,
FLAGS_cpu_threads, FLAGS_enable_mkldnn, FLAGS_rec_char_dict_path, FLAGS_cpu_threads, FLAGS_enable_mkldnn, FLAGS_rec_char_dict_path,
FLAGS_use_tensorrt, FLAGS_precision, FLAGS_rec_batch_num); FLAGS_use_tensorrt, FLAGS_precision, FLAGS_rec_batch_num,
FLAGS_rec_img_h, FLAGS_rec_img_w);
} }
}; };
......
...@@ -41,12 +41,13 @@ void Permute::Run(const cv::Mat *im, float *data) { ...@@ -41,12 +41,13 @@ void Permute::Run(const cv::Mat *im, float *data) {
} }
void PermuteBatch::Run(const std::vector<cv::Mat> imgs, float *data) { void PermuteBatch::Run(const std::vector<cv::Mat> imgs, float *data) {
for (int j = 0; j < imgs.size(); j ++){ for (int j = 0; j < imgs.size(); j++) {
int rh = imgs[j].rows; int rh = imgs[j].rows;
int rw = imgs[j].cols; int rw = imgs[j].cols;
int rc = imgs[j].channels(); int rc = imgs[j].channels();
for (int i = 0; i < rc; ++i) { for (int i = 0; i < rc; ++i) {
cv::extractChannel(imgs[j], cv::Mat(rh, rw, CV_32FC1, data + (j * rc + i) * rh * rw), i); cv::extractChannel(
imgs[j], cv::Mat(rh, rw, CV_32FC1, data + (j * rc + i) * rh * rw), i);
} }
} }
} }
...@@ -102,7 +103,7 @@ void CrnnResizeImg::Run(const cv::Mat &img, cv::Mat &resize_img, float wh_ratio, ...@@ -102,7 +103,7 @@ void CrnnResizeImg::Run(const cv::Mat &img, cv::Mat &resize_img, float wh_ratio,
imgH = rec_image_shape[1]; imgH = rec_image_shape[1];
imgW = rec_image_shape[2]; imgW = rec_image_shape[2];
imgW = int(32 * wh_ratio); imgW = int(imgH * wh_ratio);
float ratio = float(img.cols) / float(img.rows); float ratio = float(img.cols) / float(img.rows);
int resize_w, resize_h; int resize_w, resize_h;
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册