提交 f31b54e2 编写于 作者: X xiaohang

the first attention version, which seems works but very slow

上级 11ae8de2
......@@ -5,6 +5,7 @@ train_fp = open('data/train_list.txt', 'w')
for line in lines:
imgpath = line.strip().split(' ')[0]
label = imgpath.split('/')[-1].split('_')[1].lower()
label = label + '$'
label = ':'.join(label)
imgpath = 'data/mnt/ramdisk/max/90kDICT32px/%s' % imgpath
output = ' '.join([imgpath, label])
......@@ -20,6 +21,7 @@ test_fp = open('data/test_list.txt', 'w')
for line in lines:
imgpath = line.strip().split(' ')[0]
label = imgpath.split('/')[-1].split('_')[1].lower()
label = label + '$'
label = ':'.join(label)
imgpath = 'data/mnt/ramdisk/max/90kDICT32px/%s' % imgpath
output = ' '.join([imgpath, label])
......
......@@ -29,7 +29,7 @@ parser.add_argument('--beta1', type=float, default=0.5, help='beta1 for adam. de
parser.add_argument('--cuda', action='store_true', help='enables cuda')
parser.add_argument('--ngpu', type=int, default=1, help='number of GPUs to use')
parser.add_argument('--crnn', default='', help="path to crnn (to continue training)")
parser.add_argument('--alphabet', type=str, default='0:1:2:3:4:5:6:7:8:9:a:b:c:d:e:f:g:h:i:j:k:l:m:n:o:p:q:r:s:t:u:v:w:x:y:z')
parser.add_argument('--alphabet', type=str, default='0:1:2:3:4:5:6:7:8:9:a:b:c:d:e:f:g:h:i:j:k:l:m:n:o:p:q:r:s:t:u:v:w:x:y:z:$')
parser.add_argument('--sep', type=str, default=':')
parser.add_argument('--experiment', default=None, help='Where to store samples and models')
parser.add_argument('--displayInterval', type=int, default=500, help='Interval to be displayed')
......@@ -75,8 +75,8 @@ test_dataset = dataset.listDataset(list_file =opt.vallist, transform=dataset.res
nclass = len(opt.alphabet.split(opt.sep)) + 1
nc = 1
converter = utils.strLabelConverter(opt.alphabet, opt.sep)
criterion = CTCLoss()
converter = utils.strLabelConverterForAttention(opt.alphabet, opt.sep)
criterion = torch.nn.CrossEntropyLoss()
# custom weights initialization called on crnn
......@@ -97,13 +97,14 @@ if opt.crnn != '':
print(crnn)
image = torch.FloatTensor(opt.batchSize, 3, opt.imgH, opt.imgH)
text = torch.IntTensor(opt.batchSize * 5)
text = torch.LongTensor(opt.batchSize * 5)
length = torch.IntTensor(opt.batchSize)
if opt.cuda:
crnn.cuda()
crnn = torch.nn.DataParallel(crnn, device_ids=range(opt.ngpu))
image = image.cuda()
text = text.cuda()
criterion = criterion.cuda()
image = Variable(image)
......@@ -149,24 +150,21 @@ def val(net, dataset, criterion, max_iter=100):
utils.loadData(text, t)
utils.loadData(length, l)
preds = crnn(image)
preds_size = Variable(torch.IntTensor([preds.size(0)] * batch_size))
cost = criterion(preds, text, preds_size, length) / batch_size
preds = crnn(image, length)
cost = criterion(preds, text)
loss_avg.add(cost)
_, preds = preds.max(2)
#preds = preds.squeeze(2)
preds = preds.transpose(1, 0).contiguous().view(-1)
sim_preds = converter.decode(preds.data, preds_size.data, raw=False)
preds = preds.view(-1)
sim_preds = converter.decode(preds.data, length.data)
for pred, target in zip(sim_preds, cpu_texts):
target = ''.join(target.split(opt.sep))
if pred == target:
n_correct += 1
raw_preds = converter.decode(preds.data, preds_size.data, raw=True)[:opt.n_test_disp]
for raw_pred, pred, gt in zip(raw_preds, sim_preds, cpu_texts):
for pred, gt in zip(sim_preds, cpu_texts):
gt = ''.join(gt.split(opt.sep))
print('%-20s => %-20s, gt: %-20s' % (raw_pred, pred, gt))
print('%-20s, gt: %-20s' % (pred, gt))
accuracy = n_correct / float(max_iter * opt.batchSize)
print('Test loss: %f, accuray: %f' % (loss_avg.val(), accuracy))
......@@ -181,9 +179,8 @@ def trainBatch(net, criterion, optimizer):
utils.loadData(text, t)
utils.loadData(length, l)
preds = crnn(image)
preds_size = Variable(torch.IntTensor([preds.size(0)] * batch_size))
cost = criterion(preds, text, preds_size, length) / batch_size
preds = crnn(image, length)
cost = criterion(preds, text)
crnn.zero_grad()
cost.backward()
optimizer.step()
......
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.autograd import Variable
class BidirectionalLSTM(nn.Module):
......@@ -19,6 +22,62 @@ class BidirectionalLSTM(nn.Module):
return output
class AttentionCell(nn.Module):
def __init__(self, input_size, hidden_size):
super(AttentionCell, self).__init__()
self.i2h = nn.Linear(input_size, hidden_size,bias=False)
self.h2h = nn.Linear(hidden_size, hidden_size)
self.score = nn.Linear(hidden_size, 1, bias=False)
self.rnn = nn.GRUCell(input_size, hidden_size)
self.hidden_size = hidden_size
self.input_size = input_size
def forward(self, prev_hidden, feats):
nT = feats.size(0)
nB = feats.size(1)
assert(nB == 1)
nC = feats.size(2)
hidden_size = self.hidden_size
input_size = self.input_size
feats_proj = self.i2h(feats.view(-1,nC))
prev_hidden_proj = self.h2h(prev_hidden).view(1,nB, hidden_size).expand(nT, nB, hidden_size).contiguous().view(-1, hidden_size)
emition = self.score(F.tanh(feats_proj + prev_hidden_proj).view(-1, hidden_size)).view(nT,nB).transpose(0,1)
alpha = F.softmax(emition) # nB * nT
context = (feats * alpha.transpose(0,1).contiguous().view(nT,nB,1).expand(nT, nB, nC)).sum(0).squeeze(0)
cur_hidden = self.rnn(context, prev_hidden)
return cur_hidden, alpha
class Attention(nn.Module):
def __init__(self, input_size, hidden_size, num_classes):
super(Attention, self).__init__()
self.attention_cell = AttentionCell(input_size, hidden_size)
self.input_size = input_size
self.hidden_size = hidden_size
self.generator = nn.Linear(hidden_size, num_classes)
def forward(self, feats, text_length):
nT = feats.size(0)
nB = feats.size(1)
nC = feats.size(2)
hidden_size = self.hidden_size
input_size = self.input_size
assert(input_size == nC)
assert(nB == text_length.numel())
num_labels = text_length.data.sum()
output_hiddens = Variable(torch.zeros(num_labels, hidden_size).type_as(feats.data))
k = 0
for j in range(nB):
sub_feats = feats[:,j,:].contiguous().view(nT,1,nC) #feats.index_select(1, Variable(torch.LongTensor([j]).type_as(feats.data)))
sub_hidden = Variable(torch.zeros(1,hidden_size).type_as(feats.data))
for i in range(text_length.data[j]):
sub_hidden, sub_alpha = self.attention_cell(sub_hidden, sub_feats)
output_hiddens[k] = sub_hidden.view(-1)
k = k + 1
probs = self.generator(output_hiddens)
return probs
class CRNN(nn.Module):
......@@ -71,9 +130,10 @@ class CRNN(nn.Module):
#self.cnn = cnn
self.rnn = nn.Sequential(
BidirectionalLSTM(512, nh, nh),
BidirectionalLSTM(nh, nh, nclass))
BidirectionalLSTM(nh, nh, nh))
self.attention = Attention(nh, nh/2, nclass)
def forward(self, input):
def forward(self, input, length):
# conv features
conv = self.cnn(input)
b, c, h, w = conv.size()
......@@ -82,6 +142,7 @@ class CRNN(nn.Module):
conv = conv.permute(2, 0, 1) # [w, b, c]
# rnn features
output = self.rnn(conv)
rnn = self.rnn(conv)
output = self.attention(rnn, length)
return output
......@@ -6,8 +6,78 @@ import torch.nn as nn
from torch.autograd import Variable
import collections
class strLabelConverterForAttention(object):
"""Convert between str and label.
NOTE:
Insert `EOS` to the alphabet for attention.
Args:
alphabet (str): set of the possible characters.
ignore_case (bool, default=True): whether or not to ignore all of the case.
"""
def __init__(self, alphabet, sep):
self.sep = sep
self.alphabet = alphabet.split(sep)
self.dict = {}
for i, item in enumerate(self.alphabet):
# NOTE: 0 is reserved for 'blank' required by wrap_ctc
self.dict[item] = i
def encode(self, text):
"""Support batch or single str.
Args:
text (str or list of str): texts to convert.
Returns:
torch.IntTensor [length_0 + length_1 + ... length_{n - 1}]: encoded texts.
torch.IntTensor [n]: length of each text.
"""
if isinstance(text, str):
text = text.split(self.sep)
text = [self.dict[item] for item in text]
length = [len(text)]
elif isinstance(text, collections.Iterable):
length = [len(s.split(self.sep)) for s in text]
text = self.sep.join(text)
text, _ = self.encode(text)
return (torch.LongTensor(text), torch.LongTensor(length))
def decode(self, t, length):
"""Decode encoded texts back into strs.
Args:
torch.IntTensor [length_0 + length_1 + ... length_{n - 1}]: encoded texts.
torch.IntTensor [n]: length of each text.
Raises:
AssertionError: when the texts and its length does not match.
Returns:
text (str or list of str): texts to convert.
"""
if length.numel() == 1:
length = length[0]
assert t.numel() == length, "text with length: {} does not match declared length: {}".format(t.numel(), length)
if raw:
return ''.join([self.alphabet[i] for i in t])
else:
# batch mode
assert t.numel() == length.sum(), "texts with length: {} does not match declared length: {}".format(t.numel(), length.sum())
texts = []
index = 0
for i in range(length.numel()):
l = length[i]
texts.append(
self.decode(
t[index:index + l], torch.LongTensor([l])))
index += l
return texts
class strLabelConverter(object):
class strLabelConverterForCTC(object):
"""Convert between str and label.
NOTE:
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册