提交 3cb7a609 编写于 作者: W WenmuZhou

split det cls rec mode

上级 cbbd8f79
...@@ -13,7 +13,7 @@ static paddle::lite_api::PowerMode str_to_cpu_mode(const std::string &cpu_mode); ...@@ -13,7 +13,7 @@ static paddle::lite_api::PowerMode str_to_cpu_mode(const std::string &cpu_mode);
extern "C" JNIEXPORT jlong JNICALL extern "C" JNIEXPORT jlong JNICALL
Java_com_baidu_paddle_lite_demo_ocr_OCRPredictorNative_init( Java_com_baidu_paddle_lite_demo_ocr_OCRPredictorNative_init(
JNIEnv *env, jobject thiz, jstring j_det_model_path, JNIEnv *env, jobject thiz, jstring j_det_model_path,
jstring j_rec_model_path, jstring j_cls_model_path, jint j_thread_num, jstring j_rec_model_path, jstring j_cls_model_path, jint j_use_opencl, jint j_thread_num,
jstring j_cpu_mode) { jstring j_cpu_mode) {
std::string det_model_path = jstring_to_cpp_string(env, j_det_model_path); std::string det_model_path = jstring_to_cpp_string(env, j_det_model_path);
std::string rec_model_path = jstring_to_cpp_string(env, j_rec_model_path); std::string rec_model_path = jstring_to_cpp_string(env, j_rec_model_path);
...@@ -21,6 +21,7 @@ Java_com_baidu_paddle_lite_demo_ocr_OCRPredictorNative_init( ...@@ -21,6 +21,7 @@ Java_com_baidu_paddle_lite_demo_ocr_OCRPredictorNative_init(
int thread_num = j_thread_num; int thread_num = j_thread_num;
std::string cpu_mode = jstring_to_cpp_string(env, j_cpu_mode); std::string cpu_mode = jstring_to_cpp_string(env, j_cpu_mode);
ppredictor::OCR_Config conf; ppredictor::OCR_Config conf;
conf.use_opencl = j_use_opencl;
conf.thread_num = thread_num; conf.thread_num = thread_num;
conf.mode = str_to_cpu_mode(cpu_mode); conf.mode = str_to_cpu_mode(cpu_mode);
ppredictor::OCR_PPredictor *orc_predictor = ppredictor::OCR_PPredictor *orc_predictor =
...@@ -57,32 +58,31 @@ str_to_cpu_mode(const std::string &cpu_mode) { ...@@ -57,32 +58,31 @@ str_to_cpu_mode(const std::string &cpu_mode) {
extern "C" JNIEXPORT jfloatArray JNICALL extern "C" JNIEXPORT jfloatArray JNICALL
Java_com_baidu_paddle_lite_demo_ocr_OCRPredictorNative_forward( Java_com_baidu_paddle_lite_demo_ocr_OCRPredictorNative_forward(
JNIEnv *env, jobject thiz, jlong java_pointer, jfloatArray buf, JNIEnv *env, jobject thiz, jlong java_pointer, jobject original_image,jint j_max_size_len, jint j_run_det, jint j_run_cls, jint j_run_rec) {
jfloatArray ddims, jobject original_image) {
LOGI("begin to run native forward"); LOGI("begin to run native forward");
if (java_pointer == 0) { if (java_pointer == 0) {
LOGE("JAVA pointer is NULL"); LOGE("JAVA pointer is NULL");
return cpp_array_to_jfloatarray(env, nullptr, 0); return cpp_array_to_jfloatarray(env, nullptr, 0);
} }
cv::Mat origin = bitmap_to_cv_mat(env, original_image); cv::Mat origin = bitmap_to_cv_mat(env, original_image);
if (origin.size == 0) { if (origin.size == 0) {
LOGE("origin bitmap cannot convert to CV Mat"); LOGE("origin bitmap cannot convert to CV Mat");
return cpp_array_to_jfloatarray(env, nullptr, 0); return cpp_array_to_jfloatarray(env, nullptr, 0);
} }
int max_size_len = j_max_size_len;
int run_det = j_run_det;
int run_cls = j_run_cls;
int run_rec = j_run_rec;
ppredictor::OCR_PPredictor *ppredictor = ppredictor::OCR_PPredictor *ppredictor =
(ppredictor::OCR_PPredictor *)java_pointer; (ppredictor::OCR_PPredictor *)java_pointer;
std::vector<float> dims_float_arr = jfloatarray_to_float_vector(env, ddims);
std::vector<int64_t> dims_arr; std::vector<int64_t> dims_arr;
dims_arr.resize(dims_float_arr.size());
std::copy(dims_float_arr.cbegin(), dims_float_arr.cend(), dims_arr.begin());
// 这里值有点大,就不调用jfloatarray_to_float_vector了
int64_t buf_len = (int64_t)env->GetArrayLength(buf);
jfloat *buf_data = env->GetFloatArrayElements(buf, JNI_FALSE);
float *data = (jfloat *)buf_data;
std::vector<ppredictor::OCRPredictResult> results = std::vector<ppredictor::OCRPredictResult> results =
ppredictor->infer_ocr(dims_arr, data, buf_len, NET_OCR, origin); ppredictor->infer_ocr(origin, max_size_len, run_det, run_cls, run_rec);
LOGI("infer_ocr finished with boxes %ld", results.size()); LOGI("infer_ocr finished with boxes %ld", results.size());
// 这里将std::vector<ppredictor::OCRPredictResult> 序列化成 // 这里将std::vector<ppredictor::OCRPredictResult> 序列化成
// float数组,传输到java层再反序列化 // float数组,传输到java层再反序列化
std::vector<float> float_arr; std::vector<float> float_arr;
...@@ -90,13 +90,18 @@ Java_com_baidu_paddle_lite_demo_ocr_OCRPredictorNative_forward( ...@@ -90,13 +90,18 @@ Java_com_baidu_paddle_lite_demo_ocr_OCRPredictorNative_forward(
float_arr.push_back(r.points.size()); float_arr.push_back(r.points.size());
float_arr.push_back(r.word_index.size()); float_arr.push_back(r.word_index.size());
float_arr.push_back(r.score); float_arr.push_back(r.score);
// add det point
for (const std::vector<int> &point : r.points) { for (const std::vector<int> &point : r.points) {
float_arr.push_back(point.at(0)); float_arr.push_back(point.at(0));
float_arr.push_back(point.at(1)); float_arr.push_back(point.at(1));
} }
// add rec word idx
for (int index : r.word_index) { for (int index : r.word_index) {
float_arr.push_back(index); float_arr.push_back(index);
} }
// add cls result
float_arr.push_back(r.cls_label);
float_arr.push_back(r.cls_score);
} }
return cpp_array_to_jfloatarray(env, float_arr.data(), float_arr.size()); return cpp_array_to_jfloatarray(env, float_arr.data(), float_arr.size());
} }
......
...@@ -17,15 +17,15 @@ int OCR_PPredictor::init(const std::string &det_model_content, ...@@ -17,15 +17,15 @@ int OCR_PPredictor::init(const std::string &det_model_content,
const std::string &rec_model_content, const std::string &rec_model_content,
const std::string &cls_model_content) { const std::string &cls_model_content) {
_det_predictor = std::unique_ptr<PPredictor>( _det_predictor = std::unique_ptr<PPredictor>(
new PPredictor{_config.thread_num, NET_OCR, _config.mode}); new PPredictor{_config.use_opencl,_config.thread_num, NET_OCR, _config.mode});
_det_predictor->init_nb(det_model_content); _det_predictor->init_nb(det_model_content);
_rec_predictor = std::unique_ptr<PPredictor>( _rec_predictor = std::unique_ptr<PPredictor>(
new PPredictor{_config.thread_num, NET_OCR_INTERNAL, _config.mode}); new PPredictor{_config.use_opencl,_config.thread_num, NET_OCR_INTERNAL, _config.mode});
_rec_predictor->init_nb(rec_model_content); _rec_predictor->init_nb(rec_model_content);
_cls_predictor = std::unique_ptr<PPredictor>( _cls_predictor = std::unique_ptr<PPredictor>(
new PPredictor{_config.thread_num, NET_OCR_INTERNAL, _config.mode}); new PPredictor{_config.use_opencl,_config.thread_num, NET_OCR_INTERNAL, _config.mode});
_cls_predictor->init_nb(cls_model_content); _cls_predictor->init_nb(cls_model_content);
return RETURN_OK; return RETURN_OK;
} }
...@@ -34,15 +34,16 @@ int OCR_PPredictor::init_from_file(const std::string &det_model_path, ...@@ -34,15 +34,16 @@ int OCR_PPredictor::init_from_file(const std::string &det_model_path,
const std::string &rec_model_path, const std::string &rec_model_path,
const std::string &cls_model_path) { const std::string &cls_model_path) {
_det_predictor = std::unique_ptr<PPredictor>( _det_predictor = std::unique_ptr<PPredictor>(
new PPredictor{_config.thread_num, NET_OCR, _config.mode}); new PPredictor{_config.use_opencl, _config.thread_num, NET_OCR, _config.mode});
_det_predictor->init_from_file(det_model_path); _det_predictor->init_from_file(det_model_path);
_rec_predictor = std::unique_ptr<PPredictor>( _rec_predictor = std::unique_ptr<PPredictor>(
new PPredictor{_config.thread_num, NET_OCR_INTERNAL, _config.mode}); new PPredictor{_config.use_opencl,_config.thread_num, NET_OCR_INTERNAL, _config.mode});
_rec_predictor->init_from_file(rec_model_path); _rec_predictor->init_from_file(rec_model_path);
_cls_predictor = std::unique_ptr<PPredictor>( _cls_predictor = std::unique_ptr<PPredictor>(
new PPredictor{_config.thread_num, NET_OCR_INTERNAL, _config.mode}); new PPredictor{_config.use_opencl,_config.thread_num, NET_OCR_INTERNAL, _config.mode});
_cls_predictor->init_from_file(cls_model_path); _cls_predictor->init_from_file(cls_model_path);
return RETURN_OK; return RETURN_OK;
} }
...@@ -77,90 +78,173 @@ visual_img(const std::vector<std::vector<std::vector<int>>> &filter_boxes, ...@@ -77,90 +78,173 @@ visual_img(const std::vector<std::vector<std::vector<int>>> &filter_boxes,
} }
std::vector<OCRPredictResult> std::vector<OCRPredictResult>
OCR_PPredictor::infer_ocr(const std::vector<int64_t> &dims, OCR_PPredictor::infer_ocr(cv::Mat &origin,int max_size_len, int run_det, int run_cls, int run_rec) {
const float *input_data, int input_len, int net_flag, LOGI("ocr cpp start *****************");
cv::Mat &origin) { LOGI("ocr cpp det: %d, cls: %d, rec: %d", run_det, run_cls, run_rec);
std::vector<OCRPredictResult> ocr_results;
if(run_det){
infer_det(origin, max_size_len, ocr_results);
}
if(run_rec){
if(ocr_results.size()==0){
OCRPredictResult res;
ocr_results.emplace_back(std::move(res));
}
for(int i = 0; i < ocr_results.size();i++) {
infer_rec(origin, run_cls, ocr_results[i]);
}
}else if(run_cls){
ClsPredictResult cls_res = infer_cls(origin);
OCRPredictResult res;
res.cls_score = cls_res.cls_score;
res.cls_label = cls_res.cls_label;
ocr_results.push_back(res);
}
LOGI("ocr cpp end *****************");
return ocr_results;
}
cv::Mat DetResizeImg(const cv::Mat img, int max_size_len,
std::vector<float> &ratio_hw) {
int w = img.cols;
int h = img.rows;
float ratio = 1.f;
int max_wh = w >= h ? w : h;
if (max_wh > max_size_len) {
if (h > w) {
ratio = static_cast<float>(max_size_len) / static_cast<float>(h);
} else {
ratio = static_cast<float>(max_size_len) / static_cast<float>(w);
}
}
int resize_h = static_cast<int>(float(h) * ratio);
int resize_w = static_cast<int>(float(w) * ratio);
if (resize_h % 32 == 0)
resize_h = resize_h;
else if (resize_h / 32 < 1 + 1e-5)
resize_h = 32;
else
resize_h = (resize_h / 32 - 1) * 32;
if (resize_w % 32 == 0)
resize_w = resize_w;
else if (resize_w / 32 < 1 + 1e-5)
resize_w = 32;
else
resize_w = (resize_w / 32 - 1) * 32;
cv::Mat resize_img;
cv::resize(img, resize_img, cv::Size(resize_w, resize_h));
ratio_hw.push_back(static_cast<float>(resize_h) / static_cast<float>(h));
ratio_hw.push_back(static_cast<float>(resize_w) / static_cast<float>(w));
return resize_img;
}
void OCR_PPredictor::infer_det(cv::Mat &origin, int max_size_len, std::vector<OCRPredictResult> &ocr_results) {
std::vector<float> mean = {0.485f, 0.456f, 0.406f};
std::vector<float> scale = {1 / 0.229f, 1 / 0.224f, 1 / 0.225f};
PredictorInput input = _det_predictor->get_first_input(); PredictorInput input = _det_predictor->get_first_input();
input.set_dims(dims);
input.set_data(input_data, input_len); std::vector<float> ratio_hw;
cv::Mat input_image = DetResizeImg(origin, max_size_len, ratio_hw);
input_image.convertTo(input_image, CV_32FC3, 1 / 255.0f);
const float *dimg = reinterpret_cast<const float *>(input_image.data);
int input_size = input_image.rows * input_image.cols;
input.set_dims({1, 3, input_image.rows, input_image.cols});
neon_mean_scale(dimg, input.get_mutable_float_data(), input_size, mean,
scale);
LOGI("ocr cpp det shape %d,%d", input_image.rows,input_image.cols);
std::vector<PredictorOutput> results = _det_predictor->infer(); std::vector<PredictorOutput> results = _det_predictor->infer();
PredictorOutput &res = results.at(0); PredictorOutput &res = results.at(0);
std::vector<std::vector<std::vector<int>>> filtered_box = calc_filtered_boxes( std::vector<std::vector<std::vector<int>>> filtered_box = calc_filtered_boxes(
res.get_float_data(), res.get_size(), (int)dims[2], (int)dims[3], origin); res.get_float_data(), res.get_size(), input_image.rows, input_image.cols, origin);
LOGI("Filter_box size %ld", filtered_box.size()); LOGI("ocr cpp det Filter_box size %ld", filtered_box.size());
return infer_rec(filtered_box, origin);
for(int i = 0;i<filtered_box.size();i++){
LOGI("ocr cpp box %d,%d,%d,%d,%d,%d,%d,%d", filtered_box[i][0][0],filtered_box[i][0][1], filtered_box[i][1][0],filtered_box[i][1][1], filtered_box[i][2][0],filtered_box[i][2][1], filtered_box[i][3][0],filtered_box[i][3][1]);
OCRPredictResult res;
res.points = filtered_box[i];
ocr_results.push_back(res);
}
} }
std::vector<OCRPredictResult> OCR_PPredictor::infer_rec( void OCR_PPredictor::infer_rec(const cv::Mat &origin_img, int run_cls, OCRPredictResult& ocr_result) {
const std::vector<std::vector<std::vector<int>>> &boxes,
const cv::Mat &origin_img) {
std::vector<float> mean = {0.5f, 0.5f, 0.5f}; std::vector<float> mean = {0.5f, 0.5f, 0.5f};
std::vector<float> scale = {1 / 0.5f, 1 / 0.5f, 1 / 0.5f}; std::vector<float> scale = {1 / 0.5f, 1 / 0.5f, 1 / 0.5f};
std::vector<int64_t> dims = {1, 3, 0, 0}; std::vector<int64_t> dims = {1, 3, 0, 0};
std::vector<OCRPredictResult> ocr_results;
PredictorInput input = _rec_predictor->get_first_input(); PredictorInput input = _rec_predictor->get_first_input();
for (auto bp = boxes.crbegin(); bp != boxes.crend(); ++bp) {
const std::vector<std::vector<int>> &box = *bp;
cv::Mat crop_img = get_rotate_crop_image(origin_img, box);
crop_img = infer_cls(crop_img);
float wh_ratio = float(crop_img.cols) / float(crop_img.rows); const std::vector<std::vector<int>> &box = ocr_result.points;
cv::Mat input_image = crnn_resize_img(crop_img, wh_ratio); cv::Mat crop_img;
input_image.convertTo(input_image, CV_32FC3, 1 / 255.0f); if(box.size()>0){
const float *dimg = reinterpret_cast<const float *>(input_image.data); crop_img = get_rotate_crop_image(origin_img, box);
int input_size = input_image.rows * input_image.cols; }
else{
crop_img = origin_img;
}
dims[2] = input_image.rows; if(run_cls){
dims[3] = input_image.cols; ClsPredictResult cls_res = infer_cls(crop_img);
input.set_dims(dims); crop_img = cls_res.img;
ocr_result.cls_score = cls_res.cls_score;
ocr_result.cls_label = cls_res.cls_label;
}
neon_mean_scale(dimg, input.get_mutable_float_data(), input_size, mean,
scale);
std::vector<PredictorOutput> results = _rec_predictor->infer(); float wh_ratio = float(crop_img.cols) / float(crop_img.rows);
const float *predict_batch = results.at(0).get_float_data(); cv::Mat input_image = crnn_resize_img(crop_img, wh_ratio);
const std::vector<int64_t> predict_shape = results.at(0).get_shape(); input_image.convertTo(input_image, CV_32FC3, 1 / 255.0f);
const float *dimg = reinterpret_cast<const float *>(input_image.data);
int input_size = input_image.rows * input_image.cols;
OCRPredictResult res; dims[2] = input_image.rows;
dims[3] = input_image.cols;
input.set_dims(dims);
// ctc decode neon_mean_scale(dimg, input.get_mutable_float_data(), input_size, mean,
int argmax_idx; scale);
int last_index = 0;
float score = 0.f; std::vector<PredictorOutput> results = _rec_predictor->infer();
int count = 0; const float *predict_batch = results.at(0).get_float_data();
float max_value = 0.0f; const std::vector<int64_t> predict_shape = results.at(0).get_shape();
for (int n = 0; n < predict_shape[1]; n++) { // ctc decode
argmax_idx = int(argmax(&predict_batch[n * predict_shape[2]], int argmax_idx;
&predict_batch[(n + 1) * predict_shape[2]])); int last_index = 0;
max_value = float score = 0.f;
float(*std::max_element(&predict_batch[n * predict_shape[2]], int count = 0;
&predict_batch[(n + 1) * predict_shape[2]])); float max_value = 0.0f;
if (argmax_idx > 0 && (!(n > 0 && argmax_idx == last_index))) {
score += max_value; for (int n = 0; n < predict_shape[1]; n++) {
count += 1; argmax_idx = int(argmax(&predict_batch[n * predict_shape[2]],
res.word_index.push_back(argmax_idx); &predict_batch[(n + 1) * predict_shape[2]]));
} max_value =
last_index = argmax_idx; float(*std::max_element(&predict_batch[n * predict_shape[2]],
} &predict_batch[(n + 1) * predict_shape[2]]));
score /= count; if (argmax_idx > 0 && (!(n > 0 && argmax_idx == last_index))) {
if (res.word_index.empty()) { score += max_value;
continue; count += 1;
ocr_result.word_index.push_back(argmax_idx);
} }
res.score = score; last_index = argmax_idx;
res.points = box;
ocr_results.emplace_back(std::move(res));
} }
LOGI("ocr_results finished %lu", ocr_results.size()); score /= count;
return ocr_results; ocr_result.score = score;
LOGI("ocr cpp rec word size %ld", count);
} }
cv::Mat OCR_PPredictor::infer_cls(const cv::Mat &img, float thresh) { ClsPredictResult OCR_PPredictor::infer_cls(const cv::Mat &img, float thresh) {
std::vector<float> mean = {0.5f, 0.5f, 0.5f}; std::vector<float> mean = {0.5f, 0.5f, 0.5f};
std::vector<float> scale = {1 / 0.5f, 1 / 0.5f, 1 / 0.5f}; std::vector<float> scale = {1 / 0.5f, 1 / 0.5f, 1 / 0.5f};
std::vector<int64_t> dims = {1, 3, 0, 0}; std::vector<int64_t> dims = {1, 3, 0, 0};
std::vector<OCRPredictResult> ocr_results;
PredictorInput input = _cls_predictor->get_first_input(); PredictorInput input = _cls_predictor->get_first_input();
...@@ -182,7 +266,7 @@ cv::Mat OCR_PPredictor::infer_cls(const cv::Mat &img, float thresh) { ...@@ -182,7 +266,7 @@ cv::Mat OCR_PPredictor::infer_cls(const cv::Mat &img, float thresh) {
float score = 0; float score = 0;
int label = 0; int label = 0;
for (int64_t i = 0; i < results.at(0).get_size(); i++) { for (int64_t i = 0; i < results.at(0).get_size(); i++) {
LOGI("output scores [%f]", scores[i]); LOGI("ocr cpp cls output scores [%f]", scores[i]);
if (scores[i] > score) { if (scores[i] > score) {
score = scores[i]; score = scores[i];
label = i; label = i;
...@@ -193,7 +277,12 @@ cv::Mat OCR_PPredictor::infer_cls(const cv::Mat &img, float thresh) { ...@@ -193,7 +277,12 @@ cv::Mat OCR_PPredictor::infer_cls(const cv::Mat &img, float thresh) {
if (label % 2 == 1 && score > thresh) { if (label % 2 == 1 && score > thresh) {
cv::rotate(srcimg, srcimg, 1); cv::rotate(srcimg, srcimg, 1);
} }
return srcimg; ClsPredictResult res;
res.cls_label = label;
res.cls_score = score;
res.img = srcimg;
LOGI("ocr cpp cls word cls %ld, %f", label, score);
return res;
} }
std::vector<std::vector<std::vector<int>>> std::vector<std::vector<std::vector<int>>>
......
...@@ -15,7 +15,8 @@ namespace ppredictor { ...@@ -15,7 +15,8 @@ namespace ppredictor {
* Config * Config
*/ */
struct OCR_Config { struct OCR_Config {
int thread_num = 4; // Thread num int use_opencl = 0;
int thread_num = 4; // Thread num
paddle::lite_api::PowerMode mode = paddle::lite_api::PowerMode mode =
paddle::lite_api::LITE_POWER_HIGH; // PaddleLite Mode paddle::lite_api::LITE_POWER_HIGH; // PaddleLite Mode
}; };
...@@ -27,8 +28,15 @@ struct OCRPredictResult { ...@@ -27,8 +28,15 @@ struct OCRPredictResult {
std::vector<int> word_index; std::vector<int> word_index;
std::vector<std::vector<int>> points; std::vector<std::vector<int>> points;
float score; float score;
float cls_score;
int cls_label=-1;
}; };
struct ClsPredictResult {
float cls_score;
int cls_label=-1;
cv::Mat img;
};
/** /**
* OCR there are 2 models * OCR there are 2 models
* 1. First model(det),select polygones to show where are the texts * 1. First model(det),select polygones to show where are the texts
...@@ -62,8 +70,7 @@ public: ...@@ -62,8 +70,7 @@ public:
* @return * @return
*/ */
virtual std::vector<OCRPredictResult> virtual std::vector<OCRPredictResult>
infer_ocr(const std::vector<int64_t> &dims, const float *input_data, infer_ocr(cv::Mat &origin, int max_size_len, int run_det, int run_cls, int run_rec);
int input_len, int net_flag, cv::Mat &origin);
virtual NET_TYPE get_net_flag() const; virtual NET_TYPE get_net_flag() const;
...@@ -80,25 +87,26 @@ private: ...@@ -80,25 +87,26 @@ private:
calc_filtered_boxes(const float *pred, int pred_size, int output_height, calc_filtered_boxes(const float *pred, int pred_size, int output_height,
int output_width, const cv::Mat &origin); int output_width, const cv::Mat &origin);
void
infer_det(cv::Mat &origin, int max_side_len, std::vector<OCRPredictResult>& ocr_results);
/** /**
* infer for second model * infer for rec model
* *
* @param boxes * @param boxes
* @param origin * @param origin
* @return * @return
*/ */
std::vector<OCRPredictResult> void
infer_rec(const std::vector<std::vector<std::vector<int>>> &boxes, infer_rec(const cv::Mat &origin, int run_cls, OCRPredictResult& ocr_result);
const cv::Mat &origin);
/** /**
* infer for cls model * infer for cls model
* *
* @param boxes * @param boxes
* @param origin * @param origin
* @return * @return
*/ */
cv::Mat infer_cls(const cv::Mat &origin, float thresh = 0.9); ClsPredictResult infer_cls(const cv::Mat &origin, float thresh = 0.9);
/** /**
* Postprocess or sencod model to extract text * Postprocess or sencod model to extract text
......
...@@ -13,6 +13,7 @@ import android.graphics.BitmapFactory; ...@@ -13,6 +13,7 @@ import android.graphics.BitmapFactory;
import android.graphics.drawable.BitmapDrawable; import android.graphics.drawable.BitmapDrawable;
import android.media.ExifInterface; import android.media.ExifInterface;
import android.content.res.AssetManager; import android.content.res.AssetManager;
import android.media.FaceDetector;
import android.net.Uri; import android.net.Uri;
import android.os.Bundle; import android.os.Bundle;
import android.os.Environment; import android.os.Environment;
...@@ -27,7 +28,9 @@ import android.view.Menu; ...@@ -27,7 +28,9 @@ import android.view.Menu;
import android.view.MenuInflater; import android.view.MenuInflater;
import android.view.MenuItem; import android.view.MenuItem;
import android.view.View; import android.view.View;
import android.widget.CheckBox;
import android.widget.ImageView; import android.widget.ImageView;
import android.widget.Spinner;
import android.widget.TextView; import android.widget.TextView;
import android.widget.Toast; import android.widget.Toast;
...@@ -68,23 +71,24 @@ public class MainActivity extends AppCompatActivity { ...@@ -68,23 +71,24 @@ public class MainActivity extends AppCompatActivity {
protected ImageView ivInputImage; protected ImageView ivInputImage;
protected TextView tvOutputResult; protected TextView tvOutputResult;
protected TextView tvInferenceTime; protected TextView tvInferenceTime;
protected CheckBox cbOpencl;
protected Spinner spRunMode;
// Model settings of object detection // Model settings of ocr
protected String modelPath = ""; protected String modelPath = "";
protected String labelPath = ""; protected String labelPath = "";
protected String imagePath = ""; protected String imagePath = "";
protected int cpuThreadNum = 1; protected int cpuThreadNum = 1;
protected String cpuPowerMode = ""; protected String cpuPowerMode = "";
protected String inputColorFormat = ""; protected int detLongSize = 960;
protected long[] inputShape = new long[]{};
protected float[] inputMean = new float[]{};
protected float[] inputStd = new float[]{};
protected float scoreThreshold = 0.1f; protected float scoreThreshold = 0.1f;
private String currentPhotoPath; private String currentPhotoPath;
private AssetManager assetManager =null; private AssetManager assetManager = null;
protected Predictor predictor = new Predictor(); protected Predictor predictor = new Predictor();
private Bitmap cur_predict_image = null;
@Override @Override
protected void onCreate(Bundle savedInstanceState) { protected void onCreate(Bundle savedInstanceState) {
super.onCreate(savedInstanceState); super.onCreate(savedInstanceState);
...@@ -98,10 +102,12 @@ public class MainActivity extends AppCompatActivity { ...@@ -98,10 +102,12 @@ public class MainActivity extends AppCompatActivity {
// Setup the UI components // Setup the UI components
tvInputSetting = findViewById(R.id.tv_input_setting); tvInputSetting = findViewById(R.id.tv_input_setting);
cbOpencl = findViewById(R.id.cb_opencl);
tvStatus = findViewById(R.id.tv_model_img_status); tvStatus = findViewById(R.id.tv_model_img_status);
ivInputImage = findViewById(R.id.iv_input_image); ivInputImage = findViewById(R.id.iv_input_image);
tvInferenceTime = findViewById(R.id.tv_inference_time); tvInferenceTime = findViewById(R.id.tv_inference_time);
tvOutputResult = findViewById(R.id.tv_output_result); tvOutputResult = findViewById(R.id.tv_output_result);
spRunMode = findViewById(R.id.sp_run_mode);
tvInputSetting.setMovementMethod(ScrollingMovementMethod.getInstance()); tvInputSetting.setMovementMethod(ScrollingMovementMethod.getInstance());
tvOutputResult.setMovementMethod(ScrollingMovementMethod.getInstance()); tvOutputResult.setMovementMethod(ScrollingMovementMethod.getInstance());
...@@ -111,26 +117,26 @@ public class MainActivity extends AppCompatActivity { ...@@ -111,26 +117,26 @@ public class MainActivity extends AppCompatActivity {
public void handleMessage(Message msg) { public void handleMessage(Message msg) {
switch (msg.what) { switch (msg.what) {
case RESPONSE_LOAD_MODEL_SUCCESSED: case RESPONSE_LOAD_MODEL_SUCCESSED:
if(pbLoadModel!=null && pbLoadModel.isShowing()){ if (pbLoadModel != null && pbLoadModel.isShowing()) {
pbLoadModel.dismiss(); pbLoadModel.dismiss();
} }
onLoadModelSuccessed(); onLoadModelSuccessed();
break; break;
case RESPONSE_LOAD_MODEL_FAILED: case RESPONSE_LOAD_MODEL_FAILED:
if(pbLoadModel!=null && pbLoadModel.isShowing()){ if (pbLoadModel != null && pbLoadModel.isShowing()) {
pbLoadModel.dismiss(); pbLoadModel.dismiss();
} }
Toast.makeText(MainActivity.this, "Load model failed!", Toast.LENGTH_SHORT).show(); Toast.makeText(MainActivity.this, "Load model failed!", Toast.LENGTH_SHORT).show();
onLoadModelFailed(); onLoadModelFailed();
break; break;
case RESPONSE_RUN_MODEL_SUCCESSED: case RESPONSE_RUN_MODEL_SUCCESSED:
if(pbRunModel!=null && pbRunModel.isShowing()){ if (pbRunModel != null && pbRunModel.isShowing()) {
pbRunModel.dismiss(); pbRunModel.dismiss();
} }
onRunModelSuccessed(); onRunModelSuccessed();
break; break;
case RESPONSE_RUN_MODEL_FAILED: case RESPONSE_RUN_MODEL_FAILED:
if(pbRunModel!=null && pbRunModel.isShowing()){ if (pbRunModel != null && pbRunModel.isShowing()) {
pbRunModel.dismiss(); pbRunModel.dismiss();
} }
Toast.makeText(MainActivity.this, "Run model failed!", Toast.LENGTH_SHORT).show(); Toast.makeText(MainActivity.this, "Run model failed!", Toast.LENGTH_SHORT).show();
...@@ -185,7 +191,6 @@ public class MainActivity extends AppCompatActivity { ...@@ -185,7 +191,6 @@ public class MainActivity extends AppCompatActivity {
model_settingsChanged |= !model_path.equalsIgnoreCase(modelPath); model_settingsChanged |= !model_path.equalsIgnoreCase(modelPath);
settingsChanged |= !label_path.equalsIgnoreCase(labelPath); settingsChanged |= !label_path.equalsIgnoreCase(labelPath);
settingsChanged |= !image_path.equalsIgnoreCase(imagePath); settingsChanged |= !image_path.equalsIgnoreCase(imagePath);
int cpu_thread_num = Integer.parseInt(sharedPreferences.getString(getString(R.string.CPU_THREAD_NUM_KEY), int cpu_thread_num = Integer.parseInt(sharedPreferences.getString(getString(R.string.CPU_THREAD_NUM_KEY),
getString(R.string.CPU_THREAD_NUM_DEFAULT))); getString(R.string.CPU_THREAD_NUM_DEFAULT)));
model_settingsChanged |= cpu_thread_num != cpuThreadNum; model_settingsChanged |= cpu_thread_num != cpuThreadNum;
...@@ -194,33 +199,9 @@ public class MainActivity extends AppCompatActivity { ...@@ -194,33 +199,9 @@ public class MainActivity extends AppCompatActivity {
getString(R.string.CPU_POWER_MODE_DEFAULT)); getString(R.string.CPU_POWER_MODE_DEFAULT));
model_settingsChanged |= !cpu_power_mode.equalsIgnoreCase(cpuPowerMode); model_settingsChanged |= !cpu_power_mode.equalsIgnoreCase(cpuPowerMode);
String input_color_format = int det_long_size = Integer.parseInt(sharedPreferences.getString(getString(R.string.DET_LONG_SIZE_KEY),
sharedPreferences.getString(getString(R.string.INPUT_COLOR_FORMAT_KEY), getString(R.string.DET_LONG_SIZE_DEFAULT)));
getString(R.string.INPUT_COLOR_FORMAT_DEFAULT)); settingsChanged |= det_long_size != detLongSize;
settingsChanged |= !input_color_format.equalsIgnoreCase(inputColorFormat);
long[] input_shape =
Utils.parseLongsFromString(sharedPreferences.getString(getString(R.string.INPUT_SHAPE_KEY),
getString(R.string.INPUT_SHAPE_DEFAULT)), ",");
float[] input_mean =
Utils.parseFloatsFromString(sharedPreferences.getString(getString(R.string.INPUT_MEAN_KEY),
getString(R.string.INPUT_MEAN_DEFAULT)), ",");
float[] input_std =
Utils.parseFloatsFromString(sharedPreferences.getString(getString(R.string.INPUT_STD_KEY)
, getString(R.string.INPUT_STD_DEFAULT)), ",");
settingsChanged |= input_shape.length != inputShape.length;
settingsChanged |= input_mean.length != inputMean.length;
settingsChanged |= input_std.length != inputStd.length;
if (!settingsChanged) {
for (int i = 0; i < input_shape.length; i++) {
settingsChanged |= input_shape[i] != inputShape[i];
}
for (int i = 0; i < input_mean.length; i++) {
settingsChanged |= input_mean[i] != inputMean[i];
}
for (int i = 0; i < input_std.length; i++) {
settingsChanged |= input_std[i] != inputStd[i];
}
}
float score_threshold = float score_threshold =
Float.parseFloat(sharedPreferences.getString(getString(R.string.SCORE_THRESHOLD_KEY), Float.parseFloat(sharedPreferences.getString(getString(R.string.SCORE_THRESHOLD_KEY),
getString(R.string.SCORE_THRESHOLD_DEFAULT))); getString(R.string.SCORE_THRESHOLD_DEFAULT)));
...@@ -228,20 +209,16 @@ public class MainActivity extends AppCompatActivity { ...@@ -228,20 +209,16 @@ public class MainActivity extends AppCompatActivity {
if (settingsChanged) { if (settingsChanged) {
labelPath = label_path; labelPath = label_path;
imagePath = image_path; imagePath = image_path;
inputColorFormat = input_color_format; detLongSize = det_long_size;
inputShape = input_shape;
inputMean = input_mean;
inputStd = input_std;
scoreThreshold = score_threshold; scoreThreshold = score_threshold;
set_img(); set_img();
} }
if (model_settingsChanged){ if (model_settingsChanged) {
modelPath = model_path; modelPath = model_path;
cpuThreadNum = cpu_thread_num; cpuThreadNum = cpu_thread_num;
cpuPowerMode = cpu_power_mode; cpuPowerMode = cpu_power_mode;
// Update UI // Update UI
tvInputSetting.setText("Model: " + modelPath.substring(modelPath.lastIndexOf("/") + 1) + "\n" + "CPU" + tvInputSetting.setText("Model: " + modelPath.substring(modelPath.lastIndexOf("/") + 1) + "\nOPENCL: " + cbOpencl.isChecked() + "\nCPU Thread Num: " + cpuThreadNum + "\nCPU Power Mode: " + cpuPowerMode);
" Thread Num: " + Integer.toString(cpuThreadNum) + "\n" + "CPU Power Mode: " + cpuPowerMode);
tvInputSetting.scrollTo(0, 0); tvInputSetting.scrollTo(0, 0);
// Reload model if configure has been changed // Reload model if configure has been changed
loadModel(); loadModel();
...@@ -259,20 +236,28 @@ public class MainActivity extends AppCompatActivity { ...@@ -259,20 +236,28 @@ public class MainActivity extends AppCompatActivity {
} }
public boolean onLoadModel() { public boolean onLoadModel() {
return predictor.init(MainActivity.this, modelPath, labelPath, cpuThreadNum, if (predictor.isLoaded()) {
predictor.releaseModel();
}
return predictor.init(MainActivity.this, modelPath, labelPath, cbOpencl.isChecked() ? 1 : 0, cpuThreadNum,
cpuPowerMode, cpuPowerMode,
inputColorFormat, detLongSize, scoreThreshold);
inputShape, inputMean,
inputStd, scoreThreshold);
} }
public boolean onRunModel() { public boolean onRunModel() {
return predictor.isLoaded() && predictor.runModel(); String run_mode = spRunMode.getSelectedItem().toString();
int run_det = run_mode.contains("检测") ? 1 : 0;
int run_cls = run_mode.contains("分类") ? 1 : 0;
int run_rec = run_mode.contains("识别") ? 1 : 0;
return predictor.isLoaded() && predictor.runModel(run_det, run_cls, run_rec);
} }
public void onLoadModelSuccessed() { public void onLoadModelSuccessed() {
// Load test image from path and run model // Load test image from path and run model
tvInputSetting.setText("Model: " + modelPath.substring(modelPath.lastIndexOf("/") + 1) + "\nOPENCL: " + cbOpencl.isChecked() + "\nCPU Thread Num: " + cpuThreadNum + "\nCPU Power Mode: " + cpuPowerMode);
tvInputSetting.scrollTo(0, 0);
tvStatus.setText("STATUS: load model successed"); tvStatus.setText("STATUS: load model successed");
} }
public void onLoadModelFailed() { public void onLoadModelFailed() {
...@@ -306,9 +291,9 @@ public class MainActivity extends AppCompatActivity { ...@@ -306,9 +291,9 @@ public class MainActivity extends AppCompatActivity {
public void set_img() { public void set_img() {
// Load test image from path and run model // Load test image from path and run model
try { try {
assetManager= getAssets(); assetManager = getAssets();
InputStream in=assetManager.open(imagePath); InputStream in = assetManager.open(imagePath);
Bitmap bmp=BitmapFactory.decodeStream(in); Bitmap bmp = BitmapFactory.decodeStream(in);
ivInputImage.setImageBitmap(bmp); ivInputImage.setImageBitmap(bmp);
} catch (IOException e) { } catch (IOException e) {
Toast.makeText(MainActivity.this, "Load image failed!", Toast.LENGTH_SHORT).show(); Toast.makeText(MainActivity.this, "Load image failed!", Toast.LENGTH_SHORT).show();
...@@ -469,28 +454,28 @@ public class MainActivity extends AppCompatActivity { ...@@ -469,28 +454,28 @@ public class MainActivity extends AppCompatActivity {
} }
} }
public void btn_load_model_click(View view) { public void btn_reset_img_click(View view) {
if (predictor.isLoaded()){ ivInputImage.setImageBitmap(cur_predict_image);
tvStatus.setText("STATUS: model has been loaded"); }
}else{
tvStatus.setText("STATUS: load model ......"); public void cb_opencl_click(View view) {
loadModel(); tvStatus.setText("STATUS: load model ......");
} loadModel();
} }
public void btn_run_model_click(View view) { public void btn_run_model_click(View view) {
Bitmap image =((BitmapDrawable)ivInputImage.getDrawable()).getBitmap(); cur_predict_image = ((BitmapDrawable) ivInputImage.getDrawable()).getBitmap();
if(image == null) { if (cur_predict_image == null) {
tvStatus.setText("STATUS: image is not exists"); tvStatus.setText("STATUS: image is not exists");
} } else if (!predictor.isLoaded()) {
else if (!predictor.isLoaded()){
tvStatus.setText("STATUS: model is not loaded"); tvStatus.setText("STATUS: model is not loaded");
}else{ } else {
tvStatus.setText("STATUS: run model ...... "); tvStatus.setText("STATUS: run model ...... ");
predictor.setInputImage(image); predictor.setInputImage(cur_predict_image);
runModel(); runModel();
} }
} }
public void btn_choice_img_click(View view) { public void btn_choice_img_click(View view) {
if (requestAllPermissions()) { if (requestAllPermissions()) {
openGallery(); openGallery();
...@@ -511,4 +496,32 @@ public class MainActivity extends AppCompatActivity { ...@@ -511,4 +496,32 @@ public class MainActivity extends AppCompatActivity {
worker.quit(); worker.quit();
super.onDestroy(); super.onDestroy();
} }
public int get_run_mode() {
String run_mode = spRunMode.getSelectedItem().toString();
int mode;
switch (run_mode) {
case "检测+分类+识别":
mode = 1;
break;
case "检测+识别":
mode = 2;
break;
case "识别+分类":
mode = 3;
break;
case "检测":
mode = 4;
break;
case "识别":
mode = 5;
break;
case "分类":
mode = 6;
break;
default:
mode = 1;
}
return mode;
}
} }
...@@ -29,22 +29,22 @@ public class OCRPredictorNative { ...@@ -29,22 +29,22 @@ public class OCRPredictorNative {
public OCRPredictorNative(Config config) { public OCRPredictorNative(Config config) {
this.config = config; this.config = config;
loadLibrary(); loadLibrary();
nativePointer = init(config.detModelFilename, config.recModelFilename,config.clsModelFilename, nativePointer = init(config.detModelFilename, config.recModelFilename, config.clsModelFilename, config.useOpencl,
config.cpuThreadNum, config.cpuPower); config.cpuThreadNum, config.cpuPower);
Log.i("OCRPredictorNative", "load success " + nativePointer); Log.i("OCRPredictorNative", "load success " + nativePointer);
} }
public ArrayList<OcrResultModel> runImage(float[] inputData, int width, int height, int channels, Bitmap originalImage) { public ArrayList<OcrResultModel> runImage(Bitmap originalImage, int max_size_len, int run_det, int run_cls, int run_rec) {
Log.i("OCRPredictorNative", "begin to run image " + inputData.length + " " + width + " " + height); Log.i("OCRPredictorNative", "begin to run image ");
float[] dims = new float[]{1, channels, height, width}; float[] rawResults = forward(nativePointer, originalImage, max_size_len, run_det, run_cls, run_rec);
float[] rawResults = forward(nativePointer, inputData, dims, originalImage);
ArrayList<OcrResultModel> results = postprocess(rawResults); ArrayList<OcrResultModel> results = postprocess(rawResults);
return results; return results;
} }
public static class Config { public static class Config {
public int useOpencl;
public int cpuThreadNum; public int cpuThreadNum;
public String cpuPower; public String cpuPower;
public String detModelFilename; public String detModelFilename;
...@@ -53,16 +53,16 @@ public class OCRPredictorNative { ...@@ -53,16 +53,16 @@ public class OCRPredictorNative {
} }
public void destory(){ public void destory() {
if (nativePointer > 0) { if (nativePointer > 0) {
release(nativePointer); release(nativePointer);
nativePointer = 0; nativePointer = 0;
} }
} }
protected native long init(String detModelPath, String recModelPath,String clsModelPath, int threadNum, String cpuMode); protected native long init(String detModelPath, String recModelPath, String clsModelPath, int useOpencl, int threadNum, String cpuMode);
protected native float[] forward(long pointer, float[] buf, float[] ddims, Bitmap originalImage); protected native float[] forward(long pointer, Bitmap originalImage,int max_size_len, int run_det, int run_cls, int run_rec);
protected native void release(long pointer); protected native void release(long pointer);
...@@ -73,9 +73,9 @@ public class OCRPredictorNative { ...@@ -73,9 +73,9 @@ public class OCRPredictorNative {
while (begin < raw.length) { while (begin < raw.length) {
int point_num = Math.round(raw[begin]); int point_num = Math.round(raw[begin]);
int word_num = Math.round(raw[begin + 1]); int word_num = Math.round(raw[begin + 1]);
OcrResultModel model = parse(raw, begin + 2, point_num, word_num); OcrResultModel res = parse(raw, begin + 2, point_num, word_num);
begin += 2 + 1 + point_num * 2 + word_num; begin += 2 + 1 + point_num * 2 + word_num + 2;
results.add(model); results.add(res);
} }
return results; return results;
...@@ -83,19 +83,22 @@ public class OCRPredictorNative { ...@@ -83,19 +83,22 @@ public class OCRPredictorNative {
private OcrResultModel parse(float[] raw, int begin, int pointNum, int wordNum) { private OcrResultModel parse(float[] raw, int begin, int pointNum, int wordNum) {
int current = begin; int current = begin;
OcrResultModel model = new OcrResultModel(); OcrResultModel res = new OcrResultModel();
model.setConfidence(raw[current]); res.setConfidence(raw[current]);
current++; current++;
for (int i = 0; i < pointNum; i++) { for (int i = 0; i < pointNum; i++) {
model.addPoints(Math.round(raw[current + i * 2]), Math.round(raw[current + i * 2 + 1])); res.addPoints(Math.round(raw[current + i * 2]), Math.round(raw[current + i * 2 + 1]));
} }
current += (pointNum * 2); current += (pointNum * 2);
for (int i = 0; i < wordNum; i++) { for (int i = 0; i < wordNum; i++) {
int index = Math.round(raw[current + i]); int index = Math.round(raw[current + i]);
model.addWordIndex(index); res.addWordIndex(index);
} }
current += wordNum;
res.setClsIdx(raw[current]);
res.setClsConfidence(raw[current + 1]);
Log.i("OCRPredictorNative", "word finished " + wordNum); Log.i("OCRPredictorNative", "word finished " + wordNum);
return model; return res;
} }
......
...@@ -10,6 +10,9 @@ public class OcrResultModel { ...@@ -10,6 +10,9 @@ public class OcrResultModel {
private List<Integer> wordIndex; private List<Integer> wordIndex;
private String label; private String label;
private float confidence; private float confidence;
private float cls_idx;
private String cls_label;
private float cls_confidence;
public OcrResultModel() { public OcrResultModel() {
super(); super();
...@@ -49,4 +52,28 @@ public class OcrResultModel { ...@@ -49,4 +52,28 @@ public class OcrResultModel {
public void setConfidence(float confidence) { public void setConfidence(float confidence) {
this.confidence = confidence; this.confidence = confidence;
} }
public float getClsIdx() {
return cls_idx;
}
public void setClsIdx(float idx) {
this.cls_idx = idx;
}
public String getClsLabel() {
return cls_label;
}
public void setClsLabel(String label) {
this.cls_label = label;
}
public float getClsConfidence() {
return cls_confidence;
}
public void setClsConfidence(float confidence) {
this.cls_confidence = confidence;
}
} }
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册