Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
s920243400
PaddleOCR
提交
ad4853db
P
PaddleOCR
项目概览
s920243400
/
PaddleOCR
与 Fork 源项目一致
Fork自
PaddlePaddle / PaddleOCR
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
PaddleOCR
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
ad4853db
编写于
6月 03, 2021
作者:
Z
zhoujun
提交者:
GitHub
6月 03, 2021
浏览文件
操作
浏览文件
下载
差异文件
Merge pull request #2925 from WenmuZhou/Optimizing_parameters
combine args in paddleocr and ppocr/infer/utility
上级
1bc07888
bb5c6f3b
变更
5
隐藏空白更改
内联
并排
Showing
5 changed file
with
35 addition
and
112 deletion
+35
-112
doc/doc_ch/whl.md
doc/doc_ch/whl.md
+1
-1
doc/doc_en/whl_en.md
doc/doc_en/whl_en.md
+1
-1
paddleocr.py
paddleocr.py
+20
-104
tools/infer/predict_system.py
tools/infer/predict_system.py
+2
-2
tools/infer/utility.py
tools/infer/utility.py
+11
-4
未找到文件。
doc/doc_ch/whl.md
浏览文件 @
ad4853db
...
...
@@ -59,7 +59,7 @@ im_show.save('result.jpg')
from
paddleocr
import
PaddleOCR
,
draw_ocr
ocr
=
PaddleOCR
()
# need to run only once to download and load model into memory
img_path
=
'PaddleOCR/doc/imgs/11.jpg'
result
=
ocr
.
ocr
(
img_path
)
result
=
ocr
.
ocr
(
img_path
,
cls
=
False
)
for
line
in
result
:
print
(
line
)
...
...
doc/doc_en/whl_en.md
浏览文件 @
ad4853db
...
...
@@ -59,7 +59,7 @@ Visualization of results
from
paddleocr
import
PaddleOCR
,
draw_ocr
ocr
=
PaddleOCR
(
lang
=
'en'
)
# need to run only once to download and load model into memory
img_path
=
'PaddleOCR/doc/imgs_en/img_12.jpg'
result
=
ocr
.
ocr
(
img_path
)
result
=
ocr
.
ocr
(
img_path
,
cls
=
False
)
for
line
in
result
:
print
(
line
)
...
...
paddleocr.py
浏览文件 @
ad4853db
...
...
@@ -30,7 +30,7 @@ from ppocr.utils.logging import get_logger
logger
=
get_logger
()
from
ppocr.utils.utility
import
check_and_read_gif
,
get_image_file_list
from
tools.infer.utility
import
draw_ocr
from
tools.infer.utility
import
draw_ocr
,
init_args
,
str2bool
__all__
=
[
'PaddleOCR'
]
...
...
@@ -167,106 +167,24 @@ def maybe_download(model_storage_directory, url):
os
.
remove
(
tmp_path
)
def
parse_args
(
mMain
=
True
,
add_help
=
True
):
def
parse_args
(
mMain
=
True
):
import
argparse
def
str2bool
(
v
):
return
v
.
lower
()
in
(
"true"
,
"t"
,
"1"
)
parser
=
init_args
()
parser
.
add_help
=
mMain
parser
.
add_argument
(
"--lang"
,
type
=
str
,
default
=
'ch'
)
parser
.
add_argument
(
"--det"
,
type
=
str2bool
,
default
=
True
)
parser
.
add_argument
(
"--rec"
,
type
=
str2bool
,
default
=
True
)
for
action
in
parser
.
_actions
:
if
action
.
dest
==
'rec_char_dict_path'
:
action
.
default
=
None
if
mMain
:
parser
=
argparse
.
ArgumentParser
(
add_help
=
add_help
)
# params for prediction engine
parser
.
add_argument
(
"--use_gpu"
,
type
=
str2bool
,
default
=
True
)
parser
.
add_argument
(
"--ir_optim"
,
type
=
str2bool
,
default
=
True
)
parser
.
add_argument
(
"--use_tensorrt"
,
type
=
str2bool
,
default
=
False
)
parser
.
add_argument
(
"--gpu_mem"
,
type
=
int
,
default
=
8000
)
# params for text detector
parser
.
add_argument
(
"--image_dir"
,
type
=
str
)
parser
.
add_argument
(
"--det_algorithm"
,
type
=
str
,
default
=
'DB'
)
parser
.
add_argument
(
"--det_model_dir"
,
type
=
str
,
default
=
None
)
parser
.
add_argument
(
"--det_limit_side_len"
,
type
=
float
,
default
=
960
)
parser
.
add_argument
(
"--det_limit_type"
,
type
=
str
,
default
=
'max'
)
# DB parmas
parser
.
add_argument
(
"--det_db_thresh"
,
type
=
float
,
default
=
0.3
)
parser
.
add_argument
(
"--det_db_box_thresh"
,
type
=
float
,
default
=
0.5
)
parser
.
add_argument
(
"--det_db_unclip_ratio"
,
type
=
float
,
default
=
1.6
)
parser
.
add_argument
(
"--use_dilation"
,
type
=
bool
,
default
=
False
)
parser
.
add_argument
(
"--det_db_score_mode"
,
type
=
str
,
default
=
"fast"
)
# EAST parmas
parser
.
add_argument
(
"--det_east_score_thresh"
,
type
=
float
,
default
=
0.8
)
parser
.
add_argument
(
"--det_east_cover_thresh"
,
type
=
float
,
default
=
0.1
)
parser
.
add_argument
(
"--det_east_nms_thresh"
,
type
=
float
,
default
=
0.2
)
# params for text recognizer
parser
.
add_argument
(
"--rec_algorithm"
,
type
=
str
,
default
=
'CRNN'
)
parser
.
add_argument
(
"--rec_model_dir"
,
type
=
str
,
default
=
None
)
parser
.
add_argument
(
"--rec_image_shape"
,
type
=
str
,
default
=
"3, 32, 320"
)
parser
.
add_argument
(
"--rec_char_type"
,
type
=
str
,
default
=
'ch'
)
parser
.
add_argument
(
"--rec_batch_num"
,
type
=
int
,
default
=
6
)
parser
.
add_argument
(
"--max_text_length"
,
type
=
int
,
default
=
25
)
parser
.
add_argument
(
"--rec_char_dict_path"
,
type
=
str
,
default
=
None
)
parser
.
add_argument
(
"--use_space_char"
,
type
=
bool
,
default
=
True
)
parser
.
add_argument
(
"--drop_score"
,
type
=
float
,
default
=
0.5
)
# params for text classifier
parser
.
add_argument
(
"--cls_model_dir"
,
type
=
str
,
default
=
None
)
parser
.
add_argument
(
"--cls_image_shape"
,
type
=
str
,
default
=
"3, 48, 192"
)
parser
.
add_argument
(
"--label_list"
,
type
=
list
,
default
=
[
'0'
,
'180'
])
parser
.
add_argument
(
"--cls_batch_num"
,
type
=
int
,
default
=
6
)
parser
.
add_argument
(
"--cls_thresh"
,
type
=
float
,
default
=
0.9
)
parser
.
add_argument
(
"--enable_mkldnn"
,
type
=
bool
,
default
=
False
)
parser
.
add_argument
(
"--use_zero_copy_run"
,
type
=
bool
,
default
=
False
)
parser
.
add_argument
(
"--use_pdserving"
,
type
=
str2bool
,
default
=
False
)
parser
.
add_argument
(
"--lang"
,
type
=
str
,
default
=
'ch'
)
parser
.
add_argument
(
"--det"
,
type
=
str2bool
,
default
=
True
)
parser
.
add_argument
(
"--rec"
,
type
=
str2bool
,
default
=
True
)
parser
.
add_argument
(
"--use_angle_cls"
,
type
=
str2bool
,
default
=
False
)
return
parser
.
parse_args
()
else
:
return
argparse
.
Namespace
(
use_gpu
=
True
,
ir_optim
=
True
,
use_tensorrt
=
False
,
gpu_mem
=
8000
,
image_dir
=
''
,
det_algorithm
=
'DB'
,
det_model_dir
=
None
,
det_limit_side_len
=
960
,
det_limit_type
=
'max'
,
det_db_thresh
=
0.3
,
det_db_box_thresh
=
0.5
,
det_db_unclip_ratio
=
1.6
,
use_dilation
=
False
,
det_db_score_mode
=
"fast"
,
det_east_score_thresh
=
0.8
,
det_east_cover_thresh
=
0.1
,
det_east_nms_thresh
=
0.2
,
rec_algorithm
=
'CRNN'
,
rec_model_dir
=
None
,
rec_image_shape
=
"3, 32, 320"
,
rec_char_type
=
'ch'
,
rec_batch_num
=
6
,
max_text_length
=
25
,
rec_char_dict_path
=
None
,
use_space_char
=
True
,
drop_score
=
0.5
,
cls_model_dir
=
None
,
cls_image_shape
=
"3, 48, 192"
,
label_list
=
[
'0'
,
'180'
],
cls_batch_num
=
6
,
cls_thresh
=
0.9
,
enable_mkldnn
=
False
,
use_zero_copy_run
=
False
,
use_pdserving
=
False
,
lang
=
'ch'
,
det
=
True
,
rec
=
True
,
use_angle_cls
=
False
)
inference_args_dict
=
{}
for
action
in
parser
.
_actions
:
inference_args_dict
[
action
.
dest
]
=
action
.
default
return
argparse
.
Namespace
(
**
inference_args_dict
)
class
PaddleOCR
(
predict_system
.
TextSystem
):
...
...
@@ -276,7 +194,7 @@ class PaddleOCR(predict_system.TextSystem):
args:
**kwargs: other params show in paddleocr --help
"""
postprocess_params
=
parse_args
(
mMain
=
False
,
add_help
=
False
)
postprocess_params
=
parse_args
(
mMain
=
False
)
postprocess_params
.
__dict__
.
update
(
**
kwargs
)
self
.
use_angle_cls
=
postprocess_params
.
use_angle_cls
lang
=
postprocess_params
.
lang
...
...
@@ -346,7 +264,7 @@ class PaddleOCR(predict_system.TextSystem):
# init det_model and rec_model
super
().
__init__
(
postprocess_params
)
def
ocr
(
self
,
img
,
det
=
True
,
rec
=
True
,
cls
=
Fals
e
):
def
ocr
(
self
,
img
,
det
=
True
,
rec
=
True
,
cls
=
Tru
e
):
"""
ocr with paddleocr
args:
...
...
@@ -358,9 +276,7 @@ class PaddleOCR(predict_system.TextSystem):
if
isinstance
(
img
,
list
)
and
det
==
True
:
logger
.
error
(
'When input a list of images, det must be false'
)
exit
(
0
)
if
cls
==
False
:
self
.
use_angle_cls
=
False
elif
cls
==
True
and
self
.
use_angle_cls
==
False
:
if
cls
==
True
and
self
.
use_angle_cls
==
False
:
logger
.
warning
(
'Since the angle classifier is not initialized, the angle classifier will not be uesd during the forward process'
)
...
...
@@ -382,7 +298,7 @@ class PaddleOCR(predict_system.TextSystem):
if
isinstance
(
img
,
np
.
ndarray
)
and
len
(
img
.
shape
)
==
2
:
img
=
cv2
.
cvtColor
(
img
,
cv2
.
COLOR_GRAY2BGR
)
if
det
and
rec
:
dt_boxes
,
rec_res
=
self
.
__call__
(
img
)
dt_boxes
,
rec_res
=
self
.
__call__
(
img
,
cls
)
return
[[
box
.
tolist
(),
res
]
for
box
,
res
in
zip
(
dt_boxes
,
rec_res
)]
elif
det
and
not
rec
:
dt_boxes
,
elapse
=
self
.
text_detector
(
img
)
...
...
@@ -392,7 +308,7 @@ class PaddleOCR(predict_system.TextSystem):
else
:
if
not
isinstance
(
img
,
list
):
img
=
[
img
]
if
self
.
use_angle_cls
:
if
self
.
use_angle_cls
and
cls
:
img
,
cls_res
,
elapse
=
self
.
text_classifier
(
img
)
if
not
rec
:
return
cls_res
...
...
tools/infer/predict_system.py
浏览文件 @
ad4853db
...
...
@@ -85,7 +85,7 @@ class TextSystem(object):
cv2
.
imwrite
(
"./output/img_crop_%d.jpg"
%
bno
,
img_crop_list
[
bno
])
logger
.
info
(
bno
,
rec_res
[
bno
])
def
__call__
(
self
,
img
):
def
__call__
(
self
,
img
,
cls
=
True
):
ori_im
=
img
.
copy
()
dt_boxes
,
elapse
=
self
.
text_detector
(
img
)
logger
.
info
(
"dt_boxes num : {}, elapse : {}"
.
format
(
...
...
@@ -100,7 +100,7 @@ class TextSystem(object):
tmp_box
=
copy
.
deepcopy
(
dt_boxes
[
bno
])
img_crop
=
self
.
get_rotate_crop_image
(
ori_im
,
tmp_box
)
img_crop_list
.
append
(
img_crop
)
if
self
.
use_angle_cls
:
if
self
.
use_angle_cls
and
cls
:
img_crop_list
,
angle_list
,
elapse
=
self
.
text_classifier
(
img_crop_list
)
logger
.
info
(
"cls num : {}, elapse : {}"
.
format
(
...
...
tools/infer/utility.py
浏览文件 @
ad4853db
...
...
@@ -23,13 +23,15 @@ import math
from
paddle
import
inference
import
time
from
ppocr.utils.logging
import
get_logger
logger
=
get_logger
()
def
parse_args
(
):
def
str2bool
(
v
):
return
v
.
lower
()
in
(
"true"
,
"t"
,
"1"
)
def
str2bool
(
v
):
return
v
.
lower
()
in
(
"true"
,
"t"
,
"1"
)
def
init_args
():
parser
=
argparse
.
ArgumentParser
()
# params for prediction engine
parser
.
add_argument
(
"--use_gpu"
,
type
=
str2bool
,
default
=
True
)
...
...
@@ -108,6 +110,11 @@ def parse_args():
parser
.
add_argument
(
"--total_process_num"
,
type
=
int
,
default
=
1
)
parser
.
add_argument
(
"--process_id"
,
type
=
int
,
default
=
0
)
return
parser
def
parse_args
():
parser
=
init_args
()
return
parser
.
parse_args
()
...
...
@@ -141,7 +148,7 @@ def create_predictor(args, mode, logger):
config
.
enable_tensorrt_engine
(
precision_mode
=
inference
.
PrecisionType
.
Float32
,
max_batch_size
=
args
.
max_batch_size
,
min_subgraph_size
=
3
)
# skip the minmum trt subgraph
min_subgraph_size
=
3
)
# skip the minmum trt subgraph
if
mode
==
"det"
and
"mobile"
in
model_file_path
:
min_input_shape
=
{
"x"
:
[
1
,
3
,
50
,
50
],
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录