diff --git a/data_utils/__init__.py b/data_utils/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/data_utils/data_iter.py b/data_utils/data_iter.py new file mode 100644 index 0000000000000000000000000000000000000000..72dd876f87564336175af811633ee669a4358680 --- /dev/null +++ b/data_utils/data_iter.py @@ -0,0 +1,266 @@ +from __future__ import print_function + +import os +from PIL import Image +import numpy as np +import mxnet as mx +import random + +def write_txt_file(): + root_path = "D:/Data/VOCtrainval_11-May-2012/test/" + + dirs = os.listdir(os.path.join(root_path,"images")) + content = [] + for d in dirs: + files = os.listdir(os.path.join(root_path,"images", d)) + for f in files: + content.append(d+"/"+f+" "+d+"\n") + + random.shuffle(content) + + train_f = open(os.path.join(root_path,"train.txt"),"w") + test_f = open(os.path.join(root_path, "test.txt"), "w") + + for i,c in enumerate(content): + if i < 0.8*len(content): + train_f.write(c) + else: + test_f.write(c) + train_f.close() + test_f.close() + +def write_mx_lst(data_type="train"): + txt_file = "D:/BaiduNetdiskDownload/Synthetic_Chinese_String_Dataset/" + in_f = open(os.path.join(txt_file, data_type+".txt"), "r") + out_f = open(os.path.join(txt_file, data_type+".lst"), "w") + lines = in_f.readlines() + random.shuffle(lines) + for idx, line in enumerate(lines): + new_line = str(idx)+"\t" + lst = line.strip().split(" ") + for i in range(len(lst)-1): + new_line = new_line+lst[i+1]+"\t" + new_line = new_line+"images/"+lst[0]+"\n" + out_f.write(new_line) + in_f.close() + out_f.close() + + + +class SimpleBatch(object): + def __init__(self, data_names, data, label_names=list(), label=list()): + self._data = data + self._label = label + self._data_names = data_names + self._label_names = label_names + + self.pad = 0 + self.index = None # TODO: what is index? + + @property + def data(self): + return self._data + + @property + def label(self): + return self._label + + @property + def data_names(self): + return self._data_names + + @property + def label_names(self): + return self._label_names + + @property + def provide_data(self): + return [(n, x.shape) for n, x in zip(self._data_names, self._data)] + + @property + def provide_label(self): + return [(n, x.shape) for n, x in zip(self._label_names, self._label)] + + +class ImageIter(mx.io.DataIter): + + """ + Iterator class for generating captcha image data + """ + def __init__(self, data_root, data_list, batch_size, data_shape, num_label, name=None): + """ + Parameters + ---------- + data_root: str + root directory of images + data_list: str + a .txt file stores the image name and corresponding labels for each line + batch_size: int + name: str + """ + super(ImageIter, self).__init__() + self.batch_size = batch_size + self.data_shape = data_shape + self.num_label = num_label + + self.data_root = data_root + self.dataset_lst_file = open(data_list) + + self.provide_data = [('data', (batch_size, 1, data_shape[1], data_shape[0]))] + self.provide_label = [('label', (self.batch_size, self.num_label))] + self.name = name + + def __iter__(self): + data = [] + label = [] + cnt = 0 + for m_line in self.dataset_lst_file: + img_lst = m_line.strip().split(' ') + img_path = os.path.join(self.data_root, img_lst[0]) + + cnt += 1 + img = Image.open(img_path).resize(self.data_shape, Image.BILINEAR).convert('L') + img = np.array(img).reshape((1, self.data_shape[1], self.data_shape[0])) + data.append(img) + + ret = np.zeros(self.num_label, int) + for idx in range(1, len(img_lst)): + ret[idx-1] = int(img_lst[idx]) + + label.append(ret) + if cnt % self.batch_size == 0: + data_all = [mx.nd.array(data)] + label_all = [mx.nd.array(label)] + data_names = ['data'] + label_names = ['label'] + data.clear() + label.clear() + yield SimpleBatch(data_names, data_all, label_names, label_all) + continue + + + def reset(self): + if self.dataset_lst_file.seekable(): + self.dataset_lst_file.seek(0) + +class ImageIterLstm(mx.io.DataIter): + + """ + Iterator class for generating captcha image data + """ + + def __init__(self, data_root, data_list, batch_size, data_shape, num_label, lstm_init_states, name=None): + """ + Parameters + ---------- + data_root: str + root directory of images + data_list: str + a .txt file stores the image name and corresponding labels for each line + batch_size: int + name: str + """ + super(ImageIterLstm, self).__init__() + self.batch_size = batch_size + self.data_shape = data_shape + self.num_label = num_label + + self.init_states = lstm_init_states + self.init_state_arrays = [mx.nd.zeros(x[1]) for x in lstm_init_states] + + self.data_root = data_root + self.dataset_lines = open(data_list).readlines() + + self.provide_data = [('data', (batch_size, 1, data_shape[1], data_shape[0]))] + lstm_init_states + self.provide_label = [('label', (self.batch_size, self.num_label))] + self.name = name + + def __iter__(self): + init_state_names = [x[0] for x in self.init_states] + data = [] + label = [] + cnt = 0 + for m_line in self.dataset_lines: + img_lst = m_line.strip().split(' ') + img_path = os.path.join(self.data_root, img_lst[0]) + + cnt += 1 + img = Image.open(img_path).resize(self.data_shape, Image.BILINEAR).convert('L') + img = np.array(img).reshape((1, self.data_shape[1], self.data_shape[0])) + data.append(img) + + ret = np.zeros(self.num_label, int) + for idx in range(1, len(img_lst)): + ret[idx - 1] = int(img_lst[idx]) + + label.append(ret) + if cnt % self.batch_size == 0: + data_all = [mx.nd.array(data)] + self.init_state_arrays + label_all = [mx.nd.array(label)] + data_names = ['data'] + init_state_names + label_names = ['label'] + data = [] + label = [] + yield SimpleBatch(data_names, data_all, label_names, label_all) + continue + + def reset(self): + # if self.dataset_lst_file.seekable(): + # self.dataset_lst_file.seek(0) + random.shuffle(self.dataset_lines) + +# def get_label(buf): +# ret = np.zeros(10) +# for i in range(len(buf)): +# ret[i] = 1 + int(buf[i]) +# if len(buf) == 9: +# ret[3] = 0 +# return ret + +# class OCRIter(mx.io.DataIter): +# """ +# Iterator class for generating captcha image data +# """ +# +# def __init__(self, count, batch_size, captcha, name): +# """ +# Parameters +# ---------- +# count: int +# Number of batches to produce for one epoch +# batch_size: int +# +# captcha MPCaptcha +# Captcha image generator. Can be MPCaptcha or any other class providing .shape and .get() interface +# name: str +# """ +# super(OCRIter, self).__init__() +# self.batch_size = batch_size +# self.count = count +# +# self.data_shape = captcha.shape +# print(self.data_shape) +# self.provide_data = [('data', (batch_size, 1, self.data_shape[0], self.data_shape[1]))] +# self.provide_label = [('label', (self.batch_size, 10))] +# self.mp_captcha = captcha +# self.name = name +# +# def __iter__(self): +# for k in range(self.count): +# data = [] +# label = [] +# for i in range(self.batch_size): +# img, num = self.mp_captcha.get() +# img = np.array(img).reshape((1, self.data_shape[0], self.data_shape[1])) +# data.append(img) +# label.append(get_label(num)) +# data_all = [mx.nd.array(data)] +# label_all = [mx.nd.array(label)] +# data_names = ['data'] +# label_names = ['label'] +# +# data_batch = SimpleBatch(data_names, data_all, label_names, label_all) +# yield data_batch + +if __name__=="__main__": + write_mx_lst("test") \ No newline at end of file diff --git a/fit/__init__.py b/fit/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/fit/ctc_loss.py b/fit/ctc_loss.py new file mode 100644 index 0000000000000000000000000000000000000000..ff5161dfa44c74f33334514aa367ca0ca2be44e8 --- /dev/null +++ b/fit/ctc_loss.py @@ -0,0 +1,33 @@ +import mxnet as mx + +def _add_warp_ctc_loss(pred, seq_len, num_label, label): + """ Adds Symbol.contrib.ctc_loss on top of pred symbol and returns the resulting symbol """ + label = mx.sym.Reshape(data=label, shape=(-1,)) + label = mx.sym.Cast(data=label, dtype='int32') + return mx.sym.WarpCTC(data=pred, label=label, label_length=num_label, input_length=seq_len) + + +def _add_mxnet_ctc_loss(pred, seq_len, label): + """ Adds Symbol.WapCTC on top of pred symbol and returns the resulting symbol """ + pred_ctc = mx.sym.Reshape(data=pred, shape=(-4, seq_len, -1, 0)) + + loss = mx.sym.contrib.ctc_loss(data=pred_ctc, label=label) + ctc_loss = mx.sym.MakeLoss(loss) + + softmax_class = mx.symbol.SoftmaxActivation(data=pred) + softmax_loss = mx.sym.MakeLoss(softmax_class) + softmax_loss = mx.sym.BlockGrad(softmax_loss) + return mx.sym.Group([softmax_loss, ctc_loss]) + + +def add_ctc_loss(pred, seq_len, num_label, loss_type): + """ Adds CTC loss on top of pred symbol and returns the resulting symbol """ + label = mx.sym.Variable('label') + if loss_type == 'warpctc': + print("Using WarpCTC Loss") + sm = _add_warp_ctc_loss(pred, seq_len, num_label, label) + else: + print("Using MXNet CTC Loss") + assert loss_type == 'ctc' + sm = _add_mxnet_ctc_loss(pred, seq_len, label) + return sm \ No newline at end of file diff --git a/fit/ctc_metrics.py b/fit/ctc_metrics.py new file mode 100644 index 0000000000000000000000000000000000000000..0db680af18d708133400ee2443992c5c5d4447b7 --- /dev/null +++ b/fit/ctc_metrics.py @@ -0,0 +1,114 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +"""Contains a class for calculating CTC eval metrics""" + +from __future__ import print_function + +import numpy as np + + +class CtcMetrics(object): + def __init__(self, seq_len): + self.seq_len = seq_len + + @staticmethod + def ctc_label(p): + """ + Iterates through p, identifying non-zero and non-repeating values, and returns them in a list + Parameters + ---------- + p: list of int + + Returns + ------- + list of int + """ + ret = [] + p1 = [0] + p + for i, _ in enumerate(p): + c1 = p1[i] + c2 = p1[i+1] + if c2 == 0 or c2 == c1: + continue + ret.append(c2) + return ret + + @staticmethod + def _remove_blank(l): + """ Removes trailing zeros in the list of integers and returns a new list of integers""" + ret = [] + for i, _ in enumerate(l): + if l[i] == 0: + break + ret.append(l[i]) + return ret + + @staticmethod + def _lcs(p, l): + """ Calculates the Longest Common Subsequence between p and l (both list of int) and returns its length""" + # Dynamic Programming Finding LCS + if len(p) == 0: + return 0 + P = np.array(list(p)).reshape((1, len(p))) + L = np.array(list(l)).reshape((len(l), 1)) + M = np.int32(P == L) + for i in range(M.shape[0]): + for j in range(M.shape[1]): + up = 0 if i == 0 else M[i-1, j] + left = 0 if j == 0 else M[i, j-1] + M[i, j] = max(up, left, M[i, j] if (i == 0 or j == 0) else M[i, j] + M[i-1, j-1]) + return M.max() + + def accuracy(self, label, pred): + """ Simple accuracy measure: number of 100% accurate predictions divided by total number """ + hit = 0. + total = 0. + batch_size = label.shape[0] + for i in range(batch_size): + l = self._remove_blank(label[i]) + p = [] + for k in range(self.seq_len): + p.append(np.argmax(pred[k * batch_size + i])) + p = self.ctc_label(p) + if len(p) == len(l): + match = True + for k, _ in enumerate(p): + if p[k] != int(l[k]): + match = False + break + if match: + hit += 1.0 + total += 1.0 + assert total == batch_size + return hit / total + + def accuracy_lcs(self, label, pred): + """ Longest Common Subsequence accuracy measure: calculate accuracy of each prediction as LCS/length""" + hit = 0. + total = 0. + batch_size = label.shape[0] + for i in range(batch_size): + l = self._remove_blank(label[i]) + p = [] + for k in range(self.seq_len): + p.append(np.argmax(pred[k * batch_size + i])) + p = self.ctc_label(p) + hit += self._lcs(p, l) * 1.0 / len(l) + total += 1.0 + assert total == batch_size + return hit / total + diff --git a/fit/fit.py b/fit/fit.py new file mode 100644 index 0000000000000000000000000000000000000000..47ce4e441e119713ec38309bdc72087b2ff20228 --- /dev/null +++ b/fit/fit.py @@ -0,0 +1,52 @@ + +import logging +import os +import mxnet as mx + + +def _load_model(args, rank=0): + if 'load_epoch' not in args or args.load_epoch is None: + return (None, None, None) + assert args.prefix is not None + model_prefix = args.prefix + if rank > 0 and os.path.exists("%s-%d-symbol.json" % (model_prefix, rank)): + model_prefix += "-%d" % (rank) + sym, arg_params, aux_params = mx.model.load_checkpoint( + model_prefix, args.load_epoch) + logging.info('Loaded model %s_%04d.params', model_prefix, args.load_epoch) + return (sym, arg_params, aux_params) + + +def fit(network, data_train, data_val, metrics, args, hp, data_names=None): + if args.gpu: + contexts = [mx.context.gpu(i) for i in range(args.gpu)] + else: + contexts = [mx.context.cpu(i) for i in range(args.cpu)] + + sym, arg_params, aux_params = _load_model(args) + if sym is not None: + assert sym.tojson() == network.tojson() + + module = mx.mod.Module( + symbol = network, + data_names= ["data"] if data_names is None else data_names, + label_names=['label'], + context=contexts) + + module.fit(train_data=data_train, + eval_data=data_val, + begin_epoch=args.load_epoch if args.load_epoch else 0, + num_epoch=hp.num_epoch, + # use metrics.accuracy or metrics.accuracy_lcs + eval_metric=mx.metric.np(metrics.accuracy, allow_extra_outputs=True), + optimizer='AdaDelta', + optimizer_params={'learning_rate': hp.learning_rate, + # 'momentum': hp.momentum, + 'wd': 0.00001, + }, + initializer=mx.init.Xavier(factor_type="in", magnitude=2.34), + arg_params=arg_params, + aux_params=aux_params, + batch_end_callback=mx.callback.Speedometer(hp.batch_size, 50), + epoch_end_callback=mx.callback.do_checkpoint(args.prefix), + ) \ No newline at end of file diff --git a/fit/lstm.py b/fit/lstm.py new file mode 100644 index 0000000000000000000000000000000000000000..6402daeac2cf54e2377363bcae63da4da175648a --- /dev/null +++ b/fit/lstm.py @@ -0,0 +1,83 @@ +from __future__ import print_function + +from collections import namedtuple +import mxnet as mx + +LSTMState = namedtuple("LSTMState", ["c", "h"]) +LSTMParam = namedtuple("LSTMParam", ["i2h_weight", "i2h_bias", + "h2h_weight", "h2h_bias"]) +LSTMModel = namedtuple("LSTMModel", ["rnn_exec", "symbol", + "init_states", "last_states", "forward_state", "backward_state", + "seq_data", "seq_labels", "seq_outputs", + "param_blocks"]) + +def _lstm(num_hidden, indata, prev_state, param, seqidx, layeridx): + """LSTM Cell symbol""" + i2h = mx.sym.FullyConnected(data=indata, + weight=param.i2h_weight, + bias=param.i2h_bias, + num_hidden=num_hidden * 4, + name="t%d_l%d_i2h" % (seqidx, layeridx)) + h2h = mx.sym.FullyConnected(data=prev_state.h, + weight=param.h2h_weight, + bias=param.h2h_bias, + num_hidden=num_hidden * 4, + name="t%d_l%d_h2h" % (seqidx, layeridx)) + gates = i2h + h2h + slice_gates = mx.sym.split(gates, num_outputs=4, + name="t%d_l%d_slice" % (seqidx, layeridx)) + in_gate = mx.sym.Activation(slice_gates[0], act_type="sigmoid") + in_transform = mx.sym.Activation(slice_gates[1], act_type="tanh") + forget_gate = mx.sym.Activation(slice_gates[2], act_type="sigmoid") + out_gate = mx.sym.Activation(slice_gates[3], act_type="sigmoid") + next_c = (forget_gate * prev_state.c) + (in_gate * in_transform) + next_h = out_gate * mx.sym.Activation(next_c, act_type="tanh") + return LSTMState(c=next_c, h=next_h) + +def lstm(net, num_lstm_layer, num_hidden, seq_length): + last_states = [] + forward_param = [] + backward_param = [] + + for i in range(num_lstm_layer * 2): + last_states.append(LSTMState(c=mx.sym.Variable("l%d_init_c" % i), h=mx.sym.Variable("l%d_init_h" % i))) + if i % 2 == 0: + forward_param.append(LSTMParam(i2h_weight=mx.sym.Variable("l%d_i2h_weight" % i), + i2h_bias=mx.sym.Variable("l%d_i2h_bias" % i), + h2h_weight=mx.sym.Variable("l%d_h2h_weight" % i), + h2h_bias=mx.sym.Variable("l%d_h2h_bias" % i))) + else: + backward_param.append(LSTMParam(i2h_weight=mx.sym.Variable("l%d_i2h_weight" % i), + i2h_bias=mx.sym.Variable("l%d_i2h_bias" % i), + h2h_weight=mx.sym.Variable("l%d_h2h_weight" % i), + h2h_bias=mx.sym.Variable("l%d_h2h_bias" % i))) + + slices_net = mx.sym.split(data=net, axis=3, num_outputs=seq_length, squeeze_axis=1) # bz x features x 1 x time_step + + forward_hidden = [] + for seqidx in range(seq_length): + hidden = mx.sym.flatten(data=slices_net[seqidx]) + for i in range(num_lstm_layer): + next_state = _lstm(num_hidden, indata=hidden, prev_state=last_states[2 * i], + param=forward_param[i], seqidx=seqidx, layeridx=0) + hidden = next_state.h + last_states[2 * i] = next_state + forward_hidden.append(hidden) + + backward_hidden = [] + for seqidx in range(seq_length): + k = seq_length - seqidx - 1 + hidden = mx.sym.flatten(data=slices_net[k]) + for i in range(num_lstm_layer): + next_state = _lstm(num_hidden, indata=hidden, prev_state=last_states[2 * i + 1], + param=backward_param[i], seqidx=k, layeridx=1) + hidden = next_state.h + last_states[2 * i + 1] = next_state + backward_hidden.insert(0, hidden) + + hidden_all = [] + for i in range(seq_length): + hidden_all.append(mx.sym.concat(*[forward_hidden[i], backward_hidden[i]], dim=1)) + + hidden_concat = mx.sym.concat(*hidden_all, dim=0) + return hidden_concat \ No newline at end of file diff --git a/hyperparams/__init__.py b/hyperparams/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/hyperparams/hyperparams.py b/hyperparams/hyperparams.py new file mode 100644 index 0000000000000000000000000000000000000000..12ba00e3d8b81d80dca4ddfad61f4f248e69651f --- /dev/null +++ b/hyperparams/hyperparams.py @@ -0,0 +1,114 @@ +from __future__ import print_function + + +class Hyperparams(object): + """ + Hyperparameters for LSTM network + """ + def __init__(self): + # Training hyper parameters + self._train_epoch_size = 30000 + self._eval_epoch_size = 3000 + self._num_epoch = 20 + self._learning_rate = 0.001 + self._momentum = 0.9 + self._bn_mom = 0.9 + self._workspace = 512 + self._loss_type = "warpctc" # ["warpctc" "ctc"] + + self._batch_size = 128 + self._num_classes = 5990 + self._img_width = 280 + self._img_height = 32 + + # DenseNet hyper parameters + self._depth = 161 + self._growrate = 32 + self._reduction = 0.5 + + # LSTM hyper parameters + self._num_hidden = 100 + self._num_lstm_layer = 2 + self._seq_length = 35 + self._num_label = 10 + self._drop_out = 0.5 + + @property + def train_epoch_size(self): + return self._train_epoch_size + + @property + def eval_epoch_size(self): + return self._eval_epoch_size + + @property + def num_epoch(self): + return self._num_epoch + + @property + def learning_rate(self): + return self._learning_rate + + @property + def momentum(self): + return self._momentum + + @property + def bn_mom(self): + return self._bn_mom + + @property + def workspace(self): + return self._workspace + + @property + def loss_type(self): + return self._loss_type + + @property + def batch_size(self): + return self._batch_size + + @property + def num_classes(self): + return self._num_classes + + @property + def img_width(self): + return self._img_width + + @property + def img_height(self): + return self._img_height + + @property + def depth(self): + return self._depth + + @property + def growrate(self): + return self._growrate + + @property + def reduction(self): + return self._reduction + + @property + def num_hidden(self): + return self._num_hidden + + @property + def num_lstm_layer(self): + return self._num_lstm_layer + + @property + def seq_length(self): + return self._seq_length + + @property + def num_label(self): + return self._num_label + + @property + def dropout(self): + return self._drop_out diff --git a/symbols/__init__.py b/symbols/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/symbols/__pycache__/__init__.cpython-36.pyc b/symbols/__pycache__/__init__.cpython-36.pyc new file mode 100644 index 0000000000000000000000000000000000000000..3c174c0b4be5f8dea086608a64bdbfe177f8d18f Binary files /dev/null and b/symbols/__pycache__/__init__.cpython-36.pyc differ diff --git a/symbols/__pycache__/crnn_no_lstm.cpython-36.pyc b/symbols/__pycache__/crnn_no_lstm.cpython-36.pyc new file mode 100644 index 0000000000000000000000000000000000000000..98400aa6d2fc25b2c6a167f8a62017742877a9cf Binary files /dev/null and b/symbols/__pycache__/crnn_no_lstm.cpython-36.pyc differ diff --git a/symbols/__pycache__/ctc_loss.cpython-36.pyc b/symbols/__pycache__/ctc_loss.cpython-36.pyc new file mode 100644 index 0000000000000000000000000000000000000000..1af88ee9592cd5f920cf6bbaec9ddfc3c405faab Binary files /dev/null and b/symbols/__pycache__/ctc_loss.cpython-36.pyc differ diff --git a/symbols/__pycache__/ctc_metrics.cpython-36.pyc b/symbols/__pycache__/ctc_metrics.cpython-36.pyc new file mode 100644 index 0000000000000000000000000000000000000000..f2b088fadb1b813d993d049ffc5c31a7241bf42a Binary files /dev/null and b/symbols/__pycache__/ctc_metrics.cpython-36.pyc differ diff --git a/symbols/crnn.py b/symbols/crnn.py new file mode 100644 index 0000000000000000000000000000000000000000..5692bc84857f7a2fe5637a56260a7805326ba9b8 --- /dev/null +++ b/symbols/crnn.py @@ -0,0 +1,153 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +""" +LeCun, Yann, Leon Bottou, Yoshua Bengio, and Patrick Haffner. +Gradient-based learning applied to document recognition. +Proceedings of the IEEE (1998) +""" +import mxnet as mx +from fit.ctc_loss import add_ctc_loss +from fit.lstm import lstm + +def crnn_no_lstm(hp): + + # input + data = mx.sym.Variable('data') + label = mx.sym.Variable('label') + + kernel_size = [(3, 3), (3, 3), (3, 3), (3, 3), (3, 3), (3, 3)] + padding_size = [(1, 1), (1, 1), (1, 1), (1, 1), (1, 1), (1, 1)] + layer_size = [min(32*2**(i+1), 512) for i in range(len(kernel_size))] + + def convRelu(i, input_data, bn=True): + layer = mx.symbol.Convolution(name='conv-%d' % i, data=input_data, kernel=kernel_size[i], pad=padding_size[i], + num_filter=layer_size[i]) + if bn: + layer = mx.sym.BatchNorm(data=layer, name='batchnorm-%d' % i) + layer = mx.sym.LeakyReLU(data=layer,name='leakyrelu-%d' % i) + return layer + + net = convRelu(0, data) # bz x f x 32 x 200 + max = mx.sym.Pooling(data=net, name='pool-0_m', pool_type='max', kernel=(2, 2), stride=(2, 2)) + avg = mx.sym.Pooling(data=net, name='pool-0_a', pool_type='avg', kernel=(2, 2), stride=(2, 2)) + net = max - avg # 16 x 100 + net = convRelu(1, net) + net = mx.sym.Pooling(data=net, name='pool-1', pool_type='max', kernel=(2, 2), stride=(2, 2)) # bz x f x 8 x 50 + net = convRelu(2, net, True) + net = convRelu(3, net) + net = mx.sym.Pooling(data=net, name='pool-2', pool_type='max', kernel=(2, 2), stride=(2, 2)) # bz x f x 4 x 25 + net = convRelu(4, net, True) + net = convRelu(5, net) + net = mx.symbol.Pooling(data=net, kernel=(4, 1), pool_type='avg', name='pool1') # bz x f x 1 x 25 + + if hp.dropout > 0: + net = mx.symbol.Dropout(data=net, p=hp.dropout) + + net = mx.sym.transpose(data=net, axes=[1,0,2,3]) # f x bz x 1 x 25 + net = mx.sym.flatten(data=net) # f x (bz x 25) + hidden_concat = mx.sym.transpose(data=net, axes=[1,0]) # (bz x 25) x f + + # mx.sym.transpose(net, []) + pred = mx.sym.FullyConnected(data=hidden_concat, num_hidden=hp.num_classes) # (bz x 25) x num_classes + + if hp.loss_type: + # Training mode, add loss + return add_ctc_loss(pred, hp.seq_length, hp.num_label, hp.loss_type) + else: + # Inference mode, add softmax + return mx.sym.softmax(data=pred, name='softmax') + + +def crnn_lstm(hp): + + # input + data = mx.sym.Variable('data') + label = mx.sym.Variable('label') + + kernel_size = [(3, 3), (3, 3), (3, 3), (3, 3), (3, 3), (3, 3)] + padding_size = [(1, 1), (1, 1), (1, 1), (1, 1), (1, 1), (1, 1)] + layer_size = [min(32*2**(i+1), 512) for i in range(len(kernel_size))] + + def convRelu(i, input_data, bn=True): + layer = mx.symbol.Convolution(name='conv-%d' % i, data=input_data, kernel=kernel_size[i], pad=padding_size[i], + num_filter=layer_size[i]) + if bn: + layer = mx.sym.BatchNorm(data=layer, name='batchnorm-%d' % i) + layer = mx.sym.LeakyReLU(data=layer,name='leakyrelu-%d' % i) + layer = mx.symbol.Convolution(name='conv-%d-1x1' % i, data=layer, kernel=(1, 1), pad=(0, 0), + num_filter=layer_size[i]) + if bn: + layer = mx.sym.BatchNorm(data=layer, name='batchnorm-%d-1x1' % i) + layer = mx.sym.LeakyReLU(data=layer, name='leakyrelu-%d-1x1' % i) + return layer + + net = convRelu(0, data) # bz x f x 32 x 200 + max = mx.sym.Pooling(data=net, name='pool-0_m', pool_type='max', kernel=(2, 2), stride=(2, 2)) + avg = mx.sym.Pooling(data=net, name='pool-0_a', pool_type='avg', kernel=(2, 2), stride=(2, 2)) + net = max - avg # 16 x 100 + net = convRelu(1, net) + net = mx.sym.Pooling(data=net, name='pool-1', pool_type='max', kernel=(2, 2), stride=(2, 2)) # bz x f x 8 x 50 + net = convRelu(2, net, True) + net = convRelu(3, net) + net = mx.sym.Pooling(data=net, name='pool-2', pool_type='max', kernel=(2, 2), stride=(2, 2)) # bz x f x 4 x 25 + net = convRelu(4, net, True) + net = convRelu(5, net) + net = mx.symbol.Pooling(data=net, kernel=(4, 1), pool_type='avg', name='pool1') # bz x f x 1 x 25 + + if hp.dropout > 0: + net = mx.symbol.Dropout(data=net, p=hp.dropout) + + hidden_concat = lstm(net,num_lstm_layer=hp.num_lstm_layer, num_hidden=hp.num_hidden, seq_length=hp.seq_length) + + # mx.sym.transpose(net, []) + pred = mx.sym.FullyConnected(data=hidden_concat, num_hidden=hp.num_classes) # (bz x 25) x num_classes + + if hp.loss_type: + # Training mode, add loss + return add_ctc_loss(pred, hp.seq_length, hp.num_label, hp.loss_type) + else: + # Inference mode, add softmax + return mx.sym.softmax(data=pred, name='softmax') + + +from hyperparams.hyperparams import Hyperparams + +if __name__ == '__main__': + hp = Hyperparams() + + init_states = {} + init_states['data'] = (hp.batch_size, 1, hp.img_height, hp.img_width) + init_states['label'] = (hp.batch_size, hp.num_label) + + # init_c = {('l%d_init_c' % l): (hp.batch_size, hp.num_hidden) for l in range(hp.num_lstm_layer*2)} + # init_h = {('l%d_init_h' % l): (hp.batch_size, hp.num_hidden) for l in range(hp.num_lstm_layer*2)} + # + # for item in init_c: + # init_states[item] = init_c[item] + # for item in init_h: + # init_states[item] = init_h[item] + + symbol = crnn_no_lstm(hp) + interals = symbol.get_internals() + _, out_shapes, _ = interals.infer_shape(**init_states) + shape_dict = dict(zip(interals.list_outputs(), out_shapes)) + + for item in shape_dict: + print(item,shape_dict[item]) + + diff --git a/train.py b/train.py new file mode 100644 index 0000000000000000000000000000000000000000..9ac070cbe12a54198a8e508c564a7794f1daa2fe --- /dev/null +++ b/train.py @@ -0,0 +1,108 @@ +from __future__ import print_function + +import argparse +import logging +import os +import mxnet as mx + +from hyperparams.hyperparams import Hyperparams +from data_utils.data_iter import ImageIter,ImageIterLstm +from symbols.crnn import crnn_no_lstm, crnn_lstm +from fit.ctc_metrics import CtcMetrics +from fit.fit import fit + +def parse_args(): + # Parse command line arguments + parser = argparse.ArgumentParser() + + parser.add_argument("--data_root", help="Path to image files", type=str, + default='/home/richard/data/Synthetic_Chinese_String_Dataset/images') + parser.add_argument("--train_file", help="Path to train txt file", type=str, + default='/home/richard/data/Synthetic_Chinese_String_Dataset/train.txt') + parser.add_argument("--test_file", help="Path to test txt file", type=str, + default='/home/richard/data/Synthetic_Chinese_String_Dataset/test.txt') + parser.add_argument("--cpu", + help="Number of CPUs for training [Default 8]. Ignored if --gpu is specified.", + type=int, default=4) + parser.add_argument("--gpu", help="Number of GPUs for training [Default 0]", type=int, default=1) + parser.add_argument('--load_epoch', type=int, + help='load the model on an epoch using the model-load-prefix') + parser.add_argument("--prefix", help="Checkpoint prefix [Default 'ocr']", default='./check_points/model') + return parser.parse_args() + +def main(): + args = parse_args() + hp = Hyperparams() + + init_c = [('l%d_init_c' % l, (hp.batch_size, hp.num_hidden)) for l in range(hp.num_lstm_layer * 2)] + init_h = [('l%d_init_h' % l, (hp.batch_size, hp.num_hidden)) for l in range(hp.num_lstm_layer * 2)] + init_states = init_c + init_h + data_names = ['data'] + [x[0] for x in init_states] + + data_train = ImageIterLstm( + args.data_root, args.train_file, hp.batch_size, (hp.img_width, hp.img_height), hp.num_label, init_states, name="train") + data_val = ImageIterLstm( + args.data_root, args.test_file, hp.batch_size, (hp.img_width, hp.img_height), hp.num_label, init_states, name="val") + + head = '%(asctime)-15s %(message)s' + logging.basicConfig(level=logging.DEBUG, format=head) + + network = crnn_lstm(hp) + + metrics = CtcMetrics(hp.seq_length) + + fit(network=network, data_train=data_train, data_val=data_val, metrics=metrics, args=args, hp=hp, data_names=data_names) + + +def main2(): + args = parse_args() + hp = Hyperparams() + + if args.gpu: + contexts = [mx.context.gpu(i) for i in range(args.gpu)] + else: + contexts = [mx.context.cpu(i) for i in range(args.cpu)] + + + init_c = [('l%d_init_c' % l, (hp.batch_size, hp.num_hidden)) for l in range(hp.num_lstm_layer * 2)] + init_h = [('l%d_init_h' % l, (hp.batch_size, hp.num_hidden)) for l in range(hp.num_lstm_layer * 2)] + init_states = init_c + init_h + data_names = ['data'] + [x[0] for x in init_states] + + data_train = ImageIterLstm( + args.data_root, args.train_file, hp.batch_size, (hp.img_width, hp.img_height), hp.num_label, init_states, name="train") + data_val = ImageIterLstm( + args.data_root, args.test_file, hp.batch_size, (hp.img_width, hp.img_height), hp.num_label, init_states, name="val") + + head = '%(asctime)-15s %(message)s' + logging.basicConfig(level=logging.DEBUG, format=head) + + symbol = crnn_lstm(hp) + module = mx.mod.Module( + symbol, + data_names=data_names, + label_names=['label'], + context=contexts) + + module.bind(data_shapes=data_train.provide_data, label_shapes=data_train.provide_label) + + metrics = CtcMetrics(hp.seq_length) + + module.fit(train_data=data_train, + eval_data=data_val, + # use metrics.accuracy or metrics.accuracy_lcs + eval_metric=mx.metric.np(metrics.accuracy, allow_extra_outputs=True), + optimizer='AdaDelta', + optimizer_params={'learning_rate': hp.learning_rate, + # 'momentum': hp.momentum, + 'wd': 0.00001, + }, + initializer=mx.init.Xavier(factor_type="in", magnitude=2.34), + num_epoch=hp.num_epoch, + batch_end_callback=mx.callback.Speedometer(hp.batch_size, 50), + epoch_end_callback=mx.callback.do_checkpoint(args.prefix), + ) + + +if __name__ == '__main__': + main() \ No newline at end of file