Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
s920243400
PaddleOCR
提交
3cb7a609
P
PaddleOCR
项目概览
s920243400
/
PaddleOCR
与 Fork 源项目一致
Fork自
PaddlePaddle / PaddleOCR
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
PaddleOCR
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
3cb7a609
编写于
2月 26, 2022
作者:
W
WenmuZhou
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
split det cls rec mode
上级
cbbd8f79
变更
6
隐藏空白更改
内联
并排
Showing
6 changed file
with
316 addition
and
171 deletion
+316
-171
deploy/android_demo/app/src/main/cpp/native.cpp
deploy/android_demo/app/src/main/cpp/native.cpp
+17
-12
deploy/android_demo/app/src/main/cpp/ocr_ppredictor.cpp
deploy/android_demo/app/src/main/cpp/ocr_ppredictor.cpp
+157
-68
deploy/android_demo/app/src/main/cpp/ocr_ppredictor.h
deploy/android_demo/app/src/main/cpp/ocr_ppredictor.h
+17
-9
deploy/android_demo/app/src/main/java/com/baidu/paddle/lite/demo/ocr/MainActivity.java
...ain/java/com/baidu/paddle/lite/demo/ocr/MainActivity.java
+79
-66
deploy/android_demo/app/src/main/java/com/baidu/paddle/lite/demo/ocr/OCRPredictorNative.java
...va/com/baidu/paddle/lite/demo/ocr/OCRPredictorNative.java
+19
-16
deploy/android_demo/app/src/main/java/com/baidu/paddle/lite/demo/ocr/OcrResultModel.java
...n/java/com/baidu/paddle/lite/demo/ocr/OcrResultModel.java
+27
-0
未找到文件。
deploy/android_demo/app/src/main/cpp/native.cpp
浏览文件 @
3cb7a609
...
@@ -13,7 +13,7 @@ static paddle::lite_api::PowerMode str_to_cpu_mode(const std::string &cpu_mode);
...
@@ -13,7 +13,7 @@ static paddle::lite_api::PowerMode str_to_cpu_mode(const std::string &cpu_mode);
extern
"C"
JNIEXPORT
jlong
JNICALL
extern
"C"
JNIEXPORT
jlong
JNICALL
Java_com_baidu_paddle_lite_demo_ocr_OCRPredictorNative_init
(
Java_com_baidu_paddle_lite_demo_ocr_OCRPredictorNative_init
(
JNIEnv
*
env
,
jobject
thiz
,
jstring
j_det_model_path
,
JNIEnv
*
env
,
jobject
thiz
,
jstring
j_det_model_path
,
jstring
j_rec_model_path
,
jstring
j_cls_model_path
,
jint
j_thread_num
,
jstring
j_rec_model_path
,
jstring
j_cls_model_path
,
jint
j_
use_opencl
,
jint
j_
thread_num
,
jstring
j_cpu_mode
)
{
jstring
j_cpu_mode
)
{
std
::
string
det_model_path
=
jstring_to_cpp_string
(
env
,
j_det_model_path
);
std
::
string
det_model_path
=
jstring_to_cpp_string
(
env
,
j_det_model_path
);
std
::
string
rec_model_path
=
jstring_to_cpp_string
(
env
,
j_rec_model_path
);
std
::
string
rec_model_path
=
jstring_to_cpp_string
(
env
,
j_rec_model_path
);
...
@@ -21,6 +21,7 @@ Java_com_baidu_paddle_lite_demo_ocr_OCRPredictorNative_init(
...
@@ -21,6 +21,7 @@ Java_com_baidu_paddle_lite_demo_ocr_OCRPredictorNative_init(
int
thread_num
=
j_thread_num
;
int
thread_num
=
j_thread_num
;
std
::
string
cpu_mode
=
jstring_to_cpp_string
(
env
,
j_cpu_mode
);
std
::
string
cpu_mode
=
jstring_to_cpp_string
(
env
,
j_cpu_mode
);
ppredictor
::
OCR_Config
conf
;
ppredictor
::
OCR_Config
conf
;
conf
.
use_opencl
=
j_use_opencl
;
conf
.
thread_num
=
thread_num
;
conf
.
thread_num
=
thread_num
;
conf
.
mode
=
str_to_cpu_mode
(
cpu_mode
);
conf
.
mode
=
str_to_cpu_mode
(
cpu_mode
);
ppredictor
::
OCR_PPredictor
*
orc_predictor
=
ppredictor
::
OCR_PPredictor
*
orc_predictor
=
...
@@ -57,32 +58,31 @@ str_to_cpu_mode(const std::string &cpu_mode) {
...
@@ -57,32 +58,31 @@ str_to_cpu_mode(const std::string &cpu_mode) {
extern
"C"
JNIEXPORT
jfloatArray
JNICALL
extern
"C"
JNIEXPORT
jfloatArray
JNICALL
Java_com_baidu_paddle_lite_demo_ocr_OCRPredictorNative_forward
(
Java_com_baidu_paddle_lite_demo_ocr_OCRPredictorNative_forward
(
JNIEnv
*
env
,
jobject
thiz
,
jlong
java_pointer
,
jfloatArray
buf
,
JNIEnv
*
env
,
jobject
thiz
,
jlong
java_pointer
,
jobject
original_image
,
jint
j_max_size_len
,
jint
j_run_det
,
jint
j_run_cls
,
jint
j_run_rec
)
{
jfloatArray
ddims
,
jobject
original_image
)
{
LOGI
(
"begin to run native forward"
);
LOGI
(
"begin to run native forward"
);
if
(
java_pointer
==
0
)
{
if
(
java_pointer
==
0
)
{
LOGE
(
"JAVA pointer is NULL"
);
LOGE
(
"JAVA pointer is NULL"
);
return
cpp_array_to_jfloatarray
(
env
,
nullptr
,
0
);
return
cpp_array_to_jfloatarray
(
env
,
nullptr
,
0
);
}
}
cv
::
Mat
origin
=
bitmap_to_cv_mat
(
env
,
original_image
);
cv
::
Mat
origin
=
bitmap_to_cv_mat
(
env
,
original_image
);
if
(
origin
.
size
==
0
)
{
if
(
origin
.
size
==
0
)
{
LOGE
(
"origin bitmap cannot convert to CV Mat"
);
LOGE
(
"origin bitmap cannot convert to CV Mat"
);
return
cpp_array_to_jfloatarray
(
env
,
nullptr
,
0
);
return
cpp_array_to_jfloatarray
(
env
,
nullptr
,
0
);
}
}
int
max_size_len
=
j_max_size_len
;
int
run_det
=
j_run_det
;
int
run_cls
=
j_run_cls
;
int
run_rec
=
j_run_rec
;
ppredictor
::
OCR_PPredictor
*
ppredictor
=
ppredictor
::
OCR_PPredictor
*
ppredictor
=
(
ppredictor
::
OCR_PPredictor
*
)
java_pointer
;
(
ppredictor
::
OCR_PPredictor
*
)
java_pointer
;
std
::
vector
<
float
>
dims_float_arr
=
jfloatarray_to_float_vector
(
env
,
ddims
);
std
::
vector
<
int64_t
>
dims_arr
;
std
::
vector
<
int64_t
>
dims_arr
;
dims_arr
.
resize
(
dims_float_arr
.
size
());
std
::
copy
(
dims_float_arr
.
cbegin
(),
dims_float_arr
.
cend
(),
dims_arr
.
begin
());
// 这里值有点大,就不调用jfloatarray_to_float_vector了
int64_t
buf_len
=
(
int64_t
)
env
->
GetArrayLength
(
buf
);
jfloat
*
buf_data
=
env
->
GetFloatArrayElements
(
buf
,
JNI_FALSE
);
float
*
data
=
(
jfloat
*
)
buf_data
;
std
::
vector
<
ppredictor
::
OCRPredictResult
>
results
=
std
::
vector
<
ppredictor
::
OCRPredictResult
>
results
=
ppredictor
->
infer_ocr
(
dims_arr
,
data
,
buf_len
,
NET_OCR
,
origin
);
ppredictor
->
infer_ocr
(
origin
,
max_size_len
,
run_det
,
run_cls
,
run_rec
);
LOGI
(
"infer_ocr finished with boxes %ld"
,
results
.
size
());
LOGI
(
"infer_ocr finished with boxes %ld"
,
results
.
size
());
// 这里将std::vector<ppredictor::OCRPredictResult> 序列化成
// 这里将std::vector<ppredictor::OCRPredictResult> 序列化成
// float数组,传输到java层再反序列化
// float数组,传输到java层再反序列化
std
::
vector
<
float
>
float_arr
;
std
::
vector
<
float
>
float_arr
;
...
@@ -90,13 +90,18 @@ Java_com_baidu_paddle_lite_demo_ocr_OCRPredictorNative_forward(
...
@@ -90,13 +90,18 @@ Java_com_baidu_paddle_lite_demo_ocr_OCRPredictorNative_forward(
float_arr
.
push_back
(
r
.
points
.
size
());
float_arr
.
push_back
(
r
.
points
.
size
());
float_arr
.
push_back
(
r
.
word_index
.
size
());
float_arr
.
push_back
(
r
.
word_index
.
size
());
float_arr
.
push_back
(
r
.
score
);
float_arr
.
push_back
(
r
.
score
);
// add det point
for
(
const
std
::
vector
<
int
>
&
point
:
r
.
points
)
{
for
(
const
std
::
vector
<
int
>
&
point
:
r
.
points
)
{
float_arr
.
push_back
(
point
.
at
(
0
));
float_arr
.
push_back
(
point
.
at
(
0
));
float_arr
.
push_back
(
point
.
at
(
1
));
float_arr
.
push_back
(
point
.
at
(
1
));
}
}
// add rec word idx
for
(
int
index
:
r
.
word_index
)
{
for
(
int
index
:
r
.
word_index
)
{
float_arr
.
push_back
(
index
);
float_arr
.
push_back
(
index
);
}
}
// add cls result
float_arr
.
push_back
(
r
.
cls_label
);
float_arr
.
push_back
(
r
.
cls_score
);
}
}
return
cpp_array_to_jfloatarray
(
env
,
float_arr
.
data
(),
float_arr
.
size
());
return
cpp_array_to_jfloatarray
(
env
,
float_arr
.
data
(),
float_arr
.
size
());
}
}
...
...
deploy/android_demo/app/src/main/cpp/ocr_ppredictor.cpp
浏览文件 @
3cb7a609
...
@@ -17,15 +17,15 @@ int OCR_PPredictor::init(const std::string &det_model_content,
...
@@ -17,15 +17,15 @@ int OCR_PPredictor::init(const std::string &det_model_content,
const
std
::
string
&
rec_model_content
,
const
std
::
string
&
rec_model_content
,
const
std
::
string
&
cls_model_content
)
{
const
std
::
string
&
cls_model_content
)
{
_det_predictor
=
std
::
unique_ptr
<
PPredictor
>
(
_det_predictor
=
std
::
unique_ptr
<
PPredictor
>
(
new
PPredictor
{
_config
.
thread_num
,
NET_OCR
,
_config
.
mode
});
new
PPredictor
{
_config
.
use_opencl
,
_config
.
thread_num
,
NET_OCR
,
_config
.
mode
});
_det_predictor
->
init_nb
(
det_model_content
);
_det_predictor
->
init_nb
(
det_model_content
);
_rec_predictor
=
std
::
unique_ptr
<
PPredictor
>
(
_rec_predictor
=
std
::
unique_ptr
<
PPredictor
>
(
new
PPredictor
{
_config
.
thread_num
,
NET_OCR_INTERNAL
,
_config
.
mode
});
new
PPredictor
{
_config
.
use_opencl
,
_config
.
thread_num
,
NET_OCR_INTERNAL
,
_config
.
mode
});
_rec_predictor
->
init_nb
(
rec_model_content
);
_rec_predictor
->
init_nb
(
rec_model_content
);
_cls_predictor
=
std
::
unique_ptr
<
PPredictor
>
(
_cls_predictor
=
std
::
unique_ptr
<
PPredictor
>
(
new
PPredictor
{
_config
.
thread_num
,
NET_OCR_INTERNAL
,
_config
.
mode
});
new
PPredictor
{
_config
.
use_opencl
,
_config
.
thread_num
,
NET_OCR_INTERNAL
,
_config
.
mode
});
_cls_predictor
->
init_nb
(
cls_model_content
);
_cls_predictor
->
init_nb
(
cls_model_content
);
return
RETURN_OK
;
return
RETURN_OK
;
}
}
...
@@ -34,15 +34,16 @@ int OCR_PPredictor::init_from_file(const std::string &det_model_path,
...
@@ -34,15 +34,16 @@ int OCR_PPredictor::init_from_file(const std::string &det_model_path,
const
std
::
string
&
rec_model_path
,
const
std
::
string
&
rec_model_path
,
const
std
::
string
&
cls_model_path
)
{
const
std
::
string
&
cls_model_path
)
{
_det_predictor
=
std
::
unique_ptr
<
PPredictor
>
(
_det_predictor
=
std
::
unique_ptr
<
PPredictor
>
(
new
PPredictor
{
_config
.
thread_num
,
NET_OCR
,
_config
.
mode
});
new
PPredictor
{
_config
.
use_opencl
,
_config
.
thread_num
,
NET_OCR
,
_config
.
mode
});
_det_predictor
->
init_from_file
(
det_model_path
);
_det_predictor
->
init_from_file
(
det_model_path
);
_rec_predictor
=
std
::
unique_ptr
<
PPredictor
>
(
_rec_predictor
=
std
::
unique_ptr
<
PPredictor
>
(
new
PPredictor
{
_config
.
thread_num
,
NET_OCR_INTERNAL
,
_config
.
mode
});
new
PPredictor
{
_config
.
use_opencl
,
_config
.
thread_num
,
NET_OCR_INTERNAL
,
_config
.
mode
});
_rec_predictor
->
init_from_file
(
rec_model_path
);
_rec_predictor
->
init_from_file
(
rec_model_path
);
_cls_predictor
=
std
::
unique_ptr
<
PPredictor
>
(
_cls_predictor
=
std
::
unique_ptr
<
PPredictor
>
(
new
PPredictor
{
_config
.
thread_num
,
NET_OCR_INTERNAL
,
_config
.
mode
});
new
PPredictor
{
_config
.
use_opencl
,
_config
.
thread_num
,
NET_OCR_INTERNAL
,
_config
.
mode
});
_cls_predictor
->
init_from_file
(
cls_model_path
);
_cls_predictor
->
init_from_file
(
cls_model_path
);
return
RETURN_OK
;
return
RETURN_OK
;
}
}
...
@@ -77,90 +78,173 @@ visual_img(const std::vector<std::vector<std::vector<int>>> &filter_boxes,
...
@@ -77,90 +78,173 @@ visual_img(const std::vector<std::vector<std::vector<int>>> &filter_boxes,
}
}
std
::
vector
<
OCRPredictResult
>
std
::
vector
<
OCRPredictResult
>
OCR_PPredictor
::
infer_ocr
(
const
std
::
vector
<
int64_t
>
&
dims
,
OCR_PPredictor
::
infer_ocr
(
cv
::
Mat
&
origin
,
int
max_size_len
,
int
run_det
,
int
run_cls
,
int
run_rec
)
{
const
float
*
input_data
,
int
input_len
,
int
net_flag
,
LOGI
(
"ocr cpp start *****************"
);
cv
::
Mat
&
origin
)
{
LOGI
(
"ocr cpp det: %d, cls: %d, rec: %d"
,
run_det
,
run_cls
,
run_rec
);
std
::
vector
<
OCRPredictResult
>
ocr_results
;
if
(
run_det
){
infer_det
(
origin
,
max_size_len
,
ocr_results
);
}
if
(
run_rec
){
if
(
ocr_results
.
size
()
==
0
){
OCRPredictResult
res
;
ocr_results
.
emplace_back
(
std
::
move
(
res
));
}
for
(
int
i
=
0
;
i
<
ocr_results
.
size
();
i
++
)
{
infer_rec
(
origin
,
run_cls
,
ocr_results
[
i
]);
}
}
else
if
(
run_cls
){
ClsPredictResult
cls_res
=
infer_cls
(
origin
);
OCRPredictResult
res
;
res
.
cls_score
=
cls_res
.
cls_score
;
res
.
cls_label
=
cls_res
.
cls_label
;
ocr_results
.
push_back
(
res
);
}
LOGI
(
"ocr cpp end *****************"
);
return
ocr_results
;
}
cv
::
Mat
DetResizeImg
(
const
cv
::
Mat
img
,
int
max_size_len
,
std
::
vector
<
float
>
&
ratio_hw
)
{
int
w
=
img
.
cols
;
int
h
=
img
.
rows
;
float
ratio
=
1.
f
;
int
max_wh
=
w
>=
h
?
w
:
h
;
if
(
max_wh
>
max_size_len
)
{
if
(
h
>
w
)
{
ratio
=
static_cast
<
float
>
(
max_size_len
)
/
static_cast
<
float
>
(
h
);
}
else
{
ratio
=
static_cast
<
float
>
(
max_size_len
)
/
static_cast
<
float
>
(
w
);
}
}
int
resize_h
=
static_cast
<
int
>
(
float
(
h
)
*
ratio
);
int
resize_w
=
static_cast
<
int
>
(
float
(
w
)
*
ratio
);
if
(
resize_h
%
32
==
0
)
resize_h
=
resize_h
;
else
if
(
resize_h
/
32
<
1
+
1e-5
)
resize_h
=
32
;
else
resize_h
=
(
resize_h
/
32
-
1
)
*
32
;
if
(
resize_w
%
32
==
0
)
resize_w
=
resize_w
;
else
if
(
resize_w
/
32
<
1
+
1e-5
)
resize_w
=
32
;
else
resize_w
=
(
resize_w
/
32
-
1
)
*
32
;
cv
::
Mat
resize_img
;
cv
::
resize
(
img
,
resize_img
,
cv
::
Size
(
resize_w
,
resize_h
));
ratio_hw
.
push_back
(
static_cast
<
float
>
(
resize_h
)
/
static_cast
<
float
>
(
h
));
ratio_hw
.
push_back
(
static_cast
<
float
>
(
resize_w
)
/
static_cast
<
float
>
(
w
));
return
resize_img
;
}
void
OCR_PPredictor
::
infer_det
(
cv
::
Mat
&
origin
,
int
max_size_len
,
std
::
vector
<
OCRPredictResult
>
&
ocr_results
)
{
std
::
vector
<
float
>
mean
=
{
0.485
f
,
0.456
f
,
0.406
f
};
std
::
vector
<
float
>
scale
=
{
1
/
0.229
f
,
1
/
0.224
f
,
1
/
0.225
f
};
PredictorInput
input
=
_det_predictor
->
get_first_input
();
PredictorInput
input
=
_det_predictor
->
get_first_input
();
input
.
set_dims
(
dims
);
input
.
set_data
(
input_data
,
input_len
);
std
::
vector
<
float
>
ratio_hw
;
cv
::
Mat
input_image
=
DetResizeImg
(
origin
,
max_size_len
,
ratio_hw
);
input_image
.
convertTo
(
input_image
,
CV_32FC3
,
1
/
255.0
f
);
const
float
*
dimg
=
reinterpret_cast
<
const
float
*>
(
input_image
.
data
);
int
input_size
=
input_image
.
rows
*
input_image
.
cols
;
input
.
set_dims
({
1
,
3
,
input_image
.
rows
,
input_image
.
cols
});
neon_mean_scale
(
dimg
,
input
.
get_mutable_float_data
(),
input_size
,
mean
,
scale
);
LOGI
(
"ocr cpp det shape %d,%d"
,
input_image
.
rows
,
input_image
.
cols
);
std
::
vector
<
PredictorOutput
>
results
=
_det_predictor
->
infer
();
std
::
vector
<
PredictorOutput
>
results
=
_det_predictor
->
infer
();
PredictorOutput
&
res
=
results
.
at
(
0
);
PredictorOutput
&
res
=
results
.
at
(
0
);
std
::
vector
<
std
::
vector
<
std
::
vector
<
int
>>>
filtered_box
=
calc_filtered_boxes
(
std
::
vector
<
std
::
vector
<
std
::
vector
<
int
>>>
filtered_box
=
calc_filtered_boxes
(
res
.
get_float_data
(),
res
.
get_size
(),
(
int
)
dims
[
2
],
(
int
)
dims
[
3
],
origin
);
res
.
get_float_data
(),
res
.
get_size
(),
input_image
.
rows
,
input_image
.
cols
,
origin
);
LOGI
(
"Filter_box size %ld"
,
filtered_box
.
size
());
LOGI
(
"ocr cpp det Filter_box size %ld"
,
filtered_box
.
size
());
return
infer_rec
(
filtered_box
,
origin
);
for
(
int
i
=
0
;
i
<
filtered_box
.
size
();
i
++
){
LOGI
(
"ocr cpp box %d,%d,%d,%d,%d,%d,%d,%d"
,
filtered_box
[
i
][
0
][
0
],
filtered_box
[
i
][
0
][
1
],
filtered_box
[
i
][
1
][
0
],
filtered_box
[
i
][
1
][
1
],
filtered_box
[
i
][
2
][
0
],
filtered_box
[
i
][
2
][
1
],
filtered_box
[
i
][
3
][
0
],
filtered_box
[
i
][
3
][
1
]);
OCRPredictResult
res
;
res
.
points
=
filtered_box
[
i
];
ocr_results
.
push_back
(
res
);
}
}
}
std
::
vector
<
OCRPredictResult
>
OCR_PPredictor
::
infer_rec
(
void
OCR_PPredictor
::
infer_rec
(
const
cv
::
Mat
&
origin_img
,
int
run_cls
,
OCRPredictResult
&
ocr_result
)
{
const
std
::
vector
<
std
::
vector
<
std
::
vector
<
int
>>>
&
boxes
,
const
cv
::
Mat
&
origin_img
)
{
std
::
vector
<
float
>
mean
=
{
0.5
f
,
0.5
f
,
0.5
f
};
std
::
vector
<
float
>
mean
=
{
0.5
f
,
0.5
f
,
0.5
f
};
std
::
vector
<
float
>
scale
=
{
1
/
0.5
f
,
1
/
0.5
f
,
1
/
0.5
f
};
std
::
vector
<
float
>
scale
=
{
1
/
0.5
f
,
1
/
0.5
f
,
1
/
0.5
f
};
std
::
vector
<
int64_t
>
dims
=
{
1
,
3
,
0
,
0
};
std
::
vector
<
int64_t
>
dims
=
{
1
,
3
,
0
,
0
};
std
::
vector
<
OCRPredictResult
>
ocr_results
;
PredictorInput
input
=
_rec_predictor
->
get_first_input
();
PredictorInput
input
=
_rec_predictor
->
get_first_input
();
for
(
auto
bp
=
boxes
.
crbegin
();
bp
!=
boxes
.
crend
();
++
bp
)
{
const
std
::
vector
<
std
::
vector
<
int
>>
&
box
=
*
bp
;
cv
::
Mat
crop_img
=
get_rotate_crop_image
(
origin_img
,
box
);
crop_img
=
infer_cls
(
crop_img
);
float
wh_ratio
=
float
(
crop_img
.
cols
)
/
float
(
crop_img
.
rows
);
const
std
::
vector
<
std
::
vector
<
int
>>
&
box
=
ocr_result
.
points
;
cv
::
Mat
input_image
=
crnn_resize_img
(
crop_img
,
wh_ratio
);
cv
::
Mat
crop_img
;
input_image
.
convertTo
(
input_image
,
CV_32FC3
,
1
/
255.0
f
);
if
(
box
.
size
()
>
0
){
const
float
*
dimg
=
reinterpret_cast
<
const
float
*>
(
input_image
.
data
);
crop_img
=
get_rotate_crop_image
(
origin_img
,
box
);
int
input_size
=
input_image
.
rows
*
input_image
.
cols
;
}
else
{
crop_img
=
origin_img
;
}
dims
[
2
]
=
input_image
.
rows
;
if
(
run_cls
){
dims
[
3
]
=
input_image
.
cols
;
ClsPredictResult
cls_res
=
infer_cls
(
crop_img
);
input
.
set_dims
(
dims
);
crop_img
=
cls_res
.
img
;
ocr_result
.
cls_score
=
cls_res
.
cls_score
;
ocr_result
.
cls_label
=
cls_res
.
cls_label
;
}
neon_mean_scale
(
dimg
,
input
.
get_mutable_float_data
(),
input_size
,
mean
,
scale
);
std
::
vector
<
PredictorOutput
>
results
=
_rec_predictor
->
infer
();
float
wh_ratio
=
float
(
crop_img
.
cols
)
/
float
(
crop_img
.
rows
);
const
float
*
predict_batch
=
results
.
at
(
0
).
get_float_data
();
cv
::
Mat
input_image
=
crnn_resize_img
(
crop_img
,
wh_ratio
);
const
std
::
vector
<
int64_t
>
predict_shape
=
results
.
at
(
0
).
get_shape
();
input_image
.
convertTo
(
input_image
,
CV_32FC3
,
1
/
255.0
f
);
const
float
*
dimg
=
reinterpret_cast
<
const
float
*>
(
input_image
.
data
);
int
input_size
=
input_image
.
rows
*
input_image
.
cols
;
OCRPredictResult
res
;
dims
[
2
]
=
input_image
.
rows
;
dims
[
3
]
=
input_image
.
cols
;
input
.
set_dims
(
dims
);
// ctc decode
neon_mean_scale
(
dimg
,
input
.
get_mutable_float_data
(),
input_size
,
mean
,
int
argmax_idx
;
scale
);
int
last_index
=
0
;
float
score
=
0.
f
;
std
::
vector
<
PredictorOutput
>
results
=
_rec_predictor
->
infer
();
int
count
=
0
;
const
float
*
predict_batch
=
results
.
at
(
0
).
get_float_data
();
float
max_value
=
0.0
f
;
const
std
::
vector
<
int64_t
>
predict_shape
=
results
.
at
(
0
).
get_shape
();
for
(
int
n
=
0
;
n
<
predict_shape
[
1
];
n
++
)
{
// ctc decode
argmax_idx
=
int
(
argmax
(
&
predict_batch
[
n
*
predict_shape
[
2
]],
int
argmax_idx
;
&
predict_batch
[(
n
+
1
)
*
predict_shape
[
2
]]));
int
last_index
=
0
;
max_value
=
float
score
=
0.
f
;
float
(
*
std
::
max_element
(
&
predict_batch
[
n
*
predict_shape
[
2
]],
int
count
=
0
;
&
predict_batch
[(
n
+
1
)
*
predict_shape
[
2
]]));
float
max_value
=
0.0
f
;
if
(
argmax_idx
>
0
&&
(
!
(
n
>
0
&&
argmax_idx
==
last_index
)))
{
score
+=
max_value
;
for
(
int
n
=
0
;
n
<
predict_shape
[
1
];
n
++
)
{
count
+=
1
;
argmax_idx
=
int
(
argmax
(
&
predict_batch
[
n
*
predict_shape
[
2
]],
res
.
word_index
.
push_back
(
argmax_idx
);
&
predict_batch
[(
n
+
1
)
*
predict_shape
[
2
]]));
}
max_value
=
last_index
=
argmax_idx
;
float
(
*
std
::
max_element
(
&
predict_batch
[
n
*
predict_shape
[
2
]],
}
&
predict_batch
[(
n
+
1
)
*
predict_shape
[
2
]]));
score
/=
count
;
if
(
argmax_idx
>
0
&&
(
!
(
n
>
0
&&
argmax_idx
==
last_index
)))
{
if
(
res
.
word_index
.
empty
())
{
score
+=
max_value
;
continue
;
count
+=
1
;
ocr_result
.
word_index
.
push_back
(
argmax_idx
);
}
}
res
.
score
=
score
;
last_index
=
argmax_idx
;
res
.
points
=
box
;
ocr_results
.
emplace_back
(
std
::
move
(
res
));
}
}
LOGI
(
"ocr_results finished %lu"
,
ocr_results
.
size
());
score
/=
count
;
return
ocr_results
;
ocr_result
.
score
=
score
;
LOGI
(
"ocr cpp rec word size %ld"
,
count
);
}
}
cv
::
Ma
t
OCR_PPredictor
::
infer_cls
(
const
cv
::
Mat
&
img
,
float
thresh
)
{
ClsPredictResul
t
OCR_PPredictor
::
infer_cls
(
const
cv
::
Mat
&
img
,
float
thresh
)
{
std
::
vector
<
float
>
mean
=
{
0.5
f
,
0.5
f
,
0.5
f
};
std
::
vector
<
float
>
mean
=
{
0.5
f
,
0.5
f
,
0.5
f
};
std
::
vector
<
float
>
scale
=
{
1
/
0.5
f
,
1
/
0.5
f
,
1
/
0.5
f
};
std
::
vector
<
float
>
scale
=
{
1
/
0.5
f
,
1
/
0.5
f
,
1
/
0.5
f
};
std
::
vector
<
int64_t
>
dims
=
{
1
,
3
,
0
,
0
};
std
::
vector
<
int64_t
>
dims
=
{
1
,
3
,
0
,
0
};
std
::
vector
<
OCRPredictResult
>
ocr_results
;
PredictorInput
input
=
_cls_predictor
->
get_first_input
();
PredictorInput
input
=
_cls_predictor
->
get_first_input
();
...
@@ -182,7 +266,7 @@ cv::Mat OCR_PPredictor::infer_cls(const cv::Mat &img, float thresh) {
...
@@ -182,7 +266,7 @@ cv::Mat OCR_PPredictor::infer_cls(const cv::Mat &img, float thresh) {
float
score
=
0
;
float
score
=
0
;
int
label
=
0
;
int
label
=
0
;
for
(
int64_t
i
=
0
;
i
<
results
.
at
(
0
).
get_size
();
i
++
)
{
for
(
int64_t
i
=
0
;
i
<
results
.
at
(
0
).
get_size
();
i
++
)
{
LOGI
(
"output scores [%f]"
,
scores
[
i
]);
LOGI
(
"o
cr cpp cls o
utput scores [%f]"
,
scores
[
i
]);
if
(
scores
[
i
]
>
score
)
{
if
(
scores
[
i
]
>
score
)
{
score
=
scores
[
i
];
score
=
scores
[
i
];
label
=
i
;
label
=
i
;
...
@@ -193,7 +277,12 @@ cv::Mat OCR_PPredictor::infer_cls(const cv::Mat &img, float thresh) {
...
@@ -193,7 +277,12 @@ cv::Mat OCR_PPredictor::infer_cls(const cv::Mat &img, float thresh) {
if
(
label
%
2
==
1
&&
score
>
thresh
)
{
if
(
label
%
2
==
1
&&
score
>
thresh
)
{
cv
::
rotate
(
srcimg
,
srcimg
,
1
);
cv
::
rotate
(
srcimg
,
srcimg
,
1
);
}
}
return
srcimg
;
ClsPredictResult
res
;
res
.
cls_label
=
label
;
res
.
cls_score
=
score
;
res
.
img
=
srcimg
;
LOGI
(
"ocr cpp cls word cls %ld, %f"
,
label
,
score
);
return
res
;
}
}
std
::
vector
<
std
::
vector
<
std
::
vector
<
int
>>>
std
::
vector
<
std
::
vector
<
std
::
vector
<
int
>>>
...
...
deploy/android_demo/app/src/main/cpp/ocr_ppredictor.h
浏览文件 @
3cb7a609
...
@@ -15,7 +15,8 @@ namespace ppredictor {
...
@@ -15,7 +15,8 @@ namespace ppredictor {
* Config
* Config
*/
*/
struct
OCR_Config
{
struct
OCR_Config
{
int
thread_num
=
4
;
// Thread num
int
use_opencl
=
0
;
int
thread_num
=
4
;
// Thread num
paddle
::
lite_api
::
PowerMode
mode
=
paddle
::
lite_api
::
PowerMode
mode
=
paddle
::
lite_api
::
LITE_POWER_HIGH
;
// PaddleLite Mode
paddle
::
lite_api
::
LITE_POWER_HIGH
;
// PaddleLite Mode
};
};
...
@@ -27,8 +28,15 @@ struct OCRPredictResult {
...
@@ -27,8 +28,15 @@ struct OCRPredictResult {
std
::
vector
<
int
>
word_index
;
std
::
vector
<
int
>
word_index
;
std
::
vector
<
std
::
vector
<
int
>>
points
;
std
::
vector
<
std
::
vector
<
int
>>
points
;
float
score
;
float
score
;
float
cls_score
;
int
cls_label
=-
1
;
};
};
struct
ClsPredictResult
{
float
cls_score
;
int
cls_label
=-
1
;
cv
::
Mat
img
;
};
/**
/**
* OCR there are 2 models
* OCR there are 2 models
* 1. First model(det),select polygones to show where are the texts
* 1. First model(det),select polygones to show where are the texts
...
@@ -62,8 +70,7 @@ public:
...
@@ -62,8 +70,7 @@ public:
* @return
* @return
*/
*/
virtual
std
::
vector
<
OCRPredictResult
>
virtual
std
::
vector
<
OCRPredictResult
>
infer_ocr
(
const
std
::
vector
<
int64_t
>
&
dims
,
const
float
*
input_data
,
infer_ocr
(
cv
::
Mat
&
origin
,
int
max_size_len
,
int
run_det
,
int
run_cls
,
int
run_rec
);
int
input_len
,
int
net_flag
,
cv
::
Mat
&
origin
);
virtual
NET_TYPE
get_net_flag
()
const
;
virtual
NET_TYPE
get_net_flag
()
const
;
...
@@ -80,25 +87,26 @@ private:
...
@@ -80,25 +87,26 @@ private:
calc_filtered_boxes
(
const
float
*
pred
,
int
pred_size
,
int
output_height
,
calc_filtered_boxes
(
const
float
*
pred
,
int
pred_size
,
int
output_height
,
int
output_width
,
const
cv
::
Mat
&
origin
);
int
output_width
,
const
cv
::
Mat
&
origin
);
void
infer_det
(
cv
::
Mat
&
origin
,
int
max_side_len
,
std
::
vector
<
OCRPredictResult
>&
ocr_results
);
/**
/**
* infer for
second
model
* infer for
rec
model
*
*
* @param boxes
* @param boxes
* @param origin
* @param origin
* @return
* @return
*/
*/
std
::
vector
<
OCRPredictResult
>
void
infer_rec
(
const
std
::
vector
<
std
::
vector
<
std
::
vector
<
int
>>>
&
boxes
,
infer_rec
(
const
cv
::
Mat
&
origin
,
int
run_cls
,
OCRPredictResult
&
ocr_result
);
const
cv
::
Mat
&
origin
);
/**
/**
* infer for cls model
* infer for cls model
*
*
* @param boxes
* @param boxes
* @param origin
* @param origin
* @return
* @return
*/
*/
cv
::
Ma
t
infer_cls
(
const
cv
::
Mat
&
origin
,
float
thresh
=
0.9
);
ClsPredictResul
t
infer_cls
(
const
cv
::
Mat
&
origin
,
float
thresh
=
0.9
);
/**
/**
* Postprocess or sencod model to extract text
* Postprocess or sencod model to extract text
...
...
deploy/android_demo/app/src/main/java/com/baidu/paddle/lite/demo/ocr/MainActivity.java
浏览文件 @
3cb7a609
...
@@ -13,6 +13,7 @@ import android.graphics.BitmapFactory;
...
@@ -13,6 +13,7 @@ import android.graphics.BitmapFactory;
import
android.graphics.drawable.BitmapDrawable
;
import
android.graphics.drawable.BitmapDrawable
;
import
android.media.ExifInterface
;
import
android.media.ExifInterface
;
import
android.content.res.AssetManager
;
import
android.content.res.AssetManager
;
import
android.media.FaceDetector
;
import
android.net.Uri
;
import
android.net.Uri
;
import
android.os.Bundle
;
import
android.os.Bundle
;
import
android.os.Environment
;
import
android.os.Environment
;
...
@@ -27,7 +28,9 @@ import android.view.Menu;
...
@@ -27,7 +28,9 @@ import android.view.Menu;
import
android.view.MenuInflater
;
import
android.view.MenuInflater
;
import
android.view.MenuItem
;
import
android.view.MenuItem
;
import
android.view.View
;
import
android.view.View
;
import
android.widget.CheckBox
;
import
android.widget.ImageView
;
import
android.widget.ImageView
;
import
android.widget.Spinner
;
import
android.widget.TextView
;
import
android.widget.TextView
;
import
android.widget.Toast
;
import
android.widget.Toast
;
...
@@ -68,23 +71,24 @@ public class MainActivity extends AppCompatActivity {
...
@@ -68,23 +71,24 @@ public class MainActivity extends AppCompatActivity {
protected
ImageView
ivInputImage
;
protected
ImageView
ivInputImage
;
protected
TextView
tvOutputResult
;
protected
TextView
tvOutputResult
;
protected
TextView
tvInferenceTime
;
protected
TextView
tvInferenceTime
;
protected
CheckBox
cbOpencl
;
protected
Spinner
spRunMode
;
// Model settings of o
bject detection
// Model settings of o
cr
protected
String
modelPath
=
""
;
protected
String
modelPath
=
""
;
protected
String
labelPath
=
""
;
protected
String
labelPath
=
""
;
protected
String
imagePath
=
""
;
protected
String
imagePath
=
""
;
protected
int
cpuThreadNum
=
1
;
protected
int
cpuThreadNum
=
1
;
protected
String
cpuPowerMode
=
""
;
protected
String
cpuPowerMode
=
""
;
protected
String
inputColorFormat
=
""
;
protected
int
detLongSize
=
960
;
protected
long
[]
inputShape
=
new
long
[]{};
protected
float
[]
inputMean
=
new
float
[]{};
protected
float
[]
inputStd
=
new
float
[]{};
protected
float
scoreThreshold
=
0.1f
;
protected
float
scoreThreshold
=
0.1f
;
private
String
currentPhotoPath
;
private
String
currentPhotoPath
;
private
AssetManager
assetManager
=
null
;
private
AssetManager
assetManager
=
null
;
protected
Predictor
predictor
=
new
Predictor
();
protected
Predictor
predictor
=
new
Predictor
();
private
Bitmap
cur_predict_image
=
null
;
@Override
@Override
protected
void
onCreate
(
Bundle
savedInstanceState
)
{
protected
void
onCreate
(
Bundle
savedInstanceState
)
{
super
.
onCreate
(
savedInstanceState
);
super
.
onCreate
(
savedInstanceState
);
...
@@ -98,10 +102,12 @@ public class MainActivity extends AppCompatActivity {
...
@@ -98,10 +102,12 @@ public class MainActivity extends AppCompatActivity {
// Setup the UI components
// Setup the UI components
tvInputSetting
=
findViewById
(
R
.
id
.
tv_input_setting
);
tvInputSetting
=
findViewById
(
R
.
id
.
tv_input_setting
);
cbOpencl
=
findViewById
(
R
.
id
.
cb_opencl
);
tvStatus
=
findViewById
(
R
.
id
.
tv_model_img_status
);
tvStatus
=
findViewById
(
R
.
id
.
tv_model_img_status
);
ivInputImage
=
findViewById
(
R
.
id
.
iv_input_image
);
ivInputImage
=
findViewById
(
R
.
id
.
iv_input_image
);
tvInferenceTime
=
findViewById
(
R
.
id
.
tv_inference_time
);
tvInferenceTime
=
findViewById
(
R
.
id
.
tv_inference_time
);
tvOutputResult
=
findViewById
(
R
.
id
.
tv_output_result
);
tvOutputResult
=
findViewById
(
R
.
id
.
tv_output_result
);
spRunMode
=
findViewById
(
R
.
id
.
sp_run_mode
);
tvInputSetting
.
setMovementMethod
(
ScrollingMovementMethod
.
getInstance
());
tvInputSetting
.
setMovementMethod
(
ScrollingMovementMethod
.
getInstance
());
tvOutputResult
.
setMovementMethod
(
ScrollingMovementMethod
.
getInstance
());
tvOutputResult
.
setMovementMethod
(
ScrollingMovementMethod
.
getInstance
());
...
@@ -111,26 +117,26 @@ public class MainActivity extends AppCompatActivity {
...
@@ -111,26 +117,26 @@ public class MainActivity extends AppCompatActivity {
public
void
handleMessage
(
Message
msg
)
{
public
void
handleMessage
(
Message
msg
)
{
switch
(
msg
.
what
)
{
switch
(
msg
.
what
)
{
case
RESPONSE_LOAD_MODEL_SUCCESSED:
case
RESPONSE_LOAD_MODEL_SUCCESSED:
if
(
pbLoadModel
!=
null
&&
pbLoadModel
.
isShowing
())
{
if
(
pbLoadModel
!=
null
&&
pbLoadModel
.
isShowing
())
{
pbLoadModel
.
dismiss
();
pbLoadModel
.
dismiss
();
}
}
onLoadModelSuccessed
();
onLoadModelSuccessed
();
break
;
break
;
case
RESPONSE_LOAD_MODEL_FAILED:
case
RESPONSE_LOAD_MODEL_FAILED:
if
(
pbLoadModel
!=
null
&&
pbLoadModel
.
isShowing
())
{
if
(
pbLoadModel
!=
null
&&
pbLoadModel
.
isShowing
())
{
pbLoadModel
.
dismiss
();
pbLoadModel
.
dismiss
();
}
}
Toast
.
makeText
(
MainActivity
.
this
,
"Load model failed!"
,
Toast
.
LENGTH_SHORT
).
show
();
Toast
.
makeText
(
MainActivity
.
this
,
"Load model failed!"
,
Toast
.
LENGTH_SHORT
).
show
();
onLoadModelFailed
();
onLoadModelFailed
();
break
;
break
;
case
RESPONSE_RUN_MODEL_SUCCESSED:
case
RESPONSE_RUN_MODEL_SUCCESSED:
if
(
pbRunModel
!=
null
&&
pbRunModel
.
isShowing
())
{
if
(
pbRunModel
!=
null
&&
pbRunModel
.
isShowing
())
{
pbRunModel
.
dismiss
();
pbRunModel
.
dismiss
();
}
}
onRunModelSuccessed
();
onRunModelSuccessed
();
break
;
break
;
case
RESPONSE_RUN_MODEL_FAILED:
case
RESPONSE_RUN_MODEL_FAILED:
if
(
pbRunModel
!=
null
&&
pbRunModel
.
isShowing
())
{
if
(
pbRunModel
!=
null
&&
pbRunModel
.
isShowing
())
{
pbRunModel
.
dismiss
();
pbRunModel
.
dismiss
();
}
}
Toast
.
makeText
(
MainActivity
.
this
,
"Run model failed!"
,
Toast
.
LENGTH_SHORT
).
show
();
Toast
.
makeText
(
MainActivity
.
this
,
"Run model failed!"
,
Toast
.
LENGTH_SHORT
).
show
();
...
@@ -185,7 +191,6 @@ public class MainActivity extends AppCompatActivity {
...
@@ -185,7 +191,6 @@ public class MainActivity extends AppCompatActivity {
model_settingsChanged
|=
!
model_path
.
equalsIgnoreCase
(
modelPath
);
model_settingsChanged
|=
!
model_path
.
equalsIgnoreCase
(
modelPath
);
settingsChanged
|=
!
label_path
.
equalsIgnoreCase
(
labelPath
);
settingsChanged
|=
!
label_path
.
equalsIgnoreCase
(
labelPath
);
settingsChanged
|=
!
image_path
.
equalsIgnoreCase
(
imagePath
);
settingsChanged
|=
!
image_path
.
equalsIgnoreCase
(
imagePath
);
int
cpu_thread_num
=
Integer
.
parseInt
(
sharedPreferences
.
getString
(
getString
(
R
.
string
.
CPU_THREAD_NUM_KEY
),
int
cpu_thread_num
=
Integer
.
parseInt
(
sharedPreferences
.
getString
(
getString
(
R
.
string
.
CPU_THREAD_NUM_KEY
),
getString
(
R
.
string
.
CPU_THREAD_NUM_DEFAULT
)));
getString
(
R
.
string
.
CPU_THREAD_NUM_DEFAULT
)));
model_settingsChanged
|=
cpu_thread_num
!=
cpuThreadNum
;
model_settingsChanged
|=
cpu_thread_num
!=
cpuThreadNum
;
...
@@ -194,33 +199,9 @@ public class MainActivity extends AppCompatActivity {
...
@@ -194,33 +199,9 @@ public class MainActivity extends AppCompatActivity {
getString
(
R
.
string
.
CPU_POWER_MODE_DEFAULT
));
getString
(
R
.
string
.
CPU_POWER_MODE_DEFAULT
));
model_settingsChanged
|=
!
cpu_power_mode
.
equalsIgnoreCase
(
cpuPowerMode
);
model_settingsChanged
|=
!
cpu_power_mode
.
equalsIgnoreCase
(
cpuPowerMode
);
String
input_color_format
=
int
det_long_size
=
Integer
.
parseInt
(
sharedPreferences
.
getString
(
getString
(
R
.
string
.
DET_LONG_SIZE_KEY
),
sharedPreferences
.
getString
(
getString
(
R
.
string
.
INPUT_COLOR_FORMAT_KEY
),
getString
(
R
.
string
.
DET_LONG_SIZE_DEFAULT
)));
getString
(
R
.
string
.
INPUT_COLOR_FORMAT_DEFAULT
));
settingsChanged
|=
det_long_size
!=
detLongSize
;
settingsChanged
|=
!
input_color_format
.
equalsIgnoreCase
(
inputColorFormat
);
long
[]
input_shape
=
Utils
.
parseLongsFromString
(
sharedPreferences
.
getString
(
getString
(
R
.
string
.
INPUT_SHAPE_KEY
),
getString
(
R
.
string
.
INPUT_SHAPE_DEFAULT
)),
","
);
float
[]
input_mean
=
Utils
.
parseFloatsFromString
(
sharedPreferences
.
getString
(
getString
(
R
.
string
.
INPUT_MEAN_KEY
),
getString
(
R
.
string
.
INPUT_MEAN_DEFAULT
)),
","
);
float
[]
input_std
=
Utils
.
parseFloatsFromString
(
sharedPreferences
.
getString
(
getString
(
R
.
string
.
INPUT_STD_KEY
)
,
getString
(
R
.
string
.
INPUT_STD_DEFAULT
)),
","
);
settingsChanged
|=
input_shape
.
length
!=
inputShape
.
length
;
settingsChanged
|=
input_mean
.
length
!=
inputMean
.
length
;
settingsChanged
|=
input_std
.
length
!=
inputStd
.
length
;
if
(!
settingsChanged
)
{
for
(
int
i
=
0
;
i
<
input_shape
.
length
;
i
++)
{
settingsChanged
|=
input_shape
[
i
]
!=
inputShape
[
i
];
}
for
(
int
i
=
0
;
i
<
input_mean
.
length
;
i
++)
{
settingsChanged
|=
input_mean
[
i
]
!=
inputMean
[
i
];
}
for
(
int
i
=
0
;
i
<
input_std
.
length
;
i
++)
{
settingsChanged
|=
input_std
[
i
]
!=
inputStd
[
i
];
}
}
float
score_threshold
=
float
score_threshold
=
Float
.
parseFloat
(
sharedPreferences
.
getString
(
getString
(
R
.
string
.
SCORE_THRESHOLD_KEY
),
Float
.
parseFloat
(
sharedPreferences
.
getString
(
getString
(
R
.
string
.
SCORE_THRESHOLD_KEY
),
getString
(
R
.
string
.
SCORE_THRESHOLD_DEFAULT
)));
getString
(
R
.
string
.
SCORE_THRESHOLD_DEFAULT
)));
...
@@ -228,20 +209,16 @@ public class MainActivity extends AppCompatActivity {
...
@@ -228,20 +209,16 @@ public class MainActivity extends AppCompatActivity {
if
(
settingsChanged
)
{
if
(
settingsChanged
)
{
labelPath
=
label_path
;
labelPath
=
label_path
;
imagePath
=
image_path
;
imagePath
=
image_path
;
inputColorFormat
=
input_color_format
;
detLongSize
=
det_long_size
;
inputShape
=
input_shape
;
inputMean
=
input_mean
;
inputStd
=
input_std
;
scoreThreshold
=
score_threshold
;
scoreThreshold
=
score_threshold
;
set_img
();
set_img
();
}
}
if
(
model_settingsChanged
){
if
(
model_settingsChanged
)
{
modelPath
=
model_path
;
modelPath
=
model_path
;
cpuThreadNum
=
cpu_thread_num
;
cpuThreadNum
=
cpu_thread_num
;
cpuPowerMode
=
cpu_power_mode
;
cpuPowerMode
=
cpu_power_mode
;
// Update UI
// Update UI
tvInputSetting
.
setText
(
"Model: "
+
modelPath
.
substring
(
modelPath
.
lastIndexOf
(
"/"
)
+
1
)
+
"\n"
+
"CPU"
+
tvInputSetting
.
setText
(
"Model: "
+
modelPath
.
substring
(
modelPath
.
lastIndexOf
(
"/"
)
+
1
)
+
"\nOPENCL: "
+
cbOpencl
.
isChecked
()
+
"\nCPU Thread Num: "
+
cpuThreadNum
+
"\nCPU Power Mode: "
+
cpuPowerMode
);
" Thread Num: "
+
Integer
.
toString
(
cpuThreadNum
)
+
"\n"
+
"CPU Power Mode: "
+
cpuPowerMode
);
tvInputSetting
.
scrollTo
(
0
,
0
);
tvInputSetting
.
scrollTo
(
0
,
0
);
// Reload model if configure has been changed
// Reload model if configure has been changed
loadModel
();
loadModel
();
...
@@ -259,20 +236,28 @@ public class MainActivity extends AppCompatActivity {
...
@@ -259,20 +236,28 @@ public class MainActivity extends AppCompatActivity {
}
}
public
boolean
onLoadModel
()
{
public
boolean
onLoadModel
()
{
return
predictor
.
init
(
MainActivity
.
this
,
modelPath
,
labelPath
,
cpuThreadNum
,
if
(
predictor
.
isLoaded
())
{
predictor
.
releaseModel
();
}
return
predictor
.
init
(
MainActivity
.
this
,
modelPath
,
labelPath
,
cbOpencl
.
isChecked
()
?
1
:
0
,
cpuThreadNum
,
cpuPowerMode
,
cpuPowerMode
,
inputColorFormat
,
detLongSize
,
scoreThreshold
);
inputShape
,
inputMean
,
inputStd
,
scoreThreshold
);
}
}
public
boolean
onRunModel
()
{
public
boolean
onRunModel
()
{
return
predictor
.
isLoaded
()
&&
predictor
.
runModel
();
String
run_mode
=
spRunMode
.
getSelectedItem
().
toString
();
int
run_det
=
run_mode
.
contains
(
"检测"
)
?
1
:
0
;
int
run_cls
=
run_mode
.
contains
(
"分类"
)
?
1
:
0
;
int
run_rec
=
run_mode
.
contains
(
"识别"
)
?
1
:
0
;
return
predictor
.
isLoaded
()
&&
predictor
.
runModel
(
run_det
,
run_cls
,
run_rec
);
}
}
public
void
onLoadModelSuccessed
()
{
public
void
onLoadModelSuccessed
()
{
// Load test image from path and run model
// Load test image from path and run model
tvInputSetting
.
setText
(
"Model: "
+
modelPath
.
substring
(
modelPath
.
lastIndexOf
(
"/"
)
+
1
)
+
"\nOPENCL: "
+
cbOpencl
.
isChecked
()
+
"\nCPU Thread Num: "
+
cpuThreadNum
+
"\nCPU Power Mode: "
+
cpuPowerMode
);
tvInputSetting
.
scrollTo
(
0
,
0
);
tvStatus
.
setText
(
"STATUS: load model successed"
);
tvStatus
.
setText
(
"STATUS: load model successed"
);
}
}
public
void
onLoadModelFailed
()
{
public
void
onLoadModelFailed
()
{
...
@@ -306,9 +291,9 @@ public class MainActivity extends AppCompatActivity {
...
@@ -306,9 +291,9 @@ public class MainActivity extends AppCompatActivity {
public
void
set_img
()
{
public
void
set_img
()
{
// Load test image from path and run model
// Load test image from path and run model
try
{
try
{
assetManager
=
getAssets
();
assetManager
=
getAssets
();
InputStream
in
=
assetManager
.
open
(
imagePath
);
InputStream
in
=
assetManager
.
open
(
imagePath
);
Bitmap
bmp
=
BitmapFactory
.
decodeStream
(
in
);
Bitmap
bmp
=
BitmapFactory
.
decodeStream
(
in
);
ivInputImage
.
setImageBitmap
(
bmp
);
ivInputImage
.
setImageBitmap
(
bmp
);
}
catch
(
IOException
e
)
{
}
catch
(
IOException
e
)
{
Toast
.
makeText
(
MainActivity
.
this
,
"Load image failed!"
,
Toast
.
LENGTH_SHORT
).
show
();
Toast
.
makeText
(
MainActivity
.
this
,
"Load image failed!"
,
Toast
.
LENGTH_SHORT
).
show
();
...
@@ -469,28 +454,28 @@ public class MainActivity extends AppCompatActivity {
...
@@ -469,28 +454,28 @@ public class MainActivity extends AppCompatActivity {
}
}
}
}
public
void
btn_
load_model
_click
(
View
view
)
{
public
void
btn_
reset_img
_click
(
View
view
)
{
i
f
(
predictor
.
isLoaded
()){
i
vInputImage
.
setImageBitmap
(
cur_predict_image
);
tvStatus
.
setText
(
"STATUS: model has been loaded"
);
}
}
else
{
tvStatus
.
setText
(
"STATUS: load model ......"
);
public
void
cb_opencl_click
(
View
view
)
{
loadModel
(
);
tvStatus
.
setText
(
"STATUS: load model ......"
);
}
loadModel
();
}
}
public
void
btn_run_model_click
(
View
view
)
{
public
void
btn_run_model_click
(
View
view
)
{
Bitmap
image
=((
BitmapDrawable
)
ivInputImage
.
getDrawable
()).
getBitmap
();
cur_predict_image
=
((
BitmapDrawable
)
ivInputImage
.
getDrawable
()).
getBitmap
();
if
(
image
==
null
)
{
if
(
cur_predict_
image
==
null
)
{
tvStatus
.
setText
(
"STATUS: image is not exists"
);
tvStatus
.
setText
(
"STATUS: image is not exists"
);
}
}
else
if
(!
predictor
.
isLoaded
())
{
else
if
(!
predictor
.
isLoaded
()){
tvStatus
.
setText
(
"STATUS: model is not loaded"
);
tvStatus
.
setText
(
"STATUS: model is not loaded"
);
}
else
{
}
else
{
tvStatus
.
setText
(
"STATUS: run model ...... "
);
tvStatus
.
setText
(
"STATUS: run model ...... "
);
predictor
.
setInputImage
(
image
);
predictor
.
setInputImage
(
cur_predict_
image
);
runModel
();
runModel
();
}
}
}
}
public
void
btn_choice_img_click
(
View
view
)
{
public
void
btn_choice_img_click
(
View
view
)
{
if
(
requestAllPermissions
())
{
if
(
requestAllPermissions
())
{
openGallery
();
openGallery
();
...
@@ -511,4 +496,32 @@ public class MainActivity extends AppCompatActivity {
...
@@ -511,4 +496,32 @@ public class MainActivity extends AppCompatActivity {
worker
.
quit
();
worker
.
quit
();
super
.
onDestroy
();
super
.
onDestroy
();
}
}
public
int
get_run_mode
()
{
String
run_mode
=
spRunMode
.
getSelectedItem
().
toString
();
int
mode
;
switch
(
run_mode
)
{
case
"检测+分类+识别"
:
mode
=
1
;
break
;
case
"检测+识别"
:
mode
=
2
;
break
;
case
"识别+分类"
:
mode
=
3
;
break
;
case
"检测"
:
mode
=
4
;
break
;
case
"识别"
:
mode
=
5
;
break
;
case
"分类"
:
mode
=
6
;
break
;
default
:
mode
=
1
;
}
return
mode
;
}
}
}
deploy/android_demo/app/src/main/java/com/baidu/paddle/lite/demo/ocr/OCRPredictorNative.java
浏览文件 @
3cb7a609
...
@@ -29,22 +29,22 @@ public class OCRPredictorNative {
...
@@ -29,22 +29,22 @@ public class OCRPredictorNative {
public
OCRPredictorNative
(
Config
config
)
{
public
OCRPredictorNative
(
Config
config
)
{
this
.
config
=
config
;
this
.
config
=
config
;
loadLibrary
();
loadLibrary
();
nativePointer
=
init
(
config
.
detModelFilename
,
config
.
recModelFilename
,
config
.
clsModelFilename
,
nativePointer
=
init
(
config
.
detModelFilename
,
config
.
recModelFilename
,
config
.
clsModelFilename
,
config
.
useOpencl
,
config
.
cpuThreadNum
,
config
.
cpuPower
);
config
.
cpuThreadNum
,
config
.
cpuPower
);
Log
.
i
(
"OCRPredictorNative"
,
"load success "
+
nativePointer
);
Log
.
i
(
"OCRPredictorNative"
,
"load success "
+
nativePointer
);
}
}
public
ArrayList
<
OcrResultModel
>
runImage
(
float
[]
inputData
,
int
width
,
int
height
,
int
channels
,
Bitmap
originalImage
)
{
public
ArrayList
<
OcrResultModel
>
runImage
(
Bitmap
originalImage
,
int
max_size_len
,
int
run_det
,
int
run_cls
,
int
run_rec
)
{
Log
.
i
(
"OCRPredictorNative"
,
"begin to run image "
+
inputData
.
length
+
" "
+
width
+
" "
+
height
);
Log
.
i
(
"OCRPredictorNative"
,
"begin to run image "
);
float
[]
dims
=
new
float
[]{
1
,
channels
,
height
,
width
};
float
[]
rawResults
=
forward
(
nativePointer
,
originalImage
,
max_size_len
,
run_det
,
run_cls
,
run_rec
);
float
[]
rawResults
=
forward
(
nativePointer
,
inputData
,
dims
,
originalImage
);
ArrayList
<
OcrResultModel
>
results
=
postprocess
(
rawResults
);
ArrayList
<
OcrResultModel
>
results
=
postprocess
(
rawResults
);
return
results
;
return
results
;
}
}
public
static
class
Config
{
public
static
class
Config
{
public
int
useOpencl
;
public
int
cpuThreadNum
;
public
int
cpuThreadNum
;
public
String
cpuPower
;
public
String
cpuPower
;
public
String
detModelFilename
;
public
String
detModelFilename
;
...
@@ -53,16 +53,16 @@ public class OCRPredictorNative {
...
@@ -53,16 +53,16 @@ public class OCRPredictorNative {
}
}
public
void
destory
(){
public
void
destory
()
{
if
(
nativePointer
>
0
)
{
if
(
nativePointer
>
0
)
{
release
(
nativePointer
);
release
(
nativePointer
);
nativePointer
=
0
;
nativePointer
=
0
;
}
}
}
}
protected
native
long
init
(
String
detModelPath
,
String
recModelPath
,
String
clsModelPath
,
int
threadNum
,
String
cpuMode
);
protected
native
long
init
(
String
detModelPath
,
String
recModelPath
,
String
clsModelPath
,
int
useOpencl
,
int
threadNum
,
String
cpuMode
);
protected
native
float
[]
forward
(
long
pointer
,
float
[]
buf
,
float
[]
ddims
,
Bitmap
originalImage
);
protected
native
float
[]
forward
(
long
pointer
,
Bitmap
originalImage
,
int
max_size_len
,
int
run_det
,
int
run_cls
,
int
run_rec
);
protected
native
void
release
(
long
pointer
);
protected
native
void
release
(
long
pointer
);
...
@@ -73,9 +73,9 @@ public class OCRPredictorNative {
...
@@ -73,9 +73,9 @@ public class OCRPredictorNative {
while
(
begin
<
raw
.
length
)
{
while
(
begin
<
raw
.
length
)
{
int
point_num
=
Math
.
round
(
raw
[
begin
]);
int
point_num
=
Math
.
round
(
raw
[
begin
]);
int
word_num
=
Math
.
round
(
raw
[
begin
+
1
]);
int
word_num
=
Math
.
round
(
raw
[
begin
+
1
]);
OcrResultModel
model
=
parse
(
raw
,
begin
+
2
,
point_num
,
word_num
);
OcrResultModel
res
=
parse
(
raw
,
begin
+
2
,
point_num
,
word_num
);
begin
+=
2
+
1
+
point_num
*
2
+
word_num
;
begin
+=
2
+
1
+
point_num
*
2
+
word_num
+
2
;
results
.
add
(
model
);
results
.
add
(
res
);
}
}
return
results
;
return
results
;
...
@@ -83,19 +83,22 @@ public class OCRPredictorNative {
...
@@ -83,19 +83,22 @@ public class OCRPredictorNative {
private
OcrResultModel
parse
(
float
[]
raw
,
int
begin
,
int
pointNum
,
int
wordNum
)
{
private
OcrResultModel
parse
(
float
[]
raw
,
int
begin
,
int
pointNum
,
int
wordNum
)
{
int
current
=
begin
;
int
current
=
begin
;
OcrResultModel
model
=
new
OcrResultModel
();
OcrResultModel
res
=
new
OcrResultModel
();
model
.
setConfidence
(
raw
[
current
]);
res
.
setConfidence
(
raw
[
current
]);
current
++;
current
++;
for
(
int
i
=
0
;
i
<
pointNum
;
i
++)
{
for
(
int
i
=
0
;
i
<
pointNum
;
i
++)
{
model
.
addPoints
(
Math
.
round
(
raw
[
current
+
i
*
2
]),
Math
.
round
(
raw
[
current
+
i
*
2
+
1
]));
res
.
addPoints
(
Math
.
round
(
raw
[
current
+
i
*
2
]),
Math
.
round
(
raw
[
current
+
i
*
2
+
1
]));
}
}
current
+=
(
pointNum
*
2
);
current
+=
(
pointNum
*
2
);
for
(
int
i
=
0
;
i
<
wordNum
;
i
++)
{
for
(
int
i
=
0
;
i
<
wordNum
;
i
++)
{
int
index
=
Math
.
round
(
raw
[
current
+
i
]);
int
index
=
Math
.
round
(
raw
[
current
+
i
]);
model
.
addWordIndex
(
index
);
res
.
addWordIndex
(
index
);
}
}
current
+=
wordNum
;
res
.
setClsIdx
(
raw
[
current
]);
res
.
setClsConfidence
(
raw
[
current
+
1
]);
Log
.
i
(
"OCRPredictorNative"
,
"word finished "
+
wordNum
);
Log
.
i
(
"OCRPredictorNative"
,
"word finished "
+
wordNum
);
return
model
;
return
res
;
}
}
...
...
deploy/android_demo/app/src/main/java/com/baidu/paddle/lite/demo/ocr/OcrResultModel.java
浏览文件 @
3cb7a609
...
@@ -10,6 +10,9 @@ public class OcrResultModel {
...
@@ -10,6 +10,9 @@ public class OcrResultModel {
private
List
<
Integer
>
wordIndex
;
private
List
<
Integer
>
wordIndex
;
private
String
label
;
private
String
label
;
private
float
confidence
;
private
float
confidence
;
private
float
cls_idx
;
private
String
cls_label
;
private
float
cls_confidence
;
public
OcrResultModel
()
{
public
OcrResultModel
()
{
super
();
super
();
...
@@ -49,4 +52,28 @@ public class OcrResultModel {
...
@@ -49,4 +52,28 @@ public class OcrResultModel {
public
void
setConfidence
(
float
confidence
)
{
public
void
setConfidence
(
float
confidence
)
{
this
.
confidence
=
confidence
;
this
.
confidence
=
confidence
;
}
}
public
float
getClsIdx
()
{
return
cls_idx
;
}
public
void
setClsIdx
(
float
idx
)
{
this
.
cls_idx
=
idx
;
}
public
String
getClsLabel
()
{
return
cls_label
;
}
public
void
setClsLabel
(
String
label
)
{
this
.
cls_label
=
label
;
}
public
float
getClsConfidence
()
{
return
cls_confidence
;
}
public
void
setClsConfidence
(
float
confidence
)
{
this
.
cls_confidence
=
confidence
;
}
}
}
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录