Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
yinnxinn
chineseocr
提交
60383ca1
C
chineseocr
项目概览
yinnxinn
/
chineseocr
与 Fork 源项目一致
从无法访问的项目Fork
通知
5
Star
0
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
C
chineseocr
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
前往新版Gitcode,体验更适合开发者的 AI 搜索 >>
提交
60383ca1
编写于
10月 09, 2018
作者:
L
lywen
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
新增 dense ocr
上级
f449fe61
变更
3
展开全部
隐藏空白更改
内联
并排
Showing
3 changed file
with
49 addition
and
38 deletion
+49
-38
crnn/crnn.py
crnn/crnn.py
+10
-8
crnn/keys.py
crnn/keys.py
+2
-1
crnn/models/crnn.py
crnn/models/crnn.py
+37
-29
未找到文件。
crnn/crnn.py
浏览文件 @
60383ca1
...
...
@@ -8,19 +8,21 @@ from crnn import util
from
crnn
import
dataset
from
crnn.models
import
crnn
as
crnn
from
crnn
import
keys
#from conf import crnnModelPath
#from conf import GPU
GPU
=
False
from
collections
import
OrderedDict
from
config
import
ocrModel
from
config
import
ocrModel
,
LSTMFLAG
,
GPU
from
config
import
chinsesModel
def
crnnSource
():
alphabet
=
keys
.
alphabet
if
chinsesModel
:
alphabet
=
keys
.
alphabetChinese
else
:
alphabet
=
keys
.
alphabetEnglish
converter
=
util
.
strLabelConverter
(
alphabet
)
if
torch
.
cuda
.
is_available
()
and
GPU
:
model
=
crnn
.
CRNN
(
32
,
1
,
len
(
alphabet
)
+
1
,
256
,
1
).
cuda
()
model
=
crnn
.
CRNN
(
32
,
1
,
len
(
alphabet
)
+
1
,
256
,
1
,
lstmFlag
=
LSTMFLAG
).
cuda
()
##LSTMFLAG=True crnn 否则 dense ocr
else
:
model
=
crnn
.
CRNN
(
32
,
1
,
len
(
alphabet
)
+
1
,
256
,
1
).
cpu
()
model
=
crnn
.
CRNN
(
32
,
1
,
len
(
alphabet
)
+
1
,
256
,
1
,
lstmFlag
=
LSTMFLAG
).
cpu
()
state_dict
=
torch
.
load
(
ocrModel
,
map_location
=
lambda
storage
,
loc
:
storage
)
new_state_dict
=
OrderedDict
()
for
k
,
v
in
state_dict
.
items
():
...
...
crnn/keys.py
浏览文件 @
60383ca1
此差异已折叠。
点击以展开。
crnn/models/crnn.py
浏览文件 @
60383ca1
import
torch.nn
as
nn
from
.
import
utils
class
BidirectionalLSTM
(
nn
.
Module
):
def
__init__
(
self
,
nIn
,
nHidden
,
nOut
,
ngpu
):
def
__init__
(
self
,
nIn
,
nHidden
,
nOut
):
super
(
BidirectionalLSTM
,
self
).
__init__
()
self
.
ngpu
=
ngpu
self
.
rnn
=
nn
.
LSTM
(
nIn
,
nHidden
,
bidirectional
=
True
)
self
.
embedding
=
nn
.
Linear
(
nHidden
*
2
,
nOut
)
def
forward
(
self
,
input
):
recurrent
,
_
=
utils
.
data_parallel
(
self
.
rnn
,
input
,
self
.
ngpu
)
# [T, b, h * 2]
recurrent
,
_
=
self
.
rnn
(
input
)
T
,
b
,
h
=
recurrent
.
size
()
t_rec
=
recurrent
.
view
(
T
*
b
,
h
)
output
=
utils
.
data_parallel
(
self
.
embedding
,
t_rec
,
self
.
ngpu
)
# [T * b, nOut]
output
=
self
.
embedding
(
t_rec
)
# [T * b, nOut]
output
=
output
.
view
(
T
,
b
,
-
1
)
return
output
class
CRNN
(
nn
.
Module
):
def
__init__
(
self
,
imgH
,
nc
,
nclass
,
nh
,
ngpu
,
n_rnn
=
2
,
leakyRelu
=
False
):
def
__init__
(
self
,
imgH
,
nc
,
nclass
,
nh
,
n_rnn
=
2
,
leakyRelu
=
False
,
lstmFlag
=
True
):
"""
是否加入lstm特征层
"""
super
(
CRNN
,
self
).
__init__
()
self
.
ngpu
=
ngpu
assert
imgH
%
16
==
0
,
'imgH has to be a multiple of 16'
ks
=
[
3
,
3
,
3
,
3
,
3
,
3
,
2
]
ps
=
[
1
,
1
,
1
,
1
,
1
,
1
,
0
]
ss
=
[
1
,
1
,
1
,
1
,
1
,
1
,
1
]
nm
=
[
64
,
128
,
256
,
256
,
512
,
512
,
512
]
self
.
lstmFlag
=
lstmFlag
cnn
=
nn
.
Sequential
()
...
...
@@ -57,31 +55,41 @@ class CRNN(nn.Module):
cnn
.
add_module
(
'pooling{0}'
.
format
(
1
),
nn
.
MaxPool2d
(
2
,
2
))
# 128x8x32
convRelu
(
2
,
True
)
convRelu
(
3
)
cnn
.
add_module
(
'pooling{0}'
.
format
(
2
),
nn
.
MaxPool2d
((
2
,
2
),
(
2
,
1
),
(
0
,
1
)))
# 256x4x16
cnn
.
add_module
(
'pooling{0}'
.
format
(
2
),
nn
.
MaxPool2d
((
2
,
2
),
(
2
,
1
),
(
0
,
1
)))
# 256x4x16
convRelu
(
4
,
True
)
convRelu
(
5
)
cnn
.
add_module
(
'pooling{0}'
.
format
(
3
),
nn
.
MaxPool2d
((
2
,
2
),
(
2
,
1
),
(
0
,
1
)))
# 512x2x16
cnn
.
add_module
(
'pooling{0}'
.
format
(
3
),
nn
.
MaxPool2d
((
2
,
2
),
(
2
,
1
),
(
0
,
1
)))
# 512x2x16
convRelu
(
6
,
True
)
# 512x1x16
self
.
cnn
=
cnn
self
.
rnn
=
nn
.
Sequential
(
BidirectionalLSTM
(
512
,
nh
,
nh
,
ngpu
),
BidirectionalLSTM
(
nh
,
nh
,
nclass
,
ngpu
)
)
if
self
.
lstmFlag
:
self
.
rnn
=
nn
.
Sequential
(
BidirectionalLSTM
(
512
,
nh
,
nh
),
BidirectionalLSTM
(
nh
,
nh
,
nclass
))
else
:
self
.
linear
=
nn
.
Linear
(
nh
*
2
,
nclass
)
def
forward
(
self
,
input
):
# conv features
conv
=
utils
.
data_parallel
(
self
.
cnn
,
input
,
self
.
ngpu
)
conv
=
self
.
cnn
(
input
)
b
,
c
,
h
,
w
=
conv
.
size
()
assert
h
==
1
,
"the height of conv must be 1"
conv
=
conv
.
squeeze
(
2
)
conv
=
conv
.
permute
(
2
,
0
,
1
)
# [w, b, c]
# rnn features
output
=
utils
.
data_parallel
(
self
.
rnn
,
conv
,
self
.
ngpu
)
if
self
.
lstmFlag
:
# rnn features
output
=
self
.
rnn
(
conv
)
else
:
T
,
b
,
h
=
conv
.
size
()
t_rec
=
conv
.
contiguous
().
view
(
T
*
b
,
h
)
output
=
self
.
linear
(
t_rec
)
# [T * b, nOut]
output
=
output
.
view
(
T
,
b
,
-
1
)
return
output
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录