提交 43d7d03c 编写于 作者: 片刻小哥哥's avatar 片刻小哥哥

更新部分文档

上级 aea3d604
......@@ -102,3 +102,58 @@ u0 +++$+++ u2 +++$+++ m0 +++$+++ ['L207', 'L208']
4. 打印输出 pairs
5. voc.trim, 过滤掉词频数据(有利于让训练更快收敛的策略是去除词汇表中很少使用的单词。减少特征空间也会降低模型学习目标函数的难度)
```
['there .', 'where ?']
['you have my word . as a gentleman', 'you re sweet .']
['hi .', 'looks like things worked out tonight huh ?']
['you know chastity ?', 'i believe we share an art instructor']
['have fun tonight ?', 'tons']
['well no . . .', 'then that s all you had to say .']
['then that s all you had to say .', 'but']
['but', 'you always been this selfish ?']
['do you listen to this crap ?', 'what crap ?']
['what good stuff ?', 'the real you .']
```
### 为模型格式化数据
1. 加速训练,利用GPU并行计算能力,则需要使用小批量 `mini-batches`
2. 为了保证数据长短一致,设置 `(max_length,batch_size)`, 短于 max_length 的句子在 EOS_token 之后进行零填充 `(zero padded)`
3. 矩阵转置(以便跨第一维的索引返回批处理中所有句子的时间步长)
![](https://pytorch.apachecn.org/docs/1.0/img/b2f1969c698070d055c23fc81ab07b1b.jpg)
## 定义模型
Seq2seq模型的目标是将可变长度序列作为输入,并使用固定大小的模型将可变长度序列作为输出返回。
* Seq2Seq模型:
1. 编码器,其将可变长度输入序列编码为固定长度上下文向量。
2. 解码器,它接收输入文字和上下文矢量,并返回序列中下一句文字的概率和在下一次迭代中使用的隐藏状态。
![](https://pytorch.apachecn.org/docs/1.0/img/32a87cf8d0353ceb0037776f833b92a7.jpg)
* 编码器:
如果将填充的一批序列传递给RNN模块,我们必须分别使用torch.nn.utils.rnn.pack_padded_sequence和torch.nn.utils.rnn.pad_packed_sequence在RNN传递时分别进行填充和反填充。
```py
def forward(self, input_seq, input_lengths, hidden=None):
# Convert word indexes to embeddings
embedded = self.embedding(input_seq)
# Pack padded batch of sequences for RNN module
packed = torch.nn.utils.rnn.pack_padded_sequence(embedded, input_lengths)
# Forward pass through GRU
outputs, hidden = self.gru(packed, hidden)
# Unpack padding
outputs, _ = torch.nn.utils.rnn.pad_packed_sequence(outputs)
# Sum bidirectional GRU outputs
outputs = outputs[:, :, :self.hidden_size] + outputs[:, : ,self.hidden_size:]
# Return output and final hidden state
return outputs, hidden
```
![](https://pytorch.apachecn.org/docs/1.0/img/c653271eb5fb762482bceb5e2464e680.jpg)
* 解码器:
\ No newline at end of file
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册