提交 68f97869 编写于 作者: X xiaohang

the optimized Attention reduce the time from 900 to 65

上级 f31b54e2
......@@ -35,7 +35,6 @@ class AttentionCell(nn.Module):
def forward(self, prev_hidden, feats):
nT = feats.size(0)
nB = feats.size(1)
assert(nB == 1)
nC = feats.size(2)
hidden_size = self.hidden_size
input_size = self.input_size
......@@ -65,18 +64,22 @@ class Attention(nn.Module):
assert(input_size == nC)
assert(nB == text_length.numel())
num_steps = text_length.data.max()
num_labels = text_length.data.sum()
output_hiddens = Variable(torch.zeros(num_labels, hidden_size).type_as(feats.data))
k = 0
for j in range(nB):
sub_feats = feats[:,j,:].contiguous().view(nT,1,nC) #feats.index_select(1, Variable(torch.LongTensor([j]).type_as(feats.data)))
sub_hidden = Variable(torch.zeros(1,hidden_size).type_as(feats.data))
for i in range(text_length.data[j]):
sub_hidden, sub_alpha = self.attention_cell(sub_hidden, sub_feats)
output_hiddens[k] = sub_hidden.view(-1)
k = k + 1
probs = self.generator(output_hiddens)
output_hiddens = Variable(torch.zeros(num_steps, nB, hidden_size).type_as(feats.data))
hidden = Variable(torch.zeros(nB,hidden_size).type_as(feats.data))
for i in range(num_steps):
hidden, alpha = self.attention_cell(hidden, feats)
output_hiddens[i] = hidden
new_hiddens = Variable(torch.zeros(num_labels, hidden_size).type_as(feats.data))
b = 0
start = 0
for length in text_length.data:
new_hiddens[start:start+length] = output_hiddens[0:length,b,:]
start = start + length
b = b + 1
probs = self.generator(new_hiddens)
return probs
class CRNN(nn.Module):
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册