提交 d3213e4c 编写于 作者: S Superjom

fix pr

上级 a74e7981
......@@ -2,7 +2,7 @@
对变长序列的学习,现有主流框架比如 tensorflow, pytorch, caffe2, mxnet 等均使用了padding的方式,
即将一个mini-batch内不同长度的序列补0到固定长度参与计算。
现有Paddle`RecurrentLayerGroup` 实现了无padding的变长序列支持,本文也将基于该模块的思路,设计重构后的变长序列支持。
现有Paddle包括 `RecurrentLayerGroup` 在内的RNN均实现了无padding的变长序列支持,本文也将基于该模块的思路,设计重构后的变长序列支持。
## 非padding 变长序列的意义
由于tensor必须有明确的shape,因此基于tensor 的主流框架在存储变长序列时,
......@@ -11,7 +11,7 @@
由于padding是一种框架实现变长序列的妥协, 从用户角度,在使用RNN类模型时自然会比较介意padding的存在,
因此会有pytorch中对非padding方式变长序列支持长篇的讨论[3]。
由于padding对内存和计算会有额外的消耗,tensorflow和mxnet均使用了bucketing来行优化[1][2]
由于padding对内存和计算会有额外的消耗,tensorflow和mxnet均使用了bucketing来行优化[1][2]
但不管是padding还是bucket,对于用户都是额外的使用负担。
因此,**paddle原生支持变长序列的方式,能直接满足用户对变长序列的最直接的需求,在当前主流平台中可以算是一大优势**
......@@ -143,9 +143,9 @@ xx
xxx
-> sorted:
xx
xxx
xxxx
xxx
xx
```
经过 `SegmentInputs` 之后,每个会有4个时间步,每个时间步的输入如下(纵向排列)
......@@ -168,8 +168,11 @@ std::vector<SortedSeqItem> sorted_seqs;
```
来追踪序列排序后的位置。
对比现有设计,只需要修改 `SegmentInputs``ConcatOutputs` 两个接口,此外添加一个 `SortBySeqLen` 的接口,
对比现有设计,只需要修改 `InitMemories`, `SegmentInputs``ConcatOutputs` 两个接口,此外添加一个 `SortBySeqLen` 的接口,
就可以支持上述变长序列,下面详细介绍。
## InitMemories
由于序列顺序的变化,`boot_memories` 的batch上的element的顺序也需要对应重新排列。
## SegmentInputs
`SegmentInputs` 会依赖 `sorted_seqs` 的信息,将原始的序列按照排序后的序列顺序,从横向切割,转为每个step中的inputs。
......@@ -183,7 +186,7 @@ xxx
|
|
\ /
*
!
0 1 2 3
x x x x
x x x
......@@ -193,7 +196,7 @@ x x
`ConcatOutputs` 需要
- 将每个时间步的输出重新还原为原始输入的序列顺序(以防止Infer阶段顺序打乱)
-序列折叠,在batch维度上展开
-每个序列concat 为规则的mini-batch表示
## 附录
这里演示多level的变长序列的存储方法,本设计会用两层的`vector` 来存储所有序列的信息,具体数据格式如下
......@@ -243,7 +246,7 @@ std::vector<element_t> seq_start_positions_;
- 紧接着`seq_start_positions_[1]` 存储了第0个paragraph 的信息,表明有3个sentence,其在paragraph 0在tensor中对应部分的偏移分别为0,3 和7
- 紧接着`seq_start_positions_[2]` 存储了第1个paragraph 的信息,表明有2个sentence,其在paragraph 0在tensor中对应部分的偏移分别为0和 5
如上证明了`seq_start_positions_`的数据结构适用于 level 为 1(也就是Paddle中subseq),通过归纳法可以证明其适用于 N level 的序列,这里暂不赘述
如上证明了`seq_start_positions_`的数据结构适用于 level 为 1(也就是Paddle中subseq), **通过归纳法可以证明其适用于 N level 的序列,这里暂不赘述**
## 参考文献
1. [Tensorflow Bucketing](https://www.tensorflow.org/versions/r0.12/api_docs/python/contrib.training/bucketing)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册