Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
身在江湖的郭大侠
CnOCR
提交
113d790f
CnOCR
项目概览
身在江湖的郭大侠
/
CnOCR
与 Fork 源项目一致
Fork自
Cloud IDE / CnOCR
通知
1
Star
0
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
DevOps
流水线
流水线任务
计划
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
CnOCR
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
DevOps
DevOps
流水线
流水线任务
计划
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
流水线任务
提交
Issue看板
前往新版Gitcode,体验更适合开发者的 AI 搜索 >>
提交
113d790f
编写于
3月 10, 2019
作者:
B
breezedeus
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
多线程读取图片文件
上级
b7039110
变更
3
隐藏空白更改
内联
并排
Showing
3 changed file
with
116 addition
and
23 deletion
+116
-23
data_utils/data_iter.py
data_utils/data_iter.py
+80
-14
hyperparams/hyperparams.py
hyperparams/hyperparams.py
+2
-2
train.py
train.py
+34
-7
未找到文件。
data_utils/data_iter.py
浏览文件 @
113d790f
...
...
@@ -6,6 +6,8 @@ import numpy as np
import
mxnet
as
mx
import
random
from
.multiproc_data
import
MPData
class
SimpleBatch
(
object
):
def
__init__
(
self
,
data_names
,
data
,
label_names
=
list
(),
label
=
list
()):
...
...
@@ -170,11 +172,84 @@ class ImageIterLstm(mx.io.DataIter):
random
.
shuffle
(
self
.
dataset_lines
)
class
MPOcrImages
(
object
):
"""
Handles multi-process captcha image generation
"""
def
__init__
(
self
,
data_root
,
data_list
,
data_shape
,
num_label
,
num_processes
,
max_queue_size
):
"""
Parameters
----------
data_shape: [width, height]
num_processes: int
Number of processes to spawn
max_queue_size: int
Maximum images in queue before processes wait
"""
self
.
data_shape
=
data_shape
self
.
num_label
=
num_label
self
.
data_root
=
data_root
self
.
dataset_lines
=
open
(
data_list
).
readlines
()
self
.
mp_data
=
MPData
(
num_processes
,
max_queue_size
,
self
.
_gen_sample
)
def
_gen_sample
(
self
):
m_line
=
random
.
choice
(
self
.
dataset_lines
)
img_lst
=
m_line
.
strip
().
split
(
' '
)
img_path
=
os
.
path
.
join
(
self
.
data_root
,
img_lst
[
0
])
img
=
Image
.
open
(
img_path
).
resize
(
self
.
data_shape
,
Image
.
BILINEAR
).
convert
(
'L'
)
img
=
np
.
array
(
img
)
# print(img.shape)
img
=
np
.
transpose
(
img
,
(
1
,
0
))
# res: [1, width, height]
# if len(img.shape) == 2:
# img = np.expand_dims(np.transpose(img, (1, 0)), axis=0) # res: [1, width, height]
labels
=
np
.
zeros
(
self
.
num_label
,
int
)
for
idx
in
range
(
1
,
len
(
img_lst
)):
labels
[
idx
-
1
]
=
int
(
img_lst
[
idx
])
return
img
,
labels
@
property
def
size
(
self
):
return
len
(
self
.
dataset_lines
)
@
property
def
shape
(
self
):
return
self
.
data_shape
def
start
(
self
):
"""
Starts the processes
"""
self
.
mp_data
.
start
()
def
get
(
self
):
"""
Get an image from the queue
Returns
-------
np.ndarray
A captcha image, normalized to [0, 1]
"""
return
self
.
mp_data
.
get
()
def
reset
(
self
):
"""
Resets the generator by stopping all processes
"""
self
.
mp_data
.
reset
()
class
OCRIter
(
mx
.
io
.
DataIter
):
"""
Iterator class for generating captcha image data
"""
def
__init__
(
self
,
count
,
batch_size
,
lstm_init_states
,
captcha
,
name
):
def
__init__
(
self
,
count
,
batch_size
,
lstm_init_states
,
captcha
,
n
um_label
,
n
ame
):
"""
Parameters
----------
...
...
@@ -189,12 +264,12 @@ class OCRIter(mx.io.DataIter):
"""
super
(
OCRIter
,
self
).
__init__
()
self
.
batch_size
=
batch_size
self
.
count
=
count
self
.
count
=
count
if
count
>
0
else
captcha
.
size
self
.
init_states
=
lstm_init_states
self
.
init_state_arrays
=
[
mx
.
nd
.
zeros
(
x
[
1
])
for
x
in
lstm_init_states
]
data_shape
=
captcha
.
shape
self
.
provide_data
=
[(
'data'
,
(
batch_size
,
1
,
data_shape
[
1
],
data_shape
[
0
]))]
+
lstm_init_states
self
.
provide_label
=
[(
'label'
,
(
self
.
batch_size
,
4
))]
self
.
provide_label
=
[(
'label'
,
(
self
.
batch_size
,
num_label
))]
self
.
mp_captcha
=
captcha
self
.
name
=
name
...
...
@@ -204,12 +279,12 @@ class OCRIter(mx.io.DataIter):
data
=
[]
label
=
[]
for
i
in
range
(
self
.
batch_size
):
img
,
num
=
self
.
mp_captcha
.
get
()
img
,
labels
=
self
.
mp_captcha
.
get
()
# print(img.shape)
img
=
np
.
expand_dims
(
np
.
transpose
(
img
,
(
1
,
0
)),
axis
=
0
)
# size: [1, channel, height, width]
# import pdb; pdb.set_trace()
data
.
append
(
img
)
label
.
append
(
self
.
_get_label
(
num
)
)
label
.
append
(
labels
)
data_all
=
[
mx
.
nd
.
array
(
data
)]
+
self
.
init_state_arrays
label_all
=
[
mx
.
nd
.
array
(
label
)]
data_names
=
[
'data'
]
+
init_state_names
...
...
@@ -217,12 +292,3 @@ class OCRIter(mx.io.DataIter):
data_batch
=
SimpleBatch
(
data_names
,
data_all
,
label_names
,
label_all
)
yield
data_batch
@
classmethod
def
_get_label
(
cls
,
buf
):
ret
=
np
.
zeros
(
4
)
for
i
in
range
(
len
(
buf
)):
ret
[
i
]
=
1
+
int
(
buf
[
i
])
if
len
(
buf
)
==
3
:
ret
[
3
]
=
0
return
ret
hyperparams/hyperparams.py
浏览文件 @
113d790f
...
...
@@ -7,7 +7,7 @@ class Hyperparams(object):
"""
def
__init__
(
self
):
# Training hyper parameters
self
.
_train_epoch_size
=
3000
0
self
.
_train_epoch_size
=
0
self
.
_eval_epoch_size
=
3000
self
.
_num_epoch
=
20
self
.
_learning_rate
=
0.001
...
...
@@ -17,7 +17,7 @@ class Hyperparams(object):
self
.
_loss_type
=
"ctc"
# ["warpctc" "ctc"]
self
.
_batch_size
=
128
self
.
_num_classes
=
5990
self
.
_num_classes
=
6425
#
5990
self
.
_img_width
=
280
self
.
_img_height
=
32
...
...
train.py
浏览文件 @
113d790f
...
...
@@ -9,7 +9,7 @@ from data_utils.captcha_generator import MPDigitCaptcha
from
hyperparams.hyperparams
import
Hyperparams
from
hyperparams.hyperparams2
import
Hyperparams
as
Hyperparams2
from
data_utils.data_iter
import
ImageIterLstm
,
OCRIter
from
data_utils.data_iter
import
ImageIterLstm
,
MPOcrImages
,
OCRIter
from
symbols.crnn
import
crnn_no_lstm
,
crnn_lstm
from
fit.ctc_metrics
import
CtcMetrics
from
fit.fit
import
fit
...
...
@@ -80,9 +80,11 @@ def run_captcha(args):
data_names
=
[
'data'
]
+
[
x
[
0
]
for
x
in
init_states
]
data_train
=
OCRIter
(
hp
.
train_epoch_size
//
hp
.
batch_size
,
hp
.
batch_size
,
init_states
,
captcha
=
mp_captcha
,
name
=
'train'
)
hp
.
train_epoch_size
//
hp
.
batch_size
,
hp
.
batch_size
,
init_states
,
captcha
=
mp_captcha
,
num_label
=
hp
.
num_label
,
name
=
'train'
)
data_val
=
OCRIter
(
hp
.
eval_epoch_size
//
hp
.
batch_size
,
hp
.
batch_size
,
init_states
,
captcha
=
mp_captcha
,
name
=
'val'
)
hp
.
eval_epoch_size
//
hp
.
batch_size
,
hp
.
batch_size
,
init_states
,
captcha
=
mp_captcha
,
num_label
=
hp
.
num_label
,
name
=
'val'
)
head
=
'%(asctime)-15s %(message)s'
logging
.
basicConfig
(
level
=
logging
.
DEBUG
,
format
=
head
)
...
...
@@ -99,15 +101,37 @@ def run_cn_ocr(args):
network
=
crnn_lstm
(
hp
)
mp_data_train
=
MPOcrImages
(
args
.
data_root
,
args
.
train_file
,
(
hp
.
img_width
,
hp
.
img_height
),
hp
.
num_label
,
num_processes
=
args
.
num_proc
,
max_queue_size
=
hp
.
batch_size
*
2
)
# img, num = mp_data_train.get()
# print(img.shape)
# print(mp_data_train.shape)
# import pdb; pdb.set_trace()
# import numpy as np
# import cv2
# img = np.transpose(img, (1, 0))
# cv2.imwrite('xxx.png', img * 255)
# import pdb; pdb.set_trace()
mp_data_test
=
MPOcrImages
(
args
.
data_root
,
args
.
test_file
,
(
hp
.
img_width
,
hp
.
img_height
),
hp
.
num_label
,
num_processes
=
args
.
num_proc
,
max_queue_size
=
hp
.
batch_size
*
2
)
mp_data_train
.
start
()
mp_data_test
.
start
()
init_c
=
[(
'l%d_init_c'
%
l
,
(
hp
.
batch_size
,
hp
.
num_hidden
))
for
l
in
range
(
hp
.
num_lstm_layer
*
2
)]
init_h
=
[(
'l%d_init_h'
%
l
,
(
hp
.
batch_size
,
hp
.
num_hidden
))
for
l
in
range
(
hp
.
num_lstm_layer
*
2
)]
init_states
=
init_c
+
init_h
data_names
=
[
'data'
]
+
[
x
[
0
]
for
x
in
init_states
]
data_train
=
ImageIterLstm
(
args
.
data_root
,
args
.
train_file
,
hp
.
batch_size
,
(
hp
.
img_width
,
hp
.
img_height
),
hp
.
num_label
,
init_states
,
name
=
"train"
)
data_val
=
ImageIterLstm
(
args
.
data_root
,
args
.
test_file
,
hp
.
batch_size
,
(
hp
.
img_width
,
hp
.
img_height
),
hp
.
num_label
,
init_states
,
name
=
"val"
)
data_train
=
OCRIter
(
hp
.
train_epoch_size
//
hp
.
batch_size
,
hp
.
batch_size
,
init_states
,
captcha
=
mp_data_train
,
num_label
=
hp
.
num_label
,
name
=
'train'
)
data_val
=
OCRIter
(
hp
.
train_epoch_size
//
hp
.
batch_size
,
hp
.
batch_size
,
init_states
,
captcha
=
mp_data_test
,
num_label
=
hp
.
num_label
,
name
=
'val'
)
# data_train = ImageIterLstm(
# args.data_root, args.train_file, hp.batch_size, (hp.img_width, hp.img_height), hp.num_label, init_states, name="train")
# data_val = ImageIterLstm(
# args.data_root, args.test_file, hp.batch_size, (hp.img_width, hp.img_height), hp.num_label, init_states, name="val")
head
=
'%(asctime)-15s %(message)s'
logging
.
basicConfig
(
level
=
logging
.
DEBUG
,
format
=
head
)
...
...
@@ -116,6 +140,9 @@ def run_cn_ocr(args):
fit
(
network
=
network
,
data_train
=
data_train
,
data_val
=
data_val
,
metrics
=
metrics
,
args
=
args
,
hp
=
hp
,
data_names
=
data_names
)
mp_data_train
.
reset
()
mp_data_test
.
start
()
if
__name__
==
'__main__'
:
args
=
parse_args
()
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录