Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
yinnxinn
chineseocr
提交
9e6f83cb
C
chineseocr
项目概览
yinnxinn
/
chineseocr
与 Fork 源项目一致
从无法访问的项目Fork
通知
5
Star
0
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
C
chineseocr
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
前往新版Gitcode,体验更适合开发者的 AI 搜索 >>
未验证
提交
9e6f83cb
编写于
7月 18, 2019
作者:
L
lywen
提交者:
GitHub
7月 18, 2019
浏览文件
操作
浏览文件
下载
差异文件
Merge pull request #302 from wsxqyws/dev_lstm
支持 pytorch lstm层转换为keras lstm层
上级
a017ba21
897ca5b2
变更
3
隐藏空白更改
内联
并排
Showing
3 changed file
with
84 addition
and
16 deletion
+84
-16
crnn/crnn_keras.py
crnn/crnn_keras.py
+2
-4
crnn/network_keras.py
crnn/network_keras.py
+13
-3
tools/pytorch_to_keras.py
tools/pytorch_to_keras.py
+69
-9
未找到文件。
crnn/crnn_keras.py
浏览文件 @
9e6f83cb
...
...
@@ -2,6 +2,7 @@
from
crnn.utils
import
strLabelConverter
,
resizeNormalize
from
crnn.network_keras
import
keras_crnn
as
CRNN
from
config
import
LSTMFLAG
import
tensorflow
as
tf
graph
=
tf
.
get_default_graph
()
##解决web.py 相关报错问题
...
...
@@ -11,7 +12,7 @@ import numpy as np
def
crnnSource
():
alphabet
=
keys
.
alphabetChinese
##中英文模型
converter
=
strLabelConverter
(
alphabet
)
model
=
CRNN
(
32
,
1
,
len
(
alphabet
)
+
1
,
256
,
1
,
lstmFlag
=
False
)
model
=
CRNN
(
32
,
1
,
len
(
alphabet
)
+
1
,
256
,
1
,
lstmFlag
=
LSTMFLAG
)
model
.
load_weights
(
ocrModelKeras
)
return
model
,
converter
...
...
@@ -37,6 +38,3 @@ def crnnOcr(image):
preds
=
np
.
argmax
(
preds
,
axis
=
2
).
reshape
((
-
1
,))
sim_pred
=
converter
.
decode
(
preds
)
return
sim_pred
crnn/network_keras.py
浏览文件 @
9e6f83cb
from
keras.layers
import
Conv2D
,
BatchNormalization
,
MaxPool2D
,
Input
,
Permute
,
Reshape
,
Dense
,
LeakyReLU
,
Activation
from
keras.layers
import
(
Conv2D
,
BatchNormalization
,
MaxPool2D
,
Input
,
Permute
,
Reshape
,
Dense
,
LeakyReLU
,
Activation
,
Bidirectional
,
LSTM
,
TimeDistributed
)
from
keras.models
import
Model
from
keras.layers
import
ZeroPadding2D
from
keras.activations
import
relu
...
...
@@ -68,7 +68,17 @@ def keras_crnn(imgH, nc, nclass, nh, n_rnn=2, leakyRelu=False,lstmFlag=True):
x
=
Permute
((
3
,
2
,
1
))(
x
)
x
=
Reshape
((
-
1
,
512
))(
x
)
out
=
Dense
(
nclass
,
name
=
'linear'
)(
x
)
out
=
None
if
lstmFlag
:
x
=
Bidirectional
(
LSTM
(
nh
,
return_sequences
=
True
,
use_bias
=
True
,
recurrent_activation
=
'sigmoid'
))(
x
)
x
=
TimeDistributed
(
Dense
(
nh
))(
x
)
x
=
Bidirectional
(
LSTM
(
nh
,
return_sequences
=
True
,
use_bias
=
True
,
recurrent_activation
=
'sigmoid'
))(
x
)
out
=
TimeDistributed
(
Dense
(
nclass
))(
x
)
else
:
out
=
Dense
(
nclass
,
name
=
'linear'
)(
x
)
out
=
Reshape
((
-
1
,
1
,
nclass
),
name
=
'out'
)(
out
)
return
Model
(
imgInput
,
out
)
\ No newline at end of file
return
Model
(
imgInput
,
out
)
tools/pytorch_to_keras.py
浏览文件 @
9e6f83cb
...
...
@@ -12,6 +12,8 @@ def parser():
parser
=
argparse
.
ArgumentParser
(
description
=
"pytorch dense ocr to keras ocr"
)
parser
.
add_argument
(
'-weights_path'
,
help
=
'models/ocr-dense.pth'
)
parser
.
add_argument
(
'-output_path'
,
help
=
'models/ocr-dense-keras.h5'
)
parser
.
add_argument
(
'-lstm'
,
default
=
False
,
action
=
'store_true'
,
help
=
'translate lstm layer'
)
return
parser
.
parse_args
()
def
set_cnn_weight
(
name
,
keramodel
,
torchmodelDict
):
...
...
@@ -68,7 +70,62 @@ def set_dense_weight(name,keramodel,torchmodelDict):
if
weight
is
not
None
and
bias
is
not
None
:
weight
=
np
.
transpose
(
weight
)
keramodel
.
get_layer
(
name
).
set_weights
([
weight
,
bias
])
def
set_lstm_weight
(
name
,
kerasmodel
,
torchmodelDict
):
# RNN
weight_ih_l0
=
None
weight_hh_l0
=
None
bias_ih_l0
=
None
bias_hh_l0
=
None
weight_ih_l0_reverse
=
None
weight_hh_l0_reverse
=
None
bias_ih_l0_reverse
=
None
bias_hh_l0_reverse
=
None
# TimeDistributed
embedding_weight
=
None
embedding_bias
=
None
for
key
in
torchmodelDict
:
if
name
in
key
:
if
key
.
endswith
(
'rnn.weight_ih_l0'
):
weight_ih_l0
=
torchmodelDict
[
key
]
elif
key
.
endswith
(
'rnn.weight_hh_l0'
):
weight_hh_l0
=
torchmodelDict
[
key
]
elif
key
.
endswith
(
'rnn.bias_ih_l0'
):
bias_ih_l0
=
torchmodelDict
[
key
]
elif
key
.
endswith
(
'rnn.bias_hh_l0'
):
bias_hh_l0
=
torchmodelDict
[
key
]
elif
key
.
endswith
(
'rnn.weight_ih_l0_reverse'
):
weight_ih_l0_reverse
=
torchmodelDict
[
key
]
elif
key
.
endswith
(
'rnn.weight_hh_l0_reverse'
):
weight_hh_l0_reverse
=
torchmodelDict
[
key
]
elif
key
.
endswith
(
'rnn.bias_ih_l0_reverse'
):
bias_ih_l0_reverse
=
torchmodelDict
[
key
]
elif
key
.
endswith
(
'rnn.bias_hh_l0_reverse'
):
bias_hh_l0_reverse
=
torchmodelDict
[
key
]
elif
key
.
endswith
(
'embedding.weight'
):
embedding_weight
=
torchmodelDict
[
key
]
elif
key
.
endswith
(
'embedding.bias'
):
embedding_bias
=
torchmodelDict
[
key
]
rnn_weights
=
[
weight_ih_l0
.
transpose
(
1
,
0
),
weight_hh_l0
.
transpose
(
1
,
0
),
(
bias_ih_l0
+
bias_hh_l0
),
weight_ih_l0_reverse
.
transpose
(
1
,
0
),
weight_hh_l0_reverse
.
transpose
(
1
,
0
),
(
bias_ih_l0_reverse
+
bias_hh_l0_reverse
)
]
linear_weights
=
[
embedding_weight
.
transpose
(
1
,
0
).
numpy
(),
embedding_bias
.
numpy
(),
]
if
name
==
'rnn.0'
:
kerasmodel
.
get_layer
(
'bidirectional_1'
).
set_weights
(
rnn_weights
)
kerasmodel
.
get_layer
(
'time_distributed_1'
).
set_weights
(
linear_weights
)
else
:
kerasmodel
.
get_layer
(
'bidirectional_2'
).
set_weights
(
rnn_weights
)
kerasmodel
.
get_layer
(
'time_distributed_2'
).
set_weights
(
linear_weights
)
if
__name__
==
'__main__'
:
import
os
import
sys
...
...
@@ -81,10 +138,11 @@ if __name__=='__main__':
from
collections
import
OrderedDict
from
crnn.keys
import
alphabetChinese
from
crnn.network_keras
import
keras_crnn
##ocrModel='models/ocr-dense.pth'##目前只支持 dense ocr
##ocrModel='models/ocr-dense.pth' #dense ocr
##ocrModel='models/ocr-lstm.pth' #lstm ocr
ocrModel
=
args
.
weights_path
##torch模型权重
output_path
=
args
.
output_path
##keras 模型权重输出
kerasModel
=
keras_crnn
(
32
,
1
,
len
(
alphabetChinese
)
+
1
,
256
,
1
,
lstmFlag
=
False
)
kerasModel
=
keras_crnn
(
32
,
1
,
len
(
alphabetChinese
)
+
1
,
256
,
1
,
lstmFlag
=
args
.
lstm
)
state_dict
=
torch
.
load
(
ocrModel
,
map_location
=
lambda
storage
,
loc
:
storage
)
new_state_dict
=
OrderedDict
()
...
...
@@ -96,17 +154,19 @@ if __name__=='__main__':
cnn
=
[
'cnn.conv0'
,
'cnn.conv1'
,
'cnn.conv2'
,
'cnn.conv3'
,
'cnn.conv4'
,
'cnn.conv5'
,
'cnn.conv6'
]
BN
=
[
'cnn.batchnorm2'
,
'cnn.batchnorm4'
,
'cnn.batchnorm6'
]
linear
=
[
'linear'
]
lstm
=
[
'rnn.0'
,
'rnn.1'
]
##CNN 层
for
cn
in
cnn
:
set_cnn_weight
(
cn
,
kerasModel
,
new_state_dict
)
##BN 层
for
bn
in
BN
:
set_bn_weight
(
bn
,
kerasModel
,
new_state_dict
)
## linear 层
for
lr
in
linear
:
set_dense_weight
(
lr
,
kerasModel
,
new_state_dict
)
if
args
.
lstm
:
for
l
in
lstm
:
set_lstm_weight
(
l
,
kerasModel
,
new_state_dict
)
else
:
## linear 层
for
lr
in
linear
:
set_dense_weight
(
lr
,
kerasModel
,
new_state_dict
)
kerasModel
.
save_weights
(
output_path
)
##保存keras权重
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录