Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
Just_Paranoid
CnOCR
提交
b360e4da
CnOCR
项目概览
Just_Paranoid
/
CnOCR
与 Fork 源项目一致
Fork自
Cloud IDE / CnOCR
通知
1
Star
0
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
DevOps
流水线
流水线任务
计划
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
CnOCR
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
DevOps
DevOps
流水线
流水线任务
计划
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
流水线任务
提交
Issue看板
前往新版Gitcode,体验更适合开发者的 AI 搜索 >>
提交
b360e4da
编写于
8月 10, 2021
作者:
B
breezedeus
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
support for input candidates
上级
b08796f7
变更
2
隐藏空白更改
内联
并排
Showing
2 changed file
with
77 addition
and
19 deletion
+77
-19
cnocr/cn_ocr.py
cnocr/cn_ocr.py
+43
-16
cnocr/models/ocr_model.py
cnocr/models/ocr_model.py
+34
-3
未找到文件。
cnocr/cn_ocr.py
浏览文件 @
b360e4da
...
...
@@ -31,8 +31,11 @@ from cnocr.utils import (
get_model_file
,
read_charset
,
check_model_name
,
check_context
,
read_img
,
load_model_params
,
rescale_img
,
pad_img_seq
,
check_context
,
read_img
,
load_model_params
,
rescale_img
,
pad_img_seq
,
)
from
.data_utils.aug
import
NormalizeAug
from
.line_split
import
line_split
...
...
@@ -43,7 +46,9 @@ logger = logging.getLogger(__name__)
def
gen_model
(
model_name
,
vocab
):
check_model_name
(
model_name
)
if
not
model_name
.
startswith
(
'densenet-s'
):
logger
.
warning
(
'only "densenet-s" is supported now, use "densenet-s-fc" by default'
)
logger
.
warning
(
'only "densenet-s" is supported now, use "densenet-s-fc" by default'
)
model_name
=
'densenet-s-fc'
model
=
OcrModel
.
from_name
(
model_name
,
vocab
)
return
model
...
...
@@ -81,11 +86,9 @@ class CnOcr(object):
root
=
os
.
path
.
join
(
root
,
MODEL_VERSION
)
self
.
_model_dir
=
os
.
path
.
join
(
root
,
self
.
_model_name
)
# self._assert_and_prepare_model_files()
self
.
_vocab
,
self
.
_inv_alph_dict
=
read_charset
(
VOCAB_FP
)
self
.
_vocab
,
self
.
_letter2id
=
read_charset
(
VOCAB_FP
)
self
.
_cand
_alph_idx
=
None
self
.
_cand
idates
=
None
self
.
set_cand_alphabet
(
cand_alphabet
)
self
.
context
=
context
...
...
@@ -111,9 +114,13 @@ class CnOcr(object):
def
_get_module
(
self
,
context
):
from
glob
import
glob
fps
=
glob
(
'%s/%s*.ckpt'
%
(
self
.
_model_dir
,
self
.
_model_file_prefix
))
if
len
(
fps
)
>
1
:
raise
ValueError
(
'multiple ckpt files are found in %s, not sure which one should be used'
%
self
.
_model_dir
)
raise
ValueError
(
'multiple ckpt files are found in %s, not sure which one should be used'
%
self
.
_model_dir
)
elif
len
(
fps
)
<
1
:
raise
FileNotFoundError
(
'no ckpt file is found in %s'
%
self
.
_model_dir
)
...
...
@@ -131,12 +138,24 @@ class CnOcr(object):
:return: None
"""
if
cand_alphabet
is
None
:
self
.
_cand
_alph_idx
=
None
self
.
_cand
idates
=
None
else
:
self
.
_cand_alph_idx
=
[
self
.
_inv_alph_dict
[
word
]
for
word
in
cand_alphabet
]
self
.
_cand_alph_idx
.
sort
()
cand_alphabet
=
[
word
if
word
!=
' '
else
'<space>'
for
word
in
cand_alphabet
]
excluded
=
set
(
[
word
for
word
in
cand_alphabet
if
word
not
in
self
.
_letter2id
]
)
if
excluded
:
logger
.
warning
(
'chars in candidates are not in the vocab, ignoring them: %s'
%
excluded
)
candidates
=
[
word
for
word
in
cand_alphabet
if
word
in
self
.
_letter2id
]
self
.
_candidates
=
None
if
len
(
candidates
)
==
0
else
candidates
logger
.
info
(
'candidate chars: %s'
%
self
.
_candidates
)
def
ocr
(
self
,
img_fp
:
Union
[
str
,
Path
,
torch
.
Tensor
,
np
.
ndarray
])
->
List
[
Tuple
[
List
[
str
],
float
]]:
def
ocr
(
self
,
img_fp
:
Union
[
str
,
Path
,
torch
.
Tensor
,
np
.
ndarray
]
)
->
List
[
Tuple
[
List
[
str
],
float
]]:
"""
:param img_fp: image file path; or color image mx.nd.NDArray or np.ndarray,
with shape (height, width, 3), and the channels should be RGB formatted.
...
...
@@ -162,7 +181,9 @@ class CnOcr(object):
line_chars_list
=
self
.
ocr_for_single_lines
(
line_img_list
)
return
line_chars_list
def
ocr_for_single_line
(
self
,
img_fp
:
Union
[
str
,
Path
,
torch
.
Tensor
,
np
.
ndarray
])
->
Tuple
[
List
[
str
],
float
]:
def
ocr_for_single_line
(
self
,
img_fp
:
Union
[
str
,
Path
,
torch
.
Tensor
,
np
.
ndarray
]
)
->
Tuple
[
List
[
str
],
float
]:
"""
Recognize characters from an image with only one-line characters.
:param img_fp: image file path; or image mx.nd.NDArray or np.ndarray,
...
...
@@ -181,7 +202,9 @@ class CnOcr(object):
res
=
self
.
ocr_for_single_lines
([
img
])
return
res
[
0
]
def
ocr_for_single_lines
(
self
,
img_list
:
List
[
Union
[
torch
.
Tensor
,
np
.
ndarray
]])
->
List
[
Tuple
[
List
[
str
],
float
]]:
def
ocr_for_single_lines
(
self
,
img_list
:
List
[
Union
[
torch
.
Tensor
,
np
.
ndarray
]]
)
->
List
[
Tuple
[
List
[
str
],
float
]]:
"""
Batch recognize characters from a list of one-line-characters images.
:param img_list: list of images, in which each element should be a line image array,
...
...
@@ -206,7 +229,9 @@ class CnOcr(object):
return
res
def
_preprocess_img_array
(
self
,
img
:
Union
[
torch
.
Tensor
,
np
.
ndarray
])
->
torch
.
Tensor
:
def
_preprocess_img_array
(
self
,
img
:
Union
[
torch
.
Tensor
,
np
.
ndarray
]
)
->
torch
.
Tensor
:
"""
:param img: image array with type torch.Tensor or np.ndarray,
with shape [height, width] or [channel, height, width].
...
...
@@ -228,5 +253,7 @@ class CnOcr(object):
img_lengths
=
torch
.
tensor
([
img
.
shape
[
2
]
for
img
in
img_list
])
imgs
=
pad_img_seq
(
img_list
)
with
torch
.
no_grad
():
out
=
self
.
_mod
(
imgs
,
img_lengths
,
return_preds
=
True
)
out
=
self
.
_mod
(
imgs
,
img_lengths
,
candidates
=
self
.
_candidates
,
return_preds
=
True
)
return
out
cnocr/models/ocr_model.py
浏览文件 @
b360e4da
# coding: utf-8
from
typing
import
Tuple
,
Dict
,
Any
,
Optional
,
List
from
typing
import
Tuple
,
Dict
,
Any
,
Optional
,
List
,
Union
from
copy
import
deepcopy
import
numpy
as
np
...
...
@@ -147,9 +147,20 @@ class OcrModel(nn.Module):
x
:
torch
.
Tensor
,
input_lengths
:
torch
.
Tensor
,
target
:
Optional
[
List
[
str
]]
=
None
,
return_model_output
:
bool
=
False
,
candidates
:
Optional
[
Union
[
str
,
List
[
str
]]]
=
None
,
return_logits
:
bool
=
False
,
return_preds
:
bool
=
False
,
)
->
Dict
[
str
,
Any
]:
"""
:param x: [B, 1, H, W]; 一组padding后的图片
:param input_lengths: shape: [B];每张图片padding前的真实长度(宽度)
:param target: 真实的字符串
:param candidates: None or candidate strs; 允许的候选字符集合
:param return_logits: 是否返回预测的logits值
:param return_preds: 是否返回预测的字符串
:return: 预测结果
"""
features
=
self
.
encoder
(
x
)
input_lengths
=
input_lengths
//
self
.
encoder
.
compress_ratio
# B x C x H x W --> B x C*H x W --> B x W x C*H
...
...
@@ -160,9 +171,10 @@ class OcrModel(nn.Module):
logits
=
self
.
_decode
(
features_seq
,
input_lengths
)
logits
=
self
.
linear
(
logits
)
logits
=
self
.
_mask_by_candidates
(
logits
,
candidates
)
out
:
Dict
[
str
,
Any
]
=
{}
if
return_
model_output
:
if
return_
logits
:
out
[
"logits"
]
=
logits
if
target
is
None
or
return_preds
:
...
...
@@ -191,6 +203,25 @@ class OcrModel(nn.Module):
)
return
logits
def
_mask_by_candidates
(
self
,
logits
:
torch
.
Tensor
,
candidates
:
Optional
[
Union
[
str
,
List
[
str
]]]
):
if
candidates
is
None
:
return
logits
_candidates
=
[
self
.
letter2id
[
word
]
for
word
in
candidates
]
_candidates
.
sort
()
_candidates
=
torch
.
tensor
(
_candidates
,
dtype
=
torch
.
int64
)
candidates
=
torch
.
zeros
(
(
len
(
self
.
vocab
)
+
1
,),
dtype
=
torch
.
bool
,
device
=
logits
.
device
)
candidates
[
_candidates
]
=
True
candidates
[
-
1
]
=
True
# 间隔符号/填充符号,必须为真
candidates
=
candidates
.
unsqueeze
(
0
).
unsqueeze
(
0
)
# 1 x 1 x (vocab_size+1)
logits
.
masked_fill_
(
~
candidates
,
-
100.0
)
return
logits
def
_compute_loss
(
self
,
model_output
:
torch
.
Tensor
,
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录