提交 113d790f 编写于 作者: B breezedeus

多线程读取图片文件

上级 b7039110
......@@ -6,6 +6,8 @@ import numpy as np
import mxnet as mx
import random
from .multiproc_data import MPData
class SimpleBatch(object):
def __init__(self, data_names, data, label_names=list(), label=list()):
......@@ -170,11 +172,84 @@ class ImageIterLstm(mx.io.DataIter):
random.shuffle(self.dataset_lines)
class MPOcrImages(object):
"""
Handles multi-process captcha image generation
"""
def __init__(self, data_root, data_list, data_shape, num_label, num_processes, max_queue_size):
"""
Parameters
----------
data_shape: [width, height]
num_processes: int
Number of processes to spawn
max_queue_size: int
Maximum images in queue before processes wait
"""
self.data_shape = data_shape
self.num_label = num_label
self.data_root = data_root
self.dataset_lines = open(data_list).readlines()
self.mp_data = MPData(num_processes, max_queue_size, self._gen_sample)
def _gen_sample(self):
m_line = random.choice(self.dataset_lines)
img_lst = m_line.strip().split(' ')
img_path = os.path.join(self.data_root, img_lst[0])
img = Image.open(img_path).resize(self.data_shape, Image.BILINEAR).convert('L')
img = np.array(img)
# print(img.shape)
img = np.transpose(img, (1, 0)) # res: [1, width, height]
# if len(img.shape) == 2:
# img = np.expand_dims(np.transpose(img, (1, 0)), axis=0) # res: [1, width, height]
labels = np.zeros(self.num_label, int)
for idx in range(1, len(img_lst)):
labels[idx - 1] = int(img_lst[idx])
return img, labels
@property
def size(self):
return len(self.dataset_lines)
@property
def shape(self):
return self.data_shape
def start(self):
"""
Starts the processes
"""
self.mp_data.start()
def get(self):
"""
Get an image from the queue
Returns
-------
np.ndarray
A captcha image, normalized to [0, 1]
"""
return self.mp_data.get()
def reset(self):
"""
Resets the generator by stopping all processes
"""
self.mp_data.reset()
class OCRIter(mx.io.DataIter):
"""
Iterator class for generating captcha image data
"""
def __init__(self, count, batch_size, lstm_init_states, captcha, name):
def __init__(self, count, batch_size, lstm_init_states, captcha, num_label, name):
"""
Parameters
----------
......@@ -189,12 +264,12 @@ class OCRIter(mx.io.DataIter):
"""
super(OCRIter, self).__init__()
self.batch_size = batch_size
self.count = count
self.count = count if count > 0 else captcha.size
self.init_states = lstm_init_states
self.init_state_arrays = [mx.nd.zeros(x[1]) for x in lstm_init_states]
data_shape = captcha.shape
self.provide_data = [('data', (batch_size, 1, data_shape[1], data_shape[0]))] + lstm_init_states
self.provide_label = [('label', (self.batch_size, 4))]
self.provide_label = [('label', (self.batch_size, num_label))]
self.mp_captcha = captcha
self.name = name
......@@ -204,12 +279,12 @@ class OCRIter(mx.io.DataIter):
data = []
label = []
for i in range(self.batch_size):
img, num = self.mp_captcha.get()
img, labels = self.mp_captcha.get()
# print(img.shape)
img = np.expand_dims(np.transpose(img, (1, 0)), axis=0) # size: [1, channel, height, width]
# import pdb; pdb.set_trace()
data.append(img)
label.append(self._get_label(num))
label.append(labels)
data_all = [mx.nd.array(data)] + self.init_state_arrays
label_all = [mx.nd.array(label)]
data_names = ['data'] + init_state_names
......@@ -217,12 +292,3 @@ class OCRIter(mx.io.DataIter):
data_batch = SimpleBatch(data_names, data_all, label_names, label_all)
yield data_batch
@classmethod
def _get_label(cls, buf):
ret = np.zeros(4)
for i in range(len(buf)):
ret[i] = 1 + int(buf[i])
if len(buf) == 3:
ret[3] = 0
return ret
......@@ -7,7 +7,7 @@ class Hyperparams(object):
"""
def __init__(self):
# Training hyper parameters
self._train_epoch_size = 30000
self._train_epoch_size = 0
self._eval_epoch_size = 3000
self._num_epoch = 20
self._learning_rate = 0.001
......@@ -17,7 +17,7 @@ class Hyperparams(object):
self._loss_type = "ctc" # ["warpctc" "ctc"]
self._batch_size = 128
self._num_classes = 5990
self._num_classes = 6425 # 5990
self._img_width = 280
self._img_height = 32
......
......@@ -9,7 +9,7 @@ from data_utils.captcha_generator import MPDigitCaptcha
from hyperparams.hyperparams import Hyperparams
from hyperparams.hyperparams2 import Hyperparams as Hyperparams2
from data_utils.data_iter import ImageIterLstm, OCRIter
from data_utils.data_iter import ImageIterLstm, MPOcrImages, OCRIter
from symbols.crnn import crnn_no_lstm, crnn_lstm
from fit.ctc_metrics import CtcMetrics
from fit.fit import fit
......@@ -80,9 +80,11 @@ def run_captcha(args):
data_names = ['data'] + [x[0] for x in init_states]
data_train = OCRIter(
hp.train_epoch_size // hp.batch_size, hp.batch_size, init_states, captcha=mp_captcha, name='train')
hp.train_epoch_size // hp.batch_size, hp.batch_size, init_states, captcha=mp_captcha, num_label=hp.num_label,
name='train')
data_val = OCRIter(
hp.eval_epoch_size // hp.batch_size, hp.batch_size, init_states, captcha=mp_captcha, name='val')
hp.eval_epoch_size // hp.batch_size, hp.batch_size, init_states, captcha=mp_captcha, num_label=hp.num_label,
name='val')
head = '%(asctime)-15s %(message)s'
logging.basicConfig(level=logging.DEBUG, format=head)
......@@ -99,15 +101,37 @@ def run_cn_ocr(args):
network = crnn_lstm(hp)
mp_data_train = MPOcrImages(args.data_root, args.train_file, (hp.img_width, hp.img_height), hp.num_label,
num_processes=args.num_proc, max_queue_size=hp.batch_size * 2)
# img, num = mp_data_train.get()
# print(img.shape)
# print(mp_data_train.shape)
# import pdb; pdb.set_trace()
# import numpy as np
# import cv2
# img = np.transpose(img, (1, 0))
# cv2.imwrite('xxx.png', img * 255)
# import pdb; pdb.set_trace()
mp_data_test = MPOcrImages(args.data_root, args.test_file, (hp.img_width, hp.img_height), hp.num_label,
num_processes=args.num_proc, max_queue_size=hp.batch_size * 2)
mp_data_train.start()
mp_data_test.start()
init_c = [('l%d_init_c' % l, (hp.batch_size, hp.num_hidden)) for l in range(hp.num_lstm_layer * 2)]
init_h = [('l%d_init_h' % l, (hp.batch_size, hp.num_hidden)) for l in range(hp.num_lstm_layer * 2)]
init_states = init_c + init_h
data_names = ['data'] + [x[0] for x in init_states]
data_train = ImageIterLstm(
args.data_root, args.train_file, hp.batch_size, (hp.img_width, hp.img_height), hp.num_label, init_states, name="train")
data_val = ImageIterLstm(
args.data_root, args.test_file, hp.batch_size, (hp.img_width, hp.img_height), hp.num_label, init_states, name="val")
data_train = OCRIter(
hp.train_epoch_size // hp.batch_size, hp.batch_size, init_states, captcha=mp_data_train, num_label=hp.num_label,
name='train')
data_val = OCRIter(
hp.train_epoch_size // hp.batch_size, hp.batch_size, init_states, captcha=mp_data_test, num_label=hp.num_label,
name='val')
# data_train = ImageIterLstm(
# args.data_root, args.train_file, hp.batch_size, (hp.img_width, hp.img_height), hp.num_label, init_states, name="train")
# data_val = ImageIterLstm(
# args.data_root, args.test_file, hp.batch_size, (hp.img_width, hp.img_height), hp.num_label, init_states, name="val")
head = '%(asctime)-15s %(message)s'
logging.basicConfig(level=logging.DEBUG, format=head)
......@@ -116,6 +140,9 @@ def run_cn_ocr(args):
fit(network=network, data_train=data_train, data_val=data_val, metrics=metrics, args=args, hp=hp, data_names=data_names)
mp_data_train.reset()
mp_data_test.start()
if __name__ == '__main__':
args = parse_args()
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册