From 3cb7a6094c09bd0fff22b513aeaac3eb41d891f0 Mon Sep 17 00:00:00 2001 From: WenmuZhou Date: Sat, 26 Feb 2022 23:02:15 +0800 Subject: [PATCH] split det cls rec mode --- .../android_demo/app/src/main/cpp/native.cpp | 29 ++- .../app/src/main/cpp/ocr_ppredictor.cpp | 225 ++++++++++++------ .../app/src/main/cpp/ocr_ppredictor.h | 26 +- .../paddle/lite/demo/ocr/MainActivity.java | 145 ++++++----- .../lite/demo/ocr/OCRPredictorNative.java | 35 +-- .../paddle/lite/demo/ocr/OcrResultModel.java | 27 +++ 6 files changed, 316 insertions(+), 171 deletions(-) diff --git a/deploy/android_demo/app/src/main/cpp/native.cpp b/deploy/android_demo/app/src/main/cpp/native.cpp index 963c5246..ced93255 100644 --- a/deploy/android_demo/app/src/main/cpp/native.cpp +++ b/deploy/android_demo/app/src/main/cpp/native.cpp @@ -13,7 +13,7 @@ static paddle::lite_api::PowerMode str_to_cpu_mode(const std::string &cpu_mode); extern "C" JNIEXPORT jlong JNICALL Java_com_baidu_paddle_lite_demo_ocr_OCRPredictorNative_init( JNIEnv *env, jobject thiz, jstring j_det_model_path, - jstring j_rec_model_path, jstring j_cls_model_path, jint j_thread_num, + jstring j_rec_model_path, jstring j_cls_model_path, jint j_use_opencl, jint j_thread_num, jstring j_cpu_mode) { std::string det_model_path = jstring_to_cpp_string(env, j_det_model_path); std::string rec_model_path = jstring_to_cpp_string(env, j_rec_model_path); @@ -21,6 +21,7 @@ Java_com_baidu_paddle_lite_demo_ocr_OCRPredictorNative_init( int thread_num = j_thread_num; std::string cpu_mode = jstring_to_cpp_string(env, j_cpu_mode); ppredictor::OCR_Config conf; + conf.use_opencl = j_use_opencl; conf.thread_num = thread_num; conf.mode = str_to_cpu_mode(cpu_mode); ppredictor::OCR_PPredictor *orc_predictor = @@ -57,32 +58,31 @@ str_to_cpu_mode(const std::string &cpu_mode) { extern "C" JNIEXPORT jfloatArray JNICALL Java_com_baidu_paddle_lite_demo_ocr_OCRPredictorNative_forward( - JNIEnv *env, jobject thiz, jlong java_pointer, jfloatArray buf, - jfloatArray ddims, jobject original_image) { + JNIEnv *env, jobject thiz, jlong java_pointer, jobject original_image,jint j_max_size_len, jint j_run_det, jint j_run_cls, jint j_run_rec) { LOGI("begin to run native forward"); if (java_pointer == 0) { LOGE("JAVA pointer is NULL"); return cpp_array_to_jfloatarray(env, nullptr, 0); } + cv::Mat origin = bitmap_to_cv_mat(env, original_image); if (origin.size == 0) { LOGE("origin bitmap cannot convert to CV Mat"); return cpp_array_to_jfloatarray(env, nullptr, 0); } + + int max_size_len = j_max_size_len; + int run_det = j_run_det; + int run_cls = j_run_cls; + int run_rec = j_run_rec; + ppredictor::OCR_PPredictor *ppredictor = (ppredictor::OCR_PPredictor *)java_pointer; - std::vector dims_float_arr = jfloatarray_to_float_vector(env, ddims); std::vector dims_arr; - dims_arr.resize(dims_float_arr.size()); - std::copy(dims_float_arr.cbegin(), dims_float_arr.cend(), dims_arr.begin()); - - // 这里值有点大,就不调用jfloatarray_to_float_vector了 - int64_t buf_len = (int64_t)env->GetArrayLength(buf); - jfloat *buf_data = env->GetFloatArrayElements(buf, JNI_FALSE); - float *data = (jfloat *)buf_data; std::vector results = - ppredictor->infer_ocr(dims_arr, data, buf_len, NET_OCR, origin); + ppredictor->infer_ocr(origin, max_size_len, run_det, run_cls, run_rec); LOGI("infer_ocr finished with boxes %ld", results.size()); + // 这里将std::vector 序列化成 // float数组,传输到java层再反序列化 std::vector float_arr; @@ -90,13 +90,18 @@ Java_com_baidu_paddle_lite_demo_ocr_OCRPredictorNative_forward( float_arr.push_back(r.points.size()); float_arr.push_back(r.word_index.size()); float_arr.push_back(r.score); + // add det point for (const std::vector &point : r.points) { float_arr.push_back(point.at(0)); float_arr.push_back(point.at(1)); } + // add rec word idx for (int index : r.word_index) { float_arr.push_back(index); } + // add cls result + float_arr.push_back(r.cls_label); + float_arr.push_back(r.cls_score); } return cpp_array_to_jfloatarray(env, float_arr.data(), float_arr.size()); } diff --git a/deploy/android_demo/app/src/main/cpp/ocr_ppredictor.cpp b/deploy/android_demo/app/src/main/cpp/ocr_ppredictor.cpp index c68456e1..1bd989c9 100644 --- a/deploy/android_demo/app/src/main/cpp/ocr_ppredictor.cpp +++ b/deploy/android_demo/app/src/main/cpp/ocr_ppredictor.cpp @@ -17,15 +17,15 @@ int OCR_PPredictor::init(const std::string &det_model_content, const std::string &rec_model_content, const std::string &cls_model_content) { _det_predictor = std::unique_ptr( - new PPredictor{_config.thread_num, NET_OCR, _config.mode}); + new PPredictor{_config.use_opencl,_config.thread_num, NET_OCR, _config.mode}); _det_predictor->init_nb(det_model_content); _rec_predictor = std::unique_ptr( - new PPredictor{_config.thread_num, NET_OCR_INTERNAL, _config.mode}); + new PPredictor{_config.use_opencl,_config.thread_num, NET_OCR_INTERNAL, _config.mode}); _rec_predictor->init_nb(rec_model_content); _cls_predictor = std::unique_ptr( - new PPredictor{_config.thread_num, NET_OCR_INTERNAL, _config.mode}); + new PPredictor{_config.use_opencl,_config.thread_num, NET_OCR_INTERNAL, _config.mode}); _cls_predictor->init_nb(cls_model_content); return RETURN_OK; } @@ -34,15 +34,16 @@ int OCR_PPredictor::init_from_file(const std::string &det_model_path, const std::string &rec_model_path, const std::string &cls_model_path) { _det_predictor = std::unique_ptr( - new PPredictor{_config.thread_num, NET_OCR, _config.mode}); + new PPredictor{_config.use_opencl, _config.thread_num, NET_OCR, _config.mode}); _det_predictor->init_from_file(det_model_path); + _rec_predictor = std::unique_ptr( - new PPredictor{_config.thread_num, NET_OCR_INTERNAL, _config.mode}); + new PPredictor{_config.use_opencl,_config.thread_num, NET_OCR_INTERNAL, _config.mode}); _rec_predictor->init_from_file(rec_model_path); _cls_predictor = std::unique_ptr( - new PPredictor{_config.thread_num, NET_OCR_INTERNAL, _config.mode}); + new PPredictor{_config.use_opencl,_config.thread_num, NET_OCR_INTERNAL, _config.mode}); _cls_predictor->init_from_file(cls_model_path); return RETURN_OK; } @@ -77,90 +78,173 @@ visual_img(const std::vector>> &filter_boxes, } std::vector -OCR_PPredictor::infer_ocr(const std::vector &dims, - const float *input_data, int input_len, int net_flag, - cv::Mat &origin) { +OCR_PPredictor::infer_ocr(cv::Mat &origin,int max_size_len, int run_det, int run_cls, int run_rec) { + LOGI("ocr cpp start *****************"); + LOGI("ocr cpp det: %d, cls: %d, rec: %d", run_det, run_cls, run_rec); + std::vector ocr_results; + if(run_det){ + infer_det(origin, max_size_len, ocr_results); + } + if(run_rec){ + if(ocr_results.size()==0){ + OCRPredictResult res; + ocr_results.emplace_back(std::move(res)); + } + for(int i = 0; i < ocr_results.size();i++) { + infer_rec(origin, run_cls, ocr_results[i]); + } + }else if(run_cls){ + ClsPredictResult cls_res = infer_cls(origin); + OCRPredictResult res; + res.cls_score = cls_res.cls_score; + res.cls_label = cls_res.cls_label; + ocr_results.push_back(res); + } + + LOGI("ocr cpp end *****************"); + return ocr_results; +} + +cv::Mat DetResizeImg(const cv::Mat img, int max_size_len, + std::vector &ratio_hw) { + int w = img.cols; + int h = img.rows; + + float ratio = 1.f; + int max_wh = w >= h ? w : h; + if (max_wh > max_size_len) { + if (h > w) { + ratio = static_cast(max_size_len) / static_cast(h); + } else { + ratio = static_cast(max_size_len) / static_cast(w); + } + } + + int resize_h = static_cast(float(h) * ratio); + int resize_w = static_cast(float(w) * ratio); + if (resize_h % 32 == 0) + resize_h = resize_h; + else if (resize_h / 32 < 1 + 1e-5) + resize_h = 32; + else + resize_h = (resize_h / 32 - 1) * 32; + + if (resize_w % 32 == 0) + resize_w = resize_w; + else if (resize_w / 32 < 1 + 1e-5) + resize_w = 32; + else + resize_w = (resize_w / 32 - 1) * 32; + + cv::Mat resize_img; + cv::resize(img, resize_img, cv::Size(resize_w, resize_h)); + + ratio_hw.push_back(static_cast(resize_h) / static_cast(h)); + ratio_hw.push_back(static_cast(resize_w) / static_cast(w)); + return resize_img; +} + +void OCR_PPredictor::infer_det(cv::Mat &origin, int max_size_len, std::vector &ocr_results) { + std::vector mean = {0.485f, 0.456f, 0.406f}; + std::vector scale = {1 / 0.229f, 1 / 0.224f, 1 / 0.225f}; + PredictorInput input = _det_predictor->get_first_input(); - input.set_dims(dims); - input.set_data(input_data, input_len); + + std::vector ratio_hw; + cv::Mat input_image = DetResizeImg(origin, max_size_len, ratio_hw); + input_image.convertTo(input_image, CV_32FC3, 1 / 255.0f); + const float *dimg = reinterpret_cast(input_image.data); + int input_size = input_image.rows * input_image.cols; + + input.set_dims({1, 3, input_image.rows, input_image.cols}); + + neon_mean_scale(dimg, input.get_mutable_float_data(), input_size, mean, + scale); + LOGI("ocr cpp det shape %d,%d", input_image.rows,input_image.cols); std::vector results = _det_predictor->infer(); PredictorOutput &res = results.at(0); std::vector>> filtered_box = calc_filtered_boxes( - res.get_float_data(), res.get_size(), (int)dims[2], (int)dims[3], origin); - LOGI("Filter_box size %ld", filtered_box.size()); - return infer_rec(filtered_box, origin); + res.get_float_data(), res.get_size(), input_image.rows, input_image.cols, origin); + LOGI("ocr cpp det Filter_box size %ld", filtered_box.size()); + + for(int i = 0;i OCR_PPredictor::infer_rec( - const std::vector>> &boxes, - const cv::Mat &origin_img) { +void OCR_PPredictor::infer_rec(const cv::Mat &origin_img, int run_cls, OCRPredictResult& ocr_result) { std::vector mean = {0.5f, 0.5f, 0.5f}; std::vector scale = {1 / 0.5f, 1 / 0.5f, 1 / 0.5f}; std::vector dims = {1, 3, 0, 0}; - std::vector ocr_results; PredictorInput input = _rec_predictor->get_first_input(); - for (auto bp = boxes.crbegin(); bp != boxes.crend(); ++bp) { - const std::vector> &box = *bp; - cv::Mat crop_img = get_rotate_crop_image(origin_img, box); - crop_img = infer_cls(crop_img); - float wh_ratio = float(crop_img.cols) / float(crop_img.rows); - cv::Mat input_image = crnn_resize_img(crop_img, wh_ratio); - input_image.convertTo(input_image, CV_32FC3, 1 / 255.0f); - const float *dimg = reinterpret_cast(input_image.data); - int input_size = input_image.rows * input_image.cols; + const std::vector> &box = ocr_result.points; + cv::Mat crop_img; + if(box.size()>0){ + crop_img = get_rotate_crop_image(origin_img, box); + } + else{ + crop_img = origin_img; + } - dims[2] = input_image.rows; - dims[3] = input_image.cols; - input.set_dims(dims); + if(run_cls){ + ClsPredictResult cls_res = infer_cls(crop_img); + crop_img = cls_res.img; + ocr_result.cls_score = cls_res.cls_score; + ocr_result.cls_label = cls_res.cls_label; + } - neon_mean_scale(dimg, input.get_mutable_float_data(), input_size, mean, - scale); - std::vector results = _rec_predictor->infer(); - const float *predict_batch = results.at(0).get_float_data(); - const std::vector predict_shape = results.at(0).get_shape(); + float wh_ratio = float(crop_img.cols) / float(crop_img.rows); + cv::Mat input_image = crnn_resize_img(crop_img, wh_ratio); + input_image.convertTo(input_image, CV_32FC3, 1 / 255.0f); + const float *dimg = reinterpret_cast(input_image.data); + int input_size = input_image.rows * input_image.cols; - OCRPredictResult res; + dims[2] = input_image.rows; + dims[3] = input_image.cols; + input.set_dims(dims); - // ctc decode - int argmax_idx; - int last_index = 0; - float score = 0.f; - int count = 0; - float max_value = 0.0f; - - for (int n = 0; n < predict_shape[1]; n++) { - argmax_idx = int(argmax(&predict_batch[n * predict_shape[2]], - &predict_batch[(n + 1) * predict_shape[2]])); - max_value = - float(*std::max_element(&predict_batch[n * predict_shape[2]], - &predict_batch[(n + 1) * predict_shape[2]])); - if (argmax_idx > 0 && (!(n > 0 && argmax_idx == last_index))) { - score += max_value; - count += 1; - res.word_index.push_back(argmax_idx); - } - last_index = argmax_idx; - } - score /= count; - if (res.word_index.empty()) { - continue; + neon_mean_scale(dimg, input.get_mutable_float_data(), input_size, mean, + scale); + + std::vector results = _rec_predictor->infer(); + const float *predict_batch = results.at(0).get_float_data(); + const std::vector predict_shape = results.at(0).get_shape(); + + // ctc decode + int argmax_idx; + int last_index = 0; + float score = 0.f; + int count = 0; + float max_value = 0.0f; + + for (int n = 0; n < predict_shape[1]; n++) { + argmax_idx = int(argmax(&predict_batch[n * predict_shape[2]], + &predict_batch[(n + 1) * predict_shape[2]])); + max_value = + float(*std::max_element(&predict_batch[n * predict_shape[2]], + &predict_batch[(n + 1) * predict_shape[2]])); + if (argmax_idx > 0 && (!(n > 0 && argmax_idx == last_index))) { + score += max_value; + count += 1; + ocr_result.word_index.push_back(argmax_idx); } - res.score = score; - res.points = box; - ocr_results.emplace_back(std::move(res)); + last_index = argmax_idx; } - LOGI("ocr_results finished %lu", ocr_results.size()); - return ocr_results; + score /= count; + ocr_result.score = score; + LOGI("ocr cpp rec word size %ld", count); } -cv::Mat OCR_PPredictor::infer_cls(const cv::Mat &img, float thresh) { +ClsPredictResult OCR_PPredictor::infer_cls(const cv::Mat &img, float thresh) { std::vector mean = {0.5f, 0.5f, 0.5f}; std::vector scale = {1 / 0.5f, 1 / 0.5f, 1 / 0.5f}; std::vector dims = {1, 3, 0, 0}; - std::vector ocr_results; PredictorInput input = _cls_predictor->get_first_input(); @@ -182,7 +266,7 @@ cv::Mat OCR_PPredictor::infer_cls(const cv::Mat &img, float thresh) { float score = 0; int label = 0; for (int64_t i = 0; i < results.at(0).get_size(); i++) { - LOGI("output scores [%f]", scores[i]); + LOGI("ocr cpp cls output scores [%f]", scores[i]); if (scores[i] > score) { score = scores[i]; label = i; @@ -193,7 +277,12 @@ cv::Mat OCR_PPredictor::infer_cls(const cv::Mat &img, float thresh) { if (label % 2 == 1 && score > thresh) { cv::rotate(srcimg, srcimg, 1); } - return srcimg; + ClsPredictResult res; + res.cls_label = label; + res.cls_score = score; + res.img = srcimg; + LOGI("ocr cpp cls word cls %ld, %f", label, score); + return res; } std::vector>> diff --git a/deploy/android_demo/app/src/main/cpp/ocr_ppredictor.h b/deploy/android_demo/app/src/main/cpp/ocr_ppredictor.h index 588f25cb..f0bff93f 100644 --- a/deploy/android_demo/app/src/main/cpp/ocr_ppredictor.h +++ b/deploy/android_demo/app/src/main/cpp/ocr_ppredictor.h @@ -15,7 +15,8 @@ namespace ppredictor { * Config */ struct OCR_Config { - int thread_num = 4; // Thread num + int use_opencl = 0; + int thread_num = 4; // Thread num paddle::lite_api::PowerMode mode = paddle::lite_api::LITE_POWER_HIGH; // PaddleLite Mode }; @@ -27,8 +28,15 @@ struct OCRPredictResult { std::vector word_index; std::vector> points; float score; + float cls_score; + int cls_label=-1; }; +struct ClsPredictResult { + float cls_score; + int cls_label=-1; + cv::Mat img; +}; /** * OCR there are 2 models * 1. First model(det),select polygones to show where are the texts @@ -62,8 +70,7 @@ public: * @return */ virtual std::vector - infer_ocr(const std::vector &dims, const float *input_data, - int input_len, int net_flag, cv::Mat &origin); + infer_ocr(cv::Mat &origin, int max_size_len, int run_det, int run_cls, int run_rec); virtual NET_TYPE get_net_flag() const; @@ -80,25 +87,26 @@ private: calc_filtered_boxes(const float *pred, int pred_size, int output_height, int output_width, const cv::Mat &origin); + void + infer_det(cv::Mat &origin, int max_side_len, std::vector& ocr_results); /** - * infer for second model + * infer for rec model * * @param boxes * @param origin * @return */ - std::vector - infer_rec(const std::vector>> &boxes, - const cv::Mat &origin); + void + infer_rec(const cv::Mat &origin, int run_cls, OCRPredictResult& ocr_result); - /** + /** * infer for cls model * * @param boxes * @param origin * @return */ - cv::Mat infer_cls(const cv::Mat &origin, float thresh = 0.9); + ClsPredictResult infer_cls(const cv::Mat &origin, float thresh = 0.9); /** * Postprocess or sencod model to extract text diff --git a/deploy/android_demo/app/src/main/java/com/baidu/paddle/lite/demo/ocr/MainActivity.java b/deploy/android_demo/app/src/main/java/com/baidu/paddle/lite/demo/ocr/MainActivity.java index 44fa7374..32a82c6f 100644 --- a/deploy/android_demo/app/src/main/java/com/baidu/paddle/lite/demo/ocr/MainActivity.java +++ b/deploy/android_demo/app/src/main/java/com/baidu/paddle/lite/demo/ocr/MainActivity.java @@ -13,6 +13,7 @@ import android.graphics.BitmapFactory; import android.graphics.drawable.BitmapDrawable; import android.media.ExifInterface; import android.content.res.AssetManager; +import android.media.FaceDetector; import android.net.Uri; import android.os.Bundle; import android.os.Environment; @@ -27,7 +28,9 @@ import android.view.Menu; import android.view.MenuInflater; import android.view.MenuItem; import android.view.View; +import android.widget.CheckBox; import android.widget.ImageView; +import android.widget.Spinner; import android.widget.TextView; import android.widget.Toast; @@ -68,23 +71,24 @@ public class MainActivity extends AppCompatActivity { protected ImageView ivInputImage; protected TextView tvOutputResult; protected TextView tvInferenceTime; + protected CheckBox cbOpencl; + protected Spinner spRunMode; - // Model settings of object detection + // Model settings of ocr protected String modelPath = ""; protected String labelPath = ""; protected String imagePath = ""; protected int cpuThreadNum = 1; protected String cpuPowerMode = ""; - protected String inputColorFormat = ""; - protected long[] inputShape = new long[]{}; - protected float[] inputMean = new float[]{}; - protected float[] inputStd = new float[]{}; + protected int detLongSize = 960; protected float scoreThreshold = 0.1f; private String currentPhotoPath; - private AssetManager assetManager =null; + private AssetManager assetManager = null; protected Predictor predictor = new Predictor(); + private Bitmap cur_predict_image = null; + @Override protected void onCreate(Bundle savedInstanceState) { super.onCreate(savedInstanceState); @@ -98,10 +102,12 @@ public class MainActivity extends AppCompatActivity { // Setup the UI components tvInputSetting = findViewById(R.id.tv_input_setting); + cbOpencl = findViewById(R.id.cb_opencl); tvStatus = findViewById(R.id.tv_model_img_status); ivInputImage = findViewById(R.id.iv_input_image); tvInferenceTime = findViewById(R.id.tv_inference_time); tvOutputResult = findViewById(R.id.tv_output_result); + spRunMode = findViewById(R.id.sp_run_mode); tvInputSetting.setMovementMethod(ScrollingMovementMethod.getInstance()); tvOutputResult.setMovementMethod(ScrollingMovementMethod.getInstance()); @@ -111,26 +117,26 @@ public class MainActivity extends AppCompatActivity { public void handleMessage(Message msg) { switch (msg.what) { case RESPONSE_LOAD_MODEL_SUCCESSED: - if(pbLoadModel!=null && pbLoadModel.isShowing()){ + if (pbLoadModel != null && pbLoadModel.isShowing()) { pbLoadModel.dismiss(); } onLoadModelSuccessed(); break; case RESPONSE_LOAD_MODEL_FAILED: - if(pbLoadModel!=null && pbLoadModel.isShowing()){ + if (pbLoadModel != null && pbLoadModel.isShowing()) { pbLoadModel.dismiss(); } Toast.makeText(MainActivity.this, "Load model failed!", Toast.LENGTH_SHORT).show(); onLoadModelFailed(); break; case RESPONSE_RUN_MODEL_SUCCESSED: - if(pbRunModel!=null && pbRunModel.isShowing()){ + if (pbRunModel != null && pbRunModel.isShowing()) { pbRunModel.dismiss(); } onRunModelSuccessed(); break; case RESPONSE_RUN_MODEL_FAILED: - if(pbRunModel!=null && pbRunModel.isShowing()){ + if (pbRunModel != null && pbRunModel.isShowing()) { pbRunModel.dismiss(); } Toast.makeText(MainActivity.this, "Run model failed!", Toast.LENGTH_SHORT).show(); @@ -185,7 +191,6 @@ public class MainActivity extends AppCompatActivity { model_settingsChanged |= !model_path.equalsIgnoreCase(modelPath); settingsChanged |= !label_path.equalsIgnoreCase(labelPath); settingsChanged |= !image_path.equalsIgnoreCase(imagePath); - int cpu_thread_num = Integer.parseInt(sharedPreferences.getString(getString(R.string.CPU_THREAD_NUM_KEY), getString(R.string.CPU_THREAD_NUM_DEFAULT))); model_settingsChanged |= cpu_thread_num != cpuThreadNum; @@ -194,33 +199,9 @@ public class MainActivity extends AppCompatActivity { getString(R.string.CPU_POWER_MODE_DEFAULT)); model_settingsChanged |= !cpu_power_mode.equalsIgnoreCase(cpuPowerMode); - String input_color_format = - sharedPreferences.getString(getString(R.string.INPUT_COLOR_FORMAT_KEY), - getString(R.string.INPUT_COLOR_FORMAT_DEFAULT)); - settingsChanged |= !input_color_format.equalsIgnoreCase(inputColorFormat); - long[] input_shape = - Utils.parseLongsFromString(sharedPreferences.getString(getString(R.string.INPUT_SHAPE_KEY), - getString(R.string.INPUT_SHAPE_DEFAULT)), ","); - float[] input_mean = - Utils.parseFloatsFromString(sharedPreferences.getString(getString(R.string.INPUT_MEAN_KEY), - getString(R.string.INPUT_MEAN_DEFAULT)), ","); - float[] input_std = - Utils.parseFloatsFromString(sharedPreferences.getString(getString(R.string.INPUT_STD_KEY) - , getString(R.string.INPUT_STD_DEFAULT)), ","); - settingsChanged |= input_shape.length != inputShape.length; - settingsChanged |= input_mean.length != inputMean.length; - settingsChanged |= input_std.length != inputStd.length; - if (!settingsChanged) { - for (int i = 0; i < input_shape.length; i++) { - settingsChanged |= input_shape[i] != inputShape[i]; - } - for (int i = 0; i < input_mean.length; i++) { - settingsChanged |= input_mean[i] != inputMean[i]; - } - for (int i = 0; i < input_std.length; i++) { - settingsChanged |= input_std[i] != inputStd[i]; - } - } + int det_long_size = Integer.parseInt(sharedPreferences.getString(getString(R.string.DET_LONG_SIZE_KEY), + getString(R.string.DET_LONG_SIZE_DEFAULT))); + settingsChanged |= det_long_size != detLongSize; float score_threshold = Float.parseFloat(sharedPreferences.getString(getString(R.string.SCORE_THRESHOLD_KEY), getString(R.string.SCORE_THRESHOLD_DEFAULT))); @@ -228,20 +209,16 @@ public class MainActivity extends AppCompatActivity { if (settingsChanged) { labelPath = label_path; imagePath = image_path; - inputColorFormat = input_color_format; - inputShape = input_shape; - inputMean = input_mean; - inputStd = input_std; + detLongSize = det_long_size; scoreThreshold = score_threshold; set_img(); } - if (model_settingsChanged){ + if (model_settingsChanged) { modelPath = model_path; cpuThreadNum = cpu_thread_num; cpuPowerMode = cpu_power_mode; // Update UI - tvInputSetting.setText("Model: " + modelPath.substring(modelPath.lastIndexOf("/") + 1) + "\n" + "CPU" + - " Thread Num: " + Integer.toString(cpuThreadNum) + "\n" + "CPU Power Mode: " + cpuPowerMode); + tvInputSetting.setText("Model: " + modelPath.substring(modelPath.lastIndexOf("/") + 1) + "\nOPENCL: " + cbOpencl.isChecked() + "\nCPU Thread Num: " + cpuThreadNum + "\nCPU Power Mode: " + cpuPowerMode); tvInputSetting.scrollTo(0, 0); // Reload model if configure has been changed loadModel(); @@ -259,20 +236,28 @@ public class MainActivity extends AppCompatActivity { } public boolean onLoadModel() { - return predictor.init(MainActivity.this, modelPath, labelPath, cpuThreadNum, + if (predictor.isLoaded()) { + predictor.releaseModel(); + } + return predictor.init(MainActivity.this, modelPath, labelPath, cbOpencl.isChecked() ? 1 : 0, cpuThreadNum, cpuPowerMode, - inputColorFormat, - inputShape, inputMean, - inputStd, scoreThreshold); + detLongSize, scoreThreshold); } public boolean onRunModel() { - return predictor.isLoaded() && predictor.runModel(); + String run_mode = spRunMode.getSelectedItem().toString(); + int run_det = run_mode.contains("检测") ? 1 : 0; + int run_cls = run_mode.contains("分类") ? 1 : 0; + int run_rec = run_mode.contains("识别") ? 1 : 0; + return predictor.isLoaded() && predictor.runModel(run_det, run_cls, run_rec); } public void onLoadModelSuccessed() { // Load test image from path and run model + tvInputSetting.setText("Model: " + modelPath.substring(modelPath.lastIndexOf("/") + 1) + "\nOPENCL: " + cbOpencl.isChecked() + "\nCPU Thread Num: " + cpuThreadNum + "\nCPU Power Mode: " + cpuPowerMode); + tvInputSetting.scrollTo(0, 0); tvStatus.setText("STATUS: load model successed"); + } public void onLoadModelFailed() { @@ -306,9 +291,9 @@ public class MainActivity extends AppCompatActivity { public void set_img() { // Load test image from path and run model try { - assetManager= getAssets(); - InputStream in=assetManager.open(imagePath); - Bitmap bmp=BitmapFactory.decodeStream(in); + assetManager = getAssets(); + InputStream in = assetManager.open(imagePath); + Bitmap bmp = BitmapFactory.decodeStream(in); ivInputImage.setImageBitmap(bmp); } catch (IOException e) { Toast.makeText(MainActivity.this, "Load image failed!", Toast.LENGTH_SHORT).show(); @@ -469,28 +454,28 @@ public class MainActivity extends AppCompatActivity { } } - public void btn_load_model_click(View view) { - if (predictor.isLoaded()){ - tvStatus.setText("STATUS: model has been loaded"); - }else{ - tvStatus.setText("STATUS: load model ......"); - loadModel(); - } + public void btn_reset_img_click(View view) { + ivInputImage.setImageBitmap(cur_predict_image); + } + + public void cb_opencl_click(View view) { + tvStatus.setText("STATUS: load model ......"); + loadModel(); } public void btn_run_model_click(View view) { - Bitmap image =((BitmapDrawable)ivInputImage.getDrawable()).getBitmap(); - if(image == null) { + cur_predict_image = ((BitmapDrawable) ivInputImage.getDrawable()).getBitmap(); + if (cur_predict_image == null) { tvStatus.setText("STATUS: image is not exists"); - } - else if (!predictor.isLoaded()){ + } else if (!predictor.isLoaded()) { tvStatus.setText("STATUS: model is not loaded"); - }else{ + } else { tvStatus.setText("STATUS: run model ...... "); - predictor.setInputImage(image); + predictor.setInputImage(cur_predict_image); runModel(); } } + public void btn_choice_img_click(View view) { if (requestAllPermissions()) { openGallery(); @@ -511,4 +496,32 @@ public class MainActivity extends AppCompatActivity { worker.quit(); super.onDestroy(); } + + public int get_run_mode() { + String run_mode = spRunMode.getSelectedItem().toString(); + int mode; + switch (run_mode) { + case "检测+分类+识别": + mode = 1; + break; + case "检测+识别": + mode = 2; + break; + case "识别+分类": + mode = 3; + break; + case "检测": + mode = 4; + break; + case "识别": + mode = 5; + break; + case "分类": + mode = 6; + break; + default: + mode = 1; + } + return mode; + } } diff --git a/deploy/android_demo/app/src/main/java/com/baidu/paddle/lite/demo/ocr/OCRPredictorNative.java b/deploy/android_demo/app/src/main/java/com/baidu/paddle/lite/demo/ocr/OCRPredictorNative.java index 1fa419e3..622da2a3 100644 --- a/deploy/android_demo/app/src/main/java/com/baidu/paddle/lite/demo/ocr/OCRPredictorNative.java +++ b/deploy/android_demo/app/src/main/java/com/baidu/paddle/lite/demo/ocr/OCRPredictorNative.java @@ -29,22 +29,22 @@ public class OCRPredictorNative { public OCRPredictorNative(Config config) { this.config = config; loadLibrary(); - nativePointer = init(config.detModelFilename, config.recModelFilename,config.clsModelFilename, + nativePointer = init(config.detModelFilename, config.recModelFilename, config.clsModelFilename, config.useOpencl, config.cpuThreadNum, config.cpuPower); Log.i("OCRPredictorNative", "load success " + nativePointer); } - public ArrayList runImage(float[] inputData, int width, int height, int channels, Bitmap originalImage) { - Log.i("OCRPredictorNative", "begin to run image " + inputData.length + " " + width + " " + height); - float[] dims = new float[]{1, channels, height, width}; - float[] rawResults = forward(nativePointer, inputData, dims, originalImage); + public ArrayList runImage(Bitmap originalImage, int max_size_len, int run_det, int run_cls, int run_rec) { + Log.i("OCRPredictorNative", "begin to run image "); + float[] rawResults = forward(nativePointer, originalImage, max_size_len, run_det, run_cls, run_rec); ArrayList results = postprocess(rawResults); return results; } public static class Config { + public int useOpencl; public int cpuThreadNum; public String cpuPower; public String detModelFilename; @@ -53,16 +53,16 @@ public class OCRPredictorNative { } - public void destory(){ + public void destory() { if (nativePointer > 0) { release(nativePointer); nativePointer = 0; } } - protected native long init(String detModelPath, String recModelPath,String clsModelPath, int threadNum, String cpuMode); + protected native long init(String detModelPath, String recModelPath, String clsModelPath, int useOpencl, int threadNum, String cpuMode); - protected native float[] forward(long pointer, float[] buf, float[] ddims, Bitmap originalImage); + protected native float[] forward(long pointer, Bitmap originalImage,int max_size_len, int run_det, int run_cls, int run_rec); protected native void release(long pointer); @@ -73,9 +73,9 @@ public class OCRPredictorNative { while (begin < raw.length) { int point_num = Math.round(raw[begin]); int word_num = Math.round(raw[begin + 1]); - OcrResultModel model = parse(raw, begin + 2, point_num, word_num); - begin += 2 + 1 + point_num * 2 + word_num; - results.add(model); + OcrResultModel res = parse(raw, begin + 2, point_num, word_num); + begin += 2 + 1 + point_num * 2 + word_num + 2; + results.add(res); } return results; @@ -83,19 +83,22 @@ public class OCRPredictorNative { private OcrResultModel parse(float[] raw, int begin, int pointNum, int wordNum) { int current = begin; - OcrResultModel model = new OcrResultModel(); - model.setConfidence(raw[current]); + OcrResultModel res = new OcrResultModel(); + res.setConfidence(raw[current]); current++; for (int i = 0; i < pointNum; i++) { - model.addPoints(Math.round(raw[current + i * 2]), Math.round(raw[current + i * 2 + 1])); + res.addPoints(Math.round(raw[current + i * 2]), Math.round(raw[current + i * 2 + 1])); } current += (pointNum * 2); for (int i = 0; i < wordNum; i++) { int index = Math.round(raw[current + i]); - model.addWordIndex(index); + res.addWordIndex(index); } + current += wordNum; + res.setClsIdx(raw[current]); + res.setClsConfidence(raw[current + 1]); Log.i("OCRPredictorNative", "word finished " + wordNum); - return model; + return res; } diff --git a/deploy/android_demo/app/src/main/java/com/baidu/paddle/lite/demo/ocr/OcrResultModel.java b/deploy/android_demo/app/src/main/java/com/baidu/paddle/lite/demo/ocr/OcrResultModel.java index 9494574e..1bccbc7d 100644 --- a/deploy/android_demo/app/src/main/java/com/baidu/paddle/lite/demo/ocr/OcrResultModel.java +++ b/deploy/android_demo/app/src/main/java/com/baidu/paddle/lite/demo/ocr/OcrResultModel.java @@ -10,6 +10,9 @@ public class OcrResultModel { private List wordIndex; private String label; private float confidence; + private float cls_idx; + private String cls_label; + private float cls_confidence; public OcrResultModel() { super(); @@ -49,4 +52,28 @@ public class OcrResultModel { public void setConfidence(float confidence) { this.confidence = confidence; } + + public float getClsIdx() { + return cls_idx; + } + + public void setClsIdx(float idx) { + this.cls_idx = idx; + } + + public String getClsLabel() { + return cls_label; + } + + public void setClsLabel(String label) { + this.cls_label = label; + } + + public float getClsConfidence() { + return cls_confidence; + } + + public void setClsConfidence(float confidence) { + this.cls_confidence = confidence; + } } -- GitLab