Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
PaddleOCR
提交
c581ff51
P
PaddleOCR
项目概览
PaddlePaddle
/
PaddleOCR
大约 1 年 前同步成功
通知
1528
Star
32962
Fork
6643
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
108
列表
看板
标记
里程碑
合并请求
7
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
PaddleOCR
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
108
Issue
108
列表
看板
标记
里程碑
合并请求
7
合并请求
7
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
c581ff51
编写于
8月 24, 2020
作者:
W
WenmuZhou
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
model_storage_directory参数取消,改为det_model_dir和rec_model_dir
上级
09e15a68
变更
1
隐藏空白更改
内联
并排
Showing
1 changed file
with
56 addition
and
75 deletion
+56
-75
paddleocr.py
paddleocr.py
+56
-75
未找到文件。
paddleocr.py
浏览文件 @
c581ff51
...
@@ -29,25 +29,19 @@ from tools.infer import predict_system
...
@@ -29,25 +29,19 @@ from tools.infer import predict_system
from
ppocr.utils.utility
import
initial_logger
from
ppocr.utils.utility
import
initial_logger
logger
=
initial_logger
()
logger
=
initial_logger
()
from
ppocr.utils.utility
import
check_and_read_gif
from
ppocr.utils.utility
import
check_and_read_gif
,
get_image_file_list
__all__
=
[
'PaddleOCR'
]
__all__
=
[
'PaddleOCR'
]
model_params
=
{
model_params
=
{
'ch_det_mv3_db'
:
{
'det'
:
'https://paddleocr.bj.bcebos.com/ch_models/ch_det_mv3_db_infer.tar'
,
'url'
:
'rec'
:
'https://paddleocr.bj.bcebos.com/ch_models/ch_det_mv3_db_infer.tar'
,
'https://paddleocr.bj.bcebos.com/ch_models/ch_rec_mv3_crnn_enhance_infer.tar'
,
'algorithm'
:
'DB'
,
},
'ch_rec_mv3_crnn_enhance'
:
{
'url'
:
'https://paddleocr.bj.bcebos.com/ch_models/ch_rec_mv3_crnn_enhance_infer.tar'
,
'algorithm'
:
'CRNN'
},
}
}
SUPPORT_DET_MODEL
=
[
'DB'
]
SUPPORT_DET_MODEL
=
[
'DB'
]
SUPPORT_REC_MODEL
=
[
'Rosetta'
,
'CRNN'
,
'STARNet'
,
'RARE'
]
SUPPORT_REC_MODEL
=
[
'CRNN'
]
BASE_DIR
=
os
.
path
.
expanduser
(
"~/.paddleocr/"
)
def
download_with_progressbar
(
url
,
save_path
):
def
download_with_progressbar
(
url
,
save_path
):
...
@@ -65,34 +59,29 @@ def download_with_progressbar(url, save_path):
...
@@ -65,34 +59,29 @@ def download_with_progressbar(url, save_path):
sys
.
exit
(
0
)
sys
.
exit
(
0
)
def
download_and_unzip
(
url
,
model_storage_directory
):
def
maybe_download
(
model_storage_directory
,
url
):
tmp_path
=
os
.
path
.
join
(
model_storage_directory
,
url
.
split
(
'/'
)[
-
1
])
print
(
'download {} to {}'
.
format
(
url
,
tmp_path
))
os
.
makedirs
(
model_storage_directory
,
exist_ok
=
True
)
download_with_progressbar
(
url
,
tmp_path
)
with
tarfile
.
open
(
tmp_path
,
'r'
)
as
tarObj
:
for
filename
in
tarObj
.
getnames
():
tarObj
.
extract
(
filename
,
model_storage_directory
)
os
.
remove
(
tmp_path
)
def
maybe_download
(
model_storage_directory
,
model_name
,
mode
=
'det'
):
algorithm
=
None
# using custom model
# using custom model
if
os
.
path
.
exists
(
os
.
path
.
join
(
model_name
,
'model'
))
and
os
.
path
.
exists
(
if
not
os
.
path
.
exists
(
os
.
path
.
join
(
os
.
path
.
join
(
model_name
,
'params'
)):
model_storage_directory
,
'model'
))
or
not
os
.
path
.
exists
(
return
model_name
,
algorithm
os
.
path
.
join
(
model_storage_directory
,
'params'
)):
# using the model of paddleocr
tmp_path
=
os
.
path
.
join
(
model_storage_directory
,
url
.
split
(
'/'
)[
-
1
])
model_path
=
os
.
path
.
join
(
model_storage_directory
,
model_name
)
print
(
'download {} to {}'
.
format
(
url
,
tmp_path
))
if
not
os
.
path
.
exists
(
os
.
path
.
join
(
model_path
,
os
.
makedirs
(
model_storage_directory
,
exist_ok
=
True
)
'model'
))
or
not
os
.
path
.
exists
(
download_with_progressbar
(
url
,
tmp_path
)
os
.
path
.
join
(
model_path
,
'params'
)):
with
tarfile
.
open
(
tmp_path
,
'r'
)
as
tarObj
:
assert
model_name
in
model_params
,
'model must in {}'
.
format
(
for
member
in
tarObj
.
getmembers
():
model_params
.
keys
())
if
"model"
in
member
.
name
:
download_and_unzip
(
model_params
[
model_name
][
'url'
],
filename
=
'model'
model_storage_directory
)
elif
"params"
in
member
.
name
:
algorithm
=
model_params
[
model_name
][
'algorithm'
]
filename
=
'params'
return
model_path
,
algorithm
else
:
continue
file
=
tarObj
.
extractfile
(
member
)
with
open
(
os
.
path
.
join
(
model_storage_directory
,
filename
),
'wb'
)
as
f
:
f
.
write
(
file
.
read
())
os
.
remove
(
tmp_path
)
def
parse_args
():
def
parse_args
():
...
@@ -111,7 +100,7 @@ def parse_args():
...
@@ -111,7 +100,7 @@ def parse_args():
# params for text detector
# params for text detector
parser
.
add_argument
(
"--image_dir"
,
type
=
str
)
parser
.
add_argument
(
"--image_dir"
,
type
=
str
)
parser
.
add_argument
(
"--det_algorithm"
,
type
=
str
,
default
=
'DB'
)
parser
.
add_argument
(
"--det_algorithm"
,
type
=
str
,
default
=
'DB'
)
parser
.
add_argument
(
"--det_model_
name"
,
type
=
str
,
default
=
'ch_det_mv3_db'
)
parser
.
add_argument
(
"--det_model_
dir"
,
type
=
str
,
default
=
None
)
parser
.
add_argument
(
"--det_max_side_len"
,
type
=
float
,
default
=
960
)
parser
.
add_argument
(
"--det_max_side_len"
,
type
=
float
,
default
=
960
)
# DB parmas
# DB parmas
...
@@ -126,11 +115,11 @@ def parse_args():
...
@@ -126,11 +115,11 @@ def parse_args():
# params for text recognizer
# params for text recognizer
parser
.
add_argument
(
"--rec_algorithm"
,
type
=
str
,
default
=
'CRNN'
)
parser
.
add_argument
(
"--rec_algorithm"
,
type
=
str
,
default
=
'CRNN'
)
parser
.
add_argument
(
parser
.
add_argument
(
"--rec_model_dir"
,
type
=
str
,
default
=
None
)
"--rec_model_name"
,
type
=
str
,
default
=
'ch_rec_mv3_crnn_enhance'
)
parser
.
add_argument
(
"--rec_image_shape"
,
type
=
str
,
default
=
"3, 32, 320"
)
parser
.
add_argument
(
"--rec_image_shape"
,
type
=
str
,
default
=
"3, 32, 320"
)
parser
.
add_argument
(
"--rec_char_type"
,
type
=
str
,
default
=
'ch'
)
parser
.
add_argument
(
"--rec_char_type"
,
type
=
str
,
default
=
'ch'
)
parser
.
add_argument
(
"--rec_batch_num"
,
type
=
int
,
default
=
30
)
parser
.
add_argument
(
"--rec_batch_num"
,
type
=
int
,
default
=
30
)
parser
.
add_argument
(
"--max_text_length"
,
type
=
int
,
default
=
25
)
parser
.
add_argument
(
parser
.
add_argument
(
"--rec_char_dict_path"
,
"--rec_char_dict_path"
,
type
=
str
,
type
=
str
,
...
@@ -138,53 +127,30 @@ def parse_args():
...
@@ -138,53 +127,30 @@ def parse_args():
parser
.
add_argument
(
"--use_space_char"
,
type
=
bool
,
default
=
True
)
parser
.
add_argument
(
"--use_space_char"
,
type
=
bool
,
default
=
True
)
parser
.
add_argument
(
"--enable_mkldnn"
,
type
=
bool
,
default
=
False
)
parser
.
add_argument
(
"--enable_mkldnn"
,
type
=
bool
,
default
=
False
)
parser
.
add_argument
(
"--model_storage_directory"
,
type
=
str
,
default
=
False
)
parser
.
add_argument
(
"--det"
,
type
=
str2bool
,
default
=
True
)
parser
.
add_argument
(
"--det"
,
type
=
str2bool
,
default
=
True
)
parser
.
add_argument
(
"--rec"
,
type
=
str2bool
,
default
=
True
)
parser
.
add_argument
(
"--rec"
,
type
=
str2bool
,
default
=
True
)
return
parser
.
parse_args
()
return
parser
.
parse_args
()
class
PaddleOCR
(
predict_system
.
TextSystem
):
class
PaddleOCR
(
predict_system
.
TextSystem
):
def
__init__
(
self
,
def
__init__
(
self
,
**
kwargs
):
det_model_name
=
'ch_det_mv3_db'
,
rec_model_name
=
'ch_rec_mv3_crnn_enhance'
,
model_storage_directory
=
None
,
log_level
=
20
,
**
kwargs
):
"""
"""
paddleocr package
paddleocr package
args:
args:
det_model_name: det_model name, keep same with filename in paddleocr. default is ch_det_mv3_db
det_model_name: rec_model name, keep same with filename in paddleocr. default is ch_rec_mv3_crnn_enhance
model_storage_directory: model save path. default is ~/.paddleocr
det model will save to model_storage_directory/det_model
rec model will save to model_storage_directory/rec_model
log_level:
**kwargs: other params show in paddleocr --help
**kwargs: other params show in paddleocr --help
"""
"""
logger
.
setLevel
(
log_level
)
postprocess_params
=
parse_args
()
postprocess_params
=
parse_args
()
# init model dir
postprocess_params
.
__dict__
.
update
(
**
kwargs
)
if
model_storage_directory
:
self
.
model_storage_directory
=
model_storage_directory
else
:
self
.
model_storage_directory
=
os
.
path
.
expanduser
(
"~/.paddleocr/"
)
+
'/model'
Path
(
self
.
model_storage_directory
).
mkdir
(
parents
=
True
,
exist_ok
=
True
)
# init model dir
if
postprocess_params
.
det_model_dir
is
None
:
postprocess_params
.
det_model_dir
=
os
.
path
.
join
(
BASE_DIR
,
'det'
)
if
postprocess_params
.
rec_model_dir
is
None
:
postprocess_params
.
rec_model_dir
=
os
.
path
.
join
(
BASE_DIR
,
'rec'
)
print
(
postprocess_params
)
# download model
# download model
det_model_path
,
det_algorithm
=
maybe_download
(
maybe_download
(
postprocess_params
.
det_model_dir
,
model_params
[
'det'
])
self
.
model_storage_directory
,
det_model_name
,
'det'
)
maybe_download
(
postprocess_params
.
rec_model_dir
,
model_params
[
'rec'
])
rec_model_path
,
rec_algorithm
=
maybe_download
(
self
.
model_storage_directory
,
rec_model_name
,
'rec'
)
# update model and post_process params
postprocess_params
.
__dict__
.
update
(
**
kwargs
)
postprocess_params
.
det_model_dir
=
det_model_path
postprocess_params
.
rec_model_dir
=
rec_model_path
if
det_algorithm
is
not
None
:
postprocess_params
.
det_algorithm
=
det_algorithm
if
rec_algorithm
is
not
None
:
postprocess_params
.
rec_algorithm
=
rec_algorithm
if
postprocess_params
.
det_algorithm
not
in
SUPPORT_DET_MODEL
:
if
postprocess_params
.
det_algorithm
not
in
SUPPORT_DET_MODEL
:
logger
.
error
(
'det_algorithm must in {}'
.
format
(
SUPPORT_DET_MODEL
))
logger
.
error
(
'det_algorithm must in {}'
.
format
(
SUPPORT_DET_MODEL
))
...
@@ -229,3 +195,18 @@ class PaddleOCR(predict_system.TextSystem):
...
@@ -229,3 +195,18 @@ class PaddleOCR(predict_system.TextSystem):
img
=
[
img
]
img
=
[
img
]
rec_res
,
elapse
=
self
.
text_recognizer
(
img
)
rec_res
,
elapse
=
self
.
text_recognizer
(
img
)
return
rec_res
return
rec_res
def
main
():
# for com
args
=
parse_args
()
image_file_list
=
get_image_file_list
(
args
.
image_dir
)
if
len
(
image_file_list
)
==
0
:
logger
.
error
(
'no images find in {}'
.
format
(
args
.
image_dir
))
return
ocr_engine
=
PaddleOCR
()
for
img_path
in
image_file_list
:
print
(
img_path
)
result
=
ocr_engine
.
ocr
(
img_path
,
det
=
args
.
det
,
rec
=
args
.
rec
)
for
line
in
result
:
print
(
line
)
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录