diff --git a/configs/cls/cls_mv3.yml b/configs/cls/cls_mv3.yml new file mode 100755 index 0000000000000000000000000000000000000000..57afab507c03c2a32f1665f908170de05d91143a --- /dev/null +++ b/configs/cls/cls_mv3.yml @@ -0,0 +1,44 @@ +Global: + algorithm: CLS + use_gpu: False + epoch_num: 100 + log_smooth_window: 20 + print_batch_step: 100 + save_model_dir: output/cls_mv3 + save_epoch_step: 3 + eval_batch_step: 500 + train_batch_size_per_card: 512 + test_batch_size_per_card: 512 + image_shape: [3, 48, 192] + label_list: ['0','180'] + distort: True + reader_yml: ./configs/cls/cls_reader.yml + pretrain_weights: + checkpoints: + save_inference_dir: + infer_img: + +Architecture: + function: ppocr.modeling.architectures.cls_model,ClsModel + +Backbone: + function: ppocr.modeling.backbones.rec_mobilenet_v3,MobileNetV3 + scale: 0.35 + model_name: small + +Head: + function: ppocr.modeling.heads.cls_head,ClsHead + class_dim: 2 + +Loss: + function: ppocr.modeling.losses.cls_loss,ClsLoss + +Optimizer: + function: ppocr.optimizer,AdamDecay + base_lr: 0.001 + beta1: 0.9 + beta2: 0.999 + decay: + function: cosine_decay + step_each_epoch: 1169 + total_epoch: 100 \ No newline at end of file diff --git a/configs/cls/cls_reader.yml b/configs/cls/cls_reader.yml new file mode 100755 index 0000000000000000000000000000000000000000..2b1d4c4e75217998f2c489bcd3bfbbb8b8b7f415 --- /dev/null +++ b/configs/cls/cls_reader.yml @@ -0,0 +1,13 @@ +TrainReader: + reader_function: ppocr.data.cls.dataset_traversal,SimpleReader + num_workers: 8 + img_set_dir: ./train_data/cls + label_file_path: ./train_data/cls/train.txt + +EvalReader: + reader_function: ppocr.data.cls.dataset_traversal,SimpleReader + img_set_dir: ./train_data/cls + label_file_path: ./train_data/cls/test.txt + +TestReader: + reader_function: ppocr.data.cls.dataset_traversal,SimpleReader diff --git a/deploy/android_demo/app/src/main/assets/images/180.jpg b/deploy/android_demo/app/src/main/assets/images/180.jpg new file mode 100644 index 0000000000000000000000000000000000000000..84cf4c79ef14769d01b0b0e9667387bd16b3e6e7 Binary files /dev/null and b/deploy/android_demo/app/src/main/assets/images/180.jpg differ diff --git a/deploy/android_demo/app/src/main/assets/images/270.jpg b/deploy/android_demo/app/src/main/assets/images/270.jpg new file mode 100644 index 0000000000000000000000000000000000000000..568739043b7779425b0abeb4459dbb485caed847 Binary files /dev/null and b/deploy/android_demo/app/src/main/assets/images/270.jpg differ diff --git a/deploy/android_demo/app/src/main/assets/images/90.jpg b/deploy/android_demo/app/src/main/assets/images/90.jpg new file mode 100644 index 0000000000000000000000000000000000000000..49e949aa9cc14e3afc507c5806c87d9894c2dcb9 Binary files /dev/null and b/deploy/android_demo/app/src/main/assets/images/90.jpg differ diff --git a/deploy/android_demo/app/src/main/cpp/native.cpp b/deploy/android_demo/app/src/main/cpp/native.cpp index 390c594deb02a8f82693f2c83741a4750fe7cb25..963c5246d5b7b50720f92705d288526ae2cc6a73 100644 --- a/deploy/android_demo/app/src/main/cpp/native.cpp +++ b/deploy/android_demo/app/src/main/cpp/native.cpp @@ -4,29 +4,29 @@ #include "native.h" #include "ocr_ppredictor.h" -#include #include #include +#include static paddle::lite_api::PowerMode str_to_cpu_mode(const std::string &cpu_mode); -extern "C" -JNIEXPORT jlong JNICALL -Java_com_baidu_paddle_lite_demo_ocr_OCRPredictorNative_init(JNIEnv *env, jobject thiz, - jstring j_det_model_path, - jstring j_rec_model_path, - jint j_thread_num, - jstring j_cpu_mode) { - std::string det_model_path = jstring_to_cpp_string(env, j_det_model_path); - std::string rec_model_path = jstring_to_cpp_string(env, j_rec_model_path); - int thread_num = j_thread_num; - std::string cpu_mode = jstring_to_cpp_string(env, j_cpu_mode); - ppredictor::OCR_Config conf; - conf.thread_num = thread_num; - conf.mode = str_to_cpu_mode(cpu_mode); - ppredictor::OCR_PPredictor *orc_predictor = new ppredictor::OCR_PPredictor{conf}; - orc_predictor->init_from_file(det_model_path, rec_model_path); - return reinterpret_cast(orc_predictor); +extern "C" JNIEXPORT jlong JNICALL +Java_com_baidu_paddle_lite_demo_ocr_OCRPredictorNative_init( + JNIEnv *env, jobject thiz, jstring j_det_model_path, + jstring j_rec_model_path, jstring j_cls_model_path, jint j_thread_num, + jstring j_cpu_mode) { + std::string det_model_path = jstring_to_cpp_string(env, j_det_model_path); + std::string rec_model_path = jstring_to_cpp_string(env, j_rec_model_path); + std::string cls_model_path = jstring_to_cpp_string(env, j_cls_model_path); + int thread_num = j_thread_num; + std::string cpu_mode = jstring_to_cpp_string(env, j_cpu_mode); + ppredictor::OCR_Config conf; + conf.thread_num = thread_num; + conf.mode = str_to_cpu_mode(cpu_mode); + ppredictor::OCR_PPredictor *orc_predictor = + new ppredictor::OCR_PPredictor{conf}; + orc_predictor->init_from_file(det_model_path, rec_model_path, cls_model_path); + return reinterpret_cast(orc_predictor); } /** @@ -34,82 +34,81 @@ Java_com_baidu_paddle_lite_demo_ocr_OCRPredictorNative_init(JNIEnv *env, jobject * @param cpu_mode * @return */ -static paddle::lite_api::PowerMode str_to_cpu_mode(const std::string &cpu_mode) { - static std::map cpu_mode_map{ - {"LITE_POWER_HIGH", paddle::lite_api::LITE_POWER_HIGH}, - {"LITE_POWER_LOW", paddle::lite_api::LITE_POWER_HIGH}, - {"LITE_POWER_FULL", paddle::lite_api::LITE_POWER_FULL}, - {"LITE_POWER_NO_BIND", paddle::lite_api::LITE_POWER_NO_BIND}, - {"LITE_POWER_RAND_HIGH", paddle::lite_api::LITE_POWER_RAND_HIGH}, - {"LITE_POWER_RAND_LOW", paddle::lite_api::LITE_POWER_RAND_LOW} - }; - std::string upper_key; - std::transform(cpu_mode.cbegin(), cpu_mode.cend(), upper_key.begin(), ::toupper); - auto index = cpu_mode_map.find(upper_key); - if (index == cpu_mode_map.end()) { - LOGE("cpu_mode not found %s", upper_key.c_str()); - return paddle::lite_api::LITE_POWER_HIGH; - } else { - return index->second; - } - +static paddle::lite_api::PowerMode +str_to_cpu_mode(const std::string &cpu_mode) { + static std::map cpu_mode_map{ + {"LITE_POWER_HIGH", paddle::lite_api::LITE_POWER_HIGH}, + {"LITE_POWER_LOW", paddle::lite_api::LITE_POWER_HIGH}, + {"LITE_POWER_FULL", paddle::lite_api::LITE_POWER_FULL}, + {"LITE_POWER_NO_BIND", paddle::lite_api::LITE_POWER_NO_BIND}, + {"LITE_POWER_RAND_HIGH", paddle::lite_api::LITE_POWER_RAND_HIGH}, + {"LITE_POWER_RAND_LOW", paddle::lite_api::LITE_POWER_RAND_LOW}}; + std::string upper_key; + std::transform(cpu_mode.cbegin(), cpu_mode.cend(), upper_key.begin(), + ::toupper); + auto index = cpu_mode_map.find(upper_key); + if (index == cpu_mode_map.end()) { + LOGE("cpu_mode not found %s", upper_key.c_str()); + return paddle::lite_api::LITE_POWER_HIGH; + } else { + return index->second; + } } -extern "C" -JNIEXPORT jfloatArray JNICALL -Java_com_baidu_paddle_lite_demo_ocr_OCRPredictorNative_forward(JNIEnv *env, jobject thiz, - jlong java_pointer, jfloatArray buf, - jfloatArray ddims, - jobject original_image) { - LOGI("begin to run native forward"); - if (java_pointer == 0) { - LOGE("JAVA pointer is NULL"); - return cpp_array_to_jfloatarray(env, nullptr, 0); - } - cv::Mat origin = bitmap_to_cv_mat(env, original_image); - if (origin.size == 0) { - LOGE("origin bitmap cannot convert to CV Mat"); - return cpp_array_to_jfloatarray(env, nullptr, 0); - } - ppredictor::OCR_PPredictor *ppredictor = (ppredictor::OCR_PPredictor *) java_pointer; - std::vector dims_float_arr = jfloatarray_to_float_vector(env, ddims); - std::vector dims_arr; - dims_arr.resize(dims_float_arr.size()); - std::copy(dims_float_arr.cbegin(), dims_float_arr.cend(), dims_arr.begin()); +extern "C" JNIEXPORT jfloatArray JNICALL +Java_com_baidu_paddle_lite_demo_ocr_OCRPredictorNative_forward( + JNIEnv *env, jobject thiz, jlong java_pointer, jfloatArray buf, + jfloatArray ddims, jobject original_image) { + LOGI("begin to run native forward"); + if (java_pointer == 0) { + LOGE("JAVA pointer is NULL"); + return cpp_array_to_jfloatarray(env, nullptr, 0); + } + cv::Mat origin = bitmap_to_cv_mat(env, original_image); + if (origin.size == 0) { + LOGE("origin bitmap cannot convert to CV Mat"); + return cpp_array_to_jfloatarray(env, nullptr, 0); + } + ppredictor::OCR_PPredictor *ppredictor = + (ppredictor::OCR_PPredictor *)java_pointer; + std::vector dims_float_arr = jfloatarray_to_float_vector(env, ddims); + std::vector dims_arr; + dims_arr.resize(dims_float_arr.size()); + std::copy(dims_float_arr.cbegin(), dims_float_arr.cend(), dims_arr.begin()); - // 这里值有点大,就不调用jfloatarray_to_float_vector了 - int64_t buf_len = (int64_t) env->GetArrayLength(buf); - jfloat *buf_data = env->GetFloatArrayElements(buf, JNI_FALSE); - float *data = (jfloat *) buf_data; - std::vector results = ppredictor->infer_ocr(dims_arr, data, - buf_len, - NET_OCR, origin); - LOGI("infer_ocr finished with boxes %ld", results.size()); - // 这里将std::vector 序列化成 float数组,传输到java层再反序列化 - std::vector float_arr; - for (const ppredictor::OCRPredictResult &r :results) { - float_arr.push_back(r.points.size()); - float_arr.push_back(r.word_index.size()); - float_arr.push_back(r.score); - for (const std::vector &point : r.points) { - float_arr.push_back(point.at(0)); - float_arr.push_back(point.at(1)); - } - for (int index: r.word_index) { - float_arr.push_back(index); - } + // 这里值有点大,就不调用jfloatarray_to_float_vector了 + int64_t buf_len = (int64_t)env->GetArrayLength(buf); + jfloat *buf_data = env->GetFloatArrayElements(buf, JNI_FALSE); + float *data = (jfloat *)buf_data; + std::vector results = + ppredictor->infer_ocr(dims_arr, data, buf_len, NET_OCR, origin); + LOGI("infer_ocr finished with boxes %ld", results.size()); + // 这里将std::vector 序列化成 + // float数组,传输到java层再反序列化 + std::vector float_arr; + for (const ppredictor::OCRPredictResult &r : results) { + float_arr.push_back(r.points.size()); + float_arr.push_back(r.word_index.size()); + float_arr.push_back(r.score); + for (const std::vector &point : r.points) { + float_arr.push_back(point.at(0)); + float_arr.push_back(point.at(1)); } - return cpp_array_to_jfloatarray(env, float_arr.data(), float_arr.size()); + for (int index : r.word_index) { + float_arr.push_back(index); + } + } + return cpp_array_to_jfloatarray(env, float_arr.data(), float_arr.size()); } -extern "C" -JNIEXPORT void JNICALL -Java_com_baidu_paddle_lite_demo_ocr_OCRPredictorNative_release(JNIEnv *env, jobject thiz, - jlong java_pointer){ - if (java_pointer == 0) { - LOGE("JAVA pointer is NULL"); - return; - } - ppredictor::OCR_PPredictor *ppredictor = (ppredictor::OCR_PPredictor *) java_pointer; - delete ppredictor; +extern "C" JNIEXPORT void JNICALL +Java_com_baidu_paddle_lite_demo_ocr_OCRPredictorNative_release( + JNIEnv *env, jobject thiz, jlong java_pointer) { + if (java_pointer == 0) { + LOGE("JAVA pointer is NULL"); + return; + } + ppredictor::OCR_PPredictor *ppredictor = + (ppredictor::OCR_PPredictor *)java_pointer; + delete ppredictor; } \ No newline at end of file diff --git a/deploy/android_demo/app/src/main/cpp/ocr_cls_process.cpp b/deploy/android_demo/app/src/main/cpp/ocr_cls_process.cpp new file mode 100644 index 0000000000000000000000000000000000000000..d720066667b60ee87bc1a1227ad720074254074e --- /dev/null +++ b/deploy/android_demo/app/src/main/cpp/ocr_cls_process.cpp @@ -0,0 +1,46 @@ +// Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "ocr_cls_process.h" +#include +#include +#include +#include +#include +#include + +const std::vector CLS_IMAGE_SHAPE = {3, 32, 100}; + +cv::Mat cls_resize_img(const cv::Mat &img) { + int imgC = CLS_IMAGE_SHAPE[0]; + int imgW = CLS_IMAGE_SHAPE[2]; + int imgH = CLS_IMAGE_SHAPE[1]; + + float ratio = float(img.cols) / float(img.rows); + int resize_w = 0; + if (ceilf(imgH * ratio) > imgW) + resize_w = imgW; + else + resize_w = int(ceilf(imgH * ratio)); + + cv::Mat resize_img; + cv::resize(img, resize_img, cv::Size(resize_w, imgH), 0.f, 0.f, + cv::INTER_CUBIC); + + if (resize_w < imgW) { + cv::copyMakeBorder(resize_img, resize_img, 0, 0, 0, int(imgW - resize_w), + cv::BORDER_CONSTANT, {0, 0, 0}); + } + return resize_img; +} \ No newline at end of file diff --git a/deploy/android_demo/app/src/main/cpp/ocr_cls_process.h b/deploy/android_demo/app/src/main/cpp/ocr_cls_process.h new file mode 100644 index 0000000000000000000000000000000000000000..1c30ee1071e647ce1ab7050ac0641d0eff7c62ad --- /dev/null +++ b/deploy/android_demo/app/src/main/cpp/ocr_cls_process.h @@ -0,0 +1,23 @@ +// Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include "common.h" +#include +#include + +extern const std::vector CLS_IMAGE_SHAPE; + +cv::Mat cls_resize_img(const cv::Mat &img); \ No newline at end of file diff --git a/deploy/android_demo/app/src/main/cpp/ocr_ppredictor.cpp b/deploy/android_demo/app/src/main/cpp/ocr_ppredictor.cpp index 3d0147715519c195fd48f7f84b7a28a5a82f5363..f0d855e83f010ef762cb4b01086e41a0f64fb4cb 100644 --- a/deploy/android_demo/app/src/main/cpp/ocr_ppredictor.cpp +++ b/deploy/android_demo/app/src/main/cpp/ocr_ppredictor.cpp @@ -3,38 +3,48 @@ // #include "ocr_ppredictor.h" -#include "preprocess.h" #include "common.h" -#include "ocr_db_post_process.h" +#include "ocr_cls_process.h" #include "ocr_crnn_process.h" +#include "ocr_db_post_process.h" +#include "preprocess.h" namespace ppredictor { -OCR_PPredictor::OCR_PPredictor(const OCR_Config &config) : _config(config) { +OCR_PPredictor::OCR_PPredictor(const OCR_Config &config) : _config(config) {} -} +int OCR_PPredictor::init(const std::string &det_model_content, + const std::string &rec_model_content, + const std::string &cls_model_content) { + _det_predictor = std::unique_ptr( + new PPredictor{_config.thread_num, NET_OCR, _config.mode}); + _det_predictor->init_nb(det_model_content); -int -OCR_PPredictor::init(const std::string &det_model_content, const std::string &rec_model_content) { - _det_predictor = std::unique_ptr( - new PPredictor{_config.thread_num, NET_OCR, _config.mode}); - _det_predictor->init_nb(det_model_content); + _rec_predictor = std::unique_ptr( + new PPredictor{_config.thread_num, NET_OCR_INTERNAL, _config.mode}); + _rec_predictor->init_nb(rec_model_content); - _rec_predictor = std::unique_ptr( - new PPredictor{_config.thread_num, NET_OCR_INTERNAL, _config.mode}); - _rec_predictor->init_nb(rec_model_content); - return RETURN_OK; + _cls_predictor = std::unique_ptr( + new PPredictor{_config.thread_num, NET_OCR_INTERNAL, _config.mode}); + _cls_predictor->init_nb(cls_model_content); + return RETURN_OK; } -int OCR_PPredictor::init_from_file(const std::string &det_model_path, const std::string &rec_model_path){ - _det_predictor = std::unique_ptr( - new PPredictor{_config.thread_num, NET_OCR, _config.mode}); - _det_predictor->init_from_file(det_model_path); - - _rec_predictor = std::unique_ptr( - new PPredictor{_config.thread_num, NET_OCR_INTERNAL, _config.mode}); - _rec_predictor->init_from_file(rec_model_path); - return RETURN_OK; +int OCR_PPredictor::init_from_file(const std::string &det_model_path, + const std::string &rec_model_path, + const std::string &cls_model_path) { + _det_predictor = std::unique_ptr( + new PPredictor{_config.thread_num, NET_OCR, _config.mode}); + _det_predictor->init_from_file(det_model_path); + + _rec_predictor = std::unique_ptr( + new PPredictor{_config.thread_num, NET_OCR_INTERNAL, _config.mode}); + _rec_predictor->init_from_file(rec_model_path); + + _cls_predictor = std::unique_ptr( + new PPredictor{_config.thread_num, NET_OCR_INTERNAL, _config.mode}); + _cls_predictor->init_from_file(cls_model_path); + return RETURN_OK; } /** * for debug use, show result of First Step @@ -42,145 +52,188 @@ int OCR_PPredictor::init_from_file(const std::string &det_model_path, const std: * @param boxes * @param srcimg */ -static void visual_img(const std::vector>> &filter_boxes, - const std::vector>> &boxes, - const cv::Mat &srcimg) { - // visualization - cv::Point rook_points[filter_boxes.size()][4]; - for (int n = 0; n < filter_boxes.size(); n++) { - for (int m = 0; m < filter_boxes[0].size(); m++) { - rook_points[n][m] = cv::Point(int(filter_boxes[n][m][0]), int(filter_boxes[n][m][1])); - } +static void +visual_img(const std::vector>> &filter_boxes, + const std::vector>> &boxes, + const cv::Mat &srcimg) { + // visualization + cv::Point rook_points[filter_boxes.size()][4]; + for (int n = 0; n < filter_boxes.size(); n++) { + for (int m = 0; m < filter_boxes[0].size(); m++) { + rook_points[n][m] = + cv::Point(int(filter_boxes[n][m][0]), int(filter_boxes[n][m][1])); } - - cv::Mat img_vis; - srcimg.copyTo(img_vis); - for (int n = 0; n < boxes.size(); n++) { - const cv::Point *ppt[1] = {rook_points[n]}; - int npt[] = {4}; - cv::polylines(img_vis, ppt, npt, 1, 1, CV_RGB(0, 255, 0), 2, 8, 0); - } - // 调试用,自行替换需要修改的路径 - cv::imwrite("/sdcard/1/vis.png", img_vis); + } + + cv::Mat img_vis; + srcimg.copyTo(img_vis); + for (int n = 0; n < boxes.size(); n++) { + const cv::Point *ppt[1] = {rook_points[n]}; + int npt[] = {4}; + cv::polylines(img_vis, ppt, npt, 1, 1, CV_RGB(0, 255, 0), 2, 8, 0); + } + // 调试用,自行替换需要修改的路径 + cv::imwrite("/sdcard/1/vis.png", img_vis); } std::vector -OCR_PPredictor::infer_ocr(const std::vector &dims, const float *input_data, int input_len, - int net_flag, cv::Mat &origin) { +OCR_PPredictor::infer_ocr(const std::vector &dims, + const float *input_data, int input_len, int net_flag, + cv::Mat &origin) { + PredictorInput input = _det_predictor->get_first_input(); + input.set_dims(dims); + input.set_data(input_data, input_len); + std::vector results = _det_predictor->infer(); + PredictorOutput &res = results.at(0); + std::vector>> filtered_box = calc_filtered_boxes( + res.get_float_data(), res.get_size(), (int)dims[2], (int)dims[3], origin); + LOGI("Filter_box size %ld", filtered_box.size()); + return infer_rec(filtered_box, origin); +} - PredictorInput input = _det_predictor->get_first_input(); +std::vector OCR_PPredictor::infer_rec( + const std::vector>> &boxes, + const cv::Mat &origin_img) { + std::vector mean = {0.5f, 0.5f, 0.5f}; + std::vector scale = {1 / 0.5f, 1 / 0.5f, 1 / 0.5f}; + std::vector dims = {1, 3, 0, 0}; + std::vector ocr_results; + + PredictorInput input = _rec_predictor->get_first_input(); + for (auto bp = boxes.crbegin(); bp != boxes.crend(); ++bp) { + const std::vector> &box = *bp; + cv::Mat crop_img = get_rotate_crop_image(origin_img, box); + crop_img = infer_cls(crop_img); + + float wh_ratio = float(crop_img.cols) / float(crop_img.rows); + cv::Mat input_image = crnn_resize_img(crop_img, wh_ratio); + input_image.convertTo(input_image, CV_32FC3, 1 / 255.0f); + const float *dimg = reinterpret_cast(input_image.data); + int input_size = input_image.rows * input_image.cols; + + dims[2] = input_image.rows; + dims[3] = input_image.cols; input.set_dims(dims); - input.set_data(input_data, input_len); - std::vector results = _det_predictor->infer(); - PredictorOutput &res = results.at(0); - std::vector>> filtered_box - = calc_filtered_boxes(res.get_float_data(), res.get_size(), (int) dims[2], (int) dims[3], - origin); - LOGI("Filter_box size %ld", filtered_box.size()); - return infer_rec(filtered_box, origin); -} -std::vector -OCR_PPredictor::infer_rec(const std::vector>> &boxes, - const cv::Mat &origin_img) { - std::vector mean = {0.5f, 0.5f, 0.5f}; - std::vector scale = {1 / 0.5f, 1 / 0.5f, 1 / 0.5f}; - std::vector dims = {1, 3, 0, 0}; - std::vector ocr_results; - - PredictorInput input = _rec_predictor->get_first_input(); - for (auto bp = boxes.crbegin(); bp != boxes.crend(); ++bp) { - const std::vector> &box = *bp; - cv::Mat crop_img = get_rotate_crop_image(origin_img, box); - float wh_ratio = float(crop_img.cols) / float(crop_img.rows); - cv::Mat input_image = crnn_resize_img(crop_img, wh_ratio); - input_image.convertTo(input_image, CV_32FC3, 1 / 255.0f); - const float *dimg = reinterpret_cast(input_image.data); - int input_size = input_image.rows * input_image.cols; - - dims[2] = input_image.rows; - dims[3] = input_image.cols; - input.set_dims(dims); - - neon_mean_scale(dimg, input.get_mutable_float_data(), input_size, mean, scale); - - std::vector results = _rec_predictor->infer(); - - OCRPredictResult res; - res.word_index = postprocess_rec_word_index(results.at(0)); - if (res.word_index.empty()) { - continue; - } - res.score = postprocess_rec_score(results.at(1)); - res.points = box; - ocr_results.emplace_back(std::move(res)); + neon_mean_scale(dimg, input.get_mutable_float_data(), input_size, mean, + scale); + + std::vector results = _rec_predictor->infer(); + + OCRPredictResult res; + res.word_index = postprocess_rec_word_index(results.at(0)); + if (res.word_index.empty()) { + continue; } - LOGI("ocr_results finished %lu", ocr_results.size()); - return ocr_results; + res.score = postprocess_rec_score(results.at(1)); + res.points = box; + ocr_results.emplace_back(std::move(res)); + } + LOGI("ocr_results finished %lu", ocr_results.size()); + return ocr_results; +} + +cv::Mat OCR_PPredictor::infer_cls(const cv::Mat &img, float thresh) { + std::vector mean = {0.5f, 0.5f, 0.5f}; + std::vector scale = {1 / 0.5f, 1 / 0.5f, 1 / 0.5f}; + std::vector dims = {1, 3, 0, 0}; + std::vector ocr_results; + + PredictorInput input = _cls_predictor->get_first_input(); + + cv::Mat input_image = cls_resize_img(img); + input_image.convertTo(input_image, CV_32FC3, 1 / 255.0f); + const float *dimg = reinterpret_cast(input_image.data); + int input_size = input_image.rows * input_image.cols; + + dims[2] = input_image.rows; + dims[3] = input_image.cols; + input.set_dims(dims); + + neon_mean_scale(dimg, input.get_mutable_float_data(), input_size, mean, + scale); + + std::vector results = _cls_predictor->infer(); + + const float *scores = results.at(0).get_float_data(); + const int *labels = results.at(1).get_int_data(); + for (int64_t i = 0; i < results.at(0).get_size(); i++) { + LOGI("output scores [%f]", scores[i]); + } + for (int64_t i = 0; i < results.at(1).get_size(); i++) { + LOGI("output label [%d]", labels[i]); + } + int label_idx = labels[0]; + float score = scores[label_idx]; + + cv::Mat srcimg; + img.copyTo(srcimg); + if (label_idx % 2 == 1 && score > thresh) { + cv::rotate(srcimg, srcimg, 1); + } + return srcimg; } std::vector>> -OCR_PPredictor::calc_filtered_boxes(const float *pred, int pred_size, int output_height, - int output_width, const cv::Mat &origin) { - const double threshold = 0.3; - const double maxvalue = 1; - - cv::Mat pred_map = cv::Mat::zeros(output_height, output_width, CV_32F); - memcpy(pred_map.data, pred, pred_size * sizeof(float)); - cv::Mat cbuf_map; - pred_map.convertTo(cbuf_map, CV_8UC1); - - cv::Mat bit_map; - cv::threshold(cbuf_map, bit_map, threshold, maxvalue, cv::THRESH_BINARY); - - std::vector>> boxes = boxes_from_bitmap(pred_map, bit_map); - float ratio_h = output_height * 1.0f / origin.rows; - float ratio_w = output_width * 1.0f / origin.cols; - std::vector>> filter_boxes = filter_tag_det_res(boxes, ratio_h, - ratio_w, origin); - return filter_boxes; +OCR_PPredictor::calc_filtered_boxes(const float *pred, int pred_size, + int output_height, int output_width, + const cv::Mat &origin) { + const double threshold = 0.3; + const double maxvalue = 1; + + cv::Mat pred_map = cv::Mat::zeros(output_height, output_width, CV_32F); + memcpy(pred_map.data, pred, pred_size * sizeof(float)); + cv::Mat cbuf_map; + pred_map.convertTo(cbuf_map, CV_8UC1); + + cv::Mat bit_map; + cv::threshold(cbuf_map, bit_map, threshold, maxvalue, cv::THRESH_BINARY); + + std::vector>> boxes = + boxes_from_bitmap(pred_map, bit_map); + float ratio_h = output_height * 1.0f / origin.rows; + float ratio_w = output_width * 1.0f / origin.cols; + std::vector>> filter_boxes = + filter_tag_det_res(boxes, ratio_h, ratio_w, origin); + return filter_boxes; } -std::vector OCR_PPredictor::postprocess_rec_word_index(const PredictorOutput &res) { - const int *rec_idx = res.get_int_data(); - const std::vector> rec_idx_lod = res.get_lod(); +std::vector +OCR_PPredictor::postprocess_rec_word_index(const PredictorOutput &res) { + const int *rec_idx = res.get_int_data(); + const std::vector> rec_idx_lod = res.get_lod(); - std::vector pred_idx; - for (int n = int(rec_idx_lod[0][0]); n < int(rec_idx_lod[0][1] * 2); n += 2) { - pred_idx.emplace_back(rec_idx[n]); - } - return pred_idx; + std::vector pred_idx; + for (int n = int(rec_idx_lod[0][0]); n < int(rec_idx_lod[0][1] * 2); n += 2) { + pred_idx.emplace_back(rec_idx[n]); + } + return pred_idx; } float OCR_PPredictor::postprocess_rec_score(const PredictorOutput &res) { - const float *predict_batch = res.get_float_data(); - const std::vector predict_shape = res.get_shape(); - const std::vector> predict_lod = res.get_lod(); - int blank = predict_shape[1]; - float score = 0.f; - int count = 0; - for (int n = predict_lod[0][0]; n < predict_lod[0][1] - 1; n++) { - int argmax_idx = argmax(predict_batch + n * predict_shape[1], - predict_batch + (n + 1) * predict_shape[1]); - float max_value = predict_batch[n * predict_shape[1] + argmax_idx]; - if (blank - 1 - argmax_idx > 1e-5) { - score += max_value; - count += 1; - } - - } - if (count == 0) { - LOGE("calc score count 0"); - } else { - score /= count; + const float *predict_batch = res.get_float_data(); + const std::vector predict_shape = res.get_shape(); + const std::vector> predict_lod = res.get_lod(); + int blank = predict_shape[1]; + float score = 0.f; + int count = 0; + for (int n = predict_lod[0][0]; n < predict_lod[0][1] - 1; n++) { + int argmax_idx = argmax(predict_batch + n * predict_shape[1], + predict_batch + (n + 1) * predict_shape[1]); + float max_value = predict_batch[n * predict_shape[1] + argmax_idx]; + if (blank - 1 - argmax_idx > 1e-5) { + score += max_value; + count += 1; } - LOGI("calc score: %f", score); - return score; - + } + if (count == 0) { + LOGE("calc score count 0"); + } else { + score /= count; + } + LOGI("calc score: %f", score); + return score; } - -NET_TYPE OCR_PPredictor::get_net_flag() const { - return NET_OCR; -} +NET_TYPE OCR_PPredictor::get_net_flag() const { return NET_OCR; } } \ No newline at end of file diff --git a/deploy/android_demo/app/src/main/cpp/ocr_ppredictor.h b/deploy/android_demo/app/src/main/cpp/ocr_ppredictor.h index eb2bc3bc989c5dd9a2c5a8aae3508ca733602bd7..0ec458a4952cbc605e9979ce7850bdeab36c4629 100644 --- a/deploy/android_demo/app/src/main/cpp/ocr_ppredictor.h +++ b/deploy/android_demo/app/src/main/cpp/ocr_ppredictor.h @@ -4,10 +4,10 @@ #pragma once -#include +#include "ppredictor.h" #include #include -#include "ppredictor.h" +#include namespace ppredictor { @@ -15,17 +15,18 @@ namespace ppredictor { * Config */ struct OCR_Config { - int thread_num = 4; // Thread num - paddle::lite_api::PowerMode mode = paddle::lite_api::LITE_POWER_HIGH; // PaddleLite Mode + int thread_num = 4; // Thread num + paddle::lite_api::PowerMode mode = + paddle::lite_api::LITE_POWER_HIGH; // PaddleLite Mode }; /** * PolyGone Result */ struct OCRPredictResult { - std::vector word_index; - std::vector> points; - float score; + std::vector word_index; + std::vector> points; + float score; }; /** @@ -35,78 +36,87 @@ struct OCRPredictResult { */ class OCR_PPredictor : public PPredictor_Interface { public: - OCR_PPredictor(const OCR_Config &config); - - virtual ~OCR_PPredictor() { - - } - - /** - * 初始化二个模型的Predictor - * @param det_model_content - * @param rec_model_content - * @return - */ - int init(const std::string &det_model_content, const std::string &rec_model_content); - int init_from_file(const std::string &det_model_path, const std::string &rec_model_path); - /** - * Return OCR result - * @param dims - * @param input_data - * @param input_len - * @param net_flag - * @param origin - * @return - */ - virtual std::vector - infer_ocr(const std::vector &dims, const float *input_data, int input_len, - int net_flag, cv::Mat &origin); - - - virtual NET_TYPE get_net_flag() const; - + OCR_PPredictor(const OCR_Config &config); + + virtual ~OCR_PPredictor() {} + + /** + * 初始化二个模型的Predictor + * @param det_model_content + * @param rec_model_content + * @return + */ + int init(const std::string &det_model_content, + const std::string &rec_model_content, + const std::string &cls_model_content); + int init_from_file(const std::string &det_model_path, + const std::string &rec_model_path, + const std::string &cls_model_path); + /** + * Return OCR result + * @param dims + * @param input_data + * @param input_len + * @param net_flag + * @param origin + * @return + */ + virtual std::vector + infer_ocr(const std::vector &dims, const float *input_data, + int input_len, int net_flag, cv::Mat &origin); + + virtual NET_TYPE get_net_flag() const; private: - - /** - * calcul Polygone from the result image of first model - * @param pred - * @param output_height - * @param output_width - * @param origin - * @return - */ - std::vector>> - calc_filtered_boxes(const float *pred, int pred_size, int output_height, int output_width, - const cv::Mat &origin); - - /** - * infer for second model - * - * @param boxes - * @param origin - * @return - */ - std::vector - infer_rec(const std::vector>> &boxes, const cv::Mat &origin); - - /** - * Postprocess or sencod model to extract text - * @param res - * @return - */ - std::vector postprocess_rec_word_index(const PredictorOutput &res); - - /** - * calculate confidence of second model text result - * @param res - * @return - */ - float postprocess_rec_score(const PredictorOutput &res); - - std::unique_ptr _det_predictor; - std::unique_ptr _rec_predictor; - OCR_Config _config; - + /** + * calcul Polygone from the result image of first model + * @param pred + * @param output_height + * @param output_width + * @param origin + * @return + */ + std::vector>> + calc_filtered_boxes(const float *pred, int pred_size, int output_height, + int output_width, const cv::Mat &origin); + + /** + * infer for second model + * + * @param boxes + * @param origin + * @return + */ + std::vector + infer_rec(const std::vector>> &boxes, + const cv::Mat &origin); + + /** + * infer for cls model + * + * @param boxes + * @param origin + * @return + */ + cv::Mat infer_cls(const cv::Mat &origin, float thresh = 0.5); + + /** + * Postprocess or sencod model to extract text + * @param res + * @return + */ + std::vector postprocess_rec_word_index(const PredictorOutput &res); + + /** + * calculate confidence of second model text result + * @param res + * @return + */ + float postprocess_rec_score(const PredictorOutput &res); + + std::unique_ptr _det_predictor; + std::unique_ptr _rec_predictor; + std::unique_ptr _cls_predictor; + OCR_Config _config; }; } diff --git a/deploy/android_demo/app/src/main/java/com/baidu/paddle/lite/demo/ocr/OCRPredictorNative.java b/deploy/android_demo/app/src/main/java/com/baidu/paddle/lite/demo/ocr/OCRPredictorNative.java index 2e78a3ece96bb5e37bebcdda7ebc77060686b710..7499d4b92689645c0b1009256884733d392ff68d 100644 --- a/deploy/android_demo/app/src/main/java/com/baidu/paddle/lite/demo/ocr/OCRPredictorNative.java +++ b/deploy/android_demo/app/src/main/java/com/baidu/paddle/lite/demo/ocr/OCRPredictorNative.java @@ -29,7 +29,7 @@ public class OCRPredictorNative { public OCRPredictorNative(Config config) { this.config = config; loadLibrary(); - nativePointer = init(config.detModelFilename, config.recModelFilename, + nativePointer = init(config.detModelFilename, config.recModelFilename,config.clsModelFilename, config.cpuThreadNum, config.cpuPower); Log.i("OCRPredictorNative", "load success " + nativePointer); @@ -38,7 +38,7 @@ public class OCRPredictorNative { public void release() { if (nativePointer != 0) { nativePointer = 0; - destory(nativePointer); +// destory(nativePointer); } } @@ -55,10 +55,11 @@ public class OCRPredictorNative { public String cpuPower; public String detModelFilename; public String recModelFilename; + public String clsModelFilename; } - protected native long init(String detModelPath, String recModelPath, int threadNum, String cpuMode); + protected native long init(String detModelPath, String recModelPath,String clsModelPath, int threadNum, String cpuMode); protected native float[] forward(long pointer, float[] buf, float[] ddims, Bitmap originalImage); diff --git a/deploy/android_demo/app/src/main/java/com/baidu/paddle/lite/demo/ocr/Predictor.java b/deploy/android_demo/app/src/main/java/com/baidu/paddle/lite/demo/ocr/Predictor.java index 078bba286cc9cd5f9904e0594b5608c755a2b131..ddf69ab481618696189a7d0d45264791267e5631 100644 --- a/deploy/android_demo/app/src/main/java/com/baidu/paddle/lite/demo/ocr/Predictor.java +++ b/deploy/android_demo/app/src/main/java/com/baidu/paddle/lite/demo/ocr/Predictor.java @@ -121,7 +121,8 @@ public class Predictor { config.cpuThreadNum = cpuThreadNum; config.detModelFilename = realPath + File.separator + "ch_det_mv3_db_opt.nb"; config.recModelFilename = realPath + File.separator + "ch_rec_mv3_crnn_opt.nb"; - Log.e("Predictor", "model path" + config.detModelFilename + " ; " + config.recModelFilename); + config.clsModelFilename = realPath + File.separator + "cls_opt_arm.nb"; + Log.e("Predictor", "model path" + config.detModelFilename + " ; " + config.recModelFilename + ";" + config.clsModelFilename); config.cpuPower = cpuPowerMode; paddlePredictor = new OCRPredictorNative(config); diff --git a/deploy/cpp_infer/include/config.h b/deploy/cpp_infer/include/config.h index 8db693b121f1f91e30672de53e9b969babb49f8b..27539ea7934dc192e86bca3ea6bfd7999ee229a3 100644 --- a/deploy/cpp_infer/include/config.h +++ b/deploy/cpp_infer/include/config.h @@ -57,6 +57,12 @@ public: this->char_list_file.assign(config_map_["char_list_file"]); + this->use_angle_cls = bool(stoi(config_map_["use_angle_cls"])); + + this->cls_model_dir.assign(config_map_["cls_model_dir"]); + + this->cls_thresh = stod(config_map_["cls_thresh"]); + this->visualize = bool(stoi(config_map_["visualize"])); } @@ -84,8 +90,14 @@ public: std::string rec_model_dir; + bool use_angle_cls; + std::string char_list_file; + std::string cls_model_dir; + + double cls_thresh; + bool visualize = true; void PrintConfigInfo(); diff --git a/deploy/cpp_infer/include/ocr_cls.h b/deploy/cpp_infer/include/ocr_cls.h new file mode 100644 index 0000000000000000000000000000000000000000..38a37cff3c035eafe3617d83b2cc15ca47f30186 --- /dev/null +++ b/deploy/cpp_infer/include/ocr_cls.h @@ -0,0 +1,81 @@ +// Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "opencv2/core.hpp" +#include "opencv2/imgcodecs.hpp" +#include "opencv2/imgproc.hpp" +#include "paddle_api.h" +#include "paddle_inference_api.h" +#include +#include +#include +#include +#include + +#include +#include +#include + +#include +#include + +namespace PaddleOCR { + +class Classifier { +public: + explicit Classifier(const std::string &model_dir, const bool &use_gpu, + const int &gpu_id, const int &gpu_mem, + const int &cpu_math_library_num_threads, + const bool &use_mkldnn, const bool &use_zero_copy_run, + const double &cls_thresh) { + this->use_gpu_ = use_gpu; + this->gpu_id_ = gpu_id; + this->gpu_mem_ = gpu_mem; + this->cpu_math_library_num_threads_ = cpu_math_library_num_threads; + this->use_mkldnn_ = use_mkldnn; + this->use_zero_copy_run_ = use_zero_copy_run; + + this->cls_thresh = cls_thresh; + + LoadModel(model_dir); + } + + // Load Paddle inference model + void LoadModel(const std::string &model_dir); + + cv::Mat Run(cv::Mat &img); + +private: + std::shared_ptr predictor_; + + bool use_gpu_ = false; + int gpu_id_ = 0; + int gpu_mem_ = 4000; + int cpu_math_library_num_threads_ = 4; + bool use_mkldnn_ = false; + bool use_zero_copy_run_ = false; + double cls_thresh = 0.5; + + std::vector mean_ = {0.5f, 0.5f, 0.5f}; + std::vector scale_ = {1 / 0.5f, 1 / 0.5f, 1 / 0.5f}; + bool is_scale_ = true; + + // pre-process + ClsResizeImg resize_op_; + Normalize normalize_op_; + Permute permute_op_; + +}; // class Classifier + +} // namespace PaddleOCR diff --git a/deploy/cpp_infer/include/ocr_rec.h b/deploy/cpp_infer/include/ocr_rec.h index 520f0f2879dcec6b30861755b119227efa11b29c..a8b99a5960ac3e6238dfea2285ec51c9e80e1749 100644 --- a/deploy/cpp_infer/include/ocr_rec.h +++ b/deploy/cpp_infer/include/ocr_rec.h @@ -27,6 +27,7 @@ #include #include +#include #include #include #include @@ -56,7 +57,8 @@ public: // Load Paddle inference model void LoadModel(const std::string &model_dir); - void Run(std::vector>> boxes, cv::Mat &img); + void Run(std::vector>> boxes, cv::Mat &img, + Classifier *cls); private: std::shared_ptr predictor_; diff --git a/deploy/cpp_infer/include/preprocess_op.h b/deploy/cpp_infer/include/preprocess_op.h index 309d7fd4386330149afc91b474c330212fadd5e8..5cbc5cd7134238c4f09f536ca6b1153d2d703023 100644 --- a/deploy/cpp_infer/include/preprocess_op.h +++ b/deploy/cpp_infer/include/preprocess_op.h @@ -56,4 +56,10 @@ public: const std::vector &rec_image_shape = {3, 32, 320}); }; +class ClsResizeImg { +public: + virtual void Run(const cv::Mat &img, cv::Mat &resize_img, + const std::vector &rec_image_shape = {3, 32, 320}); +}; + } // namespace PaddleOCR \ No newline at end of file diff --git a/deploy/cpp_infer/src/main.cpp b/deploy/cpp_infer/src/main.cpp index 1dd33b301e8b7da1df2a6325cedb10b8156c43d2..e708a6e341e6dd5ba66abe46456e2d74a89e0cb5 100644 --- a/deploy/cpp_infer/src/main.cpp +++ b/deploy/cpp_infer/src/main.cpp @@ -53,6 +53,15 @@ int main(int argc, char **argv) { config.cpu_math_library_num_threads, config.use_mkldnn, config.use_zero_copy_run, config.max_side_len, config.det_db_thresh, config.det_db_box_thresh, config.det_db_unclip_ratio, config.visualize); + + Classifier *cls = nullptr; + if (config.use_angle_cls == true) { + cls = new Classifier(config.cls_model_dir, config.use_gpu, config.gpu_id, + config.gpu_mem, config.cpu_math_library_num_threads, + config.use_mkldnn, config.use_zero_copy_run, + config.cls_thresh); + } + CRNNRecognizer rec(config.rec_model_dir, config.use_gpu, config.gpu_id, config.gpu_mem, config.cpu_math_library_num_threads, config.use_mkldnn, config.use_zero_copy_run, @@ -62,7 +71,7 @@ int main(int argc, char **argv) { std::vector>> boxes; det.Run(srcimg, boxes); - rec.Run(boxes, srcimg); + rec.Run(boxes, srcimg, cls); auto end = std::chrono::system_clock::now(); auto duration = diff --git a/deploy/cpp_infer/src/ocr_cls.cpp b/deploy/cpp_infer/src/ocr_cls.cpp new file mode 100644 index 0000000000000000000000000000000000000000..7cdaaab40108026edffe5cb1ca53ac3972768cc6 --- /dev/null +++ b/deploy/cpp_infer/src/ocr_cls.cpp @@ -0,0 +1,110 @@ +// Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include + +namespace PaddleOCR { + +cv::Mat Classifier::Run(cv::Mat &img) { + cv::Mat src_img; + img.copyTo(src_img); + cv::Mat resize_img; + + std::vector rec_image_shape = {3, 32, 100}; + int index = 0; + float wh_ratio = float(img.cols) / float(img.rows); + + this->resize_op_.Run(img, resize_img, rec_image_shape); + + this->normalize_op_.Run(&resize_img, this->mean_, this->scale_, + this->is_scale_); + + std::vector input(1 * 3 * resize_img.rows * resize_img.cols, 0.0f); + + this->permute_op_.Run(&resize_img, input.data()); + + // Inference. + if (this->use_zero_copy_run_) { + auto input_names = this->predictor_->GetInputNames(); + auto input_t = this->predictor_->GetInputTensor(input_names[0]); + input_t->Reshape({1, 3, resize_img.rows, resize_img.cols}); + input_t->copy_from_cpu(input.data()); + this->predictor_->ZeroCopyRun(); + } else { + paddle::PaddleTensor input_t; + input_t.shape = {1, 3, resize_img.rows, resize_img.cols}; + input_t.data = + paddle::PaddleBuf(input.data(), input.size() * sizeof(float)); + input_t.dtype = PaddleDType::FLOAT32; + std::vector outputs; + this->predictor_->Run({input_t}, &outputs, 1); + } + + std::vector softmax_out; + std::vector label_out; + auto output_names = this->predictor_->GetOutputNames(); + auto softmax_out_t = this->predictor_->GetOutputTensor(output_names[0]); + auto label_out_t = this->predictor_->GetOutputTensor(output_names[1]); + auto softmax_shape_out = softmax_out_t->shape(); + auto label_shape_out = label_out_t->shape(); + + int softmax_out_num = + std::accumulate(softmax_shape_out.begin(), softmax_shape_out.end(), 1, + std::multiplies()); + + int label_out_num = + std::accumulate(label_shape_out.begin(), label_shape_out.end(), 1, + std::multiplies()); + softmax_out.resize(softmax_out_num); + label_out.resize(label_out_num); + + softmax_out_t->copy_to_cpu(softmax_out.data()); + label_out_t->copy_to_cpu(label_out.data()); + + int label = label_out[0]; + float score = softmax_out[label]; + // std::cout << "\nlabel "< this->cls_thresh) { + cv::rotate(src_img, src_img, 1); + } + return src_img; +} + +void Classifier::LoadModel(const std::string &model_dir) { + AnalysisConfig config; + config.SetModel(model_dir + "/model", model_dir + "/params"); + + if (this->use_gpu_) { + config.EnableUseGpu(this->gpu_mem_, this->gpu_id_); + } else { + config.DisableGpu(); + if (this->use_mkldnn_) { + config.EnableMKLDNN(); + } + config.SetCpuMathLibraryNumThreads(this->cpu_math_library_num_threads_); + } + + // false for zero copy tensor + config.SwitchUseFeedFetchOps(!this->use_zero_copy_run_); + // true for multiple input + config.SwitchSpecifyInputNames(true); + + config.SwitchIrOptim(true); + + config.EnableMemoryOptim(); + config.DisableGlogInfo(); + + this->predictor_ = CreatePaddlePredictor(config); +} +} // namespace PaddleOCR diff --git a/deploy/cpp_infer/src/ocr_rec.cpp b/deploy/cpp_infer/src/ocr_rec.cpp index b997d8291a64f9b6042bce648bcd358e34d55a95..7f88adc54636b4ecc61d257b7cb9159ebcdb82af 100644 --- a/deploy/cpp_infer/src/ocr_rec.cpp +++ b/deploy/cpp_infer/src/ocr_rec.cpp @@ -17,7 +17,7 @@ namespace PaddleOCR { void CRNNRecognizer::Run(std::vector>> boxes, - cv::Mat &img) { + cv::Mat &img, Classifier *cls) { cv::Mat srcimg; img.copyTo(srcimg); cv::Mat crop_img; @@ -27,6 +27,9 @@ void CRNNRecognizer::Run(std::vector>> boxes, int index = 0; for (int i = boxes.size() - 1; i >= 0; i--) { crop_img = GetRotateCropImage(srcimg, boxes[i]); + if (cls != nullptr) { + crop_img = cls->Run(crop_img); + } float wh_ratio = float(crop_img.cols) / float(crop_img.rows); diff --git a/deploy/cpp_infer/src/preprocess_op.cpp b/deploy/cpp_infer/src/preprocess_op.cpp index 0078063e06f05b1d260f76f3bbb2061b7c974f65..b44e9d022f2dcfb390cb28a7f34e0ba4b031832d 100644 --- a/deploy/cpp_infer/src/preprocess_op.cpp +++ b/deploy/cpp_infer/src/preprocess_op.cpp @@ -116,4 +116,26 @@ void CrnnResizeImg::Run(const cv::Mat &img, cv::Mat &resize_img, float wh_ratio, cv::INTER_LINEAR); } +void ClsResizeImg::Run(const cv::Mat &img, cv::Mat &resize_img, + const std::vector &rec_image_shape) { + int imgC, imgH, imgW; + imgC = rec_image_shape[0]; + imgH = rec_image_shape[1]; + imgW = rec_image_shape[2]; + + float ratio = float(img.cols) / float(img.rows); + int resize_w, resize_h; + if (ceilf(imgH * ratio) > imgW) + resize_w = imgW; + else + resize_w = int(ceilf(imgH * ratio)); + + cv::resize(img, resize_img, cv::Size(resize_w, imgH), 0.f, 0.f, + cv::INTER_LINEAR); + if (resize_w < imgW) { + cv::copyMakeBorder(resize_img, resize_img, 0, 0, 0, imgW - resize_w, + cv::BORDER_CONSTANT, cv::Scalar(0, 0, 0)); + } +} + } // namespace PaddleOCR \ No newline at end of file diff --git a/deploy/cpp_infer/tools/config.txt b/deploy/cpp_infer/tools/config.txt index 6c53f29eeb310677815d106d3e0ae39fb03bc2e2..28bacba60d4a599ad951c9820938b38e55b07283 100644 --- a/deploy/cpp_infer/tools/config.txt +++ b/deploy/cpp_infer/tools/config.txt @@ -13,6 +13,11 @@ det_db_box_thresh 0.5 det_db_unclip_ratio 2.0 det_model_dir ./inference/det_db +# cls config +use_angle_cls 0 +cls_model_dir ../inference/cls +cls_thresh 0.9 + # rec config rec_model_dir ./inference/rec_crnn char_list_file ../../ppocr/utils/ppocr_keys_v1.txt diff --git a/deploy/lite/Makefile b/deploy/lite/Makefile index 96e05ecf01904fdcb21a103e78783da6dd748ca9..4c30d64475730f5b1cdc20713493b31d66540b4b 100644 --- a/deploy/lite/Makefile +++ b/deploy/lite/Makefile @@ -40,8 +40,8 @@ CXX_LIBS = ${OPENCV_LIBS} -L$(LITE_ROOT)/cxx/lib/ -lpaddle_light_api_shared $(SY #CXX_LIBS = $(LITE_ROOT)/cxx/lib/libpaddle_api_light_bundled.a $(SYSTEM_LIBS) -ocr_db_crnn: fetch_opencv ocr_db_crnn.o crnn_process.o db_post_process.o clipper.o - $(CC) $(SYSROOT_LINK) $(CXXFLAGS_LINK) ocr_db_crnn.o crnn_process.o db_post_process.o clipper.o -o ocr_db_crnn $(CXX_LIBS) $(LDFLAGS) +ocr_db_crnn: fetch_opencv ocr_db_crnn.o crnn_process.o db_post_process.o clipper.o cls_process.o + $(CC) $(SYSROOT_LINK) $(CXXFLAGS_LINK) ocr_db_crnn.o crnn_process.o db_post_process.o clipper.o cls_process.o -o ocr_db_crnn $(CXX_LIBS) $(LDFLAGS) ocr_db_crnn.o: ocr_db_crnn.cc $(CC) $(SYSROOT_COMPLILE) $(CXX_DEFINES) $(CXX_INCLUDES) $(CXX_FLAGS) -o ocr_db_crnn.o -c ocr_db_crnn.cc @@ -49,6 +49,9 @@ ocr_db_crnn.o: ocr_db_crnn.cc crnn_process.o: fetch_opencv crnn_process.cc $(CC) $(SYSROOT_COMPLILE) $(CXX_DEFINES) $(CXX_INCLUDES) $(CXX_FLAGS) -o crnn_process.o -c crnn_process.cc +cls_process.o: fetch_opencv cls_process.cc + $(CC) $(SYSROOT_COMPLILE) $(CXX_DEFINES) $(CXX_INCLUDES) $(CXX_FLAGS) -o cls_process.o -c cls_process.cc + db_post_process.o: fetch_clipper fetch_opencv db_post_process.cc $(CC) $(SYSROOT_COMPLILE) $(CXX_DEFINES) $(CXX_INCLUDES) $(CXX_FLAGS) -o db_post_process.o -c db_post_process.cc @@ -73,5 +76,5 @@ fetch_opencv: .PHONY: clean clean: - rm -f ocr_db_crnn.o clipper.o db_post_process.o crnn_process.o + rm -f ocr_db_crnn.o clipper.o db_post_process.o crnn_process.o cls_process.o rm -f ocr_db_crnn diff --git a/deploy/lite/cls_process.cc b/deploy/lite/cls_process.cc new file mode 100644 index 0000000000000000000000000000000000000000..f522e4bc5a9050c981a91bcbd8dfefd719f034c2 --- /dev/null +++ b/deploy/lite/cls_process.cc @@ -0,0 +1,43 @@ +// Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "cls_process.h" //NOLINT +#include +#include +#include + +const std::vector rec_image_shape{3, 32, 100}; + +cv::Mat ClsResizeImg(cv::Mat img) { + int imgC, imgH, imgW; + imgC = rec_image_shape[0]; + imgH = rec_image_shape[1]; + imgW = rec_image_shape[2]; + + float ratio = static_cast(img.cols) / static_cast(img.rows); + + int resize_w, resize_h; + if (ceilf(imgH * ratio) > imgW) + resize_w = imgW; + else + resize_w = int(ceilf(imgH * ratio)); + cv::Mat resize_img; + cv::resize(img, resize_img, cv::Size(resize_w, imgH), 0.f, 0.f, + cv::INTER_LINEAR); + if (resize_w < imgW) { + cv::copyMakeBorder(resize_img, resize_img, 0, 0, 0, imgW - resize_w, + cv::BORDER_CONSTANT, cv::Scalar(0, 0, 0)); + } + return resize_img; +} \ No newline at end of file diff --git a/deploy/lite/cls_process.h b/deploy/lite/cls_process.h new file mode 100644 index 0000000000000000000000000000000000000000..eedeeb9ba7a14a5686fd27ce96ad26b74f2bf7ed --- /dev/null +++ b/deploy/lite/cls_process.h @@ -0,0 +1,29 @@ +// Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include +#include +#include +#include +#include +#include + +#include "math.h" //NOLINT +#include "opencv2/core.hpp" +#include "opencv2/imgcodecs.hpp" +#include "opencv2/imgproc.hpp" + +cv::Mat ClsResizeImg(cv::Mat img); \ No newline at end of file diff --git a/deploy/lite/ocr_db_crnn.cc b/deploy/lite/ocr_db_crnn.cc index c94062fdf1e1e6f20422e27a61e41db316ad3e41..fea093c3ac2867faf8d84e27ddcaeb55219a8cca 100644 --- a/deploy/lite/ocr_db_crnn.cc +++ b/deploy/lite/ocr_db_crnn.cc @@ -15,6 +15,7 @@ #include "paddle_api.h" // NOLINT #include +#include "cls_process.h" #include "crnn_process.h" #include "db_post_process.h" @@ -105,11 +106,55 @@ cv::Mat DetResizeImg(const cv::Mat img, int max_size_len, return resize_img; } +cv::Mat RunClsModel(cv::Mat img, std::shared_ptr predictor_cls, + const float thresh = 0.5) { + std::vector mean = {0.5f, 0.5f, 0.5f}; + std::vector scale = {1 / 0.5f, 1 / 0.5f, 1 / 0.5f}; + + cv::Mat srcimg; + img.copyTo(srcimg); + cv::Mat crop_img; + cv::Mat resize_img; + + int index = 0; + float wh_ratio = + static_cast(crop_img.cols) / static_cast(crop_img.rows); + + resize_img = ClsResizeImg(crop_img); + resize_img.convertTo(resize_img, CV_32FC3, 1 / 255.f); + + const float *dimg = reinterpret_cast(resize_img.data); + + std::unique_ptr input_tensor0(std::move(predictor_cls->GetInput(0))); + input_tensor0->Resize({1, 3, resize_img.rows, resize_img.cols}); + auto *data0 = input_tensor0->mutable_data(); + + NeonMeanScale(dimg, data0, resize_img.rows * resize_img.cols, mean, scale); + // Run CLS predictor + predictor_cls->Run(); + + // Get output and run postprocess + std::unique_ptr softmax_out( + std::move(predictor_cls->GetOutput(0))); + std::unique_ptr label_out( + std::move(predictor_cls->GetOutput(1))); + auto *softmax_scores = softmax_out->mutable_data(); + auto *label_idxs = label_out->data(); + int label_idx = label_idxs[0]; + float score = softmax_scores[label_idx]; + + if (label_idx % 2 == 1 && score > thresh) { + cv::rotate(srcimg, srcimg, 1); + } + return srcimg; +} + void RunRecModel(std::vector>> boxes, cv::Mat img, std::shared_ptr predictor_crnn, std::vector &rec_text, std::vector &rec_text_score, - std::vector charactor_dict) { + std::vector charactor_dict, + std::shared_ptr predictor_cls) { std::vector mean = {0.5f, 0.5f, 0.5f}; std::vector scale = {1 / 0.5f, 1 / 0.5f, 1 / 0.5f}; @@ -121,6 +166,7 @@ void RunRecModel(std::vector>> boxes, cv::Mat img, int index = 0; for (int i = boxes.size() - 1; i >= 0; i--) { crop_img = GetRotateCropImage(srcimg, boxes[i]); + crop_img = RunClsModel(crop_img, predictor_cls); float wh_ratio = static_cast(crop_img.cols) / static_cast(crop_img.rows); @@ -323,8 +369,9 @@ int main(int argc, char **argv) { } std::string det_model_file = argv[1]; std::string rec_model_file = argv[2]; - std::string img_path = argv[3]; - std::string dict_path = argv[4]; + std::string cls_model_file = argv[3]; + std::string img_path = argv[4]; + std::string dict_path = argv[5]; //// load config from txt file auto Config = LoadConfigTxt("./config.txt"); @@ -333,6 +380,7 @@ int main(int argc, char **argv) { auto det_predictor = loadModel(det_model_file); auto rec_predictor = loadModel(rec_model_file); + auto cls_predictor = loadModel(cls_model_file); auto charactor_dict = ReadDict(dict_path); charactor_dict.push_back(" "); @@ -343,7 +391,7 @@ int main(int argc, char **argv) { std::vector rec_text; std::vector rec_text_score; RunRecModel(boxes, srcimg, rec_predictor, rec_text, rec_text_score, - charactor_dict); + charactor_dict, cls_predictor); auto end = std::chrono::system_clock::now(); auto duration = diff --git a/doc/doc_ch/angle_class.md b/doc/doc_ch/angle_class.md new file mode 100644 index 0000000000000000000000000000000000000000..b2118661290ac0b6f2731a8fd9ba76dadcb21ded --- /dev/null +++ b/doc/doc_ch/angle_class.md @@ -0,0 +1,127 @@ +## 文字角度分类 + +### 数据准备 + +请按如下步骤设置数据集: + +训练数据的默认存储路径是 `PaddleOCR/train_data/cls`,如果您的磁盘上已有数据集,只需创建软链接至数据集目录: + +``` +ln -sf /train_data/cls/dataset +``` + +请参考下文组织您的数据。 +- 训练集 + +首先请将训练图片放入同一个文件夹(train_images),并用一个txt文件(cls_gt_train.txt)记录图片路径和标签。 + +**注意:** 默认请将图片路径和图片标签用 `\t` 分割,如用其他方式分割将造成训练报错 + +0和180分别表示图片的角度为0度和180度 + +``` +" 图像文件名 图像标注信息 " + +train_data/cls/word_001.jpg 0 +train_data/cls/word_002.jpg 180 +``` + +最终训练集应有如下文件结构: +``` +|-train_data + |-cls + |- cls_gt_train.txt + |- train + |- word_001.png + |- word_002.jpg + |- word_003.jpg + | ... +``` + +- 测试集 + +同训练集类似,测试集也需要提供一个包含所有图片的文件夹(test)和一个cls_gt_test.txt,测试集的结构如下所示: + +``` +|-train_data + |-cls + |- 和一个cls_gt_test.txt + |- test + |- word_001.jpg + |- word_002.jpg + |- word_003.jpg + | ... +``` + +### 启动训练 + +PaddleOCR提供了训练脚本、评估脚本和预测脚本。 + +开始训练: + +*如果您安装的是cpu版本,请将配置文件中的 `use_gpu` 字段修改为false* + +``` +# 设置PYTHONPATH路径 +export PYTHONPATH=$PYTHONPATH:. +# GPU训练 支持单卡,多卡训练,通过CUDA_VISIBLE_DEVICES指定卡号 +export CUDA_VISIBLE_DEVICES=0,1,2,3 +# 启动训练 +python3 tools/train.py -c configs/cls/cls_mv3.yml +``` + +- 数据增强 + +PaddleOCR提供了多种数据增强方式,如果您希望在训练时加入扰动,请在配置文件中设置 `distort: true`。 + +默认的扰动方式有:颜色空间转换(cvtColor)、模糊(blur)、抖动(jitter)、噪声(Gasuss noise)、随机切割(random crop)、透视(perspective)、颜色反转(reverse),随机数据增强(RandAugment)。 + +训练过程中除随机数据增强外每种扰动方式以50%的概率被选择,具体代码实现请参考: +[randaugment.py](https://github.com/PaddlePaddle/PaddleOCR/blob/develop/ppocr/data/cls/randaugment.py) +[img_tools.py](https://github.com/PaddlePaddle/PaddleOCR/blob/develop/ppocr/data/rec/img_tools.py) + +*由于OpenCV的兼容性问题,扰动操作暂时只支持linux* + +### 训练 + +PaddleOCR支持训练和评估交替进行, 可以在 `configs/cls/cls_mv3.yml` 中修改 `eval_batch_step` 设置评估频率,默认每500个iter评估一次。评估过程中默认将最佳acc模型,保存为 `output/cls_mv3/best_accuracy` 。 + +如果验证集很大,测试将会比较耗时,建议减少评估次数,或训练完再进行评估。 + +**注意,预测/评估时的配置文件请务必与训练一致。** + +### 评估 + +评估数据集可以通过`configs/cls/cls_reader.yml` 修改EvalReader中的 `label_file_path` 设置。 + +*注意* 评估时必须确保配置文件中 infer_img 字段为空 +``` +export CUDA_VISIBLE_DEVICES=0 +# GPU 评估, Global.checkpoints 为待测权重 +python3 tools/eval.py -c configs/cls/cls_mv3.yml -o Global.checkpoints={path/to/weights}/best_accuracy +``` + +### 预测 + +* 训练引擎的预测 + +使用 PaddleOCR 训练好的模型,可以通过以下脚本进行快速预测。 + +默认预测图片存储在 `infer_img` 里,通过 `-o Global.checkpoints` 指定权重: + +``` +# 预测分类结果 +python3 tools/infer_cls.py -c configs/cls/cls_mv3.yml -o Global.checkpoints={path/to/weights}/best_accuracy Global.infer_img=doc/imgs_words/en/word_1.png +``` + +预测图片: + +![](../imgs_words/en/word_1.png) + +得到输入图像的预测结果: + +``` +infer_img: doc/imgs_words/en/word_1.png + scores: [[0.93161047 0.06838956]] + label: [0] +``` diff --git a/doc/doc_ch/inference.md b/doc/doc_ch/inference.md index 293fee2f4291a8400661de1ed08f0c6807eef977..bfab955b75ad8aecce9ef124fc6df6d62be337a5 100644 --- a/doc/doc_ch/inference.md +++ b/doc/doc_ch/inference.md @@ -11,24 +11,28 @@ inference 模型(`fluid.io.save_inference_model`保存的模型) - [一、训练模型转inference模型](#训练模型转inference模型) - [检测模型转inference模型](#检测模型转inference模型) - [识别模型转inference模型](#识别模型转inference模型) - + - [方向分类模型转inference模型](#方向模型转inference模型) + - [二、文本检测模型推理](#文本检测模型推理) - [1. 超轻量中文检测模型推理](#超轻量中文检测模型推理) - [2. DB文本检测模型推理](#DB文本检测模型推理) - [3. EAST文本检测模型推理](#EAST文本检测模型推理) - [4. SAST文本检测模型推理](#SAST文本检测模型推理) - + - [三、文本识别模型推理](#文本识别模型推理) - [1. 超轻量中文识别模型推理](#超轻量中文识别模型推理) - [2. 基于CTC损失的识别模型推理](#基于CTC损失的识别模型推理) - [3. 基于Attention损失的识别模型推理](#基于Attention损失的识别模型推理) - - [4. 自定义文本识别字典的推理](#自定义文本识别字典的推理) - -- [四、文本检测、识别串联推理](#文本检测、识别串联推理) + - [4. 自定义文本识别字典的推理](#自定义文本识别字典的推理) + +- [四、方向分类模型推理](#方向识别模型推理) + - [1. 方向分类模型推理](#方向分类模型推理) + +- [五、文本检测、方向分类和文字识别串联推理](#文本检测、方向分类和文字识别串联推理) - [1. 超轻量中文OCR模型推理](#超轻量中文OCR模型推理) - [2. 其他模型推理](#其他模型推理) - - + + ## 一、训练模型转inference模型 @@ -84,6 +88,32 @@ python3 tools/export_model.py -c configs/rec/rec_chinese_lite_train.yml -o Globa └─ params 识别inference模型的参数文件 ``` + +### 方向分类模型转inference模型 + +下载方向分类模型: +``` +wget -P ./ch_lite/ https://paddleocr.bj.bcebos.com/20-09-22/cls/ch_ppocr_mobile-v1.1.cls_pre.tar && tar xf ./ch_lite/ch_ppocr_mobile-v1.1.cls_pre.tar -C ./ch_lite/ +``` + +方向分类模型转inference模型与检测的方式相同,如下: +``` +# -c后面设置训练算法的yml配置文件 +# -o配置可选参数 +# Global.checkpoints参数设置待转换的训练模型地址,不用添加文件后缀.pdmodel,.pdopt或.pdparams。 +# Global.save_inference_dir参数设置转换的模型将保存的地址。 + +python3 tools/export_model.py -c configs/cls/cls_mv3.yml -o Global.checkpoints=./ch_lite/cls_model/best_accuracy \ + Global.save_inference_dir=./inference/cls/ +``` + +转换成功后,在目录下有两个文件: +``` +/inference/cls/ + └─ model 识别inference模型的program文件 + └─ params 识别inference模型的参数文件 +``` + ## 二、文本检测模型推理 @@ -275,15 +305,36 @@ dict_character = list(self.character_str) python3 tools/infer/predict_rec.py --image_dir="./doc/imgs_words_en/word_336.png" --rec_model_dir="./your inference model" --rec_image_shape="3, 32, 100" --rec_char_type="en" --rec_char_dict_path="your text dict path" ``` - -## 四、文本检测、识别串联推理 + + +## 四、方向分类模型推理 + +下面将介绍方向分类模型推理。 + + +### 1. 方向分类模型推理 + +方向分类模型推理,可以执行如下命令: + +``` +python3 tools/infer/predict_cls.py --image_dir="./doc/imgs_words/ch/word_4.jpg" --cls_model_dir="./inference/cls/" +``` + +![](../imgs_words/ch/word_4.jpg) + +执行命令后,上面图像的预测结果(分类的方向和得分)会打印到屏幕上,示例如下: + +Predicts of ./doc/imgs_words/ch/word_4.jpg:['0', 0.9999963] + + +## 五、文本检测、方向分类和文字识别串联推理 ### 1. 超轻量中文OCR模型推理 -在执行预测时,需要通过参数image_dir指定单张图像或者图像集合的路径、参数det_model_dir指定检测inference模型的路径和参数rec_model_dir指定识别inference模型的路径。可视化识别结果默认保存到 ./inference_results 文件夹里面。 +在执行预测时,需要通过参数`image_dir`指定单张图像或者图像集合的路径、参数`det_model_dir`,`cls_model_dir`和`rec_model_dir`分别指定检测,方向分类和识别的inference模型路径。参数`use_angle_cls`用于控制是否启用方向分类模型。可视化识别结果默认保存到 ./inference_results 文件夹里面。 ``` -python3 tools/infer/predict_system.py --image_dir="./doc/imgs/2.jpg" --det_model_dir="./inference/det_db/" --rec_model_dir="./inference/rec_crnn/" +python3 tools/infer/predict_system.py --image_dir="./doc/imgs/2.jpg" --det_model_dir="./inference/det_db/" --cls_model_dir="./inference/cls/" --rec_model_dir="./inference/rec_crnn/" --use_angle_cls true ``` 执行命令后,识别结果图像如下: diff --git a/doc/doc_en/angle_class_en.md b/doc/doc_en/angle_class_en.md new file mode 100644 index 0000000000000000000000000000000000000000..c7fff3a1833570cda7687b87efb7c3af2ec49120 --- /dev/null +++ b/doc/doc_en/angle_class_en.md @@ -0,0 +1,126 @@ +## TEXT ANGLE CLASSIFICATION + +### DATA PREPARATION + +Please organize the dataset as follows: + +The default storage path for training data is `PaddleOCR/train_data/cls`, if you already have a dataset on your disk, just create a soft link to the dataset directory: + +``` +ln -sf /train_data/cls/dataset +``` + +please refer to the following to organize your data. + +- Training set + +First put the training images in the same folder (train_images), and use a txt file (cls_gt_train.txt) to store the image path and label. + +* Note: by default, the image path and image label are split with `\t`, if you use other methods to split, it will cause training error + +0 and 180 indicate that the angle of the image is 0 degrees and 180 degrees, respectively. + +``` +" Image file name Image annotation " + +train_data/word_001.jpg 0 +train_data/word_002.jpg 180 +``` + +The final training set should have the following file structure: + +``` +|-train_data + |-cls + |- cls_gt_train.txt + |- train + |- word_001.png + |- word_002.jpg + |- word_003.jpg + | ... +``` + +- Test set + +Similar to the training set, the test set also needs to be provided a folder +containing all images (test) and a cls_gt_test.txt. The structure of the test set is as follows: + +``` +|-train_data + |-cls + |- cls_gt_test.txt + |- test + |- word_001.jpg + |- word_002.jpg + |- word_003.jpg + | ... +``` + +### TRAINING + +PaddleOCR provides training scripts, evaluation scripts, and prediction scripts. + +Start training: + +``` +# Set PYTHONPATH path +export PYTHONPATH=$PYTHONPATH:. +# GPU training Support single card and multi-card training, specify the card number through CUDA_VISIBLE_DEVICES +export CUDA_VISIBLE_DEVICES=0,1,2,3 +# Training icdar15 English data +python3 tools/train.py -c configs/cls/cls_mv3.yml +``` + +- Data Augmentation + +PaddleOCR provides a variety of data augmentation methods. If you want to add disturbance during training, please set `distort: true` in the configuration file. + +The default perturbation methods are: cvtColor, blur, jitter, Gasuss noise, random crop, perspective, color reverse, RandAugment. + +Except for RandAugment, each disturbance method is selected with a 50% probability during the training process. For specific code implementation, please refer to: +[randaugment.py](https://github.com/PaddlePaddle/PaddleOCR/blob/develop/ppocr/data/cls/randaugment.py) +[img_tools.py](https://github.com/PaddlePaddle/PaddleOCR/blob/develop/ppocr/data/rec/img_tools.py) + + +- Training + +PaddleOCR supports alternating training and evaluation. You can modify `eval_batch_step` in `configs/cls/cls_mv3.yml` to set the evaluation frequency. By default, it is evaluated every 500 iter and the best acc model is saved under `output/cls_mv3/best_accuracy` during the evaluation process. + +If the evaluation set is large, the test will be time-consuming. It is recommended to reduce the number of evaluations, or evaluate after training. + +**Note that the configuration file for prediction/evaluation must be consistent with the training.** + +### EVALUATION + +The evaluation data set can be modified via `configs/cls/cls_reader.yml` setting of `label_file_path` in EvalReader. + +``` +export CUDA_VISIBLE_DEVICES=0 +# GPU evaluation, Global.checkpoints is the weight to be tested +python3 tools/eval.py -c configs/cls/cls_mv3.yml -o Global.checkpoints={path/to/weights}/best_accuracy +``` + +### PREDICTION + +* Training engine prediction + +Using the model trained by paddleocr, you can quickly get prediction through the following script. + +The default prediction picture is stored in `infer_img`, and the weight is specified via `-o Global.checkpoints`: + +``` +# Predict English results +python3 tools/infer_rec.py -c configs/cls/cls_mv3.yml -o Global.checkpoints={path/to/weights}/best_accuracy TestReader.infer_img=doc/imgs_words/en/word_1.jpg +``` + +Input image: + +![](../imgs_words/en/word_1.png) + +Get the prediction result of the input image: + +``` +infer_img: doc/imgs_words/en/word_1.png + scores: [[0.93161047 0.06838956]] + label: [0] +``` diff --git a/doc/doc_en/inference_en.md b/doc/doc_en/inference_en.md index 83ec2a90c45a320815e10e8572d894068c0b5130..db064f03e565b4fa3de3409d4a17459c636eae70 100644 --- a/doc/doc_en/inference_en.md +++ b/doc/doc_en/inference_en.md @@ -12,25 +12,28 @@ Next, we first introduce how to convert a trained model into an inference model, - [CONVERT TRAINING MODEL TO INFERENCE MODEL](#CONVERT) - [Convert detection model to inference model](#Convert_detection_model) - [Convert recognition model to inference model](#Convert_recognition_model) - - + - [Convert angle classification model to inference model](#Convert_angle_class_model) + + - [TEXT DETECTION MODEL INFERENCE](#DETECTION_MODEL_INFERENCE) - [1. LIGHTWEIGHT CHINESE DETECTION MODEL INFERENCE](#LIGHTWEIGHT_DETECTION) - [2. DB TEXT DETECTION MODEL INFERENCE](#DB_DETECTION) - [3. EAST TEXT DETECTION MODEL INFERENCE](#EAST_DETECTION) - [4. SAST TEXT DETECTION MODEL INFERENCE](#SAST_DETECTION) - + - [TEXT RECOGNITION MODEL INFERENCE](#RECOGNITION_MODEL_INFERENCE) - [1. LIGHTWEIGHT CHINESE MODEL](#LIGHTWEIGHT_RECOGNITION) - [2. CTC-BASED TEXT RECOGNITION MODEL INFERENCE](#CTC-BASED_RECOGNITION) - [3. ATTENTION-BASED TEXT RECOGNITION MODEL INFERENCE](#ATTENTION-BASED_RECOGNITION) - [4. TEXT RECOGNITION MODEL INFERENCE USING CUSTOM CHARACTERS DICTIONARY](#USING_CUSTOM_CHARACTERS) - - -- [TEXT DETECTION AND RECOGNITION INFERENCE CONCATENATION](#CONCATENATION) + +- [ANGLE CLASSIFICATION MODEL INFERENCE](#ANGLE_CLASS_MODEL_INFERENCE) + - [1. ANGLE CLASSIFICATION MODEL INFERENCE](#ANGLE_CLASS_MODEL_INFERENCE) + +- [TEXT DETECTION ANGLE CLASSIFICATION AND RECOGNITION INFERENCE CONCATENATION](#CONCATENATION) - [1. LIGHTWEIGHT CHINESE MODEL](#LIGHTWEIGHT_CHINESE_MODEL) - [2. OTHER MODELS](#OTHER_MODELS) - + ## CONVERT TRAINING MODEL TO INFERENCE MODEL @@ -87,6 +90,33 @@ After the conversion is successful, there are two files in the directory: └─ params Identify the parameter files of the inference model ``` + +### Convert angle classification model to inference model + +Download the angle classification model: +``` +wget -P ./ch_lite/ https://paddleocr.bj.bcebos.com/20-09-22/cls/ch_ppocr_mobile-v1.1.cls_pre.tar && tar xf ./ch_lite/ch_ppocr_mobile-v1.1.cls_pre.tar -C ./ch_lite/ +``` + +The angle classification model is converted to the inference model in the same way as the detection, as follows: +``` +# -c Set the training algorithm yml configuration file +# -o Set optional parameters +# Global.checkpoints parameter Set the training model address to be converted without adding the file suffix .pdmodel, .pdopt or .pdparams. +# Global.save_inference_dir Set the address where the converted model will be saved. + +python3 tools/export_model.py -c configs/cls/cls_mv3.yml -o Global.checkpoints=./ch_lite/cls_model/best_accuracy \ + Global.save_inference_dir=./inference/cls/ +``` + +After the conversion is successful, there are two files in the directory: +``` +/inference/cls/ + └─ model Identify the saved model files + └─ params Identify the parameter files of the inference model +``` + + ## TEXT DETECTION MODEL INFERENCE @@ -276,16 +306,39 @@ If the chars dictionary is modified during training, you need to specify the new python3 tools/infer/predict_rec.py --image_dir="./doc/imgs_words_en/word_336.png" --rec_model_dir="./your inference model" --rec_image_shape="3, 32, 100" --rec_char_type="en" --rec_char_dict_path="your text dict path" ``` + + +## ANGLE CLASSIFICATION MODEL INFERENCE + +The following will introduce the angle classification model inference. + + + +### 1.ANGLE CLASSIFICATION MODEL INFERENCE + +For angle classification model inference, you can execute the following commands: + +``` +python3 tools/infer/predict_cls.py --image_dir="./doc/imgs_words/ch/word_4.jpg" --cls_model_dir="./inference/cls/" +``` + +![](../imgs_words/ch/word_4.jpg) + +After executing the command, the prediction results (classification angle and score) of the above image will be printed on the screen. + +Predicts of ./doc/imgs_words/ch/word_4.jpg:['0', 0.9999963] + + -## TEXT DETECTION AND RECOGNITION INFERENCE CONCATENATION +## TEXT DETECTION ANGLE CLASSIFICATION AND RECOGNITION INFERENCE CONCATENATION ### 1. LIGHTWEIGHT CHINESE MODEL -When performing prediction, you need to specify the path of a single image or a folder of images through the parameter `image_dir`, the parameter `det_model_dir` specifies the path to detect the inference model, and the parameter `rec_model_dir` specifies the path to identify the inference model. The visualized recognition results are saved to the `./inference_results` folder by default. +When performing prediction, you need to specify the path of a single image or a folder of images through the parameter `image_dir`, the parameter `det_model_dir` specifies the path to detect the inference model, the parameter `cls_model_dir` specifies the path to angle classification inference model and the parameter `rec_model_dir` specifies the path to identify the inference model. The parameter `use_angle_cls` is used to control whether to enable the angle classification model.The visualized recognition results are saved to the `./inference_results` folder by default. ``` -python3 tools/infer/predict_system.py --image_dir="./doc/imgs/2.jpg" --det_model_dir="./inference/det_db/" --rec_model_dir="./inference/rec_crnn/" +python3 tools/infer/predict_system.py --image_dir="./doc/imgs/2.jpg" --det_model_dir="./inference/det_db/" --cls_model_dir="./inference/cls/" --rec_model_dir="./inference/rec_crnn/" --use_angle_cls true ``` After executing the command, the recognition result image is as follows: diff --git a/ppocr/data/cls/__init__.py b/ppocr/data/cls/__init__.py new file mode 100755 index 0000000000000000000000000000000000000000..abf198b97e6e818e1fbe59006f98492640bcee54 --- /dev/null +++ b/ppocr/data/cls/__init__.py @@ -0,0 +1,13 @@ +# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/ppocr/data/cls/dataset_traversal.py b/ppocr/data/cls/dataset_traversal.py new file mode 100755 index 0000000000000000000000000000000000000000..01f8c89c839f0c8f6d07ca6ad9676947ce25f6ab --- /dev/null +++ b/ppocr/data/cls/dataset_traversal.py @@ -0,0 +1,144 @@ +# copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os +import sys +import math +import random +import numpy as np +import cv2 + +from ppocr.utils.utility import initial_logger +from ppocr.utils.utility import get_image_file_list + +logger = initial_logger() + +from ppocr.data.rec.img_tools import resize_norm_img, warp +from ppocr.data.cls.randaugment import RandAugment + + +def random_crop(img): + img_h, img_w = img.shape[:2] + if img_w > img_h * 4: + w = random.randint(img_h * 2, img_w) + i = random.randint(0, img_w - w) + + img = img[:, i:i + w, :] + return img + + +class SimpleReader(object): + def __init__(self, params): + if params['mode'] != 'train': + self.num_workers = 1 + else: + self.num_workers = params['num_workers'] + if params['mode'] != 'test': + self.img_set_dir = params['img_set_dir'] + self.label_file_path = params['label_file_path'] + self.use_gpu = params['use_gpu'] + self.image_shape = params['image_shape'] + self.mode = params['mode'] + self.infer_img = params['infer_img'] + self.use_distort = params['mode'] == 'train' and params['distort'] + self.randaug = RandAugment() + self.label_list = params['label_list'] + if "distort" in params: + self.use_distort = params['distort'] and params['use_gpu'] + if not params['use_gpu']: + logger.info( + "Distort operation can only support in GPU.Distort will be set to False." + ) + if params['mode'] == 'train': + self.batch_size = params['train_batch_size_per_card'] + self.drop_last = True + else: + self.batch_size = params['test_batch_size_per_card'] + self.drop_last = False + self.use_distort = False + + def __call__(self, process_id): + if self.mode != 'train': + process_id = 0 + + def get_device_num(): + if self.use_gpu: + gpus = os.environ.get("CUDA_VISIBLE_DEVICES", 1) + gpu_num = len(gpus.split(',')) + return gpu_num + else: + cpu_num = os.environ.get("CPU_NUM", 1) + return int(cpu_num) + + def sample_iter_reader(): + if self.mode != 'train' and self.infer_img is not None: + image_file_list = get_image_file_list(self.infer_img) + for single_img in image_file_list: + img = cv2.imread(single_img) + if img.shape[-1] == 1 or len(list(img.shape)) == 2: + img = cv2.cvtColor(img, cv2.COLOR_GRAY2BGR) + norm_img = resize_norm_img(img, self.image_shape) + + norm_img = norm_img[np.newaxis, :] + yield norm_img + else: + with open(self.label_file_path, "rb") as fin: + label_infor_list = fin.readlines() + img_num = len(label_infor_list) + img_id_list = list(range(img_num)) + random.shuffle(img_id_list) + if sys.platform == "win32" and self.num_workers != 1: + print("multiprocess is not fully compatible with Windows." + "num_workers will be 1.") + self.num_workers = 1 + if self.batch_size * get_device_num( + ) * self.num_workers > img_num: + raise Exception( + "The number of the whole data ({}) is smaller than the batch_size * devices_num * num_workers ({})". + format(img_num, self.batch_size * get_device_num() * + self.num_workers)) + for img_id in range(process_id, img_num, self.num_workers): + label_infor = label_infor_list[img_id_list[img_id]] + substr = label_infor.decode('utf-8').strip("\n").split("\t") + label = self.label_list.index(substr[1]) + + img_path = self.img_set_dir + "/" + substr[0] + img = cv2.imread(img_path) + if img is None: + logger.info("{} does not exist!".format(img_path)) + continue + if img.shape[-1] == 1 or len(list(img.shape)) == 2: + img = cv2.cvtColor(img, cv2.COLOR_GRAY2BGR) + + if self.use_distort: + img = warp(img, 10) + img = self.randaug(img) + norm_img = resize_norm_img(img, self.image_shape) + norm_img = norm_img[np.newaxis, :] + yield (norm_img, label) + + def batch_iter_reader(): + batch_outs = [] + for outs in sample_iter_reader(): + batch_outs.append(outs) + if len(batch_outs) == self.batch_size: + yield batch_outs + batch_outs = [] + if not self.drop_last: + if len(batch_outs) != 0: + yield batch_outs + + if self.infer_img is None: + return batch_iter_reader + return sample_iter_reader diff --git a/ppocr/data/cls/randaugment.py b/ppocr/data/cls/randaugment.py new file mode 100644 index 0000000000000000000000000000000000000000..21345c05be59f6d1c9ae5a8d396ffed2dd9b0ca1 --- /dev/null +++ b/ppocr/data/cls/randaugment.py @@ -0,0 +1,135 @@ +# copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function +from __future__ import unicode_literals + +from PIL import Image, ImageEnhance, ImageOps +import numpy as np +import random +import six + + +class RawRandAugment(object): + def __init__(self, num_layers=2, magnitude=5, fillcolor=(128, 128, 128)): + self.num_layers = num_layers + self.magnitude = magnitude + self.max_level = 10 + + abso_level = self.magnitude / self.max_level + self.level_map = { + "shearX": 0.3 * abso_level, + "shearY": 0.3 * abso_level, + "translateX": 150.0 / 331 * abso_level, + "translateY": 150.0 / 331 * abso_level, + "rotate": 30 * abso_level, + "color": 0.9 * abso_level, + "posterize": int(4.0 * abso_level), + "solarize": 256.0 * abso_level, + "contrast": 0.9 * abso_level, + "sharpness": 0.9 * abso_level, + "brightness": 0.9 * abso_level, + "autocontrast": 0, + "equalize": 0, + "invert": 0 + } + + # from https://stackoverflow.com/questions/5252170/ + # specify-image-filling-color-when-rotating-in-python-with-pil-and-setting-expand + def rotate_with_fill(img, magnitude): + rot = img.convert("RGBA").rotate(magnitude) + return Image.composite(rot, + Image.new("RGBA", rot.size, (128, ) * 4), + rot).convert(img.mode) + + rnd_ch_op = random.choice + + self.func = { + "shearX": lambda img, magnitude: img.transform( + img.size, + Image.AFFINE, + (1, magnitude * rnd_ch_op([-1, 1]), 0, 0, 1, 0), + Image.BICUBIC, + fillcolor=fillcolor), + "shearY": lambda img, magnitude: img.transform( + img.size, + Image.AFFINE, + (1, 0, 0, magnitude * rnd_ch_op([-1, 1]), 1, 0), + Image.BICUBIC, + fillcolor=fillcolor), + "translateX": lambda img, magnitude: img.transform( + img.size, + Image.AFFINE, + (1, 0, magnitude * img.size[0] * rnd_ch_op([-1, 1]), 0, 1, 0), + fillcolor=fillcolor), + "translateY": lambda img, magnitude: img.transform( + img.size, + Image.AFFINE, + (1, 0, 0, 0, 1, magnitude * img.size[1] * rnd_ch_op([-1, 1])), + fillcolor=fillcolor), + "rotate": lambda img, magnitude: rotate_with_fill(img, magnitude), + "color": lambda img, magnitude: ImageEnhance.Color(img).enhance( + 1 + magnitude * rnd_ch_op([-1, 1])), + "posterize": lambda img, magnitude: + ImageOps.posterize(img, magnitude), + "solarize": lambda img, magnitude: + ImageOps.solarize(img, magnitude), + "contrast": lambda img, magnitude: + ImageEnhance.Contrast(img).enhance( + 1 + magnitude * rnd_ch_op([-1, 1])), + "sharpness": lambda img, magnitude: + ImageEnhance.Sharpness(img).enhance( + 1 + magnitude * rnd_ch_op([-1, 1])), + "brightness": lambda img, magnitude: + ImageEnhance.Brightness(img).enhance( + 1 + magnitude * rnd_ch_op([-1, 1])), + "autocontrast": lambda img, magnitude: + ImageOps.autocontrast(img), + "equalize": lambda img, magnitude: ImageOps.equalize(img), + "invert": lambda img, magnitude: ImageOps.invert(img) + } + + def __call__(self, img): + avaiable_op_names = list(self.level_map.keys()) + for layer_num in range(self.num_layers): + op_name = np.random.choice(avaiable_op_names) + img = self.func[op_name](img, self.level_map[op_name]) + return img + + +class RandAugment(RawRandAugment): + """ RandAugment wrapper to auto fit different img types """ + + def __init__(self, *args, **kwargs): + if six.PY2: + super(RandAugment, self).__init__(*args, **kwargs) + else: + super().__init__(*args, **kwargs) + + def __call__(self, img): + if not isinstance(img, Image.Image): + img = np.ascontiguousarray(img) + img = Image.fromarray(img) + + if six.PY2: + img = super(RandAugment, self).__call__(img) + else: + img = super().__call__(img) + + if isinstance(img, Image.Image): + img = np.asarray(img) + + return img diff --git a/ppocr/modeling/architectures/cls_model.py b/ppocr/modeling/architectures/cls_model.py new file mode 100755 index 0000000000000000000000000000000000000000..ad3ad0e7cf4010a14c70a700ed02d02ee1f1323b --- /dev/null +++ b/ppocr/modeling/architectures/cls_model.py @@ -0,0 +1,85 @@ +# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from paddle import fluid + +from ppocr.utils.utility import create_module +from ppocr.utils.utility import initial_logger + +logger = initial_logger() +from copy import deepcopy + + +class ClsModel(object): + def __init__(self, params): + super(ClsModel, self).__init__() + global_params = params['Global'] + self.infer_img = global_params['infer_img'] + + backbone_params = deepcopy(params["Backbone"]) + backbone_params.update(global_params) + self.backbone = create_module(backbone_params['function']) \ + (params=backbone_params) + + head_params = deepcopy(params["Head"]) + head_params.update(global_params) + self.head = create_module(head_params['function']) \ + (params=head_params) + + loss_params = deepcopy(params["Loss"]) + loss_params.update(global_params) + self.loss = create_module(loss_params['function']) \ + (params=loss_params) + + self.image_shape = global_params['image_shape'] + + def create_feed(self, mode): + image_shape = deepcopy(self.image_shape) + image_shape.insert(0, -1) + if mode == "train": + image = fluid.data(name='image', shape=image_shape, dtype='float32') + label = fluid.data(name='label', shape=[None, 1], dtype='int64') + feed_list = [image, label] + labels = {'label': label} + loader = fluid.io.DataLoader.from_generator( + feed_list=feed_list, + capacity=64, + use_double_buffer=True, + iterable=False) + else: + labels = None + loader = None + image = fluid.data(name='image', shape=image_shape, dtype='float32') + return image, labels, loader + + def __call__(self, mode): + image, labels, loader = self.create_feed(mode) + inputs = image + conv_feas = self.backbone(inputs) + predicts = self.head(conv_feas, labels, mode) + if mode == "train": + loss = self.loss(predicts, labels) + label = labels['label'] + acc = fluid.layers.accuracy(predicts['predict'], label, k=1) + outputs = {'total_loss': loss, 'decoded_out': \ + predicts['decoded_out'], 'label': label, 'acc': acc} + return loader, outputs + elif mode == "export": + return [image, predicts] + else: + return loader, predicts diff --git a/ppocr/modeling/heads/cls_head.py b/ppocr/modeling/heads/cls_head.py new file mode 100644 index 0000000000000000000000000000000000000000..4567adcbaef3ddb14f3b46b27f772e3836db6793 --- /dev/null +++ b/ppocr/modeling/heads/cls_head.py @@ -0,0 +1,46 @@ +#copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve. +# +#Licensed under the Apache License, Version 2.0 (the "License"); +#you may not use this file except in compliance with the License. +#You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +#Unless required by applicable law or agreed to in writing, software +#distributed under the License is distributed on an "AS IS" BASIS, +#WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +#See the License for the specific language governing permissions and +#limitations under the License. + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import math + +import paddle +import paddle.fluid as fluid + + +class ClsHead(object): + def __init__(self, params): + super(ClsHead, self).__init__() + self.class_dim = params['class_dim'] + + def __call__(self, inputs, labels=None, mode=None): + pool = fluid.layers.pool2d( + input=inputs, pool_type='avg', global_pooling=True) + stdv = 1.0 / math.sqrt(pool.shape[1] * 1.0) + + out = fluid.layers.fc( + input=pool, + size=self.class_dim, + param_attr=fluid.param_attr.ParamAttr( + name="fc_0.w_0", + initializer=fluid.initializer.Uniform(-stdv, stdv)), + bias_attr=fluid.param_attr.ParamAttr(name="fc_0.b_0")) + + softmax_out = fluid.layers.softmax(out, use_cudnn=False) + out_label = fluid.layers.argmax(out, axis=1) + predicts = {'predict': softmax_out, 'decoded_out': out_label} + return predicts diff --git a/ppocr/modeling/losses/cls_loss.py b/ppocr/modeling/losses/cls_loss.py new file mode 100755 index 0000000000000000000000000000000000000000..c187dce3618feef7503b01704a72a040145816a2 --- /dev/null +++ b/ppocr/modeling/losses/cls_loss.py @@ -0,0 +1,33 @@ +# copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import paddle.fluid as fluid + + +class ClsLoss(object): + def __init__(self, params): + super(ClsLoss, self).__init__() + self.loss_func = fluid.layers.cross_entropy + + def __call__(self, predicts, labels): + predict = predicts['predict'] + label = labels['label'] + # softmax_out = fluid.layers.softmax(predict, use_cudnn=False) + cost = fluid.layers.cross_entropy(input=predict, label=label) + sum_cost = fluid.layers.mean(cost) + return sum_cost diff --git a/tools/eval.py b/tools/eval.py index 22185911db073cd096c9590781609f03feea4fdb..aff5fc7111a062c9b4346e9c2dcbc8f9225fe8da 100755 --- a/tools/eval.py +++ b/tools/eval.py @@ -45,10 +45,12 @@ from ppocr.utils.save_load import init_model from eval_utils.eval_det_utils import eval_det_run from eval_utils.eval_rec_utils import test_rec_benchmark from eval_utils.eval_rec_utils import eval_rec_run +from eval_utils.eval_cls_utils import eval_cls_run def main(): - startup_prog, eval_program, place, config, train_alg_type = program.preprocess() + startup_prog, eval_program, place, config, train_alg_type = program.preprocess( + ) eval_build_outputs = program.build( config, eval_program, startup_prog, mode='test') eval_fetch_name_list = eval_build_outputs[1] @@ -67,6 +69,14 @@ def main(): 'fetch_varname_list':eval_fetch_varname_list} metrics = eval_det_run(exe, config, eval_info_dict, "eval") logger.info("Eval result: {}".format(metrics)) + elif train_alg_type == 'cls': + eval_reader = reader_main(config=config, mode="eval") + eval_info_dict = {'program': eval_program, \ + 'reader': eval_reader, \ + 'fetch_name_list': eval_fetch_name_list, \ + 'fetch_varname_list': eval_fetch_varname_list} + metrics = eval_cls_run(exe, eval_info_dict) + logger.info("Eval result: {}".format(metrics)) else: reader_type = config['Global']['reader_yml'] if "benchmark" not in reader_type: diff --git a/tools/eval_utils/eval_cls_utils.py b/tools/eval_utils/eval_cls_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..9c9b26677dc57b70f7641c26e5ef57ce1d77f1af --- /dev/null +++ b/tools/eval_utils/eval_cls_utils.py @@ -0,0 +1,70 @@ +# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import numpy as np + +__all__ = ['eval_cls_run'] + +import logging + +FORMAT = '%(asctime)s-%(levelname)s: %(message)s' +logging.basicConfig(level=logging.INFO, format=FORMAT) +logger = logging.getLogger(__name__) + + +def eval_cls_run(exe, eval_info_dict): + """ + Run evaluation program, return program outputs. + """ + total_sample_num = 0 + total_acc_num = 0 + total_batch_num = 0 + + for data in eval_info_dict['reader'](): + img_num = len(data) + img_list = [] + label_list = [] + for ino in range(img_num): + img_list.append(data[ino][0]) + label_list.append(data[ino][1]) + + img_list = np.concatenate(img_list, axis=0) + outs = exe.run(eval_info_dict['program'], \ + feed={'image': img_list}, \ + fetch_list=eval_info_dict['fetch_varname_list'], \ + return_numpy=False) + softmax_outs = np.array(outs[1]) + if len(softmax_outs.shape) != 1: + softmax_outs = np.array(outs[0]) + acc, acc_num = cal_cls_acc(softmax_outs, label_list) + total_acc_num += acc_num + total_sample_num += len(label_list) + # logger.info("eval batch id: {}, acc: {}".format(total_batch_num, acc)) + total_batch_num += 1 + avg_acc = total_acc_num * 1.0 / total_sample_num + metrics = {'avg_acc': avg_acc, "total_acc_num": total_acc_num, \ + "total_sample_num": total_sample_num} + return metrics + + +def cal_cls_acc(preds, labels): + acc_num = 0 + for pred, label in zip(preds, labels): + if pred == label: + acc_num += 1 + return acc_num / len(preds), acc_num diff --git a/tools/infer/predict_cls.py b/tools/infer/predict_cls.py new file mode 100755 index 0000000000000000000000000000000000000000..f5e358e95e5b1c9a0134c473877f1e53047f09db --- /dev/null +++ b/tools/infer/predict_cls.py @@ -0,0 +1,145 @@ +# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import os +import sys + +__dir__ = os.path.dirname(os.path.abspath(__file__)) +sys.path.append(__dir__) +sys.path.append(os.path.abspath(os.path.join(__dir__, '../..'))) + +import tools.infer.utility as utility +from ppocr.utils.utility import initial_logger + +logger = initial_logger() +from ppocr.utils.utility import get_image_file_list, check_and_read_gif +import cv2 +import copy +import numpy as np +import math +import time +from paddle import fluid + + +class TextClassifier(object): + def __init__(self, args): + self.predictor, self.input_tensor, self.output_tensors = \ + utility.create_predictor(args, mode="cls") + self.cls_image_shape = [int(v) for v in args.cls_image_shape.split(",")] + self.cls_batch_num = args.rec_batch_num + self.label_list = args.label_list + self.use_zero_copy_run = args.use_zero_copy_run + + def resize_norm_img(self, img): + imgC, imgH, imgW = self.cls_image_shape + h = img.shape[0] + w = img.shape[1] + ratio = w / float(h) + if math.ceil(imgH * ratio) > imgW: + resized_w = imgW + else: + resized_w = int(math.ceil(imgH * ratio)) + resized_image = cv2.resize(img, (resized_w, imgH)) + resized_image = resized_image.astype('float32') + if self.cls_image_shape[0] == 1: + resized_image = resized_image / 255 + resized_image = resized_image[np.newaxis, :] + else: + resized_image = resized_image.transpose((2, 0, 1)) / 255 + resized_image -= 0.5 + resized_image /= 0.5 + padding_im = np.zeros((imgC, imgH, imgW), dtype=np.float32) + padding_im[:, :, 0:resized_w] = resized_image + return padding_im + + def __call__(self, img_list): + img_list = copy.deepcopy(img_list) + img_num = len(img_list) + # Calculate the aspect ratio of all text bars + width_list = [] + for img in img_list: + width_list.append(img.shape[1] / float(img.shape[0])) + # Sorting can speed up the cls process + indices = np.argsort(np.array(width_list)) + + cls_res = [['', 0.0]] * img_num + batch_num = self.cls_batch_num + predict_time = 0 + for beg_img_no in range(0, img_num, batch_num): + end_img_no = min(img_num, beg_img_no + batch_num) + norm_img_batch = [] + max_wh_ratio = 0 + for ino in range(beg_img_no, end_img_no): + h, w = img_list[indices[ino]].shape[0:2] + wh_ratio = w * 1.0 / h + max_wh_ratio = max(max_wh_ratio, wh_ratio) + for ino in range(beg_img_no, end_img_no): + norm_img = self.resize_norm_img(img_list[indices[ino]]) + norm_img = norm_img[np.newaxis, :] + norm_img_batch.append(norm_img) + norm_img_batch = np.concatenate(norm_img_batch) + norm_img_batch = norm_img_batch.copy() + starttime = time.time() + + if self.use_zero_copy_run: + self.input_tensor.copy_from_cpu(norm_img_batch) + self.predictor.zero_copy_run() + else: + norm_img_batch = fluid.core.PaddleTensor(norm_img_batch) + self.predictor.run([norm_img_batch]) + + prob_out = self.output_tensors[0].copy_to_cpu() + label_out = self.output_tensors[1].copy_to_cpu() + if len(label_out.shape) != 1: + prob_out, label_out = label_out, prob_out + + elapse = time.time() - starttime + predict_time += elapse + for rno in range(len(label_out)): + label_idx = label_out[rno] + score = prob_out[rno][label_idx] + label = self.label_list[label_idx] + cls_res[indices[beg_img_no + rno]] = [label, score] + if '180' in label and score > 0.9999: + img_list[indices[beg_img_no + rno]] = cv2.rotate( + img_list[indices[beg_img_no + rno]], 1) + return img_list, cls_res, predict_time + + +def main(args): + image_file_list = get_image_file_list(args.image_dir) + text_classifier = TextClassifier(args) + valid_image_file_list = [] + img_list = [] + for image_file in image_file_list[:10]: + img, flag = check_and_read_gif(image_file) + if not flag: + img = cv2.imread(image_file) + if img is None: + logger.info("error in loading image:{}".format(image_file)) + continue + valid_image_file_list.append(image_file) + img_list.append(img) + try: + img_list, cls_res, predict_time = text_classifier(img_list) + except Exception as e: + print(e) + exit() + for ino in range(len(img_list)): + print("Predicts of %s:%s" % (valid_image_file_list[ino], cls_res[ino])) + print("Total predict time for %d images:%.3f" % + (len(img_list), predict_time)) + + +if __name__ == "__main__": + main(utility.parse_args()) diff --git a/tools/infer/predict_system.py b/tools/infer/predict_system.py index ff5d53e94e8ac110d58f2fda9afeb575cd7f0971..3e6be234c68dcd82f0f9e844f3ad2859000cec88 100755 --- a/tools/infer/predict_system.py +++ b/tools/infer/predict_system.py @@ -13,16 +13,19 @@ # limitations under the License. import os import sys + __dir__ = os.path.dirname(os.path.abspath(__file__)) sys.path.append(__dir__) sys.path.append(os.path.abspath(os.path.join(__dir__, '../..'))) import tools.infer.utility as utility from ppocr.utils.utility import initial_logger + logger = initial_logger() import cv2 import tools.infer.predict_det as predict_det import tools.infer.predict_rec as predict_rec +import tools.infer.predict_cls as predict_cls import copy import numpy as np import math @@ -37,6 +40,9 @@ class TextSystem(object): def __init__(self, args): self.text_detector = predict_det.TextDetector(args) self.text_recognizer = predict_rec.TextRecognizer(args) + self.use_angle_cls = args.use_angle_cls + if self.use_angle_cls: + self.text_classifier = predict_cls.TextClassifier(args) def get_rotate_crop_image(self, img, points): ''' @@ -91,6 +97,11 @@ class TextSystem(object): tmp_box = copy.deepcopy(dt_boxes[bno]) img_crop = self.get_rotate_crop_image(ori_im, tmp_box) img_crop_list.append(img_crop) + if self.use_angle_cls: + img_crop_list, angle_list, elapse = self.text_classifier( + img_crop_list) + print("cls num : {}, elapse : {}".format( + len(img_crop_list), elapse)) rec_res, elapse = self.text_recognizer(img_crop_list) print("rec_res num : {}, elapse : {}".format(len(rec_res), elapse)) # self.print_draw_crop_rec_res(img_crop_list, rec_res) @@ -110,8 +121,8 @@ def sorted_boxes(dt_boxes): _boxes = list(sorted_boxes) for i in range(num_boxes - 1): - if abs(_boxes[i+1][0][1] - _boxes[i][0][1]) < 10 and \ - (_boxes[i + 1][0][0] < _boxes[i][0][0]): + if abs(_boxes[i + 1][0][1] - _boxes[i][0][1]) < 10 and \ + (_boxes[i + 1][0][0] < _boxes[i][0][0]): tmp = _boxes[i] _boxes[i] = _boxes[i + 1] _boxes[i + 1] = tmp diff --git a/tools/infer/utility.py b/tools/infer/utility.py index 3e1f07b8a7127e64a994c34d296c945ad1cafd0a..92212afd5f3e16601939d0ca7882fb3b90c3a9ac 100755 --- a/tools/infer/utility.py +++ b/tools/infer/utility.py @@ -15,6 +15,7 @@ import argparse import os, sys from ppocr.utils.utility import initial_logger + logger = initial_logger() from paddle.fluid.core import PaddleTensor from paddle.fluid.core import AnalysisConfig @@ -31,34 +32,34 @@ def parse_args(): return v.lower() in ("true", "t", "1") parser = argparse.ArgumentParser() - #params for prediction engine + # params for prediction engine parser.add_argument("--use_gpu", type=str2bool, default=True) parser.add_argument("--ir_optim", type=str2bool, default=True) parser.add_argument("--use_tensorrt", type=str2bool, default=False) parser.add_argument("--gpu_mem", type=int, default=8000) - #params for text detector + # params for text detector parser.add_argument("--image_dir", type=str) parser.add_argument("--det_algorithm", type=str, default='DB') parser.add_argument("--det_model_dir", type=str) parser.add_argument("--det_max_side_len", type=float, default=960) - #DB parmas + # DB parmas parser.add_argument("--det_db_thresh", type=float, default=0.3) parser.add_argument("--det_db_box_thresh", type=float, default=0.5) parser.add_argument("--det_db_unclip_ratio", type=float, default=2.0) - #EAST parmas + # EAST parmas parser.add_argument("--det_east_score_thresh", type=float, default=0.8) parser.add_argument("--det_east_cover_thresh", type=float, default=0.1) parser.add_argument("--det_east_nms_thresh", type=float, default=0.2) - #SAST parmas + # SAST parmas parser.add_argument("--det_sast_score_thresh", type=float, default=0.5) parser.add_argument("--det_sast_nms_thresh", type=float, default=0.2) parser.add_argument("--det_sast_polygon", type=bool, default=False) - #params for text recognizer + # params for text recognizer parser.add_argument("--rec_algorithm", type=str, default='CRNN') parser.add_argument("--rec_model_dir", type=str) parser.add_argument("--rec_image_shape", type=str, default="3, 32, 320") @@ -70,14 +71,24 @@ def parse_args(): type=str, default="./ppocr/utils/ppocr_keys_v1.txt") parser.add_argument("--use_space_char", type=bool, default=True) - parser.add_argument("--enable_mkldnn", type=bool, default=False) - parser.add_argument("--use_zero_copy_run", type=bool, default=False) + + # params for text classifier + parser.add_argument("--use_angle_cls", type=str2bool, default=False) + parser.add_argument("--cls_model_dir", type=str) + parser.add_argument("--cls_image_shape", type=str, default="3, 48, 192") + parser.add_argument("--label_list", type=list, default=['0', '180']) + parser.add_argument("--cls_batch_num", type=int, default=30) + + parser.add_argument("--enable_mkldnn", type=str2bool, default=False) + parser.add_argument("--use_zero_copy_run", type=str2bool, default=False) return parser.parse_args() def create_predictor(args, mode): if mode == "det": model_dir = args.det_model_dir + elif mode == 'cls': + model_dir = args.cls_model_dir else: model_dir = args.rec_model_dir @@ -105,7 +116,7 @@ def create_predictor(args, mode): config.set_mkldnn_cache_capacity(10) config.enable_mkldnn() - #config.enable_memory_optim() + # config.enable_memory_optim() config.disable_glog_info() if args.use_zero_copy_run: diff --git a/tools/infer_cls.py b/tools/infer_cls.py new file mode 100755 index 0000000000000000000000000000000000000000..aebdc0761b7ec48f81143ecbb758ce0e4da2edf7 --- /dev/null +++ b/tools/infer_cls.py @@ -0,0 +1,114 @@ +# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import numpy as np +import os +import sys + +__dir__ = os.path.dirname(__file__) +sys.path.append(__dir__) +sys.path.append(os.path.join(__dir__, '..')) + + +def set_paddle_flags(**kwargs): + for key, value in kwargs.items(): + if os.environ.get(key, None) is None: + os.environ[key] = str(value) + + +# NOTE(paddle-dev): All of these flags should be +# set before `import paddle`. Otherwise, it would +# not take any effect. +set_paddle_flags( + FLAGS_eager_delete_tensor_gb=0, # enable GC to save memory +) + +import tools.program as program +from paddle import fluid +from ppocr.utils.utility import initial_logger + +logger = initial_logger() +from ppocr.data.reader_main import reader_main +from ppocr.utils.save_load import init_model +from ppocr.utils.utility import create_module +from ppocr.utils.utility import get_image_file_list + + +def main(): + config = program.load_config(FLAGS.config) + program.merge_config(FLAGS.opt) + logger.info(config) + + # check if set use_gpu=True in paddlepaddle cpu version + use_gpu = config['Global']['use_gpu'] + # check_gpu(use_gpu) + + place = fluid.CUDAPlace(0) if use_gpu else fluid.CPUPlace() + exe = fluid.Executor(place) + + rec_model = create_module(config['Architecture']['function'])(params=config) + startup_prog = fluid.Program() + eval_prog = fluid.Program() + with fluid.program_guard(eval_prog, startup_prog): + with fluid.unique_name.guard(): + _, outputs = rec_model(mode="test") + fetch_name_list = list(outputs.keys()) + fetch_varname_list = [outputs[v].name for v in fetch_name_list] + eval_prog = eval_prog.clone(for_test=True) + exe.run(startup_prog) + + init_model(config, eval_prog, exe) + + blobs = reader_main(config, 'test')() + infer_img = config['Global']['infer_img'] + infer_list = get_image_file_list(infer_img) + max_img_num = len(infer_list) + if len(infer_list) == 0: + logger.info("Can not find img in infer_img dir.") + for i in range(max_img_num): + logger.info("infer_img:%s" % infer_list[i]) + img = next(blobs) + predict = exe.run(program=eval_prog, + feed={"image": img}, + fetch_list=fetch_varname_list, + return_numpy=False) + scores = np.array(predict[0]) + label = np.array(predict[1]) + if len(label.shape) != 1: + label, scores = scores, label + logger.info('\t scores: {}'.format(scores)) + logger.info('\t label: {}'.format(label)) + # save for inference model + target_var = [] + for key, values in outputs.items(): + target_var.append(values) + + fluid.io.save_inference_model( + "./output", + feeded_var_names=['image'], + target_vars=target_var, + executor=exe, + main_program=eval_prog, + model_filename="model", + params_filename="params") + + +if __name__ == '__main__': + parser = program.ArgsParser() + FLAGS = parser.parse_args() + main() diff --git a/tools/program.py b/tools/program.py index be133ac2f0605abc39026587baaf884687e48911..72e479aa132e7e4a47ad87822693e6792f9ade53 100755 --- a/tools/program.py +++ b/tools/program.py @@ -30,6 +30,7 @@ import time from ppocr.utils.stats import TrainingStats from eval_utils.eval_det_utils import eval_det_run from eval_utils.eval_rec_utils import eval_rec_run +from eval_utils.eval_cls_utils import eval_cls_run from ppocr.utils.save_load import save_model import numpy as np from ppocr.utils.character import cal_predicts_accuracy, cal_predicts_accuracy_srn, CharacterOps @@ -409,6 +410,87 @@ def train_eval_rec_run(config, exe, train_info_dict, eval_info_dict): return +def train_eval_cls_run(config, exe, train_info_dict, eval_info_dict): + train_batch_id = 0 + log_smooth_window = config['Global']['log_smooth_window'] + epoch_num = config['Global']['epoch_num'] + print_batch_step = config['Global']['print_batch_step'] + eval_batch_step = config['Global']['eval_batch_step'] + start_eval_step = 0 + if type(eval_batch_step) == list and len(eval_batch_step) >= 2: + start_eval_step = eval_batch_step[0] + eval_batch_step = eval_batch_step[1] + logger.info( + "During the training process, after the {}th iteration, an evaluation is run every {} iterations". + format(start_eval_step, eval_batch_step)) + save_epoch_step = config['Global']['save_epoch_step'] + save_model_dir = config['Global']['save_model_dir'] + if not os.path.exists(save_model_dir): + os.makedirs(save_model_dir) + train_stats = TrainingStats(log_smooth_window, ['loss', 'acc']) + best_eval_acc = -1 + best_batch_id = 0 + best_epoch = 0 + train_loader = train_info_dict['reader'] + for epoch in range(epoch_num): + train_loader.start() + try: + while True: + t1 = time.time() + train_outs = exe.run( + program=train_info_dict['compile_program'], + fetch_list=train_info_dict['fetch_varname_list'], + return_numpy=False) + fetch_map = dict( + zip(train_info_dict['fetch_name_list'], + range(len(train_outs)))) + + loss = np.mean(np.array(train_outs[fetch_map['total_loss']])) + lr = np.mean(np.array(train_outs[fetch_map['lr']])) + acc = np.mean(np.array(train_outs[fetch_map['acc']])) + + t2 = time.time() + train_batch_elapse = t2 - t1 + stats = {'loss': loss, 'acc': acc} + train_stats.update(stats) + if train_batch_id > start_eval_step and (train_batch_id - start_eval_step) \ + % print_batch_step == 0: + logs = train_stats.log() + strs = 'epoch: {}, iter: {}, lr: {:.6f}, {}, time: {:.3f}'.format( + epoch, train_batch_id, lr, logs, train_batch_elapse) + logger.info(strs) + + if train_batch_id > 0 and\ + train_batch_id % eval_batch_step == 0: + model_average = train_info_dict['model_average'] + if model_average != None: + model_average.apply(exe) + metrics = eval_cls_run(exe, eval_info_dict) + eval_acc = metrics['avg_acc'] + eval_sample_num = metrics['total_sample_num'] + if eval_acc > best_eval_acc: + best_eval_acc = eval_acc + best_batch_id = train_batch_id + best_epoch = epoch + save_path = save_model_dir + "/best_accuracy" + save_model(train_info_dict['train_program'], save_path) + strs = 'Test iter: {}, acc:{:.6f}, best_acc:{:.6f}, best_epoch:{}, best_batch_id:{}, eval_sample_num:{}'.format( + train_batch_id, eval_acc, best_eval_acc, best_epoch, + best_batch_id, eval_sample_num) + logger.info(strs) + train_batch_id += 1 + + except fluid.core.EOFException: + train_loader.reset() + if epoch == 0 and save_epoch_step == 1: + save_path = save_model_dir + "/iter_epoch_0" + save_model(train_info_dict['train_program'], save_path) + if epoch > 0 and epoch % save_epoch_step == 0: + save_path = save_model_dir + "/iter_epoch_%d" % (epoch) + save_model(train_info_dict['train_program'], save_path) + return + + def preprocess(): FLAGS = ArgsParser().parse_args() config = load_config(FLAGS.config) @@ -421,7 +503,7 @@ def preprocess(): alg = config['Global']['algorithm'] assert alg in [ - 'EAST', 'DB', 'SAST', 'Rosetta', 'CRNN', 'STARNet', 'RARE', 'SRN' + 'EAST', 'DB', 'SAST', 'Rosetta', 'CRNN', 'STARNet', 'RARE', 'SRN', 'CLS' ] if alg in ['Rosetta', 'CRNN', 'STARNet', 'RARE', 'SRN']: config['Global']['char_ops'] = CharacterOps(config['Global']) @@ -432,7 +514,9 @@ def preprocess(): if alg in ['EAST', 'DB', 'SAST']: train_alg_type = 'det' - else: + elif alg in ['Rosetta', 'CRNN', 'STARNet', 'RARE', 'SRN']: train_alg_type = 'rec' + else: + train_alg_type = 'cls' return startup_program, train_program, place, config, train_alg_type diff --git a/tools/train.py b/tools/train.py index 300705e0288c86ec1e9f2a00ef7311cd7c8e9897..531dd15933ebfd83527f091215c40b85253f7866 100755 --- a/tools/train.py +++ b/tools/train.py @@ -75,7 +75,8 @@ def main(): # dump mode structure if config['Global']['debug']: - if train_alg_type == 'rec' and 'attention' in config['Global']['loss_type']: + if train_alg_type == 'rec' and 'attention' in config['Global'][ + 'loss_type']: logger.warning('Does not suport dump attention...') else: summary(train_program) @@ -96,8 +97,10 @@ def main(): if train_alg_type == 'det': program.train_eval_det_run(config, exe, train_info_dict, eval_info_dict) - else: + elif train_alg_type == 'rec': program.train_eval_rec_run(config, exe, train_info_dict, eval_info_dict) + else: + program.train_eval_cls_run(config, exe, train_info_dict, eval_info_dict) def test_reader(): @@ -119,6 +122,7 @@ def test_reader(): if __name__ == '__main__': - startup_program, train_program, place, config, train_alg_type = program.preprocess() + startup_program, train_program, place, config, train_alg_type = program.preprocess( + ) main() # test_reader()