提交 d411fb81 编写于 作者: B breezedeus

refactoring

上级 f7a105df
此差异已折叠。
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
""" Helper classes for multiprocess captcha image generation
This module also provides script for saving captcha images to file using CLI.
"""
from __future__ import print_function
import random
from captcha.image import ImageCaptcha
import cv2
from .multiproc_data import MPData
import numpy as np
class CaptchaGen(object):
"""
Generates a captcha image
"""
def __init__(self, h, w, font_paths):
"""
Parameters
----------
h: int
Height of the generated images
w: int
Width of the generated images
font_paths: list of str
List of all fonts in ttf format
"""
self.captcha = ImageCaptcha(fonts=font_paths)
self.h = h
self.w = w
def image(self, captcha_str):
"""
Generate a greyscale captcha image representing number string
Parameters
----------
captcha_str: str
string a characters for captcha image
Returns
-------
numpy.ndarray
Generated greyscale image in np.ndarray float type with values normalized to [0, 1]
"""
img = self.captcha.generate(captcha_str)
img = np.fromstring(img.getvalue(), dtype='uint8')
img = cv2.imdecode(img, cv2.IMREAD_GRAYSCALE)
img = cv2.resize(img, (self.h, self.w))
img = img.transpose(1, 0)
img = np.multiply(img, 1 / 255.0)
return img
class DigitCaptcha(object):
"""
Provides shape() and get() interface for digit-captcha image generation
"""
def __init__(self, font_paths, h, w, num_digit_min, num_digit_max):
"""
Parameters
----------
font_paths: list of str
List of path to ttf font files
h: int
height of the generated image
w: int
width of the generated image
num_digit_min: int
minimum number of digits in generated captcha image
num_digit_max: int
maximum number of digits in generated captcha image
"""
self.num_digit_min = num_digit_min
self.num_digit_max = num_digit_max
self.captcha = CaptchaGen(h=h, w=w, font_paths=font_paths)
@property
def shape(self):
"""
Returns shape of the image data generated
Returns
-------
tuple(int, int)
"""
return self.captcha.h, self.captcha.w
def get(self):
"""
Get an image from the queue
Returns
-------
np.ndarray
A captcha image, normalized to [0, 1]
"""
return self._gen_sample()
@staticmethod
def get_rand(num_digit_min, num_digit_max):
"""
Generates a character string of digits. Number of digits are
between self.num_digit_min and self.num_digit_max
Returns
-------
str
"""
buf = ""
max_len = random.randint(num_digit_min, num_digit_max)
for i in range(max_len):
buf += str(random.randint(0, 9))
return buf
def _gen_sample(self):
"""
Generate a random captcha image sample
Returns
-------
(numpy.ndarray, str)
Tuple of image (numpy ndarray) and character string of digits used to generate the image
"""
num_str = self.get_rand(self.num_digit_min, self.num_digit_max)
return self.captcha.image(num_str), num_str
class MPDigitCaptcha(DigitCaptcha):
"""
Handles multi-process captcha image generation
"""
def __init__(self, font_paths, h, w, num_digit_min, num_digit_max, num_processes, max_queue_size):
"""
Parameters
----------
font_paths: list of str
List of path to ttf font files
h: int
height of the generated image
w: int
width of the generated image
num_digit_min: int
minimum number of digits in generated captcha image
num_digit_max: int
maximum number of digits in generated captcha image
num_processes: int
Number of processes to spawn
max_queue_size: int
Maximum images in queue before processes wait
"""
super(MPDigitCaptcha, self).__init__(font_paths, h, w, num_digit_min, num_digit_max)
self.mp_data = MPData(num_processes, max_queue_size, self._gen_sample)
def start(self):
"""
Starts the processes
"""
self.mp_data.start()
def get(self):
"""
Get an image from the queue
Returns
-------
np.ndarray
A captcha image, normalized to [0, 1]
"""
return self.mp_data.get()
def reset(self):
"""
Resets the generator by stopping all processes
"""
self.mp_data.reset()
if __name__ == '__main__':
import argparse
def main():
parser = argparse.ArgumentParser()
parser.add_argument("font_path", help="Path to ttf font file")
parser.add_argument("output", help="Output filename including extension (e.g. 'sample.jpg')")
parser.add_argument("--num", help="Up to 4 digit number [Default: random]")
args = parser.parse_args()
captcha = ImageCaptcha(fonts=[args.font_path])
captcha_str = args.num if args.num else DigitCaptcha.get_rand(3, 4)
img = captcha.generate(captcha_str)
img = np.fromstring(img.getvalue(), dtype='uint8')
img = cv2.imdecode(img, cv2.IMREAD_GRAYSCALE)
cv2.imwrite(args.output, img)
print("Captcha image with digits {} written to {}".format([int(c) for c in captcha_str], args.output))
main()
from __future__ import print_function
import os
from PIL import Image
import numpy as np
import mxnet as mx
import random
from .multiproc_data import MPData
class SimpleBatch(object):
def __init__(self, data_names, data, label_names=list(), label=list()):
self._data = data
self._label = label
self._data_names = data_names
self._label_names = label_names
self.pad = 0
self.index = None # TODO: what is index?
@property
def data(self):
return self._data
@property
def label(self):
return self._label
@property
def data_names(self):
return self._data_names
@property
def label_names(self):
return self._label_names
@property
def provide_data(self):
return [(n, x.shape) for n, x in zip(self._data_names, self._data)]
@property
def provide_label(self):
return [(n, x.shape) for n, x in zip(self._label_names, self._label)]
# class ImageIter(mx.io.DataIter):
#
# """
# Iterator class for generating captcha image data
# """
# def __init__(self, data_root, data_list, batch_size, data_shape, num_label, name=None):
# """
# Parameters
# ----------
# data_root: str
# root directory of images
# data_list: str
# a .txt file stores the image name and corresponding labels for each line
# batch_size: int
# name: str
# """
# super(ImageIter, self).__init__()
# self.batch_size = batch_size
# self.data_shape = data_shape
# self.num_label = num_label
#
# self.data_root = data_root
# self.dataset_lst_file = open(data_list)
#
# self.provide_data = [('data', (batch_size, 1, data_shape[1], data_shape[0]))]
# self.provide_label = [('label', (self.batch_size, self.num_label))]
# self.name = name
#
# def __iter__(self):
# data = []
# label = []
# cnt = 0
# for m_line in self.dataset_lst_file:
# img_lst = m_line.strip().split(' ')
# img_path = os.path.join(self.data_root, img_lst[0])
#
# cnt += 1
# img = Image.open(img_path).resize(self.data_shape, Image.BILINEAR).convert('L')
# img = np.array(img).reshape((1, self.data_shape[1], self.data_shape[0]))
# data.append(img)
#
# ret = np.zeros(self.num_label, int)
# for idx in range(1, len(img_lst)):
# ret[idx-1] = int(img_lst[idx])
#
# label.append(ret)
# if cnt % self.batch_size == 0:
# data_all = [mx.nd.array(data)]
# label_all = [mx.nd.array(label)]
# data_names = ['data']
# label_names = ['label']
# data.clear()
# label.clear()
# yield SimpleBatch(data_names, data_all, label_names, label_all)
# continue
#
#
# def reset(self):
# if self.dataset_lst_file.seekable():
# self.dataset_lst_file.seek(0)
class ImageIterLstm(mx.io.DataIter):
"""
Iterator class for generating captcha image data
"""
def __init__(self, data_root, data_list, batch_size, data_shape, num_label, lstm_init_states, name=None):
"""
Parameters
----------
data_root: str
root directory of images
data_list: str
a .txt file stores the image name and corresponding labels for each line
batch_size: int
name: str
"""
super(ImageIterLstm, self).__init__()
self.batch_size = batch_size
self.data_shape = data_shape
self.num_label = num_label
self.init_states = lstm_init_states
self.init_state_arrays = [mx.nd.zeros(x[1]) for x in lstm_init_states]
self.data_root = data_root
self.dataset_lines = open(data_list).readlines()
self.provide_data = [('data', (batch_size, 1, data_shape[1], data_shape[0]))] + lstm_init_states
self.provide_label = [('label', (self.batch_size, self.num_label))]
self.name = name
def __iter__(self):
init_state_names = [x[0] for x in self.init_states]
data = []
label = []
cnt = 0
for m_line in self.dataset_lines:
img_lst = m_line.strip().split(' ')
img_path = os.path.join(self.data_root, img_lst[0])
cnt += 1
img = Image.open(img_path).resize(self.data_shape, Image.BILINEAR).convert('L')
img = np.array(img).reshape((1, self.data_shape[1], self.data_shape[0])) # res: [1, height, width]
data.append(img)
ret = np.zeros(self.num_label, int)
for idx in range(1, len(img_lst)):
ret[idx - 1] = int(img_lst[idx])
label.append(ret)
if cnt % self.batch_size == 0:
data_all = [mx.nd.array(data)] + self.init_state_arrays
label_all = [mx.nd.array(label)]
data_names = ['data'] + init_state_names
label_names = ['label']
data = []
label = []
yield SimpleBatch(data_names, data_all, label_names, label_all)
continue
def reset(self):
# if self.dataset_lst_file.seekable():
# self.dataset_lst_file.seek(0)
random.shuffle(self.dataset_lines)
class MPOcrImages(object):
"""
Handles multi-process Chinese OCR image generation
"""
def __init__(self, data_root, data_list, data_shape, num_label, num_processes, max_queue_size):
"""
Parameters
----------
data_shape: [width, height]
num_processes: int
Number of processes to spawn
max_queue_size: int
Maximum images in queue before processes wait
"""
self.data_shape = data_shape
self.num_label = num_label
self.data_root = data_root
self.dataset_lines = open(data_list).readlines()
self.mp_data = MPData(num_processes, max_queue_size, self._gen_sample)
def _gen_sample(self):
m_line = random.choice(self.dataset_lines)
img_lst = m_line.strip().split(' ')
img_path = os.path.join(self.data_root, img_lst[0])
img = Image.open(img_path).resize(self.data_shape, Image.BILINEAR).convert('L')
img = np.array(img)
# print(img.shape)
img = np.transpose(img, (1, 0)) # res: [1, width, height]
# if len(img.shape) == 2:
# img = np.expand_dims(np.transpose(img, (1, 0)), axis=0) # res: [1, width, height]
labels = np.zeros(self.num_label, int)
for idx in range(1, len(img_lst)):
labels[idx - 1] = int(img_lst[idx])
return img, labels
@property
def size(self):
return len(self.dataset_lines)
@property
def shape(self):
return self.data_shape
def start(self):
"""
Starts the processes
"""
self.mp_data.start()
def get(self):
"""
Get an image from the queue
Returns
-------
np.ndarray
A captcha image, normalized to [0, 1]
"""
return self.mp_data.get()
def reset(self):
"""
Resets the generator by stopping all processes
"""
self.mp_data.reset()
class OCRIter(mx.io.DataIter):
"""
Iterator class for generating captcha image data
"""
def __init__(self, count, batch_size, lstm_init_states, captcha, num_label, name):
"""
Parameters
----------
count: int
Number of batches to produce for one epoch
batch_size: int
lstm_init_states: list of tuple(str, tuple)
A list of tuples with [0] name and [1] shape of each LSTM init state
captcha MPCaptcha
Captcha image generator. Can be MPCaptcha or any other class providing .shape and .get() interface
name: str
"""
super(OCRIter, self).__init__()
self.batch_size = batch_size
self.count = count if count > 0 else captcha.size // batch_size
self.init_states = lstm_init_states
self.init_state_arrays = [mx.nd.zeros(x[1]) for x in lstm_init_states]
data_shape = captcha.shape
self.provide_data = [('data', (batch_size, 1, data_shape[1], data_shape[0]))] + lstm_init_states
self.provide_label = [('label', (self.batch_size, num_label))]
self.mp_captcha = captcha
self.name = name
def __iter__(self):
init_state_names = [x[0] for x in self.init_states]
for k in range(self.count):
data = []
label = []
for i in range(self.batch_size):
img, labels = self.mp_captcha.get()
# print(img.shape)
img = np.expand_dims(np.transpose(img, (1, 0)), axis=0) # size: [1, height, width]
# import pdb; pdb.set_trace()
data.append(img)
label.append(labels)
data_all = [mx.nd.array(data)] + self.init_state_arrays
label_all = [mx.nd.array(label)]
data_names = ['data'] + init_state_names
label_names = ['label']
data_batch = SimpleBatch(data_names, data_all, label_names, label_all)
yield data_batch
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
from __future__ import print_function
from ctypes import c_bool
import multiprocessing as mp
try:
from queue import Full as QFullExcept
from queue import Empty as QEmptyExcept
except ImportError as error:
raise error
# import numpy as np
class MPData(object):
"""
Handles multi-process data generation.
Operation:
- call start() to start the data generation
- call get() (blocking) to read one sample
- call reset() to stop data generation
"""
def __init__(self, num_processes, max_queue_size, fn):
"""
Parameters
----------
num_processes: int
Number of processes to spawn
max_queue_size: int
Maximum samples in the queue before processes wait
fn: function
function that generates samples, executed on separate processes.
"""
self.queue = mp.Queue(maxsize=int(max_queue_size))
self.alive = mp.Value(c_bool, False, lock=False)
self.num_proc = num_processes
self.proc = list()
self.fn = fn
def start(self):
"""
Starts the processes
Parameters
----------
fn: function
"""
"""
Starts the processes
"""
self._init_proc()
@staticmethod
def _proc_loop(proc_id, alive, queue, fn):
"""
Thread loop for generating data
Parameters
----------
proc_id: int
Process id
alive: multiprocessing.Value
variable for signaling whether process should continue or not
queue: multiprocessing.Queue
queue for passing data back
fn: function
function object that returns a sample to be pushed into the queue
"""
print("proc {} started".format(proc_id))
try:
while alive.value:
data = fn()
put_success = False
while alive.value and not put_success:
try:
queue.put(data, timeout=0.5)
put_success = True
except QFullExcept:
# print("Queue Full")
pass
except KeyboardInterrupt:
print("W: interrupt received, stopping process {} ...".format(proc_id))
print("Closing process {}".format(proc_id))
queue.close()
def _init_proc(self):
"""
Start processes if not already started
"""
if not self.proc:
self.proc = [
mp.Process(target=self._proc_loop, args=(i, self.alive, self.queue, self.fn))
for i in range(self.num_proc)
]
self.alive.value = True
for p in self.proc:
p.start()
def get(self):
"""
Get a datum from the queue
Returns
-------
np.ndarray
A captcha image, normalized to [0, 1]
"""
self._init_proc()
return self.queue.get()
def reset(self):
"""
Resets the generator by stopping all processes
"""
self.alive.value = False
qsize = 0
try:
while True:
self.queue.get(timeout=0.1)
qsize += 1
except QEmptyExcept:
pass
print("Queue size on reset: {}".format(qsize))
for i, p in enumerate(self.proc):
p.join()
self.proc.clear()
import mxnet as mx
def _add_warp_ctc_loss(pred, seq_len, num_label, label):
""" Adds Symbol.contrib.ctc_loss on top of pred symbol and returns the resulting symbol """
label = mx.sym.Reshape(data=label, shape=(-1,))
label = mx.sym.Cast(data=label, dtype='int32')
return mx.sym.WarpCTC(data=pred, label=label, label_length=num_label, input_length=seq_len)
def _add_mxnet_ctc_loss(pred, seq_len, label):
""" Adds Symbol.WapCTC on top of pred symbol and returns the resulting symbol """
pred_ctc = mx.sym.Reshape(data=pred, shape=(-4, seq_len, -1, 0))
loss = mx.sym.contrib.ctc_loss(data=pred_ctc, label=label)
ctc_loss = mx.sym.MakeLoss(loss)
softmax_class = mx.symbol.SoftmaxActivation(data=pred)
softmax_loss = mx.sym.MakeLoss(softmax_class)
softmax_loss = mx.sym.BlockGrad(softmax_loss)
return mx.sym.Group([softmax_loss, ctc_loss])
def add_ctc_loss(pred, seq_len, num_label, loss_type):
""" Adds CTC loss on top of pred symbol and returns the resulting symbol """
label = mx.sym.Variable('label')
if loss_type == 'warpctc':
print("Using WarpCTC Loss")
sm = _add_warp_ctc_loss(pred, seq_len, num_label, label)
else:
print("Using MXNet CTC Loss")
assert loss_type == 'ctc'
sm = _add_mxnet_ctc_loss(pred, seq_len, label)
return sm
\ No newline at end of file
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
"""Contains a class for calculating CTC eval metrics"""
from __future__ import print_function
import numpy as np
class CtcMetrics(object):
def __init__(self, seq_len):
self.seq_len = seq_len
@staticmethod
def ctc_label(p):
"""
Iterates through p, identifying non-zero and non-repeating values, and returns them in a list
Parameters
----------
p: list of int
Returns
-------
list of int
"""
ret = []
p1 = [0] + p
for i, _ in enumerate(p):
c1 = p1[i]
c2 = p1[i+1]
if c2 == 0 or c2 == c1:
continue
ret.append(c2)
return ret
@staticmethod
def _remove_blank(l):
""" Removes trailing zeros in the list of integers and returns a new list of integers"""
ret = []
for i, _ in enumerate(l):
if l[i] == 0:
break
ret.append(l[i])
return ret
@staticmethod
def _lcs(p, l):
""" Calculates the Longest Common Subsequence between p and l (both list of int) and returns its length"""
# Dynamic Programming Finding LCS
if len(p) == 0:
return 0
P = np.array(list(p)).reshape((1, len(p)))
L = np.array(list(l)).reshape((len(l), 1))
M = np.int32(P == L)
for i in range(M.shape[0]):
for j in range(M.shape[1]):
up = 0 if i == 0 else M[i-1, j]
left = 0 if j == 0 else M[i, j-1]
M[i, j] = max(up, left, M[i, j] if (i == 0 or j == 0) else M[i, j] + M[i-1, j-1])
return M.max()
def accuracy(self, label, pred):
""" Simple accuracy measure: number of 100% accurate predictions divided by total number """
hit = 0.
total = 0.
batch_size = label.shape[0]
for i in range(batch_size):
l = self._remove_blank(label[i])
p = []
for k in range(self.seq_len):
p.append(np.argmax(pred[k * batch_size + i]))
p = self.ctc_label(p)
if len(p) == len(l):
match = True
for k, _ in enumerate(p):
if p[k] != int(l[k]):
match = False
break
if match:
hit += 1.0
total += 1.0
assert total == batch_size
return hit / total
def accuracy_lcs(self, label, pred):
""" Longest Common Subsequence accuracy measure: calculate accuracy of each prediction as LCS/length"""
hit = 0.
total = 0.
batch_size = label.shape[0]
for i in range(batch_size):
l = self._remove_blank(label[i])
p = []
for k in range(self.seq_len):
p.append(np.argmax(pred[k * batch_size + i]))
p = self.ctc_label(p)
hit += self._lcs(p, l) * 1.0 / len(l)
total += 1.0
assert total == batch_size
return hit / total
import logging
import os
import mxnet as mx
def _load_model(args, rank=0):
if 'load_epoch' not in args or args.load_epoch is None:
return (None, None, None)
assert args.prefix is not None
model_prefix = args.prefix
if rank > 0 and os.path.exists("%s-%d-symbol.json" % (model_prefix, rank)):
model_prefix += "-%d" % (rank)
sym, arg_params, aux_params = mx.model.load_checkpoint(
model_prefix, args.load_epoch)
logging.info('Loaded model %s_%04d.params', model_prefix, args.load_epoch)
return (sym, arg_params, aux_params)
def fit(network, data_train, data_val, metrics, args, hp, data_names=None):
if args.gpu:
contexts = [mx.context.gpu(i) for i in range(args.gpu)]
else:
contexts = [mx.context.cpu(i) for i in range(args.cpu)]
sym, arg_params, aux_params = _load_model(args)
if sym is not None:
assert sym.tojson() == network.tojson()
if not os.path.exists(os.path.dirname(args.prefix)):
os.makedirs(os.path.dirname(args.prefix))
module = mx.mod.Module(
symbol = network,
data_names= ["data"] if data_names is None else data_names,
label_names=['label'],
context=contexts)
module.fit(train_data=data_train,
eval_data=data_val,
begin_epoch=args.load_epoch if args.load_epoch else 0,
num_epoch=hp.num_epoch,
# use metrics.accuracy or metrics.accuracy_lcs
eval_metric=mx.metric.np(metrics.accuracy, allow_extra_outputs=True),
optimizer='AdaDelta',
optimizer_params={'learning_rate': hp.learning_rate,
# 'momentum': hp.momentum,
'wd': 0.00001,
},
initializer=mx.init.Xavier(factor_type="in", magnitude=2.34),
arg_params=arg_params,
aux_params=aux_params,
batch_end_callback=mx.callback.Speedometer(hp.batch_size, 50),
epoch_end_callback=mx.callback.do_checkpoint(args.prefix),
)
\ No newline at end of file
from __future__ import print_function
from collections import namedtuple
import mxnet as mx
LSTMState = namedtuple("LSTMState", ["c", "h"])
LSTMParam = namedtuple("LSTMParam", ["i2h_weight", "i2h_bias",
"h2h_weight", "h2h_bias"])
LSTMModel = namedtuple("LSTMModel", ["rnn_exec", "symbol",
"init_states", "last_states", "forward_state", "backward_state",
"seq_data", "seq_labels", "seq_outputs",
"param_blocks"])
def init_states(batch_size, num_lstm_layer, num_hidden):
"""
Returns name and shape of init states of LSTM network
Parameters
----------
batch_size: list of tuple of str and tuple of int and int
num_lstm_layer: int
num_hidden: int
Returns
-------
list of tuple of str and tuple of int and int
"""
init_c = [('l%d_init_c' % l, (batch_size, num_hidden)) for l in range(num_lstm_layer * 2)]
init_h = [('l%d_init_h' % l, (batch_size, num_hidden)) for l in range(num_lstm_layer * 2)]
return init_c + init_h
def _lstm(num_hidden, indata, prev_state, param, seqidx, layeridx):
"""LSTM Cell symbol"""
i2h = mx.sym.FullyConnected(data=indata,
weight=param.i2h_weight,
bias=param.i2h_bias,
num_hidden=num_hidden * 4,
name="t%d_l%d_i2h" % (seqidx, layeridx))
h2h = mx.sym.FullyConnected(data=prev_state.h,
weight=param.h2h_weight,
bias=param.h2h_bias,
num_hidden=num_hidden * 4,
name="t%d_l%d_h2h" % (seqidx, layeridx))
gates = i2h + h2h
slice_gates = mx.sym.split(gates, num_outputs=4,
name="t%d_l%d_slice" % (seqidx, layeridx))
in_gate = mx.sym.Activation(slice_gates[0], act_type="sigmoid")
in_transform = mx.sym.Activation(slice_gates[1], act_type="tanh")
forget_gate = mx.sym.Activation(slice_gates[2], act_type="sigmoid")
out_gate = mx.sym.Activation(slice_gates[3], act_type="sigmoid")
next_c = (forget_gate * prev_state.c) + (in_gate * in_transform)
next_h = out_gate * mx.sym.Activation(next_c, act_type="tanh")
return LSTMState(c=next_c, h=next_h)
def lstm(net, num_lstm_layer, num_hidden, seq_length):
last_states = []
forward_param = []
backward_param = []
# seq_length = mx.sym.Variable("seq_length")
for i in range(num_lstm_layer * 2):
last_states.append(LSTMState(c=mx.sym.Variable("l%d_init_c" % i), h=mx.sym.Variable("l%d_init_h" % i)))
if i % 2 == 0:
forward_param.append(LSTMParam(i2h_weight=mx.sym.Variable("l%d_i2h_weight" % i),
i2h_bias=mx.sym.Variable("l%d_i2h_bias" % i),
h2h_weight=mx.sym.Variable("l%d_h2h_weight" % i),
h2h_bias=mx.sym.Variable("l%d_h2h_bias" % i)))
else:
backward_param.append(LSTMParam(i2h_weight=mx.sym.Variable("l%d_i2h_weight" % i),
i2h_bias=mx.sym.Variable("l%d_i2h_bias" % i),
h2h_weight=mx.sym.Variable("l%d_h2h_weight" % i),
h2h_bias=mx.sym.Variable("l%d_h2h_bias" % i)))
slices_net = mx.sym.split(data=net, axis=3, num_outputs=seq_length, squeeze_axis=1) # bz x features x 1 x time_step
# slices_net = mx.sym.slice_axis(data=net, axis=3, begin=0, end=None) # bz x features x 1 x time_step
# seq_length = len(slices_net)
forward_hidden = []
for seqidx in range(seq_length):
hidden = mx.sym.flatten(data=slices_net[seqidx])
for i in range(num_lstm_layer):
next_state = _lstm(num_hidden, indata=hidden, prev_state=last_states[2 * i],
param=forward_param[i], seqidx=seqidx, layeridx=i)
hidden = next_state.h
last_states[2 * i] = next_state
forward_hidden.append(hidden)
backward_hidden = []
for seqidx in range(seq_length):
k = seq_length - seqidx - 1
hidden = mx.sym.flatten(data=slices_net[k])
for i in range(num_lstm_layer):
next_state = _lstm(num_hidden, indata=hidden, prev_state=last_states[2 * i + 1],
param=backward_param[i], seqidx=k, layeridx=i)
hidden = next_state.h
last_states[2 * i + 1] = next_state
backward_hidden.insert(0, hidden)
hidden_all = []
for i in range(seq_length):
hidden_all.append(mx.sym.concat(*[forward_hidden[i], backward_hidden[i]], dim=1))
hidden_concat = mx.sym.concat(*hidden_all, dim=0)
return hidden_concat
from __future__ import print_function
class CnHyperparams(object):
"""
Hyperparameters for LSTM network
"""
def __init__(self):
# Training hyper parameters
self._train_epoch_size = 2560000
self._eval_epoch_size = 3000
self._num_epoch = 20
self._learning_rate = 0.001
self._momentum = 0.9
self._bn_mom = 0.9
self._workspace = 512
self._loss_type = "ctc" # ["warpctc" "ctc"]
self._batch_size = 128
self._num_classes = 6425 # 应该是6426的。。 5990
self._img_width = 280
self._img_height = 32
# DenseNet hyper parameters
self._depth = 161
self._growrate = 32
self._reduction = 0.5
# LSTM hyper parameters
self._num_hidden = 100
self._num_lstm_layer = 2
# self._seq_length = 35
self._seq_length = self._img_width // 8
self._num_label = 10
self._drop_out = 0.5
@property
def train_epoch_size(self):
return self._train_epoch_size
@property
def eval_epoch_size(self):
return self._eval_epoch_size
@property
def num_epoch(self):
return self._num_epoch
@property
def learning_rate(self):
return self._learning_rate
@property
def momentum(self):
return self._momentum
@property
def bn_mom(self):
return self._bn_mom
@property
def workspace(self):
return self._workspace
@property
def loss_type(self):
return self._loss_type
@property
def batch_size(self):
return self._batch_size
@property
def num_classes(self):
return self._num_classes
@property
def img_width(self):
return self._img_width
@property
def img_height(self):
return self._img_height
@property
def depth(self):
return self._depth
@property
def growrate(self):
return self._growrate
@property
def reduction(self):
return self._reduction
@property
def num_hidden(self):
return self._num_hidden
@property
def num_lstm_layer(self):
return self._num_lstm_layer
@property
def seq_length(self):
return self._seq_length
@property
def num_label(self):
return self._num_label
@property
def dropout(self):
return self._drop_out
from __future__ import print_function
class Hyperparams(object):
"""
Hyperparameters for LSTM network
"""
def __init__(self):
# Training hyper parameters
self._train_epoch_size = 30000
self._eval_epoch_size = 3000
self._num_epoch = 20
self._learning_rate = 0.001
self._momentum = 0.9
self._bn_mom = 0.9
self._workspace = 512
self._loss_type = "ctc" # ["warpctc" "ctc"]
self._batch_size = 128
self._num_classes = 11
self._img_width = 100
self._img_height = 32
# DenseNet hyper parameters
self._depth = 161
self._growrate = 32
self._reduction = 0.5
# LSTM hyper parameters
self._num_hidden = 100
self._num_lstm_layer = 2
self._seq_length = self._img_width // 8
self._num_label = 4
self._drop_out = 0.5
@property
def train_epoch_size(self):
return self._train_epoch_size
@property
def eval_epoch_size(self):
return self._eval_epoch_size
@property
def num_epoch(self):
return self._num_epoch
@property
def learning_rate(self):
return self._learning_rate
@property
def momentum(self):
return self._momentum
@property
def bn_mom(self):
return self._bn_mom
@property
def workspace(self):
return self._workspace
@property
def loss_type(self):
return self._loss_type
@property
def batch_size(self):
return self._batch_size
@property
def num_classes(self):
return self._num_classes
@property
def img_width(self):
return self._img_width
@property
def img_height(self):
return self._img_height
@property
def depth(self):
return self._depth
@property
def growrate(self):
return self._growrate
@property
def reduction(self):
return self._reduction
@property
def num_hidden(self):
return self._num_hidden
@property
def num_lstm_layer(self):
return self._num_lstm_layer
@property
def seq_length(self):
return self._seq_length
@property
def num_label(self):
return self._num_label
@property
def dropout(self):
return self._drop_out
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
"""
LeCun, Yann, Leon Bottou, Yoshua Bengio, and Patrick Haffner.
Gradient-based learning applied to document recognition.
Proceedings of the IEEE (1998)
"""
import mxnet as mx
from ..fit.ctc_loss import add_ctc_loss
from ..fit.lstm import lstm
def crnn_no_lstm(hp):
# input
data = mx.sym.Variable('data')
label = mx.sym.Variable('label')
kernel_size = [(3, 3), (3, 3), (3, 3), (3, 3), (3, 3), (3, 3)]
padding_size = [(1, 1), (1, 1), (1, 1), (1, 1), (1, 1), (1, 1)]
layer_size = [min(32*2**(i+1), 512) for i in range(len(kernel_size))]
def convRelu(i, input_data, bn=True):
layer = mx.symbol.Convolution(name='conv-%d' % i, data=input_data, kernel=kernel_size[i], pad=padding_size[i],
num_filter=layer_size[i])
if bn:
layer = mx.sym.BatchNorm(data=layer, name='batchnorm-%d' % i)
layer = mx.sym.LeakyReLU(data=layer,name='leakyrelu-%d' % i)
return layer
net = convRelu(0, data) # bz x f x 32 x 200
max = mx.sym.Pooling(data=net, name='pool-0_m', pool_type='max', kernel=(2, 2), stride=(2, 2))
avg = mx.sym.Pooling(data=net, name='pool-0_a', pool_type='avg', kernel=(2, 2), stride=(2, 2))
net = max - avg # 16 x 100
net = convRelu(1, net)
net = mx.sym.Pooling(data=net, name='pool-1', pool_type='max', kernel=(2, 2), stride=(2, 2)) # bz x f x 8 x 50
net = convRelu(2, net, True)
net = convRelu(3, net)
net = mx.sym.Pooling(data=net, name='pool-2', pool_type='max', kernel=(2, 2), stride=(2, 2)) # bz x f x 4 x 25
net = convRelu(4, net, True)
net = convRelu(5, net)
net = mx.symbol.Pooling(data=net, kernel=(4, 1), pool_type='avg', name='pool1') # bz x f x 1 x 25
if hp.dropout > 0:
net = mx.symbol.Dropout(data=net, p=hp.dropout)
net = mx.sym.transpose(data=net, axes=[1,0,2,3]) # f x bz x 1 x 25
net = mx.sym.flatten(data=net) # f x (bz x 25)
hidden_concat = mx.sym.transpose(data=net, axes=[1,0]) # (bz x 25) x f
# mx.sym.transpose(net, [])
pred = mx.sym.FullyConnected(data=hidden_concat, num_hidden=hp.num_classes) # (bz x 25) x num_classes
if hp.loss_type:
# Training mode, add loss
return add_ctc_loss(pred, hp.seq_length, hp.num_label, hp.loss_type)
else:
# Inference mode, add softmax
return mx.sym.softmax(data=pred, name='softmax')
def crnn_lstm(hp):
# input
data = mx.sym.Variable('data')
label = mx.sym.Variable('label')
# data = mx.sym.Variable('data', shape=(128, 1, 32, 100))
# label = mx.sym.Variable('label', shape=(128, 4))
kernel_size = [(3, 3), (3, 3), (3, 3), (3, 3), (3, 3), (3, 3)]
padding_size = [(1, 1), (1, 1), (1, 1), (1, 1), (1, 1), (1, 1)]
layer_size = [min(32*2**(i+1), 512) for i in range(len(kernel_size))]
def convRelu(i, input_data, bn=True):
layer = mx.symbol.Convolution(name='conv-%d' % i, data=input_data, kernel=kernel_size[i], pad=padding_size[i],
num_filter=layer_size[i])
if bn:
layer = mx.sym.BatchNorm(data=layer, name='batchnorm-%d' % i)
layer = mx.sym.LeakyReLU(data=layer,name='leakyrelu-%d' % i)
layer = mx.symbol.Convolution(name='conv-%d-1x1' % i, data=layer, kernel=(1, 1), pad=(0, 0),
num_filter=layer_size[i])
if bn:
layer = mx.sym.BatchNorm(data=layer, name='batchnorm-%d-1x1' % i)
layer = mx.sym.LeakyReLU(data=layer, name='leakyrelu-%d-1x1' % i)
return layer
net = convRelu(0, data) # bz x f x 32 x 280
# print('0', net.infer_shape()[1])
max = mx.sym.Pooling(data=net, name='pool-0_m', pool_type='max', kernel=(2, 2), stride=(2, 2))
avg = mx.sym.Pooling(data=net, name='pool-0_a', pool_type='avg', kernel=(2, 2), stride=(2, 2))
net = convRelu(1, net)
net = max - avg # 8 x 70
# print('2', net.infer_shape()[1])
net = mx.sym.Pooling(data=net, name='pool-1', pool_type='max', kernel=(2, 2), stride=(2, 2)) # res: bz x f x 8 x 70
# print('3', net.infer_shape()[1])
net = convRelu(2, net, True)
net = convRelu(3, net)
net = mx.sym.Pooling(data=net, name='pool-2', pool_type='max', kernel=(2, 2), stride=(2, 2)) # res: bz x f x 4 x 35
# print('4', net.infer_shape()[1])
net = convRelu(4, net, True)
net = convRelu(5, net)
net = mx.symbol.Pooling(data=net, kernel=(4, 1), pool_type='avg', name='pool1') # res: bz x f x 1 x 35
# print('5', net.infer_shape()[1])
if hp.dropout > 0:
net = mx.symbol.Dropout(data=net, p=hp.dropout)
hidden_concat = lstm(net, num_lstm_layer=hp.num_lstm_layer, num_hidden=hp.num_hidden, seq_length=hp.seq_length)
# import pdb; pdb.set_trace()
# mx.sym.transpose(net, [])
pred = mx.sym.FullyConnected(data=hidden_concat, num_hidden=hp.num_classes, name='pred_fc') # (bz x 25) x num_classes
if hp.loss_type:
# Training mode, add loss
return add_ctc_loss(pred, hp.seq_length, hp.num_label, hp.loss_type)
else:
# Inference mode, add softmax
return mx.sym.softmax(data=pred, name='softmax')
from ..hyperparams.cn_hyperparams import CnHyperparams as Hyperparams
if __name__ == '__main__':
hp = Hyperparams()
init_states = {}
init_states['data'] = (hp.batch_size, 1, hp.img_height, hp.img_width)
init_states['label'] = (hp.batch_size, hp.num_label)
# init_c = {('l%d_init_c' % l): (hp.batch_size, hp.num_hidden) for l in range(hp.num_lstm_layer*2)}
# init_h = {('l%d_init_h' % l): (hp.batch_size, hp.num_hidden) for l in range(hp.num_lstm_layer*2)}
#
# for item in init_c:
# init_states[item] = init_c[item]
# for item in init_h:
# init_states[item] = init_h[item]
symbol = crnn_no_lstm(hp)
interals = symbol.get_internals()
_, out_shapes, _ = interals.infer_shape(**init_states)
shape_dict = dict(zip(interals.list_outputs(), out_shapes))
for item in shape_dict:
print(item,shape_dict[item])
#click==6.7
numpy==1.14.0
pillow==5.3.0
mxnet==1.3.1
gluoncv==0.3.0
#opencv-python==3.4.4.19
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
""" An example of predicting CAPTCHA image data with a LSTM network pre-trained with a CTC loss"""
from __future__ import print_function
import argparse
from cnocr.fit.ctc_metrics import CtcMetrics
# from PIL import Image
from cnocr.hyperparams.cn_hyperparams import CnHyperparams as Hyperparams
from cnocr.hyperparams.hyperparams2 import Hyperparams as Hyperparams2
from cnocr.fit.lstm import init_states
import mxnet as mx
import numpy as np
from cnocr.data_utils.data_iter import SimpleBatch
from cnocr.symbols.crnn import crnn_lstm
def read_captcha_img(path, hp):
""" Reads image specified by path into numpy.ndarray"""
import cv2
tgt_h, tgt_w = hp.img_height, hp.img_width
img = cv2.resize(cv2.imread(path, 0), (tgt_h, tgt_w)).astype(np.float32) / 255
img = np.expand_dims(img.transpose(1, 0), 0) # res: [1, height, width]
return img
def read_ocr_img(path, hp):
# img = Image.open(path).resize((hp.img_width, hp.img_height), Image.BILINEAR)
# img = img.convert('L')
# img = np.expand_dims(np.array(img), 0)
# return img
img = mx.image.imread(path, 0)
scale = hp.img_height / img.shape[0]
new_width = int(scale * img.shape[1])
hp._seq_length = new_width // 8
img = mx.image.imresize(img, new_width, hp.img_height).asnumpy()
img = np.squeeze(img, axis=2)
# import pdb; pdb.set_trace()
return np.expand_dims(img, 0)
# img2 = mx.image.imread(path)
# img2 = mx.image.imresize(img2, hp.img_width, hp.img_height)
# img2 = cv2.cvtColor(img2.asnumpy(), cv2.COLOR_RGB2GRAY)
# img2 = np.expand_dims(np.array(img2), 0)
# return img2
def lstm_init_states(batch_size, hp):
""" Returns a tuple of names and zero arrays for LSTM init states"""
init_shapes = init_states(batch_size=batch_size, num_lstm_layer=hp.num_lstm_layer, num_hidden=hp.num_hidden)
init_names = [s[0] for s in init_shapes]
init_arrays = [mx.nd.zeros(x[1]) for x in init_shapes]
# init_names.append('seq_length')
# init_arrays.append(hp.seq_length)
return init_names, init_arrays
def load_module(prefix, epoch, data_names, data_shapes, network=None):
"""
Loads the model from checkpoint specified by prefix and epoch, binds it
to an executor, and sets its parameters and returns a mx.mod.Module
"""
sym, arg_params, aux_params = mx.model.load_checkpoint(prefix, epoch)
if network is not None:
sym = network
# We don't need CTC loss for prediction, just a simple softmax will suffice.
# We get the output of the layer just before the loss layer ('pred_fc') and add softmax on top
pred_fc = sym.get_internals()['pred_fc_output']
sym = mx.sym.softmax(data=pred_fc)
mod = mx.mod.Module(symbol=sym, context=mx.cpu(), data_names=data_names, label_names=None)
mod.bind(for_training=False, data_shapes=data_shapes)
mod.set_params(arg_params, aux_params, allow_missing=False)
return mod
def read_charset(charset_fp):
alphabet = []
# 第0个元素是预留id,在CTC中用来分割字符。它不对应有意义的字符
with open(charset_fp) as fp:
for line in fp:
alphabet.append(line.rstrip('\n'))
print('Alphabet size: %d' % len(alphabet))
inv_alph_dict = {_char: idx for idx, _char in enumerate(alphabet)}
inv_alph_dict[' '] = inv_alph_dict['<space>'] # 对应空格
return alphabet, inv_alph_dict
def main():
parser = argparse.ArgumentParser()
parser.add_argument("--dataset", help="use which kind of dataset, captcha or cn_ocr",
choices=['captcha', 'cn_ocr'], type=str, default='cn_ocr')
parser.add_argument("--file", help="Path to the CAPTCHA image file")
parser.add_argument("--prefix", help="Checkpoint prefix [Default 'ocr']", default='./models/model')
parser.add_argument("--epoch", help="Checkpoint epoch [Default 100]", type=int, default=100)
parser.add_argument('--charset_file', type=str, help='存储了每个字对应哪个id的关系.')
args = parser.parse_args()
if args.dataset == 'cn_ocr':
hp = Hyperparams()
img = read_ocr_img(args.file, hp)
else:
hp = Hyperparams2()
img = read_captcha_img(args.file, hp)
init_state_names, init_state_arrays = lstm_init_states(batch_size=1, hp=hp)
# import pdb; pdb.set_trace()
sample = SimpleBatch(
data_names=['data'] + init_state_names,
data=[mx.nd.array([img])] + init_state_arrays)
network = crnn_lstm(hp)
mod = load_module(args.prefix, args.epoch, sample.data_names, sample.provide_data, network=network)
mod.forward(sample)
prob = mod.get_outputs()[0].asnumpy()
prediction = CtcMetrics.ctc_label(np.argmax(prob, axis=-1).tolist())
if args.charset_file:
alphabet, _ = read_charset(args.charset_file)
res = [alphabet[p] for p in prediction]
print("Predicted Chars:", res)
else:
# Predictions are 1 to 10 for digits 0 to 9 respectively (prediction 0 means no-digit)
prediction = [p - 1 for p in prediction]
print("Digits:", prediction)
return
if __name__ == '__main__':
main()
#!/usr/bin/env bash
# -*- coding: utf-8 -*-
cd `dirname $0`
# 训练中文ocr模型crnn
python train_ocr.py --cpu 2 --num_proc 4 --loss ctc --dataset cn_ocr
# coding: utf-8
from __future__ import print_function
import argparse
import logging
import os
import mxnet as mx
from cnocr.data_utils.captcha_generator import MPDigitCaptcha
from cnocr.hyperparams.cn_hyperparams import CnHyperparams as Hyperparams
from cnocr.hyperparams.hyperparams2 import Hyperparams as Hyperparams2
from cnocr.data_utils.data_iter import ImageIterLstm, MPOcrImages, OCRIter
from cnocr.symbols.crnn import crnn_no_lstm, crnn_lstm
from cnocr.fit.ctc_metrics import CtcMetrics
from cnocr.fit.fit import fit
def parse_args():
# Parse command line arguments
parser = argparse.ArgumentParser()
parser.add_argument("--dataset",
help="use which kind of dataset, captcha or cn_ocr",
choices=['captcha', 'cn_ocr'],
type=str, default='captcha')
parser.add_argument("--data_root", help="Path to image files", type=str,
default='/Users/king/Documents/WhatIHaveDone/Test/text_renderer/output/wechat_simulator')
parser.add_argument("--train_file", help="Path to train txt file", type=str,
default='/Users/king/Documents/WhatIHaveDone/Test/text_renderer/output/wechat_simulator/train.txt')
parser.add_argument("--test_file", help="Path to test txt file", type=str,
default='/Users/king/Documents/WhatIHaveDone/Test/text_renderer/output/wechat_simulator/test.txt')
parser.add_argument("--cpu",
help="Number of CPUs for training [Default 8]. Ignored if --gpu is specified.",
type=int, default=2)
parser.add_argument("--gpu", help="Number of GPUs for training [Default 0]", type=int)
parser.add_argument('--load_epoch', type=int,
help='load the model on an epoch using the model-load-prefix')
parser.add_argument("--prefix", help="Checkpoint prefix [Default 'ocr']", default='./models/model')
parser.add_argument("--loss", help="'ctc' or 'warpctc' loss [Default 'ctc']", default='ctc')
parser.add_argument("--num_proc", help="Number CAPTCHA generating processes [Default 4]", type=int, default=4)
parser.add_argument("--font_path", help="Path to ttf font file or directory containing ttf files")
return parser.parse_args()
def get_fonts(path):
fonts = list()
if os.path.isdir(path):
for filename in os.listdir(path):
if filename.endswith('.ttf') or filename.endswith('.ttc'):
fonts.append(os.path.join(path, filename))
else:
fonts.append(path)
return fonts
def run_captcha(args):
hp = Hyperparams2()
network = crnn_lstm(hp)
# arg_shape, out_shape, aux_shape = network.infer_shape(data=(128, 1, 32, 100), label=(128, 10),
# l0_init_h=(128, 100), l1_init_h=(128, 100), l2_init_h=(128, 100), l3_init_h=(128, 100))
# print(dict(zip(network.list_arguments(), arg_shape)))
# import pdb; pdb.set_trace()
# Start a multiprocessor captcha image generator
mp_captcha = MPDigitCaptcha(
font_paths=get_fonts(args.font_path), h=hp.img_width, w=hp.img_height,
num_digit_min=3, num_digit_max=4, num_processes=args.num_proc, max_queue_size=hp.batch_size * 2)
mp_captcha.start()
# img, num = mp_captcha.get()
# print(img.shape)
# import numpy as np
# import cv2
# img = np.transpose(img, (1, 0))
# cv2.imwrite('captcha1.png', img * 255)
# import pdb; pdb.set_trace()
init_c = [('l%d_init_c' % l, (hp.batch_size, hp.num_hidden)) for l in range(hp.num_lstm_layer * 2)]
init_h = [('l%d_init_h' % l, (hp.batch_size, hp.num_hidden)) for l in range(hp.num_lstm_layer * 2)]
init_states = init_c + init_h
data_names = ['data'] + [x[0] for x in init_states]
data_train = OCRIter(
hp.train_epoch_size // hp.batch_size, hp.batch_size, init_states, captcha=mp_captcha, num_label=hp.num_label,
name='train')
data_val = OCRIter(
hp.eval_epoch_size // hp.batch_size, hp.batch_size, init_states, captcha=mp_captcha, num_label=hp.num_label,
name='val')
head = '%(asctime)-15s %(message)s'
logging.basicConfig(level=logging.DEBUG, format=head)
metrics = CtcMetrics(hp.seq_length)
fit(network=network, data_train=data_train, data_val=data_val, metrics=metrics, args=args, hp=hp, data_names=data_names)
mp_captcha.reset()
def run_cn_ocr(args):
hp = Hyperparams()
network = crnn_lstm(hp)
mp_data_train = MPOcrImages(args.data_root, args.train_file, (hp.img_width, hp.img_height), hp.num_label,
num_processes=args.num_proc, max_queue_size=hp.batch_size * 2)
# img, num = mp_data_train.get()
# print(img.shape)
# print(mp_data_train.shape)
# import pdb; pdb.set_trace()
# import numpy as np
# import cv2
# img = np.transpose(img, (1, 0))
# cv2.imwrite('captcha1.png', img * 255)
# import pdb; pdb.set_trace()
mp_data_test = MPOcrImages(args.data_root, args.test_file, (hp.img_width, hp.img_height), hp.num_label,
num_processes=max(args.num_proc // 2, 1), max_queue_size=hp.batch_size * 2)
mp_data_train.start()
mp_data_test.start()
init_c = [('l%d_init_c' % l, (hp.batch_size, hp.num_hidden)) for l in range(hp.num_lstm_layer * 2)]
init_h = [('l%d_init_h' % l, (hp.batch_size, hp.num_hidden)) for l in range(hp.num_lstm_layer * 2)]
init_states = init_c + init_h
data_names = ['data'] + [x[0] for x in init_states]
data_train = OCRIter(
hp.train_epoch_size // hp.batch_size, hp.batch_size, init_states, captcha=mp_data_train, num_label=hp.num_label,
name='train')
data_val = OCRIter(
hp.eval_epoch_size // hp.batch_size, hp.batch_size, init_states, captcha=mp_data_test, num_label=hp.num_label,
name='val')
# data_train = ImageIterLstm(
# args.data_root, args.train_file, hp.batch_size, (hp.img_width, hp.img_height), hp.num_label, init_states, name="train")
# data_val = ImageIterLstm(
# args.data_root, args.test_file, hp.batch_size, (hp.img_width, hp.img_height), hp.num_label, init_states, name="val")
head = '%(asctime)-15s %(message)s'
logging.basicConfig(level=logging.DEBUG, format=head)
metrics = CtcMetrics(hp.seq_length)
fit(network=network, data_train=data_train, data_val=data_val, metrics=metrics, args=args, hp=hp, data_names=data_names)
mp_data_train.reset()
mp_data_test.start()
if __name__ == '__main__':
args = parse_args()
if args.dataset == 'captcha':
run_captcha(args)
else:
run_cn_ocr(args)
#!/usr/bin/env python3
import os
from setuptools import find_packages, setup
from setuptools.command.build_py import build_py
from subprocess import check_call
dir_path = os.path.dirname(os.path.realpath(__file__))
required = [
'numpy>=1.14.0,<1.15.0',
'pillow>=5.3.0',
'mxnet>=1.3.1,<1.4.0',
'gluoncv>=0.3.0,<0.4.0',
]
setup(
name='cnocr',
version='0.1',
description="Package for Chinese OCR, which can be used after installed without training yourself OCR model",
author='breezedeus',
author_email='breezedeus@163.com',
license='Apache 2.0',
url='https://github.com/breezedeus/cnocr',
platforms=["all"],
packages=find_packages(),
# entry_points={'console_scripts': ['chitchatbot=chitchatbot.cli:main'],
# 'plus.ein.botlet': ['chitchatbot=chitchatbot:ChitchatBot'],
# 'plus.ein.botlet.parser': ['chitchatbot=chitchatbot:Spec']},
include_package_data=True,
install_requires=required,
zip_safe=False,
classifiers=[
'Development Status :: 4 - Beta',
'Operating System :: OS Independent',
'Intended Audience :: Developers',
'License :: OSI Approved :: Apache 2.0 License',
'Programming Language :: Python',
'Programming Language :: Python :: Implementation',
'Programming Language :: Python :: 3',
'Programming Language :: Python :: 3.4',
'Programming Language :: Python :: 3.5',
'Programming Language :: Python :: 3.6',
'Topic :: Software Development :: Libraries'
],
)
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册