未验证 提交 594d34f5 编写于 作者: M MissPenguin 提交者: GitHub

Merge pull request #708 from WenmuZhou/angle_class

添加分类模型
Global:
algorithm: CLS
use_gpu: False
epoch_num: 100
log_smooth_window: 20
print_batch_step: 100
save_model_dir: output/cls_mv3
save_epoch_step: 3
eval_batch_step: 500
train_batch_size_per_card: 512
test_batch_size_per_card: 512
image_shape: [3, 48, 192]
label_list: ['0','180']
distort: True
reader_yml: ./configs/cls/cls_reader.yml
pretrain_weights:
checkpoints:
save_inference_dir:
infer_img:
Architecture:
function: ppocr.modeling.architectures.cls_model,ClsModel
Backbone:
function: ppocr.modeling.backbones.rec_mobilenet_v3,MobileNetV3
scale: 0.35
model_name: small
Head:
function: ppocr.modeling.heads.cls_head,ClsHead
class_dim: 2
Loss:
function: ppocr.modeling.losses.cls_loss,ClsLoss
Optimizer:
function: ppocr.optimizer,AdamDecay
base_lr: 0.001
beta1: 0.9
beta2: 0.999
decay:
function: cosine_decay
step_each_epoch: 1169
total_epoch: 100
\ No newline at end of file
TrainReader:
reader_function: ppocr.data.cls.dataset_traversal,SimpleReader
num_workers: 8
img_set_dir: ./train_data/cls
label_file_path: ./train_data/cls/train.txt
EvalReader:
reader_function: ppocr.data.cls.dataset_traversal,SimpleReader
img_set_dir: ./train_data/cls
label_file_path: ./train_data/cls/test.txt
TestReader:
reader_function: ppocr.data.cls.dataset_traversal,SimpleReader
...@@ -4,29 +4,29 @@ ...@@ -4,29 +4,29 @@
#include "native.h" #include "native.h"
#include "ocr_ppredictor.h" #include "ocr_ppredictor.h"
#include <string>
#include <algorithm> #include <algorithm>
#include <paddle_api.h> #include <paddle_api.h>
#include <string>
static paddle::lite_api::PowerMode str_to_cpu_mode(const std::string &cpu_mode); static paddle::lite_api::PowerMode str_to_cpu_mode(const std::string &cpu_mode);
extern "C" extern "C" JNIEXPORT jlong JNICALL
JNIEXPORT jlong JNICALL Java_com_baidu_paddle_lite_demo_ocr_OCRPredictorNative_init(
Java_com_baidu_paddle_lite_demo_ocr_OCRPredictorNative_init(JNIEnv *env, jobject thiz, JNIEnv *env, jobject thiz, jstring j_det_model_path,
jstring j_det_model_path, jstring j_rec_model_path, jstring j_cls_model_path, jint j_thread_num,
jstring j_rec_model_path, jstring j_cpu_mode) {
jint j_thread_num, std::string det_model_path = jstring_to_cpp_string(env, j_det_model_path);
jstring j_cpu_mode) { std::string rec_model_path = jstring_to_cpp_string(env, j_rec_model_path);
std::string det_model_path = jstring_to_cpp_string(env, j_det_model_path); std::string cls_model_path = jstring_to_cpp_string(env, j_cls_model_path);
std::string rec_model_path = jstring_to_cpp_string(env, j_rec_model_path); int thread_num = j_thread_num;
int thread_num = j_thread_num; std::string cpu_mode = jstring_to_cpp_string(env, j_cpu_mode);
std::string cpu_mode = jstring_to_cpp_string(env, j_cpu_mode); ppredictor::OCR_Config conf;
ppredictor::OCR_Config conf; conf.thread_num = thread_num;
conf.thread_num = thread_num; conf.mode = str_to_cpu_mode(cpu_mode);
conf.mode = str_to_cpu_mode(cpu_mode); ppredictor::OCR_PPredictor *orc_predictor =
ppredictor::OCR_PPredictor *orc_predictor = new ppredictor::OCR_PPredictor{conf}; new ppredictor::OCR_PPredictor{conf};
orc_predictor->init_from_file(det_model_path, rec_model_path); orc_predictor->init_from_file(det_model_path, rec_model_path, cls_model_path);
return reinterpret_cast<jlong>(orc_predictor); return reinterpret_cast<jlong>(orc_predictor);
} }
/** /**
...@@ -34,82 +34,81 @@ Java_com_baidu_paddle_lite_demo_ocr_OCRPredictorNative_init(JNIEnv *env, jobject ...@@ -34,82 +34,81 @@ Java_com_baidu_paddle_lite_demo_ocr_OCRPredictorNative_init(JNIEnv *env, jobject
* @param cpu_mode * @param cpu_mode
* @return * @return
*/ */
static paddle::lite_api::PowerMode str_to_cpu_mode(const std::string &cpu_mode) { static paddle::lite_api::PowerMode
static std::map<std::string, paddle::lite_api::PowerMode> cpu_mode_map{ str_to_cpu_mode(const std::string &cpu_mode) {
{"LITE_POWER_HIGH", paddle::lite_api::LITE_POWER_HIGH}, static std::map<std::string, paddle::lite_api::PowerMode> cpu_mode_map{
{"LITE_POWER_LOW", paddle::lite_api::LITE_POWER_HIGH}, {"LITE_POWER_HIGH", paddle::lite_api::LITE_POWER_HIGH},
{"LITE_POWER_FULL", paddle::lite_api::LITE_POWER_FULL}, {"LITE_POWER_LOW", paddle::lite_api::LITE_POWER_HIGH},
{"LITE_POWER_NO_BIND", paddle::lite_api::LITE_POWER_NO_BIND}, {"LITE_POWER_FULL", paddle::lite_api::LITE_POWER_FULL},
{"LITE_POWER_RAND_HIGH", paddle::lite_api::LITE_POWER_RAND_HIGH}, {"LITE_POWER_NO_BIND", paddle::lite_api::LITE_POWER_NO_BIND},
{"LITE_POWER_RAND_LOW", paddle::lite_api::LITE_POWER_RAND_LOW} {"LITE_POWER_RAND_HIGH", paddle::lite_api::LITE_POWER_RAND_HIGH},
}; {"LITE_POWER_RAND_LOW", paddle::lite_api::LITE_POWER_RAND_LOW}};
std::string upper_key; std::string upper_key;
std::transform(cpu_mode.cbegin(), cpu_mode.cend(), upper_key.begin(), ::toupper); std::transform(cpu_mode.cbegin(), cpu_mode.cend(), upper_key.begin(),
auto index = cpu_mode_map.find(upper_key); ::toupper);
if (index == cpu_mode_map.end()) { auto index = cpu_mode_map.find(upper_key);
LOGE("cpu_mode not found %s", upper_key.c_str()); if (index == cpu_mode_map.end()) {
return paddle::lite_api::LITE_POWER_HIGH; LOGE("cpu_mode not found %s", upper_key.c_str());
} else { return paddle::lite_api::LITE_POWER_HIGH;
return index->second; } else {
} return index->second;
}
} }
extern "C" extern "C" JNIEXPORT jfloatArray JNICALL
JNIEXPORT jfloatArray JNICALL Java_com_baidu_paddle_lite_demo_ocr_OCRPredictorNative_forward(
Java_com_baidu_paddle_lite_demo_ocr_OCRPredictorNative_forward(JNIEnv *env, jobject thiz, JNIEnv *env, jobject thiz, jlong java_pointer, jfloatArray buf,
jlong java_pointer, jfloatArray buf, jfloatArray ddims, jobject original_image) {
jfloatArray ddims, LOGI("begin to run native forward");
jobject original_image) { if (java_pointer == 0) {
LOGI("begin to run native forward"); LOGE("JAVA pointer is NULL");
if (java_pointer == 0) { return cpp_array_to_jfloatarray(env, nullptr, 0);
LOGE("JAVA pointer is NULL"); }
return cpp_array_to_jfloatarray(env, nullptr, 0); cv::Mat origin = bitmap_to_cv_mat(env, original_image);
} if (origin.size == 0) {
cv::Mat origin = bitmap_to_cv_mat(env, original_image); LOGE("origin bitmap cannot convert to CV Mat");
if (origin.size == 0) { return cpp_array_to_jfloatarray(env, nullptr, 0);
LOGE("origin bitmap cannot convert to CV Mat"); }
return cpp_array_to_jfloatarray(env, nullptr, 0); ppredictor::OCR_PPredictor *ppredictor =
} (ppredictor::OCR_PPredictor *)java_pointer;
ppredictor::OCR_PPredictor *ppredictor = (ppredictor::OCR_PPredictor *) java_pointer; std::vector<float> dims_float_arr = jfloatarray_to_float_vector(env, ddims);
std::vector<float> dims_float_arr = jfloatarray_to_float_vector(env, ddims); std::vector<int64_t> dims_arr;
std::vector<int64_t> dims_arr; dims_arr.resize(dims_float_arr.size());
dims_arr.resize(dims_float_arr.size()); std::copy(dims_float_arr.cbegin(), dims_float_arr.cend(), dims_arr.begin());
std::copy(dims_float_arr.cbegin(), dims_float_arr.cend(), dims_arr.begin());
// 这里值有点大,就不调用jfloatarray_to_float_vector了 // 这里值有点大,就不调用jfloatarray_to_float_vector了
int64_t buf_len = (int64_t) env->GetArrayLength(buf); int64_t buf_len = (int64_t)env->GetArrayLength(buf);
jfloat *buf_data = env->GetFloatArrayElements(buf, JNI_FALSE); jfloat *buf_data = env->GetFloatArrayElements(buf, JNI_FALSE);
float *data = (jfloat *) buf_data; float *data = (jfloat *)buf_data;
std::vector<ppredictor::OCRPredictResult> results = ppredictor->infer_ocr(dims_arr, data, std::vector<ppredictor::OCRPredictResult> results =
buf_len, ppredictor->infer_ocr(dims_arr, data, buf_len, NET_OCR, origin);
NET_OCR, origin); LOGI("infer_ocr finished with boxes %ld", results.size());
LOGI("infer_ocr finished with boxes %ld", results.size()); // 这里将std::vector<ppredictor::OCRPredictResult> 序列化成
// 这里将std::vector<ppredictor::OCRPredictResult> 序列化成 float数组,传输到java层再反序列化 // float数组,传输到java层再反序列化
std::vector<float> float_arr; std::vector<float> float_arr;
for (const ppredictor::OCRPredictResult &r :results) { for (const ppredictor::OCRPredictResult &r : results) {
float_arr.push_back(r.points.size()); float_arr.push_back(r.points.size());
float_arr.push_back(r.word_index.size()); float_arr.push_back(r.word_index.size());
float_arr.push_back(r.score); float_arr.push_back(r.score);
for (const std::vector<int> &point : r.points) { for (const std::vector<int> &point : r.points) {
float_arr.push_back(point.at(0)); float_arr.push_back(point.at(0));
float_arr.push_back(point.at(1)); float_arr.push_back(point.at(1));
}
for (int index: r.word_index) {
float_arr.push_back(index);
}
} }
return cpp_array_to_jfloatarray(env, float_arr.data(), float_arr.size()); for (int index : r.word_index) {
float_arr.push_back(index);
}
}
return cpp_array_to_jfloatarray(env, float_arr.data(), float_arr.size());
} }
extern "C" extern "C" JNIEXPORT void JNICALL
JNIEXPORT void JNICALL Java_com_baidu_paddle_lite_demo_ocr_OCRPredictorNative_release(
Java_com_baidu_paddle_lite_demo_ocr_OCRPredictorNative_release(JNIEnv *env, jobject thiz, JNIEnv *env, jobject thiz, jlong java_pointer) {
jlong java_pointer){ if (java_pointer == 0) {
if (java_pointer == 0) { LOGE("JAVA pointer is NULL");
LOGE("JAVA pointer is NULL"); return;
return; }
} ppredictor::OCR_PPredictor *ppredictor =
ppredictor::OCR_PPredictor *ppredictor = (ppredictor::OCR_PPredictor *) java_pointer; (ppredictor::OCR_PPredictor *)java_pointer;
delete ppredictor; delete ppredictor;
} }
\ No newline at end of file
// Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "ocr_cls_process.h"
#include <cmath>
#include <cstring>
#include <fstream>
#include <iostream>
#include <iostream>
#include <vector>
const std::vector<int> CLS_IMAGE_SHAPE = {3, 32, 100};
cv::Mat cls_resize_img(const cv::Mat &img) {
int imgC = CLS_IMAGE_SHAPE[0];
int imgW = CLS_IMAGE_SHAPE[2];
int imgH = CLS_IMAGE_SHAPE[1];
float ratio = float(img.cols) / float(img.rows);
int resize_w = 0;
if (ceilf(imgH * ratio) > imgW)
resize_w = imgW;
else
resize_w = int(ceilf(imgH * ratio));
cv::Mat resize_img;
cv::resize(img, resize_img, cv::Size(resize_w, imgH), 0.f, 0.f,
cv::INTER_CUBIC);
if (resize_w < imgW) {
cv::copyMakeBorder(resize_img, resize_img, 0, 0, 0, int(imgW - resize_w),
cv::BORDER_CONSTANT, {0, 0, 0});
}
return resize_img;
}
\ No newline at end of file
// Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include "common.h"
#include <opencv2/opencv.hpp>
#include <vector>
extern const std::vector<int> CLS_IMAGE_SHAPE;
cv::Mat cls_resize_img(const cv::Mat &img);
\ No newline at end of file
...@@ -3,38 +3,48 @@ ...@@ -3,38 +3,48 @@
// //
#include "ocr_ppredictor.h" #include "ocr_ppredictor.h"
#include "preprocess.h"
#include "common.h" #include "common.h"
#include "ocr_db_post_process.h" #include "ocr_cls_process.h"
#include "ocr_crnn_process.h" #include "ocr_crnn_process.h"
#include "ocr_db_post_process.h"
#include "preprocess.h"
namespace ppredictor { namespace ppredictor {
OCR_PPredictor::OCR_PPredictor(const OCR_Config &config) : _config(config) { OCR_PPredictor::OCR_PPredictor(const OCR_Config &config) : _config(config) {}
} int OCR_PPredictor::init(const std::string &det_model_content,
const std::string &rec_model_content,
const std::string &cls_model_content) {
_det_predictor = std::unique_ptr<PPredictor>(
new PPredictor{_config.thread_num, NET_OCR, _config.mode});
_det_predictor->init_nb(det_model_content);
int _rec_predictor = std::unique_ptr<PPredictor>(
OCR_PPredictor::init(const std::string &det_model_content, const std::string &rec_model_content) { new PPredictor{_config.thread_num, NET_OCR_INTERNAL, _config.mode});
_det_predictor = std::unique_ptr<PPredictor>( _rec_predictor->init_nb(rec_model_content);
new PPredictor{_config.thread_num, NET_OCR, _config.mode});
_det_predictor->init_nb(det_model_content);
_rec_predictor = std::unique_ptr<PPredictor>( _cls_predictor = std::unique_ptr<PPredictor>(
new PPredictor{_config.thread_num, NET_OCR_INTERNAL, _config.mode}); new PPredictor{_config.thread_num, NET_OCR_INTERNAL, _config.mode});
_rec_predictor->init_nb(rec_model_content); _cls_predictor->init_nb(cls_model_content);
return RETURN_OK; return RETURN_OK;
} }
int OCR_PPredictor::init_from_file(const std::string &det_model_path, const std::string &rec_model_path){ int OCR_PPredictor::init_from_file(const std::string &det_model_path,
_det_predictor = std::unique_ptr<PPredictor>( const std::string &rec_model_path,
new PPredictor{_config.thread_num, NET_OCR, _config.mode}); const std::string &cls_model_path) {
_det_predictor->init_from_file(det_model_path); _det_predictor = std::unique_ptr<PPredictor>(
new PPredictor{_config.thread_num, NET_OCR, _config.mode});
_rec_predictor = std::unique_ptr<PPredictor>( _det_predictor->init_from_file(det_model_path);
new PPredictor{_config.thread_num, NET_OCR_INTERNAL, _config.mode});
_rec_predictor->init_from_file(rec_model_path); _rec_predictor = std::unique_ptr<PPredictor>(
return RETURN_OK; new PPredictor{_config.thread_num, NET_OCR_INTERNAL, _config.mode});
_rec_predictor->init_from_file(rec_model_path);
_cls_predictor = std::unique_ptr<PPredictor>(
new PPredictor{_config.thread_num, NET_OCR_INTERNAL, _config.mode});
_cls_predictor->init_from_file(cls_model_path);
return RETURN_OK;
} }
/** /**
* for debug use, show result of First Step * for debug use, show result of First Step
...@@ -42,145 +52,188 @@ int OCR_PPredictor::init_from_file(const std::string &det_model_path, const std: ...@@ -42,145 +52,188 @@ int OCR_PPredictor::init_from_file(const std::string &det_model_path, const std:
* @param boxes * @param boxes
* @param srcimg * @param srcimg
*/ */
static void visual_img(const std::vector<std::vector<std::vector<int>>> &filter_boxes, static void
const std::vector<std::vector<std::vector<int>>> &boxes, visual_img(const std::vector<std::vector<std::vector<int>>> &filter_boxes,
const cv::Mat &srcimg) { const std::vector<std::vector<std::vector<int>>> &boxes,
// visualization const cv::Mat &srcimg) {
cv::Point rook_points[filter_boxes.size()][4]; // visualization
for (int n = 0; n < filter_boxes.size(); n++) { cv::Point rook_points[filter_boxes.size()][4];
for (int m = 0; m < filter_boxes[0].size(); m++) { for (int n = 0; n < filter_boxes.size(); n++) {
rook_points[n][m] = cv::Point(int(filter_boxes[n][m][0]), int(filter_boxes[n][m][1])); for (int m = 0; m < filter_boxes[0].size(); m++) {
} rook_points[n][m] =
cv::Point(int(filter_boxes[n][m][0]), int(filter_boxes[n][m][1]));
} }
}
cv::Mat img_vis;
srcimg.copyTo(img_vis); cv::Mat img_vis;
for (int n = 0; n < boxes.size(); n++) { srcimg.copyTo(img_vis);
const cv::Point *ppt[1] = {rook_points[n]}; for (int n = 0; n < boxes.size(); n++) {
int npt[] = {4}; const cv::Point *ppt[1] = {rook_points[n]};
cv::polylines(img_vis, ppt, npt, 1, 1, CV_RGB(0, 255, 0), 2, 8, 0); int npt[] = {4};
} cv::polylines(img_vis, ppt, npt, 1, 1, CV_RGB(0, 255, 0), 2, 8, 0);
// 调试用,自行替换需要修改的路径 }
cv::imwrite("/sdcard/1/vis.png", img_vis); // 调试用,自行替换需要修改的路径
cv::imwrite("/sdcard/1/vis.png", img_vis);
} }
std::vector<OCRPredictResult> std::vector<OCRPredictResult>
OCR_PPredictor::infer_ocr(const std::vector<int64_t> &dims, const float *input_data, int input_len, OCR_PPredictor::infer_ocr(const std::vector<int64_t> &dims,
int net_flag, cv::Mat &origin) { const float *input_data, int input_len, int net_flag,
cv::Mat &origin) {
PredictorInput input = _det_predictor->get_first_input();
input.set_dims(dims);
input.set_data(input_data, input_len);
std::vector<PredictorOutput> results = _det_predictor->infer();
PredictorOutput &res = results.at(0);
std::vector<std::vector<std::vector<int>>> filtered_box = calc_filtered_boxes(
res.get_float_data(), res.get_size(), (int)dims[2], (int)dims[3], origin);
LOGI("Filter_box size %ld", filtered_box.size());
return infer_rec(filtered_box, origin);
}
PredictorInput input = _det_predictor->get_first_input(); std::vector<OCRPredictResult> OCR_PPredictor::infer_rec(
const std::vector<std::vector<std::vector<int>>> &boxes,
const cv::Mat &origin_img) {
std::vector<float> mean = {0.5f, 0.5f, 0.5f};
std::vector<float> scale = {1 / 0.5f, 1 / 0.5f, 1 / 0.5f};
std::vector<int64_t> dims = {1, 3, 0, 0};
std::vector<OCRPredictResult> ocr_results;
PredictorInput input = _rec_predictor->get_first_input();
for (auto bp = boxes.crbegin(); bp != boxes.crend(); ++bp) {
const std::vector<std::vector<int>> &box = *bp;
cv::Mat crop_img = get_rotate_crop_image(origin_img, box);
crop_img = infer_cls(crop_img);
float wh_ratio = float(crop_img.cols) / float(crop_img.rows);
cv::Mat input_image = crnn_resize_img(crop_img, wh_ratio);
input_image.convertTo(input_image, CV_32FC3, 1 / 255.0f);
const float *dimg = reinterpret_cast<const float *>(input_image.data);
int input_size = input_image.rows * input_image.cols;
dims[2] = input_image.rows;
dims[3] = input_image.cols;
input.set_dims(dims); input.set_dims(dims);
input.set_data(input_data, input_len);
std::vector<PredictorOutput> results = _det_predictor->infer();
PredictorOutput &res = results.at(0);
std::vector<std::vector<std::vector<int>>> filtered_box
= calc_filtered_boxes(res.get_float_data(), res.get_size(), (int) dims[2], (int) dims[3],
origin);
LOGI("Filter_box size %ld", filtered_box.size());
return infer_rec(filtered_box, origin);
}
std::vector<OCRPredictResult> neon_mean_scale(dimg, input.get_mutable_float_data(), input_size, mean,
OCR_PPredictor::infer_rec(const std::vector<std::vector<std::vector<int>>> &boxes, scale);
const cv::Mat &origin_img) {
std::vector<float> mean = {0.5f, 0.5f, 0.5f}; std::vector<PredictorOutput> results = _rec_predictor->infer();
std::vector<float> scale = {1 / 0.5f, 1 / 0.5f, 1 / 0.5f};
std::vector<int64_t> dims = {1, 3, 0, 0}; OCRPredictResult res;
std::vector<OCRPredictResult> ocr_results; res.word_index = postprocess_rec_word_index(results.at(0));
if (res.word_index.empty()) {
PredictorInput input = _rec_predictor->get_first_input(); continue;
for (auto bp = boxes.crbegin(); bp != boxes.crend(); ++bp) {
const std::vector<std::vector<int>> &box = *bp;
cv::Mat crop_img = get_rotate_crop_image(origin_img, box);
float wh_ratio = float(crop_img.cols) / float(crop_img.rows);
cv::Mat input_image = crnn_resize_img(crop_img, wh_ratio);
input_image.convertTo(input_image, CV_32FC3, 1 / 255.0f);
const float *dimg = reinterpret_cast<const float *>(input_image.data);
int input_size = input_image.rows * input_image.cols;
dims[2] = input_image.rows;
dims[3] = input_image.cols;
input.set_dims(dims);
neon_mean_scale(dimg, input.get_mutable_float_data(), input_size, mean, scale);
std::vector<PredictorOutput> results = _rec_predictor->infer();
OCRPredictResult res;
res.word_index = postprocess_rec_word_index(results.at(0));
if (res.word_index.empty()) {
continue;
}
res.score = postprocess_rec_score(results.at(1));
res.points = box;
ocr_results.emplace_back(std::move(res));
} }
LOGI("ocr_results finished %lu", ocr_results.size()); res.score = postprocess_rec_score(results.at(1));
return ocr_results; res.points = box;
ocr_results.emplace_back(std::move(res));
}
LOGI("ocr_results finished %lu", ocr_results.size());
return ocr_results;
}
cv::Mat OCR_PPredictor::infer_cls(const cv::Mat &img, float thresh) {
std::vector<float> mean = {0.5f, 0.5f, 0.5f};
std::vector<float> scale = {1 / 0.5f, 1 / 0.5f, 1 / 0.5f};
std::vector<int64_t> dims = {1, 3, 0, 0};
std::vector<OCRPredictResult> ocr_results;
PredictorInput input = _cls_predictor->get_first_input();
cv::Mat input_image = cls_resize_img(img);
input_image.convertTo(input_image, CV_32FC3, 1 / 255.0f);
const float *dimg = reinterpret_cast<const float *>(input_image.data);
int input_size = input_image.rows * input_image.cols;
dims[2] = input_image.rows;
dims[3] = input_image.cols;
input.set_dims(dims);
neon_mean_scale(dimg, input.get_mutable_float_data(), input_size, mean,
scale);
std::vector<PredictorOutput> results = _cls_predictor->infer();
const float *scores = results.at(0).get_float_data();
const int *labels = results.at(1).get_int_data();
for (int64_t i = 0; i < results.at(0).get_size(); i++) {
LOGI("output scores [%f]", scores[i]);
}
for (int64_t i = 0; i < results.at(1).get_size(); i++) {
LOGI("output label [%d]", labels[i]);
}
int label_idx = labels[0];
float score = scores[label_idx];
cv::Mat srcimg;
img.copyTo(srcimg);
if (label_idx % 2 == 1 && score > thresh) {
cv::rotate(srcimg, srcimg, 1);
}
return srcimg;
} }
std::vector<std::vector<std::vector<int>>> std::vector<std::vector<std::vector<int>>>
OCR_PPredictor::calc_filtered_boxes(const float *pred, int pred_size, int output_height, OCR_PPredictor::calc_filtered_boxes(const float *pred, int pred_size,
int output_width, const cv::Mat &origin) { int output_height, int output_width,
const double threshold = 0.3; const cv::Mat &origin) {
const double maxvalue = 1; const double threshold = 0.3;
const double maxvalue = 1;
cv::Mat pred_map = cv::Mat::zeros(output_height, output_width, CV_32F);
memcpy(pred_map.data, pred, pred_size * sizeof(float)); cv::Mat pred_map = cv::Mat::zeros(output_height, output_width, CV_32F);
cv::Mat cbuf_map; memcpy(pred_map.data, pred, pred_size * sizeof(float));
pred_map.convertTo(cbuf_map, CV_8UC1); cv::Mat cbuf_map;
pred_map.convertTo(cbuf_map, CV_8UC1);
cv::Mat bit_map;
cv::threshold(cbuf_map, bit_map, threshold, maxvalue, cv::THRESH_BINARY); cv::Mat bit_map;
cv::threshold(cbuf_map, bit_map, threshold, maxvalue, cv::THRESH_BINARY);
std::vector<std::vector<std::vector<int>>> boxes = boxes_from_bitmap(pred_map, bit_map);
float ratio_h = output_height * 1.0f / origin.rows; std::vector<std::vector<std::vector<int>>> boxes =
float ratio_w = output_width * 1.0f / origin.cols; boxes_from_bitmap(pred_map, bit_map);
std::vector<std::vector<std::vector<int>>> filter_boxes = filter_tag_det_res(boxes, ratio_h, float ratio_h = output_height * 1.0f / origin.rows;
ratio_w, origin); float ratio_w = output_width * 1.0f / origin.cols;
return filter_boxes; std::vector<std::vector<std::vector<int>>> filter_boxes =
filter_tag_det_res(boxes, ratio_h, ratio_w, origin);
return filter_boxes;
} }
std::vector<int> OCR_PPredictor::postprocess_rec_word_index(const PredictorOutput &res) { std::vector<int>
const int *rec_idx = res.get_int_data(); OCR_PPredictor::postprocess_rec_word_index(const PredictorOutput &res) {
const std::vector<std::vector<uint64_t>> rec_idx_lod = res.get_lod(); const int *rec_idx = res.get_int_data();
const std::vector<std::vector<uint64_t>> rec_idx_lod = res.get_lod();
std::vector<int> pred_idx; std::vector<int> pred_idx;
for (int n = int(rec_idx_lod[0][0]); n < int(rec_idx_lod[0][1] * 2); n += 2) { for (int n = int(rec_idx_lod[0][0]); n < int(rec_idx_lod[0][1] * 2); n += 2) {
pred_idx.emplace_back(rec_idx[n]); pred_idx.emplace_back(rec_idx[n]);
} }
return pred_idx; return pred_idx;
} }
float OCR_PPredictor::postprocess_rec_score(const PredictorOutput &res) { float OCR_PPredictor::postprocess_rec_score(const PredictorOutput &res) {
const float *predict_batch = res.get_float_data(); const float *predict_batch = res.get_float_data();
const std::vector<int64_t> predict_shape = res.get_shape(); const std::vector<int64_t> predict_shape = res.get_shape();
const std::vector<std::vector<uint64_t>> predict_lod = res.get_lod(); const std::vector<std::vector<uint64_t>> predict_lod = res.get_lod();
int blank = predict_shape[1]; int blank = predict_shape[1];
float score = 0.f; float score = 0.f;
int count = 0; int count = 0;
for (int n = predict_lod[0][0]; n < predict_lod[0][1] - 1; n++) { for (int n = predict_lod[0][0]; n < predict_lod[0][1] - 1; n++) {
int argmax_idx = argmax(predict_batch + n * predict_shape[1], int argmax_idx = argmax(predict_batch + n * predict_shape[1],
predict_batch + (n + 1) * predict_shape[1]); predict_batch + (n + 1) * predict_shape[1]);
float max_value = predict_batch[n * predict_shape[1] + argmax_idx]; float max_value = predict_batch[n * predict_shape[1] + argmax_idx];
if (blank - 1 - argmax_idx > 1e-5) { if (blank - 1 - argmax_idx > 1e-5) {
score += max_value; score += max_value;
count += 1; count += 1;
}
}
if (count == 0) {
LOGE("calc score count 0");
} else {
score /= count;
} }
LOGI("calc score: %f", score); }
return score; if (count == 0) {
LOGE("calc score count 0");
} else {
score /= count;
}
LOGI("calc score: %f", score);
return score;
} }
NET_TYPE OCR_PPredictor::get_net_flag() const { return NET_OCR; }
NET_TYPE OCR_PPredictor::get_net_flag() const {
return NET_OCR;
}
} }
\ No newline at end of file
...@@ -4,10 +4,10 @@ ...@@ -4,10 +4,10 @@
#pragma once #pragma once
#include <string> #include "ppredictor.h"
#include <opencv2/opencv.hpp> #include <opencv2/opencv.hpp>
#include <paddle_api.h> #include <paddle_api.h>
#include "ppredictor.h" #include <string>
namespace ppredictor { namespace ppredictor {
...@@ -15,17 +15,18 @@ namespace ppredictor { ...@@ -15,17 +15,18 @@ namespace ppredictor {
* Config * Config
*/ */
struct OCR_Config { struct OCR_Config {
int thread_num = 4; // Thread num int thread_num = 4; // Thread num
paddle::lite_api::PowerMode mode = paddle::lite_api::LITE_POWER_HIGH; // PaddleLite Mode paddle::lite_api::PowerMode mode =
paddle::lite_api::LITE_POWER_HIGH; // PaddleLite Mode
}; };
/** /**
* PolyGone Result * PolyGone Result
*/ */
struct OCRPredictResult { struct OCRPredictResult {
std::vector<int> word_index; std::vector<int> word_index;
std::vector<std::vector<int>> points; std::vector<std::vector<int>> points;
float score; float score;
}; };
/** /**
...@@ -35,78 +36,87 @@ struct OCRPredictResult { ...@@ -35,78 +36,87 @@ struct OCRPredictResult {
*/ */
class OCR_PPredictor : public PPredictor_Interface { class OCR_PPredictor : public PPredictor_Interface {
public: public:
OCR_PPredictor(const OCR_Config &config); OCR_PPredictor(const OCR_Config &config);
virtual ~OCR_PPredictor() { virtual ~OCR_PPredictor() {}
} /**
* 初始化二个模型的Predictor
/** * @param det_model_content
* 初始化二个模型的Predictor * @param rec_model_content
* @param det_model_content * @return
* @param rec_model_content */
* @return int init(const std::string &det_model_content,
*/ const std::string &rec_model_content,
int init(const std::string &det_model_content, const std::string &rec_model_content); const std::string &cls_model_content);
int init_from_file(const std::string &det_model_path, const std::string &rec_model_path); int init_from_file(const std::string &det_model_path,
/** const std::string &rec_model_path,
* Return OCR result const std::string &cls_model_path);
* @param dims /**
* @param input_data * Return OCR result
* @param input_len * @param dims
* @param net_flag * @param input_data
* @param origin * @param input_len
* @return * @param net_flag
*/ * @param origin
virtual std::vector<OCRPredictResult> * @return
infer_ocr(const std::vector<int64_t> &dims, const float *input_data, int input_len, */
int net_flag, cv::Mat &origin); virtual std::vector<OCRPredictResult>
infer_ocr(const std::vector<int64_t> &dims, const float *input_data,
int input_len, int net_flag, cv::Mat &origin);
virtual NET_TYPE get_net_flag() const;
virtual NET_TYPE get_net_flag() const;
private: private:
/**
/** * calcul Polygone from the result image of first model
* calcul Polygone from the result image of first model * @param pred
* @param pred * @param output_height
* @param output_height * @param output_width
* @param output_width * @param origin
* @param origin * @return
* @return */
*/ std::vector<std::vector<std::vector<int>>>
std::vector<std::vector<std::vector<int>>> calc_filtered_boxes(const float *pred, int pred_size, int output_height,
calc_filtered_boxes(const float *pred, int pred_size, int output_height, int output_width, int output_width, const cv::Mat &origin);
const cv::Mat &origin);
/**
/** * infer for second model
* infer for second model *
* * @param boxes
* @param boxes * @param origin
* @param origin * @return
* @return */
*/ std::vector<OCRPredictResult>
std::vector<OCRPredictResult> infer_rec(const std::vector<std::vector<std::vector<int>>> &boxes,
infer_rec(const std::vector<std::vector<std::vector<int>>> &boxes, const cv::Mat &origin); const cv::Mat &origin);
/** /**
* Postprocess or sencod model to extract text * infer for cls model
* @param res *
* @return * @param boxes
*/ * @param origin
std::vector<int> postprocess_rec_word_index(const PredictorOutput &res); * @return
*/
/** cv::Mat infer_cls(const cv::Mat &origin, float thresh = 0.5);
* calculate confidence of second model text result
* @param res /**
* @return * Postprocess or sencod model to extract text
*/ * @param res
float postprocess_rec_score(const PredictorOutput &res); * @return
*/
std::unique_ptr<PPredictor> _det_predictor; std::vector<int> postprocess_rec_word_index(const PredictorOutput &res);
std::unique_ptr<PPredictor> _rec_predictor;
OCR_Config _config; /**
* calculate confidence of second model text result
* @param res
* @return
*/
float postprocess_rec_score(const PredictorOutput &res);
std::unique_ptr<PPredictor> _det_predictor;
std::unique_ptr<PPredictor> _rec_predictor;
std::unique_ptr<PPredictor> _cls_predictor;
OCR_Config _config;
}; };
} }
...@@ -29,7 +29,7 @@ public class OCRPredictorNative { ...@@ -29,7 +29,7 @@ public class OCRPredictorNative {
public OCRPredictorNative(Config config) { public OCRPredictorNative(Config config) {
this.config = config; this.config = config;
loadLibrary(); loadLibrary();
nativePointer = init(config.detModelFilename, config.recModelFilename, nativePointer = init(config.detModelFilename, config.recModelFilename,config.clsModelFilename,
config.cpuThreadNum, config.cpuPower); config.cpuThreadNum, config.cpuPower);
Log.i("OCRPredictorNative", "load success " + nativePointer); Log.i("OCRPredictorNative", "load success " + nativePointer);
...@@ -38,7 +38,7 @@ public class OCRPredictorNative { ...@@ -38,7 +38,7 @@ public class OCRPredictorNative {
public void release() { public void release() {
if (nativePointer != 0) { if (nativePointer != 0) {
nativePointer = 0; nativePointer = 0;
destory(nativePointer); // destory(nativePointer);
} }
} }
...@@ -55,10 +55,11 @@ public class OCRPredictorNative { ...@@ -55,10 +55,11 @@ public class OCRPredictorNative {
public String cpuPower; public String cpuPower;
public String detModelFilename; public String detModelFilename;
public String recModelFilename; public String recModelFilename;
public String clsModelFilename;
} }
protected native long init(String detModelPath, String recModelPath, int threadNum, String cpuMode); protected native long init(String detModelPath, String recModelPath,String clsModelPath, int threadNum, String cpuMode);
protected native float[] forward(long pointer, float[] buf, float[] ddims, Bitmap originalImage); protected native float[] forward(long pointer, float[] buf, float[] ddims, Bitmap originalImage);
......
...@@ -121,7 +121,8 @@ public class Predictor { ...@@ -121,7 +121,8 @@ public class Predictor {
config.cpuThreadNum = cpuThreadNum; config.cpuThreadNum = cpuThreadNum;
config.detModelFilename = realPath + File.separator + "ch_det_mv3_db_opt.nb"; config.detModelFilename = realPath + File.separator + "ch_det_mv3_db_opt.nb";
config.recModelFilename = realPath + File.separator + "ch_rec_mv3_crnn_opt.nb"; config.recModelFilename = realPath + File.separator + "ch_rec_mv3_crnn_opt.nb";
Log.e("Predictor", "model path" + config.detModelFilename + " ; " + config.recModelFilename); config.clsModelFilename = realPath + File.separator + "cls_opt_arm.nb";
Log.e("Predictor", "model path" + config.detModelFilename + " ; " + config.recModelFilename + ";" + config.clsModelFilename);
config.cpuPower = cpuPowerMode; config.cpuPower = cpuPowerMode;
paddlePredictor = new OCRPredictorNative(config); paddlePredictor = new OCRPredictorNative(config);
......
...@@ -57,6 +57,12 @@ public: ...@@ -57,6 +57,12 @@ public:
this->char_list_file.assign(config_map_["char_list_file"]); this->char_list_file.assign(config_map_["char_list_file"]);
this->use_angle_cls = bool(stoi(config_map_["use_angle_cls"]));
this->cls_model_dir.assign(config_map_["cls_model_dir"]);
this->cls_thresh = stod(config_map_["cls_thresh"]);
this->visualize = bool(stoi(config_map_["visualize"])); this->visualize = bool(stoi(config_map_["visualize"]));
} }
...@@ -84,8 +90,14 @@ public: ...@@ -84,8 +90,14 @@ public:
std::string rec_model_dir; std::string rec_model_dir;
bool use_angle_cls;
std::string char_list_file; std::string char_list_file;
std::string cls_model_dir;
double cls_thresh;
bool visualize = true; bool visualize = true;
void PrintConfigInfo(); void PrintConfigInfo();
......
// Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "opencv2/core.hpp"
#include "opencv2/imgcodecs.hpp"
#include "opencv2/imgproc.hpp"
#include "paddle_api.h"
#include "paddle_inference_api.h"
#include <chrono>
#include <iomanip>
#include <iostream>
#include <ostream>
#include <vector>
#include <cstring>
#include <fstream>
#include <numeric>
#include <include/preprocess_op.h>
#include <include/utility.h>
namespace PaddleOCR {
class Classifier {
public:
explicit Classifier(const std::string &model_dir, const bool &use_gpu,
const int &gpu_id, const int &gpu_mem,
const int &cpu_math_library_num_threads,
const bool &use_mkldnn, const bool &use_zero_copy_run,
const double &cls_thresh) {
this->use_gpu_ = use_gpu;
this->gpu_id_ = gpu_id;
this->gpu_mem_ = gpu_mem;
this->cpu_math_library_num_threads_ = cpu_math_library_num_threads;
this->use_mkldnn_ = use_mkldnn;
this->use_zero_copy_run_ = use_zero_copy_run;
this->cls_thresh = cls_thresh;
LoadModel(model_dir);
}
// Load Paddle inference model
void LoadModel(const std::string &model_dir);
cv::Mat Run(cv::Mat &img);
private:
std::shared_ptr<PaddlePredictor> predictor_;
bool use_gpu_ = false;
int gpu_id_ = 0;
int gpu_mem_ = 4000;
int cpu_math_library_num_threads_ = 4;
bool use_mkldnn_ = false;
bool use_zero_copy_run_ = false;
double cls_thresh = 0.5;
std::vector<float> mean_ = {0.5f, 0.5f, 0.5f};
std::vector<float> scale_ = {1 / 0.5f, 1 / 0.5f, 1 / 0.5f};
bool is_scale_ = true;
// pre-process
ClsResizeImg resize_op_;
Normalize normalize_op_;
Permute permute_op_;
}; // class Classifier
} // namespace PaddleOCR
...@@ -27,6 +27,7 @@ ...@@ -27,6 +27,7 @@
#include <fstream> #include <fstream>
#include <numeric> #include <numeric>
#include <include/ocr_cls.h>
#include <include/postprocess_op.h> #include <include/postprocess_op.h>
#include <include/preprocess_op.h> #include <include/preprocess_op.h>
#include <include/utility.h> #include <include/utility.h>
...@@ -56,7 +57,8 @@ public: ...@@ -56,7 +57,8 @@ public:
// Load Paddle inference model // Load Paddle inference model
void LoadModel(const std::string &model_dir); void LoadModel(const std::string &model_dir);
void Run(std::vector<std::vector<std::vector<int>>> boxes, cv::Mat &img); void Run(std::vector<std::vector<std::vector<int>>> boxes, cv::Mat &img,
Classifier *cls);
private: private:
std::shared_ptr<PaddlePredictor> predictor_; std::shared_ptr<PaddlePredictor> predictor_;
......
...@@ -56,4 +56,10 @@ public: ...@@ -56,4 +56,10 @@ public:
const std::vector<int> &rec_image_shape = {3, 32, 320}); const std::vector<int> &rec_image_shape = {3, 32, 320});
}; };
class ClsResizeImg {
public:
virtual void Run(const cv::Mat &img, cv::Mat &resize_img,
const std::vector<int> &rec_image_shape = {3, 32, 320});
};
} // namespace PaddleOCR } // namespace PaddleOCR
\ No newline at end of file
...@@ -53,6 +53,15 @@ int main(int argc, char **argv) { ...@@ -53,6 +53,15 @@ int main(int argc, char **argv) {
config.cpu_math_library_num_threads, config.use_mkldnn, config.cpu_math_library_num_threads, config.use_mkldnn,
config.use_zero_copy_run, config.max_side_len, config.det_db_thresh, config.use_zero_copy_run, config.max_side_len, config.det_db_thresh,
config.det_db_box_thresh, config.det_db_unclip_ratio, config.visualize); config.det_db_box_thresh, config.det_db_unclip_ratio, config.visualize);
Classifier *cls = nullptr;
if (config.use_angle_cls == true) {
cls = new Classifier(config.cls_model_dir, config.use_gpu, config.gpu_id,
config.gpu_mem, config.cpu_math_library_num_threads,
config.use_mkldnn, config.use_zero_copy_run,
config.cls_thresh);
}
CRNNRecognizer rec(config.rec_model_dir, config.use_gpu, config.gpu_id, CRNNRecognizer rec(config.rec_model_dir, config.use_gpu, config.gpu_id,
config.gpu_mem, config.cpu_math_library_num_threads, config.gpu_mem, config.cpu_math_library_num_threads,
config.use_mkldnn, config.use_zero_copy_run, config.use_mkldnn, config.use_zero_copy_run,
...@@ -62,7 +71,7 @@ int main(int argc, char **argv) { ...@@ -62,7 +71,7 @@ int main(int argc, char **argv) {
std::vector<std::vector<std::vector<int>>> boxes; std::vector<std::vector<std::vector<int>>> boxes;
det.Run(srcimg, boxes); det.Run(srcimg, boxes);
rec.Run(boxes, srcimg); rec.Run(boxes, srcimg, cls);
auto end = std::chrono::system_clock::now(); auto end = std::chrono::system_clock::now();
auto duration = auto duration =
......
// Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include <include/ocr_cls.h>
namespace PaddleOCR {
cv::Mat Classifier::Run(cv::Mat &img) {
cv::Mat src_img;
img.copyTo(src_img);
cv::Mat resize_img;
std::vector<int> rec_image_shape = {3, 32, 100};
int index = 0;
float wh_ratio = float(img.cols) / float(img.rows);
this->resize_op_.Run(img, resize_img, rec_image_shape);
this->normalize_op_.Run(&resize_img, this->mean_, this->scale_,
this->is_scale_);
std::vector<float> input(1 * 3 * resize_img.rows * resize_img.cols, 0.0f);
this->permute_op_.Run(&resize_img, input.data());
// Inference.
if (this->use_zero_copy_run_) {
auto input_names = this->predictor_->GetInputNames();
auto input_t = this->predictor_->GetInputTensor(input_names[0]);
input_t->Reshape({1, 3, resize_img.rows, resize_img.cols});
input_t->copy_from_cpu(input.data());
this->predictor_->ZeroCopyRun();
} else {
paddle::PaddleTensor input_t;
input_t.shape = {1, 3, resize_img.rows, resize_img.cols};
input_t.data =
paddle::PaddleBuf(input.data(), input.size() * sizeof(float));
input_t.dtype = PaddleDType::FLOAT32;
std::vector<paddle::PaddleTensor> outputs;
this->predictor_->Run({input_t}, &outputs, 1);
}
std::vector<float> softmax_out;
std::vector<int64_t> label_out;
auto output_names = this->predictor_->GetOutputNames();
auto softmax_out_t = this->predictor_->GetOutputTensor(output_names[0]);
auto label_out_t = this->predictor_->GetOutputTensor(output_names[1]);
auto softmax_shape_out = softmax_out_t->shape();
auto label_shape_out = label_out_t->shape();
int softmax_out_num =
std::accumulate(softmax_shape_out.begin(), softmax_shape_out.end(), 1,
std::multiplies<int>());
int label_out_num =
std::accumulate(label_shape_out.begin(), label_shape_out.end(), 1,
std::multiplies<int>());
softmax_out.resize(softmax_out_num);
label_out.resize(label_out_num);
softmax_out_t->copy_to_cpu(softmax_out.data());
label_out_t->copy_to_cpu(label_out.data());
int label = label_out[0];
float score = softmax_out[label];
// std::cout << "\nlabel "<<label<<" score: "<<score;
if (label % 2 == 1 && score > this->cls_thresh) {
cv::rotate(src_img, src_img, 1);
}
return src_img;
}
void Classifier::LoadModel(const std::string &model_dir) {
AnalysisConfig config;
config.SetModel(model_dir + "/model", model_dir + "/params");
if (this->use_gpu_) {
config.EnableUseGpu(this->gpu_mem_, this->gpu_id_);
} else {
config.DisableGpu();
if (this->use_mkldnn_) {
config.EnableMKLDNN();
}
config.SetCpuMathLibraryNumThreads(this->cpu_math_library_num_threads_);
}
// false for zero copy tensor
config.SwitchUseFeedFetchOps(!this->use_zero_copy_run_);
// true for multiple input
config.SwitchSpecifyInputNames(true);
config.SwitchIrOptim(true);
config.EnableMemoryOptim();
config.DisableGlogInfo();
this->predictor_ = CreatePaddlePredictor(config);
}
} // namespace PaddleOCR
...@@ -17,7 +17,7 @@ ...@@ -17,7 +17,7 @@
namespace PaddleOCR { namespace PaddleOCR {
void CRNNRecognizer::Run(std::vector<std::vector<std::vector<int>>> boxes, void CRNNRecognizer::Run(std::vector<std::vector<std::vector<int>>> boxes,
cv::Mat &img) { cv::Mat &img, Classifier *cls) {
cv::Mat srcimg; cv::Mat srcimg;
img.copyTo(srcimg); img.copyTo(srcimg);
cv::Mat crop_img; cv::Mat crop_img;
...@@ -27,6 +27,9 @@ void CRNNRecognizer::Run(std::vector<std::vector<std::vector<int>>> boxes, ...@@ -27,6 +27,9 @@ void CRNNRecognizer::Run(std::vector<std::vector<std::vector<int>>> boxes,
int index = 0; int index = 0;
for (int i = boxes.size() - 1; i >= 0; i--) { for (int i = boxes.size() - 1; i >= 0; i--) {
crop_img = GetRotateCropImage(srcimg, boxes[i]); crop_img = GetRotateCropImage(srcimg, boxes[i]);
if (cls != nullptr) {
crop_img = cls->Run(crop_img);
}
float wh_ratio = float(crop_img.cols) / float(crop_img.rows); float wh_ratio = float(crop_img.cols) / float(crop_img.rows);
......
...@@ -116,4 +116,26 @@ void CrnnResizeImg::Run(const cv::Mat &img, cv::Mat &resize_img, float wh_ratio, ...@@ -116,4 +116,26 @@ void CrnnResizeImg::Run(const cv::Mat &img, cv::Mat &resize_img, float wh_ratio,
cv::INTER_LINEAR); cv::INTER_LINEAR);
} }
void ClsResizeImg::Run(const cv::Mat &img, cv::Mat &resize_img,
const std::vector<int> &rec_image_shape) {
int imgC, imgH, imgW;
imgC = rec_image_shape[0];
imgH = rec_image_shape[1];
imgW = rec_image_shape[2];
float ratio = float(img.cols) / float(img.rows);
int resize_w, resize_h;
if (ceilf(imgH * ratio) > imgW)
resize_w = imgW;
else
resize_w = int(ceilf(imgH * ratio));
cv::resize(img, resize_img, cv::Size(resize_w, imgH), 0.f, 0.f,
cv::INTER_LINEAR);
if (resize_w < imgW) {
cv::copyMakeBorder(resize_img, resize_img, 0, 0, 0, imgW - resize_w,
cv::BORDER_CONSTANT, cv::Scalar(0, 0, 0));
}
}
} // namespace PaddleOCR } // namespace PaddleOCR
\ No newline at end of file
...@@ -13,6 +13,11 @@ det_db_box_thresh 0.5 ...@@ -13,6 +13,11 @@ det_db_box_thresh 0.5
det_db_unclip_ratio 2.0 det_db_unclip_ratio 2.0
det_model_dir ./inference/det_db det_model_dir ./inference/det_db
# cls config
use_angle_cls 0
cls_model_dir ../inference/cls
cls_thresh 0.9
# rec config # rec config
rec_model_dir ./inference/rec_crnn rec_model_dir ./inference/rec_crnn
char_list_file ../../ppocr/utils/ppocr_keys_v1.txt char_list_file ../../ppocr/utils/ppocr_keys_v1.txt
......
...@@ -40,8 +40,8 @@ CXX_LIBS = ${OPENCV_LIBS} -L$(LITE_ROOT)/cxx/lib/ -lpaddle_light_api_shared $(SY ...@@ -40,8 +40,8 @@ CXX_LIBS = ${OPENCV_LIBS} -L$(LITE_ROOT)/cxx/lib/ -lpaddle_light_api_shared $(SY
#CXX_LIBS = $(LITE_ROOT)/cxx/lib/libpaddle_api_light_bundled.a $(SYSTEM_LIBS) #CXX_LIBS = $(LITE_ROOT)/cxx/lib/libpaddle_api_light_bundled.a $(SYSTEM_LIBS)
ocr_db_crnn: fetch_opencv ocr_db_crnn.o crnn_process.o db_post_process.o clipper.o ocr_db_crnn: fetch_opencv ocr_db_crnn.o crnn_process.o db_post_process.o clipper.o cls_process.o
$(CC) $(SYSROOT_LINK) $(CXXFLAGS_LINK) ocr_db_crnn.o crnn_process.o db_post_process.o clipper.o -o ocr_db_crnn $(CXX_LIBS) $(LDFLAGS) $(CC) $(SYSROOT_LINK) $(CXXFLAGS_LINK) ocr_db_crnn.o crnn_process.o db_post_process.o clipper.o cls_process.o -o ocr_db_crnn $(CXX_LIBS) $(LDFLAGS)
ocr_db_crnn.o: ocr_db_crnn.cc ocr_db_crnn.o: ocr_db_crnn.cc
$(CC) $(SYSROOT_COMPLILE) $(CXX_DEFINES) $(CXX_INCLUDES) $(CXX_FLAGS) -o ocr_db_crnn.o -c ocr_db_crnn.cc $(CC) $(SYSROOT_COMPLILE) $(CXX_DEFINES) $(CXX_INCLUDES) $(CXX_FLAGS) -o ocr_db_crnn.o -c ocr_db_crnn.cc
...@@ -49,6 +49,9 @@ ocr_db_crnn.o: ocr_db_crnn.cc ...@@ -49,6 +49,9 @@ ocr_db_crnn.o: ocr_db_crnn.cc
crnn_process.o: fetch_opencv crnn_process.cc crnn_process.o: fetch_opencv crnn_process.cc
$(CC) $(SYSROOT_COMPLILE) $(CXX_DEFINES) $(CXX_INCLUDES) $(CXX_FLAGS) -o crnn_process.o -c crnn_process.cc $(CC) $(SYSROOT_COMPLILE) $(CXX_DEFINES) $(CXX_INCLUDES) $(CXX_FLAGS) -o crnn_process.o -c crnn_process.cc
cls_process.o: fetch_opencv cls_process.cc
$(CC) $(SYSROOT_COMPLILE) $(CXX_DEFINES) $(CXX_INCLUDES) $(CXX_FLAGS) -o cls_process.o -c cls_process.cc
db_post_process.o: fetch_clipper fetch_opencv db_post_process.cc db_post_process.o: fetch_clipper fetch_opencv db_post_process.cc
$(CC) $(SYSROOT_COMPLILE) $(CXX_DEFINES) $(CXX_INCLUDES) $(CXX_FLAGS) -o db_post_process.o -c db_post_process.cc $(CC) $(SYSROOT_COMPLILE) $(CXX_DEFINES) $(CXX_INCLUDES) $(CXX_FLAGS) -o db_post_process.o -c db_post_process.cc
...@@ -73,5 +76,5 @@ fetch_opencv: ...@@ -73,5 +76,5 @@ fetch_opencv:
.PHONY: clean .PHONY: clean
clean: clean:
rm -f ocr_db_crnn.o clipper.o db_post_process.o crnn_process.o rm -f ocr_db_crnn.o clipper.o db_post_process.o crnn_process.o cls_process.o
rm -f ocr_db_crnn rm -f ocr_db_crnn
// Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "cls_process.h" //NOLINT
#include <algorithm>
#include <memory>
#include <string>
const std::vector<int> rec_image_shape{3, 32, 100};
cv::Mat ClsResizeImg(cv::Mat img) {
int imgC, imgH, imgW;
imgC = rec_image_shape[0];
imgH = rec_image_shape[1];
imgW = rec_image_shape[2];
float ratio = static_cast<float>(img.cols) / static_cast<float>(img.rows);
int resize_w, resize_h;
if (ceilf(imgH * ratio) > imgW)
resize_w = imgW;
else
resize_w = int(ceilf(imgH * ratio));
cv::Mat resize_img;
cv::resize(img, resize_img, cv::Size(resize_w, imgH), 0.f, 0.f,
cv::INTER_LINEAR);
if (resize_w < imgW) {
cv::copyMakeBorder(resize_img, resize_img, 0, 0, 0, imgW - resize_w,
cv::BORDER_CONSTANT, cv::Scalar(0, 0, 0));
}
return resize_img;
}
\ No newline at end of file
// Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <cstring>
#include <fstream>
#include <iostream>
#include <memory>
#include <string>
#include <vector>
#include "math.h" //NOLINT
#include "opencv2/core.hpp"
#include "opencv2/imgcodecs.hpp"
#include "opencv2/imgproc.hpp"
cv::Mat ClsResizeImg(cv::Mat img);
\ No newline at end of file
...@@ -15,6 +15,7 @@ ...@@ -15,6 +15,7 @@
#include "paddle_api.h" // NOLINT #include "paddle_api.h" // NOLINT
#include <chrono> #include <chrono>
#include "cls_process.h"
#include "crnn_process.h" #include "crnn_process.h"
#include "db_post_process.h" #include "db_post_process.h"
...@@ -105,11 +106,55 @@ cv::Mat DetResizeImg(const cv::Mat img, int max_size_len, ...@@ -105,11 +106,55 @@ cv::Mat DetResizeImg(const cv::Mat img, int max_size_len,
return resize_img; return resize_img;
} }
cv::Mat RunClsModel(cv::Mat img, std::shared_ptr<PaddlePredictor> predictor_cls,
const float thresh = 0.5) {
std::vector<float> mean = {0.5f, 0.5f, 0.5f};
std::vector<float> scale = {1 / 0.5f, 1 / 0.5f, 1 / 0.5f};
cv::Mat srcimg;
img.copyTo(srcimg);
cv::Mat crop_img;
cv::Mat resize_img;
int index = 0;
float wh_ratio =
static_cast<float>(crop_img.cols) / static_cast<float>(crop_img.rows);
resize_img = ClsResizeImg(crop_img);
resize_img.convertTo(resize_img, CV_32FC3, 1 / 255.f);
const float *dimg = reinterpret_cast<const float *>(resize_img.data);
std::unique_ptr<Tensor> input_tensor0(std::move(predictor_cls->GetInput(0)));
input_tensor0->Resize({1, 3, resize_img.rows, resize_img.cols});
auto *data0 = input_tensor0->mutable_data<float>();
NeonMeanScale(dimg, data0, resize_img.rows * resize_img.cols, mean, scale);
// Run CLS predictor
predictor_cls->Run();
// Get output and run postprocess
std::unique_ptr<const Tensor> softmax_out(
std::move(predictor_cls->GetOutput(0)));
std::unique_ptr<const Tensor> label_out(
std::move(predictor_cls->GetOutput(1)));
auto *softmax_scores = softmax_out->mutable_data<float>();
auto *label_idxs = label_out->data<int64>();
int label_idx = label_idxs[0];
float score = softmax_scores[label_idx];
if (label_idx % 2 == 1 && score > thresh) {
cv::rotate(srcimg, srcimg, 1);
}
return srcimg;
}
void RunRecModel(std::vector<std::vector<std::vector<int>>> boxes, cv::Mat img, void RunRecModel(std::vector<std::vector<std::vector<int>>> boxes, cv::Mat img,
std::shared_ptr<PaddlePredictor> predictor_crnn, std::shared_ptr<PaddlePredictor> predictor_crnn,
std::vector<std::string> &rec_text, std::vector<std::string> &rec_text,
std::vector<float> &rec_text_score, std::vector<float> &rec_text_score,
std::vector<std::string> charactor_dict) { std::vector<std::string> charactor_dict,
std::shared_ptr<PaddlePredictor> predictor_cls) {
std::vector<float> mean = {0.5f, 0.5f, 0.5f}; std::vector<float> mean = {0.5f, 0.5f, 0.5f};
std::vector<float> scale = {1 / 0.5f, 1 / 0.5f, 1 / 0.5f}; std::vector<float> scale = {1 / 0.5f, 1 / 0.5f, 1 / 0.5f};
...@@ -121,6 +166,7 @@ void RunRecModel(std::vector<std::vector<std::vector<int>>> boxes, cv::Mat img, ...@@ -121,6 +166,7 @@ void RunRecModel(std::vector<std::vector<std::vector<int>>> boxes, cv::Mat img,
int index = 0; int index = 0;
for (int i = boxes.size() - 1; i >= 0; i--) { for (int i = boxes.size() - 1; i >= 0; i--) {
crop_img = GetRotateCropImage(srcimg, boxes[i]); crop_img = GetRotateCropImage(srcimg, boxes[i]);
crop_img = RunClsModel(crop_img, predictor_cls);
float wh_ratio = float wh_ratio =
static_cast<float>(crop_img.cols) / static_cast<float>(crop_img.rows); static_cast<float>(crop_img.cols) / static_cast<float>(crop_img.rows);
...@@ -323,8 +369,9 @@ int main(int argc, char **argv) { ...@@ -323,8 +369,9 @@ int main(int argc, char **argv) {
} }
std::string det_model_file = argv[1]; std::string det_model_file = argv[1];
std::string rec_model_file = argv[2]; std::string rec_model_file = argv[2];
std::string img_path = argv[3]; std::string cls_model_file = argv[3];
std::string dict_path = argv[4]; std::string img_path = argv[4];
std::string dict_path = argv[5];
//// load config from txt file //// load config from txt file
auto Config = LoadConfigTxt("./config.txt"); auto Config = LoadConfigTxt("./config.txt");
...@@ -333,6 +380,7 @@ int main(int argc, char **argv) { ...@@ -333,6 +380,7 @@ int main(int argc, char **argv) {
auto det_predictor = loadModel(det_model_file); auto det_predictor = loadModel(det_model_file);
auto rec_predictor = loadModel(rec_model_file); auto rec_predictor = loadModel(rec_model_file);
auto cls_predictor = loadModel(cls_model_file);
auto charactor_dict = ReadDict(dict_path); auto charactor_dict = ReadDict(dict_path);
charactor_dict.push_back(" "); charactor_dict.push_back(" ");
...@@ -343,7 +391,7 @@ int main(int argc, char **argv) { ...@@ -343,7 +391,7 @@ int main(int argc, char **argv) {
std::vector<std::string> rec_text; std::vector<std::string> rec_text;
std::vector<float> rec_text_score; std::vector<float> rec_text_score;
RunRecModel(boxes, srcimg, rec_predictor, rec_text, rec_text_score, RunRecModel(boxes, srcimg, rec_predictor, rec_text, rec_text_score,
charactor_dict); charactor_dict, cls_predictor);
auto end = std::chrono::system_clock::now(); auto end = std::chrono::system_clock::now();
auto duration = auto duration =
......
## 文字角度分类
### 数据准备
请按如下步骤设置数据集:
训练数据的默认存储路径是 `PaddleOCR/train_data/cls`,如果您的磁盘上已有数据集,只需创建软链接至数据集目录:
```
ln -sf <path/to/dataset> <path/to/paddle_ocr>/train_data/cls/dataset
```
请参考下文组织您的数据。
- 训练集
首先请将训练图片放入同一个文件夹(train_images),并用一个txt文件(cls_gt_train.txt)记录图片路径和标签。
**注意:** 默认请将图片路径和图片标签用 `\t` 分割,如用其他方式分割将造成训练报错
0和180分别表示图片的角度为0度和180度
```
" 图像文件名 图像标注信息 "
train_data/cls/word_001.jpg 0
train_data/cls/word_002.jpg 180
```
最终训练集应有如下文件结构:
```
|-train_data
|-cls
|- cls_gt_train.txt
|- train
|- word_001.png
|- word_002.jpg
|- word_003.jpg
| ...
```
- 测试集
同训练集类似,测试集也需要提供一个包含所有图片的文件夹(test)和一个cls_gt_test.txt,测试集的结构如下所示:
```
|-train_data
|-cls
|- 和一个cls_gt_test.txt
|- test
|- word_001.jpg
|- word_002.jpg
|- word_003.jpg
| ...
```
### 启动训练
PaddleOCR提供了训练脚本、评估脚本和预测脚本。
开始训练:
*如果您安装的是cpu版本,请将配置文件中的 `use_gpu` 字段修改为false*
```
# 设置PYTHONPATH路径
export PYTHONPATH=$PYTHONPATH:.
# GPU训练 支持单卡,多卡训练,通过CUDA_VISIBLE_DEVICES指定卡号
export CUDA_VISIBLE_DEVICES=0,1,2,3
# 启动训练
python3 tools/train.py -c configs/cls/cls_mv3.yml
```
- 数据增强
PaddleOCR提供了多种数据增强方式,如果您希望在训练时加入扰动,请在配置文件中设置 `distort: true`
默认的扰动方式有:颜色空间转换(cvtColor)、模糊(blur)、抖动(jitter)、噪声(Gasuss noise)、随机切割(random crop)、透视(perspective)、颜色反转(reverse),随机数据增强(RandAugment)。
训练过程中除随机数据增强外每种扰动方式以50%的概率被选择,具体代码实现请参考:
[randaugment.py](https://github.com/PaddlePaddle/PaddleOCR/blob/develop/ppocr/data/cls/randaugment.py)
[img_tools.py](https://github.com/PaddlePaddle/PaddleOCR/blob/develop/ppocr/data/rec/img_tools.py)
*由于OpenCV的兼容性问题,扰动操作暂时只支持linux*
### 训练
PaddleOCR支持训练和评估交替进行, 可以在 `configs/cls/cls_mv3.yml` 中修改 `eval_batch_step` 设置评估频率,默认每500个iter评估一次。评估过程中默认将最佳acc模型,保存为 `output/cls_mv3/best_accuracy`
如果验证集很大,测试将会比较耗时,建议减少评估次数,或训练完再进行评估。
**注意,预测/评估时的配置文件请务必与训练一致。**
### 评估
评估数据集可以通过`configs/cls/cls_reader.yml` 修改EvalReader中的 `label_file_path` 设置。
*注意* 评估时必须确保配置文件中 infer_img 字段为空
```
export CUDA_VISIBLE_DEVICES=0
# GPU 评估, Global.checkpoints 为待测权重
python3 tools/eval.py -c configs/cls/cls_mv3.yml -o Global.checkpoints={path/to/weights}/best_accuracy
```
### 预测
* 训练引擎的预测
使用 PaddleOCR 训练好的模型,可以通过以下脚本进行快速预测。
默认预测图片存储在 `infer_img` 里,通过 `-o Global.checkpoints` 指定权重:
```
# 预测分类结果
python3 tools/infer_cls.py -c configs/cls/cls_mv3.yml -o Global.checkpoints={path/to/weights}/best_accuracy Global.infer_img=doc/imgs_words/en/word_1.png
```
预测图片:
![](../imgs_words/en/word_1.png)
得到输入图像的预测结果:
```
infer_img: doc/imgs_words/en/word_1.png
scores: [[0.93161047 0.06838956]]
label: [0]
```
...@@ -11,24 +11,28 @@ inference 模型(`fluid.io.save_inference_model`保存的模型) ...@@ -11,24 +11,28 @@ inference 模型(`fluid.io.save_inference_model`保存的模型)
- [一、训练模型转inference模型](#训练模型转inference模型) - [一、训练模型转inference模型](#训练模型转inference模型)
- [检测模型转inference模型](#检测模型转inference模型) - [检测模型转inference模型](#检测模型转inference模型)
- [识别模型转inference模型](#识别模型转inference模型) - [识别模型转inference模型](#识别模型转inference模型)
- [方向分类模型转inference模型](#方向模型转inference模型)
- [二、文本检测模型推理](#文本检测模型推理) - [二、文本检测模型推理](#文本检测模型推理)
- [1. 超轻量中文检测模型推理](#超轻量中文检测模型推理) - [1. 超轻量中文检测模型推理](#超轻量中文检测模型推理)
- [2. DB文本检测模型推理](#DB文本检测模型推理) - [2. DB文本检测模型推理](#DB文本检测模型推理)
- [3. EAST文本检测模型推理](#EAST文本检测模型推理) - [3. EAST文本检测模型推理](#EAST文本检测模型推理)
- [4. SAST文本检测模型推理](#SAST文本检测模型推理) - [4. SAST文本检测模型推理](#SAST文本检测模型推理)
- [三、文本识别模型推理](#文本识别模型推理) - [三、文本识别模型推理](#文本识别模型推理)
- [1. 超轻量中文识别模型推理](#超轻量中文识别模型推理) - [1. 超轻量中文识别模型推理](#超轻量中文识别模型推理)
- [2. 基于CTC损失的识别模型推理](#基于CTC损失的识别模型推理) - [2. 基于CTC损失的识别模型推理](#基于CTC损失的识别模型推理)
- [3. 基于Attention损失的识别模型推理](#基于Attention损失的识别模型推理) - [3. 基于Attention损失的识别模型推理](#基于Attention损失的识别模型推理)
- [4. 自定义文本识别字典的推理](#自定义文本识别字典的推理) - [4. 自定义文本识别字典的推理](#自定义文本识别字典的推理)
- [四、文本检测、识别串联推理](#文本检测、识别串联推理) - [四、方向分类模型推理](#方向识别模型推理)
- [1. 方向分类模型推理](#方向分类模型推理)
- [五、文本检测、方向分类和文字识别串联推理](#文本检测、方向分类和文字识别串联推理)
- [1. 超轻量中文OCR模型推理](#超轻量中文OCR模型推理) - [1. 超轻量中文OCR模型推理](#超轻量中文OCR模型推理)
- [2. 其他模型推理](#其他模型推理) - [2. 其他模型推理](#其他模型推理)
<a name="训练模型转inference模型"></a> <a name="训练模型转inference模型"></a>
## 一、训练模型转inference模型 ## 一、训练模型转inference模型
<a name="检测模型转inference模型"></a> <a name="检测模型转inference模型"></a>
...@@ -84,6 +88,32 @@ python3 tools/export_model.py -c configs/rec/rec_chinese_lite_train.yml -o Globa ...@@ -84,6 +88,32 @@ python3 tools/export_model.py -c configs/rec/rec_chinese_lite_train.yml -o Globa
└─ params 识别inference模型的参数文件 └─ params 识别inference模型的参数文件
``` ```
<a name="方向分类模型转inference模型"></a>
### 方向分类模型转inference模型
下载方向分类模型:
```
wget -P ./ch_lite/ https://paddleocr.bj.bcebos.com/20-09-22/cls/ch_ppocr_mobile-v1.1.cls_pre.tar && tar xf ./ch_lite/ch_ppocr_mobile-v1.1.cls_pre.tar -C ./ch_lite/
```
方向分类模型转inference模型与检测的方式相同,如下:
```
# -c后面设置训练算法的yml配置文件
# -o配置可选参数
# Global.checkpoints参数设置待转换的训练模型地址,不用添加文件后缀.pdmodel,.pdopt或.pdparams。
# Global.save_inference_dir参数设置转换的模型将保存的地址。
python3 tools/export_model.py -c configs/cls/cls_mv3.yml -o Global.checkpoints=./ch_lite/cls_model/best_accuracy \
Global.save_inference_dir=./inference/cls/
```
转换成功后,在目录下有两个文件:
```
/inference/cls/
└─ model 识别inference模型的program文件
└─ params 识别inference模型的参数文件
```
<a name="文本检测模型推理"></a> <a name="文本检测模型推理"></a>
## 二、文本检测模型推理 ## 二、文本检测模型推理
...@@ -275,15 +305,36 @@ dict_character = list(self.character_str) ...@@ -275,15 +305,36 @@ dict_character = list(self.character_str)
python3 tools/infer/predict_rec.py --image_dir="./doc/imgs_words_en/word_336.png" --rec_model_dir="./your inference model" --rec_image_shape="3, 32, 100" --rec_char_type="en" --rec_char_dict_path="your text dict path" python3 tools/infer/predict_rec.py --image_dir="./doc/imgs_words_en/word_336.png" --rec_model_dir="./your inference model" --rec_image_shape="3, 32, 100" --rec_char_type="en" --rec_char_dict_path="your text dict path"
``` ```
<a name="文本检测、识别串联推理"></a>
## 四、文本检测、识别串联推理 <a name="方向分类模型推理"></a>
## 四、方向分类模型推理
下面将介绍方向分类模型推理。
<a name="方向分类模型推理"></a>
### 1. 方向分类模型推理
方向分类模型推理,可以执行如下命令:
```
python3 tools/infer/predict_cls.py --image_dir="./doc/imgs_words/ch/word_4.jpg" --cls_model_dir="./inference/cls/"
```
![](../imgs_words/ch/word_4.jpg)
执行命令后,上面图像的预测结果(分类的方向和得分)会打印到屏幕上,示例如下:
Predicts of ./doc/imgs_words/ch/word_4.jpg:['0', 0.9999963]
<a name="文本检测、方向分类和文字识别串联推理"></a>
## 五、文本检测、方向分类和文字识别串联推理
<a name="超轻量中文OCR模型推理"></a> <a name="超轻量中文OCR模型推理"></a>
### 1. 超轻量中文OCR模型推理 ### 1. 超轻量中文OCR模型推理
在执行预测时,需要通过参数image_dir指定单张图像或者图像集合的路径、参数det_model_dir指定检测inference模型的路径和参数rec_model_dir指定识别inference模型的路径。可视化识别结果默认保存到 ./inference_results 文件夹里面。 在执行预测时,需要通过参数`image_dir`指定单张图像或者图像集合的路径、参数`det_model_dir`,`cls_model_dir``rec_model_dir`分别指定检测,方向分类和识别的inference模型路径。参数`use_angle_cls`用于控制是否启用方向分类模型。可视化识别结果默认保存到 ./inference_results 文件夹里面。
``` ```
python3 tools/infer/predict_system.py --image_dir="./doc/imgs/2.jpg" --det_model_dir="./inference/det_db/" --rec_model_dir="./inference/rec_crnn/" python3 tools/infer/predict_system.py --image_dir="./doc/imgs/2.jpg" --det_model_dir="./inference/det_db/" --cls_model_dir="./inference/cls/" --rec_model_dir="./inference/rec_crnn/" --use_angle_cls true
``` ```
执行命令后,识别结果图像如下: 执行命令后,识别结果图像如下:
......
## TEXT ANGLE CLASSIFICATION
### DATA PREPARATION
Please organize the dataset as follows:
The default storage path for training data is `PaddleOCR/train_data/cls`, if you already have a dataset on your disk, just create a soft link to the dataset directory:
```
ln -sf <path/to/dataset> <path/to/paddle_ocr>/train_data/cls/dataset
```
please refer to the following to organize your data.
- Training set
First put the training images in the same folder (train_images), and use a txt file (cls_gt_train.txt) to store the image path and label.
* Note: by default, the image path and image label are split with `\t`, if you use other methods to split, it will cause training error
0 and 180 indicate that the angle of the image is 0 degrees and 180 degrees, respectively.
```
" Image file name Image annotation "
train_data/word_001.jpg 0
train_data/word_002.jpg 180
```
The final training set should have the following file structure:
```
|-train_data
|-cls
|- cls_gt_train.txt
|- train
|- word_001.png
|- word_002.jpg
|- word_003.jpg
| ...
```
- Test set
Similar to the training set, the test set also needs to be provided a folder
containing all images (test) and a cls_gt_test.txt. The structure of the test set is as follows:
```
|-train_data
|-cls
|- cls_gt_test.txt
|- test
|- word_001.jpg
|- word_002.jpg
|- word_003.jpg
| ...
```
### TRAINING
PaddleOCR provides training scripts, evaluation scripts, and prediction scripts.
Start training:
```
# Set PYTHONPATH path
export PYTHONPATH=$PYTHONPATH:.
# GPU training Support single card and multi-card training, specify the card number through CUDA_VISIBLE_DEVICES
export CUDA_VISIBLE_DEVICES=0,1,2,3
# Training icdar15 English data
python3 tools/train.py -c configs/cls/cls_mv3.yml
```
- Data Augmentation
PaddleOCR provides a variety of data augmentation methods. If you want to add disturbance during training, please set `distort: true` in the configuration file.
The default perturbation methods are: cvtColor, blur, jitter, Gasuss noise, random crop, perspective, color reverse, RandAugment.
Except for RandAugment, each disturbance method is selected with a 50% probability during the training process. For specific code implementation, please refer to:
[randaugment.py](https://github.com/PaddlePaddle/PaddleOCR/blob/develop/ppocr/data/cls/randaugment.py)
[img_tools.py](https://github.com/PaddlePaddle/PaddleOCR/blob/develop/ppocr/data/rec/img_tools.py)
- Training
PaddleOCR supports alternating training and evaluation. You can modify `eval_batch_step` in `configs/cls/cls_mv3.yml` to set the evaluation frequency. By default, it is evaluated every 500 iter and the best acc model is saved under `output/cls_mv3/best_accuracy` during the evaluation process.
If the evaluation set is large, the test will be time-consuming. It is recommended to reduce the number of evaluations, or evaluate after training.
**Note that the configuration file for prediction/evaluation must be consistent with the training.**
### EVALUATION
The evaluation data set can be modified via `configs/cls/cls_reader.yml` setting of `label_file_path` in EvalReader.
```
export CUDA_VISIBLE_DEVICES=0
# GPU evaluation, Global.checkpoints is the weight to be tested
python3 tools/eval.py -c configs/cls/cls_mv3.yml -o Global.checkpoints={path/to/weights}/best_accuracy
```
### PREDICTION
* Training engine prediction
Using the model trained by paddleocr, you can quickly get prediction through the following script.
The default prediction picture is stored in `infer_img`, and the weight is specified via `-o Global.checkpoints`:
```
# Predict English results
python3 tools/infer_rec.py -c configs/cls/cls_mv3.yml -o Global.checkpoints={path/to/weights}/best_accuracy TestReader.infer_img=doc/imgs_words/en/word_1.jpg
```
Input image:
![](../imgs_words/en/word_1.png)
Get the prediction result of the input image:
```
infer_img: doc/imgs_words/en/word_1.png
scores: [[0.93161047 0.06838956]]
label: [0]
```
...@@ -12,25 +12,28 @@ Next, we first introduce how to convert a trained model into an inference model, ...@@ -12,25 +12,28 @@ Next, we first introduce how to convert a trained model into an inference model,
- [CONVERT TRAINING MODEL TO INFERENCE MODEL](#CONVERT) - [CONVERT TRAINING MODEL TO INFERENCE MODEL](#CONVERT)
- [Convert detection model to inference model](#Convert_detection_model) - [Convert detection model to inference model](#Convert_detection_model)
- [Convert recognition model to inference model](#Convert_recognition_model) - [Convert recognition model to inference model](#Convert_recognition_model)
- [Convert angle classification model to inference model](#Convert_angle_class_model)
- [TEXT DETECTION MODEL INFERENCE](#DETECTION_MODEL_INFERENCE) - [TEXT DETECTION MODEL INFERENCE](#DETECTION_MODEL_INFERENCE)
- [1. LIGHTWEIGHT CHINESE DETECTION MODEL INFERENCE](#LIGHTWEIGHT_DETECTION) - [1. LIGHTWEIGHT CHINESE DETECTION MODEL INFERENCE](#LIGHTWEIGHT_DETECTION)
- [2. DB TEXT DETECTION MODEL INFERENCE](#DB_DETECTION) - [2. DB TEXT DETECTION MODEL INFERENCE](#DB_DETECTION)
- [3. EAST TEXT DETECTION MODEL INFERENCE](#EAST_DETECTION) - [3. EAST TEXT DETECTION MODEL INFERENCE](#EAST_DETECTION)
- [4. SAST TEXT DETECTION MODEL INFERENCE](#SAST_DETECTION) - [4. SAST TEXT DETECTION MODEL INFERENCE](#SAST_DETECTION)
- [TEXT RECOGNITION MODEL INFERENCE](#RECOGNITION_MODEL_INFERENCE) - [TEXT RECOGNITION MODEL INFERENCE](#RECOGNITION_MODEL_INFERENCE)
- [1. LIGHTWEIGHT CHINESE MODEL](#LIGHTWEIGHT_RECOGNITION) - [1. LIGHTWEIGHT CHINESE MODEL](#LIGHTWEIGHT_RECOGNITION)
- [2. CTC-BASED TEXT RECOGNITION MODEL INFERENCE](#CTC-BASED_RECOGNITION) - [2. CTC-BASED TEXT RECOGNITION MODEL INFERENCE](#CTC-BASED_RECOGNITION)
- [3. ATTENTION-BASED TEXT RECOGNITION MODEL INFERENCE](#ATTENTION-BASED_RECOGNITION) - [3. ATTENTION-BASED TEXT RECOGNITION MODEL INFERENCE](#ATTENTION-BASED_RECOGNITION)
- [4. TEXT RECOGNITION MODEL INFERENCE USING CUSTOM CHARACTERS DICTIONARY](#USING_CUSTOM_CHARACTERS) - [4. TEXT RECOGNITION MODEL INFERENCE USING CUSTOM CHARACTERS DICTIONARY](#USING_CUSTOM_CHARACTERS)
- [ANGLE CLASSIFICATION MODEL INFERENCE](#ANGLE_CLASS_MODEL_INFERENCE)
- [TEXT DETECTION AND RECOGNITION INFERENCE CONCATENATION](#CONCATENATION) - [1. ANGLE CLASSIFICATION MODEL INFERENCE](#ANGLE_CLASS_MODEL_INFERENCE)
- [TEXT DETECTION ANGLE CLASSIFICATION AND RECOGNITION INFERENCE CONCATENATION](#CONCATENATION)
- [1. LIGHTWEIGHT CHINESE MODEL](#LIGHTWEIGHT_CHINESE_MODEL) - [1. LIGHTWEIGHT CHINESE MODEL](#LIGHTWEIGHT_CHINESE_MODEL)
- [2. OTHER MODELS](#OTHER_MODELS) - [2. OTHER MODELS](#OTHER_MODELS)
<a name="CONVERT"></a> <a name="CONVERT"></a>
## CONVERT TRAINING MODEL TO INFERENCE MODEL ## CONVERT TRAINING MODEL TO INFERENCE MODEL
<a name="Convert_detection_model"></a> <a name="Convert_detection_model"></a>
...@@ -87,6 +90,33 @@ After the conversion is successful, there are two files in the directory: ...@@ -87,6 +90,33 @@ After the conversion is successful, there are two files in the directory:
└─ params Identify the parameter files of the inference model └─ params Identify the parameter files of the inference model
``` ```
<a name="Convert_angle_class_model"></a>
### Convert angle classification model to inference model
Download the angle classification model:
```
wget -P ./ch_lite/ https://paddleocr.bj.bcebos.com/20-09-22/cls/ch_ppocr_mobile-v1.1.cls_pre.tar && tar xf ./ch_lite/ch_ppocr_mobile-v1.1.cls_pre.tar -C ./ch_lite/
```
The angle classification model is converted to the inference model in the same way as the detection, as follows:
```
# -c Set the training algorithm yml configuration file
# -o Set optional parameters
# Global.checkpoints parameter Set the training model address to be converted without adding the file suffix .pdmodel, .pdopt or .pdparams.
# Global.save_inference_dir Set the address where the converted model will be saved.
python3 tools/export_model.py -c configs/cls/cls_mv3.yml -o Global.checkpoints=./ch_lite/cls_model/best_accuracy \
Global.save_inference_dir=./inference/cls/
```
After the conversion is successful, there are two files in the directory:
```
/inference/cls/
└─ model Identify the saved model files
└─ params Identify the parameter files of the inference model
```
<a name="DETECTION_MODEL_INFERENCE"></a> <a name="DETECTION_MODEL_INFERENCE"></a>
## TEXT DETECTION MODEL INFERENCE ## TEXT DETECTION MODEL INFERENCE
...@@ -276,16 +306,39 @@ If the chars dictionary is modified during training, you need to specify the new ...@@ -276,16 +306,39 @@ If the chars dictionary is modified during training, you need to specify the new
python3 tools/infer/predict_rec.py --image_dir="./doc/imgs_words_en/word_336.png" --rec_model_dir="./your inference model" --rec_image_shape="3, 32, 100" --rec_char_type="en" --rec_char_dict_path="your text dict path" python3 tools/infer/predict_rec.py --image_dir="./doc/imgs_words_en/word_336.png" --rec_model_dir="./your inference model" --rec_image_shape="3, 32, 100" --rec_char_type="en" --rec_char_dict_path="your text dict path"
``` ```
<a name="ANGLE_CLASSIFICATION_MODEL_INFERENCE"></a>
## ANGLE CLASSIFICATION MODEL INFERENCE
The following will introduce the angle classification model inference.
<a name="ANGLE_CLASS_MODEL_INFERENCE"></a>
### 1.ANGLE CLASSIFICATION MODEL INFERENCE
For angle classification model inference, you can execute the following commands:
```
python3 tools/infer/predict_cls.py --image_dir="./doc/imgs_words/ch/word_4.jpg" --cls_model_dir="./inference/cls/"
```
![](../imgs_words/ch/word_4.jpg)
After executing the command, the prediction results (classification angle and score) of the above image will be printed on the screen.
Predicts of ./doc/imgs_words/ch/word_4.jpg:['0', 0.9999963]
<a name="CONCATENATION"></a> <a name="CONCATENATION"></a>
## TEXT DETECTION AND RECOGNITION INFERENCE CONCATENATION ## TEXT DETECTION ANGLE CLASSIFICATION AND RECOGNITION INFERENCE CONCATENATION
<a name="LIGHTWEIGHT_CHINESE_MODEL"></a> <a name="LIGHTWEIGHT_CHINESE_MODEL"></a>
### 1. LIGHTWEIGHT CHINESE MODEL ### 1. LIGHTWEIGHT CHINESE MODEL
When performing prediction, you need to specify the path of a single image or a folder of images through the parameter `image_dir`, the parameter `det_model_dir` specifies the path to detect the inference model, and the parameter `rec_model_dir` specifies the path to identify the inference model. The visualized recognition results are saved to the `./inference_results` folder by default. When performing prediction, you need to specify the path of a single image or a folder of images through the parameter `image_dir`, the parameter `det_model_dir` specifies the path to detect the inference model, the parameter `cls_model_dir` specifies the path to angle classification inference model and the parameter `rec_model_dir` specifies the path to identify the inference model. The parameter `use_angle_cls` is used to control whether to enable the angle classification model.The visualized recognition results are saved to the `./inference_results` folder by default.
``` ```
python3 tools/infer/predict_system.py --image_dir="./doc/imgs/2.jpg" --det_model_dir="./inference/det_db/" --rec_model_dir="./inference/rec_crnn/" python3 tools/infer/predict_system.py --image_dir="./doc/imgs/2.jpg" --det_model_dir="./inference/det_db/" --cls_model_dir="./inference/cls/" --rec_model_dir="./inference/rec_crnn/" --use_angle_cls true
``` ```
After executing the command, the recognition result image is as follows: After executing the command, the recognition result image is as follows:
......
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
import sys
import math
import random
import numpy as np
import cv2
from ppocr.utils.utility import initial_logger
from ppocr.utils.utility import get_image_file_list
logger = initial_logger()
from ppocr.data.rec.img_tools import resize_norm_img, warp
from ppocr.data.cls.randaugment import RandAugment
def random_crop(img):
img_h, img_w = img.shape[:2]
if img_w > img_h * 4:
w = random.randint(img_h * 2, img_w)
i = random.randint(0, img_w - w)
img = img[:, i:i + w, :]
return img
class SimpleReader(object):
def __init__(self, params):
if params['mode'] != 'train':
self.num_workers = 1
else:
self.num_workers = params['num_workers']
if params['mode'] != 'test':
self.img_set_dir = params['img_set_dir']
self.label_file_path = params['label_file_path']
self.use_gpu = params['use_gpu']
self.image_shape = params['image_shape']
self.mode = params['mode']
self.infer_img = params['infer_img']
self.use_distort = params['mode'] == 'train' and params['distort']
self.randaug = RandAugment()
self.label_list = params['label_list']
if "distort" in params:
self.use_distort = params['distort'] and params['use_gpu']
if not params['use_gpu']:
logger.info(
"Distort operation can only support in GPU.Distort will be set to False."
)
if params['mode'] == 'train':
self.batch_size = params['train_batch_size_per_card']
self.drop_last = True
else:
self.batch_size = params['test_batch_size_per_card']
self.drop_last = False
self.use_distort = False
def __call__(self, process_id):
if self.mode != 'train':
process_id = 0
def get_device_num():
if self.use_gpu:
gpus = os.environ.get("CUDA_VISIBLE_DEVICES", 1)
gpu_num = len(gpus.split(','))
return gpu_num
else:
cpu_num = os.environ.get("CPU_NUM", 1)
return int(cpu_num)
def sample_iter_reader():
if self.mode != 'train' and self.infer_img is not None:
image_file_list = get_image_file_list(self.infer_img)
for single_img in image_file_list:
img = cv2.imread(single_img)
if img.shape[-1] == 1 or len(list(img.shape)) == 2:
img = cv2.cvtColor(img, cv2.COLOR_GRAY2BGR)
norm_img = resize_norm_img(img, self.image_shape)
norm_img = norm_img[np.newaxis, :]
yield norm_img
else:
with open(self.label_file_path, "rb") as fin:
label_infor_list = fin.readlines()
img_num = len(label_infor_list)
img_id_list = list(range(img_num))
random.shuffle(img_id_list)
if sys.platform == "win32" and self.num_workers != 1:
print("multiprocess is not fully compatible with Windows."
"num_workers will be 1.")
self.num_workers = 1
if self.batch_size * get_device_num(
) * self.num_workers > img_num:
raise Exception(
"The number of the whole data ({}) is smaller than the batch_size * devices_num * num_workers ({})".
format(img_num, self.batch_size * get_device_num() *
self.num_workers))
for img_id in range(process_id, img_num, self.num_workers):
label_infor = label_infor_list[img_id_list[img_id]]
substr = label_infor.decode('utf-8').strip("\n").split("\t")
label = self.label_list.index(substr[1])
img_path = self.img_set_dir + "/" + substr[0]
img = cv2.imread(img_path)
if img is None:
logger.info("{} does not exist!".format(img_path))
continue
if img.shape[-1] == 1 or len(list(img.shape)) == 2:
img = cv2.cvtColor(img, cv2.COLOR_GRAY2BGR)
if self.use_distort:
img = warp(img, 10)
img = self.randaug(img)
norm_img = resize_norm_img(img, self.image_shape)
norm_img = norm_img[np.newaxis, :]
yield (norm_img, label)
def batch_iter_reader():
batch_outs = []
for outs in sample_iter_reader():
batch_outs.append(outs)
if len(batch_outs) == self.batch_size:
yield batch_outs
batch_outs = []
if not self.drop_last:
if len(batch_outs) != 0:
yield batch_outs
if self.infer_img is None:
return batch_iter_reader
return sample_iter_reader
# copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from __future__ import unicode_literals
from PIL import Image, ImageEnhance, ImageOps
import numpy as np
import random
import six
class RawRandAugment(object):
def __init__(self, num_layers=2, magnitude=5, fillcolor=(128, 128, 128)):
self.num_layers = num_layers
self.magnitude = magnitude
self.max_level = 10
abso_level = self.magnitude / self.max_level
self.level_map = {
"shearX": 0.3 * abso_level,
"shearY": 0.3 * abso_level,
"translateX": 150.0 / 331 * abso_level,
"translateY": 150.0 / 331 * abso_level,
"rotate": 30 * abso_level,
"color": 0.9 * abso_level,
"posterize": int(4.0 * abso_level),
"solarize": 256.0 * abso_level,
"contrast": 0.9 * abso_level,
"sharpness": 0.9 * abso_level,
"brightness": 0.9 * abso_level,
"autocontrast": 0,
"equalize": 0,
"invert": 0
}
# from https://stackoverflow.com/questions/5252170/
# specify-image-filling-color-when-rotating-in-python-with-pil-and-setting-expand
def rotate_with_fill(img, magnitude):
rot = img.convert("RGBA").rotate(magnitude)
return Image.composite(rot,
Image.new("RGBA", rot.size, (128, ) * 4),
rot).convert(img.mode)
rnd_ch_op = random.choice
self.func = {
"shearX": lambda img, magnitude: img.transform(
img.size,
Image.AFFINE,
(1, magnitude * rnd_ch_op([-1, 1]), 0, 0, 1, 0),
Image.BICUBIC,
fillcolor=fillcolor),
"shearY": lambda img, magnitude: img.transform(
img.size,
Image.AFFINE,
(1, 0, 0, magnitude * rnd_ch_op([-1, 1]), 1, 0),
Image.BICUBIC,
fillcolor=fillcolor),
"translateX": lambda img, magnitude: img.transform(
img.size,
Image.AFFINE,
(1, 0, magnitude * img.size[0] * rnd_ch_op([-1, 1]), 0, 1, 0),
fillcolor=fillcolor),
"translateY": lambda img, magnitude: img.transform(
img.size,
Image.AFFINE,
(1, 0, 0, 0, 1, magnitude * img.size[1] * rnd_ch_op([-1, 1])),
fillcolor=fillcolor),
"rotate": lambda img, magnitude: rotate_with_fill(img, magnitude),
"color": lambda img, magnitude: ImageEnhance.Color(img).enhance(
1 + magnitude * rnd_ch_op([-1, 1])),
"posterize": lambda img, magnitude:
ImageOps.posterize(img, magnitude),
"solarize": lambda img, magnitude:
ImageOps.solarize(img, magnitude),
"contrast": lambda img, magnitude:
ImageEnhance.Contrast(img).enhance(
1 + magnitude * rnd_ch_op([-1, 1])),
"sharpness": lambda img, magnitude:
ImageEnhance.Sharpness(img).enhance(
1 + magnitude * rnd_ch_op([-1, 1])),
"brightness": lambda img, magnitude:
ImageEnhance.Brightness(img).enhance(
1 + magnitude * rnd_ch_op([-1, 1])),
"autocontrast": lambda img, magnitude:
ImageOps.autocontrast(img),
"equalize": lambda img, magnitude: ImageOps.equalize(img),
"invert": lambda img, magnitude: ImageOps.invert(img)
}
def __call__(self, img):
avaiable_op_names = list(self.level_map.keys())
for layer_num in range(self.num_layers):
op_name = np.random.choice(avaiable_op_names)
img = self.func[op_name](img, self.level_map[op_name])
return img
class RandAugment(RawRandAugment):
""" RandAugment wrapper to auto fit different img types """
def __init__(self, *args, **kwargs):
if six.PY2:
super(RandAugment, self).__init__(*args, **kwargs)
else:
super().__init__(*args, **kwargs)
def __call__(self, img):
if not isinstance(img, Image.Image):
img = np.ascontiguousarray(img)
img = Image.fromarray(img)
if six.PY2:
img = super(RandAugment, self).__call__(img)
else:
img = super().__call__(img)
if isinstance(img, Image.Image):
img = np.asarray(img)
return img
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from paddle import fluid
from ppocr.utils.utility import create_module
from ppocr.utils.utility import initial_logger
logger = initial_logger()
from copy import deepcopy
class ClsModel(object):
def __init__(self, params):
super(ClsModel, self).__init__()
global_params = params['Global']
self.infer_img = global_params['infer_img']
backbone_params = deepcopy(params["Backbone"])
backbone_params.update(global_params)
self.backbone = create_module(backbone_params['function']) \
(params=backbone_params)
head_params = deepcopy(params["Head"])
head_params.update(global_params)
self.head = create_module(head_params['function']) \
(params=head_params)
loss_params = deepcopy(params["Loss"])
loss_params.update(global_params)
self.loss = create_module(loss_params['function']) \
(params=loss_params)
self.image_shape = global_params['image_shape']
def create_feed(self, mode):
image_shape = deepcopy(self.image_shape)
image_shape.insert(0, -1)
if mode == "train":
image = fluid.data(name='image', shape=image_shape, dtype='float32')
label = fluid.data(name='label', shape=[None, 1], dtype='int64')
feed_list = [image, label]
labels = {'label': label}
loader = fluid.io.DataLoader.from_generator(
feed_list=feed_list,
capacity=64,
use_double_buffer=True,
iterable=False)
else:
labels = None
loader = None
image = fluid.data(name='image', shape=image_shape, dtype='float32')
return image, labels, loader
def __call__(self, mode):
image, labels, loader = self.create_feed(mode)
inputs = image
conv_feas = self.backbone(inputs)
predicts = self.head(conv_feas, labels, mode)
if mode == "train":
loss = self.loss(predicts, labels)
label = labels['label']
acc = fluid.layers.accuracy(predicts['predict'], label, k=1)
outputs = {'total_loss': loss, 'decoded_out': \
predicts['decoded_out'], 'label': label, 'acc': acc}
return loader, outputs
elif mode == "export":
return [image, predicts]
else:
return loader, predicts
#copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
#
#Licensed under the Apache License, Version 2.0 (the "License");
#you may not use this file except in compliance with the License.
#You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
#Unless required by applicable law or agreed to in writing, software
#distributed under the License is distributed on an "AS IS" BASIS,
#WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
#See the License for the specific language governing permissions and
#limitations under the License.
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import math
import paddle
import paddle.fluid as fluid
class ClsHead(object):
def __init__(self, params):
super(ClsHead, self).__init__()
self.class_dim = params['class_dim']
def __call__(self, inputs, labels=None, mode=None):
pool = fluid.layers.pool2d(
input=inputs, pool_type='avg', global_pooling=True)
stdv = 1.0 / math.sqrt(pool.shape[1] * 1.0)
out = fluid.layers.fc(
input=pool,
size=self.class_dim,
param_attr=fluid.param_attr.ParamAttr(
name="fc_0.w_0",
initializer=fluid.initializer.Uniform(-stdv, stdv)),
bias_attr=fluid.param_attr.ParamAttr(name="fc_0.b_0"))
softmax_out = fluid.layers.softmax(out, use_cudnn=False)
out_label = fluid.layers.argmax(out, axis=1)
predicts = {'predict': softmax_out, 'decoded_out': out_label}
return predicts
# copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import paddle.fluid as fluid
class ClsLoss(object):
def __init__(self, params):
super(ClsLoss, self).__init__()
self.loss_func = fluid.layers.cross_entropy
def __call__(self, predicts, labels):
predict = predicts['predict']
label = labels['label']
# softmax_out = fluid.layers.softmax(predict, use_cudnn=False)
cost = fluid.layers.cross_entropy(input=predict, label=label)
sum_cost = fluid.layers.mean(cost)
return sum_cost
...@@ -45,10 +45,12 @@ from ppocr.utils.save_load import init_model ...@@ -45,10 +45,12 @@ from ppocr.utils.save_load import init_model
from eval_utils.eval_det_utils import eval_det_run from eval_utils.eval_det_utils import eval_det_run
from eval_utils.eval_rec_utils import test_rec_benchmark from eval_utils.eval_rec_utils import test_rec_benchmark
from eval_utils.eval_rec_utils import eval_rec_run from eval_utils.eval_rec_utils import eval_rec_run
from eval_utils.eval_cls_utils import eval_cls_run
def main(): def main():
startup_prog, eval_program, place, config, train_alg_type = program.preprocess() startup_prog, eval_program, place, config, train_alg_type = program.preprocess(
)
eval_build_outputs = program.build( eval_build_outputs = program.build(
config, eval_program, startup_prog, mode='test') config, eval_program, startup_prog, mode='test')
eval_fetch_name_list = eval_build_outputs[1] eval_fetch_name_list = eval_build_outputs[1]
...@@ -67,6 +69,14 @@ def main(): ...@@ -67,6 +69,14 @@ def main():
'fetch_varname_list':eval_fetch_varname_list} 'fetch_varname_list':eval_fetch_varname_list}
metrics = eval_det_run(exe, config, eval_info_dict, "eval") metrics = eval_det_run(exe, config, eval_info_dict, "eval")
logger.info("Eval result: {}".format(metrics)) logger.info("Eval result: {}".format(metrics))
elif train_alg_type == 'cls':
eval_reader = reader_main(config=config, mode="eval")
eval_info_dict = {'program': eval_program, \
'reader': eval_reader, \
'fetch_name_list': eval_fetch_name_list, \
'fetch_varname_list': eval_fetch_varname_list}
metrics = eval_cls_run(exe, eval_info_dict)
logger.info("Eval result: {}".format(metrics))
else: else:
reader_type = config['Global']['reader_yml'] reader_type = config['Global']['reader_yml']
if "benchmark" not in reader_type: if "benchmark" not in reader_type:
......
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import numpy as np
__all__ = ['eval_cls_run']
import logging
FORMAT = '%(asctime)s-%(levelname)s: %(message)s'
logging.basicConfig(level=logging.INFO, format=FORMAT)
logger = logging.getLogger(__name__)
def eval_cls_run(exe, eval_info_dict):
"""
Run evaluation program, return program outputs.
"""
total_sample_num = 0
total_acc_num = 0
total_batch_num = 0
for data in eval_info_dict['reader']():
img_num = len(data)
img_list = []
label_list = []
for ino in range(img_num):
img_list.append(data[ino][0])
label_list.append(data[ino][1])
img_list = np.concatenate(img_list, axis=0)
outs = exe.run(eval_info_dict['program'], \
feed={'image': img_list}, \
fetch_list=eval_info_dict['fetch_varname_list'], \
return_numpy=False)
softmax_outs = np.array(outs[1])
if len(softmax_outs.shape) != 1:
softmax_outs = np.array(outs[0])
acc, acc_num = cal_cls_acc(softmax_outs, label_list)
total_acc_num += acc_num
total_sample_num += len(label_list)
# logger.info("eval batch id: {}, acc: {}".format(total_batch_num, acc))
total_batch_num += 1
avg_acc = total_acc_num * 1.0 / total_sample_num
metrics = {'avg_acc': avg_acc, "total_acc_num": total_acc_num, \
"total_sample_num": total_sample_num}
return metrics
def cal_cls_acc(preds, labels):
acc_num = 0
for pred, label in zip(preds, labels):
if pred == label:
acc_num += 1
return acc_num / len(preds), acc_num
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
import sys
__dir__ = os.path.dirname(os.path.abspath(__file__))
sys.path.append(__dir__)
sys.path.append(os.path.abspath(os.path.join(__dir__, '../..')))
import tools.infer.utility as utility
from ppocr.utils.utility import initial_logger
logger = initial_logger()
from ppocr.utils.utility import get_image_file_list, check_and_read_gif
import cv2
import copy
import numpy as np
import math
import time
from paddle import fluid
class TextClassifier(object):
def __init__(self, args):
self.predictor, self.input_tensor, self.output_tensors = \
utility.create_predictor(args, mode="cls")
self.cls_image_shape = [int(v) for v in args.cls_image_shape.split(",")]
self.cls_batch_num = args.rec_batch_num
self.label_list = args.label_list
self.use_zero_copy_run = args.use_zero_copy_run
def resize_norm_img(self, img):
imgC, imgH, imgW = self.cls_image_shape
h = img.shape[0]
w = img.shape[1]
ratio = w / float(h)
if math.ceil(imgH * ratio) > imgW:
resized_w = imgW
else:
resized_w = int(math.ceil(imgH * ratio))
resized_image = cv2.resize(img, (resized_w, imgH))
resized_image = resized_image.astype('float32')
if self.cls_image_shape[0] == 1:
resized_image = resized_image / 255
resized_image = resized_image[np.newaxis, :]
else:
resized_image = resized_image.transpose((2, 0, 1)) / 255
resized_image -= 0.5
resized_image /= 0.5
padding_im = np.zeros((imgC, imgH, imgW), dtype=np.float32)
padding_im[:, :, 0:resized_w] = resized_image
return padding_im
def __call__(self, img_list):
img_list = copy.deepcopy(img_list)
img_num = len(img_list)
# Calculate the aspect ratio of all text bars
width_list = []
for img in img_list:
width_list.append(img.shape[1] / float(img.shape[0]))
# Sorting can speed up the cls process
indices = np.argsort(np.array(width_list))
cls_res = [['', 0.0]] * img_num
batch_num = self.cls_batch_num
predict_time = 0
for beg_img_no in range(0, img_num, batch_num):
end_img_no = min(img_num, beg_img_no + batch_num)
norm_img_batch = []
max_wh_ratio = 0
for ino in range(beg_img_no, end_img_no):
h, w = img_list[indices[ino]].shape[0:2]
wh_ratio = w * 1.0 / h
max_wh_ratio = max(max_wh_ratio, wh_ratio)
for ino in range(beg_img_no, end_img_no):
norm_img = self.resize_norm_img(img_list[indices[ino]])
norm_img = norm_img[np.newaxis, :]
norm_img_batch.append(norm_img)
norm_img_batch = np.concatenate(norm_img_batch)
norm_img_batch = norm_img_batch.copy()
starttime = time.time()
if self.use_zero_copy_run:
self.input_tensor.copy_from_cpu(norm_img_batch)
self.predictor.zero_copy_run()
else:
norm_img_batch = fluid.core.PaddleTensor(norm_img_batch)
self.predictor.run([norm_img_batch])
prob_out = self.output_tensors[0].copy_to_cpu()
label_out = self.output_tensors[1].copy_to_cpu()
if len(label_out.shape) != 1:
prob_out, label_out = label_out, prob_out
elapse = time.time() - starttime
predict_time += elapse
for rno in range(len(label_out)):
label_idx = label_out[rno]
score = prob_out[rno][label_idx]
label = self.label_list[label_idx]
cls_res[indices[beg_img_no + rno]] = [label, score]
if '180' in label and score > 0.9999:
img_list[indices[beg_img_no + rno]] = cv2.rotate(
img_list[indices[beg_img_no + rno]], 1)
return img_list, cls_res, predict_time
def main(args):
image_file_list = get_image_file_list(args.image_dir)
text_classifier = TextClassifier(args)
valid_image_file_list = []
img_list = []
for image_file in image_file_list[:10]:
img, flag = check_and_read_gif(image_file)
if not flag:
img = cv2.imread(image_file)
if img is None:
logger.info("error in loading image:{}".format(image_file))
continue
valid_image_file_list.append(image_file)
img_list.append(img)
try:
img_list, cls_res, predict_time = text_classifier(img_list)
except Exception as e:
print(e)
exit()
for ino in range(len(img_list)):
print("Predicts of %s:%s" % (valid_image_file_list[ino], cls_res[ino]))
print("Total predict time for %d images:%.3f" %
(len(img_list), predict_time))
if __name__ == "__main__":
main(utility.parse_args())
...@@ -13,16 +13,19 @@ ...@@ -13,16 +13,19 @@
# limitations under the License. # limitations under the License.
import os import os
import sys import sys
__dir__ = os.path.dirname(os.path.abspath(__file__)) __dir__ = os.path.dirname(os.path.abspath(__file__))
sys.path.append(__dir__) sys.path.append(__dir__)
sys.path.append(os.path.abspath(os.path.join(__dir__, '../..'))) sys.path.append(os.path.abspath(os.path.join(__dir__, '../..')))
import tools.infer.utility as utility import tools.infer.utility as utility
from ppocr.utils.utility import initial_logger from ppocr.utils.utility import initial_logger
logger = initial_logger() logger = initial_logger()
import cv2 import cv2
import tools.infer.predict_det as predict_det import tools.infer.predict_det as predict_det
import tools.infer.predict_rec as predict_rec import tools.infer.predict_rec as predict_rec
import tools.infer.predict_cls as predict_cls
import copy import copy
import numpy as np import numpy as np
import math import math
...@@ -37,6 +40,9 @@ class TextSystem(object): ...@@ -37,6 +40,9 @@ class TextSystem(object):
def __init__(self, args): def __init__(self, args):
self.text_detector = predict_det.TextDetector(args) self.text_detector = predict_det.TextDetector(args)
self.text_recognizer = predict_rec.TextRecognizer(args) self.text_recognizer = predict_rec.TextRecognizer(args)
self.use_angle_cls = args.use_angle_cls
if self.use_angle_cls:
self.text_classifier = predict_cls.TextClassifier(args)
def get_rotate_crop_image(self, img, points): def get_rotate_crop_image(self, img, points):
''' '''
...@@ -91,6 +97,11 @@ class TextSystem(object): ...@@ -91,6 +97,11 @@ class TextSystem(object):
tmp_box = copy.deepcopy(dt_boxes[bno]) tmp_box = copy.deepcopy(dt_boxes[bno])
img_crop = self.get_rotate_crop_image(ori_im, tmp_box) img_crop = self.get_rotate_crop_image(ori_im, tmp_box)
img_crop_list.append(img_crop) img_crop_list.append(img_crop)
if self.use_angle_cls:
img_crop_list, angle_list, elapse = self.text_classifier(
img_crop_list)
print("cls num : {}, elapse : {}".format(
len(img_crop_list), elapse))
rec_res, elapse = self.text_recognizer(img_crop_list) rec_res, elapse = self.text_recognizer(img_crop_list)
print("rec_res num : {}, elapse : {}".format(len(rec_res), elapse)) print("rec_res num : {}, elapse : {}".format(len(rec_res), elapse))
# self.print_draw_crop_rec_res(img_crop_list, rec_res) # self.print_draw_crop_rec_res(img_crop_list, rec_res)
...@@ -110,8 +121,8 @@ def sorted_boxes(dt_boxes): ...@@ -110,8 +121,8 @@ def sorted_boxes(dt_boxes):
_boxes = list(sorted_boxes) _boxes = list(sorted_boxes)
for i in range(num_boxes - 1): for i in range(num_boxes - 1):
if abs(_boxes[i+1][0][1] - _boxes[i][0][1]) < 10 and \ if abs(_boxes[i + 1][0][1] - _boxes[i][0][1]) < 10 and \
(_boxes[i + 1][0][0] < _boxes[i][0][0]): (_boxes[i + 1][0][0] < _boxes[i][0][0]):
tmp = _boxes[i] tmp = _boxes[i]
_boxes[i] = _boxes[i + 1] _boxes[i] = _boxes[i + 1]
_boxes[i + 1] = tmp _boxes[i + 1] = tmp
......
...@@ -15,6 +15,7 @@ ...@@ -15,6 +15,7 @@
import argparse import argparse
import os, sys import os, sys
from ppocr.utils.utility import initial_logger from ppocr.utils.utility import initial_logger
logger = initial_logger() logger = initial_logger()
from paddle.fluid.core import PaddleTensor from paddle.fluid.core import PaddleTensor
from paddle.fluid.core import AnalysisConfig from paddle.fluid.core import AnalysisConfig
...@@ -31,34 +32,34 @@ def parse_args(): ...@@ -31,34 +32,34 @@ def parse_args():
return v.lower() in ("true", "t", "1") return v.lower() in ("true", "t", "1")
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
#params for prediction engine # params for prediction engine
parser.add_argument("--use_gpu", type=str2bool, default=True) parser.add_argument("--use_gpu", type=str2bool, default=True)
parser.add_argument("--ir_optim", type=str2bool, default=True) parser.add_argument("--ir_optim", type=str2bool, default=True)
parser.add_argument("--use_tensorrt", type=str2bool, default=False) parser.add_argument("--use_tensorrt", type=str2bool, default=False)
parser.add_argument("--gpu_mem", type=int, default=8000) parser.add_argument("--gpu_mem", type=int, default=8000)
#params for text detector # params for text detector
parser.add_argument("--image_dir", type=str) parser.add_argument("--image_dir", type=str)
parser.add_argument("--det_algorithm", type=str, default='DB') parser.add_argument("--det_algorithm", type=str, default='DB')
parser.add_argument("--det_model_dir", type=str) parser.add_argument("--det_model_dir", type=str)
parser.add_argument("--det_max_side_len", type=float, default=960) parser.add_argument("--det_max_side_len", type=float, default=960)
#DB parmas # DB parmas
parser.add_argument("--det_db_thresh", type=float, default=0.3) parser.add_argument("--det_db_thresh", type=float, default=0.3)
parser.add_argument("--det_db_box_thresh", type=float, default=0.5) parser.add_argument("--det_db_box_thresh", type=float, default=0.5)
parser.add_argument("--det_db_unclip_ratio", type=float, default=2.0) parser.add_argument("--det_db_unclip_ratio", type=float, default=2.0)
#EAST parmas # EAST parmas
parser.add_argument("--det_east_score_thresh", type=float, default=0.8) parser.add_argument("--det_east_score_thresh", type=float, default=0.8)
parser.add_argument("--det_east_cover_thresh", type=float, default=0.1) parser.add_argument("--det_east_cover_thresh", type=float, default=0.1)
parser.add_argument("--det_east_nms_thresh", type=float, default=0.2) parser.add_argument("--det_east_nms_thresh", type=float, default=0.2)
#SAST parmas # SAST parmas
parser.add_argument("--det_sast_score_thresh", type=float, default=0.5) parser.add_argument("--det_sast_score_thresh", type=float, default=0.5)
parser.add_argument("--det_sast_nms_thresh", type=float, default=0.2) parser.add_argument("--det_sast_nms_thresh", type=float, default=0.2)
parser.add_argument("--det_sast_polygon", type=bool, default=False) parser.add_argument("--det_sast_polygon", type=bool, default=False)
#params for text recognizer # params for text recognizer
parser.add_argument("--rec_algorithm", type=str, default='CRNN') parser.add_argument("--rec_algorithm", type=str, default='CRNN')
parser.add_argument("--rec_model_dir", type=str) parser.add_argument("--rec_model_dir", type=str)
parser.add_argument("--rec_image_shape", type=str, default="3, 32, 320") parser.add_argument("--rec_image_shape", type=str, default="3, 32, 320")
...@@ -70,14 +71,24 @@ def parse_args(): ...@@ -70,14 +71,24 @@ def parse_args():
type=str, type=str,
default="./ppocr/utils/ppocr_keys_v1.txt") default="./ppocr/utils/ppocr_keys_v1.txt")
parser.add_argument("--use_space_char", type=bool, default=True) parser.add_argument("--use_space_char", type=bool, default=True)
parser.add_argument("--enable_mkldnn", type=bool, default=False)
parser.add_argument("--use_zero_copy_run", type=bool, default=False) # params for text classifier
parser.add_argument("--use_angle_cls", type=str2bool, default=False)
parser.add_argument("--cls_model_dir", type=str)
parser.add_argument("--cls_image_shape", type=str, default="3, 48, 192")
parser.add_argument("--label_list", type=list, default=['0', '180'])
parser.add_argument("--cls_batch_num", type=int, default=30)
parser.add_argument("--enable_mkldnn", type=str2bool, default=False)
parser.add_argument("--use_zero_copy_run", type=str2bool, default=False)
return parser.parse_args() return parser.parse_args()
def create_predictor(args, mode): def create_predictor(args, mode):
if mode == "det": if mode == "det":
model_dir = args.det_model_dir model_dir = args.det_model_dir
elif mode == 'cls':
model_dir = args.cls_model_dir
else: else:
model_dir = args.rec_model_dir model_dir = args.rec_model_dir
...@@ -105,7 +116,7 @@ def create_predictor(args, mode): ...@@ -105,7 +116,7 @@ def create_predictor(args, mode):
config.set_mkldnn_cache_capacity(10) config.set_mkldnn_cache_capacity(10)
config.enable_mkldnn() config.enable_mkldnn()
#config.enable_memory_optim() # config.enable_memory_optim()
config.disable_glog_info() config.disable_glog_info()
if args.use_zero_copy_run: if args.use_zero_copy_run:
......
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import numpy as np
import os
import sys
__dir__ = os.path.dirname(__file__)
sys.path.append(__dir__)
sys.path.append(os.path.join(__dir__, '..'))
def set_paddle_flags(**kwargs):
for key, value in kwargs.items():
if os.environ.get(key, None) is None:
os.environ[key] = str(value)
# NOTE(paddle-dev): All of these flags should be
# set before `import paddle`. Otherwise, it would
# not take any effect.
set_paddle_flags(
FLAGS_eager_delete_tensor_gb=0, # enable GC to save memory
)
import tools.program as program
from paddle import fluid
from ppocr.utils.utility import initial_logger
logger = initial_logger()
from ppocr.data.reader_main import reader_main
from ppocr.utils.save_load import init_model
from ppocr.utils.utility import create_module
from ppocr.utils.utility import get_image_file_list
def main():
config = program.load_config(FLAGS.config)
program.merge_config(FLAGS.opt)
logger.info(config)
# check if set use_gpu=True in paddlepaddle cpu version
use_gpu = config['Global']['use_gpu']
# check_gpu(use_gpu)
place = fluid.CUDAPlace(0) if use_gpu else fluid.CPUPlace()
exe = fluid.Executor(place)
rec_model = create_module(config['Architecture']['function'])(params=config)
startup_prog = fluid.Program()
eval_prog = fluid.Program()
with fluid.program_guard(eval_prog, startup_prog):
with fluid.unique_name.guard():
_, outputs = rec_model(mode="test")
fetch_name_list = list(outputs.keys())
fetch_varname_list = [outputs[v].name for v in fetch_name_list]
eval_prog = eval_prog.clone(for_test=True)
exe.run(startup_prog)
init_model(config, eval_prog, exe)
blobs = reader_main(config, 'test')()
infer_img = config['Global']['infer_img']
infer_list = get_image_file_list(infer_img)
max_img_num = len(infer_list)
if len(infer_list) == 0:
logger.info("Can not find img in infer_img dir.")
for i in range(max_img_num):
logger.info("infer_img:%s" % infer_list[i])
img = next(blobs)
predict = exe.run(program=eval_prog,
feed={"image": img},
fetch_list=fetch_varname_list,
return_numpy=False)
scores = np.array(predict[0])
label = np.array(predict[1])
if len(label.shape) != 1:
label, scores = scores, label
logger.info('\t scores: {}'.format(scores))
logger.info('\t label: {}'.format(label))
# save for inference model
target_var = []
for key, values in outputs.items():
target_var.append(values)
fluid.io.save_inference_model(
"./output",
feeded_var_names=['image'],
target_vars=target_var,
executor=exe,
main_program=eval_prog,
model_filename="model",
params_filename="params")
if __name__ == '__main__':
parser = program.ArgsParser()
FLAGS = parser.parse_args()
main()
...@@ -30,6 +30,7 @@ import time ...@@ -30,6 +30,7 @@ import time
from ppocr.utils.stats import TrainingStats from ppocr.utils.stats import TrainingStats
from eval_utils.eval_det_utils import eval_det_run from eval_utils.eval_det_utils import eval_det_run
from eval_utils.eval_rec_utils import eval_rec_run from eval_utils.eval_rec_utils import eval_rec_run
from eval_utils.eval_cls_utils import eval_cls_run
from ppocr.utils.save_load import save_model from ppocr.utils.save_load import save_model
import numpy as np import numpy as np
from ppocr.utils.character import cal_predicts_accuracy, cal_predicts_accuracy_srn, CharacterOps from ppocr.utils.character import cal_predicts_accuracy, cal_predicts_accuracy_srn, CharacterOps
...@@ -409,6 +410,87 @@ def train_eval_rec_run(config, exe, train_info_dict, eval_info_dict): ...@@ -409,6 +410,87 @@ def train_eval_rec_run(config, exe, train_info_dict, eval_info_dict):
return return
def train_eval_cls_run(config, exe, train_info_dict, eval_info_dict):
train_batch_id = 0
log_smooth_window = config['Global']['log_smooth_window']
epoch_num = config['Global']['epoch_num']
print_batch_step = config['Global']['print_batch_step']
eval_batch_step = config['Global']['eval_batch_step']
start_eval_step = 0
if type(eval_batch_step) == list and len(eval_batch_step) >= 2:
start_eval_step = eval_batch_step[0]
eval_batch_step = eval_batch_step[1]
logger.info(
"During the training process, after the {}th iteration, an evaluation is run every {} iterations".
format(start_eval_step, eval_batch_step))
save_epoch_step = config['Global']['save_epoch_step']
save_model_dir = config['Global']['save_model_dir']
if not os.path.exists(save_model_dir):
os.makedirs(save_model_dir)
train_stats = TrainingStats(log_smooth_window, ['loss', 'acc'])
best_eval_acc = -1
best_batch_id = 0
best_epoch = 0
train_loader = train_info_dict['reader']
for epoch in range(epoch_num):
train_loader.start()
try:
while True:
t1 = time.time()
train_outs = exe.run(
program=train_info_dict['compile_program'],
fetch_list=train_info_dict['fetch_varname_list'],
return_numpy=False)
fetch_map = dict(
zip(train_info_dict['fetch_name_list'],
range(len(train_outs))))
loss = np.mean(np.array(train_outs[fetch_map['total_loss']]))
lr = np.mean(np.array(train_outs[fetch_map['lr']]))
acc = np.mean(np.array(train_outs[fetch_map['acc']]))
t2 = time.time()
train_batch_elapse = t2 - t1
stats = {'loss': loss, 'acc': acc}
train_stats.update(stats)
if train_batch_id > start_eval_step and (train_batch_id - start_eval_step) \
% print_batch_step == 0:
logs = train_stats.log()
strs = 'epoch: {}, iter: {}, lr: {:.6f}, {}, time: {:.3f}'.format(
epoch, train_batch_id, lr, logs, train_batch_elapse)
logger.info(strs)
if train_batch_id > 0 and\
train_batch_id % eval_batch_step == 0:
model_average = train_info_dict['model_average']
if model_average != None:
model_average.apply(exe)
metrics = eval_cls_run(exe, eval_info_dict)
eval_acc = metrics['avg_acc']
eval_sample_num = metrics['total_sample_num']
if eval_acc > best_eval_acc:
best_eval_acc = eval_acc
best_batch_id = train_batch_id
best_epoch = epoch
save_path = save_model_dir + "/best_accuracy"
save_model(train_info_dict['train_program'], save_path)
strs = 'Test iter: {}, acc:{:.6f}, best_acc:{:.6f}, best_epoch:{}, best_batch_id:{}, eval_sample_num:{}'.format(
train_batch_id, eval_acc, best_eval_acc, best_epoch,
best_batch_id, eval_sample_num)
logger.info(strs)
train_batch_id += 1
except fluid.core.EOFException:
train_loader.reset()
if epoch == 0 and save_epoch_step == 1:
save_path = save_model_dir + "/iter_epoch_0"
save_model(train_info_dict['train_program'], save_path)
if epoch > 0 and epoch % save_epoch_step == 0:
save_path = save_model_dir + "/iter_epoch_%d" % (epoch)
save_model(train_info_dict['train_program'], save_path)
return
def preprocess(): def preprocess():
FLAGS = ArgsParser().parse_args() FLAGS = ArgsParser().parse_args()
config = load_config(FLAGS.config) config = load_config(FLAGS.config)
...@@ -421,7 +503,7 @@ def preprocess(): ...@@ -421,7 +503,7 @@ def preprocess():
alg = config['Global']['algorithm'] alg = config['Global']['algorithm']
assert alg in [ assert alg in [
'EAST', 'DB', 'SAST', 'Rosetta', 'CRNN', 'STARNet', 'RARE', 'SRN' 'EAST', 'DB', 'SAST', 'Rosetta', 'CRNN', 'STARNet', 'RARE', 'SRN', 'CLS'
] ]
if alg in ['Rosetta', 'CRNN', 'STARNet', 'RARE', 'SRN']: if alg in ['Rosetta', 'CRNN', 'STARNet', 'RARE', 'SRN']:
config['Global']['char_ops'] = CharacterOps(config['Global']) config['Global']['char_ops'] = CharacterOps(config['Global'])
...@@ -432,7 +514,9 @@ def preprocess(): ...@@ -432,7 +514,9 @@ def preprocess():
if alg in ['EAST', 'DB', 'SAST']: if alg in ['EAST', 'DB', 'SAST']:
train_alg_type = 'det' train_alg_type = 'det'
else: elif alg in ['Rosetta', 'CRNN', 'STARNet', 'RARE', 'SRN']:
train_alg_type = 'rec' train_alg_type = 'rec'
else:
train_alg_type = 'cls'
return startup_program, train_program, place, config, train_alg_type return startup_program, train_program, place, config, train_alg_type
...@@ -75,7 +75,8 @@ def main(): ...@@ -75,7 +75,8 @@ def main():
# dump mode structure # dump mode structure
if config['Global']['debug']: if config['Global']['debug']:
if train_alg_type == 'rec' and 'attention' in config['Global']['loss_type']: if train_alg_type == 'rec' and 'attention' in config['Global'][
'loss_type']:
logger.warning('Does not suport dump attention...') logger.warning('Does not suport dump attention...')
else: else:
summary(train_program) summary(train_program)
...@@ -96,8 +97,10 @@ def main(): ...@@ -96,8 +97,10 @@ def main():
if train_alg_type == 'det': if train_alg_type == 'det':
program.train_eval_det_run(config, exe, train_info_dict, eval_info_dict) program.train_eval_det_run(config, exe, train_info_dict, eval_info_dict)
else: elif train_alg_type == 'rec':
program.train_eval_rec_run(config, exe, train_info_dict, eval_info_dict) program.train_eval_rec_run(config, exe, train_info_dict, eval_info_dict)
else:
program.train_eval_cls_run(config, exe, train_info_dict, eval_info_dict)
def test_reader(): def test_reader():
...@@ -119,6 +122,7 @@ def test_reader(): ...@@ -119,6 +122,7 @@ def test_reader():
if __name__ == '__main__': if __name__ == '__main__':
startup_program, train_program, place, config, train_alg_type = program.preprocess() startup_program, train_program, place, config, train_alg_type = program.preprocess(
)
main() main()
# test_reader() # test_reader()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册