提交 f2006ba9 编写于 作者: W Wei Tang

Code for CRNN

training code for chinese text recognition
上级 f7d4636f
from __future__ import print_function
import os
from PIL import Image
import numpy as np
import mxnet as mx
import random
def write_txt_file():
root_path = "D:/Data/VOCtrainval_11-May-2012/test/"
dirs = os.listdir(os.path.join(root_path,"images"))
content = []
for d in dirs:
files = os.listdir(os.path.join(root_path,"images", d))
for f in files:
content.append(d+"/"+f+" "+d+"\n")
random.shuffle(content)
train_f = open(os.path.join(root_path,"train.txt"),"w")
test_f = open(os.path.join(root_path, "test.txt"), "w")
for i,c in enumerate(content):
if i < 0.8*len(content):
train_f.write(c)
else:
test_f.write(c)
train_f.close()
test_f.close()
def write_mx_lst(data_type="train"):
txt_file = "D:/BaiduNetdiskDownload/Synthetic_Chinese_String_Dataset/"
in_f = open(os.path.join(txt_file, data_type+".txt"), "r")
out_f = open(os.path.join(txt_file, data_type+".lst"), "w")
lines = in_f.readlines()
random.shuffle(lines)
for idx, line in enumerate(lines):
new_line = str(idx)+"\t"
lst = line.strip().split(" ")
for i in range(len(lst)-1):
new_line = new_line+lst[i+1]+"\t"
new_line = new_line+"images/"+lst[0]+"\n"
out_f.write(new_line)
in_f.close()
out_f.close()
class SimpleBatch(object):
def __init__(self, data_names, data, label_names=list(), label=list()):
self._data = data
self._label = label
self._data_names = data_names
self._label_names = label_names
self.pad = 0
self.index = None # TODO: what is index?
@property
def data(self):
return self._data
@property
def label(self):
return self._label
@property
def data_names(self):
return self._data_names
@property
def label_names(self):
return self._label_names
@property
def provide_data(self):
return [(n, x.shape) for n, x in zip(self._data_names, self._data)]
@property
def provide_label(self):
return [(n, x.shape) for n, x in zip(self._label_names, self._label)]
class ImageIter(mx.io.DataIter):
"""
Iterator class for generating captcha image data
"""
def __init__(self, data_root, data_list, batch_size, data_shape, num_label, name=None):
"""
Parameters
----------
data_root: str
root directory of images
data_list: str
a .txt file stores the image name and corresponding labels for each line
batch_size: int
name: str
"""
super(ImageIter, self).__init__()
self.batch_size = batch_size
self.data_shape = data_shape
self.num_label = num_label
self.data_root = data_root
self.dataset_lst_file = open(data_list)
self.provide_data = [('data', (batch_size, 1, data_shape[1], data_shape[0]))]
self.provide_label = [('label', (self.batch_size, self.num_label))]
self.name = name
def __iter__(self):
data = []
label = []
cnt = 0
for m_line in self.dataset_lst_file:
img_lst = m_line.strip().split(' ')
img_path = os.path.join(self.data_root, img_lst[0])
cnt += 1
img = Image.open(img_path).resize(self.data_shape, Image.BILINEAR).convert('L')
img = np.array(img).reshape((1, self.data_shape[1], self.data_shape[0]))
data.append(img)
ret = np.zeros(self.num_label, int)
for idx in range(1, len(img_lst)):
ret[idx-1] = int(img_lst[idx])
label.append(ret)
if cnt % self.batch_size == 0:
data_all = [mx.nd.array(data)]
label_all = [mx.nd.array(label)]
data_names = ['data']
label_names = ['label']
data.clear()
label.clear()
yield SimpleBatch(data_names, data_all, label_names, label_all)
continue
def reset(self):
if self.dataset_lst_file.seekable():
self.dataset_lst_file.seek(0)
class ImageIterLstm(mx.io.DataIter):
"""
Iterator class for generating captcha image data
"""
def __init__(self, data_root, data_list, batch_size, data_shape, num_label, lstm_init_states, name=None):
"""
Parameters
----------
data_root: str
root directory of images
data_list: str
a .txt file stores the image name and corresponding labels for each line
batch_size: int
name: str
"""
super(ImageIterLstm, self).__init__()
self.batch_size = batch_size
self.data_shape = data_shape
self.num_label = num_label
self.init_states = lstm_init_states
self.init_state_arrays = [mx.nd.zeros(x[1]) for x in lstm_init_states]
self.data_root = data_root
self.dataset_lines = open(data_list).readlines()
self.provide_data = [('data', (batch_size, 1, data_shape[1], data_shape[0]))] + lstm_init_states
self.provide_label = [('label', (self.batch_size, self.num_label))]
self.name = name
def __iter__(self):
init_state_names = [x[0] for x in self.init_states]
data = []
label = []
cnt = 0
for m_line in self.dataset_lines:
img_lst = m_line.strip().split(' ')
img_path = os.path.join(self.data_root, img_lst[0])
cnt += 1
img = Image.open(img_path).resize(self.data_shape, Image.BILINEAR).convert('L')
img = np.array(img).reshape((1, self.data_shape[1], self.data_shape[0]))
data.append(img)
ret = np.zeros(self.num_label, int)
for idx in range(1, len(img_lst)):
ret[idx - 1] = int(img_lst[idx])
label.append(ret)
if cnt % self.batch_size == 0:
data_all = [mx.nd.array(data)] + self.init_state_arrays
label_all = [mx.nd.array(label)]
data_names = ['data'] + init_state_names
label_names = ['label']
data = []
label = []
yield SimpleBatch(data_names, data_all, label_names, label_all)
continue
def reset(self):
# if self.dataset_lst_file.seekable():
# self.dataset_lst_file.seek(0)
random.shuffle(self.dataset_lines)
# def get_label(buf):
# ret = np.zeros(10)
# for i in range(len(buf)):
# ret[i] = 1 + int(buf[i])
# if len(buf) == 9:
# ret[3] = 0
# return ret
# class OCRIter(mx.io.DataIter):
# """
# Iterator class for generating captcha image data
# """
#
# def __init__(self, count, batch_size, captcha, name):
# """
# Parameters
# ----------
# count: int
# Number of batches to produce for one epoch
# batch_size: int
#
# captcha MPCaptcha
# Captcha image generator. Can be MPCaptcha or any other class providing .shape and .get() interface
# name: str
# """
# super(OCRIter, self).__init__()
# self.batch_size = batch_size
# self.count = count
#
# self.data_shape = captcha.shape
# print(self.data_shape)
# self.provide_data = [('data', (batch_size, 1, self.data_shape[0], self.data_shape[1]))]
# self.provide_label = [('label', (self.batch_size, 10))]
# self.mp_captcha = captcha
# self.name = name
#
# def __iter__(self):
# for k in range(self.count):
# data = []
# label = []
# for i in range(self.batch_size):
# img, num = self.mp_captcha.get()
# img = np.array(img).reshape((1, self.data_shape[0], self.data_shape[1]))
# data.append(img)
# label.append(get_label(num))
# data_all = [mx.nd.array(data)]
# label_all = [mx.nd.array(label)]
# data_names = ['data']
# label_names = ['label']
#
# data_batch = SimpleBatch(data_names, data_all, label_names, label_all)
# yield data_batch
if __name__=="__main__":
write_mx_lst("test")
\ No newline at end of file
import mxnet as mx
def _add_warp_ctc_loss(pred, seq_len, num_label, label):
""" Adds Symbol.contrib.ctc_loss on top of pred symbol and returns the resulting symbol """
label = mx.sym.Reshape(data=label, shape=(-1,))
label = mx.sym.Cast(data=label, dtype='int32')
return mx.sym.WarpCTC(data=pred, label=label, label_length=num_label, input_length=seq_len)
def _add_mxnet_ctc_loss(pred, seq_len, label):
""" Adds Symbol.WapCTC on top of pred symbol and returns the resulting symbol """
pred_ctc = mx.sym.Reshape(data=pred, shape=(-4, seq_len, -1, 0))
loss = mx.sym.contrib.ctc_loss(data=pred_ctc, label=label)
ctc_loss = mx.sym.MakeLoss(loss)
softmax_class = mx.symbol.SoftmaxActivation(data=pred)
softmax_loss = mx.sym.MakeLoss(softmax_class)
softmax_loss = mx.sym.BlockGrad(softmax_loss)
return mx.sym.Group([softmax_loss, ctc_loss])
def add_ctc_loss(pred, seq_len, num_label, loss_type):
""" Adds CTC loss on top of pred symbol and returns the resulting symbol """
label = mx.sym.Variable('label')
if loss_type == 'warpctc':
print("Using WarpCTC Loss")
sm = _add_warp_ctc_loss(pred, seq_len, num_label, label)
else:
print("Using MXNet CTC Loss")
assert loss_type == 'ctc'
sm = _add_mxnet_ctc_loss(pred, seq_len, label)
return sm
\ No newline at end of file
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
"""Contains a class for calculating CTC eval metrics"""
from __future__ import print_function
import numpy as np
class CtcMetrics(object):
def __init__(self, seq_len):
self.seq_len = seq_len
@staticmethod
def ctc_label(p):
"""
Iterates through p, identifying non-zero and non-repeating values, and returns them in a list
Parameters
----------
p: list of int
Returns
-------
list of int
"""
ret = []
p1 = [0] + p
for i, _ in enumerate(p):
c1 = p1[i]
c2 = p1[i+1]
if c2 == 0 or c2 == c1:
continue
ret.append(c2)
return ret
@staticmethod
def _remove_blank(l):
""" Removes trailing zeros in the list of integers and returns a new list of integers"""
ret = []
for i, _ in enumerate(l):
if l[i] == 0:
break
ret.append(l[i])
return ret
@staticmethod
def _lcs(p, l):
""" Calculates the Longest Common Subsequence between p and l (both list of int) and returns its length"""
# Dynamic Programming Finding LCS
if len(p) == 0:
return 0
P = np.array(list(p)).reshape((1, len(p)))
L = np.array(list(l)).reshape((len(l), 1))
M = np.int32(P == L)
for i in range(M.shape[0]):
for j in range(M.shape[1]):
up = 0 if i == 0 else M[i-1, j]
left = 0 if j == 0 else M[i, j-1]
M[i, j] = max(up, left, M[i, j] if (i == 0 or j == 0) else M[i, j] + M[i-1, j-1])
return M.max()
def accuracy(self, label, pred):
""" Simple accuracy measure: number of 100% accurate predictions divided by total number """
hit = 0.
total = 0.
batch_size = label.shape[0]
for i in range(batch_size):
l = self._remove_blank(label[i])
p = []
for k in range(self.seq_len):
p.append(np.argmax(pred[k * batch_size + i]))
p = self.ctc_label(p)
if len(p) == len(l):
match = True
for k, _ in enumerate(p):
if p[k] != int(l[k]):
match = False
break
if match:
hit += 1.0
total += 1.0
assert total == batch_size
return hit / total
def accuracy_lcs(self, label, pred):
""" Longest Common Subsequence accuracy measure: calculate accuracy of each prediction as LCS/length"""
hit = 0.
total = 0.
batch_size = label.shape[0]
for i in range(batch_size):
l = self._remove_blank(label[i])
p = []
for k in range(self.seq_len):
p.append(np.argmax(pred[k * batch_size + i]))
p = self.ctc_label(p)
hit += self._lcs(p, l) * 1.0 / len(l)
total += 1.0
assert total == batch_size
return hit / total
import logging
import os
import mxnet as mx
def _load_model(args, rank=0):
if 'load_epoch' not in args or args.load_epoch is None:
return (None, None, None)
assert args.prefix is not None
model_prefix = args.prefix
if rank > 0 and os.path.exists("%s-%d-symbol.json" % (model_prefix, rank)):
model_prefix += "-%d" % (rank)
sym, arg_params, aux_params = mx.model.load_checkpoint(
model_prefix, args.load_epoch)
logging.info('Loaded model %s_%04d.params', model_prefix, args.load_epoch)
return (sym, arg_params, aux_params)
def fit(network, data_train, data_val, metrics, args, hp, data_names=None):
if args.gpu:
contexts = [mx.context.gpu(i) for i in range(args.gpu)]
else:
contexts = [mx.context.cpu(i) for i in range(args.cpu)]
sym, arg_params, aux_params = _load_model(args)
if sym is not None:
assert sym.tojson() == network.tojson()
module = mx.mod.Module(
symbol = network,
data_names= ["data"] if data_names is None else data_names,
label_names=['label'],
context=contexts)
module.fit(train_data=data_train,
eval_data=data_val,
begin_epoch=args.load_epoch if args.load_epoch else 0,
num_epoch=hp.num_epoch,
# use metrics.accuracy or metrics.accuracy_lcs
eval_metric=mx.metric.np(metrics.accuracy, allow_extra_outputs=True),
optimizer='AdaDelta',
optimizer_params={'learning_rate': hp.learning_rate,
# 'momentum': hp.momentum,
'wd': 0.00001,
},
initializer=mx.init.Xavier(factor_type="in", magnitude=2.34),
arg_params=arg_params,
aux_params=aux_params,
batch_end_callback=mx.callback.Speedometer(hp.batch_size, 50),
epoch_end_callback=mx.callback.do_checkpoint(args.prefix),
)
\ No newline at end of file
from __future__ import print_function
from collections import namedtuple
import mxnet as mx
LSTMState = namedtuple("LSTMState", ["c", "h"])
LSTMParam = namedtuple("LSTMParam", ["i2h_weight", "i2h_bias",
"h2h_weight", "h2h_bias"])
LSTMModel = namedtuple("LSTMModel", ["rnn_exec", "symbol",
"init_states", "last_states", "forward_state", "backward_state",
"seq_data", "seq_labels", "seq_outputs",
"param_blocks"])
def _lstm(num_hidden, indata, prev_state, param, seqidx, layeridx):
"""LSTM Cell symbol"""
i2h = mx.sym.FullyConnected(data=indata,
weight=param.i2h_weight,
bias=param.i2h_bias,
num_hidden=num_hidden * 4,
name="t%d_l%d_i2h" % (seqidx, layeridx))
h2h = mx.sym.FullyConnected(data=prev_state.h,
weight=param.h2h_weight,
bias=param.h2h_bias,
num_hidden=num_hidden * 4,
name="t%d_l%d_h2h" % (seqidx, layeridx))
gates = i2h + h2h
slice_gates = mx.sym.split(gates, num_outputs=4,
name="t%d_l%d_slice" % (seqidx, layeridx))
in_gate = mx.sym.Activation(slice_gates[0], act_type="sigmoid")
in_transform = mx.sym.Activation(slice_gates[1], act_type="tanh")
forget_gate = mx.sym.Activation(slice_gates[2], act_type="sigmoid")
out_gate = mx.sym.Activation(slice_gates[3], act_type="sigmoid")
next_c = (forget_gate * prev_state.c) + (in_gate * in_transform)
next_h = out_gate * mx.sym.Activation(next_c, act_type="tanh")
return LSTMState(c=next_c, h=next_h)
def lstm(net, num_lstm_layer, num_hidden, seq_length):
last_states = []
forward_param = []
backward_param = []
for i in range(num_lstm_layer * 2):
last_states.append(LSTMState(c=mx.sym.Variable("l%d_init_c" % i), h=mx.sym.Variable("l%d_init_h" % i)))
if i % 2 == 0:
forward_param.append(LSTMParam(i2h_weight=mx.sym.Variable("l%d_i2h_weight" % i),
i2h_bias=mx.sym.Variable("l%d_i2h_bias" % i),
h2h_weight=mx.sym.Variable("l%d_h2h_weight" % i),
h2h_bias=mx.sym.Variable("l%d_h2h_bias" % i)))
else:
backward_param.append(LSTMParam(i2h_weight=mx.sym.Variable("l%d_i2h_weight" % i),
i2h_bias=mx.sym.Variable("l%d_i2h_bias" % i),
h2h_weight=mx.sym.Variable("l%d_h2h_weight" % i),
h2h_bias=mx.sym.Variable("l%d_h2h_bias" % i)))
slices_net = mx.sym.split(data=net, axis=3, num_outputs=seq_length, squeeze_axis=1) # bz x features x 1 x time_step
forward_hidden = []
for seqidx in range(seq_length):
hidden = mx.sym.flatten(data=slices_net[seqidx])
for i in range(num_lstm_layer):
next_state = _lstm(num_hidden, indata=hidden, prev_state=last_states[2 * i],
param=forward_param[i], seqidx=seqidx, layeridx=0)
hidden = next_state.h
last_states[2 * i] = next_state
forward_hidden.append(hidden)
backward_hidden = []
for seqidx in range(seq_length):
k = seq_length - seqidx - 1
hidden = mx.sym.flatten(data=slices_net[k])
for i in range(num_lstm_layer):
next_state = _lstm(num_hidden, indata=hidden, prev_state=last_states[2 * i + 1],
param=backward_param[i], seqidx=k, layeridx=1)
hidden = next_state.h
last_states[2 * i + 1] = next_state
backward_hidden.insert(0, hidden)
hidden_all = []
for i in range(seq_length):
hidden_all.append(mx.sym.concat(*[forward_hidden[i], backward_hidden[i]], dim=1))
hidden_concat = mx.sym.concat(*hidden_all, dim=0)
return hidden_concat
\ No newline at end of file
from __future__ import print_function
class Hyperparams(object):
"""
Hyperparameters for LSTM network
"""
def __init__(self):
# Training hyper parameters
self._train_epoch_size = 30000
self._eval_epoch_size = 3000
self._num_epoch = 20
self._learning_rate = 0.001
self._momentum = 0.9
self._bn_mom = 0.9
self._workspace = 512
self._loss_type = "warpctc" # ["warpctc" "ctc"]
self._batch_size = 128
self._num_classes = 5990
self._img_width = 280
self._img_height = 32
# DenseNet hyper parameters
self._depth = 161
self._growrate = 32
self._reduction = 0.5
# LSTM hyper parameters
self._num_hidden = 100
self._num_lstm_layer = 2
self._seq_length = 35
self._num_label = 10
self._drop_out = 0.5
@property
def train_epoch_size(self):
return self._train_epoch_size
@property
def eval_epoch_size(self):
return self._eval_epoch_size
@property
def num_epoch(self):
return self._num_epoch
@property
def learning_rate(self):
return self._learning_rate
@property
def momentum(self):
return self._momentum
@property
def bn_mom(self):
return self._bn_mom
@property
def workspace(self):
return self._workspace
@property
def loss_type(self):
return self._loss_type
@property
def batch_size(self):
return self._batch_size
@property
def num_classes(self):
return self._num_classes
@property
def img_width(self):
return self._img_width
@property
def img_height(self):
return self._img_height
@property
def depth(self):
return self._depth
@property
def growrate(self):
return self._growrate
@property
def reduction(self):
return self._reduction
@property
def num_hidden(self):
return self._num_hidden
@property
def num_lstm_layer(self):
return self._num_lstm_layer
@property
def seq_length(self):
return self._seq_length
@property
def num_label(self):
return self._num_label
@property
def dropout(self):
return self._drop_out
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
"""
LeCun, Yann, Leon Bottou, Yoshua Bengio, and Patrick Haffner.
Gradient-based learning applied to document recognition.
Proceedings of the IEEE (1998)
"""
import mxnet as mx
from fit.ctc_loss import add_ctc_loss
from fit.lstm import lstm
def crnn_no_lstm(hp):
# input
data = mx.sym.Variable('data')
label = mx.sym.Variable('label')
kernel_size = [(3, 3), (3, 3), (3, 3), (3, 3), (3, 3), (3, 3)]
padding_size = [(1, 1), (1, 1), (1, 1), (1, 1), (1, 1), (1, 1)]
layer_size = [min(32*2**(i+1), 512) for i in range(len(kernel_size))]
def convRelu(i, input_data, bn=True):
layer = mx.symbol.Convolution(name='conv-%d' % i, data=input_data, kernel=kernel_size[i], pad=padding_size[i],
num_filter=layer_size[i])
if bn:
layer = mx.sym.BatchNorm(data=layer, name='batchnorm-%d' % i)
layer = mx.sym.LeakyReLU(data=layer,name='leakyrelu-%d' % i)
return layer
net = convRelu(0, data) # bz x f x 32 x 200
max = mx.sym.Pooling(data=net, name='pool-0_m', pool_type='max', kernel=(2, 2), stride=(2, 2))
avg = mx.sym.Pooling(data=net, name='pool-0_a', pool_type='avg', kernel=(2, 2), stride=(2, 2))
net = max - avg # 16 x 100
net = convRelu(1, net)
net = mx.sym.Pooling(data=net, name='pool-1', pool_type='max', kernel=(2, 2), stride=(2, 2)) # bz x f x 8 x 50
net = convRelu(2, net, True)
net = convRelu(3, net)
net = mx.sym.Pooling(data=net, name='pool-2', pool_type='max', kernel=(2, 2), stride=(2, 2)) # bz x f x 4 x 25
net = convRelu(4, net, True)
net = convRelu(5, net)
net = mx.symbol.Pooling(data=net, kernel=(4, 1), pool_type='avg', name='pool1') # bz x f x 1 x 25
if hp.dropout > 0:
net = mx.symbol.Dropout(data=net, p=hp.dropout)
net = mx.sym.transpose(data=net, axes=[1,0,2,3]) # f x bz x 1 x 25
net = mx.sym.flatten(data=net) # f x (bz x 25)
hidden_concat = mx.sym.transpose(data=net, axes=[1,0]) # (bz x 25) x f
# mx.sym.transpose(net, [])
pred = mx.sym.FullyConnected(data=hidden_concat, num_hidden=hp.num_classes) # (bz x 25) x num_classes
if hp.loss_type:
# Training mode, add loss
return add_ctc_loss(pred, hp.seq_length, hp.num_label, hp.loss_type)
else:
# Inference mode, add softmax
return mx.sym.softmax(data=pred, name='softmax')
def crnn_lstm(hp):
# input
data = mx.sym.Variable('data')
label = mx.sym.Variable('label')
kernel_size = [(3, 3), (3, 3), (3, 3), (3, 3), (3, 3), (3, 3)]
padding_size = [(1, 1), (1, 1), (1, 1), (1, 1), (1, 1), (1, 1)]
layer_size = [min(32*2**(i+1), 512) for i in range(len(kernel_size))]
def convRelu(i, input_data, bn=True):
layer = mx.symbol.Convolution(name='conv-%d' % i, data=input_data, kernel=kernel_size[i], pad=padding_size[i],
num_filter=layer_size[i])
if bn:
layer = mx.sym.BatchNorm(data=layer, name='batchnorm-%d' % i)
layer = mx.sym.LeakyReLU(data=layer,name='leakyrelu-%d' % i)
layer = mx.symbol.Convolution(name='conv-%d-1x1' % i, data=layer, kernel=(1, 1), pad=(0, 0),
num_filter=layer_size[i])
if bn:
layer = mx.sym.BatchNorm(data=layer, name='batchnorm-%d-1x1' % i)
layer = mx.sym.LeakyReLU(data=layer, name='leakyrelu-%d-1x1' % i)
return layer
net = convRelu(0, data) # bz x f x 32 x 200
max = mx.sym.Pooling(data=net, name='pool-0_m', pool_type='max', kernel=(2, 2), stride=(2, 2))
avg = mx.sym.Pooling(data=net, name='pool-0_a', pool_type='avg', kernel=(2, 2), stride=(2, 2))
net = max - avg # 16 x 100
net = convRelu(1, net)
net = mx.sym.Pooling(data=net, name='pool-1', pool_type='max', kernel=(2, 2), stride=(2, 2)) # bz x f x 8 x 50
net = convRelu(2, net, True)
net = convRelu(3, net)
net = mx.sym.Pooling(data=net, name='pool-2', pool_type='max', kernel=(2, 2), stride=(2, 2)) # bz x f x 4 x 25
net = convRelu(4, net, True)
net = convRelu(5, net)
net = mx.symbol.Pooling(data=net, kernel=(4, 1), pool_type='avg', name='pool1') # bz x f x 1 x 25
if hp.dropout > 0:
net = mx.symbol.Dropout(data=net, p=hp.dropout)
hidden_concat = lstm(net,num_lstm_layer=hp.num_lstm_layer, num_hidden=hp.num_hidden, seq_length=hp.seq_length)
# mx.sym.transpose(net, [])
pred = mx.sym.FullyConnected(data=hidden_concat, num_hidden=hp.num_classes) # (bz x 25) x num_classes
if hp.loss_type:
# Training mode, add loss
return add_ctc_loss(pred, hp.seq_length, hp.num_label, hp.loss_type)
else:
# Inference mode, add softmax
return mx.sym.softmax(data=pred, name='softmax')
from hyperparams.hyperparams import Hyperparams
if __name__ == '__main__':
hp = Hyperparams()
init_states = {}
init_states['data'] = (hp.batch_size, 1, hp.img_height, hp.img_width)
init_states['label'] = (hp.batch_size, hp.num_label)
# init_c = {('l%d_init_c' % l): (hp.batch_size, hp.num_hidden) for l in range(hp.num_lstm_layer*2)}
# init_h = {('l%d_init_h' % l): (hp.batch_size, hp.num_hidden) for l in range(hp.num_lstm_layer*2)}
#
# for item in init_c:
# init_states[item] = init_c[item]
# for item in init_h:
# init_states[item] = init_h[item]
symbol = crnn_no_lstm(hp)
interals = symbol.get_internals()
_, out_shapes, _ = interals.infer_shape(**init_states)
shape_dict = dict(zip(interals.list_outputs(), out_shapes))
for item in shape_dict:
print(item,shape_dict[item])
from __future__ import print_function
import argparse
import logging
import os
import mxnet as mx
from hyperparams.hyperparams import Hyperparams
from data_utils.data_iter import ImageIter,ImageIterLstm
from symbols.crnn import crnn_no_lstm, crnn_lstm
from fit.ctc_metrics import CtcMetrics
from fit.fit import fit
def parse_args():
# Parse command line arguments
parser = argparse.ArgumentParser()
parser.add_argument("--data_root", help="Path to image files", type=str,
default='/home/richard/data/Synthetic_Chinese_String_Dataset/images')
parser.add_argument("--train_file", help="Path to train txt file", type=str,
default='/home/richard/data/Synthetic_Chinese_String_Dataset/train.txt')
parser.add_argument("--test_file", help="Path to test txt file", type=str,
default='/home/richard/data/Synthetic_Chinese_String_Dataset/test.txt')
parser.add_argument("--cpu",
help="Number of CPUs for training [Default 8]. Ignored if --gpu is specified.",
type=int, default=4)
parser.add_argument("--gpu", help="Number of GPUs for training [Default 0]", type=int, default=1)
parser.add_argument('--load_epoch', type=int,
help='load the model on an epoch using the model-load-prefix')
parser.add_argument("--prefix", help="Checkpoint prefix [Default 'ocr']", default='./check_points/model')
return parser.parse_args()
def main():
args = parse_args()
hp = Hyperparams()
init_c = [('l%d_init_c' % l, (hp.batch_size, hp.num_hidden)) for l in range(hp.num_lstm_layer * 2)]
init_h = [('l%d_init_h' % l, (hp.batch_size, hp.num_hidden)) for l in range(hp.num_lstm_layer * 2)]
init_states = init_c + init_h
data_names = ['data'] + [x[0] for x in init_states]
data_train = ImageIterLstm(
args.data_root, args.train_file, hp.batch_size, (hp.img_width, hp.img_height), hp.num_label, init_states, name="train")
data_val = ImageIterLstm(
args.data_root, args.test_file, hp.batch_size, (hp.img_width, hp.img_height), hp.num_label, init_states, name="val")
head = '%(asctime)-15s %(message)s'
logging.basicConfig(level=logging.DEBUG, format=head)
network = crnn_lstm(hp)
metrics = CtcMetrics(hp.seq_length)
fit(network=network, data_train=data_train, data_val=data_val, metrics=metrics, args=args, hp=hp, data_names=data_names)
def main2():
args = parse_args()
hp = Hyperparams()
if args.gpu:
contexts = [mx.context.gpu(i) for i in range(args.gpu)]
else:
contexts = [mx.context.cpu(i) for i in range(args.cpu)]
init_c = [('l%d_init_c' % l, (hp.batch_size, hp.num_hidden)) for l in range(hp.num_lstm_layer * 2)]
init_h = [('l%d_init_h' % l, (hp.batch_size, hp.num_hidden)) for l in range(hp.num_lstm_layer * 2)]
init_states = init_c + init_h
data_names = ['data'] + [x[0] for x in init_states]
data_train = ImageIterLstm(
args.data_root, args.train_file, hp.batch_size, (hp.img_width, hp.img_height), hp.num_label, init_states, name="train")
data_val = ImageIterLstm(
args.data_root, args.test_file, hp.batch_size, (hp.img_width, hp.img_height), hp.num_label, init_states, name="val")
head = '%(asctime)-15s %(message)s'
logging.basicConfig(level=logging.DEBUG, format=head)
symbol = crnn_lstm(hp)
module = mx.mod.Module(
symbol,
data_names=data_names,
label_names=['label'],
context=contexts)
module.bind(data_shapes=data_train.provide_data, label_shapes=data_train.provide_label)
metrics = CtcMetrics(hp.seq_length)
module.fit(train_data=data_train,
eval_data=data_val,
# use metrics.accuracy or metrics.accuracy_lcs
eval_metric=mx.metric.np(metrics.accuracy, allow_extra_outputs=True),
optimizer='AdaDelta',
optimizer_params={'learning_rate': hp.learning_rate,
# 'momentum': hp.momentum,
'wd': 0.00001,
},
initializer=mx.init.Xavier(factor_type="in", magnitude=2.34),
num_epoch=hp.num_epoch,
batch_end_callback=mx.callback.Speedometer(hp.batch_size, 50),
epoch_end_callback=mx.callback.do_checkpoint(args.prefix),
)
if __name__ == '__main__':
main()
\ No newline at end of file
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册