Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
Just_Paranoid
CnOCR
提交
5d338906
CnOCR
项目概览
Just_Paranoid
/
CnOCR
与 Fork 源项目一致
Fork自
Cloud IDE / CnOCR
通知
1
Star
0
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
DevOps
流水线
流水线任务
计划
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
CnOCR
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
DevOps
DevOps
流水线
流水线任务
计划
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
流水线任务
提交
Issue看板
前往新版Gitcode,体验更适合开发者的 AI 搜索 >>
提交
5d338906
编写于
5月 14, 2022
作者:
B
breezedeus
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
support onnx models for predictions
上级
2f8a95fe
变更
4
隐藏空白更改
内联
并排
Showing
4 changed file
with
144 addition
and
61 deletion
+144
-61
cnocr/cn_ocr.py
cnocr/cn_ocr.py
+82
-22
cnocr/consts.py
cnocr/consts.py
+10
-6
cnocr/utils.py
cnocr/utils.py
+14
-8
tests/test_cnocr.py
tests/test_cnocr.py
+38
-25
未找到文件。
cnocr/cn_ocr.py
浏览文件 @
5d338906
...
...
@@ -39,9 +39,11 @@ from cnocr.utils import (
load_model_params
,
rescale_img
,
pad_img_seq
,
to_numpy
,
)
from
.data_utils.aug
import
NormalizeAug
from
.line_split
import
line_split
from
.models.ctc
import
CTCPostProcessor
logger
=
logging
.
getLogger
(
__name__
)
...
...
@@ -62,7 +64,9 @@ class CnOcr(object):
cand_alphabet
:
Optional
[
Union
[
Collection
,
str
]]
=
None
,
context
:
str
=
'cpu'
,
# ['cpu', 'gpu', 'cuda']
model_fp
:
Optional
[
str
]
=
None
,
model_backend
:
str
=
'onnx'
,
# ['pytorch', 'onnx']
root
:
Union
[
str
,
Path
]
=
data_dir
(),
vocab_fp
:
Union
[
str
,
Path
]
=
VOCAB_FP
,
**
kwargs
,
):
"""
...
...
@@ -73,9 +77,11 @@ class CnOcr(object):
cand_alphabet (Optional[Union[Collection, str]]): 待识别字符所在的候选集合。默认为 `None`,表示不限定识别字符范围
context (str): 'cpu', or 'gpu'。表明预测时是使用CPU还是GPU。默认为 `cpu`
model_fp (Optional[str]): 如果不使用系统自带的模型,可以通过此参数直接指定所使用的模型文件('.ckpt' 文件)
model_backend (str): 'pytorch', or 'onnx'。表明预测时是使用是使用 PyTorch 模型,还是使用 ONNX 模型。默认为 `pytorch`
root (Union[str, Path]): 模型文件所在的根目录。
Linux/Mac下默认值为 `~/.cnocr`,表示模型文件所处文件夹类似 `~/.cnocr/2.1/densenet_lite_136-fc`。
Windows下默认值为 `C:/Users/<username>/AppData/Roaming/cnocr`。
vocab_fp (Union[str, Path]): 字符集合的文件路径,即 `label_cn.txt` 文件路径
**kwargs: 目前未被使用。
Examples:
...
...
@@ -89,6 +95,8 @@ class CnOcr(object):
>>> ocr = CnOcr(model_name='densenet_lite_136-fc', cand_alphabet='0123456789')
"""
model_backend
=
model_backend
.
lower
()
assert
model_backend
in
(
'pytorch'
,
'onnx'
)
if
'name'
in
kwargs
:
logger
.
warning
(
'param `name` is useless and deprecated since version %s'
...
...
@@ -96,22 +104,31 @@ class CnOcr(object):
)
check_model_name
(
model_name
)
check_context
(
context
)
self
.
_model_name
=
model_name
self
.
_model_backend
=
model_backend
if
context
==
'gpu'
:
context
=
'cuda'
self
.
context
=
context
self
.
_model_file_prefix
=
'{}-{}'
.
format
(
self
.
MODEL_FILE_PREFIX
,
model_name
)
model_epoch
=
AVAILABLE_MODELS
.
get
(
model_name
,
[
None
])[
0
]
if
model_epoch
is
not
None
:
self
.
_model_file_prefix
=
'%s-epoch=%03d'
%
(
self
.
_model_file_prefix
,
model_epoch
,
try
:
self
.
_assert_and_prepare_model_files
(
model_fp
,
root
)
except
NotImplementedError
:
logger
.
warning
(
'no available model is found for name %s and backend %s'
%
(
self
.
_model_name
,
self
.
_model_backend
)
)
self
.
_model_backend
=
(
'onnx'
if
self
.
_model_backend
==
'pytorch'
else
'pytorch'
)
logger
.
warning
(
'trying to use name %s and backend %s'
%
(
self
.
_model_name
,
self
.
_model_backend
)
)
self
.
_assert_and_prepare_model_files
(
model_fp
,
root
)
self
.
_
assert_and_prepare_model_files
(
model_fp
,
root
)
self
.
_vocab
,
self
.
_letter2id
=
read_charset
(
VOCAB_FP
)
self
.
_
vocab
,
self
.
_letter2id
=
read_charset
(
vocab_fp
)
self
.
postprocessor
=
CTCPostProcessor
(
vocab
=
self
.
_vocab
)
self
.
_candidates
=
None
self
.
set_cand_alphabet
(
cand_alphabet
)
...
...
@@ -119,6 +136,15 @@ class CnOcr(object):
self
.
_model
=
self
.
_get_model
(
context
)
def
_assert_and_prepare_model_files
(
self
,
model_fp
,
root
):
self
.
_model_file_prefix
=
'{}-{}'
.
format
(
self
.
MODEL_FILE_PREFIX
,
self
.
_model_name
)
model_epoch
=
AVAILABLE_MODELS
.
get
((
self
.
_model_name
,
self
.
_model_backend
),
[
None
])[
0
]
if
model_epoch
is
not
None
:
self
.
_model_file_prefix
=
'%s-epoch=%03d'
%
(
self
.
_model_file_prefix
,
model_epoch
,
)
if
model_fp
is
not
None
and
not
os
.
path
.
isfile
(
model_fp
):
raise
FileNotFoundError
(
'can not find model file %s'
%
model_fp
)
...
...
@@ -128,25 +154,37 @@ class CnOcr(object):
root
=
os
.
path
.
join
(
root
,
MODEL_VERSION
)
self
.
_model_dir
=
os
.
path
.
join
(
root
,
self
.
_model_name
)
fps
=
glob
(
'%s/%s*.ckpt'
%
(
self
.
_model_dir
,
self
.
_model_file_prefix
))
model_ext
=
'ckpt'
if
self
.
_model_backend
==
'pytorch'
else
'onnx'
fps
=
glob
(
'%s/%s*.%s'
%
(
self
.
_model_dir
,
self
.
_model_file_prefix
,
model_ext
))
if
len
(
fps
)
>
1
:
raise
ValueError
(
'multiple
ckpt
files are found in %s, not sure which one should be used'
%
self
.
_model_dir
'multiple
%s
files are found in %s, not sure which one should be used'
%
(
model_ext
,
self
.
_model_dir
)
)
elif
len
(
fps
)
<
1
:
logger
.
warning
(
'no ckpt file is found in %s'
%
self
.
_model_dir
)
get_model_file
(
self
.
_model_dir
)
# download the .zip file and unzip
fps
=
glob
(
'%s/%s*.ckpt'
%
(
self
.
_model_dir
,
self
.
_model_file_prefix
))
logger
.
warning
(
'no %s file is found in %s'
%
(
model_ext
,
self
.
_model_dir
))
get_model_file
(
self
.
_model_name
,
self
.
_model_backend
,
self
.
_model_dir
)
# download the .zip file and unzip
fps
=
glob
(
'%s/%s*.%s'
%
(
self
.
_model_dir
,
self
.
_model_file_prefix
,
model_ext
)
)
self
.
_model_fp
=
fps
[
0
]
def
_get_model
(
self
,
context
):
logger
.
info
(
'use model: %s'
%
self
.
_model_fp
)
model
=
gen_model
(
self
.
_model_name
,
self
.
_vocab
)
model
.
eval
()
model
.
to
(
self
.
context
)
model
=
load_model_params
(
model
,
self
.
_model_fp
,
context
)
if
self
.
_model_backend
==
'pytorch'
:
model
=
gen_model
(
self
.
_model_name
,
self
.
_vocab
)
model
.
eval
()
model
.
to
(
self
.
context
)
model
=
load_model_params
(
model
,
self
.
_model_fp
,
context
)
elif
self
.
_model_backend
==
'onnx'
:
import
onnxruntime
model
=
onnxruntime
.
InferenceSession
(
self
.
_model_fp
)
else
:
raise
NotImplementedError
(
f
'
{
self
.
_model_backend
}
is not supported yet'
)
return
model
...
...
@@ -335,11 +373,33 @@ class CnOcr(object):
img
=
rescale_img
(
img
.
transpose
((
2
,
0
,
1
)))
# res: [C, H, W]
return
NormalizeAug
()(
img
).
to
(
device
=
torch
.
device
(
self
.
context
))
@
torch
.
no_grad
()
def
_predict
(
self
,
img_list
:
List
[
torch
.
Tensor
]):
img_lengths
=
torch
.
tensor
([
img
.
shape
[
2
]
for
img
in
img_list
])
imgs
=
pad_img_seq
(
img_list
)
out
=
self
.
_model
(
imgs
,
img_lengths
,
candidates
=
self
.
_candidates
,
return_preds
=
True
if
self
.
_model_backend
==
'pytorch'
:
with
torch
.
no_grad
():
out
=
self
.
_model
(
imgs
,
img_lengths
,
candidates
=
self
.
_candidates
,
return_preds
=
True
)
else
:
# onnx
out
=
self
.
_onnx_predict
(
imgs
,
img_lengths
)
return
out
def
_onnx_predict
(
self
,
imgs
,
img_lengths
):
ort_session
=
self
.
_model
ort_inputs
=
{
ort_session
.
get_inputs
()[
0
].
name
:
to_numpy
(
imgs
),
ort_session
.
get_inputs
()[
1
].
name
:
to_numpy
(
img_lengths
),
}
ort_outs
=
ort_session
.
run
(
None
,
ort_inputs
)
out
=
{
'logits'
:
torch
.
from_numpy
(
ort_outs
[
0
]),
'output_lengths'
:
torch
.
from_numpy
(
ort_outs
[
1
]),
}
out
[
'logits'
]
=
OcrModel
.
mask_by_candidates
(
out
[
'logits'
],
self
.
_candidates
,
self
.
_vocab
,
self
.
_letter2id
)
out
[
"preds"
]
=
self
.
postprocessor
(
out
[
'logits'
],
out
[
'output_lengths'
])
return
out
cnocr/consts.py
浏览文件 @
5d338906
...
...
@@ -107,12 +107,16 @@ root_url = (
)
# name: (epoch, url)
AVAILABLE_MODELS
=
{
'densenet_lite_114-fc'
:
(
37
,
root_url
+
'densenet_lite_114-fc.zip'
),
'densenet_lite_124-fc'
:
(
39
,
root_url
+
'densenet_lite_124-fc.zip'
),
'densenet_lite_134-fc'
:
(
34
,
root_url
+
'densenet_lite_134-fc.zip'
),
'densenet_lite_136-fc'
:
(
39
,
root_url
+
'densenet_lite_136-fc.zip'
),
'densenet_lite_134-gru'
:
(
2
,
root_url
+
'densenet_lite_134-gru.zip'
),
'densenet_lite_136-gru'
:
(
2
,
root_url
+
'densenet_lite_136-gru.zip'
),
(
'densenet_lite_114-fc'
,
'pytorch'
):
(
37
,
root_url
+
'densenet_lite_114-fc.zip'
),
(
'densenet_lite_124-fc'
,
'pytorch'
):
(
39
,
root_url
+
'densenet_lite_124-fc.zip'
),
(
'densenet_lite_134-fc'
,
'pytorch'
):
(
34
,
root_url
+
'densenet_lite_134-fc.zip'
),
(
'densenet_lite_136-fc'
,
'pytorch'
):
(
39
,
root_url
+
'densenet_lite_136-fc.zip'
),
(
'densenet_lite_114-fc'
,
'onnx'
):
(
37
,
root_url
+
'densenet_lite_114-fc-onnx.zip'
),
(
'densenet_lite_124-fc'
,
'onnx'
):
(
39
,
root_url
+
'densenet_lite_124-fc-onnx.zip'
),
(
'densenet_lite_134-fc'
,
'onnx'
):
(
34
,
root_url
+
'densenet_lite_134-fc-onnx.zip'
),
(
'densenet_lite_136-fc'
,
'onnx'
):
(
39
,
root_url
+
'densenet_lite_136-fc-onnx.zip'
),
(
'densenet_lite_134-gru'
,
'pytorch'
):
(
2
,
root_url
+
'densenet_lite_134-gru.zip'
),
(
'densenet_lite_136-gru'
,
'pytorch'
):
(
2
,
root_url
+
'densenet_lite_136-gru.zip'
),
}
# 候选字符集合
...
...
cnocr/utils.py
浏览文件 @
5d338906
...
...
@@ -107,6 +107,12 @@ def check_model_name(model_name):
assert
decoder_type
in
DECODER_CONFIGS
def
to_numpy
(
tensor
:
torch
.
Tensor
)
->
np
.
ndarray
:
return
(
tensor
.
detach
().
cpu
().
numpy
()
if
tensor
.
requires_grad
else
tensor
.
cpu
().
numpy
()
)
def
check_sha1
(
filename
,
sha1_hash
):
"""Check whether the sha1 hash of the file content matches the expected hash.
Parameters
...
...
@@ -202,7 +208,7 @@ def download(url, path=None, overwrite=False, sha1_hash=None):
return
fname
def
get_model_file
(
model_dir
):
def
get_model_file
(
model_
name
,
model_backend
,
model_
dir
):
r
"""Return location for the downloaded models on local file system.
This function will download from online model zoo when model cannot be found or has mismatch.
...
...
@@ -210,6 +216,8 @@ def get_model_file(model_dir):
Parameters
----------
model_name : str
model_backend : str
model_dir : str, default $CNOCR_HOME
Location for keeping the model parameters.
...
...
@@ -222,14 +230,12 @@ def get_model_file(model_dir):
par_dir
=
os
.
path
.
dirname
(
model_dir
)
os
.
makedirs
(
par_dir
,
exist_ok
=
True
)
zip_file_path
=
model_dir
+
'.zip'
if
(
model_name
,
model_backend
)
not
in
AVAILABLE_MODELS
:
raise
NotImplementedError
(
'%s is not a downloadable model'
%
model_name
)
url
=
AVAILABLE_MODELS
[(
model_name
,
model_backend
)][
1
]
zip_file_path
=
os
.
path
.
join
(
par_dir
,
os
.
path
.
basename
(
url
))
if
not
os
.
path
.
exists
(
zip_file_path
):
model_name
=
os
.
path
.
basename
(
model_dir
)
if
model_name
not
in
AVAILABLE_MODELS
:
raise
NotImplementedError
(
'%s is not an available downloaded model'
%
model_name
)
url
=
AVAILABLE_MODELS
[
model_name
][
1
]
download
(
url
,
path
=
zip_file_path
,
overwrite
=
True
)
with
zipfile
.
ZipFile
(
zip_file_path
)
as
zf
:
zf
.
extractall
(
par_dir
)
...
...
tests/test_cnocr.py
浏览文件 @
5d338906
...
...
@@ -19,7 +19,10 @@
import
os
import
sys
import
logging
import
time
import
pytest
import
numpy
as
np
from
PIL
import
Image
import
Levenshtein
...
...
@@ -28,13 +31,15 @@ sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
sys
.
path
.
insert
(
1
,
os
.
path
.
dirname
(
os
.
path
.
abspath
(
__file__
)))
from
cnocr
import
CnOcr
from
cnocr.utils
import
read_img
from
cnocr.utils
import
set_logger
,
read_img
from
cnocr.consts
import
NUMBERS
,
AVAILABLE_MODELS
from
cnocr.line_split
import
line_split
logger
=
set_logger
(
log_level
=
logging
.
INFO
)
root_dir
=
os
.
path
.
dirname
(
os
.
path
.
dirname
(
os
.
path
.
abspath
(
__file__
)))
example_dir
=
os
.
path
.
join
(
root_dir
,
'docs/examples'
)
CNOCR
=
CnOcr
(
model_name
=
'densenet
-s
-fc'
,
model_epoch
=
None
)
CNOCR
=
CnOcr
(
model_name
=
'densenet
_lite_136
-fc'
,
model_epoch
=
None
)
SINGLE_LINE_CASES
=
[
(
'20457890_2399557098.jpg'
,
[
'就会哈哈大笑。3.0'
]),
...
...
@@ -110,8 +115,7 @@ def cal_score(preds, expected):
@
pytest
.
mark
.
parametrize
(
'img_fp, expected'
,
CASES
)
def
test_ocr
(
img_fp
,
expected
):
ocr
=
CNOCR
root_dir
=
os
.
path
.
dirname
(
os
.
path
.
dirname
(
os
.
path
.
abspath
(
__file__
)))
img_fp
=
os
.
path
.
join
(
root_dir
,
'examples'
,
img_fp
)
img_fp
=
os
.
path
.
join
(
example_dir
,
img_fp
)
pred
=
ocr
.
ocr
(
img_fp
)
print
(
'
\n
'
)
...
...
@@ -132,8 +136,7 @@ def test_ocr(img_fp, expected):
@
pytest
.
mark
.
parametrize
(
'img_fp, expected'
,
SINGLE_LINE_CASES
)
def
test_ocr_for_single_line
(
img_fp
,
expected
):
ocr
=
CNOCR
root_dir
=
os
.
path
.
dirname
(
os
.
path
.
dirname
(
os
.
path
.
abspath
(
__file__
)))
img_fp
=
os
.
path
.
join
(
root_dir
,
'examples'
,
img_fp
)
img_fp
=
os
.
path
.
join
(
example_dir
,
img_fp
)
pred
=
ocr
.
ocr_for_single_line
(
img_fp
)
print
(
'
\n
'
)
print_preds
([
pred
])
...
...
@@ -165,8 +168,7 @@ def test_ocr_for_single_line(img_fp, expected):
@
pytest
.
mark
.
parametrize
(
'img_fp, expected'
,
MULTIPLE_LINE_CASES
)
def
test_ocr_for_single_lines
(
img_fp
,
expected
):
ocr
=
CNOCR
root_dir
=
os
.
path
.
dirname
(
os
.
path
.
dirname
(
os
.
path
.
abspath
(
__file__
)))
img_fp
=
os
.
path
.
join
(
root_dir
,
'examples'
,
img_fp
)
img_fp
=
os
.
path
.
join
(
example_dir
,
img_fp
)
img
=
read_img
(
img_fp
)
if
img
.
mean
()
<
145
:
# 把黑底白字的图片对调为白底黑字
img
=
255
-
img
...
...
@@ -186,26 +188,37 @@ def test_ocr_for_single_lines(img_fp, expected):
def
test_cand_alphabet
():
img_fp
=
os
.
path
.
join
(
example_dir
,
'hybrid.png'
)
ocr
=
CnOcr
(
cand_alphabet
=
NUMBERS
)
pred
=
ocr
.
ocr
(
img_fp
)
pred
=
[
''
.
join
(
line_p
)
for
line_p
,
_
in
pred
]
ocr
=
CnOcr
(
'densenet_lite_136-fc'
,
cand_alphabet
=
NUMBERS
)
p
t_p
red
=
ocr
.
ocr
(
img_fp
)
pred
=
[
''
.
join
(
line_p
)
for
line_p
,
_
in
p
t_p
red
]
print
(
"Predicted Chars:"
,
pred
)
assert
len
(
pred
)
==
1
and
pred
[
0
]
==
'012345678'
ocr
=
CnOcr
(
'densenet_lite_136-fc'
,
model_backend
=
'onnx'
,
cand_alphabet
=
NUMBERS
)
onnx_pred
=
ocr
.
ocr
(
img_fp
)
pred
=
[
''
.
join
(
line_p
)
for
line_p
,
_
in
onnx_pred
]
print
(
"Predicted Chars:"
,
pred
)
assert
len
(
pred
)
==
1
and
pred
[
0
]
==
'012345678'
INSTANCE_ID
=
0
assert
pt_pred
[
0
][
0
]
==
onnx_pred
[
0
][
0
]
assert
abs
(
pt_pred
[
0
][
1
]
-
onnx_pred
[
0
][
1
])
<
1e-5
@
pytest
.
mark
.
parametrize
(
'model_name'
,
AVAILABLE_MODELS
.
keys
())
def
test_multiple_instances
(
model_name
):
global
INSTANCE_ID
print
(
'test multiple instances for model_name: %s'
%
model_name
)
img_fp
=
os
.
path
.
join
(
example_dir
,
'hybrid.png'
)
INSTANCE_ID
+=
1
print
(
'instance id: %d'
%
INSTANCE_ID
)
cnocr1
=
CnOcr
(
model_name
,
name
=
'instance-%d'
%
INSTANCE_ID
)
print_preds
(
cnocr1
.
ocr
(
img_fp
))
INSTANCE_ID
+=
1
print
(
'instance id: %d'
%
INSTANCE_ID
)
cnocr2
=
CnOcr
(
model_name
,
name
=
'instance-%d'
%
INSTANCE_ID
,
cand_alphabet
=
NUMBERS
)
print_preds
(
cnocr2
.
ocr
(
img_fp
))
@
pytest
.
mark
.
parametrize
(
'img_fp, expected'
,
SINGLE_LINE_CASES
)
def
test_onnx
(
img_fp
,
expected
):
img_fp
=
os
.
path
.
join
(
example_dir
,
img_fp
)
pt_ocr
=
CnOcr
(
'densenet_lite_136-fc'
,
model_backend
=
'pytorch'
)
start_time
=
time
.
time
()
pt_preds
=
pt_ocr
.
ocr_for_single_line
(
img_fp
)
end_time
=
time
.
time
()
print
(
f
'
\n
pytorch time cost
{
end_time
-
start_time
}
'
,
pt_preds
)
onnx_ocr
=
CnOcr
(
'densenet_lite_136-fc'
,
model_backend
=
'onnx'
)
start_time
=
time
.
time
()
onnx_preds
=
onnx_ocr
.
ocr_for_single_line
(
img_fp
)
end_time
=
time
.
time
()
print
(
f
'onnx time cost
{
end_time
-
start_time
}
'
,
onnx_preds
,
'
\n\n
'
)
assert
pt_preds
[
0
]
==
onnx_preds
[
0
]
assert
abs
(
pt_preds
[
1
]
-
onnx_preds
[
1
])
<
1e-5
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录