Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
GoAI
attention_ocr.pytorch
提交
3df8fc31
A
attention_ocr.pytorch
项目概览
GoAI
/
attention_ocr.pytorch
与 Fork 源项目一致
从无法访问的项目Fork
通知
1
Star
0
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
A
attention_ocr.pytorch
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
前往新版Gitcode,体验更适合开发者的 AI 搜索 >>
提交
3df8fc31
编写于
10月 31, 2017
作者:
X
xiaohang
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
add -lang .module. max_locs alpha emition
上级
fd54339d
变更
3
隐藏空白更改
内联
并排
Showing
3 changed file
with
44 addition
and
37 deletion
+44
-37
data/create_mnt_list.py
data/create_mnt_list.py
+15
-0
main.py
main.py
+11
-3
models/crnn.py
models/crnn.py
+18
-34
未找到文件。
data/create_mnt_list.py
浏览文件 @
3df8fc31
...
...
@@ -28,3 +28,18 @@ for line in lines:
print
>>
test_fp
,
output
test_fp
.
close
()
with
open
(
'data/mnt/ramdisk/max/90kDICT32px/annotation_test.txt'
)
as
fp
:
lines
=
fp
.
readlines
()
val_fp
=
open
(
'data/val_list.txt'
,
'w'
)
for
line
in
lines
:
imgpath
=
line
.
strip
().
split
(
' '
)[
0
]
label
=
imgpath
.
split
(
'/'
)[
-
1
].
split
(
'_'
)[
1
].
lower
()
label
=
label
+
'$'
label
=
':'
.
join
(
label
)
imgpath
=
'data/mnt/ramdisk/max/90kDICT32px/%s'
%
imgpath
output
=
' '
.
join
([
imgpath
,
label
])
print
>>
val_fp
,
output
val_fp
.
close
()
main.py
浏览文件 @
3df8fc31
...
...
@@ -14,6 +14,7 @@ import dataset
import
time
import
models.crnn
as
crnn
print
(
crnn
.
__name__
)
parser
=
argparse
.
ArgumentParser
()
parser
.
add_argument
(
'--trainlist'
,
required
=
True
,
help
=
'path to train_list'
)
...
...
@@ -39,6 +40,7 @@ parser.add_argument('--saveInterval', type=int, default=10000, help='Interval to
parser
.
add_argument
(
'--adam'
,
action
=
'store_true'
,
help
=
'Whether to use adam (default is rmsprop)'
)
parser
.
add_argument
(
'--adadelta'
,
action
=
'store_true'
,
help
=
'Whether to use adadelta (default is rmsprop)'
)
parser
.
add_argument
(
'--keep_ratio'
,
action
=
'store_true'
,
help
=
'whether to keep ratio for image resize'
)
parser
.
add_argument
(
'--lang'
,
action
=
'store_true'
,
help
=
'whether to use char language model'
)
parser
.
add_argument
(
'--random_sample'
,
action
=
'store_true'
,
help
=
'whether to sample the dataset with random sampler'
)
opt
=
parser
.
parse_args
()
print
(
opt
)
...
...
@@ -150,7 +152,10 @@ def val(net, dataset, criterion, max_iter=100):
utils
.
loadData
(
text
,
t
)
utils
.
loadData
(
length
,
l
)
preds
=
crnn
(
image
,
length
)
if
opt
.
lang
:
preds
=
crnn
(
image
,
length
,
text
)
else
:
preds
=
crnn
(
image
,
length
)
cost
=
criterion
(
preds
,
text
)
loss_avg
.
add
(
cost
)
...
...
@@ -179,7 +184,10 @@ def trainBatch(net, criterion, optimizer):
utils
.
loadData
(
text
,
t
)
utils
.
loadData
(
length
,
l
)
preds
=
crnn
(
image
,
length
)
if
opt
.
lang
:
preds
=
crnn
(
image
,
length
,
text
)
else
:
preds
=
crnn
(
image
,
length
)
cost
=
criterion
(
preds
,
text
)
crnn
.
zero_grad
()
cost
.
backward
()
...
...
@@ -214,4 +222,4 @@ for epoch in range(opt.niter):
# do checkpointing
if
i
%
opt
.
saveInterval
==
0
:
torch
.
save
(
crnn
.
state_dict
(),
'{0}/netCRNN_{1}_{2}.pth'
.
format
(
opt
.
experiment
,
epoch
,
i
))
crnn
.
module
.
state_dict
(),
'{0}/netCRNN_{1}_{2}.pth'
.
format
(
opt
.
experiment
,
epoch
,
i
))
models/crnn.py
浏览文件 @
3df8fc31
...
...
@@ -31,8 +31,10 @@ class AttentionCell(nn.Module):
self
.
rnn
=
nn
.
GRUCell
(
input_size
,
hidden_size
)
self
.
hidden_size
=
hidden_size
self
.
input_size
=
input_size
self
.
processed_batches
=
0
def
forward
(
self
,
prev_hidden
,
feats
):
self
.
processed_batches
=
self
.
processed_batches
+
1
nT
=
feats
.
size
(
0
)
nB
=
feats
.
size
(
1
)
nC
=
feats
.
size
(
2
)
...
...
@@ -43,6 +45,11 @@ class AttentionCell(nn.Module):
prev_hidden_proj
=
self
.
h2h
(
prev_hidden
).
view
(
1
,
nB
,
hidden_size
).
expand
(
nT
,
nB
,
hidden_size
).
contiguous
().
view
(
-
1
,
hidden_size
)
emition
=
self
.
score
(
F
.
tanh
(
feats_proj
+
prev_hidden_proj
).
view
(
-
1
,
hidden_size
)).
view
(
nT
,
nB
).
transpose
(
0
,
1
)
alpha
=
F
.
softmax
(
emition
)
# nB * nT
if
self
.
processed_batches
%
10000
==
0
:
print
(
'emition '
,
list
(
emition
.
data
[
0
]))
print
(
'alpha '
,
list
(
alpha
.
data
[
0
]))
context
=
(
feats
*
alpha
.
transpose
(
0
,
1
).
contiguous
().
view
(
nT
,
nB
,
1
).
expand
(
nT
,
nB
,
nC
)).
sum
(
0
).
squeeze
(
0
)
cur_hidden
=
self
.
rnn
(
context
,
prev_hidden
)
return
cur_hidden
,
alpha
...
...
@@ -54,8 +61,10 @@ class Attention(nn.Module):
self
.
input_size
=
input_size
self
.
hidden_size
=
hidden_size
self
.
generator
=
nn
.
Linear
(
hidden_size
,
num_classes
)
self
.
processed_batches
=
0
def
forward
(
self
,
feats
,
text_length
):
self
.
processed_batches
=
self
.
processed_batches
+
1
nT
=
feats
.
size
(
0
)
nB
=
feats
.
size
(
1
)
nC
=
feats
.
size
(
2
)
...
...
@@ -69,9 +78,18 @@ class Attention(nn.Module):
output_hiddens
=
Variable
(
torch
.
zeros
(
num_steps
,
nB
,
hidden_size
).
type_as
(
feats
.
data
))
hidden
=
Variable
(
torch
.
zeros
(
nB
,
hidden_size
).
type_as
(
feats
.
data
))
max_locs
=
torch
.
zeros
(
num_steps
,
nB
)
max_vals
=
torch
.
zeros
(
num_steps
,
nB
)
for
i
in
range
(
num_steps
):
hidden
,
alpha
=
self
.
attention_cell
(
hidden
,
feats
)
output_hiddens
[
i
]
=
hidden
if
self
.
processed_batches
%
500
==
0
:
max_val
,
max_loc
=
alpha
.
data
.
max
(
1
)
max_locs
[
i
]
=
max_loc
.
cpu
()
max_vals
[
i
]
=
max_val
.
cpu
()
if
self
.
processed_batches
%
500
==
0
:
print
(
'max_locs'
,
list
(
max_locs
[
0
:
text_length
.
data
[
0
],
0
]))
print
(
'max_vals'
,
list
(
max_vals
[
0
:
text_length
.
data
[
0
],
0
]))
new_hiddens
=
Variable
(
torch
.
zeros
(
num_labels
,
hidden_size
).
type_as
(
feats
.
data
))
b
=
0
start
=
0
...
...
@@ -88,40 +106,6 @@ class CRNN(nn.Module):
super
(
CRNN
,
self
).
__init__
()
assert
imgH
%
16
==
0
,
'imgH has to be a multiple of 16'
ks
=
[
3
,
3
,
3
,
3
,
3
,
3
,
2
]
ps
=
[
1
,
1
,
1
,
1
,
1
,
1
,
0
]
ss
=
[
1
,
1
,
1
,
1
,
1
,
1
,
1
]
nm
=
[
64
,
128
,
256
,
256
,
512
,
512
,
512
]
cnn
=
nn
.
Sequential
()
def
convRelu
(
i
,
batchNormalization
=
False
):
nIn
=
nc
if
i
==
0
else
nm
[
i
-
1
]
nOut
=
nm
[
i
]
cnn
.
add_module
(
'conv{0}'
.
format
(
i
),
nn
.
Conv2d
(
nIn
,
nOut
,
ks
[
i
],
ss
[
i
],
ps
[
i
]))
if
batchNormalization
:
cnn
.
add_module
(
'batchnorm{0}'
.
format
(
i
),
nn
.
BatchNorm2d
(
nOut
))
if
leakyRelu
:
cnn
.
add_module
(
'relu{0}'
.
format
(
i
),
nn
.
LeakyReLU
(
0.2
,
inplace
=
True
))
else
:
cnn
.
add_module
(
'relu{0}'
.
format
(
i
),
nn
.
ReLU
(
True
))
convRelu
(
0
)
cnn
.
add_module
(
'pooling{0}'
.
format
(
0
),
nn
.
MaxPool2d
(
2
,
2
))
# 64x16x64
convRelu
(
1
)
cnn
.
add_module
(
'pooling{0}'
.
format
(
1
),
nn
.
MaxPool2d
(
2
,
2
))
# 128x8x32
convRelu
(
2
,
True
)
convRelu
(
3
)
cnn
.
add_module
(
'pooling{0}'
.
format
(
2
),
nn
.
MaxPool2d
((
2
,
2
),
(
2
,
1
),
(
0
,
1
)))
# 256x4x16
convRelu
(
4
,
True
)
convRelu
(
5
)
cnn
.
add_module
(
'pooling{0}'
.
format
(
3
),
nn
.
MaxPool2d
((
2
,
2
),
(
2
,
1
),
(
0
,
1
)))
# 512x2x16
convRelu
(
6
,
True
)
# 512x1x16
self
.
cnn
=
nn
.
Sequential
(
nn
.
Conv2d
(
nc
,
64
,
3
,
1
,
1
),
nn
.
ReLU
(
True
),
nn
.
MaxPool2d
(
2
,
2
),
# 64x16x50
nn
.
Conv2d
(
64
,
128
,
3
,
1
,
1
),
nn
.
ReLU
(
True
),
nn
.
MaxPool2d
(
2
,
2
),
# 128x8x25
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录