提交 2f84edb4 编写于 作者: X xiaohang

add crnn_lang

上级 0385b322
......@@ -13,7 +13,7 @@ import utils
import dataset
import time
import models.crnn as crnn
import models.crnn_lang as crnn
print(crnn.__name__)
parser = argparse.ArgumentParser()
......
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.autograd import Variable
from torch.nn.parameter import Parameter
class BidirectionalLSTM(nn.Module):
def __init__(self, nIn, nHidden, nOut):
super(BidirectionalLSTM, self).__init__()
self.rnn = nn.LSTM(nIn, nHidden, bidirectional=True)
self.embedding = nn.Linear(nHidden * 2, nOut)
def forward(self, input):
recurrent, _ = self.rnn(input)
T, b, h = recurrent.size()
t_rec = recurrent.view(T * b, h)
output = self.embedding(t_rec) # [T * b, nOut]
output = output.view(T, b, -1)
return output
class AttentionCell(nn.Module):
def __init__(self, input_size, hidden_size, num_embeddings=128):
super(AttentionCell, self).__init__()
self.i2h = nn.Linear(input_size, hidden_size,bias=False)
self.h2h = nn.Linear(hidden_size, hidden_size)
self.score = nn.Linear(hidden_size, 1, bias=False)
self.rnn = nn.GRUCell(input_size+num_embeddings, hidden_size)
self.hidden_size = hidden_size
self.input_size = input_size
self.num_embeddings = num_embeddings
self.processed_batches = 0
def forward(self, prev_hidden, feats, cur_embeddings):
nT = feats.size(0)
nB = feats.size(1)
nC = feats.size(2)
hidden_size = self.hidden_size
input_size = self.input_size
feats_proj = self.i2h(feats.view(-1,nC))
prev_hidden_proj = self.h2h(prev_hidden).view(1,nB, hidden_size).expand(nT, nB, hidden_size).contiguous().view(-1, hidden_size)
emition = self.score(F.tanh(feats_proj + prev_hidden_proj).view(-1, hidden_size)).view(nT,nB).transpose(0,1)
self.processed_batches = self.processed_batches + 1
if self.processed_batches % 10000 == 0:
print('processed_batches = %d' %(self.processed_batches))
alpha = F.softmax(emition) # nB * nT
if self.processed_batches % 10000 == 0:
print('emition ', list(emition.data[0]))
print('alpha ', list(alpha.data[0]))
context = (feats * alpha.transpose(0,1).contiguous().view(nT,nB,1).expand(nT, nB, nC)).sum(0).squeeze(0) # nB * nC
context = torch.cat([context, cur_embeddings], 1)
cur_hidden = self.rnn(context, prev_hidden)
return cur_hidden, alpha
class Attention(nn.Module):
def __init__(self, input_size, hidden_size, num_classes, num_embeddings=128):
super(Attention, self).__init__()
self.attention_cell = AttentionCell(input_size, hidden_size, num_embeddings)
self.input_size = input_size
self.hidden_size = hidden_size
self.generator = nn.Linear(hidden_size, num_classes)
self.char_embeddings = Parameter(torch.randn(num_classes+1, num_embeddings))
self.num_embeddings = num_embeddings
self.processed_batches = 0
# targets is nT * nB
def forward(self, feats, text_length, text):
self.processed_batches = self.processed_batches + 1
nT = feats.size(0)
nB = feats.size(1)
nC = feats.size(2)
hidden_size = self.hidden_size
input_size = self.input_size
assert(input_size == nC)
assert(nB == text_length.numel())
num_steps = text_length.data.max()
num_labels = text_length.data.sum()
targets = torch.zeros(nB, num_steps+1).long().cuda()
start_id = 0
for i in range(nB):
targets[i][1:1+text_length.data[i]] = text.data[start_id:start_id+text_length.data[i]]+1
start_id = start_id+text_length.data[i]
targets = Variable(targets.transpose(0,1).contiguous())
output_hiddens = Variable(torch.zeros(num_steps, nB, hidden_size).type_as(feats.data))
hidden = Variable(torch.zeros(nB,hidden_size).type_as(feats.data))
max_locs = torch.zeros(num_steps, nB)
max_vals = torch.zeros(num_steps, nB)
for i in range(num_steps):
cur_embeddings = self.char_embeddings.index_select(0, targets[i])
hidden, alpha = self.attention_cell(hidden, feats, cur_embeddings)
output_hiddens[i] = hidden
if self.processed_batches % 500 == 0:
max_val, max_loc = alpha.data.max(1)
max_locs[i] = max_loc.cpu()
max_vals[i] = max_val.cpu()
if self.processed_batches % 500 == 0:
print('max_locs', list(max_locs[0:text_length.data[0],0]))
print('max_vals', list(max_vals[0:text_length.data[0],0]))
new_hiddens = Variable(torch.zeros(num_labels, hidden_size).type_as(feats.data))
b = 0
start = 0
for length in text_length.data:
new_hiddens[start:start+length] = output_hiddens[0:length,b,:]
start = start + length
b = b + 1
probs = self.generator(new_hiddens)
return probs
class CRNN(nn.Module):
def __init__(self, imgH, nc, nclass, nh, n_rnn=2, leakyRelu=False):
super(CRNN, self).__init__()
assert imgH % 16 == 0, 'imgH has to be a multiple of 16'
self.cnn = nn.Sequential(
nn.Conv2d(nc, 64, 3, 1, 1), nn.ReLU(True), nn.MaxPool2d(2, 2), # 64x16x50
nn.Conv2d(64, 128, 3, 1, 1), nn.ReLU(True), nn.MaxPool2d(2, 2), # 128x8x25
nn.Conv2d(128, 256, 3, 1, 1), nn.BatchNorm2d(256), nn.ReLU(True), # 256x8x25
nn.Conv2d(256, 256, 3, 1, 1), nn.ReLU(True), nn.MaxPool2d((2,2), (2,1), (0,1)), # 256x4x25
nn.Conv2d(256, 512, 3, 1, 1), nn.BatchNorm2d(512), nn.ReLU(True), # 512x4x25
nn.Conv2d(512, 512, 3, 1, 1), nn.ReLU(True), nn.MaxPool2d((2,2), (2,1), (0,1)), # 512x2x25
nn.Conv2d(512, 512, 2, 1, 0), nn.BatchNorm2d(512), nn.ReLU(True)) # 512x1x25
#self.cnn = cnn
self.rnn = nn.Sequential(
BidirectionalLSTM(512, nh, nh),
BidirectionalLSTM(nh, nh, nh))
self.attention = Attention(nh, nh, nclass, 256)
def forward(self, input, length, text):
# conv features
conv = self.cnn(input)
b, c, h, w = conv.size()
assert h == 1, "the height of conv must be 1"
conv = conv.squeeze(2)
conv = conv.permute(2, 0, 1) # [w, b, c]
# rnn features
rnn = self.rnn(conv)
output = self.attention(rnn, length, text)
return output
nohup python main.py --trainlist data/train_list.txt --vallist data/val_list.txt --cuda --adam --lr=0.001 > log.txt &
#CUDA_VISIBLE_DEVICES=0 nohup python main.py --experiment expr_basic --trainlist data/train_list.txt --vallist data/val_list.txt --cuda --adam --lr=0.001 > log.txt &
CUDA_VISIBLE_DEVICES=1 nohup python main.py --lang --experiment expr_basic_lang --trainlist data/train_list.txt --vallist data/val_list.txt --cuda --adam --lr=0.001 > log_lang.txt &
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册