Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
OpenDocCN
d2l-zh
提交
8c8cfa93
D
d2l-zh
项目概览
OpenDocCN
/
d2l-zh
通知
2
Star
0
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
D
d2l-zh
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
前往新版Gitcode,体验更适合开发者的 AI 搜索 >>
未验证
提交
8c8cfa93
编写于
2月 09, 2018
作者:
A
Aston Zhang
提交者:
GitHub
2月 09, 2018
浏览文件
操作
浏览文件
下载
差异文件
Merge pull request #206 from astonzhang/seq
add beam search
上级
681c30f4
db733518
变更
1
隐藏空白更改
内联
并排
Showing
1 changed file
with
31 addition
and
2 deletion
+31
-2
chapter_natural-language-processing/nmt.md
chapter_natural-language-processing/nmt.md
+31
-2
未找到文件。
chapter_natural-language-processing/nmt.md
浏览文件 @
8c8cfa93
...
...
@@ -263,7 +263,7 @@ def translate(encoder, decoder, decoder_init_state, fr_ens, ctx, max_seq_len):
print('Expect:', fr_en[1], '\n')
```
下面定义模型训练函数。为了初始化解码器的隐含状态,我们通过一层全连接网络来转化编码器最早时刻的输出隐含状态。
下面定义模型训练函数。为了初始化解码器的隐含状态,我们通过一层全连接网络来转化编码器最早时刻的输出隐含状态。
这里的解码器使用当前时刻的预测结果作为下一时刻的输入。
```
{.python .input}
def train(encoder, decoder, decoder_init_state, max_seq_len, ctx, eval_fr_ens):
...
...
@@ -302,7 +302,7 @@ def train(encoder, decoder, decoder_init_state, max_seq_len, ctx, eval_fr_ens):
for i in range(max_seq_len):
decoder_output, decoder_state = decoder(
decoder_input, decoder_state, encoder_outputs)
#
使用当前时刻的预测结果作为下一时刻的编码器
输入。
#
解码器使用当前时刻的预测结果作为下一时刻的
输入。
decoder_input = nd.array(
[decoder_output.argmax(axis=1).asscalar()], ctx=ctx)
loss = loss + softmax_cross_entropy(decoder_output, y[0][i])
...
...
@@ -349,6 +349,32 @@ eval_fr_ens =[['elle est japonaise .', 'she is japanese .'],
train(encoder, decoder, decoder_init_state, max_seq_len, ctx, eval_fr_ens)
```
## 束搜索
在上一节里,我们提到编码器最终输出了一个背景向量$
\m
athbf{c}$,该背景向量编码了输入序列$x_1, x_2,
\l
dots, x_T$的信息。假设训练数据中的输出序列是$y_1, y_2,
\l
dots, y_{T^
\p
rime}$,输出序列的生成概率是
$$
\m
athbb{P}(y_1,
\l
dots, y_{T^
\p
rime}) =
\p
rod_{t^
\p
rime=1}^{T^
\p
rime}
\m
athbb{P}(y_{t^
\p
rime}
\m
id y_1,
\l
dots, y_{t^
\p
rime-1},
\m
athbf{c})$$
对于机器翻译的输出来说,如果输出语言的词汇集合$
\m
athcal{Y}$的大小为$|
\m
athcal{Y}|$,输出序列的长度为$T^
\p
rime$,那么可能的输出序列种类是$
\m
athcal{O}(|
\m
athcal{Y}|^{T^
\p
rime})$。为了找到生成概率最大的输出序列,一种方法是计算所有$
\m
athcal{O}(|
\m
athcal{Y}|^{T^
\p
rime})$种可能序列的生成概率,并输出概率最大的序列。我们将该序列称为最优序列。但是这种方法的计算开销过高(例如,$10000^{10} = 1
\t
imes 10^{40}$)。
我们目前所介绍的解码器在每个时刻只输出生成概率最大的一个词汇。对于任一时刻$t^
\p
rime$,我们从$|
\m
athcal{Y}|$个词中搜索出输出词
$$y_{t^
\p
rime} =
\t
ext{argmax}_{y_{t^
\p
rime}
\i
n
\m
athcal{Y}}
\m
athbb{P}(y_{t^
\p
rime}
\m
id y_1,
\l
dots, y_{t^
\p
rime-1},
\m
athbf{c})$$
因此,搜索计算开销($
\m
athcal{O}(|
\m
athcal{Y}|
\t
imes {T^
\p
rime})$)显著下降(例如,$10000
\t
imes 10 = 1
\t
imes 10^5$),但这并不能保证一定搜索到最优序列。
束搜索(beam search)介于上面二者之间。我们来看一个例子。
假设输出序列的词典中只包含五个词:$
\m
athcal{Y} =
\{
A, B, C, D, E
\}
$。束搜索的一个超参数叫做束宽(beam width)。以束宽等于2为例,假设输出序列长度为3,假如时刻1生成概率$
\m
athbb{P}(y_{t^
\p
rime}
\m
id
\m
athbf{c})$最大的两个词为$A$和$C$,我们在时刻2对于所有的$y_2
\i
n
\m
athcal{Y}$都分别计算$
\m
athbb{P}(y_2
\m
id A,
\m
athbf{c})$和$
\m
athbb{P}(y_2
\m
id C,
\m
athbf{c})$,从计算出的10个概率中取最大的两个,假设为$
\m
athbb{P}(B
\m
id A,
\m
athbf{c})$和$
\m
athbb{P}(E
\m
id C,
\m
athbf{c})$。那么,我们在时刻3对于所有的$y_3
\i
n
\m
athcal{Y}$都分别计算$
\m
athbb{P}(y_3
\m
id A, B,
\m
athbf{c})$和$
\m
athbb{P}(y_3
\m
id C, E,
\m
athbf{c})$,从计算出的10个概率中取最大的两个,假设为$
\m
athbb{P}(D
\m
id A, B,
\m
athbf{c})$和$
\m
athbb{P}(D
\m
id C, E,
\m
athbf{c})$。
接下来,我们可以在输出序列:$A$、$C$、$AB$、$CE$、$ABD$、$CED$中筛选出以特殊字符EOS结尾的候选序列。再在候选序列中取以下分数最高的序列作为最终候选序列:
$$
\f
rac{1}{L^
\a
lpha}
\l
og
\m
athbb{P}(y_1,
\l
dots, y_{L})$$
其中$L$为候选序列长度,$
\a
lpha$一般可选为0.75。分母上的$L^
\a
lpha$是为了惩罚较长序列的分数中的相加项。
## 结论
*
我们可以将编码器—解码器和注意力机制应用于神经机器翻译中。
...
...
@@ -357,6 +383,9 @@ train(encoder, decoder, decoder_init_state, max_seq_len, ctx, eval_fr_ens)
## 练习
*
试着使用更大的翻译数据集来训练模型,例如
[
WMT
](
http://www.statmt.org/wmt14/translation-task.html
)
和
[
Tatoeba Project
](
http://www.manythings.org/anki/
)
。调一调不同参数并观察实验结果。
*
Teacher forcing:在模型训练中,试着让解码器使用当前时刻的正确结果(而不是预测结果)作为下一时刻的输入。结果会怎么样?
**吐槽和讨论欢迎点**
[
这里
](
https://discuss.gluon.ai/t/topic/4689
)
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录