Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
机器未来
Paddle
提交
7163dd04
P
Paddle
项目概览
机器未来
/
Paddle
与 Fork 源项目一致
Fork自
PaddlePaddle / Paddle
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
1
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
1
Issue
1
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
7163dd04
编写于
10月 01, 2017
作者:
Q
qiaolongfei
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
revert code
上级
8db3afad
变更
3
隐藏空白更改
内联
并排
Showing
3 changed file
with
60 addition
and
3 deletion
+60
-3
paddle/operators/recurrent_op.cc
paddle/operators/recurrent_op.cc
+41
-0
paddle/operators/recurrent_op.h
paddle/operators/recurrent_op.h
+19
-0
python/paddle/v2/framework/tests/test_recurrent_op.py
python/paddle/v2/framework/tests/test_recurrent_op.py
+0
-3
未找到文件。
paddle/operators/recurrent_op.cc
浏览文件 @
7163dd04
...
@@ -28,6 +28,29 @@ using Variable = framework::Variable;
...
@@ -28,6 +28,29 @@ using Variable = framework::Variable;
using
Tensor
=
framework
::
Tensor
;
using
Tensor
=
framework
::
Tensor
;
using
LoDTensor
=
framework
::
LoDTensor
;
using
LoDTensor
=
framework
::
LoDTensor
;
void
RecurrentAlgorithm
::
InferShape
(
const
Scope
&
scope
)
const
{
auto
*
input0
=
scope
.
FindVar
(
arg_
->
inlinks
[
0
]);
PADDLE_ENFORCE_NOT_NULL
(
input0
);
seq_len_
=
input0
->
GetMutable
<
LoDTensor
>
()
->
dims
()[
0
];
PADDLE_ENFORCE_GT
(
seq_len_
,
0
);
CreateScopes
(
scope
);
auto
&
step_scopes
=
GetStepScopes
(
scope
);
rnn
::
SegmentInputs
(
step_scopes
,
arg_
->
inlinks
,
seq_len_
,
true
/*infer_shape_mode*/
);
InitMemories
(
step_scopes
[
0
],
true
/*infer_shape_mode*/
);
for
(
size_t
i
=
0
;
i
<
seq_len_
;
i
++
)
{
if
(
i
>
0
)
{
rnn
::
LinkMemories
(
step_scopes
,
arg_
->
memories
,
i
,
-
1
,
true
/*infer_shape_mode*/
);
}
(
*
stepnet_
)
->
InferShape
(
*
step_scopes
[
i
]);
}
rnn
::
ConcatOutputs
(
step_scopes
,
arg_
->
outlinks
,
seq_len_
,
true
/*infer_shape_mode*/
);
}
void
RecurrentAlgorithm
::
Run
(
const
Scope
&
scope
,
void
RecurrentAlgorithm
::
Run
(
const
Scope
&
scope
,
const
platform
::
DeviceContext
&
dev_ctx
)
const
{
const
platform
::
DeviceContext
&
dev_ctx
)
const
{
auto
step_scopes
=
GetStepScopes
(
scope
);
auto
step_scopes
=
GetStepScopes
(
scope
);
...
@@ -179,6 +202,24 @@ void RecurrentGradientAlgorithm::LinkBootMemoryGradients(
...
@@ -179,6 +202,24 @@ void RecurrentGradientAlgorithm::LinkBootMemoryGradients(
}
}
}
}
void
RecurrentGradientAlgorithm
::
InferShape
(
const
Scope
&
scope
)
const
{
seq_len_
=
scope
.
FindVar
(
arg_
->
inlinks
[
0
])
->
GetMutable
<
LoDTensor
>
()
->
dims
()[
0
];
auto
step_scopes
=
GetStepScopes
(
scope
);
rnn
::
SegmentInputs
(
step_scopes
,
arg_
->
inlinks
,
seq_len_
,
true
/*infer_shape_mode*/
);
for
(
int
step_id
=
seq_len_
-
1
;
step_id
>=
0
;
--
step_id
)
{
if
(
static_cast
<
size_t
>
(
step_id
)
!=
seq_len_
-
1
)
{
rnn
::
LinkMemories
(
step_scopes
,
arg_
->
memories
,
step_id
,
1
,
true
/*infer_shape_mode*/
);
}
(
*
stepnet_
)
->
InferShape
(
*
step_scopes
[
step_id
]);
}
rnn
::
ConcatOutputs
(
step_scopes
,
arg_
->
outlinks
,
seq_len_
,
true
/*infer_shape_mode*/
);
LinkBootMemoryGradients
(
step_scopes
[
0
],
true
/*infer_shape_mode*/
);
}
RecurrentGradientOp
::
RecurrentGradientOp
(
RecurrentGradientOp
::
RecurrentGradientOp
(
const
std
::
string
&
type
,
const
framework
::
VariableNameMap
&
inputs
,
const
std
::
string
&
type
,
const
framework
::
VariableNameMap
&
inputs
,
const
framework
::
VariableNameMap
&
outputs
,
const
framework
::
VariableNameMap
&
outputs
,
...
...
paddle/operators/recurrent_op.h
浏览文件 @
7163dd04
...
@@ -41,6 +41,11 @@ class RecurrentAlgorithm {
...
@@ -41,6 +41,11 @@ class RecurrentAlgorithm {
stepnet_
=
stepnet
;
stepnet_
=
stepnet
;
}
}
/**
* InferShape must be called before Run.
*/
void
InferShape
(
const
framework
::
Scope
&
scope
)
const
;
protected:
protected:
/*
/*
* The step scopes will be stored in the father scope as a variable.
* The step scopes will be stored in the father scope as a variable.
...
@@ -89,6 +94,11 @@ class RecurrentGradientAlgorithm {
...
@@ -89,6 +94,11 @@ class RecurrentGradientAlgorithm {
void
LinkBootMemoryGradients
(
framework
::
Scope
*
step_scopes
,
void
LinkBootMemoryGradients
(
framework
::
Scope
*
step_scopes
,
bool
infer_shape_mode
)
const
;
bool
infer_shape_mode
)
const
;
/**
* InferShape must be called before Run.
*/
void
InferShape
(
const
framework
::
Scope
&
scope
)
const
;
protected:
protected:
inline
const
std
::
vector
<
framework
::
Scope
*>&
GetStepScopes
(
inline
const
std
::
vector
<
framework
::
Scope
*>&
GetStepScopes
(
const
framework
::
Scope
&
scope
)
const
{
const
framework
::
Scope
&
scope
)
const
{
...
@@ -123,8 +133,13 @@ class RecurrentOp : public framework::OperatorBase {
...
@@ -123,8 +133,13 @@ class RecurrentOp : public framework::OperatorBase {
void
set_stepnet
(
std
::
unique_ptr
<
OperatorBase
>
net
)
{
void
set_stepnet
(
std
::
unique_ptr
<
OperatorBase
>
net
)
{
stepnet_
=
std
::
move
(
net
);
stepnet_
=
std
::
move
(
net
);
}
}
const
OperatorBase
&
stepnet
()
const
{
return
*
stepnet_
;
}
const
OperatorBase
&
stepnet
()
const
{
return
*
stepnet_
;
}
void
InferShape
(
const
framework
::
Scope
&
scope
)
const
{
alg_
.
InferShape
(
scope
);
}
static
const
rnn
::
ArgumentName
kArgName
;
static
const
rnn
::
ArgumentName
kArgName
;
private:
private:
...
@@ -147,6 +162,10 @@ class RecurrentGradientOp : public framework::OperatorBase {
...
@@ -147,6 +162,10 @@ class RecurrentGradientOp : public framework::OperatorBase {
PADDLE_THROW
(
"Not Implemented"
);
PADDLE_THROW
(
"Not Implemented"
);
}
}
void
InferShape
(
const
framework
::
Scope
&
scope
)
const
{
alg_
.
InferShape
(
scope
);
}
void
Run
(
const
framework
::
Scope
&
scope
,
void
Run
(
const
framework
::
Scope
&
scope
,
const
platform
::
DeviceContext
&
dev_ctx
)
const
override
{
const
platform
::
DeviceContext
&
dev_ctx
)
const
override
{
alg_
.
Run
(
scope
,
dev_ctx
);
alg_
.
Run
(
scope
,
dev_ctx
);
...
...
python/paddle/v2/framework/tests/test_recurrent_op.py
浏览文件 @
7163dd04
...
@@ -197,7 +197,4 @@ class RecurrentGradientOpTest(unittest.TestCase):
...
@@ -197,7 +197,4 @@ class RecurrentGradientOpTest(unittest.TestCase):
if
__name__
==
'__main__'
:
if
__name__
==
'__main__'
:
exit
(
0
)
# FIXME(yuyang18): InferShape has been removed, this unittest may error
unittest
.
main
()
unittest
.
main
()
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录