提交 fd2fa2bc 编写于 作者: W wilfChen

GPU Lstm network

上级 cf4d317d
...@@ -88,6 +88,6 @@ class SentimentNet(nn.Cell): ...@@ -88,6 +88,6 @@ class SentimentNet(nn.Cell):
embeddings = self.trans(embeddings, self.perm) embeddings = self.trans(embeddings, self.perm)
output, _ = self.encoder(embeddings, (self.h, self.c)) output, _ = self.encoder(embeddings, (self.h, self.c))
# states[i] size(64,200) -> encoding.size(64,400) # states[i] size(64,200) -> encoding.size(64,400)
encoding = self.concat((output[0], output[-1])) encoding = self.concat((output[0], output[499]))
outputs = self.decoder(encoding) outputs = self.decoder(encoding)
return outputs return outputs
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册