提交 3df8fc31 编写于 作者: X xiaohang

add -lang .module. max_locs alpha emition

上级 fd54339d
......@@ -28,3 +28,18 @@ for line in lines:
print >> test_fp, output
test_fp.close()
with open('data/mnt/ramdisk/max/90kDICT32px/annotation_test.txt') as fp:
lines = fp.readlines()
val_fp = open('data/val_list.txt', 'w')
for line in lines:
imgpath = line.strip().split(' ')[0]
label = imgpath.split('/')[-1].split('_')[1].lower()
label = label + '$'
label = ':'.join(label)
imgpath = 'data/mnt/ramdisk/max/90kDICT32px/%s' % imgpath
output = ' '.join([imgpath, label])
print >> val_fp, output
val_fp.close()
......@@ -14,6 +14,7 @@ import dataset
import time
import models.crnn as crnn
print(crnn.__name__)
parser = argparse.ArgumentParser()
parser.add_argument('--trainlist', required=True, help='path to train_list')
......@@ -39,6 +40,7 @@ parser.add_argument('--saveInterval', type=int, default=10000, help='Interval to
parser.add_argument('--adam', action='store_true', help='Whether to use adam (default is rmsprop)')
parser.add_argument('--adadelta', action='store_true', help='Whether to use adadelta (default is rmsprop)')
parser.add_argument('--keep_ratio', action='store_true', help='whether to keep ratio for image resize')
parser.add_argument('--lang', action='store_true', help='whether to use char language model')
parser.add_argument('--random_sample', action='store_true', help='whether to sample the dataset with random sampler')
opt = parser.parse_args()
print(opt)
......@@ -150,7 +152,10 @@ def val(net, dataset, criterion, max_iter=100):
utils.loadData(text, t)
utils.loadData(length, l)
preds = crnn(image, length)
if opt.lang:
preds = crnn(image, length, text)
else:
preds = crnn(image, length)
cost = criterion(preds, text)
loss_avg.add(cost)
......@@ -179,7 +184,10 @@ def trainBatch(net, criterion, optimizer):
utils.loadData(text, t)
utils.loadData(length, l)
preds = crnn(image, length)
if opt.lang:
preds = crnn(image, length, text)
else:
preds = crnn(image, length)
cost = criterion(preds, text)
crnn.zero_grad()
cost.backward()
......@@ -214,4 +222,4 @@ for epoch in range(opt.niter):
# do checkpointing
if i % opt.saveInterval == 0:
torch.save(
crnn.state_dict(), '{0}/netCRNN_{1}_{2}.pth'.format(opt.experiment, epoch, i))
crnn.module.state_dict(), '{0}/netCRNN_{1}_{2}.pth'.format(opt.experiment, epoch, i))
......@@ -31,8 +31,10 @@ class AttentionCell(nn.Module):
self.rnn = nn.GRUCell(input_size, hidden_size)
self.hidden_size = hidden_size
self.input_size = input_size
self.processed_batches = 0
def forward(self, prev_hidden, feats):
self.processed_batches = self.processed_batches + 1
nT = feats.size(0)
nB = feats.size(1)
nC = feats.size(2)
......@@ -43,6 +45,11 @@ class AttentionCell(nn.Module):
prev_hidden_proj = self.h2h(prev_hidden).view(1,nB, hidden_size).expand(nT, nB, hidden_size).contiguous().view(-1, hidden_size)
emition = self.score(F.tanh(feats_proj + prev_hidden_proj).view(-1, hidden_size)).view(nT,nB).transpose(0,1)
alpha = F.softmax(emition) # nB * nT
if self.processed_batches % 10000 == 0:
print('emition ', list(emition.data[0]))
print('alpha ', list(alpha.data[0]))
context = (feats * alpha.transpose(0,1).contiguous().view(nT,nB,1).expand(nT, nB, nC)).sum(0).squeeze(0)
cur_hidden = self.rnn(context, prev_hidden)
return cur_hidden, alpha
......@@ -54,8 +61,10 @@ class Attention(nn.Module):
self.input_size = input_size
self.hidden_size = hidden_size
self.generator = nn.Linear(hidden_size, num_classes)
self.processed_batches = 0
def forward(self, feats, text_length):
self.processed_batches = self.processed_batches + 1
nT = feats.size(0)
nB = feats.size(1)
nC = feats.size(2)
......@@ -69,9 +78,18 @@ class Attention(nn.Module):
output_hiddens = Variable(torch.zeros(num_steps, nB, hidden_size).type_as(feats.data))
hidden = Variable(torch.zeros(nB,hidden_size).type_as(feats.data))
max_locs = torch.zeros(num_steps, nB)
max_vals = torch.zeros(num_steps, nB)
for i in range(num_steps):
hidden, alpha = self.attention_cell(hidden, feats)
output_hiddens[i] = hidden
if self.processed_batches % 500 == 0:
max_val, max_loc = alpha.data.max(1)
max_locs[i] = max_loc.cpu()
max_vals[i] = max_val.cpu()
if self.processed_batches % 500 == 0:
print('max_locs', list(max_locs[0:text_length.data[0],0]))
print('max_vals', list(max_vals[0:text_length.data[0],0]))
new_hiddens = Variable(torch.zeros(num_labels, hidden_size).type_as(feats.data))
b = 0
start = 0
......@@ -88,40 +106,6 @@ class CRNN(nn.Module):
super(CRNN, self).__init__()
assert imgH % 16 == 0, 'imgH has to be a multiple of 16'
ks = [3, 3, 3, 3, 3, 3, 2]
ps = [1, 1, 1, 1, 1, 1, 0]
ss = [1, 1, 1, 1, 1, 1, 1]
nm = [64, 128, 256, 256, 512, 512, 512]
cnn = nn.Sequential()
def convRelu(i, batchNormalization=False):
nIn = nc if i == 0 else nm[i - 1]
nOut = nm[i]
cnn.add_module('conv{0}'.format(i),
nn.Conv2d(nIn, nOut, ks[i], ss[i], ps[i]))
if batchNormalization:
cnn.add_module('batchnorm{0}'.format(i), nn.BatchNorm2d(nOut))
if leakyRelu:
cnn.add_module('relu{0}'.format(i),
nn.LeakyReLU(0.2, inplace=True))
else:
cnn.add_module('relu{0}'.format(i), nn.ReLU(True))
convRelu(0)
cnn.add_module('pooling{0}'.format(0), nn.MaxPool2d(2, 2)) # 64x16x64
convRelu(1)
cnn.add_module('pooling{0}'.format(1), nn.MaxPool2d(2, 2)) # 128x8x32
convRelu(2, True)
convRelu(3)
cnn.add_module('pooling{0}'.format(2),
nn.MaxPool2d((2, 2), (2, 1), (0, 1))) # 256x4x16
convRelu(4, True)
convRelu(5)
cnn.add_module('pooling{0}'.format(3),
nn.MaxPool2d((2, 2), (2, 1), (0, 1))) # 512x2x16
convRelu(6, True) # 512x1x16
self.cnn = nn.Sequential(
nn.Conv2d(nc, 64, 3, 1, 1), nn.ReLU(True), nn.MaxPool2d(2, 2), # 64x16x50
nn.Conv2d(64, 128, 3, 1, 1), nn.ReLU(True), nn.MaxPool2d(2, 2), # 128x8x25
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册