提交 1328060a 编写于 作者: S Superjom

update2

上级 a0a2f1bf
......@@ -80,6 +80,70 @@ void InferShape(const std::shared_ptr<Scope<>& scope) {
void CopyInSeqToOut();
```
## 根据长度排序
按照长度排序后,从前往后的时间步的batch size会自然地递减,这是 Net 支持的
比如:
```
origin:
xxxx
xx
xxx
-> sorted:
xx
xxx
xxxx
```
经过 `SegmentInputs` 之后,每个会有4个时间步,每个时间步的输入如下(纵向排列)
```
0 1 2 3
x x x x
x x x
x x
```
为了追踪排序前后序列的变化,这里用
```c++
struct SortedSeqItem {
void *start{nullptr};
void *end{nullptr};
};
std::vector<SortedSeqItem> sorted_seqs;
```
来追踪序列排序后的位置。
对比现有设计,只需要修改 `SegmentInputs``ConcatOutputs` 两个接口,此外添加一个 `SortBySeqLen` 的接口,
就可以支持上述变长序列,下面详细介绍。
## SegmentInputs
`SegmentInputs` 会依赖 `sorted_seqs` 的信息,将原始的序列按照排序后的序列顺序,从横向切割,转为每个step中的inputs。
即下面的转变:
```
origin:
xxxx
xx
xxx
|
|
\ /
*
0 1 2 3
x x x x
x x x
x x
```
## ConcatOutputs
`ConcatOutputs` 需要
- 将每个时间步的输出重新还原为原始输入的序列顺序(以防止Infer阶段顺序打乱)
- 将序列折叠,在batch维度上展开
## 参考文献
1. [Tensorflow Bucketing](https://www.tensorflow.org/versions/r0.12/api_docs/python/contrib.training/bucketing)
2. [mxnet Bucketing](http://mxnet.io/how_to/bucketing.html)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册