Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
GoAI
attention_ocr.pytorch
提交
b6672e43
A
attention_ocr.pytorch
项目概览
GoAI
/
attention_ocr.pytorch
与 Fork 源项目一致
从无法访问的项目Fork
通知
1
Star
0
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
A
attention_ocr.pytorch
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
前往新版Gitcode,体验更适合开发者的 AI 搜索 >>
提交
b6672e43
编写于
10月 14, 2017
作者:
X
xiaohang
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
add listdataset
上级
e79773da
变更
3
隐藏空白更改
内联
并排
Showing
3 changed file
with
51 addition
and
5 deletion
+51
-5
crnn_main.py
crnn_main.py
+16
-4
dataset.py
dataset.py
+33
-0
run.sh
run.sh
+2
-1
未找到文件。
crnn_main.py
浏览文件 @
b6672e43
...
...
@@ -11,11 +11,13 @@ from warpctc_pytorch import CTCLoss
import
os
import
utils
import
dataset
import
time
import
models.crnn
as
crnn
parser
=
argparse
.
ArgumentParser
()
parser
.
add_argument
(
'--trainroot'
,
required
=
True
,
help
=
'path to dataset'
)
parser
.
add_argument
(
'--trainroot'
,
default
=
""
,
help
=
'path to dataset'
)
parser
.
add_argument
(
'--trainlist'
,
default
=
""
,
help
=
'path to train_list'
)
parser
.
add_argument
(
'--valroot'
,
required
=
True
,
help
=
'path to dataset'
)
parser
.
add_argument
(
'--workers'
,
type
=
int
,
help
=
'number of data loading workers'
,
default
=
2
)
parser
.
add_argument
(
'--batchSize'
,
type
=
int
,
default
=
64
,
help
=
'input batch size'
)
...
...
@@ -56,7 +58,13 @@ cudnn.benchmark = True
if
torch
.
cuda
.
is_available
()
and
not
opt
.
cuda
:
print
(
"WARNING: You have a CUDA device, so you should probably run with --cuda"
)
train_dataset
=
dataset
.
lmdbDataset
(
root
=
opt
.
trainroot
)
if
opt
.
trainroot
!=
""
:
train_dataset
=
dataset
.
lmdbDataset
(
root
=
opt
.
trainroot
)
elif
opt
.
trainlist
!=
""
:
train_dataset
=
dataset
.
listDataset
(
list_file
=
opt
.
trainlist
)
else
:
print
(
"no train data, exit"
)
exit
(
0
)
assert
train_dataset
if
not
opt
.
random_sample
:
sampler
=
dataset
.
randomSequentialSampler
(
train_dataset
,
opt
.
batchSize
)
...
...
@@ -64,7 +72,7 @@ else:
sampler
=
None
train_loader
=
torch
.
utils
.
data
.
DataLoader
(
train_dataset
,
batch_size
=
opt
.
batchSize
,
shuffle
=
Tru
e
,
sampler
=
sampler
,
shuffle
=
Fals
e
,
sampler
=
sampler
,
num_workers
=
int
(
opt
.
workers
),
collate_fn
=
dataset
.
alignCollate
(
imgH
=
opt
.
imgH
,
imgW
=
opt
.
imgW
,
keep_ratio
=
opt
.
keep_ratio
))
test_dataset
=
dataset
.
lmdbDataset
(
...
...
@@ -153,7 +161,7 @@ def val(net, dataset, criterion, max_iter=100):
loss_avg
.
add
(
cost
)
_
,
preds
=
preds
.
max
(
2
)
preds
=
preds
.
squeeze
(
2
)
#
preds = preds.squeeze(2)
preds
=
preds
.
transpose
(
1
,
0
).
contiguous
().
view
(
-
1
)
sim_preds
=
converter
.
decode
(
preds
.
data
,
preds_size
.
data
,
raw
=
False
)
for
pred
,
target
in
zip
(
sim_preds
,
cpu_texts
):
...
...
@@ -186,6 +194,7 @@ def trainBatch(net, criterion, optimizer):
return
cost
t0
=
time
.
time
()
for
epoch
in
range
(
opt
.
niter
):
train_iter
=
iter
(
train_loader
)
i
=
0
...
...
@@ -202,6 +211,9 @@ for epoch in range(opt.niter):
print
(
'[%d/%d][%d/%d] Loss: %f'
%
(
epoch
,
opt
.
niter
,
i
,
len
(
train_loader
),
loss_avg
.
val
()))
loss_avg
.
reset
()
t1
=
time
.
time
()
print
(
'time elapsed %d'
%
(
t1
-
t0
))
t0
=
time
.
time
()
if
i
%
opt
.
valInterval
==
0
:
val
(
crnn
,
test_dataset
,
criterion
)
...
...
dataset.py
浏览文件 @
b6672e43
...
...
@@ -13,6 +13,39 @@ from PIL import Image
import
numpy
as
np
class
listDataset
(
Dataset
):
def
__init__
(
self
,
list_file
=
None
,
transform
=
None
,
target_transform
=
None
):
with
open
(
list_file
)
as
fp
:
self
.
lines
=
fp
.
readlines
()
self
.
nSamples
=
len
(
self
.
lines
)
self
.
transform
=
transform
self
.
target_transform
=
target_transform
def
__len__
(
self
):
return
self
.
nSamples
def
__getitem__
(
self
,
index
):
assert
index
<=
len
(
self
),
'index range error'
index
+=
1
imgpath
=
self
.
lines
[
index
].
strip
()
try
:
img
=
Image
.
open
(
imgpath
).
convert
(
'L'
)
except
IOError
:
print
(
'Corrupted image for %d'
%
index
)
return
self
[
index
+
1
]
if
self
.
transform
is
not
None
:
img
=
self
.
transform
(
img
)
label
=
imgpath
.
split
(
'/'
)[
-
1
].
split
(
'_'
)[
1
].
lower
()
if
self
.
target_transform
is
not
None
:
label
=
self
.
target_transform
(
label
)
return
(
img
,
label
)
class
lmdbDataset
(
Dataset
):
def
__init__
(
self
,
root
=
None
,
transform
=
None
,
target_transform
=
None
):
...
...
run.sh
浏览文件 @
b6672e43
nohup
python crnn_main.py
--trainroot
../../PyTorch/crnn/tool/data/train_lmdb/
--valroot
../../PyTorch/crnn/tool/data/test_lmdb/
--cuda
--adam
--lr
=
0.001
>
log_adam.txt &
#python main.py --trainroot ../PyTorch/crnn/tool/data/train_lmdb/ --valroot ../PyTorch/crnn/tool/data/test_lmdb/ --cuda --adam --lr=0.001
python main.py
--trainlist
train_list.txt
--valroot
../PyTorch/crnn/tool/data/test_lmdb/
--cuda
--adam
--lr
=
0.001
# train_list could be annotation_train.txt
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录