Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
Just_Paranoid
CnOCR
提交
2f8a95fe
CnOCR
项目概览
Just_Paranoid
/
CnOCR
与 Fork 源项目一致
Fork自
Cloud IDE / CnOCR
通知
1
Star
0
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
DevOps
流水线
流水线任务
计划
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
CnOCR
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
DevOps
DevOps
流水线
流水线任务
计划
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
流水线任务
提交
Issue看板
前往新版Gitcode,体验更适合开发者的 AI 搜索 >>
提交
2f8a95fe
编写于
5月 14, 2022
作者:
B
breezedeus
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
refactor for onnx exportation
上级
f9f74d81
变更
1
隐藏空白更改
内联
并排
Showing
1 changed file
with
22 addition
and
11 deletion
+22
-11
cnocr/models/ocr_model.py
cnocr/models/ocr_model.py
+22
-11
未找到文件。
cnocr/models/ocr_model.py
浏览文件 @
2f8a95fe
...
...
@@ -17,7 +17,7 @@
# specific language governing permissions and limitations
# under the License.
# Credits: adapted from https://github.com/mindee/doctr
from
collections
import
OrderedDict
from
typing
import
Tuple
,
Dict
,
Any
,
Optional
,
List
,
Union
from
copy
import
deepcopy
...
...
@@ -177,7 +177,7 @@ class OcrModel(nn.Module):
input_lengths
:
torch
.
Tensor
,
target
:
Optional
[
List
[
str
]]
=
None
,
candidates
:
Optional
[
Union
[
str
,
List
[
str
]]]
=
None
,
return_logits
:
bool
=
Fals
e
,
return_logits
:
bool
=
Tru
e
,
return_preds
:
bool
=
False
,
)
->
Dict
[
str
,
Any
]:
"""
...
...
@@ -191,7 +191,9 @@ class OcrModel(nn.Module):
:return: 预测结果
"""
features
=
self
.
encoder
(
x
)
input_lengths
=
input_lengths
//
self
.
encoder
.
compress_ratio
input_lengths
=
torch
.
div
(
input_lengths
,
self
.
encoder
.
compress_ratio
,
rounding_mode
=
'floor'
)
# B x C x H x W --> B x C*H x W --> B x W x C*H
c
,
h
,
w
=
features
.
shape
[
1
],
features
.
shape
[
2
],
features
.
shape
[
3
]
features_seq
=
torch
.
reshape
(
features
,
shape
=
(
-
1
,
h
*
c
,
w
))
...
...
@@ -200,20 +202,24 @@ class OcrModel(nn.Module):
logits
=
self
.
_decode
(
features_seq
,
input_lengths
)
logits
=
self
.
linear
(
logits
)
logits
=
self
.
_mask_by_candidates
(
logits
,
candidates
)
logits
=
self
.
mask_by_candidates
(
logits
,
candidates
,
self
.
vocab
,
self
.
letter2id
)
out
:
Dict
[
str
,
Any
]
=
{}
out
:
Ordered
Dict
[
str
,
Any
]
=
{}
if
return_logits
:
out
[
"logits"
]
=
logits
out
[
'output_lengths'
]
=
input_lengths
if
target
is
None
or
return_preds
:
# Post-process boxes
out
[
"preds"
]
=
self
.
postprocessor
(
logits
,
input_lengths
)
if
self
.
postprocessor
is
not
None
:
out
[
"preds"
]
=
self
.
postprocessor
(
logits
,
input_lengths
)
if
target
is
not
None
:
out
[
'loss'
]
=
self
.
_compute_loss
(
logits
,
target
,
input_lengths
)
return
out
return
dict
(
out
)
def
_decode
(
self
,
features_seq
,
input_lengths
):
if
not
isinstance
(
self
.
decoder
,
(
nn
.
LSTM
,
nn
.
GRU
)):
...
...
@@ -232,18 +238,23 @@ class OcrModel(nn.Module):
)
return
logits
def
_mask_by_candidates
(
self
,
logits
:
torch
.
Tensor
,
candidates
:
Optional
[
Union
[
str
,
List
[
str
]]]
@
classmethod
def
mask_by_candidates
(
cls
,
logits
:
torch
.
Tensor
,
candidates
:
Optional
[
Union
[
str
,
List
[
str
]]],
vocab
:
List
[
str
],
letter2id
:
Dict
[
str
,
int
],
):
if
candidates
is
None
:
return
logits
_candidates
=
[
self
.
letter2id
[
word
]
for
word
in
candidates
]
_candidates
=
[
letter2id
[
word
]
for
word
in
candidates
]
_candidates
.
sort
()
_candidates
=
torch
.
tensor
(
_candidates
,
dtype
=
torch
.
int64
)
candidates
=
torch
.
zeros
(
(
len
(
self
.
vocab
)
+
1
,),
dtype
=
torch
.
bool
,
device
=
logits
.
device
(
len
(
vocab
)
+
1
,),
dtype
=
torch
.
bool
,
device
=
logits
.
device
)
candidates
[
_candidates
]
=
True
candidates
[
-
1
]
=
True
# 间隔符号/填充符号,必须为真
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录