Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
GoAI
attention_ocr.pytorch
提交
f31b54e2
A
attention_ocr.pytorch
项目概览
GoAI
/
attention_ocr.pytorch
与 Fork 源项目一致
从无法访问的项目Fork
通知
1
Star
0
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
A
attention_ocr.pytorch
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
前往新版Gitcode,体验更适合开发者的 AI 搜索 >>
提交
f31b54e2
编写于
10月 15, 2017
作者:
X
xiaohang
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
the first attention version, which seems works but very slow
上级
11ae8de2
变更
4
隐藏空白更改
内联
并排
Showing
4 changed file
with
150 addition
and
20 deletion
+150
-20
data/create_mnt_list.py
data/create_mnt_list.py
+2
-0
main.py
main.py
+13
-16
models/crnn.py
models/crnn.py
+64
-3
utils.py
utils.py
+71
-1
未找到文件。
data/create_mnt_list.py
浏览文件 @
f31b54e2
...
...
@@ -5,6 +5,7 @@ train_fp = open('data/train_list.txt', 'w')
for
line
in
lines
:
imgpath
=
line
.
strip
().
split
(
' '
)[
0
]
label
=
imgpath
.
split
(
'/'
)[
-
1
].
split
(
'_'
)[
1
].
lower
()
label
=
label
+
'$'
label
=
':'
.
join
(
label
)
imgpath
=
'data/mnt/ramdisk/max/90kDICT32px/%s'
%
imgpath
output
=
' '
.
join
([
imgpath
,
label
])
...
...
@@ -20,6 +21,7 @@ test_fp = open('data/test_list.txt', 'w')
for
line
in
lines
:
imgpath
=
line
.
strip
().
split
(
' '
)[
0
]
label
=
imgpath
.
split
(
'/'
)[
-
1
].
split
(
'_'
)[
1
].
lower
()
label
=
label
+
'$'
label
=
':'
.
join
(
label
)
imgpath
=
'data/mnt/ramdisk/max/90kDICT32px/%s'
%
imgpath
output
=
' '
.
join
([
imgpath
,
label
])
...
...
main.py
浏览文件 @
f31b54e2
...
...
@@ -29,7 +29,7 @@ parser.add_argument('--beta1', type=float, default=0.5, help='beta1 for adam. de
parser
.
add_argument
(
'--cuda'
,
action
=
'store_true'
,
help
=
'enables cuda'
)
parser
.
add_argument
(
'--ngpu'
,
type
=
int
,
default
=
1
,
help
=
'number of GPUs to use'
)
parser
.
add_argument
(
'--crnn'
,
default
=
''
,
help
=
"path to crnn (to continue training)"
)
parser
.
add_argument
(
'--alphabet'
,
type
=
str
,
default
=
'0:1:2:3:4:5:6:7:8:9:a:b:c:d:e:f:g:h:i:j:k:l:m:n:o:p:q:r:s:t:u:v:w:x:y:z'
)
parser
.
add_argument
(
'--alphabet'
,
type
=
str
,
default
=
'0:1:2:3:4:5:6:7:8:9:a:b:c:d:e:f:g:h:i:j:k:l:m:n:o:p:q:r:s:t:u:v:w:x:y:z
:$
'
)
parser
.
add_argument
(
'--sep'
,
type
=
str
,
default
=
':'
)
parser
.
add_argument
(
'--experiment'
,
default
=
None
,
help
=
'Where to store samples and models'
)
parser
.
add_argument
(
'--displayInterval'
,
type
=
int
,
default
=
500
,
help
=
'Interval to be displayed'
)
...
...
@@ -75,8 +75,8 @@ test_dataset = dataset.listDataset(list_file =opt.vallist, transform=dataset.res
nclass
=
len
(
opt
.
alphabet
.
split
(
opt
.
sep
))
+
1
nc
=
1
converter
=
utils
.
strLabelConverter
(
opt
.
alphabet
,
opt
.
sep
)
criterion
=
CTC
Loss
()
converter
=
utils
.
strLabelConverter
ForAttention
(
opt
.
alphabet
,
opt
.
sep
)
criterion
=
torch
.
nn
.
CrossEntropy
Loss
()
# custom weights initialization called on crnn
...
...
@@ -97,13 +97,14 @@ if opt.crnn != '':
print
(
crnn
)
image
=
torch
.
FloatTensor
(
opt
.
batchSize
,
3
,
opt
.
imgH
,
opt
.
imgH
)
text
=
torch
.
Int
Tensor
(
opt
.
batchSize
*
5
)
text
=
torch
.
Long
Tensor
(
opt
.
batchSize
*
5
)
length
=
torch
.
IntTensor
(
opt
.
batchSize
)
if
opt
.
cuda
:
crnn
.
cuda
()
crnn
=
torch
.
nn
.
DataParallel
(
crnn
,
device_ids
=
range
(
opt
.
ngpu
))
image
=
image
.
cuda
()
text
=
text
.
cuda
()
criterion
=
criterion
.
cuda
()
image
=
Variable
(
image
)
...
...
@@ -149,24 +150,21 @@ def val(net, dataset, criterion, max_iter=100):
utils
.
loadData
(
text
,
t
)
utils
.
loadData
(
length
,
l
)
preds
=
crnn
(
image
)
preds_size
=
Variable
(
torch
.
IntTensor
([
preds
.
size
(
0
)]
*
batch_size
))
cost
=
criterion
(
preds
,
text
,
preds_size
,
length
)
/
batch_size
preds
=
crnn
(
image
,
length
)
cost
=
criterion
(
preds
,
text
)
loss_avg
.
add
(
cost
)
_
,
preds
=
preds
.
max
(
2
)
#preds = preds.squeeze(2)
preds
=
preds
.
transpose
(
1
,
0
).
contiguous
().
view
(
-
1
)
sim_preds
=
converter
.
decode
(
preds
.
data
,
preds_size
.
data
,
raw
=
False
)
preds
=
preds
.
view
(
-
1
)
sim_preds
=
converter
.
decode
(
preds
.
data
,
length
.
data
)
for
pred
,
target
in
zip
(
sim_preds
,
cpu_texts
):
target
=
''
.
join
(
target
.
split
(
opt
.
sep
))
if
pred
==
target
:
n_correct
+=
1
raw_preds
=
converter
.
decode
(
preds
.
data
,
preds_size
.
data
,
raw
=
True
)[:
opt
.
n_test_disp
]
for
raw_pred
,
pred
,
gt
in
zip
(
raw_preds
,
sim_preds
,
cpu_texts
):
for
pred
,
gt
in
zip
(
sim_preds
,
cpu_texts
):
gt
=
''
.
join
(
gt
.
split
(
opt
.
sep
))
print
(
'%-20s
=> %-20s, gt: %-20s'
%
(
raw_pred
,
pred
,
gt
))
print
(
'%-20s
, gt: %-20s'
%
(
pred
,
gt
))
accuracy
=
n_correct
/
float
(
max_iter
*
opt
.
batchSize
)
print
(
'Test loss: %f, accuray: %f'
%
(
loss_avg
.
val
(),
accuracy
))
...
...
@@ -181,9 +179,8 @@ def trainBatch(net, criterion, optimizer):
utils
.
loadData
(
text
,
t
)
utils
.
loadData
(
length
,
l
)
preds
=
crnn
(
image
)
preds_size
=
Variable
(
torch
.
IntTensor
([
preds
.
size
(
0
)]
*
batch_size
))
cost
=
criterion
(
preds
,
text
,
preds_size
,
length
)
/
batch_size
preds
=
crnn
(
image
,
length
)
cost
=
criterion
(
preds
,
text
)
crnn
.
zero_grad
()
cost
.
backward
()
optimizer
.
step
()
...
...
models/crnn.py
浏览文件 @
f31b54e2
import
torch
import
torch.nn
as
nn
import
torch.nn.functional
as
F
from
torch.autograd
import
Variable
class
BidirectionalLSTM
(
nn
.
Module
):
...
...
@@ -19,6 +22,62 @@ class BidirectionalLSTM(nn.Module):
return
output
class
AttentionCell
(
nn
.
Module
):
def
__init__
(
self
,
input_size
,
hidden_size
):
super
(
AttentionCell
,
self
).
__init__
()
self
.
i2h
=
nn
.
Linear
(
input_size
,
hidden_size
,
bias
=
False
)
self
.
h2h
=
nn
.
Linear
(
hidden_size
,
hidden_size
)
self
.
score
=
nn
.
Linear
(
hidden_size
,
1
,
bias
=
False
)
self
.
rnn
=
nn
.
GRUCell
(
input_size
,
hidden_size
)
self
.
hidden_size
=
hidden_size
self
.
input_size
=
input_size
def
forward
(
self
,
prev_hidden
,
feats
):
nT
=
feats
.
size
(
0
)
nB
=
feats
.
size
(
1
)
assert
(
nB
==
1
)
nC
=
feats
.
size
(
2
)
hidden_size
=
self
.
hidden_size
input_size
=
self
.
input_size
feats_proj
=
self
.
i2h
(
feats
.
view
(
-
1
,
nC
))
prev_hidden_proj
=
self
.
h2h
(
prev_hidden
).
view
(
1
,
nB
,
hidden_size
).
expand
(
nT
,
nB
,
hidden_size
).
contiguous
().
view
(
-
1
,
hidden_size
)
emition
=
self
.
score
(
F
.
tanh
(
feats_proj
+
prev_hidden_proj
).
view
(
-
1
,
hidden_size
)).
view
(
nT
,
nB
).
transpose
(
0
,
1
)
alpha
=
F
.
softmax
(
emition
)
# nB * nT
context
=
(
feats
*
alpha
.
transpose
(
0
,
1
).
contiguous
().
view
(
nT
,
nB
,
1
).
expand
(
nT
,
nB
,
nC
)).
sum
(
0
).
squeeze
(
0
)
cur_hidden
=
self
.
rnn
(
context
,
prev_hidden
)
return
cur_hidden
,
alpha
class
Attention
(
nn
.
Module
):
def
__init__
(
self
,
input_size
,
hidden_size
,
num_classes
):
super
(
Attention
,
self
).
__init__
()
self
.
attention_cell
=
AttentionCell
(
input_size
,
hidden_size
)
self
.
input_size
=
input_size
self
.
hidden_size
=
hidden_size
self
.
generator
=
nn
.
Linear
(
hidden_size
,
num_classes
)
def
forward
(
self
,
feats
,
text_length
):
nT
=
feats
.
size
(
0
)
nB
=
feats
.
size
(
1
)
nC
=
feats
.
size
(
2
)
hidden_size
=
self
.
hidden_size
input_size
=
self
.
input_size
assert
(
input_size
==
nC
)
assert
(
nB
==
text_length
.
numel
())
num_labels
=
text_length
.
data
.
sum
()
output_hiddens
=
Variable
(
torch
.
zeros
(
num_labels
,
hidden_size
).
type_as
(
feats
.
data
))
k
=
0
for
j
in
range
(
nB
):
sub_feats
=
feats
[:,
j
,:].
contiguous
().
view
(
nT
,
1
,
nC
)
#feats.index_select(1, Variable(torch.LongTensor([j]).type_as(feats.data)))
sub_hidden
=
Variable
(
torch
.
zeros
(
1
,
hidden_size
).
type_as
(
feats
.
data
))
for
i
in
range
(
text_length
.
data
[
j
]):
sub_hidden
,
sub_alpha
=
self
.
attention_cell
(
sub_hidden
,
sub_feats
)
output_hiddens
[
k
]
=
sub_hidden
.
view
(
-
1
)
k
=
k
+
1
probs
=
self
.
generator
(
output_hiddens
)
return
probs
class
CRNN
(
nn
.
Module
):
...
...
@@ -71,9 +130,10 @@ class CRNN(nn.Module):
#self.cnn = cnn
self
.
rnn
=
nn
.
Sequential
(
BidirectionalLSTM
(
512
,
nh
,
nh
),
BidirectionalLSTM
(
nh
,
nh
,
nclass
))
BidirectionalLSTM
(
nh
,
nh
,
nh
))
self
.
attention
=
Attention
(
nh
,
nh
/
2
,
nclass
)
def
forward
(
self
,
input
):
def
forward
(
self
,
input
,
length
):
# conv features
conv
=
self
.
cnn
(
input
)
b
,
c
,
h
,
w
=
conv
.
size
()
...
...
@@ -82,6 +142,7 @@ class CRNN(nn.Module):
conv
=
conv
.
permute
(
2
,
0
,
1
)
# [w, b, c]
# rnn features
output
=
self
.
rnn
(
conv
)
rnn
=
self
.
rnn
(
conv
)
output
=
self
.
attention
(
rnn
,
length
)
return
output
utils.py
浏览文件 @
f31b54e2
...
...
@@ -6,8 +6,78 @@ import torch.nn as nn
from
torch.autograd
import
Variable
import
collections
class
strLabelConverterForAttention
(
object
):
"""Convert between str and label.
NOTE:
Insert `EOS` to the alphabet for attention.
Args:
alphabet (str): set of the possible characters.
ignore_case (bool, default=True): whether or not to ignore all of the case.
"""
def
__init__
(
self
,
alphabet
,
sep
):
self
.
sep
=
sep
self
.
alphabet
=
alphabet
.
split
(
sep
)
self
.
dict
=
{}
for
i
,
item
in
enumerate
(
self
.
alphabet
):
# NOTE: 0 is reserved for 'blank' required by wrap_ctc
self
.
dict
[
item
]
=
i
def
encode
(
self
,
text
):
"""Support batch or single str.
Args:
text (str or list of str): texts to convert.
Returns:
torch.IntTensor [length_0 + length_1 + ... length_{n - 1}]: encoded texts.
torch.IntTensor [n]: length of each text.
"""
if
isinstance
(
text
,
str
):
text
=
text
.
split
(
self
.
sep
)
text
=
[
self
.
dict
[
item
]
for
item
in
text
]
length
=
[
len
(
text
)]
elif
isinstance
(
text
,
collections
.
Iterable
):
length
=
[
len
(
s
.
split
(
self
.
sep
))
for
s
in
text
]
text
=
self
.
sep
.
join
(
text
)
text
,
_
=
self
.
encode
(
text
)
return
(
torch
.
LongTensor
(
text
),
torch
.
LongTensor
(
length
))
def
decode
(
self
,
t
,
length
):
"""Decode encoded texts back into strs.
Args:
torch.IntTensor [length_0 + length_1 + ... length_{n - 1}]: encoded texts.
torch.IntTensor [n]: length of each text.
Raises:
AssertionError: when the texts and its length does not match.
Returns:
text (str or list of str): texts to convert.
"""
if
length
.
numel
()
==
1
:
length
=
length
[
0
]
assert
t
.
numel
()
==
length
,
"text with length: {} does not match declared length: {}"
.
format
(
t
.
numel
(),
length
)
if
raw
:
return
''
.
join
([
self
.
alphabet
[
i
]
for
i
in
t
])
else
:
# batch mode
assert
t
.
numel
()
==
length
.
sum
(),
"texts with length: {} does not match declared length: {}"
.
format
(
t
.
numel
(),
length
.
sum
())
texts
=
[]
index
=
0
for
i
in
range
(
length
.
numel
()):
l
=
length
[
i
]
texts
.
append
(
self
.
decode
(
t
[
index
:
index
+
l
],
torch
.
LongTensor
([
l
])))
index
+=
l
return
texts
class
strLabelConverter
(
object
):
class
strLabelConverter
ForCTC
(
object
):
"""Convert between str and label.
NOTE:
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录