提交 ac22c225 编写于 作者: A Aston Zhang

revise beam search, mt

上级 e0e6b96b
......@@ -2,15 +2,15 @@
上一节介绍了如何训练输入输出均为不定长序列的编码器—解码器,这一节我们介绍如何使用编码器—解码器来预测不定长的序列。
上一节里已经提到,在准备训练数据集时,我们通常会在样本的输入序列和输出序列后面分别附上一个特殊符号“<eos>”表示序列的终止。我们在接下来的讨论中也将沿用上一节的数学符号。为了便于讨论,假设解码器的输出是一段文本序列。设输出文本词典$\mathcal{Y}$(包含特殊符号“<eos>”)的大小为$|\mathcal{Y}|$,输出序列的最大长度为$T'$。所有可能的输出序列一共有$\mathcal{O}(|\mathcal{Y}|^{T'})$种。这些输出序列中所有特殊符号“<eos>”及其后面的子序列将被舍弃。
上一节里已经提到,在准备训练数据集时,我们通常会在样本的输入序列和输出序列后面分别附上一个特殊符号“<eos>”表示序列的终止。我们在接下来的讨论中也将沿用上一节的数学符号。为了便于讨论,假设解码器的输出是一段文本序列。设输出文本词典$\mathcal{Y}$(包含特殊符号“<eos>”)的大小为$\left|\mathcal{Y}\right|$,输出序列的最大长度为$T'$。所有可能的输出序列一共有$\mathcal{O}(\left|\mathcal{Y}\right|^{T'})$种。这些输出序列中所有特殊符号“<eos>后面的子序列将被舍弃。
## 穷举搜索
我们在描述解码器时提到,输出序列基于输入序列的条件概率是$\prod_{t'=1}^{T'} \mathbb{P}(y_{t'} \mid y_1, \ldots, y_{t'-1}, \boldsymbol{c})$。为了搜索该概率最大的输出序列,一种方法是穷举所有可能序列的概率,并输出概率最大的序列。我们将该序列称为最优序列,并将这种搜索方法称为穷举搜索(exhaustive search)。
我们在上一节描述解码器时提到,输出序列基于输入序列的条件概率是$\prod_{t'=1}^{T'} \mathbb{P}(y_{t'} \mid y_1, \ldots, y_{t'-1}, \boldsymbol{c})$。为了搜索该条件概率最大的输出序列,一种方法是穷举所有可能输出序列的条件概率,并输出条件概率最大的序列。我们将该序列称为最优序列,并将这种搜索方法称为穷举搜索(exhaustive search)。
虽然穷举搜索可以得到最优的预测序列,但它的计算开销$\mathcal{O}(|\mathcal{Y}|^{T'})$很容易过大。例如,当$|\mathcal{Y}|=10000$且$T'=10$时,我们将评估$10000^{10} = 10^{40}$个序列:这几乎不可能完成。
虽然穷举搜索可以得到最优序列,但它的计算开销$\mathcal{O}(\left|\mathcal{Y}\right|^{T'})$很容易过大。例如,当$|\mathcal{Y}|=10000$且$T'=10$时,我们将评估$10000^{10} = 10^{40}$个序列:这几乎不可能完成。
......@@ -20,37 +20,37 @@
$$y_{t'} = \text{argmax}_{y_{t'} \in \mathcal{Y}} \mathbb{P}(y_{t'} \mid y_1, \ldots, y_{t'-1}, \boldsymbol{c}),$$
且一旦搜索出“<eos>”符号即完成输出。
且一旦搜索出“<eos>”符号即完成输出序列。贪婪搜索的计算开销是$\mathcal{O}(\left|\mathcal{Y}\right|T')$。它比起穷举搜索的计算开销显著下降。例如,当$|\mathcal{Y}|=10000$且$T'=10$时,我们只需评估$10000\times10=1\times10^5$个序列
下面我们来看一个例子。假设输出词典里面有“A”,“B”,“C”和“<eos>”这四个词。图10.3中每个时间步下的四个数字分别代表了该时间步生成“A”,“B”,“C”和“<eos>”这四个词的条件概率。在每个时间步,贪婪搜索选取生成的条件概率最高的词。因此,图10.3中将生成序列“A”、“B”和“C”(舍弃特殊符号“<eos>”)。该输出序列的条件概率是$0.5\times0.4\times0.4 = 0.08$。
下面我们来看一个例子。假设输出词典里面有“A”、“B”、“C”和“<eos>”这四个词。图10.3中每个时间步下的四个数字分别代表了该时间步生成“A”、“B”、“C”和“<eos>”这四个词的条件概率。在每个时间步,贪婪搜索选取生成条件概率最大的词。因此,图10.3中将生成序列“ABC<eos>”。该输出序列的条件概率是$0.5\times0.4\times0.4\times0.6 = 0.048$。
![每个时间步下的四个数字分别代表了该时间步生成“A”,“B”,“C”和“<eos>”这四个词的条件概率。在每个时间步,贪婪搜索选取生成的条件概率最高的词。](../img/s2s_prob1.svg)
![每个时间步下的四个数字分别代表了该时间步生成“A”、“B”、“C”和“<eos>”这四个词的条件概率。在每个时间步,贪婪搜索选取生成条件概率最大的词。](../img/s2s_prob1.svg)
正如绝大部分贪婪算法不能保证最优解一样,贪婪搜索也无法保证找出条件概率最高的输出序列。图10.4演示了这样的一个例子。与图10.3中不同,图10.4在时间步2中选取了条件概率第二大的“C”。由于时间步3所基于的时间步1和2的输出子序列由图10.3的“A”和“B”变为了图10.4的“A”和“C”,图10.4中时间步3生成各个词的条件概率发生了变化。我们选取条件概率最大的“B”。此时时间步4所基于的前三个时间步的输出子序列为“A”、“C”和“B”。图10.4中时间步4生成各个词的条件概率也与图10.3中的不同。我们发现,此时的输出序列“A”、“C”和“B”的条件概率是$0.5\times0.3\times0.6=0.09$,高于贪婪搜索得到的输出序列的条件概率。
![每个时间步下的四个数字分别代表了该时间步生成“A”,“B”,“C”和“<eos>”这四个词的条件概率。在时间步2选取条件概率第二大的“C”。](../img/s2s_prob2.svg)
正如绝大部分贪婪算法不能保证最优解一样,贪婪搜索也无法保证找出条件概率最大的最优序列。图10.4演示了这样的一个例子。与图10.3中不同,图10.4在时间步2中选取了条件概率第二大的“C”。由于时间步3所基于的时间步1和2的输出子序列由图10.3中的“AB”变为了图10.4中的“AC”,图10.4中时间步3生成各个词的条件概率发生了变化。我们选取条件概率最大的“B”。此时时间步4所基于的前三个时间步的输出子序列为“ACB”,与图10.3中的“ABC”不同。因此图10.4中时间步4生成各个词的条件概率也与图10.3中的不同。我们发现,此时的输出序列“ACB<eos>”的条件概率是$0.5\times0.3\times0.6\times0.6=0.054$,大于贪婪搜索得到的输出序列的条件概率。因此,贪婪搜索得到的输出序列“ABC<eos>”并非最优序列。
![每个时间步下的四个数字分别代表了该时间步生成“A”、“B”、“C”和“<eos>”这四个词的条件概率。在时间步2选取条件概率第二大的“C”。](../img/s2s_prob2.svg)
## 束搜索
让我们再来回顾下贪婪搜索和穷举搜索。它们可以概况成如下的算法。假设在每个时间步$t'$我们现在有$n$条长度为$t'$的候选输出序列(包括了开始符)。例如当$t'=0$时就是一条只含有开始符的序列。然后对输出字典里每个候选词计算条件概率,这样我们可以得到$n|\mathcal{O}|$条长度为$t'+1$的候选输出序列和它们的条件概率,然后我们在其中筛选出进入下一个时间步的候选序列。
束搜索(beam search)是比贪婪搜索更加广义的搜索算法。它有一个束宽(beam size)超参数。我们将它设为$k$。在时间步1时,选取当前时间步生成条件概率最大的$k$个词,分别组成$k$个候选输出序列的首词。在之后的每个时间步,基于上个时间步的$k$个候选输出序列,从$k\left|\mathcal{Y}\right|$个可能的输出序列中选取生成条件概率最大的$k$个,作为该时间步的候选输出序列。
最终,我们在各个时间步的候选输出序列中筛选出包含特殊符号“<eos>”的序列,并将它们中所有特殊符号“<eos>”后面的子序列舍弃,得到最终候选输出序列。在这些最终候选输出序列中,取以下分数最高的序列作为输出序列:
前面两个搜索算法不同的地方就在于筛选这一步。贪婪搜索只保留概率最高的序列进入下一时间步,而穷举搜索则保留所有。前者计算简单,其计算复杂度为$\mathcal{O}(|T'|)$,但难以保证输出质量。后者保证最优输出,但计算复杂度高达$\mathcal{O}(|\mathcal{Y}|^{T'})$,其实际中几乎不可能完成。
$$ \frac{1}{L^\alpha} \log \mathbb{P}(y_1, \ldots, y_{L}) = \frac{1}{L^\alpha} \sum_{t'=1}^L \log \mathbb{P}(y_{t'} \mid y_1, \ldots, y_{t'-1}, \boldsymbol{c}),$$
束搜索(beam search)介于两者之间,在筛选时它只保留条件概率最高的$k$条序列,这里$k$叫做束宽(beam size),是一个超参数。当$k=1$时其等价于贪婪搜索,而$k=\infty$时其等价于穷举搜索。束搜索的算法复杂度为$\mathcal{O}(kT'|\mathcal{Y}|)$。实际使用中,我们通过$k$来权衡输出序列质量和计算复杂度。下图演示了束搜索的工作原理
其中$L$为最终候选序列长度,$\alpha$一般可选为0.75。分母上的$L^\alpha$是为了惩罚较长序列在以上分数中较多的对数相加项。分析可得,束搜索的计算开销为$\mathcal{O}(k\left|\mathcal{Y}\right|T')$。这介于穷举搜索和贪婪搜索的计算开销之间
![束宽为2的束搜索,每个时间步选取条件概率最高的两个序列作为候选进入下一时间。](../img/beam_search.svg)
束搜索的停止条件有多种,例如找到一条有终止符的序列,或者找到一条条件概率高于某个阈值的序列,或者到达了最大输出长度。停止时束搜索输出最多$k$个候选输出序列。如果某个序列中含有终止符,那么去掉终止符后面的子序列。通常我们保留最佳的一个或数个序列作为最终输出。
![束搜索的过程。束宽为2,输出序列最大长度为3。候选输出序列有$A$、$C$、$AB$、$CE$、$ABD$和$CED$。](../img/beam_search.svg)
由于输出序列可能长度不一样,较短的序列通常条件概率比较大。在比较时经常将长度信息考虑在内。例如对于长为$L$的序列$y_1,\ldots,y_L$,我们将其对数条件概率除以$L^\alpha$作为分数,然后选取分数高的作为最终输出。这里分数的计算为
$$ \frac{1}{L^\alpha} \log \mathbb{P}(y_1, \ldots, y_{L}\mid \mathrm{context}) = \frac{1}{L^\alpha} \sum_{t^\prime=1}^L \log \mathbb{P}(y_{t^\prime} \mid y_1, \ldots, y_{t^\prime-1}, \mathrm{context}),$$
图10.5通过一个例子演示了束搜索的过程。假设输出序列的词典中只包含五个元素:$\mathcal{Y} = \{A, B, C, D, E\}$,且其中一个为特殊符号“<eos>”。设束搜索的束宽等于2,输出序列最大长度为3。在输出序列的时间步1时,假设条件概率$\mathbb{P}(y_1 \mid \boldsymbol{c})$最大的两个词为$A$和$C$。我们在时间步2时将对所有的$y_2 \in \mathcal{Y}$都分别计算$\mathbb{P}(y_2 \mid A, \boldsymbol{c})$和$\mathbb{P}(y_2 \mid C, \boldsymbol{c})$,并从计算出的10个条件概率中取最大的两个:假设为$\mathbb{P}(B \mid A, \boldsymbol{c})$和$\mathbb{P}(E \mid C, \boldsymbol{c})$。那么,我们在时间步3时将对所有的$y_3 \in \mathcal{Y}$都分别计算$\mathbb{P}(y_3 \mid A, B, \boldsymbol{c})$和$\mathbb{P}(y_3 \mid C, E, \boldsymbol{c})$,并从计算出的10个条件概率中取最大的两个:假设为$\mathbb{P}(D \mid A, B, \boldsymbol{c})$和$\mathbb{P}(D \mid C, E, \boldsymbol{c})$。接下来,我们可以在6个候选输出序列:$A$、$C$、$AB$、$CE$、$ABD$和$CED$中筛选出包含特殊符号“<eos>”的序列,并将它们中所有特殊符号“<eos>”后面的子序列舍弃,得到最终候选输出序列。我们可以在最终候选输出序列中取分数最高的序列作为输出序列。
常数$\alpha$一般可选为0.75
贪婪搜索可看作是束宽为1的束搜索。束搜索通过更灵活的束宽$k$来权衡计算开销和搜索质量
## 小结
......@@ -61,13 +61,10 @@ $$ \frac{1}{L^\alpha} \log \mathbb{P}(y_1, \ldots, y_{L}\mid \mathrm{context}) =
## 练习
* 穷举搜索可否看作是特殊束宽的束搜索?为什么?
*[“循环神经网络”](../chapter_recurrent-neural-networks/rnn.md)一节中,我们使用语言模型创作歌词。它的输出属于哪种搜索?你能改进它吗?
## 扫码直达[讨论区](https://discuss.gluon.ai/t/topic/6817)
![](../img/qr_beam-search.svg)
## 参考文献
[1] Sutskever, I., Vinyals, O., & Le, Q. V. (2014). Sequence to sequence learning with neural networks. In Advances in neural information processing systems (pp. 3104-3112).
......@@ -201,7 +201,7 @@ class DecoderInitState(nn.Block):
### 训练模型并输出不定长序列
我们定义`translate`函数应用训练好的模型,并通过贪婪搜索输出不定长的翻译文本序列。解码器的最初时间步输入来自“<bos>”符号。对于一个输出中的序列,当解码器在某一时间步搜索出“<eos>”符号时,即完成该输出序列。
Sutskever等人发现贪婪搜索也可以在机器翻译中也可以取得不错的结果 [1]。我们定义`translate`函数应用训练好的模型,并通过贪婪搜索输出不定长的翻译文本序列。解码器的最初时间步输入来自“<bos>”符号。对于一个输出中的序列,当解码器在某一时间步搜索出“<eos>”符号时,即完成该输出序列。
```{.python .input}
def translate(encoder, decoder, decoder_init_state, fr_ens, ctx, max_seq_len):
......@@ -325,7 +325,7 @@ train(encoder, decoder, decoder_init_state, max_seq_len, ctx, eval_fr_ens)
## 评价翻译结果
2002年,IBM团队提出了一种评价翻译结果的指标,叫BLEU(Bilingual Evaluation Understudy)[1]。
2002年,IBM团队提出了一种评价翻译结果的指标,叫BLEU(Bilingual Evaluation Understudy)[2]。
设$k$为我们希望评价的$n$个连续词的最大长度,例如$k=4$。设$n$个连续词的精度为$p_n$。它是模型预测序列与样本标签序列匹配$n$个连续词的数量与模型预测序列中$n$个连续词数量之比。举个例子,假设标签序列为$ABCDEF$,预测序列为$ABBCD$。那么$p_1 = 4/5, p_2 = 3/4, p_3 = 1/3, p_4 = 0$。设$len_{\text{label}}$和$len_{\text{pred}}$分别为标签序列和模型预测序列的词数。那么,BLEU的定义为
......@@ -341,7 +341,7 @@ $$ \exp(\min(0, 1 - \frac{len_{\text{label}}}{len_{\text{pred}}})) \prod_{i=1}^k
## 练习
* 试着使用更大的翻译数据集来训练模型,例如WMT [2] 和Tatoeba Project [3]。
* 试着使用更大的翻译数据集来训练模型,例如WMT [3] 和Tatoeba Project [4]。
* 在解码器中使用强制教学,观察实现现象。
## 扫码直达[讨论区](https://discuss.gluon.ai/t/topic/4689)
......@@ -350,8 +350,10 @@ $$ \exp(\min(0, 1 - \frac{len_{\text{label}}}{len_{\text{pred}}})) \prod_{i=1}^k
## 参考文献
[1] Papineni, K., Roukos, S., Ward, T., & Zhu, W. J. (2002, July). BLEU: a method for automatic evaluation of machine translation. In Proceedings of the 40th annual meeting on association for computational linguistics (pp. 311-318). Association for Computational Linguistics.
[1] Sutskever, I., Vinyals, O., & Le, Q. V. (2014). Sequence to sequence learning with neural networks. In Advances in neural information processing systems (pp. 3104-3112).
[2] WMT. http://www.statmt.org/wmt14/translation-task.html
[2] Papineni, K., Roukos, S., Ward, T., & Zhu, W. J. (2002, July). BLEU: a method for automatic evaluation of machine translation. In Proceedings of the 40th annual meeting on association for computational linguistics (pp. 311-318). Association for Computational Linguistics.
[3] Tatoeba Project. http://www.manythings.org/anki/
[3] WMT. http://www.statmt.org/wmt14/translation-task.html
[4] Tatoeba Project. http://www.manythings.org/anki/
此差异已折叠。
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册