提交 748f3330 编写于 作者: J jrzaurin

removing attention layer. Will be implemented in future releases

上级 a658304e
......@@ -10,15 +10,12 @@ from ..wdtypes import *
class DeepText(nn.Module):
def __init__(self, vocab_size:int, embedding_dim:int, hidden_dim:int, n_layers:int,
rnn_dropout:float, spatial_dropout:float, padding_idx:int, output_dim:int,
attention:bool=False, bidirectional:bool=False,
embedding_matrix:Optional[np.ndarray]=None):
bidirectional:bool=False, embedding_matrix:Optional[np.ndarray]=None):
super(DeepText, self).__init__()
"""
Standard Text Classifier/Regressor with a stack of RNNs.
"""
self.bidirectional = bidirectional
self.attention = attention
self.spatial_dropout = spatial_dropout
self.embedding_dropout = nn.Dropout2d(spatial_dropout)
self.embedding = nn.Embedding(vocab_size, embedding_dim, padding_idx = padding_idx)
......@@ -33,20 +30,6 @@ class DeepText(nn.Module):
input_dim = hidden_dim*2 if bidirectional else hidden_dim
self.dtlinear = nn.Linear(input_dim, output_dim)
def attention_net(self, output:Tensor, hidden:Tensor)->Tensor:
"""
Attention through Soft alignment Score between output and last hidden.
Read here (and references therein) for more details:
https://machinelearningmastery.com/how-does-attention-work-in-encoder-decoder-recurrent-neural-networks/
code from here (there are more sophisticated approaches but these will do):
https://github.com/prakashpandey9/Text-Classification-Pytorch/blob/master/models/LSTM_Attn.py
"""
attn_weights = torch.bmm(output, hidden.unsqueeze(2)).squeeze(2)
attn_weights = F.softmax(attn_weights, 1)
new_hidden = torch.bmm(output.transpose(1, 2), attn_weights.unsqueeze(2)).squeeze(2)
return new_hidden
def forward(self, X:Tensor)->Tensor:
embedded = self.embedding(X)
......@@ -62,7 +45,5 @@ class DeepText(nn.Module):
last_h = torch.cat((h[-2], h[-1]), dim = 1)
else:
last_h = h[-1]
if self.attention:
last_h = self.attention_net(o, last_h)
out = self.dtlinear(last_h)
return out
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册