rnn_design.md 5.5 KB
Newer Older
S
Superjom 已提交
1
# RNN 变长输入设计
S
Superjom 已提交
2 3
对变长序列的学习,现有主流框架比如 tensorflow, pytorch, caffe2, mxnet 等均使用了padding的方式,
即将一个mini-batch内不同长度的序列补0到固定长度参与计算。
S
Superjom 已提交
4

S
Superjom 已提交
5 6 7 8 9 10 11 12 13 14 15 16 17 18 19
现有Paddle的 `RecurrentLayerGroup` 实现了无padding的变长序列支持,本文也将基于该模块的思路,设计重构后的变长序列支持。

## 非padding 变长序列的意义
由于tensor必须有明确的shape,因此基于tensor 的主流框架在存储变长序列时,
必须用zero-padding的方式将变长序列补全为固定shape的tensor。

由于padding是一种框架实现变长序列的妥协, 从用户角度,在使用RNN类模型时自然会比较介意padding的存在,
因此会有pytorch中对非padding方式变长序列支持长篇的讨论[3]。

由于padding对内存和计算会有额外的消耗,tensorflow和mxnet均使用了bucketing来就行优化[1][2]
但不管是padding还是bucket,对于用户都是额外的使用负担。

因此,**paddle原生支持变长序列的方式,能直接满足用户对变长序列的最直接的需求,在当前主流平台中可以算是一大优势**

但对变长序列的支持,需要对目前框架做一些修改,下面讨论如何在最小修改下支持变长序列。
S
Superjom 已提交
20 21

## 变长数据格式
S
Superjom 已提交
22 23 24 25 26
目前 Paddle 会将一个mini-batch内的数据存储在一维的内存上,
额外使用 `Argument.sequenceStartPositions` 来存储每个句子的信息。

基于当前重构现状,我们使用如下设计来存储变长数据格式

S
update  
Superjom 已提交
27
- 扩充 Tensor 以支持存储变长序列的信息(这部分信息后续用SeqPosVar表示)
S
Superjom 已提交
28 29 30 31 32 33 34 35 36 37 38
- Op 的 `InferShape` 会更新outputs 的`SeqPosVar`
- 为了兼容序列Op(比如RNN)和传统Op(比如FC),序列的所有元素均flatten追加存储到一个mini-batch中
  - 比如,长度分别为2,3,4的三个句子会存储为一个size为9的`mini-batch`
  - 额外会有一个`SeqPosVar`,存储句子的结构,比如offest:`0,2,5,9`
  
为了支持sub-sequence,Paddle里使用 `Argument.subSequenceStartPositions` 来存储2维的序列信息,更高维度的序列无法支持;
这里为了扩展性,将SeqPosVar定义成如下数据结构来支持N维的序列信息的存储:

```c++
struct SeqPos {
  int dim{1};
S
update  
Superjom 已提交
39
  std::vector<std::shared_ptr<std::vector<int>> startPoses;
S
Superjom 已提交
40 41
};
```
S
Superjom 已提交
42

S
update  
Superjom 已提交
43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58
其中,startPoses可以用于存储多维的子序列,具体如下:

- 如果为1维序列,则 `dim=1``startPoses.size() = 1` 
- 如果为 2 维序列,则 `dim=2`, `startPoses[0]` 存储第一维序列信息,`startPoses[1:]` 存储第二维序列信息
- 如果为 n 维序列,则 `dim=n`, `startPoses[0]` 存储第一维序列,后续追加第 `2.. n` 维序列
  - 当有完整的 n 维序列的 `SeqPos` 信息时,可以从前往后,粒度从粗到细解析序列
  - 当拆解成 n-1 维序列时, `dim=n-1`,startPoses 去除第 1 维序列信息,为每个次级序列单独抽取出对应的信息组成新的 `SeqPos`

Tensor 扩展为
```c++
struct TensorWithSequence {
  Tensor* tensor;
  std::shared_ptr<SeqPos> seq_pos;
}
```

S
Superjom 已提交
59
## 框架支持方法
S
Superjom 已提交
60 61 62
类似Paddle现在的做法,为了支持每个参与inputs/outputs的variable必须有对应的SeqPosVar,
**这里需要框架就行一些修改,有一些trick的成分**

S
update  
Superjom 已提交
63
现有框架可以在 `Context` 里添加一个与 `Input` 平行的接口 `InputSeq` 来获取序列信息,具体定义如下
S
Superjom 已提交
64

S
update  
Superjom 已提交
65 66 67
```
std::shared_ptr<SeqPos> InputSeq(const std::string& name);
```
S
Superjom 已提交
68

S
update  
Superjom 已提交
69 70 71
为了能够将SeqPos在Op的调用关系中传递下去,考虑到一些不支持序列的Op(比如FC)可能丢失SeqPos,
框架需要强制所有的OP的InferShape都必须感知并传递SeqPos,
目前最简单的方式是直接在 OperatorBase的InferShape里设置
S
Superjom 已提交
72

S
update  
Superjom 已提交
73 74 75 76 77
```c++
void InferShape(const std::shared_ptr<Scope<>& scope) {
  CopyInSeqToOut();
  // ...
}
S
Superjom 已提交
78

S
update  
Superjom 已提交
79 80 81
// if inputs has SeqPos, copy to output.
void CopyInSeqToOut();
```
S
Superjom 已提交
82

S
update2  
Superjom 已提交
83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146
## 根据长度排序
按照长度排序后,从前往后的时间步的batch size会自然地递减,这是 Net 支持的

比如:

```
origin:
xxxx
xx
xxx

-> sorted:
xx
xxx
xxxx
```

经过 `SegmentInputs` 之后,每个会有4个时间步,每个时间步的输入如下(纵向排列)

```
0    1    2    3
x    x    x    x
x    x    x
x    x
```

为了追踪排序前后序列的变化,这里用
```c++
struct SortedSeqItem {
   void *start{nullptr};
   void *end{nullptr};
};

std::vector<SortedSeqItem> sorted_seqs;
```
来追踪序列排序后的位置。

对比现有设计,只需要修改 `SegmentInputs``ConcatOutputs` 两个接口,此外添加一个 `SortBySeqLen` 的接口,
就可以支持上述变长序列,下面详细介绍。
## SegmentInputs
`SegmentInputs` 会依赖 `sorted_seqs` 的信息,将原始的序列按照排序后的序列顺序,从横向切割,转为每个step中的inputs。

即下面的转变:
```
origin:
xxxx
xx
xxx

   |
   |
  \ /
   *
0    1    2    3
x    x    x    x
x    x    x
x    x
```
## ConcatOutputs
`ConcatOutputs` 需要

- 将每个时间步的输出重新还原为原始输入的序列顺序(以防止Infer阶段顺序打乱)
- 将序列折叠,在batch维度上展开

S
Superjom 已提交
147 148 149 150
## 参考文献
1. [Tensorflow Bucketing](https://www.tensorflow.org/versions/r0.12/api_docs/python/contrib.training/bucketing)
2. [mxnet Bucketing](http://mxnet.io/how_to/bucketing.html)
3. [variable length input in RNN scenario](https://discuss.pytorch.org/t/about-the-variable-length-input-in-rnn-scenario/345/5)