Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
PaddleDetection
提交
a0a2f1bf
P
PaddleDetection
项目概览
PaddlePaddle
/
PaddleDetection
大约 1 年 前同步成功
通知
695
Star
11112
Fork
2696
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
184
列表
看板
标记
里程碑
合并请求
40
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
PaddleDetection
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
184
Issue
184
列表
看板
标记
里程碑
合并请求
40
合并请求
40
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
a0a2f1bf
编写于
7月 25, 2017
作者:
S
Superjom
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
update
上级
45072ed2
变更
1
隐藏空白更改
内联
并排
Showing
1 changed file
with
33 addition
and
9 deletion
+33
-9
paddle/operators/rnn_design.md
paddle/operators/rnn_design.md
+33
-9
未找到文件。
paddle/operators/rnn_design.md
浏览文件 @
a0a2f1bf
...
...
@@ -24,7 +24,7 @@
基于当前重构现状,我们使用如下设计来存储变长数据格式
-
每个参与到 Op 的
`inputs/outputs`
的variable 均有一个对应的variable用来存储序列信息(下面我们称此类variable 为
`SeqPosVar`
)
-
扩充 Tensor 以支持存储变长序列的信息(这部分信息后续用SeqPosVar表示
)
-
Op 的
`InferShape`
会更新outputs 的
`SeqPosVar`
-
为了兼容序列Op(比如RNN)和传统Op(比如FC),序列的所有元素均flatten追加存储到一个mini-batch中
-
比如,长度分别为2,3,4的三个句子会存储为一个size为9的
`mini-batch`
...
...
@@ -36,25 +36,49 @@
```
c++
struct
SeqPos
{
int
dim
{
1
};
std
::
vector
<
SeqPos
>
seq_offset
s
;
std
::
vector
<
std
::
shared_ptr
<
std
::
vector
<
int
>>
startPose
s
;
};
```
其中,startPoses可以用于存储多维的子序列,具体如下:
-
如果为1维序列,则
`dim=1`
,
`startPoses.size() = 1`
-
如果为 2 维序列,则
`dim=2`
,
`startPoses[0]`
存储第一维序列信息,
`startPoses[1:]`
存储第二维序列信息
-
如果为 n 维序列,则
`dim=n`
,
`startPoses[0]`
存储第一维序列,后续追加第
`2.. n`
维序列
-
当有完整的 n 维序列的
`SeqPos`
信息时,可以从前往后,粒度从粗到细解析序列
-
当拆解成 n-1 维序列时,
`dim=n-1`
,startPoses 去除第 1 维序列信息,为每个次级序列单独抽取出对应的信息组成新的
`SeqPos`
Tensor 扩展为
```
c++
struct
TensorWithSequence
{
Tensor
*
tensor
;
std
::
shared_ptr
<
SeqPos
>
seq_pos
;
}
```
## 框架支持方法
类似Paddle现在的做法,为了支持每个参与inputs/outputs的variable必须有对应的SeqPosVar,
**这里需要框架就行一些修改,有一些trick的成分**
。
框架需要保证每个参与计算的 variable 均有一个对应的
`SeqPosVar`
,初步设想在 AddOp 时增量创建
`SeqPosVar`
,
在scope里对应的key可以为对应variable的加一个固定的后缀,比如
`@seq-pos`
现有框架可以在
`Context`
里添加一个与
`Input`
平行的接口
`InputSeq`
来获取序列信息,具体定义如下
```
std::shared_ptr<SeqPos> InputSeq(const std::string& name);
```
### 在OP间传递SeqPos
每个Op的
`InferShape`
需要额外更新outputs的SeqPosVar,即使不修改序列信息,也要显式从inputs的SeqPosVar复制给outputs的。
为了能够将SeqPos在Op的调用关系中传递下去,考虑到一些不支持序列的Op(比如FC)可能丢失SeqPos,
框架需要强制所有的OP的InferShape都必须感知并传递SeqPos,
目前最简单的方式是直接在 OperatorBase的InferShape里设置
如果当前Op (比如RNN)需要用到序列信息,则对input添加后缀
`@seq-pos`
获取其对应的 SeqPosVar,操作之。
```
c++
void
InferShape
(
const
std
::
shared_ptr
<
Scope
<>&
scope
)
{
CopyInSeqToOut
();
// ...
}
### 内存复用
由于当计算图固定时,Op是否修改序列信息是确定的,因此SeqPosVar可以用
`shared_ptr`
支持无内存的复制操作来节约这部分内存消耗。
// if inputs has SeqPos, copy to output.
void
CopyInSeqToOut
();
```
## 参考文献
1.
[
Tensorflow Bucketing
](
https://www.tensorflow.org/versions/r0.12/api_docs/python/contrib.training/bucketing
)
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录