提交 76d0139f 编写于 作者: X xiaohang

add README

上级 dbe73da0
......@@ -31,3 +31,8 @@ Train a new model
-----------------
1. Construct dataset following origin guide. For training with variable length, please sort the image according to the text length.
2. ``python crnn_main.py [--param val]``. Explore ``crnn_main.py`` for details.
Stable commits
--------------
dbe73da0dd7efb8bd76dbd7f0ac3856e742b98d4: support image list with label and alphabet
### Train on vgg recogniton txt
- download mjsynth.tar.gz and unzip to current folder
- copy annotation_train.txt annotation_test.txt annotation_val.txt to current
- correct path info
- create imagelist: cat annotation_train.imgs | awk -F / '{print $NF}' | awk -F _ '{print $2}' | tr [:upper:] [:lower:]
- python create_dataset.py
require('table')
require('torch')
require('os')
function clone(t)
-- deep-copy a table
if type(t) ~= "table" then return t end
local meta = getmetatable(t)
local target = {}
for k, v in pairs(t) do
if type(v) == "table" then
target[k] = clone(v)
else
target[k] = v
end
end
setmetatable(target, meta)
return target
end
function tableMerge(lhs, rhs)
output = clone(lhs)
for _, v in pairs(rhs) do
table.insert(output, v)
end
return output
end
function isInTable(val, val_list)
for _, item in pairs(val_list) do
if val == item then
return true
end
end
return false
end
function modelToList(model)
local ignoreList = {
'nn.Copy',
'nn.AddConstant',
'nn.MulConstant',
'nn.View',
'nn.Transpose',
'nn.SplitTable',
'nn.SharedParallelTable',
'nn.JoinTable',
}
local state = {}
local param
for i, layer in pairs(model.modules) do
local typeName = torch.type(layer)
if not isInTable(typeName, ignoreList) then
if typeName == 'nn.Sequential' or typeName == 'nn.ConcatTable' then
param = modelToList(layer)
elseif typeName == 'cudnn.SpatialConvolution' or typeName == 'nn.SpatialConvolution' then
param = layer:parameters()
elseif typeName == 'cudnn.SpatialBatchNormalization' or typeName == 'nn.SpatialBatchNormalization' then
param = layer:parameters()
bn_vars = {layer.running_mean, layer.running_var}
param = tableMerge(param, bn_vars)
elseif typeName == 'nn.LstmLayer' then
param = layer:parameters()
elseif typeName == 'nn.BiRnnJoin' then
param = layer:parameters()
elseif typeName == 'cudnn.SpatialMaxPooling' or typeName == 'nn.SpatialMaxPooling' then
param = {}
elseif typeName == 'cudnn.ReLU' or typeName == 'nn.ReLU' then
param = {}
else
print(string.format('Unknown class %s', typeName))
os.exit(0)
end
table.insert(state, {typeName, param})
else
print(string.format('pass %s', typeName))
end
end
return state
end
function saveModel(model, output_path)
local state = modelToList(model)
torch.save(output_path, state)
end
import torchfile
import argparse
import torch
from torch.nn.parameter import Parameter
import numpy as np
import models.crnn as crnn
layer_map = {
'SpatialConvolution': 'Conv2d',
'SpatialBatchNormalization': 'BatchNorm2d',
'ReLU': 'ReLU',
'SpatialMaxPooling': 'MaxPool2d',
'SpatialAveragePooling': 'AvgPool2d',
'SpatialUpSamplingNearest': 'UpsamplingNearest2d',
'View': None,
'Linear': 'linear',
'Dropout': 'Dropout',
'SoftMax': 'Softmax',
'Identity': None,
'SpatialFullConvolution': 'ConvTranspose2d',
'SpatialReplicationPadding': None,
'SpatialReflectionPadding': None,
'Copy': None,
'Narrow': None,
'SpatialCrossMapLRN': None,
'Sequential': None,
'ConcatTable': None, # output is list
'CAddTable': None, # input is list
'Concat': None,
'TorchObject': None,
'LstmLayer': 'LSTM',
'BiRnnJoin': 'Linear'
}
def torch_layer_serial(layer, layers):
name = layer[0]
if name == 'nn.Sequential' or name == 'nn.ConcatTable':
tmp_layers = []
for sub_layer in layer[1]:
torch_layer_serial(sub_layer, tmp_layers)
layers.extend(tmp_layers)
else:
layers.append(layer)
def py_layer_serial(layer, layers):
"""
Assume modules are defined as executive sequence.
"""
if len(layer._modules) >= 1:
tmp_layers = []
for sub_layer in layer.children():
py_layer_serial(sub_layer, tmp_layers)
layers.extend(tmp_layers)
else:
layers.append(layer)
def trans_pos(param, part_indexes, dim=0):
parts = np.split(param, len(part_indexes), dim)
new_parts = []
for i in part_indexes:
new_parts.append(parts[i])
return np.concatenate(new_parts, dim)
def load_params(py_layer, t7_layer):
if type(py_layer).__name__ == 'LSTM':
# LSTM
all_weights = []
num_directions = 2 if py_layer.bidirectional else 1
for i in range(py_layer.num_layers):
for j in range(num_directions):
suffix = '_reverse' if j == 1 else ''
weights = ['weight_ih_l{}{}', 'bias_ih_l{}{}',
'weight_hh_l{}{}', 'bias_hh_l{}{}']
weights = [x.format(i, suffix) for x in weights]
all_weights += weights
params = []
for i in range(len(t7_layer)):
params.extend(t7_layer[i][1])
params = [trans_pos(p, [0, 1, 3, 2], dim=0) for p in params]
else:
all_weights = []
name = t7_layer[0].split('.')[-1]
if name == 'BiRnnJoin':
weight_0, bias_0, weight_1, bias_1 = t7_layer[1]
weight = np.concatenate((weight_0, weight_1), axis=1)
bias = bias_0 + bias_1
t7_layer[1] = [weight, bias]
all_weights += ['weight', 'bias']
elif name == 'SpatialConvolution' or name == 'Linear':
all_weights += ['weight', 'bias']
elif name == 'SpatialBatchNormalization':
all_weights += ['weight', 'bias', 'running_mean', 'running_var']
params = t7_layer[1]
params = [torch.from_numpy(item) for item in params]
assert len(all_weights) == len(params), "params' number not match"
for py_param_name, t7_param in zip(all_weights, params):
item = getattr(py_layer, py_param_name)
if isinstance(item, Parameter):
item = item.data
try:
item.copy_(t7_param)
except RuntimeError:
print('Size not match between %s and %s' %
(item.size(), t7_param.size()))
def torch_to_pytorch(model, t7_file, output):
py_layers = []
for layer in list(model.children()):
py_layer_serial(layer, py_layers)
t7_data = torchfile.load(t7_file)
t7_layers = []
for layer in t7_data:
torch_layer_serial(layer, t7_layers)
j = 0
for i, py_layer in enumerate(py_layers):
py_name = type(py_layer).__name__
t7_layer = t7_layers[j]
t7_name = t7_layer[0].split('.')[-1]
if layer_map[t7_name] != py_name:
raise RuntimeError('%s does not match %s' % (py_name, t7_name))
if py_name == 'LSTM':
n_layer = 2 if py_layer.bidirectional else 1
n_layer *= py_layer.num_layers
t7_layer = t7_layers[j:j + n_layer]
j += n_layer
else:
j += 1
load_params(py_layer, t7_layer)
torch.save(model.state_dict(), output)
if __name__ == "__main__":
parser = argparse.ArgumentParser(
description='Convert torch t7 model to pytorch'
)
parser.add_argument(
'--model_file',
'-m',
type=str,
required=True,
help='torch model file in t7 format'
)
parser.add_argument(
'--output',
'-o',
type=str,
default=None,
help='output file name prefix, xxx.py xxx.pth'
)
args = parser.parse_args()
py_model = crnn.CRNN(32, 1, 37, 256, 1)
torch_to_pytorch(py_model, args.model_file, args.output)
import os
import lmdb # install lmdb by "pip install lmdb"
import cv2
import numpy as np
def checkImageIsValid(imageBin):
if imageBin is None:
return False
try:
imageBuf = np.fromstring(imageBin, dtype=np.uint8)
img = cv2.imdecode(imageBuf, cv2.IMREAD_GRAYSCALE)
imgH, imgW = img.shape[0], img.shape[1]
if imgH * imgW == 0:
return False
return True
except Exception:
return False
def writeCache(env, cache):
with env.begin(write=True) as txn:
for k, v in cache.iteritems():
txn.put(k, v)
def createDataset(outputPath, imagePathList, labelList, lexiconList=None, checkValid=True):
"""
Create LMDB dataset for CRNN training.
ARGS:
outputPath : LMDB output path
imagePathList : list of image path
labelList : list of corresponding groundtruth texts
lexiconList : (optional) list of lexicon lists
checkValid : if true, check the validity of every image
"""
assert(len(imagePathList) == len(labelList))
nSamples = len(imagePathList)
env = lmdb.open(outputPath, map_size=1099511627776)
cache = {}
cnt = 1
for i in xrange(nSamples):
imagePath = imagePathList[i]
label = labelList[i]
if not os.path.exists(imagePath):
print('%s does not exist' % imagePath)
continue
with open(imagePath, 'r') as f:
imageBin = f.read()
if checkValid:
#print('check %s' % imagePath)
#print('len(imageBin) = %d' % len(imageBin))
if len(imageBin) == 0 or (not checkImageIsValid(imageBin)):
print('%s is not a valid image' % imagePath)
continue
imageKey = 'image-%09d' % cnt
labelKey = 'label-%09d' % cnt
cache[imageKey] = imageBin
cache[labelKey] = label
if lexiconList:
lexiconKey = 'lexicon-%09d' % cnt
cache[lexiconKey] = ' '.join(lexiconList[i])
if cnt % 1000 == 0:
writeCache(env, cache)
cache = {}
print('Written %d / %d' % (cnt, nSamples))
cnt += 1
nSamples = cnt-1
cache['num-samples'] = str(nSamples)
writeCache(env, cache)
print('Created dataset with %d samples' % nSamples)
if __name__ == '__main__':
imagePathList = open('annotation_train.imgs').read().split('\n')
labelList = open('annotation_train.labels').read().split('\n')
outputPath = 'data/train_lmdb'
createDataset(outputPath, imagePathList, labelList, lexiconList=None, checkValid=True)
import lmdb
import six
from PIL import Image
env = lmdb.open('lmdb/test_lmdb',
max_readers=1,
readonly=True,
lock=False,
readahead=False,
meminit=False)
label_fp = open('out/labels.txt', 'w')
with env.begin(write=False) as txn:
nSamples = int(txn.get('num-samples'))
#print nSamples
for index in range(nSamples):
image_key = 'image-%09d' % (index+1)
label_key = 'label-%09d' % (index+1)
imgbuf = txn.get(image_key)
buf = six.BytesIO()
buf.write(imgbuf)
buf.seek(0)
try:
img = Image.open(buf)
savename = "out/%06d.png" % (index+1)
img.save(savename)
print("save %s" % savename)
except IOError:
print('Corrupted image for %d' % index)
label = txn.get(label_key)
print >> label_fp, label
label_fp.close()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册