Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
GoAI
attention_ocr.pytorch
提交
11ae8de2
A
attention_ocr.pytorch
项目概览
GoAI
/
attention_ocr.pytorch
与 Fork 源项目一致
从无法访问的项目Fork
通知
1
Star
0
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
A
attention_ocr.pytorch
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
前往新版Gitcode,体验更适合开发者的 AI 搜索 >>
提交
11ae8de2
编写于
10月 14, 2017
作者:
X
xiaohang
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
add main.py
上级
3baa25ed
变更
1
隐藏空白更改
内联
并排
Showing
1 changed file
with
220 addition
and
0 deletion
+220
-0
main.py
main.py
+220
-0
未找到文件。
main.py
0 → 100644
浏览文件 @
11ae8de2
from
__future__
import
print_function
import
argparse
import
random
import
torch
import
torch.backends.cudnn
as
cudnn
import
torch.optim
as
optim
import
torch.utils.data
from
torch.autograd
import
Variable
import
numpy
as
np
from
warpctc_pytorch
import
CTCLoss
import
os
import
utils
import
dataset
import
time
import
models.crnn
as
crnn
parser
=
argparse
.
ArgumentParser
()
parser
.
add_argument
(
'--trainlist'
,
required
=
True
,
help
=
'path to train_list'
)
parser
.
add_argument
(
'--vallist'
,
required
=
True
,
help
=
'path to val_list'
)
parser
.
add_argument
(
'--workers'
,
type
=
int
,
help
=
'number of data loading workers'
,
default
=
2
)
parser
.
add_argument
(
'--batchSize'
,
type
=
int
,
default
=
64
,
help
=
'input batch size'
)
parser
.
add_argument
(
'--imgH'
,
type
=
int
,
default
=
32
,
help
=
'the height of the input image to network'
)
parser
.
add_argument
(
'--imgW'
,
type
=
int
,
default
=
100
,
help
=
'the width of the input image to network'
)
parser
.
add_argument
(
'--nh'
,
type
=
int
,
default
=
256
,
help
=
'size of the lstm hidden state'
)
parser
.
add_argument
(
'--niter'
,
type
=
int
,
default
=
25
,
help
=
'number of epochs to train for'
)
parser
.
add_argument
(
'--lr'
,
type
=
float
,
default
=
0.01
,
help
=
'learning rate for Critic, default=0.00005'
)
parser
.
add_argument
(
'--beta1'
,
type
=
float
,
default
=
0.5
,
help
=
'beta1 for adam. default=0.5'
)
parser
.
add_argument
(
'--cuda'
,
action
=
'store_true'
,
help
=
'enables cuda'
)
parser
.
add_argument
(
'--ngpu'
,
type
=
int
,
default
=
1
,
help
=
'number of GPUs to use'
)
parser
.
add_argument
(
'--crnn'
,
default
=
''
,
help
=
"path to crnn (to continue training)"
)
parser
.
add_argument
(
'--alphabet'
,
type
=
str
,
default
=
'0:1:2:3:4:5:6:7:8:9:a:b:c:d:e:f:g:h:i:j:k:l:m:n:o:p:q:r:s:t:u:v:w:x:y:z'
)
parser
.
add_argument
(
'--sep'
,
type
=
str
,
default
=
':'
)
parser
.
add_argument
(
'--experiment'
,
default
=
None
,
help
=
'Where to store samples and models'
)
parser
.
add_argument
(
'--displayInterval'
,
type
=
int
,
default
=
500
,
help
=
'Interval to be displayed'
)
parser
.
add_argument
(
'--n_test_disp'
,
type
=
int
,
default
=
10
,
help
=
'Number of samples to display when test'
)
parser
.
add_argument
(
'--valInterval'
,
type
=
int
,
default
=
500
,
help
=
'Interval to be displayed'
)
parser
.
add_argument
(
'--saveInterval'
,
type
=
int
,
default
=
500
,
help
=
'Interval to be displayed'
)
parser
.
add_argument
(
'--adam'
,
action
=
'store_true'
,
help
=
'Whether to use adam (default is rmsprop)'
)
parser
.
add_argument
(
'--adadelta'
,
action
=
'store_true'
,
help
=
'Whether to use adadelta (default is rmsprop)'
)
parser
.
add_argument
(
'--keep_ratio'
,
action
=
'store_true'
,
help
=
'whether to keep ratio for image resize'
)
parser
.
add_argument
(
'--random_sample'
,
action
=
'store_true'
,
help
=
'whether to sample the dataset with random sampler'
)
opt
=
parser
.
parse_args
()
print
(
opt
)
if
opt
.
experiment
is
None
:
opt
.
experiment
=
'expr'
os
.
system
(
'mkdir {0}'
.
format
(
opt
.
experiment
))
opt
.
manualSeed
=
random
.
randint
(
1
,
10000
)
# fix seed
print
(
"Random Seed: "
,
opt
.
manualSeed
)
random
.
seed
(
opt
.
manualSeed
)
np
.
random
.
seed
(
opt
.
manualSeed
)
torch
.
manual_seed
(
opt
.
manualSeed
)
cudnn
.
benchmark
=
True
if
torch
.
cuda
.
is_available
()
and
not
opt
.
cuda
:
print
(
"WARNING: You have a CUDA device, so you should probably run with --cuda"
)
train_dataset
=
dataset
.
listDataset
(
list_file
=
opt
.
trainlist
)
assert
train_dataset
if
not
opt
.
random_sample
:
sampler
=
dataset
.
randomSequentialSampler
(
train_dataset
,
opt
.
batchSize
)
else
:
sampler
=
None
train_loader
=
torch
.
utils
.
data
.
DataLoader
(
train_dataset
,
batch_size
=
opt
.
batchSize
,
shuffle
=
False
,
sampler
=
sampler
,
num_workers
=
int
(
opt
.
workers
),
collate_fn
=
dataset
.
alignCollate
(
imgH
=
opt
.
imgH
,
imgW
=
opt
.
imgW
,
keep_ratio
=
opt
.
keep_ratio
))
test_dataset
=
dataset
.
listDataset
(
list_file
=
opt
.
vallist
,
transform
=
dataset
.
resizeNormalize
((
100
,
32
)))
nclass
=
len
(
opt
.
alphabet
.
split
(
opt
.
sep
))
+
1
nc
=
1
converter
=
utils
.
strLabelConverter
(
opt
.
alphabet
,
opt
.
sep
)
criterion
=
CTCLoss
()
# custom weights initialization called on crnn
def
weights_init
(
m
):
classname
=
m
.
__class__
.
__name__
if
classname
.
find
(
'Conv'
)
!=
-
1
:
m
.
weight
.
data
.
normal_
(
0.0
,
0.02
)
elif
classname
.
find
(
'BatchNorm'
)
!=
-
1
:
m
.
weight
.
data
.
normal_
(
1.0
,
0.02
)
m
.
bias
.
data
.
fill_
(
0
)
crnn
=
crnn
.
CRNN
(
opt
.
imgH
,
nc
,
nclass
,
opt
.
nh
)
crnn
.
apply
(
weights_init
)
if
opt
.
crnn
!=
''
:
print
(
'loading pretrained model from %s'
%
opt
.
crnn
)
crnn
.
load_state_dict
(
torch
.
load
(
opt
.
crnn
))
print
(
crnn
)
image
=
torch
.
FloatTensor
(
opt
.
batchSize
,
3
,
opt
.
imgH
,
opt
.
imgH
)
text
=
torch
.
IntTensor
(
opt
.
batchSize
*
5
)
length
=
torch
.
IntTensor
(
opt
.
batchSize
)
if
opt
.
cuda
:
crnn
.
cuda
()
crnn
=
torch
.
nn
.
DataParallel
(
crnn
,
device_ids
=
range
(
opt
.
ngpu
))
image
=
image
.
cuda
()
criterion
=
criterion
.
cuda
()
image
=
Variable
(
image
)
text
=
Variable
(
text
)
length
=
Variable
(
length
)
# loss averager
loss_avg
=
utils
.
averager
()
# setup optimizer
if
opt
.
adam
:
optimizer
=
optim
.
Adam
(
crnn
.
parameters
(),
lr
=
opt
.
lr
,
betas
=
(
opt
.
beta1
,
0.999
))
elif
opt
.
adadelta
:
optimizer
=
optim
.
Adadelta
(
crnn
.
parameters
(),
lr
=
opt
.
lr
)
else
:
optimizer
=
optim
.
RMSprop
(
crnn
.
parameters
(),
lr
=
opt
.
lr
)
def
val
(
net
,
dataset
,
criterion
,
max_iter
=
100
):
print
(
'Start val'
)
for
p
in
crnn
.
parameters
():
p
.
requires_grad
=
False
net
.
eval
()
data_loader
=
torch
.
utils
.
data
.
DataLoader
(
dataset
,
shuffle
=
True
,
batch_size
=
opt
.
batchSize
,
num_workers
=
int
(
opt
.
workers
))
val_iter
=
iter
(
data_loader
)
i
=
0
n_correct
=
0
loss_avg
=
utils
.
averager
()
max_iter
=
min
(
max_iter
,
len
(
data_loader
))
for
i
in
range
(
max_iter
):
data
=
val_iter
.
next
()
i
+=
1
cpu_images
,
cpu_texts
=
data
batch_size
=
cpu_images
.
size
(
0
)
utils
.
loadData
(
image
,
cpu_images
)
t
,
l
=
converter
.
encode
(
cpu_texts
)
utils
.
loadData
(
text
,
t
)
utils
.
loadData
(
length
,
l
)
preds
=
crnn
(
image
)
preds_size
=
Variable
(
torch
.
IntTensor
([
preds
.
size
(
0
)]
*
batch_size
))
cost
=
criterion
(
preds
,
text
,
preds_size
,
length
)
/
batch_size
loss_avg
.
add
(
cost
)
_
,
preds
=
preds
.
max
(
2
)
#preds = preds.squeeze(2)
preds
=
preds
.
transpose
(
1
,
0
).
contiguous
().
view
(
-
1
)
sim_preds
=
converter
.
decode
(
preds
.
data
,
preds_size
.
data
,
raw
=
False
)
for
pred
,
target
in
zip
(
sim_preds
,
cpu_texts
):
target
=
''
.
join
(
target
.
split
(
opt
.
sep
))
if
pred
==
target
:
n_correct
+=
1
raw_preds
=
converter
.
decode
(
preds
.
data
,
preds_size
.
data
,
raw
=
True
)[:
opt
.
n_test_disp
]
for
raw_pred
,
pred
,
gt
in
zip
(
raw_preds
,
sim_preds
,
cpu_texts
):
gt
=
''
.
join
(
gt
.
split
(
opt
.
sep
))
print
(
'%-20s => %-20s, gt: %-20s'
%
(
raw_pred
,
pred
,
gt
))
accuracy
=
n_correct
/
float
(
max_iter
*
opt
.
batchSize
)
print
(
'Test loss: %f, accuray: %f'
%
(
loss_avg
.
val
(),
accuracy
))
def
trainBatch
(
net
,
criterion
,
optimizer
):
data
=
train_iter
.
next
()
cpu_images
,
cpu_texts
=
data
batch_size
=
cpu_images
.
size
(
0
)
utils
.
loadData
(
image
,
cpu_images
)
t
,
l
=
converter
.
encode
(
cpu_texts
)
utils
.
loadData
(
text
,
t
)
utils
.
loadData
(
length
,
l
)
preds
=
crnn
(
image
)
preds_size
=
Variable
(
torch
.
IntTensor
([
preds
.
size
(
0
)]
*
batch_size
))
cost
=
criterion
(
preds
,
text
,
preds_size
,
length
)
/
batch_size
crnn
.
zero_grad
()
cost
.
backward
()
optimizer
.
step
()
return
cost
t0
=
time
.
time
()
for
epoch
in
range
(
opt
.
niter
):
train_iter
=
iter
(
train_loader
)
i
=
0
while
i
<
len
(
train_loader
):
for
p
in
crnn
.
parameters
():
p
.
requires_grad
=
True
crnn
.
train
()
cost
=
trainBatch
(
crnn
,
criterion
,
optimizer
)
loss_avg
.
add
(
cost
)
i
+=
1
if
i
%
opt
.
displayInterval
==
0
:
print
(
'[%d/%d][%d/%d] Loss: %f'
%
(
epoch
,
opt
.
niter
,
i
,
len
(
train_loader
),
loss_avg
.
val
()))
loss_avg
.
reset
()
t1
=
time
.
time
()
print
(
'time elapsed %d'
%
(
t1
-
t0
))
t0
=
time
.
time
()
if
i
%
opt
.
valInterval
==
0
:
val
(
crnn
,
test_dataset
,
criterion
)
# do checkpointing
if
i
%
opt
.
saveInterval
==
0
:
torch
.
save
(
crnn
.
state_dict
(),
'{0}/netCRNN_{1}_{2}.pth'
.
format
(
opt
.
experiment
,
epoch
,
i
))
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录